Spiking Neural Network classifying MNIST data

This notebook describes an educational project of mine to build a spiking neural network (SNN) simulation that performs digit classification on the MNIST dataset. Using only NumPy for implementation (no libraries), the model achieves over 90% classification accuracy on the MNIST test set. The network uses leaky integrate-and-fire (LIF) neurons and is trained using surrogate gradient descent. The implementation includes time-resolved Poisson encoding of images, a two-layer feedforward architecture, and logging of membrane potentials and spike activity throughout simulation and training.

Github repository: https://github.com/eoinmurray/spiking-neural-network-on-mnist

Notation

Dataset and encoding

The network operates on the MNIST dataset, consisting of 28×28 grayscale images of handwritten digits. Each image is encoded into a binary spike train using a Poisson encoder.

def poisson_encode_speed(images, num_time_steps):
    images = images.numpy() # Convert to numpy array from pytorch
    images = images.reshape(images.shape[0], -1) # Flatten the images
    spike_probabilities = np.repeat(images[:, None, :], num_time_steps, axis=1) # Repeat for time steps
    return np.random.rand(*spike_probabilities.shape) < spike_probabilities # Generate spikes

For a given number of simulation time steps (default: 100), each pixel is treated as an independent Poisson process whose firing probability is proportional to its intensity. This generates a binary tensor of shape [batch_size, num_time_steps, num_pixels] representing spike events over time. Figure 1 below shows an example MNIST image and its corresponding spike train encoding, demonstrating how the static image is transformed into a temporal pattern of spikes.

Figure 1: An MNIST image (left) and its corresponding spike train (right). The spike train shows neuron activity across 100 time steps (horizontal axis) for each of the 784 input neurons (vertical axis), with brighter points indicating spike events.

Network architecture

The SNN is structured as a fully connected feedforward network with the following layers:

Connections between layers are initialized with uniform random weights drawn from a variance-scaled distribution to promote stable forward propagation:

weights_input_to_hidden = np.random.uniform(
    -np.sqrt(1 / num_input_neurons), np.sqrt(1 / num_input_neurons), 
    size=(num_input_neurons, num_hidden_neurons)
)
weights_hidden_to_output = np.random.uniform(
    -np.sqrt(1 / num_hidden_neurons), np.sqrt(1 / num_hidden_neurons), 
    size=(num_hidden_neurons, num_output_neurons)
)

The weight distributions before and after training are shown in Figure 5 (at the end of this document), illustrating how learning reshapes the weight distributions during training.

LIF dynamics

Each image is presented for T=100T = 100 time steps with Δt=1\Delta t = 1 ms per step. The LIF neurons integrate input spikes over time with a decay constant τ\tau and reset on spiking:

τdV(t)dt=V(t)+I(t).\tau \frac{dV(t)}{dt} = -V(t) + I(t).

In discrete time:

V(t+1)=αV(t)+I(t),with α=e1/τ.V(t + 1) = \alpha V(t) + I(t), \quad \text{with } \alpha = e^{-1/\tau}.

This behavior is implemented in the lif_step function:

def lif_step(input_current, membrane_potential, membrane_decay, surrogate_grad_steepness, membrane_threshold = 1.5):
    # Update membrane potential with decay and input current
    membrane_potential = membrane_decay * membrane_potential + input_current
    
    # Determine which neurons spike (V > threshold)
    spikes = membrane_potential > membrane_threshold
    
    # Calculate surrogate gradient for backward pass
    exp_term = np.exp(-surrogate_grad_steepness * membrane_potential - 1.0)
    grad_surrogate = surrogate_grad_steepness * exp_term / (1.0 + exp_term) ** 2
    
    # Reset membrane potential for neurons that spiked
    membrane_potential[spikes] = 0.0
    
    return spikes.astype(float), membrane_potential, grad_surrogate

This behavior is non-differentiable, so we use the surrogate gradient during training.

Surrogate gradient

We use a fast sigmoid as the surrogate:

σ(V)=11+eβV1\sigma (V) = \frac{1}{1+e^{-\beta V - 1}}

with derivative

σ(V)=βeβV(1+eβV)2\sigma'(V) = \frac{ \beta e^{-\beta V} }{ (1 + e^{-\beta V})^2 }

