#!pip install ANNarchy
Miconi network
Reward-modulated recurrent network based on:
Miconi T. (2017). Biologically plausible learning in recurrent neural networks reproduces neural dynamics observed during cognitive tasks. eLife 6:e20899. doi:10.7554/eLife.20899
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm
import ANNarchy as ann
ANNarchy 5.0 (5.0.0) on darwin (posix).
Each neuron in the reservoir follows the following equations:
\tau \frac{dx(t)}{dt} + x(t) = \sum_\text{input} W^\text{IN} \, r^\text{IN}(t) + \sum_\text{rec} W^\text{REC} \, r(t) + \xi(t)
r(t) = \tanh(x(t))
where \xi(t) is a random perturbation at 3 Hz, with an amplitude randomly sampled between -A and +A.
We additionally keep track of the mean firing rate with a sliding average:
\tilde{x}(t) = \alpha \, \tilde{x}(t) + (1 - \alpha) \, x(t)
The three first neurons keep a constant rate throughout learning (1 or -1) to provide some bias to the other neurons.
= ann.Neuron(
neuron = dict(
parameters = 30.0, # Time constant
tau = ann.Parameter(0.0), # The four first neurons have constant rates
constant = 0.05, # To compute the sliding mean
alpha = 3.0, # Frequency of the perturbation
f = 16., # Perturbation amplitude. dt*A/tau should be 0.5...
A
),= [
equations # Perturbation
'perturbation = if Uniform(0.0, 1.0) < f/1000.: 1.0 else: 0.0',
'noise = if perturbation > 0.5: A * Uniform(-1.0, 1.0) else: 0.0',
# ODE for x
'x += dt*(sum(in) + sum(exc) - x + noise)/tau',
# Output r
'rprev = r', # store r at previous time step
'r = if constant == 0.0: tanh(x) else: tanh(constant)',
# Sliding mean
'delta_x = x - x_mean',
'x_mean = alpha * x_mean + (1 - alpha) * x',
] )
The learning rule is defined by a trace e_{i, j}(t) for each synapse i \rightarrow j incremented at each time step with:
e_{i, j}(t) = e_{i, j}(t-1) + (r_i (t) \, x_j(t))^3
At the end T of a trial, the reward R is delivered and all weights are updated using:
\Delta w_{i, j} = \eta \, e_{i, j}(T) \, |R_\text{mean}| \, (R - R_\text{mean})
where R_\text{mean} is the mean reward for the task. Here the reward is defined as the opposite of the prediction error.
All traces are then reset to 0 for the next trial. Weight changes are clamped between -0.0003 and 0.0003.
As ANNarchy applies the synaptic equations at each time step, we need to introduce a global boolean learning_phase
which performs trace integration when false, and allows weight update when true.
= ann.Synapse(
synapse =dict(
parameters= 0.5, # Learning rate
eta = 0.0003, # Clip the weight changes
max_weight_change
# Flag to allow learning only at the end of a trial
= ann.Parameter(False, 'global', 'bool'),
learning_phase = 0.0, # Reward received
reward = 0.0, # Mean Reward received
mean_reward
),= [
equations # Trace
"""
trace += if not(learning_phase):
power(pre.rprev * (post.delta_x), 3)
else:
0.0
""",
# Weight update only at the end of the trial
"""
ann.Variable( delta_w = if learning_phase:
eta * trace * fabs(mean_reward) * (reward - mean_reward)
else:
0.0
""",
min='-max_weight_change', max='max_weight_change'),
# Weight update
"w += delta_w",
] )
We implement the network as a class deriving from ann.Network
. The network has two inputs A and B, so we create the corresponding static population. The reservoir has 200 neurons, 3 of which having constant rates to serve as biases for the other neurons.
Input weights are uniformly distributed between -1 and 1.
The recurrent weights are normally distributed, with a coupling strength of g=1.5 (edge of chaos). In the original paper, the projection is fully connected (but self-connections are avoided). Using a sparse (0.1) connectivity matrix leads to similar results and is much faster.
class MiconiNetwork (ann.Network):
def __init__(self, N, g, sparseness):
# Input population
self.inp = self.create(2, ann.Neuron("r=0.0"))
# Recurrent population
self.pop = self.create(N, neuron)
# Biases
self.pop[0].constant = 1.0
self.pop[1].constant = 1.0
self.pop[2].constant = -1.0
# Input weights
self.Wi = self.connect(self.inp, self.pop, 'in')
self.Wi.all_to_all(weights=ann.Uniform(-1.0, 1.0))
# Recurrent weights
self.Wrec = self.connect(self.pop, self.pop, 'exc', synapse)
if sparseness == 1.0:
self.Wrec.all_to_all(weights=ann.Normal(0., g/np.sqrt(N)))
else:
self.Wrec.fixed_probability(
=sparseness,
probability=ann.Normal(0., g/np.sqrt(sparseness*N))
weights
)
# Monitor
self.m = self.monitor(self.pop, ['r'], start=False)
= MiconiNetwork(N=200, g=1.5, sparseness=0.1)
net compile() net.
Compiling network 1... OK
The output of the reservoir is chosen to be the neuron of index 100.
= 100 OUTPUT_NEURON
Parameters defining the task:
# Durations
= 200
d_stim = 200
d_delay= 200 d_response
Definition of a DNMS trial (AA, AB, BA, BB):
def dnms_trial(trial_number, input, target, R_mean, record=False, perturbation=True):
# Switch off perturbations if needed
if not perturbation:
= net.pop.A
old_A = 0.0
net.pop.A
# Reinitialize network
= ann.Uniform(-0.1, 0.1).get_values(net.pop.size)
net.pop.x = np.tanh(net.pop.x)
net.pop.r 0].r = np.tanh(1.0)
net.pop[1].r = np.tanh(1.0)
net.pop[2].r = np.tanh(-1.0)
net.pop[
if record: net.m.resume()
# First input
input[0]].r = 1.0
net.inp[
net.simulate(d_stim)
# Delay
= 0.0
net.inp.r
net.simulate(d_delay)
# Second input
input[1]].r = 1.0
net.inp[
net.simulate(d_stim)
# Delay
= 0.0
net.inp.r
net.simulate(d_delay)
# Response
if not record: net.m.resume()
= 0.0
net.inp.r
net.simulate(d_response)
# Read the output
net.m.pause()= net.m.get('r')
recordings
# Response is over the last 200 ms
= recordings[-int(d_response):, OUTPUT_NEURON] # neuron 100 over the last 200 ms
output
# Compute the reward as the opposite of the absolute error
= - np.mean(np.abs(target - output))
reward
# The first 25 trial do not learn, to let R_mean get realistic values
if trial_number > 25:
# Apply the learning rule
= True
net.Wrec.learning_phase = reward
net.Wrec.reward = R_mean
net.Wrec.mean_reward
# Learn for one step
net.step()
# Reset the traces
= False
net.Wrec.learning_phase = 0.0
net.Wrec.trace #_ = m.get() # to flush the recording of the last step
# Switch back on perturbations if needed
if not perturbation:
= old_A
net.pop.A
return recordings, reward
Let’s visualize the activity of the output neuron during the first four trials.
# Perform the four different trials successively
= dnms_trial(0, [0, 0], -0.98, 0.0, record=True)
initialAA, errorAA = dnms_trial(0, [0, 1], +0.98, 0.0, record=True)
initialAB, errorAB = dnms_trial(0, [1, 0], +0.98, 0.0, record=True)
initialBA, errorBA = dnms_trial(0, [1, 1], -0.98, 0.0, record=True)
initialBB, errorBB
=(12, 10))
plt.figure(figsize= plt.subplot(221)
ax
ax.plot(initialAA[:, OUTPUT_NEURON])-1., 1.))
ax.set_ylim(('Output AA -1')
ax.set_title(= plt.subplot(222)
ax
ax.plot(initialBA[:, OUTPUT_NEURON])-1., 1.))
ax.set_ylim(('Output BA +1')
ax.set_title(= plt.subplot(223)
ax
ax.plot(initialAB[:, OUTPUT_NEURON])-1., 1.))
ax.set_ylim(('Output AB +1')
ax.set_title(= plt.subplot(224)
ax
ax.plot(initialBB[:, OUTPUT_NEURON])-1., 1.))
ax.set_ylim(('Output BB -1')
ax.set_title( plt.show()
We can now run the simulation for 1500 trials. Beware, this can take 15 to 20 minutes.
# Compute the mean reward per trial
= - np.ones((2, 2))
R_mean = 0.75
alpha
# Many trials of each type
= []
record_rewards
for trial in (t := tqdm(range(10000))):
# Perform the four different trials successively
= dnms_trial(trial, [0, 0], -0.98, R_mean[0, 0])
_, rewardAA
= dnms_trial(trial, [0, 1], +0.98, R_mean[0, 1])
_, rewardAB
= dnms_trial(trial, [1, 0], +0.98, R_mean[1, 0])
_, rewardBA
= dnms_trial(trial, [1, 1], -0.98, R_mean[1, 1])
_, rewardBB
# Reward
= np.array([[rewardAA, rewardBA], [rewardBA, rewardBB]])
reward
# Update mean reward
= alpha * R_mean + (1.- alpha) * reward
R_mean
record_rewards.append(R_mean)
t.set_description(f'AA: {R_mean[0, 0]:.2f} AB: {R_mean[0, 1]:.2f} BA: {R_mean[1, 0]:.2f} BB: {R_mean[1, 1]:.2f}'
)
AA: -0.07 AB: -0.28 BA: -0.28 BB: -0.06: 100%|██████████| 10000/10000 [13:32<00:00, 12.31it/s]
= np.array(record_rewards)
record_rewards
=(10, 6))
plt.figure(figsize0, 0], label='AA')
plt.plot(record_rewards[:, 0, 1], label='AB')
plt.plot(record_rewards[:, 1, 0], label='BA')
plt.plot(record_rewards[:, 1, 1], label='BB')
plt.plot(record_rewards[:, =(1,2)), label='mean')
plt.plot(record_rewards.mean(axis
"Trials")
plt.xlabel("Mean reward")
plt.ylabel(
plt.legend() plt.show()
# Perform the four different trials without perturbation for testing
= dnms_trial(0, [0, 0], -0.98, 0.0, record=True, perturbation=False)
recordsAA, errorAA = dnms_trial(0, [0, 1], +0.98, 0.0, record=True, perturbation=False)
recordsAB, errorAB = dnms_trial(0, [1, 0], +0.98, 0.0, record=True, perturbation=False)
recordsBA, errorBA = dnms_trial(0, [1, 1], -0.98, 0.0, record=True, perturbation=False)
recordsBB, errorBB
=(12, 10))
plt.figure(figsize221)
plt.subplot(='before')
plt.plot(initialAA[:, OUTPUT_NEURON], label='after')
plt.plot(recordsAA[:, OUTPUT_NEURON], label
plt.legend()-1., 1.))
plt.ylim(('Trial AA : t=-1')
plt.title(222)
plt.subplot(='before')
plt.plot(initialBA[:, OUTPUT_NEURON], label='after')
plt.plot(recordsBA[:, OUTPUT_NEURON], label-1., 1.))
plt.ylim(('Trial BA : t=+1')
plt.title(223)
plt.subplot(='before')
plt.plot(initialAB[:, OUTPUT_NEURON], label='after')
plt.plot(recordsAB[:, OUTPUT_NEURON], label-1., 1.))
plt.ylim(('Trial AB : t=+1')
plt.title(224)
plt.subplot(='before')
plt.plot(initialBB[:, OUTPUT_NEURON], label='after')
plt.plot(recordsBB[:, OUTPUT_NEURON], label-1., 1.))
plt.ylim(('Trial BB : t=-1')
plt.title( plt.show()