Backpropagation Through Time (BPTT) in a Spiking Neural Network (SNN)

Overview

This model trains a two-layer spiking neural network using Backpropagation Through Time (BPTT) with surrogate gradients. The SNN consists of:

The spiking neurons produce binary outputs S{0,1}S \in \{0, 1\}, and are trained using surrogate gradients to approximate the non-differentiable spike function.

1. Spiking Neuron Dynamics

1.1. Membrane Potential Update (LIF Model)

The membrane potential VitV_i^t of neuron ii at time tt evolves as:

Vit=αVit1+Iit1V_i^t = \alpha V_i^{t-1} + I_i^{t-1}

1.2. Spike Generation

A neuron emits a spike if its membrane potential exceeds a threshold VthV_{\text{th}}:

Sit=H(VitVth)S_i^t = H(V_i^t - V_{\text{th}})

Where H(x)H(x) is the Heaviside step function:

H(x)={1if x00if x<0H(x) = \begin{cases} 1 & \text{if } x \geq 0 \\ 0 & \text{if } x < 0 \end{cases}

This function is non-differentiable, so we approximate its derivative for learning.

2. Surrogate Gradient Approximation

To allow gradient-based optimization, we use a surrogate derivative for HH. One common choice is a fast sigmoid:

dSitdVitσ(VitVth)=1(1+βVitVth)2\frac{dS_i^t}{dV_i^t} \approx \sigma'(V_i^t - V_{\text{th}}) = \frac{1}{(1 + \beta |V_i^t - V_{\text{th}}|)^2}

3. Forward Pass (Unrolled in Time)

Let:

At each timestep t=1,,Tt = 1, \dots, T:

  1. Hidden layer input current: Ihidt=XtWinI_{\text{hid}}^t = X^t \cdot W_{\text{in}}
  2. Hidden layer LIF step: Vhidt=αVhidt1+IhidtV_{\text{hid}}^t = \alpha V_{\text{hid}}^{t-1} + I_{\text{hid}}^t Shidt=H(VhidtVth)S_{\text{hid}}^t = H(V_{\text{hid}}^t - V_{\text{th}})
  3. Output layer input current: Ioutt=ShidtWoutI_{\text{out}}^t = S_{\text{hid}}^t \cdot W_{\text{out}}
  4. Output layer LIF step: Voutt=αVoutt1+IouttV_{\text{out}}^t = \alpha V_{\text{out}}^{t-1} + I_{\text{out}}^t Soutt=H(VouttVth)S_{\text{out}}^t = H(V_{\text{out}}^t - V_{\text{th}})
  5. Accumulate spikes across time: logits=t=1TSoutt\text{logits} = \sum_{t=1}^T S_{\text{out}}^t

4. Loss Function

We use softmax cross-entropy loss on the accumulated output spikes:

  1. Softmax probabilities: y^i=exp(zi)jexp(zj),zi=logitsi\hat{y}_i = \frac{\exp(z_i)}{\sum_j \exp(z_j)}, \quad z_i = \text{logits}_i
  2. Cross-entropy loss: L=1Bi=1Bc=1Cyiclogy^icL = -\frac{1}{B} \sum_{i=1}^B \sum_{c=1}^{C} y_{ic} \log \hat{y}_{ic}

5. Backpropagation Through Time (BPTT)

We compute the gradient of the loss with respect to the weights, unrolling the network over TT time steps.

5.1. Derivative of Loss w.r.t. Logits

Lzi=y^iyi\frac{\partial L}{\partial z_i} = \hat{y}_i - y_i

This error is constant over all timesteps since zi=tSitz_i = \sum_t S_i^t

5.2. Backward Pass Over Time

Loop over t=T,T1,,1t = T, T-1, \dots, 1:

Output Layer

  1. Recompute:

    Ioutt=ShidtWoutI_{\text{out}}^t = S_{\text{hid}}^t \cdot W_{\text{out}}
  2. Get surrogate gradient:

    SouttIouttσ(VouttVth)\frac{\partial S_{\text{out}}^t}{\partial I_{\text{out}}^t} \approx \sigma'(V_{\text{out}}^t - V_{\text{th}})
  3. Chain rule for gradient of loss wrt output weights:

    δoutt=LzSouttIoutt\delta_{\text{out}}^t = \frac{\partial L}{\partial z} \cdot \frac{\partial S_{\text{out}}^t}{\partial I_{\text{out}}^t} LWout+=Shidtδoutt\frac{\partial L}{\partial W_{\text{out}}} += S_{\text{hid}}^{t^\top} \cdot \delta_{\text{out}}^t
  4. Backpropagate to hidden spike space:

    δhidt=δouttWout\delta_{\text{hid}}^t = \delta_{\text{out}}^t \cdot W_{\text{out}}^\top

Hidden Layer

  1. Get surrogate gradient: ShidtIhidtσ(VhidtVth)\frac{\partial S_{\text{hid}}^t}{\partial I_{\text{hid}}^t} \approx \sigma'(V_{\text{hid}}^t - V_{\text{th}})
  2. Chain rule: ΔWin+=Xt(δhidtShidtIhidt)\Delta W_{\text{in}} += X^{t^\top} \cdot \left( \delta_{\text{hid}}^t \circ \frac{\partial S_{\text{hid}}^t}{\partial I_{\text{hid}}^t} \right)

6. Weight Update (SGD)

After gradients have been accumulated across all timesteps:

WinWinηLWinW_{\text{in}} \leftarrow W_{\text{in}} - \eta \cdot \frac{\partial L}{\partial W_{\text{in}}} WoutWoutηLWoutW_{\text{out}} \leftarrow W_{\text{out}} - \eta \cdot \frac{\partial L}{\partial W_{\text{out}}}

Summary

Implementation Notes