Skip to content

Commit 2cd4ca6

Browse files
author
rbodo
committed
Improved temporal pattern code: Correctly handle non-normalized inputs and linear activation functions; implement low-precision activations; more efficient computation of the number of ops.
1 parent 420c5c8 commit 2cd4ca6

File tree

3 files changed

+60
-59
lines changed

3 files changed

+60
-59
lines changed

snntoolbox/simulation/backends/inisim/temporal_pattern.py

+18-47
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@
1313
@author: rbodo
1414
"""
1515

16-
import numpy as np
1716
import tensorflow as tf
1817
from tensorflow.keras.layers import Dense, Flatten, AveragePooling2D, Layer, \
1918
MaxPooling2D, Conv2D, Concatenate, DepthwiseConv2D, Reshape, ZeroPadding2D
@@ -80,31 +79,37 @@ def spike_call(self, x, call):
8079
self._a = tf.Variable(lambda: tf.zeros_like(x), name='activation',
8180
trainable=False)
8281

83-
# In case of centered input layer, some x values could be negative.
82+
# If not using ReLU, some x values could be negative.
8483
# Remove and store signs to apply after binarization.
8584
signs = tf.sign(x)
8685
x = tf.abs(x)
8786

88-
# Make sure x is normalized before binarization.
89-
x_max = tf.reduce_max(x)
90-
x = tf.divide(x, x_max)
87+
# Make sure input is normalized before binarization. Hidden layers are
88+
# normalized during parsing.
89+
if self.is_first_spiking:
90+
x_max = tf.reduce_max(x)
91+
x = tf.divide(x, x_max)
92+
else:
93+
x_max = 1
9194

9295
# Transform x into binary format here. Effective batch_size increases
9396
# from 1 to num_bits.
94-
x_b = self.to_binary(x)
97+
x = self.to_binary(x)
9598

9699
# Apply signs and rescale back to original range.
97-
x_b = tf.multiply(x_b, signs * x_max)
100+
x = tf.multiply(x, signs * x_max)
98101

99102
# Perform layer operation, e.g. convolution, on every power of 2.
100-
x_b = call(self, x_b)
103+
y = call(self, x)
101104

102105
# Add up the weighted powers of 2 to recover the activation values.
103-
y = tf.reduce_sum(x_b, 0, keepdims=True)
106+
y = tf.reduce_sum(y, 0, keepdims=True)
104107

105108
# Apply non-linearity.
106-
y = tf.nn.softmax(y) if self.activation_str == 'softmax' \
107-
else tf.nn.relu(y)
109+
if self.activation_str == 'softmax':
110+
y = tf.nn.softmax(y)
111+
elif self.activation_str == 'relu':
112+
y = tf.nn.relu(y)
108113

109114
self.spikerates.assign(y)
110115

@@ -130,7 +135,8 @@ def to_binary(self, x):
130135
``x`` is distributed across the first dimension of ``x_binary``.
131136
"""
132137

133-
self._a.assign(x)
138+
n = 2 ** self.num_bits - 1
139+
self._a.assign(tf.divide(tf.round(tf.multiply(x, n)), n))
134140

135141
for i in tf.range(self.num_bits):
136142
mask = tf.cast(tf.greater(self._a, self.powers[i]), tf.float32)
@@ -143,41 +149,6 @@ def to_binary(self, x):
143149
return self._x_binary
144150

145151

146-
def to_binary_numpy(x, num_bits):
147-
"""Transform an array of floats into binary representation.
148-
149-
Parameters
150-
----------
151-
152-
x: ndarray
153-
Input array containing float values. The first dimension has to be of
154-
length 1.
155-
num_bits: int
156-
The fixed point precision to be used when converting to binary.
157-
158-
Returns
159-
-------
160-
161-
binary_array: ndarray
162-
Output boolean array. The first dimension of x is expanded to length
163-
``bits``. The binary representation of each value in ``x`` is
164-
distributed across the first dimension of ``binary_array``.
165-
"""
166-
167-
x_binary = np.zeros([num_bits] + list(x.shape[1:]))
168-
169-
powers = [2**-(i+1) for i in range(num_bits)]
170-
171-
a = np.copy(x)
172-
173-
for i in range(num_bits):
174-
mask = np.greater(a, powers[i])
175-
x_binary[i] = mask
176-
a -= mask * powers[i]
177-
178-
return x_binary
179-
180-
181152
class SpikeConcatenate(Concatenate):
182153
"""Spike merge layer"""
183154

