ANNarchy 5.0.0
  • ANNarchy
  • Installation
  • Tutorial
  • Manual
  • Notebooks
  • Reference

ANN-to-SNN conversion - MLP

  • List of notebooks
  • Rate-coded networks
    • Echo-state networks
    • Neural field
    • Bar Learning
    • Miconi network
    • Structural plasticity
  • Spiking networks
    • AdEx
    • PyNN/Brian
    • Izhikevich
    • Synaptic transmission
    • Gap junctions
    • Hodgkin-Huxley
    • COBA/CUBA
    • STP
    • STDP I
    • STDP II
    • Homeostatic STDP - Ramp
    • Homeostatic STDP - SORF
  • Advanced features
    • Hybrid networks
    • Parallel run
    • Bayesian optimization
  • Extensions
    • Image
    • Tensorboard
    • BOLD monitor I
    • BOLD monitor II
    • ANN to SNN I
    • ANN to SNN II

On this page

  • Training an ANN in tensorflow/keras
  • Initialize the ANN-to-SNN converter

ANN-to-SNN conversion - MLP

Download JupyterNotebook Download JupyterNotebook

This notebook demonstrates how to transform a fully-connected neural network trained using tensorflow/keras into an SNN network usable in ANNarchy.

The methods are adapted from the original models used in:

Diehl et al. (2015) “Fast-classifying, high-accuracy spiking deep networks through weight and threshold balancing” Proceedings of IJCNN. doi: 10.1109/IJCNN.2015.7280696

#!pip install ANNarchy
import numpy as np
import matplotlib.pyplot as plt

import tensorflow as tf
print(f"Tensorflow {tf.__version__}")
Tensorflow 2.16.2

First we need to download and process the MNIST dataset provided by tensorflow.

# Download data
(X_train, t_train), (X_test, t_test) = tf.keras.datasets.mnist.load_data()

# Normalize inputs
X_train = X_train.reshape(X_train.shape[0], 784).astype('float32') / 255.
X_test = X_test.reshape(X_test.shape[0], 784).astype('float32') / 255.

# One-hot output vectors
T_train = tf.keras.utils.to_categorical(t_train, 10)
T_test = tf.keras.utils.to_categorical(t_test, 10)

Training an ANN in tensorflow/keras

The tensorflow.keras network is build using the functional API.

The fully-connected network has two fully connected layers with ReLU, no bias, dropout at 0.5, and a softmax output layer with 10 neurons. We use the standard SGD optimizer and the categorical crossentropy loss for classification.

def create_mlp():
    # Model
    inputs = tf.keras.layers.Input(shape=(784,))
    x= tf.keras.layers.Dense(128, use_bias=False, activation='relu')(inputs)
    x = tf.keras.layers.Dropout(0.5)(x)
    x= tf.keras.layers.Dense(128, use_bias=False, activation='relu')(x)
    x = tf.keras.layers.Dropout(0.5)(x)
    x=tf.keras.layers.Dense(10, use_bias=False, activation='softmax')(x)

    model= tf.keras.Model(inputs, x)

    # Optimizer
    optimizer = tf.keras.optimizers.SGD(learning_rate=0.05)

    # Loss function
    model.compile(
        loss='categorical_crossentropy', # loss function
        optimizer=optimizer, # learning rule
        metrics=['accuracy'] # show accuracy
    )
    print(model.summary())

    return model

We can now train the network and save the weights in the HDF5 format.

# Create model
model = create_mlp()

# Train model
history = model.fit(
    X_train, T_train,       # training data
    batch_size=128,          # batch size
    epochs=20,              # Maximum number of epochs
    validation_split=0.1,   # Percentage of training data used for validation
)

model.save("runs/mlp.keras")

# Test model
predictions_keras = model.predict(X_test, verbose=0)
test_loss, test_accuracy = model.evaluate(X_test, T_test, verbose=0)
print(f"Test accuracy: {test_accuracy}")
Model: "functional"
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┓
┃ Layer (type)                    ┃ Output Shape           ┃       Param # ┃
┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━┩
│ input_layer (InputLayer)        │ (None, 784)            │             0 │
├─────────────────────────────────┼────────────────────────┼───────────────┤
│ dense (Dense)                   │ (None, 128)            │       100,352 │
├─────────────────────────────────┼────────────────────────┼───────────────┤
│ dropout (Dropout)               │ (None, 128)            │             0 │
├─────────────────────────────────┼────────────────────────┼───────────────┤
│ dense_1 (Dense)                 │ (None, 128)            │        16,384 │
├─────────────────────────────────┼────────────────────────┼───────────────┤
│ dropout_1 (Dropout)             │ (None, 128)            │             0 │
├─────────────────────────────────┼────────────────────────┼───────────────┤
│ dense_2 (Dense)                 │ (None, 10)             │         1,280 │
└─────────────────────────────────┴────────────────────────┴───────────────┘
 Total params: 118,016 (461.00 KB)
 Trainable params: 118,016 (461.00 KB)
 Non-trainable params: 0 (0.00 B)
