Backpropagation Through Time (BPTT) in a Spiking Neural Network (SNN)
Overview
This model trains a two-layer spiking neural network using Backpropagation Through Time (BPTT) with surrogate gradients. The SNN consists of:
- Input layer: Poisson-encoded spike trains from image pixels
- Hidden layer: Leaky Integrate-and-Fire (LIF) neurons
- Output layer: LIF neurons, with spike counts integrated over time
- Decoder: Softmax classifier on accumulated output spikes
The spiking neurons produce binary outputs S∈{0,1}, and are trained using surrogate gradients to approximate the non-differentiable spike function.
1. Spiking Neuron Dynamics
1.1. Membrane Potential Update (LIF Model)
The membrane potential Vit of neuron i at time t evolves as:
Vit=αVit−1+Iit−1
- α=exp(−Δt/τm): membrane decay factor
- Iit−1: input current at previous timestep
- At spike, Vit resets to zero (or baseline, depending on model)
1.2. Spike Generation
A neuron emits a spike if its membrane potential exceeds a threshold Vth:
Sit=H(Vit−Vth)
Where H(x) is the Heaviside step function:
H(x)={10if x≥0if x<0
This function is non-differentiable, so we approximate its derivative for learning.
2. Surrogate Gradient Approximation
To allow gradient-based optimization, we use a surrogate derivative for H. One common choice is a fast sigmoid:
dVitdSit≈σ′(Vit−Vth)=(1+β∣Vit−Vth∣)21
- β: controls the steepness of the surrogate gradient
- Implemented via
lif_step()
in simulation
3. Forward Pass (Unrolled in Time)
Let:
- X∈RB×T×Nin: encoded input spike trains
- Win∈RNin×Nhid: input-to-hidden weights
- Wout∈RNhid×Nout: hidden-to-output weights
At each timestep t=1,…,T:
- Hidden layer input current:
Ihidt=Xt⋅Win
- Hidden layer LIF step:
Vhidt=αVhidt−1+Ihidt
Shidt=H(Vhidt−Vth)
- Output layer input current:
Ioutt=Shidt⋅Wout
- Output layer LIF step:
Voutt=αVoutt−1+Ioutt
Soutt=H(Voutt−Vth)
- Accumulate spikes across time:
logits=t=1∑TSoutt
4. Loss Function
We use softmax cross-entropy loss on the accumulated output spikes:
- Softmax probabilities:
y^i=∑jexp(zj)exp(zi),zi=logitsi
- Cross-entropy loss:
L=−B1i=1∑Bc=1∑Cyiclogy^ic
- yic: ground-truth one-hot labels
- y^ic: predicted probability for class c
5. Backpropagation Through Time (BPTT)
We compute the gradient of the loss with respect to the weights, unrolling the network over T time steps.
5.1. Derivative of Loss w.r.t. Logits
∂zi∂L=y^i−yi
This error is constant over all timesteps since zi=∑tSit
5.2. Backward Pass Over Time
Loop over t=T,T−1,…,1:
Output Layer
-
Recompute:
Ioutt=Shidt⋅Wout
-
Get surrogate gradient:
∂Ioutt∂Soutt≈σ′(Voutt−Vth)
-
Chain rule for gradient of loss wrt output weights:
δoutt=∂z∂L⋅∂Ioutt∂Soutt
∂Wout∂L+=Shidt⊤⋅δoutt
-
Backpropagate to hidden spike space:
δhidt=δoutt⋅Wout⊤
Hidden Layer
- Get surrogate gradient:
∂Ihidt∂Shidt≈σ′(Vhidt−Vth)
- Chain rule:
ΔWin+=Xt⊤⋅(δhidt∘∂Ihidt∂Shidt)
6. Weight Update (SGD)
After gradients have been accumulated across all timesteps:
Win←Win−η⋅∂Win∂L
Wout←Wout−η⋅∂Wout∂L
Summary
- The network is unfolded in time; membrane and spike histories are stored.
- Surrogate gradients allow training through discrete spike events.
- Gradients are propagated backward in time from output to input.
- The learning rule is effectively a spike-based BPTT with temporal credit assignment.
Implementation Notes
lif_step()
encapsulates membrane update, spike generation, and surrogate gradient.
- Gradient flow:
loss→logits→Soutt→Shidt→Xt
- All steps are implemented explicitly using NumPy, no autograd used.