#!pip install ANNarchy
Convolutions and pooling
This simple example in examples/image
demonstrates how to load images directly into the firing rates of a population and apply basic linear filters on it.
It relies on the ANNarchy extensions image
and convolution
which must be explicitly imported:
import numpy as np
import ANNarchy as ann
from ANNarchy.extensions.image import ImagePopulation
from ANNarchy.extensions.convolution import Convolution, Pooling
ann.clear()
ANNarchy 4.8 (4.8.2) on darwin (posix).
ANNarchy.extensions.image
depends on the Python bindings of OpenCV, they must be installed before running the script.
We first create an ImagePopulation
that will load images:
= ImagePopulation(geometry=(480, 640, 3)) image
Its geometry specifies the size of the images that can be loaded, here 640x480 RGB images. Note the geometry must be of the form (height, width, channels), where channels is 1 for grayscale images and 3 for color images.
The next step is to reduce the size of the image, what can be done by using the Pooling
class of the convolution
extension.
We define a dummy artificial neuron, whose firing rate r
will simply be the sum of excitatory connections /ensured to be positive, but this should always be the case). We then create a smaller population pooled
with this neuron type, and connect it to the ImagePopulation
using mean-pooling:
# Simple ANN
= ann.Neuron(equations="r=sum(exc): min=0.0")
LinearNeuron
# Subsampling population
= ann.Population(geometry=(48, 64, 3), neuron = LinearNeuron)
pooled
# Mean-pooling projection
= Pooling(pre=image, post=pooled, target='exc', operation='mean')
pool_proj pool_proj.connect_pooling()
<ANNarchy.extensions.convolution.Pooling.Pooling at 0x117fe1010>
The pooled
population reduces the size of the image by a factor ten (defined by the size of the population) by averaging the pixels values over 10x10 regions (operation
is set to 'mean'
, but one could use 'max'
or 'min'
). The connect_pooling()
connector creates the “fake” connection pattern (as no weights are involved).
Let’s apply now a 3x3 box filter on each channel of the pooled population:
# Smoothing population
= ann.Population(geometry=(48, 64, 3), neuron = LinearNeuron)
smoothed
# Box filter projection
= np.ones((3, 3, 1))/9.
box_filter = Convolution(pre=pooled, post=smoothed, target='exc')
smooth_proj =box_filter) smooth_proj.connect_filter(weights
<ANNarchy.extensions.convolution.Convolve.Convolution at 0x13016d2b0>
To perform a convolution operation on the population (or more precisely a cross-correlation), we call the connect_filter()
connector method of the Convolution
projection. It requires to define a kernel (weights
) that will be convolved over the input population. Here we use a simple box filter, but any filter can be used.
As the pooled
population has three dimensions and we want to smooth the activities per color channel, we need to define a (3, 3, 1) kernel. If we wanted to smooth also over the color channels, we could have used a (3, 3) filter: the resulting population would have the shape (48, 64).
We now apply a bank of three filters, each selective to a particular color (red/green/blue). This filters do not have a spatial extent (1x1 convolution), but sum over the third dimension (the color channels):
# Convolution population
= ann.Population(geometry=(48, 64, 3), neuron = LinearNeuron)
filtered
# Red/Green/Blue filter bank
= np.array([
filter_bank 2.0, -1.0, -1.0] ]] , # Red filter
[[ [-1.0, 2.0, -1.0] ]] , # Blue filter
[[ [-1.0, -1.0, 2.0] ]] # Green filter
[[ [
])= Convolution(pre=smoothed, post=filtered, target='exc')
filter_proj =filter_bank) filter_proj.connect_filters(weights
<ANNarchy.extensions.convolution.Convolve.Convolution at 0x1300f1810>
Each of the three filter has the shape (1, 1, 3). The result of each convolution would then be (48, 64), but as there are three filters, the output population is (48, 64, 3). The last dimension does not correspond to the number of color channels, but to the number of filters in the bank: if you add a filter, the population will have to be (48, 64, 4).
Banks of filters require to use connect_filters()
instead of connect_filter()
.
compile() ann.
Compiling ... OK
After compilation, we can load an image into the input population:
'test.jpg') image.set_image(
To see the result, we need to simulate for four time steps (4 milliseconds, as dt=1.0
).
- Step 1: The
image
population loads the image. - Step 2: The
pooled
population subsamples the image. - Step 3: The
smoothed
population filters the pooled image. - Step 4: The bank of filters are applied by
filtered
.
4.0) ann.simulate(
import matplotlib.pyplot as plt
= plt.figure(figsize=(15, 20))
fig
532)
plt.subplot(
plt.imshow(image.r)'Original')
plt.title(
534)
plt.subplot(0], cmap='gray', interpolation='nearest', vmin= 0.0, vmax=1.0)
plt.imshow(image.r[:,:,'image R')
plt.title(535)
plt.subplot(1], cmap='gray', interpolation='nearest', vmin= 0.0, vmax=1.0)
plt.imshow(image.r[:,:,'image G')
plt.title(536)
plt.subplot(2], cmap='gray', interpolation='nearest', vmin= 0.0, vmax=1.0)
plt.imshow(image.r[:,:,'image B')
plt.title(
537)
plt.subplot(0], cmap='gray', interpolation='nearest', vmin= 0.0, vmax=1.0)
plt.imshow(pooled.r[:,:,'pooled R')
plt.title(538)
plt.subplot(1], cmap='gray', interpolation='nearest', vmin= 0.0, vmax=1.0)
plt.imshow(pooled.r[:,:,'pooled G')
plt.title(539)
plt.subplot(2], cmap='gray', interpolation='nearest', vmin= 0.0, vmax=1.0)
plt.imshow(pooled.r[:,:,'pooled B')
plt.title(
5, 3, 10)
plt.subplot(0], cmap='gray', interpolation='nearest', vmin= 0.0, vmax=1.0)
plt.imshow(smoothed.r[:,:,'smoothed R')
plt.title(5, 3, 11)
plt.subplot(1], cmap='gray', interpolation='nearest', vmin= 0.0, vmax=1.0)
plt.imshow(smoothed.r[:,:,'smoothed G')
plt.title(5, 3, 12)
plt.subplot(2], cmap='gray', interpolation='nearest', vmin= 0.0, vmax=1.0)
plt.imshow(smoothed.r[:,:,'smoothed B')
plt.title(
5, 3, 13)
plt.subplot(0], cmap='gray', interpolation='nearest', vmin= 0.0, vmax=1.0)
plt.imshow(filtered.r[:,:,'filtered R')
plt.title(5, 3, 14)
plt.subplot(1], cmap='gray', interpolation='nearest', vmin= 0.0, vmax=1.0)
plt.imshow(filtered.r[:,:,'filtered G')
plt.title(5, 3, 15)
plt.subplot(2], cmap='gray', interpolation='nearest', vmin= 0.0, vmax=1.0)
plt.imshow(filtered.r[:,:,'filtered B')
plt.title(
plt.show()