None

Epoch 1/20

422/422 ━━━━━━━━━━━━━━━━━━━━ 1s 2ms/step - accuracy: 0.4845 - loss: 1.5137 - val_accuracy: 0.9093 - val_loss: 0.3401

Epoch 2/20

422/422 ━━━━━━━━━━━━━━━━━━━━ 1s 1ms/step - accuracy: 0.8202 - loss: 0.5868 - val_accuracy: 0.9325 - val_loss: 0.2403

Epoch 3/20

422/422 ━━━━━━━━━━━━━━━━━━━━ 1s 2ms/step - accuracy: 0.8576 - loss: 0.4765 - val_accuracy: 0.9397 - val_loss: 0.2055

Epoch 4/20

422/422 ━━━━━━━━━━━━━━━━━━━━ 1s 1ms/step - accuracy: 0.8795 - loss: 0.4142 - val_accuracy: 0.9478 - val_loss: 0.1804

Epoch 5/20

422/422 ━━━━━━━━━━━━━━━━━━━━ 1s 1ms/step - accuracy: 0.8931 - loss: 0.3656 - val_accuracy: 0.9528 - val_loss: 0.1633

Epoch 6/20

422/422 ━━━━━━━━━━━━━━━━━━━━ 1s 2ms/step - accuracy: 0.9040 - loss: 0.3385 - val_accuracy: 0.9572 - val_loss: 0.1492

Epoch 7/20

422/422 ━━━━━━━━━━━━━━━━━━━━ 1s 2ms/step - accuracy: 0.9083 - loss: 0.3170 - val_accuracy: 0.9597 - val_loss: 0.1424

Epoch 8/20

422/422 ━━━━━━━━━━━━━━━━━━━━ 1s 1ms/step - accuracy: 0.9148 - loss: 0.3010 - val_accuracy: 0.9650 - val_loss: 0.1317

Epoch 9/20

422/422 ━━━━━━━━━━━━━━━━━━━━ 1s 2ms/step - accuracy: 0.9192 - loss: 0.2857 - val_accuracy: 0.9632 - val_loss: 0.1269

Epoch 10/20

422/422 ━━━━━━━━━━━━━━━━━━━━ 1s 2ms/step - accuracy: 0.9237 - loss: 0.2655 - val_accuracy: 0.9648 - val_loss: 0.1218

Epoch 11/20

422/422 ━━━━━━━━━━━━━━━━━━━━ 1s 1ms/step - accuracy: 0.9240 - loss: 0.2662 - val_accuracy: 0.9670 - val_loss: 0.1189

Epoch 12/20

422/422 ━━━━━━━━━━━━━━━━━━━━ 1s 2ms/step - accuracy: 0.9265 - loss: 0.2539 - val_accuracy: 0.9677 - val_loss: 0.1108

Epoch 13/20

422/422 ━━━━━━━━━━━━━━━━━━━━ 1s 1ms/step - accuracy: 0.9303 - loss: 0.2385 - val_accuracy: 0.9693 - val_loss: 0.1092

Epoch 14/20

422/422 ━━━━━━━━━━━━━━━━━━━━ 1s 1ms/step - accuracy: 0.9335 - loss: 0.2353 - val_accuracy: 0.9713 - val_loss: 0.1049

Epoch 15/20

422/422 ━━━━━━━━━━━━━━━━━━━━ 1s 2ms/step - accuracy: 0.9342 - loss: 0.2312 - val_accuracy: 0.9707 - val_loss: 0.1053

Epoch 16/20

422/422 ━━━━━━━━━━━━━━━━━━━━ 1s 2ms/step - accuracy: 0.9347 - loss: 0.2231 - val_accuracy: 0.9708 - val_loss: 0.1041

Epoch 17/20

422/422 ━━━━━━━━━━━━━━━━━━━━ 1s 1ms/step - accuracy: 0.9398 - loss: 0.2122 - val_accuracy: 0.9718 - val_loss: 0.0996

Epoch 18/20

422/422 ━━━━━━━━━━━━━━━━━━━━ 1s 1ms/step - accuracy: 0.9402 - loss: 0.2086 - val_accuracy: 0.9747 - val_loss: 0.0946

Epoch 19/20

422/422 ━━━━━━━━━━━━━━━━━━━━ 1s 2ms/step - accuracy: 0.9394 - loss: 0.2090 - val_accuracy: 0.9723 - val_loss: 0.0969

