Tiny Recurrent Spiking Neural Network

Why tiny? I wanted to make a model small enough to fit in a browser thread so I could visualise all its moving parts.

Scroll down or use the arrow keys to move through the slides.

↓ Slide 2 / 10

Dataset

We will classify blinking patterns.

We use a dataset of blinking patterns, where each sequence briefly shows a horizontal, vertical, cross, or diagonal line in one frame. The rest of the frames are blank, and each sequence is labeled by the pattern type.

Why blinking? Because we want to give the recurrent connections something to do.

horizontal

vertical

cross

↓ Slide 3 / 10

Noisy Dataset

Making the dataset non-trivial.

Noise is added to make the task non-trivial, challenging models to learn the pattern under uncertainty.

We randomly flip pixels in the frame to add the noise.

The training dataset has samples, each with 15 frames. The test dataset has samples.

horizontal

horizontal

horizontal

horizontal

vertical

vertical

vertical

vertical

cross

cross

cross

cross

↓ Slide 4 / 10

Network Training

The input layer size is matched to the flattened size of the input images (25 neurons), the hidden layer size is 24 neurons and the output layer size is 3 neurons, corresponding to the number of classes in the dataset.

Use the control panel below to start the training.
Epoch
0/5
Iterations
0/3600
Status
Loading Dataset
↓ Slide 5 / 10

Weights

The weights are trained using a simplified variant of the backpropagation through time algorithm, adapted for spiking neural networks.

The weights for all three layers are shown here, only visible after training has started.

Waiting for data...
Waiting for data...
Waiting for data...
Epoch
0/5
Iterations
0/3600
Status
Loading Dataset
↓ Slide 6 / 10

Accuracy and Loss

As the network trains, the accuracy and loss are updated, the loss is calculated every iteration while the accuracy is calculated every epoch.

Waiting for data...
Waiting for data...
Waiting for data...
Epoch
0/5
Iterations
0/3600
Status
Loading Dataset
↓ Slide 7 / 10

Recurrent Layer Spikes

Here we plot the spikes of the hidden neurons. The x-axis is the global time step and the y-axis is the neuron index. Each dot represents a spike at a given time step for a given neuron.

Each input is 15 frames long, and the model is trained on one input per iteration. Therefore we plot the spikes of the hidden neurons for the current iteration.

Since each input is blinking, most of the neurons will not spike at all.

Epoch
0/5
Iterations
0/3600
Status
Loading Dataset
↓ Slide 8 / 10

Classification Post Training

Once the network is trained, we can use it to classify the test set. It will not classify well before training, but after training it should classify the test set with a higher accuracy.

Epoch
0/5
Iterations
0/3600
Status
Loading Dataset
↓ Slide 9 / 10

Discussion

Each input is a short sequence of grayscale image frames where a specific pattern—like a line or cross—briefly blinks in a single frame. The rest of the sequence is mostly blank or noisy, forcing the model to detect the blink despite added pixel noise. Each sequence is labeled by the type of blinking pattern.

The model is a recurrent spiking neural network (RSNN) composed of an input layer, a hidden layer of Leaky Integrate-and-Fire (LIF) neurons, and a dense output layer. Each input frame is flattened and fed into the hidden layer, where neurons integrate input over time and emit spikes based on their membrane potential.

The hidden layer also has recurrent connections, allowing neurons to retain temporal context across frames. After the full sequence, the number of spikes from each hidden neuron is summed and passed through the output layer to produce class logits, which are converted to probabilities using softmax.

The Leaky Integrate-and-Fire (LIF) neuron simulates temporal integration with a decaying membrane potential. At each timestep , the membrane voltage is updated as:

where is the leak rate, is the input current, and is the previous voltage. If exceeds a threshold , the neuron emits a spike and resets its voltage to zero.

The input to the network consists of frames, which are sent to the hidden layer times the weights, the hidden layer updates according to:

where is the input current to neuron at time , are the input weights, and are the recurrent weights, and is the normalized pixel intensity, and is the spike count from the previous time step.

The output layer is a dense layer that takes the summed spikes from the hidden layer and applies a softmax activation function to produce class probabilities. These are then compared to the true labels using a cross-entropy loss function.

We use a simplified backpropagation through time algorithm to update the weights. Error from output is backpropagated to each hidden spike using:

Since real spikes are non-differentiable, the model uses a surrogate gradient:

This allows gradient-based optimization to update weights despite the binary nature of spikes, enabling the neuron to learn from temporal patterns in the input.

The training loop cycles through the dataset multiple times, where each cycle is called an epoch, updating the model weights at each step. The accuracy of the model is evaluated on a separate validation set after each epoch to monitor performance and prevent overfitting.

↓ Slide 10 / 10

Appendix: Logs

Every good system needs logging.

[]
Epoch
0/5
Iterations
0/3600
Status
Loading Dataset
Slide 1 / 10