ANNarchy 4.8.2
  • ANNarchy
  • Installation
  • Tutorial
  • Manual
  • Notebooks
  • Reference

ANN-to-SNN conversion - CNN

  • List of notebooks
  • Rate-coded networks
    • Echo-state networks
    • Neural field
    • Bar Learning
    • Miconi network
    • Structural plasticity
  • Spiking networks
    • AdEx
    • PyNN/Brian
    • Izhikevich
    • Synaptic transmission
    • Gap junctions
    • Hodgkin-Huxley
    • COBA/CUBA
    • STP
    • STDP I
    • STDP II
    • Homeostatic STDP - Ramp
    • Homeostatic STDP - SORF
  • Advanced features
    • Hybrid networks
    • Parallel run
    • Bayesian optimization
  • Extensions
    • Image
    • Tensorboard
    • BOLD monitor I
    • BOLD monitor II
    • ANN to SNN I
    • ANN to SNN II

On this page

  • Training an ANN in tensorflow/keras
  • Initialize the ANN-to-SNN converter

ANN-to-SNN conversion - CNN

Download JupyterNotebook Download JupyterNotebook

This notebook demonstrates how to transform a CNN trained using tensorflow/keras into an SNN network usable in ANNarchy.

The CNN is adapted from the original model used in:

Diehl et al. (2015) “Fast-classifying, high-accuracy spiking deep networks through weight and threshold balancing” Proceedings of IJCNN. doi: 10.1109/IJCNN.2015.7280696

#!pip install ANNarchy
import numpy as np
import matplotlib.pyplot as plt

import tensorflow as tf
print(f"Tensorflow {tf.__version__}")
Tensorflow 2.16.2
# Download data
(X_train, t_train), (X_test, t_test) = tf.keras.datasets.mnist.load_data()

# Normalize inputs
X_train = X_train.astype('float32') / 255.
X_test = X_test.astype('float32') / 255.

# One-hot output vectors
T_train = tf.keras.utils.to_categorical(t_train, 10)
T_test = tf.keras.utils.to_categorical(t_test, 10)

Training an ANN in tensorflow/keras

The tensorflow.keras convolutional network is built using the functional API.

The CNN has three 5*5 convolutional layers with ReLU, each followed by 2*2 max-pooling, no bias, dropout at 0.25, and a softmax output layer with 10 neurons. We use the standard SGD optimizer and the categorical crossentropy loss for classification.

def create_cnn():
    
    inputs = tf.keras.Input(shape = (28, 28, 1))
    x = tf.keras.layers.Conv2D(
        16, 
        kernel_size=(5,5),
        activation='relu',
        padding = 'same',
        use_bias=False)(inputs)
    x = tf.keras.layers.MaxPooling2D(pool_size=(2, 2))(x)
    x = tf.keras.layers.Conv2D(
        64,
        kernel_size=(5,5),
        activation='relu',
        padding = 'same',
        use_bias=False)(x)
    x = tf.keras.layers.MaxPooling2D(pool_size=(2, 2))(x)
    x = tf.keras.layers.Conv2D(
        64,
        kernel_size=(5,5),
        activation='relu',
        padding = 'same',
        use_bias=False)(x)
    x = tf.keras.layers.MaxPooling2D(pool_size=(2, 2))(x)
    x = tf.keras.layers.Dropout(0.25)(x)
    x = tf.keras.layers.Flatten()(x)
    x = tf.keras.layers.Dense(
        10,
        activation='softmax',
        use_bias=False)(x)

    # Create functional model
    model= tf.keras.Model(inputs, x)
    optimizer = tf.keras.optimizers.SGD(learning_rate=0.01)

    # Loss function
    model.compile(
        loss='categorical_crossentropy', # loss function
        optimizer=optimizer, # learning rule
        metrics=['accuracy'] # show accuracy
    )
    print(model.summary())

    return model
# Create model
model = create_cnn()

# Train model
history = model.fit(
    X_train, T_train,       # training data
    batch_size=128,          # batch size
    epochs=20,              # Maximum number of epochs
    validation_split=0.1,   # Percentage of training data used for validation
)

model.save("runs/cnn.keras")