Epoch 20/20

422/422 ━━━━━━━━━━━━━━━━━━━━ 1s 1ms/step - accuracy: 0.9406 - loss: 0.2005 - val_accuracy: 0.9728 - val_loss: 0.0944

Test accuracy: 0.9666000008583069
plt.figure(figsize=(12, 6))
plt.subplot(121)
plt.plot(history.history['loss'], '-r', label="Training")
plt.plot(history.history['val_loss'], '-b', label="Validation")
plt.xlabel('Epoch #')
plt.ylabel('Loss')
plt.legend()

plt.subplot(122)
plt.plot(history.history['accuracy'], '-r', label="Training")
plt.plot(history.history['val_accuracy'], '-b', label="Validation")
plt.xlabel('Epoch #')
plt.ylabel('Accuracy')
plt.legend()
plt.show()

Initialize the ANN-to-SNN converter

We first create an instance of the ANN-to-SNN conversion object. The function receives the input_encoding parameter, which is the type of input encoding we want to use.

By default, there are intrinsically bursting (IB), phase shift oscillation (PSO) and Poisson (poisson) available.

from ANNarchy.extensions.ann_to_snn_conversion import ANNtoSNNConverter

snn_converter = ANNtoSNNConverter(
    input_encoding='IB', 
    hidden_neuron='IaF',
    read_out='spike_count',
)
ANNarchy 5.0 (5.0.0) on darwin (posix).

After that, we provide the TensorFlow model stored as a .keras file to the conversion tool. The print-out of the network structure of the imported network is suppressed when show_info=False is provided to load_keras_model.

net = snn_converter.load_keras_model("runs/mlp.keras", show_info=True)
WARNING: Dense representation is an experimental feature for spiking models, we greatly appreciate bug reports. 
* Input layer: input_layer, (784,)
* InputLayer skipped.
* Dense layer: dense, 128 
    weights: (128, 784)
    mean -0.004177612718194723, std 0.05281704664230347
    min -0.3429253101348877, max 0.22064846754074097
* Dropout skipped.
* Dense layer: dense_1, 128 
    weights: (128, 128)
    mean 0.005270183552056551, std 0.10235019028186798
    min -0.28134748339653015, max 0.39932867884635925
* Dropout skipped.
* Dense layer: dense_2, 10 
    weights: (10, 128)
    mean 0.00408650329336524, std 0.21635150909423828
    min -0.5984256267547607, max 0.46855056285858154

When the network has been built successfully, we can perform a test using all MNIST training samples. Using duration_per_sample, the duration simulated for each image can be specified. Here, 200 ms seem to be enough.

predictions_snn = snn_converter.predict(X_test, duration_per_sample=200)
100%|██████████| 10000/10000 [00:56<00:00, 178.57it/s]

Using the recorded predictions, we can now compute the accuracy using scikit-learn for all presented samples.

from sklearn.metrics import classification_report, accuracy_score

print(classification_report(t_test, predictions_snn))
print("Test accuracy of the SNN:", accuracy_score(t_test, predictions_snn))
              precision    recall  f1-score   support

           0       0.96      0.99      0.98       980
           1       0.98      0.98      0.98      1135
           2       0.96      0.97      0.96      1032
           3       0.96      0.95      0.96      1010
           4       0.97      0.96      0.96       982
           5       0.96      0.95      0.95       892
           6       0.96      0.97      0.96       958
           7       0.97      0.96      0.96      1028
           8       0.94      0.95      0.95       974
           9       0.97      0.94      0.96      1009

    accuracy                           0.96     10000
   macro avg       0.96      0.96      0.96     10000
weighted avg       0.96      0.96      0.96     10000

Test accuracy of the SNN: 0.9627

For comparison, here is the performance of the original ANN in keras:

print(classification_report(t_test, predictions_keras.argmax(axis=1)))
print("Test accuracy of the ANN:", accuracy_score(t_test, predictions_keras.argmax(axis=1)))
              precision    recall  f1-score   support

           0       0.97      0.99      0.98       980
           1       0.98      0.99      0.98      1135
           2       0.97      0.97      0.97      1032
           3       0.97      0.96      0.97      1010
           4       0.96      0.96      0.96       982
           5       0.96      0.95      0.96       892
           6       0.95      0.97      0.96       958
           7       0.96      0.97      0.96      1028
           8       0.96      0.96      0.96       974
           9       0.98      0.94      0.96      1009

    accuracy                           0.97     10000
   macro avg       0.97      0.97      0.97     10000
weighted avg       0.97      0.97      0.97     10000

Test accuracy of the ANN: 0.9666
BOLD monitor II
ANN to SNN II
 

Copyright Julien Vitay, Helge Ülo Dinkelbach, Fred Hamker