where β\beta is the steepness parameter (set to 5.0 in the implementation). You can see the implementation of this surrogate gradient in the lif_step function in the section that calculates exp_term and grad_surrogate.

Loss Function

We use the cross‑entropy loss on the time‑integrated output firing rates:

L  =  k=1Nyklog ⁣(y^k),\mathcal{L} \;=\; -\sum_{k=1}^{N} y_k \,\log\!\bigl(\hat y_k\bigr),

where NN is the number of classes, yky_k the one‑hot target, and y^k\hat y_k the softmax probability for class kk. The implementation computes softmax probabilities from time-integrated spike counts and compares them with one-hot encoded targets:

# Accumulate spikes over all time steps
output_spike_accumulator += spikes_output

# After time simulation, compute softmax probabilities
softmax_numerators = np.exp(output_spike_accumulator - np.max(output_spike_accumulator, axis=1, keepdims=True))
probabilities = softmax_numerators / softmax_numerators.sum(axis=1, keepdims=True)

# Create one-hot targets
one_hot_targets = np.zeros_like(probabilities)
one_hot_targets[np.arange(len(label_batch)), label_batch.numpy()] = 1

# Compute cross-entropy loss
loss = -np.mean(np.sum(one_hot_targets * np.log(probabilities + 1e-9), axis=1))

# Compute gradient of loss with respect to logits
grad_logits = (probabilities - one_hot_targets) / batch_size

Training protocol

Using pure NumPy (without deep learning frameworks), the model achieves >90% classification accuracy on MNIST in less than 5 minutes on an M4 Macbook Pro CPU. Training is carried out over a fixed number of epochs (default: 5) using mini-batches of size 128. As shown in Figure 2, the network improves from random performance (~10% accuracy) to over 90% accuracy within 10 epochs. The loss curve correspondingly shows a steady decrease.

Figure 2: Accuracy vs epochs showing >90% accuracy after 10 epochs (left) and loss vs epochs (right).

For each batch:

# Apply weight updates
weights_input_to_hidden -= learning_rate * grad_weights_input
weights_hidden_to_output -= learning_rate * grad_weights_output

Figure 3 below shows the spiking activity and membrane potentials for a test image after training, visualizing how information flows through the network and how membrane potentials evolve to generate classification outputs.

Figure 3: Spike trains and membrane potentials for the last image (see Figure 1) in the evaluation run with the trained model. Left: Spike trains for the hidden layer. Middle: Spike trains for the output layer spiking neuron 6 corresponding to classifying the image as digit 6. Right: Membrane potentials for the output layer showing decay for all neurons except 6, which is spiking.

Backpropagation Through Time (BPTT)

Because the spiking network evolves over time, we use backpropagation through time (BPTT) to train it. This involves unrolling the network across time steps and computing gradients by summing contributions from each time step. For each weight, we accumulate partial derivatives of the loss with respect to the membrane potential and spike history at each time step:

LW=tLW(t)\frac{\partial L}{\partial W} = \sum_t \frac{\partial L}{\partial W}(t)

As shown in the code implementation below, we accumulate gradients by iterating through each time step in reverse order using a backward loop. Figure 4 illustrates the L2 norms of the gradients for both hidden and output layers throughout training, providing insight into the training dynamics and stability.

Figure 4: L2 norms for the hidden layer and output gradients recorded every 50 iterations

Temporal Dependencies in SNNs

A challenge in SNNs is that the network state at time tt depends on the history of spikes and membrane potentials from previous time steps. Due to the leaky integrate-and-fire dynamics:

V(t+1)=αV(t)+I(t)V(t + 1) = \alpha V(t) + I(t)

the current membrane potential depends on all previous inputs and spikes. This creates a complex temporal dependency chain that must be correctly backpropagated during training.

Handling Recurrent Dependencies

During BPTT, we treat the SNN as a recurrent neural network unrolled in time, where:

  1. Forward pass: Propagate activity sequentially through time steps 11 to TT.
  2. Backward pass: Propagate gradients backwards from time step TT to 11.

For each time step tt, gradients flow through two paths:

