#!pip install ANNarchy
STDP - single synapse
This notebook demonstrates the online implementation of the spike time-dependent plasticity (STDP) rule for a pair of neurons.
import numpy as np
import matplotlib.pyplot as plt
import ANNarchy as ann
ANNarchy 5.0 (5.0.0) on darwin (posix).
The STDP learning rule maintains exponentially-decaying traces for the pre-synaptic and post-synaptic spikes.
\tau^+ \, \frac{d x(t)}{dt} = -x (t)
\tau^- \, \frac{d y(t)}{dt} = -x (t)
LTP and LTD occur at spike times depending on the corresponding traces.
- When a pre-synaptic spike occurs, x(t) is incremented and LTD is applied proportionally to y(t).
- When a post-synaptic spike occurs, y(t) is incremented and LTP is applied proportionally to x(t).
= ann.Synapse(
STDP = dict(
parameters = 20.0,
tau_plus = 20.0,
tau_minus = 0.01,
A_plus = 0.01,
A_minus = 0.0,
w_min = 2.0,
w_max
),= [
equations # Pre-synaptic trace
'tau_plus * dx/dt = -x', method='event-driven'),
ann.Variable(# Post-synaptic trace
'tau_minus * dy/dt = -y', method='event-driven'),
ann.Variable(
],="""
pre_spike g_target += w
x += A_plus * w_max
w = clip(w - y, w_min , w_max) # LTD
""",
="""
post_spike y += A_minus * w_max
w = clip(w + x, w_min , w_max) # LTP
"""
)
We create two dummy populations with one neuron each, whose spike times we can control.
= ann.Network()
net = net.create(ann.SpikeSourceArray([[0.]]))
pre = net.create(ann.SpikeSourceArray([[50.]])) post
We connect the population using a STDP synapse.
= net.connect(pre, post, 'exc', STDP)
proj 1.0) proj.all_to_all(
<ANNarchy.core.Projection.Projection at 0x1125f0e30>
compile() net.
Compiling network 1... OK
The presynaptic neuron will fire at various times between 0 and 100 ms, while the postsynaptic neuron keeps firing at 50 ms.
= np.linspace(100.0, 0.0, 101) pre_times
= []
weight_changes
for t_pre in pre_times:
# Reset the populations
pre.clear()
post.clear()= [[t_pre]]
pre.spike_times = [[50.0]]
post.spike_times
# Reset the traces
= 0.0
proj.x = 0.0
proj.y
# Weight before the simulation
= proj[0].w[0]
w_before
# Simulate long enough
105.0)
net.simulate(
# Record weight change
= proj[0].w[0] - w_before
delta_w weight_changes.append(delta_w)
We can now plot the classical STDP figure:
=(10, 8))
plt.figure(figsize50. - pre_times, weight_changes, "*")
plt.plot(-50, 50], [0, 0], 'k')
plt.plot([0, 0], [min(weight_changes), max(weight_changes)], 'k')
plt.plot(["t_post - t_pre")
plt.xlabel("delta_w")
plt.ylabel( plt.show()