#!pip install ANNarchy
Homeostatic STDP: SORF model
Reimplementation of the SORF model published in:
Carlson, K.D.; Richert, M.; Dutt, N.; Krichmar, J.L., “Biologically plausible models of homeostasis and STDP: Stability and learning in spiking neural networks,” in Neural Networks (IJCNN), The 2013 International Joint Conference on , vol., no., pp.1-8, 4-9 Aug. 2013. doi: 10.1109/IJCNN.2013.6706961
import numpy as np
import matplotlib.pyplot as plt
import ANNarchy as ann
from tqdm import tqdm
ANNarchy 5.0 (5.0.0) on darwin (posix).
Hyperparameters:
= 4 # Number of exc and inh neurons
nb_neuron = (32, 32) # input size
size = 1.2 # nb_cycles/half-image
freq = 40 # Number of grating per epoch
nb_stim = 20 # Number of epochs
nb_epochs = 28. # Max frequency of the poisson neurons
max_freq = 10000. # Period for averaging the firing rate T
Neuron type:
# Izhikevich Coba neuron with AMPA, NMDA and GABA receptors
= ann.Neuron(
RSNeuron = dict(
parameters = 0.02,
a = 0.2,
b = -65.,
c = 8.,
d
= 5.,
tau_ampa = 150.,
tau_nmda = 6.,
tau_gabaa = 150.,
tau_gabab
= 0.0,
vrev_ampa = 0.0,
vrev_nmda = -70.0,
vrev_gabaa = -90.0,
vrev_gabab
) ,= [
equations # Inputs
"""
ann.Variable( I = g_ampa * (vrev_ampa - v) + g_nmda * nmda(v, -80.0, 60.0) * (vrev_nmda -v) + g_gabaa * (vrev_gabaa - v) + g_gabab * (vrev_gabab -v)
"""),
# Midpoint scheme
"dv/dt = (0.04 * v + 5.0) * v + 140.0 - u + I", init=-65., min=-90., method='midpoint'),
ann.Variable("du/dt = a * (b*v - u)", init=-13., method='midpoint'),
ann.Variable(
# Conductances
"tau_ampa * dg_ampa/dt = -g_ampa", method='exponential'),
ann.Variable("tau_nmda * dg_nmda/dt = -g_nmda", method='exponential'),
ann.Variable("tau_gabaa * dg_gabaa/dt = -g_gabaa", method='exponential'),
ann.Variable("tau_gabab * dg_gabab/dt = -g_gabab", method='exponential'),
ann.Variable(
],= "v >= 30.",
spike = """
reset v = c
u += d
g_ampa = 0.0
g_nmda = 0.0
g_gabaa = 0.0
g_gabab = 0.0
""",
= """
functions nmda(v, t, s) = ((v-t)/(s))^2 / (1.0 + ((v-t)/(s))^2)
""",
=1.0
refractory )
Synapse:
# STDP with homeostatic regulation
= ann.Synapse(
homeo_stdp =dict(
parameters# STDP
= 60.,
tau_plus = 90.,
tau_minus = 0.000045,
A_plus = 0.00003,
A_minus
# Homeostatic regulation
= 0.1,
alpha = 50.0, # <- Difference with the original implementation
beta = 50.0,
gamma = 10.,
Rtarget = 10000.,
T
),= [
equations # Homeostatic values
"R = post.r", locality='semiglobal'),
ann.Variable("K = R/(T * (1. + fabs(1. - R / Rtarget) * gamma))", locality='semiglobal'),
ann.Variable(
# Nearest-neighbour
"stdp = if t_post >= t_pre: ltp else: - ltd"),
ann.Variable("w += (alpha * w * (1- R/Rtarget) + beta * stdp ) * K", min=0.0, max=10.0),
ann.Variable(
# Traces
"tau_plus * dltp/dt = -ltp", method="exponential"),
ann.Variable("tau_minus * dltd/dt = -ltd", method="exponential"),
ann.Variable(
],="""
pre_spike g_target += w
ltp = A_plus
""",
="ltd = A_minus"
post_spike )
Network:
# Network
= ann.Network()
net
# Input population
= net.create(ann.PoissonPopulation(size, rates=1.0))
OnPoiss = net.create(ann.PoissonPopulation(size, rates=1.0))
OffPoiss
# RS neuron for the input buffers
= net.create(size, RSNeuron)
OnBuffer = net.create(size, RSNeuron)
OffBuffer
# Connect the buffers
= net.connect(OnPoiss, OnBuffer, ['ampa', 'nmda'])
OnPoissBuffer 0.2, 0.6))
OnPoissBuffer.one_to_one(ann.Uniform(= net.connect(OffPoiss, OffBuffer, ['ampa', 'nmda'])
OffPoissBuffer 0.2, 0.6))
OffPoissBuffer.one_to_one(ann.Uniform(
# Excitatory and inhibitory neurons
= net.create(nb_neuron, RSNeuron)
Exc = net.create(nb_neuron, RSNeuron)
Inh
Exc.compute_firing_rate(T)
Inh.compute_firing_rate(T)
# Input connections
= net.connect(OnBuffer, Exc, ['ampa', 'nmda'], homeo_stdp)
OnBufferExc 0.004, 0.015))
OnBufferExc.all_to_all(ann.Uniform(= net.connect(OffBuffer, Exc, ['ampa', 'nmda'], homeo_stdp)
OffBufferExc 0.004, 0.015))
OffBufferExc.all_to_all(ann.Uniform(
# Competition
= net.connect(Exc, Inh, ['ampa', 'nmda'], homeo_stdp)
ExcInh 0.116, 0.403))
ExcInh.all_to_all(ann.Uniform(
= 75.
ExcInh.Rtarget = 51.
ExcInh.tau_plus = 78.
ExcInh.tau_minus = -0.000041
ExcInh.A_plus = -0.000015
ExcInh.A_minus
= net.connect(Inh, Exc, ['gabaa', 'gabab'])
InhExc 0.065, 0.259))
InhExc.all_to_all(ann.Uniform(
compile() net.
Compiling network 1... OK
# Inputs
def get_grating(theta):
= np.linspace(-1., 1., size[0])
x = np.linspace(-1., 1., size[1])
y = np.meshgrid(x, y)
xx, yy = np.sin(2.*np.pi*(np.cos(theta)*xx + np.sin(theta)*yy)*freq)
z return np.maximum(z, 0.), -np.minimum(z, 0.0)
# Initial weights
= OnBufferExc.w
w_on_start = OffBufferExc.w
w_off_start
# Monitors
= net.monitor(Exc, 'r')
m = net.monitor(Inh, 'r')
n = net.monitor(OnBufferExc[0], 'w', period=1000.)
o = net.monitor(ExcInh[0], 'w', period=1000.)
p
# Learning procedure
from time import time
import random
= time()
tstart = list(range(nb_stim))
stim_order
for epoch in tqdm(range(nb_epochs)):
random.shuffle(stim_order)
for stim in stim_order:
# Generate a grating randomly
= get_grating(np.pi*stim/float(nb_stim))
rates_on, rates_off
# Set it as input to the poisson neurons
= max_freq * rates_on
OnPoiss.rates = max_freq * rates_off
OffPoiss.rates
# Simulate for 2s
2000.)
net.simulate(
# Relax the Poisson inputs
= 1.
OnPoiss.rates = 1.
OffPoiss.rates
# Simulate for 500ms
500.)
net.simulate(
print('Done in ', time()-tstart)
# Recordings
= m.get('r')
datae = n.get('r')
datai = o.get('w')
dataw = p.get('w') datal
100%|██████████| 20/20 [01:40<00:00, 5.05s/it]
Done in 101.01462006568909
# Final weights
= OnBufferExc.w
w_on_end = OffBufferExc.w
w_off_end
# Plot
=(12, 12))
plt.figure(figsize'Feedforward weights before and after learning')
plt.title(for i in range(nb_neuron):
3, nb_neuron, i+1)
plt.subplot(32,32)), aspect='auto', cmap='hot')
plt.imshow((np.array(w_on_start[i])).reshape((3, nb_neuron, nb_neuron + i +1)
plt.subplot(32,32)), aspect='auto', cmap='hot')
plt.imshow((np.array(w_on_end[i])).reshape((3, nb_neuron, 2*nb_neuron + i +1)
plt.subplot(32,32)), aspect='auto', cmap='hot')
plt.imshow((np.array(w_off_end[i])).reshape((
=(12, 8))
plt.figure(figsize0], label='Exc')
plt.plot(datae[:, 0], label='Inh')
plt.plot(datai[:, 'Mean FR of the Exc and Inh neurons')
plt.title(
plt.legend()
=(12, 8))
plt.figure(figsize121)
plt.subplot(='float').T, aspect='auto', cmap='hot')
plt.imshow(np.array(dataw, dtype'Timecourse of feedforward weights')
plt.title(
plt.colorbar()122)
plt.subplot(='float').T, aspect='auto', cmap='hot')
plt.imshow(np.array(datal, dtype'Timecourse of inhibitory weights')
plt.title(
plt.colorbar() plt.show()