This recurrent dependency is captured in the implementation by storing the complete history of spikes and membrane states during the forward pass (using the spike_history_hidden and input_current_history_hidden arrays) and then processing them in reverse temporal order during backpropagation (using the for t in reversed(range(num_time_steps)) loop).

The recurrent dependency significantly affects how gradients flow through time during backpropagation. When a neuron spikes at time tt, it influences:

  1. The immediate forward connections to the next layer at time tt
  2. Its own future membrane potential at time t+1,t+2,...t+1, t+2, ...

For each backward step through time, gradients must be propagated through both these paths:

Handling Discontinuities with Surrogate Gradients

The spike function is discontinuous (step function), making the network non-differentiable. The surrogate gradient approach replaces this non-differentiable function with a differentiable approximation (sigmoid) during the backward pass only:

SpikeVσ(V)\frac{\partial \text{Spike}}{\partial V} \approx \sigma'(V)

This allows gradient flow while preserving the discrete spiking behavior in the forward pass. As shown in the code implementation, we recalculate the surrogate gradients during the backward pass (using the lif_step function to get the grad_output and grad_hidden values).

Implementation of BPTT

For a network with NN neurons simulated for TT time steps, we need to store:

The implementation handles these requirements by:

  1. Storing activity history during the forward pass:
# During forward pass - storing spike activity and input currents
spike_history_hidden = []
input_current_history_hidden = []
membrane_potential_history_output = []
membrane_potential_history_hidden = []


for t in range(num_time_steps):
    current_input_hidden = encoded_spikes[:, t, :] @ weights_input_to_hidden
    spikes_hidden, membrane_potential_hidden, grad_hidden = lif_step(current_input_hidden, membrane_potential_hidden, membrane_decay, surrogate_grad_steepness, membrane_threshold)
    spike_history_hidden.append(spikes_hidden)
    input_current_history_hidden.append(current_input_hidden)
    current_input_output = spikes_hidden @ weights_hidden_to_output
    spikes_output, membrane_potential_output, grad_output = lif_step(current_input_output, membrane_potential_output, membrane_decay, surrogate_grad_steepness, membrane_threshold)
    output_spike_accumulator += spikes_output
    membrane_potential_history_output.append(membrane_potential_output.copy())
    membrane_potential_history_hidden.append(membrane_potential_hidden.copy())
  1. Backpropagating gradients through time in reverse order:
# During backward pass - iterating backward through time steps
for t in reversed(range(num_time_steps)):
    spikes_hidden = spike_history_hidden[t]
    input_current_hidden = input_current_history_hidden[t]
    current_input_output = spikes_hidden @ weights_hidden_to_output

    # Calculate output gradients
    _, _, grad_output = lif_step(current_input_output, membrane_potential_history_output[t], membrane_decay, surrogate_grad_steepness, membrane_threshold)
    grad_output_current = grad_logits * grad_output
    grad_weights_output += spikes_hidden.T @ grad_output_current
    
    # Backpropagate to hidden layer
    grad_spikes_hidden = grad_output_current @ weights_hidden_to_output.T
    
    # Calculate hidden gradients
    _, _, grad_hidden = lif_step(input_current_hidden, membrane_potential_history_hidden[t], membrane_decay, surrogate_grad_steepness, membrane_threshold)
    grad_input_hidden = grad_spikes_hidden * grad_hidden
    grad_weights_input += encoded_spikes[:, t, :].T @ grad_input_hidden

Gradient w.r.t. Hidden→Output Weights WhokjW_{ho}^{kj}

We want

LWhokj=tLok(t)  ok(t)Vok(t)  Vok(t)Whokj.\frac{\partial \mathcal{L}}{\partial W_{ho}^{kj}} =\sum_{t} \frac{\partial \mathcal{L}}{\partial o_k(t)} \;\frac{\partial o_k(t)}{\partial V_o^k(t)} \;\frac{\partial V_o^k(t)}{\partial W_{ho}^{kj}}.
  1. Error at the spike output

    Lok(t)=y^kyk.\frac{\partial \mathcal{L}}{\partial o_k(t)} = \hat y_k - y_k.
  2. Surrogate spike derivative

    ok(t)Vok(t)σ(Vok(t)).\frac{\partial o_k(t)}{\partial V_o^k(t)} \approx \sigma'\bigl(V_o^k(t)\bigr).
  3. Voltage w.r.t. weight

    Vok(t)=jWhokjhj(t)Vok(t)Whokj=hj(t).V_o^k(t) = \sum_{j} W_{ho}^{kj}\,h_j(t) \quad\Longrightarrow\quad \frac{\partial V_o^k(t)}{\partial W_{ho}^{kj}} = h_j(t).

