This notebook describes an educational project of mine to build a spiking neural network (SNN) simulation that performs digit classification on the MNIST dataset. Using only NumPy for implementation (no libraries), the model achieves over 90% classification accuracy on the MNIST test set. The network uses leaky integrate-and-fire (LIF) neurons and is trained using surrogate gradient descent. The implementation includes time-resolved Poisson encoding of images, a two-layer feedforward architecture, and logging of membrane potentials and spike activity throughout simulation and training.
Github repository: https://github.com/eoinmurray/spiking-neural-network-on-mnist
The network operates on the MNIST dataset, consisting of 28×28 grayscale images of handwritten digits. Each image is encoded into a binary spike train using a Poisson encoder.
def poisson_encode_speed(images, num_time_steps):
images = images.numpy() # Convert to numpy array from pytorch
images = images.reshape(images.shape[0], -1) # Flatten the images
spike_probabilities = np.repeat(images[:, None, :], num_time_steps, axis=1) # Repeat for time steps
return np.random.rand(*spike_probabilities.shape) < spike_probabilities # Generate spikes
For a given number of simulation time steps (default: 100), each pixel is treated as an independent Poisson process whose firing probability is proportional to its intensity. This generates a binary tensor of shape [batch_size, num_time_steps, num_pixels] representing spike events over time. Figure 1 below shows an example MNIST image and its corresponding spike train encoding, demonstrating how the static image is transformed into a temporal pattern of spikes.
The SNN is structured as a fully connected feedforward network with the following layers:
Connections between layers are initialized with uniform random weights drawn from a variance-scaled distribution to promote stable forward propagation:
weights_input_to_hidden = np.random.uniform(
-np.sqrt(1 / num_input_neurons), np.sqrt(1 / num_input_neurons),
size=(num_input_neurons, num_hidden_neurons)
)
weights_hidden_to_output = np.random.uniform(
-np.sqrt(1 / num_hidden_neurons), np.sqrt(1 / num_hidden_neurons),
size=(num_hidden_neurons, num_output_neurons)
)
The weight distributions before and after training are shown in Figure 5 (at the end of this document), illustrating how learning reshapes the weight distributions during training.
Each image is presented for time steps with ms per step. The LIF neurons integrate input spikes over time with a decay constant and reset on spiking:
In discrete time:
membrane_potential = membrane_decay * membrane_potential + input_current
in code below)spikes = membrane_potential > membrane_threshold
)membrane_potential[spikes] = 0.0
)membrane_decay
term in the update equation)This behavior is implemented in the lif_step
function:
def lif_step(input_current, membrane_potential, membrane_decay, surrogate_grad_steepness, membrane_threshold = 1.5):
# Update membrane potential with decay and input current
membrane_potential = membrane_decay * membrane_potential + input_current
# Determine which neurons spike (V > threshold)
spikes = membrane_potential > membrane_threshold
# Calculate surrogate gradient for backward pass
exp_term = np.exp(-surrogate_grad_steepness * membrane_potential - 1.0)
grad_surrogate = surrogate_grad_steepness * exp_term / (1.0 + exp_term) ** 2
# Reset membrane potential for neurons that spiked
membrane_potential[spikes] = 0.0
return spikes.astype(float), membrane_potential, grad_surrogate
This behavior is non-differentiable, so we use the surrogate gradient during training.
We use a fast sigmoid as the surrogate:
with derivative
where is the steepness parameter (set to 5.0 in the implementation). You can see the implementation of this surrogate gradient in the lif_step
function in the section that calculates exp_term
and grad_surrogate
.
We use the cross‑entropy loss on the time‑integrated output firing rates:
where is the number of classes, the one‑hot target, and the softmax probability for class . The implementation computes softmax probabilities from time-integrated spike counts and compares them with one-hot encoded targets:
# Accumulate spikes over all time steps
output_spike_accumulator += spikes_output
# After time simulation, compute softmax probabilities
softmax_numerators = np.exp(output_spike_accumulator - np.max(output_spike_accumulator, axis=1, keepdims=True))
probabilities = softmax_numerators / softmax_numerators.sum(axis=1, keepdims=True)
# Create one-hot targets
one_hot_targets = np.zeros_like(probabilities)
one_hot_targets[np.arange(len(label_batch)), label_batch.numpy()] = 1
# Compute cross-entropy loss
loss = -np.mean(np.sum(one_hot_targets * np.log(probabilities + 1e-9), axis=1))
# Compute gradient of loss with respect to logits
grad_logits = (probabilities - one_hot_targets) / batch_size
Using pure NumPy (without deep learning frameworks), the model achieves >90% classification accuracy on MNIST in less than 5 minutes on an M4 Macbook Pro CPU. Training is carried out over a fixed number of epochs (default: 5) using mini-batches of size 128. As shown in Figure 2, the network improves from random performance (~10% accuracy) to over 90% accuracy within 10 epochs. The loss curve correspondingly shows a steady decrease.
For each batch:
poisson_encode_speed
function).for t in range(num_time_steps)
).for t in reversed(range(num_time_steps))
).# Apply weight updates
weights_input_to_hidden -= learning_rate * grad_weights_input
weights_hidden_to_output -= learning_rate * grad_weights_output
Figure 3 below shows the spiking activity and membrane potentials for a test image after training, visualizing how information flows through the network and how membrane potentials evolve to generate classification outputs.
Because the spiking network evolves over time, we use backpropagation through time (BPTT) to train it. This involves unrolling the network across time steps and computing gradients by summing contributions from each time step. For each weight, we accumulate partial derivatives of the loss with respect to the membrane potential and spike history at each time step:
As shown in the code implementation below, we accumulate gradients by iterating through each time step in reverse order using a backward loop. Figure 4 illustrates the L2 norms of the gradients for both hidden and output layers throughout training, providing insight into the training dynamics and stability.
A challenge in SNNs is that the network state at time depends on the history of spikes and membrane potentials from previous time steps. Due to the leaky integrate-and-fire dynamics:
the current membrane potential depends on all previous inputs and spikes. This creates a complex temporal dependency chain that must be correctly backpropagated during training.
During BPTT, we treat the SNN as a recurrent neural network unrolled in time, where:
For each time step , gradients flow through two paths:
This recurrent dependency is captured in the implementation by storing the complete history of spikes and membrane states
during the forward pass (using the spike_history_hidden
and input_current_history_hidden
arrays) and then processing them
in reverse temporal order during backpropagation (using the for t in reversed(range(num_time_steps))
loop).
The recurrent dependency significantly affects how gradients flow through time during backpropagation. When a neuron spikes at time , it influences:
For each backward step through time, gradients must be propagated through both these paths:
The spike function is discontinuous (step function), making the network non-differentiable. The surrogate gradient approach replaces this non-differentiable function with a differentiable approximation (sigmoid) during the backward pass only:
This allows gradient flow while preserving the discrete spiking behavior in the forward pass. As shown in the code implementation, we
recalculate the surrogate gradients during the backward pass (using the lif_step
function to get the grad_output
and grad_hidden
values).
For a network with neurons simulated for time steps, we need to store:
The implementation handles these requirements by:
# During forward pass - storing spike activity and input currents
spike_history_hidden = []
input_current_history_hidden = []
membrane_potential_history_output = []
membrane_potential_history_hidden = []
for t in range(num_time_steps):
current_input_hidden = encoded_spikes[:, t, :] @ weights_input_to_hidden
spikes_hidden, membrane_potential_hidden, grad_hidden = lif_step(current_input_hidden, membrane_potential_hidden, membrane_decay, surrogate_grad_steepness, membrane_threshold)
spike_history_hidden.append(spikes_hidden)
input_current_history_hidden.append(current_input_hidden)
current_input_output = spikes_hidden @ weights_hidden_to_output
spikes_output, membrane_potential_output, grad_output = lif_step(current_input_output, membrane_potential_output, membrane_decay, surrogate_grad_steepness, membrane_threshold)
output_spike_accumulator += spikes_output
membrane_potential_history_output.append(membrane_potential_output.copy())
membrane_potential_history_hidden.append(membrane_potential_hidden.copy())
# During backward pass - iterating backward through time steps
for t in reversed(range(num_time_steps)):
spikes_hidden = spike_history_hidden[t]
input_current_hidden = input_current_history_hidden[t]
current_input_output = spikes_hidden @ weights_hidden_to_output
# Calculate output gradients
_, _, grad_output = lif_step(current_input_output, membrane_potential_history_output[t], membrane_decay, surrogate_grad_steepness, membrane_threshold)
grad_output_current = grad_logits * grad_output
grad_weights_output += spikes_hidden.T @ grad_output_current
# Backpropagate to hidden layer
grad_spikes_hidden = grad_output_current @ weights_hidden_to_output.T
# Calculate hidden gradients
_, _, grad_hidden = lif_step(input_current_hidden, membrane_potential_history_hidden[t], membrane_decay, surrogate_grad_steepness, membrane_threshold)
grad_input_hidden = grad_spikes_hidden * grad_hidden
grad_weights_input += encoded_spikes[:, t, :].T @ grad_input_hidden
We want
Error at the spike output
Surrogate spike derivative
Voltage w.r.t. weight
Putting it together:
We want
Backpropagated error into hidden spike
Surrogate at hidden layer
Voltage w.r.t. weight
Together:
These equations are implemented in the backward pass code, where we:
grad_output_current = grad_logits * grad_output
)grad_spikes_hidden = grad_output_current @ weights_hidden_to_output.T
)grad_weights_output += ...
and grad_weights_input += ...
)The effect of these weight updates is visualized in Figure 5, showing how the weight distributions change from their initial random state to more structured patterns after training.
This simulation was a fun project to training a spiking neural network on image classification tasks using biologically inspired encoding, LIF dynamics, and surrogate gradient descent. NumPy was chosen over higher level frameworks not for performance, but to understand the underlying mechanics of SNNs. The model achieved over 90% accuracy on the MNIST dataset successfully, and future implementations will try to generalize the model, starting with FashionMNIST.