snntoolbox/simulation/target_simulators/INI_temporal_mean_rate_target_sim.py

+3
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from tensorflow import keras
1111
import numpy as np
1212

13+
from snntoolbox.parsing.utils import get_inbound_layers_with_params
1314
from snntoolbox.simulation.utils import AbstractSNN, remove_name_counter
1415

1516
remove_classifier = False
@@ -84,6 +85,8 @@ def add_layer(self, layer):
8485

8586
spike_layer = spike_layer_name(**layer_kwargs)
8687
spike_layer.activation_str = activation_str
88+
spike_layer.is_first_spiking = \
89+
len(get_inbound_layers_with_params(layer)) == 0
8790
self._spiking_layers[layer.name] = spike_layer(inbound)
8891

8992
def build_dense(self, layer):

snntoolbox/simulation/target_simulators/INI_temporal_pattern_target_sim.py

+39-12
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,8 @@ def simulate(self, **kwargs):
8989
# Excludes Input, Flatten, Concatenate, etc:
9090
if hasattr(layer, 'spikerates') and layer.spikerates is not None:
9191
spikerates_b_l = layer.spikerates.numpy()
92-
spiketrains_b_l_t = self.spikerates_to_trains(spikerates_b_l)
92+
spiketrains_b_l_t = to_binary_numpy(spikerates_b_l,
93+
self.num_bits)
9394
self.set_spikerates(spikerates_b_l, i)
9495
self.set_spiketrains(spiketrains_b_l_t, i)
9596
if self.synaptic_operations_b_t is not None:
@@ -127,14 +128,40 @@ def set_neuron_operations(self, i):
127128

128129
def set_synaptic_operations(self, spiketrains_b_l_t, i):
129130
for t in range(self.synaptic_operations_b_t.shape[-1]):
130-
self.synaptic_operations_b_t[:, t] += 2 * \
131-
get_layer_synaptic_operations(
132-
spiketrains_b_l_t[Ellipsis, t], self.fanout[i + 1])
133-
134-
def spikerates_to_trains(self, spikerates_b_l):
135-
x = self.sim.to_binary_numpy(spikerates_b_l, self.num_bits)
136-
shape = [self.num_bits] + [1] * (x.ndim - 1)
137-
x *= np.resize(np.arange(self.num_bits), shape)
138-
perm = (1, 2, 3, 0) if len(x.shape) > 2 else (1, 0)
139-
spiketrains_b_l_t = np.expand_dims(np.transpose(x, perm), 0)
140-
return spiketrains_b_l_t
131+
ops = get_layer_synaptic_operations(spiketrains_b_l_t[Ellipsis, t],
132+
self.fanout[i + 1])
133+
self.synaptic_operations_b_t[:, t] += 2 * ops
134+
135+
136+
def to_binary_numpy(x, num_bits):
137+
"""Transform an array of floats into binary representation.
138+
139+
Parameters
140+
----------
141+
142+
x: ndarray
143+
Input array containing float values. The first dimension has to be of
144+
length 1.
145+
num_bits: int
146+
The fixed point precision to be used when converting to binary.
147+
148+
Returns
149+
-------
150+
151+
y: ndarray
152+
Output array with same shape as ``x`` except that an axis is added to
153+
the last dimension with size ``num_bits``. The binary representation of
154+
each value in ``x`` is distributed across the last dimension of ``y``.
155+
"""
156+
157+
n = 2 ** num_bits - 1
158+
a = np.round(x * n) / n
159+
160+
y = np.zeros(list(x.shape) + [num_bits])
161+
for i in range(num_bits):
162+
p = 2 ** -(i + 1)
163+
b = np.greater(a, p) * p
164+
y[Ellipsis, i] = b
165+
a -= b
166+
167+
return y

0 commit comments

Comments
 (0)