# Test model
predictions_keras = model.predict(X_test, verbose=0)
test_loss, test_accuracy = model.evaluate(X_test, T_test, verbose=0)
print(f"Test accuracy: {test_accuracy}")
Model: "functional"
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┓
┃ Layer (type)                    ┃ Output Shape           ┃       Param # ┃
┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━┩
│ input_layer (InputLayer)        │ (None, 28, 28, 1)      │             0 │
├─────────────────────────────────┼────────────────────────┼───────────────┤
│ conv2d (Conv2D)                 │ (None, 28, 28, 16)     │           400 │
├─────────────────────────────────┼────────────────────────┼───────────────┤
│ max_pooling2d (MaxPooling2D)    │ (None, 14, 14, 16)     │             0 │
├─────────────────────────────────┼────────────────────────┼───────────────┤
│ conv2d_1 (Conv2D)               │ (None, 14, 14, 64)     │        25,600 │
├─────────────────────────────────┼────────────────────────┼───────────────┤
│ max_pooling2d_1 (MaxPooling2D)  │ (None, 7, 7, 64)       │             0 │
├─────────────────────────────────┼────────────────────────┼───────────────┤
│ conv2d_2 (Conv2D)               │ (None, 7, 7, 64)       │       102,400 │
├─────────────────────────────────┼────────────────────────┼───────────────┤
│ max_pooling2d_2 (MaxPooling2D)  │ (None, 3, 3, 64)       │             0 │
├─────────────────────────────────┼────────────────────────┼───────────────┤
│ dropout (Dropout)               │ (None, 3, 3, 64)       │             0 │
├─────────────────────────────────┼────────────────────────┼───────────────┤
│ flatten (Flatten)               │ (None, 576)            │             0 │
├─────────────────────────────────┼────────────────────────┼───────────────┤
│ dense (Dense)                   │ (None, 10)             │         5,760 │
└─────────────────────────────────┴────────────────────────┴───────────────┘
 Total params: 134,160 (524.06 KB)
 Trainable params: 134,160 (524.06 KB)
 Non-trainable params: 0 (0.00 B)
None
Epoch 1/20
422/422 ━━━━━━━━━━━━━━━━━━━━ 16s 37ms/step - accuracy: 0.3226 - loss: 2.0190 - val_accuracy: 0.9087 - val_loss: 0.3620
Epoch 2/20
422/422 ━━━━━━━━━━━━━━━━━━━━ 17s 40ms/step - accuracy: 0.8745 - loss: 0.4172 - val_accuracy: 0.9560 - val_loss: 0.1607
Epoch 3/20
422/422 ━━━━━━━━━━━━━━━━━━━━ 16s 38ms/step - accuracy: 0.9303 - loss: 0.2364 - val_accuracy: 0.9662 - val_loss: 0.1238
Epoch 4/20
422/422 ━━━━━━━━━━━━━━━━━━━━ 16s 38ms/step - accuracy: 0.9459 - loss: 0.1799 - val_accuracy: 0.9745 - val_loss: 0.0967
Epoch 5/20
422/422 ━━━━━━━━━━━━━━━━━━━━ 17s 40ms/step - accuracy: 0.9554 - loss: 0.1512 - val_accuracy: 0.9742 - val_loss: 0.0908
Epoch 6/20
422/422 ━━━━━━━━━━━━━━━━━━━━ 17s 41ms/step - accuracy: 0.9601 - loss: 0.1332 - val_accuracy: 0.9775 - val_loss: 0.0792
Epoch 7/20
422/422 ━━━━━━━━━━━━━━━━━━━━ 16s 38ms/step - accuracy: 0.9636 - loss: 0.1189 - val_accuracy: 0.9788 - val_loss: 0.0732
Epoch 8/20
422/422 ━━━━━━━━━━━━━━━━━━━━ 16s 39ms/step - accuracy: 0.9671 - loss: 0.1061 - val_accuracy: 0.9783 - val_loss: 0.0713
Epoch 9/20
422/422 ━━━━━━━━━━━━━━━━━━━━ 17s 40ms/step - accuracy: 0.9692 - loss: 0.0988 - val_accuracy: 0.9797 - val_loss: 0.0690
Epoch 10/20
422/422 ━━━━━━━━━━━━━━━━━━━━ 18s 43ms/step - accuracy: 0.9712 - loss: 0.0937 - val_accuracy: 0.9825 - val_loss: 0.0615
Epoch 11/20
422/422 ━━━━━━━━━━━━━━━━━━━━ 17s 39ms/step - accuracy: 0.9724 - loss: 0.0876 - val_accuracy: 0.9820 - val_loss: 0.0617
Epoch 12/20
422/422 ━━━━━━━━━━━━━━━━━━━━ 18s 42ms/step - accuracy: 0.9745 - loss: 0.0821 - val_accuracy: 0.9833 - val_loss: 0.0569
Epoch 13/20
422/422 ━━━━━━━━━━━━━━━━━━━━ 19s 44ms/step - accuracy: 0.9760 - loss: 0.0780 - val_accuracy: 0.9810 - val_loss: 0.0594
Epoch 14/20
422/422 ━━━━━━━━━━━━━━━━━━━━ 16s 39ms/step - accuracy: 0.9763 - loss: 0.0768 - val_accuracy: 0.9832 - val_loss: 0.0581
Epoch 15/20
422/422 ━━━━━━━━━━━━━━━━━━━━ 18s 42ms/step - accuracy: 0.9777 - loss: 0.0730 - val_accuracy: 0.9818 - val_loss: 0.0633
Epoch 16/20
422/422 ━━━━━━━━━━━━━━━━━━━━ 17s 40ms/step - accuracy: 0.9783 - loss: 0.0688 - val_accuracy: 0.9830 - val_loss: 0.0535
Epoch 17/20
422/422 ━━━━━━━━━━━━━━━━━━━━ 19s 46ms/step - accuracy: 0.9783 - loss: 0.0680 - val_accuracy: 0.9840 - val_loss: 0.0540
Epoch 18/20
422/422 ━━━━━━━━━━━━━━━━━━━━ 17s 39ms/step - accuracy: 0.9810 - loss: 0.0623 - val_accuracy: 0.9853 - val_loss: 0.0500
Epoch 19/20
422/422 ━━━━━━━━━━━━━━━━━━━━ 16s 38ms/step - accuracy: 0.9792 - loss: 0.0634 - val_accuracy: 0.9850 - val_loss: 0.0517
Epoch 20/20
422/422 ━━━━━━━━━━━━━━━━━━━━ 16s 38ms/step - accuracy: 0.9812 - loss: 0.0615 - val_accuracy: 0.9850 - val_loss: 0.0500
Test accuracy: 0.9857000112533569
plt.figure(figsize=(12, 6))
plt.subplot(121)
plt.plot(history.history['loss'], '-r', label="Training")
plt.plot(history.history['val_loss'], '-b', label="Validation")
plt.xlabel('Epoch #')
plt.ylabel('Loss')
plt.legend()