Putting it together:

LWhokj=t(y^kyk)  σ(Vok(t))  hj(t).\boxed{ \frac{\partial \mathcal{L}}{\partial W_{ho}^{kj}} = \sum_{t} \bigl(\hat y_k - y_k\bigr) \;\sigma'\bigl(V_o^k(t)\bigr)\; h_j(t) }.

Gradient w.r.t. Input→Hidden Weights WihjiW_{ih}^{j i}

We want

LWihji=tLhj(t)  hj(t)Vhj(t)  Vhj(t)Wihji.\frac{\partial \mathcal{L}}{\partial W_{ih}^{j i}} =\sum_{t} \frac{\partial \mathcal{L}}{\partial h_j(t)} \;\frac{\partial h_j(t)}{\partial V_h^j(t)} \;\frac{\partial V_h^j(t)}{\partial W_{ih}^{j i}}.
  1. Backpropagated error into hidden spike

    Lhj(t)=kLok(t)  ok(t)Vok(t)  Vok(t)hj(t)=k(y^kyk)σ(Vok(t))Whokj.\frac{\partial \mathcal{L}}{\partial h_j(t)} = \sum_{k} \frac{\partial \mathcal{L}}{\partial o_k(t)} \;\frac{\partial o_k(t)}{\partial V_o^k(t)} \;\frac{\partial V_o^k(t)}{\partial h_j(t)} = \sum_{k} (\hat y_k - y_k)\, \sigma'\bigl(V_o^k(t)\bigr)\, W_{ho}^{kj}.
  2. Surrogate at hidden layer

    hj(t)Vhj(t)σ(Vhj(t)).\frac{\partial h_j(t)}{\partial V_h^j(t)} \approx \sigma'\bigl(V_h^j(t)\bigr).
  3. Voltage w.r.t. weight

    Vhj(t)=iWihjixi(t)Vhj(t)Wihji=xi(t).V_h^j(t) = \sum_{i} W_{ih}^{j i}\,x_i(t) \quad\Longrightarrow\quad \frac{\partial V_h^j(t)}{\partial W_{ih}^{j i}} = x_i(t).

Together:

LWihji=t[k(y^kyk)σ(Vok(t))Whokj]  σ(Vhj(t))  xi(t).\boxed{ \frac{\partial \mathcal{L}}{\partial W_{ih}^{j i}} = \sum_{t} \Bigl[\sum_{k} (\hat y_k - y_k)\, \sigma'\bigl(V_o^k(t)\bigr)\, W_{ho}^{kj} \Bigr] \;\sigma'\bigl(V_h^j(t)\bigr)\; x_i(t) }.

These equations are implemented in the backward pass code, where we:

  1. First calculate the output layer gradients (grad_output_current = grad_logits * grad_output)
  2. Then backpropagate to the hidden layer (grad_spikes_hidden = grad_output_current @ weights_hidden_to_output.T)
  3. Finally accumulate the weight gradients for both layers (grad_weights_output += ... and grad_weights_input += ...)

The effect of these weight updates is visualized in Figure 5, showing how the weight distributions change from their initial random state to more structured patterns after training.

Figure 5: Weight distributions before and after training. Top left: Input layer weights before training. Top right: Input layer weights after training. Bottom left: Hidden layer weights before training. Bottom right: Hidden layer weights after training.

Conclusion

This simulation was a fun project to training a spiking neural network on image classification tasks using biologically inspired encoding, LIF dynamics, and surrogate gradient descent. NumPy was chosen over higher level frameworks not for performance, but to understand the underlying mechanics of SNNs. The model achieved over 90% accuracy on the MNIST dataset successfully, and future implementations will try to generalize the model, starting with FashionMNIST.