#!pip install ANNarchy
ANN-to-SNN conversion - MLP
This notebook demonstrates how to transform a fully-connected neural network trained using tensorflow/keras into an SNN network usable in ANNarchy.
The methods are adapted from the original models used in:
Diehl et al. (2015) “Fast-classifying, high-accuracy spiking deep networks through weight and threshold balancing” Proceedings of IJCNN. doi: 10.1109/IJCNN.2015.7280696
import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf
print(f"Tensorflow {tf.__version__}")
Tensorflow 2.16.2
First we need to download and process the MNIST dataset provided by tensorflow.
# Download data
= tf.keras.datasets.mnist.load_data()
(X_train, t_train), (X_test, t_test)
# Normalize inputs
= X_train.reshape(X_train.shape[0], 784).astype('float32') / 255.
X_train = X_test.reshape(X_test.shape[0], 784).astype('float32') / 255.
X_test
# One-hot output vectors
= tf.keras.utils.to_categorical(t_train, 10)
T_train = tf.keras.utils.to_categorical(t_test, 10) T_test
Training an ANN in tensorflow/keras
The tensorflow.keras
network is build using the functional API.
The fully-connected network has two fully connected layers with ReLU, no bias, dropout at 0.5, and a softmax output layer with 10 neurons. We use the standard SGD optimizer and the categorical crossentropy loss for classification.
def create_mlp():
# Model
= tf.keras.layers.Input(shape=(784,))
inputs = tf.keras.layers.Dense(128, use_bias=False, activation='relu')(inputs)
x= tf.keras.layers.Dropout(0.5)(x)
x = tf.keras.layers.Dense(128, use_bias=False, activation='relu')(x)
x= tf.keras.layers.Dropout(0.5)(x)
x =tf.keras.layers.Dense(10, use_bias=False, activation='softmax')(x)
x
= tf.keras.Model(inputs, x)
model
# Optimizer
= tf.keras.optimizers.SGD(learning_rate=0.05)
optimizer
# Loss function
compile(
model.='categorical_crossentropy', # loss function
loss=optimizer, # learning rule
optimizer=['accuracy'] # show accuracy
metrics
)print(model.summary())
return model
We can now train the network and save the weights in the HDF5 format.
# Create model
= create_mlp()
model
# Train model
= model.fit(
history # training data
X_train, T_train, =128, # batch size
batch_size=20, # Maximum number of epochs
epochs=0.1, # Percentage of training data used for validation
validation_split
)
"runs/mlp.keras")
model.save(
# Test model
= model.predict(X_test, verbose=0)
predictions_keras = model.evaluate(X_test, T_test, verbose=0)
test_loss, test_accuracy print(f"Test accuracy: {test_accuracy}")
Model: "functional"
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┓ ┃ Layer (type) ┃ Output Shape ┃ Param # ┃ ┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━┩ │ input_layer (InputLayer) │ (None, 784) │ 0 │ ├─────────────────────────────────┼────────────────────────┼───────────────┤ │ dense (Dense) │ (None, 128) │ 100,352 │ ├─────────────────────────────────┼────────────────────────┼───────────────┤ │ dropout (Dropout) │ (None, 128) │ 0 │ ├─────────────────────────────────┼────────────────────────┼───────────────┤ │ dense_1 (Dense) │ (None, 128) │ 16,384 │ ├─────────────────────────────────┼────────────────────────┼───────────────┤ │ dropout_1 (Dropout) │ (None, 128) │ 0 │ ├─────────────────────────────────┼────────────────────────┼───────────────┤ │ dense_2 (Dense) │ (None, 10) │ 1,280 │ └─────────────────────────────────┴────────────────────────┴───────────────┘
Total params: 118,016 (461.00 KB)
Trainable params: 118,016 (461.00 KB)
Non-trainable params: 0 (0.00 B)
None Epoch 1/20 422/422 ━━━━━━━━━━━━━━━━━━━━ 1s 2ms/step - accuracy: 0.4845 - loss: 1.5137 - val_accuracy: 0.9093 - val_loss: 0.3401 Epoch 2/20 422/422 ━━━━━━━━━━━━━━━━━━━━ 1s 1ms/step - accuracy: 0.8202 - loss: 0.5868 - val_accuracy: 0.9325 - val_loss: 0.2403 Epoch 3/20 422/422 ━━━━━━━━━━━━━━━━━━━━ 1s 2ms/step - accuracy: 0.8576 - loss: 0.4765 - val_accuracy: 0.9397 - val_loss: 0.2055 Epoch 4/20 422/422 ━━━━━━━━━━━━━━━━━━━━ 1s 1ms/step - accuracy: 0.8795 - loss: 0.4142 - val_accuracy: 0.9478 - val_loss: 0.1804 Epoch 5/20 422/422 ━━━━━━━━━━━━━━━━━━━━ 1s 1ms/step - accuracy: 0.8931 - loss: 0.3656 - val_accuracy: 0.9528 - val_loss: 0.1633 Epoch 6/20 422/422 ━━━━━━━━━━━━━━━━━━━━ 1s 2ms/step - accuracy: 0.9040 - loss: 0.3385 - val_accuracy: 0.9572 - val_loss: 0.1492 Epoch 7/20 422/422 ━━━━━━━━━━━━━━━━━━━━ 1s 2ms/step - accuracy: 0.9083 - loss: 0.3170 - val_accuracy: 0.9597 - val_loss: 0.1424 Epoch 8/20 422/422 ━━━━━━━━━━━━━━━━━━━━ 1s 1ms/step - accuracy: 0.9148 - loss: 0.3010 - val_accuracy: 0.9650 - val_loss: 0.1317 Epoch 9/20 422/422 ━━━━━━━━━━━━━━━━━━━━ 1s 2ms/step - accuracy: 0.9192 - loss: 0.2857 - val_accuracy: 0.9632 - val_loss: 0.1269 Epoch 10/20 422/422 ━━━━━━━━━━━━━━━━━━━━ 1s 2ms/step - accuracy: 0.9237 - loss: 0.2655 - val_accuracy: 0.9648 - val_loss: 0.1218 Epoch 11/20 422/422 ━━━━━━━━━━━━━━━━━━━━ 1s 1ms/step - accuracy: 0.9240 - loss: 0.2662 - val_accuracy: 0.9670 - val_loss: 0.1189 Epoch 12/20 422/422 ━━━━━━━━━━━━━━━━━━━━ 1s 2ms/step - accuracy: 0.9265 - loss: 0.2539 - val_accuracy: 0.9677 - val_loss: 0.1108 Epoch 13/20 422/422 ━━━━━━━━━━━━━━━━━━━━ 1s 1ms/step - accuracy: 0.9303 - loss: 0.2385 - val_accuracy: 0.9693 - val_loss: 0.1092 Epoch 14/20 422/422 ━━━━━━━━━━━━━━━━━━━━ 1s 1ms/step - accuracy: 0.9335 - loss: 0.2353 - val_accuracy: 0.9713 - val_loss: 0.1049 Epoch 15/20 422/422 ━━━━━━━━━━━━━━━━━━━━ 1s 2ms/step - accuracy: 0.9342 - loss: 0.2312 - val_accuracy: 0.9707 - val_loss: 0.1053 Epoch 16/20 422/422 ━━━━━━━━━━━━━━━━━━━━ 1s 2ms/step - accuracy: 0.9347 - loss: 0.2231 - val_accuracy: 0.9708 - val_loss: 0.1041 Epoch 17/20 422/422 ━━━━━━━━━━━━━━━━━━━━ 1s 1ms/step - accuracy: 0.9398 - loss: 0.2122 - val_accuracy: 0.9718 - val_loss: 0.0996 Epoch 18/20 422/422 ━━━━━━━━━━━━━━━━━━━━ 1s 1ms/step - accuracy: 0.9402 - loss: 0.2086 - val_accuracy: 0.9747 - val_loss: 0.0946 Epoch 19/20 422/422 ━━━━━━━━━━━━━━━━━━━━ 1s 2ms/step - accuracy: 0.9394 - loss: 0.2090 - val_accuracy: 0.9723 - val_loss: 0.0969 Epoch 20/20 422/422 ━━━━━━━━━━━━━━━━━━━━ 1s 1ms/step - accuracy: 0.9406 - loss: 0.2005 - val_accuracy: 0.9728 - val_loss: 0.0944 Test accuracy: 0.9666000008583069
=(12, 6))
plt.figure(figsize121)
plt.subplot('loss'], '-r', label="Training")
plt.plot(history.history['val_loss'], '-b', label="Validation")
plt.plot(history.history['Epoch #')
plt.xlabel('Loss')
plt.ylabel(
plt.legend()
122)
plt.subplot('accuracy'], '-r', label="Training")
plt.plot(history.history['val_accuracy'], '-b', label="Validation")
plt.plot(history.history['Epoch #')
plt.xlabel('Accuracy')
plt.ylabel(
plt.legend() plt.show()
Initialize the ANN-to-SNN converter
We first create an instance of the ANN-to-SNN conversion object. The function receives the input_encoding parameter, which is the type of input encoding we want to use.
By default, there are intrinsically bursting (IB
), phase shift oscillation (PSO
) and Poisson (poisson
) available.
from ANNarchy.extensions.ann_to_snn_conversion import ANNtoSNNConverter
= ANNtoSNNConverter(
snn_converter ='IB',
input_encoding='IaF',
hidden_neuron='spike_count',
read_out )
ANNarchy 5.0 (5.0.0) on darwin (posix).
After that, we provide the TensorFlow model stored as a .keras
file to the conversion tool. The print-out of the network structure of the imported network is suppressed when show_info=False
is provided to load_keras_model
.
= snn_converter.load_keras_model("runs/mlp.keras", show_info=True) net
WARNING: Dense representation is an experimental feature for spiking models, we greatly appreciate bug reports.
* Input layer: input_layer, (784,)
* InputLayer skipped.
* Dense layer: dense, 128
weights: (128, 784)
mean -0.004177612718194723, std 0.05281704664230347
min -0.3429253101348877, max 0.22064846754074097
* Dropout skipped.
* Dense layer: dense_1, 128
weights: (128, 128)
mean 0.005270183552056551, std 0.10235019028186798
min -0.28134748339653015, max 0.39932867884635925
* Dropout skipped.
* Dense layer: dense_2, 10
weights: (10, 128)
mean 0.00408650329336524, std 0.21635150909423828
min -0.5984256267547607, max 0.46855056285858154
When the network has been built successfully, we can perform a test using all MNIST training samples. Using duration_per_sample
, the duration simulated for each image can be specified. Here, 200 ms seem to be enough.
= snn_converter.predict(X_test, duration_per_sample=200) predictions_snn
100%|██████████| 10000/10000 [00:56<00:00, 178.57it/s]
Using the recorded predictions, we can now compute the accuracy using scikit-learn for all presented samples.
from sklearn.metrics import classification_report, accuracy_score
print(classification_report(t_test, predictions_snn))
print("Test accuracy of the SNN:", accuracy_score(t_test, predictions_snn))
precision recall f1-score support
0 0.96 0.99 0.98 980
1 0.98 0.98 0.98 1135
2 0.96 0.97 0.96 1032
3 0.96 0.95 0.96 1010
4 0.97 0.96 0.96 982
5 0.96 0.95 0.95 892
6 0.96 0.97 0.96 958
7 0.97 0.96 0.96 1028
8 0.94 0.95 0.95 974
9 0.97 0.94 0.96 1009
accuracy 0.96 10000
macro avg 0.96 0.96 0.96 10000
weighted avg 0.96 0.96 0.96 10000
Test accuracy of the SNN: 0.9627
For comparison, here is the performance of the original ANN in keras:
print(classification_report(t_test, predictions_keras.argmax(axis=1)))
print("Test accuracy of the ANN:", accuracy_score(t_test, predictions_keras.argmax(axis=1)))
precision recall f1-score support
0 0.97 0.99 0.98 980
1 0.98 0.99 0.98 1135
2 0.97 0.97 0.97 1032
3 0.97 0.96 0.97 1010
4 0.96 0.96 0.96 982
5 0.96 0.95 0.96 892
6 0.95 0.97 0.96 958
7 0.96 0.97 0.96 1028
8 0.96 0.96 0.96 974
9 0.98 0.94 0.96 1009
accuracy 0.97 10000
macro avg 0.97 0.97 0.97 10000
weighted avg 0.97 0.97 0.97 10000
Test accuracy of the ANN: 0.9666