plt.subplot(122)
plt.plot(history.history['accuracy'], '-r', label="Training")
plt.plot(history.history['val_accuracy'], '-b', label="Validation")
plt.xlabel('Epoch #')
plt.ylabel('Accuracy')
plt.legend()
plt.show()

Initialize the ANN-to-SNN converter

We now create an instance of the ANN-to-SNN conversion object.

from ANNarchy.extensions.ann_to_snn_conversion import ANNtoSNNConverter

snn_converter = ANNtoSNNConverter(
    input_encoding='IB', 
    hidden_neuron='IaF',
    read_out='spike_count',
)
ANNarchy 4.8 (4.8.1) on darwin (posix).
net = snn_converter.load_keras_model("runs/cnn.keras", show_info=True)
WARNING: Dense representation is an experimental feature for spiking models, we greatly appreciate bug reports. 
* Input layer: input_layer, (28, 28, 1)
* InputLayer skipped.
* Conv2D layer: conv2d, (28, 28, 16) 
* MaxPooling2D layer: max_pooling2d, (14, 14, 16) 
* Conv2D layer: conv2d_1, (14, 14, 64) 
* MaxPooling2D layer: max_pooling2d_1, (7, 7, 64) 
* Conv2D layer: conv2d_2, (7, 7, 64) 
* MaxPooling2D layer: max_pooling2d_2, (3, 3, 64) 
* Dropout skipped.
* Flatten skipped.
* Dense layer: dense, 10 
    weights: (10, 576)
    mean 5.032593981013633e-05, std 0.06920499354600906
    min -0.26682066917419434, max 0.22082802653312683
predictions_snn = snn_converter.predict(X_test[:300], duration_per_sample=200)
300/300

Using the recorded predictions, we can now compute the accuracy using scikit-learn for all presented samples.

from sklearn.metrics import classification_report, accuracy_score

print(classification_report(t_test[:300], predictions_snn))
print("Test accuracy of the SNN:", accuracy_score(t_test[:300], predictions_snn))
              precision    recall  f1-score   support

           0       1.00      1.00      1.00        24
           1       1.00      0.95      0.97        41
           2       0.97      1.00      0.98        32
           3       1.00      1.00      1.00        24
           4       1.00      0.95      0.97        37
           5       1.00      1.00      1.00        29
           6       1.00      1.00      1.00        24
           7       1.00      1.00      1.00        34
           8       0.91      1.00      0.95        21
           9       0.97      1.00      0.99        34

    accuracy                           0.99       300
   macro avg       0.99      0.99      0.99       300
weighted avg       0.99      0.99      0.99       300

Test accuracy of the SNN: 0.9866666666666667
ANN to SNN I
 

Copyright Julien Vitay, Helge Ülo Dinkelbach, Fred Hamker