This notebook documents an educational project in which I built a spiking neural network (SNN) to classify digits from the MNIST dataset. The model is implemented entirely in NumPy, with PyTorch used solely for dataset loading. After 5 epochs of training, the network achieves a classification accuracy of 94% on the MNIST test set. It uses leaky integrate-and-fire (LIF) neurons and is trained via surrogate gradient descent and backpropagation through time.
Github repository: https://github.com/eoinmurray/spiking-neural-network-on-mnist
encoded_images[:, t, :]
in the code)spikes_hidden
or spike_history_hidden[t]
in the code)s_output
in the code)V_hidden
in the code)V_output
in the code)weights_input_to_hidden
in the code)weights_hidden_to_output
in the code)grad_surrogate
in the code)membrane_decay
in the code)surrogate_grad_steepness
in the code)membrane_threshold
in the code)The network operates on the MNIST dataset, consisting of 28×28 grayscale images of handwritten digits. Each image is encoded into a binary spike train using a Poisson encoder.
def poisson_encode_speed(images, num_time_steps):
images = images.numpy()
images = images.reshape(images.shape[0], -1)
images = images / images.max() # normalise to [0,1]
spike_probabilities = np.repeat(images[:, None, :], num_time_steps, axis=1)
return (np.random.rand(*spike_probabilities.shape) < spike_probabilities)
For a given number of simulation time steps (default: 100), each pixel is treated as an independent Poisson process whose firing probability is proportional to its intensity. This generates a binary tensor of shape [batch_size, num_time_steps, num_pixels] representing spike events over time. Figure 1 below shows an example MNIST image and its corresponding spike train encoding, demonstrating how the static image is transformed into a temporal pattern of spikes.
The SNN is structured as a fully connected feedforward network with the following layers:
Connections between layers are initialized with normal random weights with zero mean and fixed standard deviation:
weights_input_to_hidden = np.random.normal(0.0, 0.20, (num_input_neurons, num_hidden_neurons))
weights_hidden_to_output = np.random.normal(0.0, 0.20, (num_hidden_neurons, num_output_neurons))
The weight distributions before and after training are shown in Figure 5 (at the end of this document), illustrating how learning reshapes the weight distributions during training.
Each image is presented for time steps with ms per step. The Leaky Integrate-and-Fire (LIF) neurons integrate input spikes over time with a decay constant and reset on spiking:
In discrete time:
This is implemented in the lif_step
function as:
# LIF dynamics: V(t+1) = α*V(t) + (1-α)*I(t)
membrane_potential = membrane_decay * membrane_potential + (1.0 - membrane_decay) * input_current
Where:
membrane_potential
corresponds to (either or depending on the layer)membrane_decay
corresponds to (set to 0.95 in the code)input_current
corresponds to (either or in the code)The complete LIF neuron behavior includes:
spikes = membrane_potential > membrane_threshold
)membrane_potential[spikes] = 0.0
)The membrane threshold is set to 0.6 by default.
This behavior is implemented in the lif_step
function:
def lif_step(input_current, membrane_potential, membrane_decay, surrogate_grad_steepness, membrane_threshold=0.6):
membrane_potential = membrane_decay * membrane_potential + (1.0 - membrane_decay) * input_current
spikes = membrane_potential > membrane_threshold
# surrogate gradient (sigmoid)
exp_term = np.exp(-surrogate_grad_steepness * (membrane_potential - membrane_threshold))
grad_surrogate = surrogate_grad_steepness * exp_term / (1.0 + exp_term) ** 2
v_pre_reset = membrane_potential.copy()
membrane_potential[spikes] = 0.0
return spikes.astype(float), membrane_potential, grad_surrogate, v_pre_reset
This behavior is non-differentiable, so we use the surrogate gradient during training.
We use a fast sigmoid as the surrogate gradient function:
with derivative
This is implemented in the lif_step
function as:
# Surrogate gradient (sigmoid derivative):
# σ'(V) = β*exp(-β(V-Vthresh))/(1+exp(-β(V-Vthresh)))²
exp_term = np.exp(-surrogate_grad_steepness * (membrane_potential - membrane_threshold))
grad_surrogate = surrogate_grad_steepness * exp_term / (1.0 + exp_term) ** 2
Where:
surrogate_grad_steepness
corresponds to (set to in the implementation)membrane_potential
corresponds to membrane_threshold
corresponds to grad_surrogate
corresponds to , stored as d_sigma_hidden
and d_sigma_output
for respective layersWe use the cross‑entropy loss on the time‑integrated output firing rates:
where is the number of classes, the one‑hot target, and the softmax probability for class .
This is implemented in the code by:
# During forward pass: accumulate spikes over time
spike_accumulator += s_output
# After time simulation, compute logits and softmax
logits = spike_accumulator
logits -= logits.max(1, keepdims=True) # Numerical stability
exp_logits = np.exp(logits)
probabilities = exp_logits / exp_logits.sum(1, keepdims=True)
# Create one-hot targets
one_hot_encoded = np.zeros_like(probabilities)
one_hot_encoded[np.arange(this_batch_size), labels.numpy()] = 1.0
# Compute cross-entropy loss
loss = -np.mean(np.sum(one_hot_encoded * np.log(probabilities + 1e-9), axis=1))
# Compute gradient of loss with respect to logits
grad_logits = (probabilities - one_hot_encoded) # ∂L/∂o_k(t)
Where:
s_output
corresponds to output spikes spike_accumulator
tracks accumulated spikes over timeprobabilities
corresponds to softmax probabilities one_hot_encoded
corresponds to target labels grad_logits
corresponds to The gradient of this loss function with respect to the output spikes becomes the starting point for backpropagation.
Using pure NumPy (without deep learning frameworks), the model achieves 94% classification accuracy on MNIST in less than 5 minutes on an M4 Macbook Pro CPU. Training is carried out over a fixed number of epochs (default: 10) using mini-batches of size 128. As shown in Figure 2, the network improves from random performance (~10% accuracy) to 94% accuracy within 5 epochs. The loss curve correspondingly shows a steady decrease.
For each batch:
poisson_encode_speed
function).for t in range(num_time_steps)
).for t in reversed(range(num_time_steps))
).# Apply clipped gradient updates
np.clip(grad_weights_input_to_hidden, -5, 5, out=grad_weights_input_to_hidden)
np.clip(grad_weights_hidden_to_output, -5, 5, out=grad_weights_hidden_to_output)
grad_weights_input_to_hidden /= this_batch_size
grad_weights_hidden_to_output /= this_batch_size
weights_input_to_hidden -= learning_rate * grad_weights_input_to_hidden
weights_hidden_to_output -= learning_rate * grad_weights_hidden_to_output
Figure 3 below shows the spiking activity and membrane potentials for a test image after training, visualizing how information flows through the network and how membrane potentials evolve to generate classification outputs.
Because the spiking network evolves over time, we use backpropagation through time (BPTT) to train it [1,2]. This involves unrolling the network across time steps and computing gradients by summing contributions from each time step. For each weight, we accumulate partial derivatives of the loss with respect to the membrane potential and spike history at each time step:
This is implemented in the code as gradient accumulators:
# Gradient accumulators for BPTT: ∂L/∂W = ∑_t ∂L/∂W(t)
grad_weights_input_to_hidden = np.zeros_like(weights_input_to_hidden)
grad_weights_hidden_to_output = np.zeros_like(weights_hidden_to_output)
And gradients are accumulated in the backward pass:
# Accumulate gradients for the weights
grad_weights_hidden_to_output += spike_history_hidden[t].T @ dV_output
grad_weights_input_to_hidden += I_history_hidden[t].T @ dI_hidden
Where:
grad_weights_input_to_hidden
corresponds to grad_weights_hidden_to_output
corresponds to spike_history_hidden[t]
corresponds to hidden layer spikes I_history_hidden[t]
corresponds to input spikes dV_output
corresponds to dI_hidden
corresponds to As shown in the code, we accumulate gradients by iterating through each time step in reverse
order using a backward loop (for t in reversed(range(num_time_steps))
). Figure 4 illustrates the L2 norms of the gradients for both hidden and output layers throughout training, providing insight into the training dynamics and stability.
A challenge in SNNs is that the network state at time depends on the history of spikes and membrane potentials from previous time steps. Due to the leaky integrate-and-fire dynamics:
the current membrane potential depends on all previous inputs and spikes. This creates a complex temporal dependency chain that must be correctly backpropagated during training.
During BPTT, we treat the SNN as a recurrent neural network unrolled in time, where:
For each time step , gradients flow through two paths:
This is implemented in the code by:
# Store history during forward pass
spike_history_hidden.append(spikes_hidden)
grad_surrogate_hidden.append(d_sigma_hidden)
I_history_hidden.append(encoded_images[:, t, :])
grad_surrogate_output.append(d_sigma_output)
# In backward pass, recurrent dependencies are handled with dV_next terms
dV_output = dS_output * grad_surrogate_output[t] + dV_next_output * membrane_decay
dV_hidden = dS_hidden * grad_surrogate_hidden[t] + dV_next_hidden * membrane_decay
# Update recurrent gradients for next time step (going backward)
dV_next_hidden = dV_hidden * membrane_decay # ∂L/∂V_h(t-1)
dV_next_output = dV_output * membrane_decay # ∂L/∂V_o(t-1)
The recurrent dependency significantly affects how gradients flow through time during backpropagation. When a neuron spikes at time , it influences:
For each backward step through time, gradients must be propagated through both these paths:
dS_output * grad_surrogate_output[t]
term)dV_next_output * membrane_decay
term)This creates a multiplicative effect where gradients from later time steps flow through a chain of dependencies. Since and we apply this recursively, gradients from time step to time step are scaled by .
The implementation uses the membrane_decay
factor (α
in the equations) to propagate these temporal dependencies. This technique helps manage potential gradient attenuation over long time horizons for small (high leakage) or instability for values of close to 1 (low leakage).
The spike function is discontinuous (step function), making the network non-differentiable. The surrogate gradient approach replaces this non-differentiable function with a differentiable approximation (sigmoid) during the backward pass only:
This is implemented in the lif_step
function where we have:
# Spike if membrane potential exceeds threshold
spikes = membrane_potential > membrane_threshold # Non-differentiable step function
# Surrogate gradient (sigmoid derivative):
# σ'(V) = β*exp(-β(V-Vthresh))/(1+exp(-β(V-Vthresh)))²
exp_term = np.exp(-surrogate_grad_steepness * (membrane_potential - membrane_threshold))
grad_surrogate = surrogate_grad_steepness * exp_term / (1.0 + exp_term) ** 2
Where:
spikes
implements the non-differentiable step functiongrad_surrogate
corresponds to , the surrogate derivativeThese surrogate gradient values are stored during the forward pass and then used during the backward pass:
# Store surrogate gradients during forward pass
grad_surrogate_hidden.append(d_sigma_hidden)
grad_surrogate_output.append(d_sigma_output)
# Use them in backward pass
dV_output = dS_output * grad_surrogate_output[t] + dV_next_output * membrane_decay
dV_hidden = dS_hidden * grad_surrogate_hidden[t] + dV_next_hidden * membrane_decay
Where:
d_sigma_hidden
and d_sigma_output
store the surrogate gradients and dS_output
and dS_hidden
correspond to and dV_output
and dV_hidden
correspond to and This allows gradient flow while preserving the discrete spiking behavior in the forward pass.
For a network with neurons simulated for time steps, we need to store:
The implementation handles these requirements by:
# Initialize history arrays for BPTT
spike_history_hidden = [] # Store h_j(t) for all t
grad_surrogate_hidden = [] # Store σ'(V_h^j(t)) for all t
I_history_hidden = [] # Store x_i(t) for all t
grad_surrogate_output = [] # Store σ'(V_o^k(t)) for all t
# FORWARD PASS: Propagate activity through time steps 1 to T
for t in range(num_time_steps):
# Input → Hidden: I_h = x(t) · W_ih
I_hidden = encoded_images[:, t, :] @ weights_input_to_hidden
# Process hidden layer with LIF neurons
# h(t), V_h(t+1) = LIF(I_h(t), V_h(t))
spikes_hidden, V_hidden, d_sigma_hidden, _ = lif_step(I_hidden, V_hidden, membrane_decay,
surrogate_grad_steepness, membrane_threshold)
# Hidden → Output: I_o = h(t) · W_ho
I_output = spikes_hidden @ weights_hidden_to_output
# Process output layer with LIF neurons
# o(t), V_o(t+1) = LIF(I_o(t), V_o(t))
s_output, V_output, d_sigma_output, _ = lif_step(I_output, V_output, membrane_decay,
surrogate_grad_steepness, membrane_threshold)
# Accumulate output spikes across time (rate coding for classification)
spike_accumulator += s_output
# Store history for BPTT
spike_history_hidden.append(spikes_hidden) # h_j(t)
grad_surrogate_hidden.append(d_sigma_hidden) # σ'(V_h^j(t))
I_history_hidden.append(encoded_images[:, t, :]) # x_i(t)
grad_surrogate_output.append(d_sigma_output) # σ'(V_o^k(t))
# Initialize recurrent gradient terms for time dependency
dV_next_hidden = np.zeros_like(V_hidden) # ∂L/∂V_h(t+1)
dV_next_output = np.zeros_like(V_output) # ∂L/∂V_o(t+1)
# Iterate backward through time (BPTT)
for t in reversed(range(num_time_steps)):
# Output layer backpropagation
# 1. ∂L/∂o_k(t) = p_k - y_k (from grad_logits)
dS_output = grad_logits # ∂L/∂o_k(t)
# 2. ∂L/∂V_o^k(t) = ∂L/∂o_k(t) * σ'(V_o^k(t)) + ∂L/∂V_o^k(t+1) * α
# First term: direct influence on current spike
# Second term: influence on future membrane potentials through decay
dV_output = dS_output * grad_surrogate_output[t] + dV_next_output * membrane_decay
# 3. ∂L/∂W_ho^kj = ∑_t ∂L/∂V_o^k(t) * h_j(t)
grad_weights_hidden_to_output += spike_history_hidden[t].T @ dV_output
# Hidden layer backpropagation
# 4. ∂L/∂h_j(t) = ∑_k ∂L/∂V_o^k(t) * W_ho^kj
dS_hidden = dV_output @ weights_hidden_to_output.T
# 5. ∂L/∂V_h^j(t) = ∂L/∂h_j(t) * σ'(V_h^j(t)) + ∂L/∂V_h^j(t+1) * α
dV_hidden = dS_hidden * grad_surrogate_hidden[t] + dV_next_hidden * membrane_decay
# 6. ∂L/∂I_h^j(t) = ∂L/∂V_h^j(t) * (1-α)
dI_hidden = dV_hidden * (1.0 - membrane_decay)
# 7. ∂L/∂W_ih^ji = ∑_t ∂L/∂I_h^j(t) * x_i(t)
grad_weights_input_to_hidden += I_history_hidden[t].T @ dI_hidden
# Store recurrent gradients for next time step
# These capture how current membrane potentials affect future time steps
dV_next_hidden = dV_hidden * membrane_decay # ∂L/∂V_h(t-1)
dV_next_output = dV_output * membrane_decay # ∂L/∂V_o(t-1)
This implementation directly matches the mathematical derivation in the appendix, with each step clearly documented in the code comments.
This simulation was a fun project to training a spiking neural network on image classification tasks using biologically inspired encoding, LIF dynamics, and surrogate gradient descent. NumPy was chosen over higher level frameworks not for performance, but to understand the underlying mechanics of SNNs. The model achieved over 90% accuracy on the MNIST dataset successfully, and future implementations can generalize the model to other datasets like FashionMNIST, which is already supported as an option in the code.
We want
Error at the spike output [3]
This is implemented as:
# Gradient of loss w.r.t. logits: ∂L/∂o_k(t) = p_k - y_k
grad_logits = (probabilities - one_hot_encoded)
dS_output = grad_logits # ∂L/∂o_k(t)
Surrogate spike derivative
Implemented as:
# Surrogate gradient (sigmoid derivative)
exp_term = np.exp(-surrogate_grad_steepness * (membrane_potential - membrane_threshold))
grad_surrogate = surrogate_grad_steepness * exp_term / (1.0 + exp_term) ** 2
Voltage w.r.t. weight
since
Putting it together:
This is implemented in the code as:
# 1. ∂L/∂o_k(t) = p_k - y_k (from grad_logits)
dS_output = grad_logits # ∂L/∂o_k(t)
# 2. ∂L/∂V_o^k(t) = ∂L/∂o_k(t) * σ'(V_o^k(t)) + ∂L/∂V_o^k(t+1) * α
dV_output = dS_output * grad_surrogate_output[t] + dV_next_output * membrane_decay
# 3. ∂L/∂W_ho^kj = ∑_t ∂L/∂V_o^k(t) * h_j(t)
grad_weights_hidden_to_output += spike_history_hidden[t].T @ dV_output
Where:
grad_logits
corresponds to dS_output
corresponds to grad_surrogate_output[t]
corresponds to spike_history_hidden[t]
corresponds to grad_weights_hidden_to_output
corresponds to dV_next_output
handles recurrent dependencies through timeWe want
Backpropagated error into hidden spike
Implemented as:
# 4. ∂L/∂h_j(t) = ∑_k ∂L/∂V_o^k(t) * W_ho^kj
dS_hidden = dV_output @ weights_hidden_to_output.T
Surrogate at hidden layer
Implemented with the same surrogate gradient function as the output layer:
# 5. ∂L/∂V_h^j(t) = ∂L/∂h_j(t) * σ'(V_h^j(t)) + ∂L/∂V_h^j(t+1) * α
dV_hidden = dS_hidden * grad_surrogate_hidden[t] + dV_next_hidden * membrane_decay
Voltage w.r.t. weight and input current
Implemented as:
# 6. ∂L/∂I_h^j(t) = ∂L/∂V_h^j(t) * (1-α)
dI_hidden = dV_hidden * (1.0 - membrane_decay)
# 7. ∂L/∂W_ih^ji = ∑_t ∂L/∂I_h^j(t) * x_i(t)
grad_weights_input_to_hidden += I_history_hidden[t].T @ dI_hidden
Together:
The code implements this equation with the addition of temporal dependencies:
# Hidden layer backpropagation
# 4. ∂L/∂h_j(t) = ∑_k ∂L/∂V_o^k(t) * W_ho^kj
dS_hidden = dV_output @ weights_hidden_to_output.T
# 5. ∂L/∂V_h^j(t) = ∂L/∂h_j(t) * σ'(V_h^j(t)) + ∂L/∂V_h^j(t+1) * α
dV_hidden = dS_hidden * grad_surrogate_hidden[t] + dV_next_hidden * membrane_decay
# 6. ∂L/∂I_h^j(t) = ∂L/∂V_h^j(t) * (1-α)
dI_hidden = dV_hidden * (1.0 - membrane_decay)
# 7. ∂L/∂W_ih^ji = ∑_t ∂L/∂I_h^j(t) * x_i(t)
grad_weights_input_to_hidden += I_history_hidden[t].T @ dI_hidden
Where:
dS_hidden
corresponds to dV_hidden
corresponds to dI_hidden
corresponds to I_history_hidden[t]
contains the input spikes for each timestepgrad_surrogate_hidden[t]
corresponds to weights_hidden_to_output
corresponds to grad_weights_input_to_hidden
corresponds to dV_next_hidden
and dV_next_output
handle recurrent dependencies through timeThe effect of these weight updates is visualized in Figure 5, showing how the weight distributions change from their initial random state to more structured patterns after training.