Spiking Neural Network classifying MNIST data

Table of Contents

    This notebook documents an educational project in which I built a spiking neural network (SNN) to classify digits from the MNIST dataset. The model is implemented entirely in NumPy, with PyTorch used solely for dataset loading. After 5 epochs of training, the network achieves a classification accuracy of 94% on the MNIST test set. It uses leaky integrate-and-fire (LIF) neurons and is trained via surrogate gradient descent and backpropagation through time.

    Github repository: https://github.com/eoinmurray/spiking-neural-network-on-mnist

    Notation

    Dataset and encoding

    The network operates on the MNIST dataset, consisting of 28×28 grayscale images of handwritten digits. Each image is encoded into a binary spike train using a Poisson encoder.

    def poisson_encode_speed(images, num_time_steps):
        images = images.numpy()
        images = images.reshape(images.shape[0], -1)
        images = images / images.max()                               # normalise to [0,1]
        spike_probabilities = np.repeat(images[:, None, :], num_time_steps, axis=1)
        return (np.random.rand(*spike_probabilities.shape) < spike_probabilities)
    

    For a given number of simulation time steps (default: 100), each pixel is treated as an independent Poisson process whose firing probability is proportional to its intensity. This generates a binary tensor of shape [batch_size, num_time_steps, num_pixels] representing spike events over time. Figure 1 below shows an example MNIST image and its corresponding spike train encoding, demonstrating how the static image is transformed into a temporal pattern of spikes.

    Figure 1: An MNIST image (left) and its corresponding spike train (right). The spike train shows neuron activity across 100 time steps (horizontal axis) for each of the 784 input neurons (vertical axis), with brighter points indicating spike events.

    Network architecture

    The SNN is structured as a fully connected feedforward network with the following layers:

    Connections between layers are initialized with normal random weights with zero mean and fixed standard deviation:

    weights_input_to_hidden  = np.random.normal(0.0, 0.20, (num_input_neurons, num_hidden_neurons))
    weights_hidden_to_output = np.random.normal(0.0, 0.20, (num_hidden_neurons, num_output_neurons))
    

    The weight distributions before and after training are shown in Figure 5 (at the end of this document), illustrating how learning reshapes the weight distributions during training.

    LIF dynamics

    Each image is presented for T=100T = 100 time steps with Δt=1\Delta t = 1 ms per step. The Leaky Integrate-and-Fire (LIF) neurons integrate input spikes over time with a decay constant τ\tau and reset on spiking:

    τdV(t)dt=V(t)+I(t).\tau \frac{dV(t)}{dt} = -V(t) + I(t).

    In discrete time:

    V(t+1)=αV(t)+(1α)I(t),with α=e1/τ.V(t + 1) = \alpha V(t) + (1 - \alpha) I(t), \quad \text{with } \alpha = e^{-1/\tau}.

    This is implemented in the lif_step function as:

    # LIF dynamics: V(t+1) = α*V(t) + (1-α)*I(t)
    membrane_potential = membrane_decay * membrane_potential + (1.0 - membrane_decay) * input_current
    

    Where:

    The complete LIF neuron behavior includes:

    The membrane threshold VthreshV_{\text{thresh}} is set to 0.6 by default.

    This behavior is implemented in the lif_step function:

    def lif_step(input_current, membrane_potential, membrane_decay, surrogate_grad_steepness, membrane_threshold=0.6):
        membrane_potential = membrane_decay * membrane_potential + (1.0 - membrane_decay) * input_current
        spikes = membrane_potential > membrane_threshold
    
        # surrogate gradient (sigmoid)
        exp_term = np.exp(-surrogate_grad_steepness * (membrane_potential - membrane_threshold))
        grad_surrogate = surrogate_grad_steepness * exp_term / (1.0 + exp_term) ** 2
    
        v_pre_reset = membrane_potential.copy()
        membrane_potential[spikes] = 0.0
        return spikes.astype(float), membrane_potential, grad_surrogate, v_pre_reset
    

    This behavior is non-differentiable, so we use the surrogate gradient during training.

    Surrogate gradient

    We use a fast sigmoid as the surrogate gradient function:

    σ(V)=11+eβ(VVthresh)\sigma (V) = \frac{1}{1+e^{-\beta (V - V_{thresh})}}

    with derivative

    σ(V)=βeβ(VVthresh)(1+eβ(VVthresh))2\sigma'(V) = \frac{ \beta e^{-\beta (V - V_{thresh})} }{ (1 + e^{-\beta (V - V_{thresh})})^2 }

    This is implemented in the lif_step function as:

    # Surrogate gradient (sigmoid derivative):
    # σ'(V) = β*exp(-β(V-Vthresh))/(1+exp(-β(V-Vthresh)))²
    exp_term = np.exp(-surrogate_grad_steepness * (membrane_potential - membrane_threshold))
    grad_surrogate = surrogate_grad_steepness * exp_term / (1.0 + exp_term) ** 2
    

    Where:

    Loss Function

    We use the cross‑entropy loss on the time‑integrated output firing rates:

    L  =  k=1Nyklog ⁣(y^k),\mathcal{L} \;=\; -\sum_{k=1}^{N} y_k \,\log\!\bigl(\hat y_k\bigr),

    where NN is the number of classes, yky_k the one‑hot target, and y^k\hat y_k the softmax probability for class kk.

    This is implemented in the code by:

    1. First accumulating spikes over all time steps (rate coding)
    2. Computing softmax probabilities from these spike counts
    3. Calculating cross-entropy against one-hot encoded targets
    # During forward pass: accumulate spikes over time
    spike_accumulator += s_output
    
    # After time simulation, compute logits and softmax
    logits = spike_accumulator
    logits -= logits.max(1, keepdims=True)  # Numerical stability
    exp_logits = np.exp(logits)
    probabilities = exp_logits / exp_logits.sum(1, keepdims=True)
    
    # Create one-hot targets
    one_hot_encoded = np.zeros_like(probabilities)
    one_hot_encoded[np.arange(this_batch_size), labels.numpy()] = 1.0
    
    # Compute cross-entropy loss
    loss = -np.mean(np.sum(one_hot_encoded * np.log(probabilities + 1e-9), axis=1))
    
    # Compute gradient of loss with respect to logits
    grad_logits = (probabilities - one_hot_encoded)  # ∂L/∂o_k(t)
    

    Where:

    The gradient of this loss function with respect to the output spikes becomes the starting point for backpropagation.

    Training protocol

    Using pure NumPy (without deep learning frameworks), the model achieves 94% classification accuracy on MNIST in less than 5 minutes on an M4 Macbook Pro CPU. Training is carried out over a fixed number of epochs (default: 10) using mini-batches of size 128. As shown in Figure 2, the network improves from random performance (~10% accuracy) to 94% accuracy within 5 epochs. The loss curve correspondingly shows a steady decrease.

    Figure 2: Accuracy vs epochs showing >90% accuracy after a few epochs (left) and loss vs epochs (right).

    For each batch:

    # Apply clipped gradient updates
    np.clip(grad_weights_input_to_hidden,  -5, 5, out=grad_weights_input_to_hidden)
    np.clip(grad_weights_hidden_to_output, -5, 5, out=grad_weights_hidden_to_output)
    
    grad_weights_input_to_hidden  /= this_batch_size
    grad_weights_hidden_to_output /= this_batch_size
    
    weights_input_to_hidden  -= learning_rate * grad_weights_input_to_hidden
    weights_hidden_to_output -= learning_rate * grad_weights_hidden_to_output
    

    Figure 3 below shows the spiking activity and membrane potentials for a test image after training, visualizing how information flows through the network and how membrane potentials evolve to generate classification outputs.

    Figure 3: Spike trains and membrane potentials for the last image (see Figure 1) in the evaluation run with the trained model. Left: Spike trains for the hidden layer. Middle: Spike trains for the output layer spiking neuron 6 corresponding to classifying the image as digit 6. Right: Membrane potentials for the output layer showing decay for all neurons except 6, which is spiking.

    Backpropagation Through Time (BPTT)

    Because the spiking network evolves over time, we use backpropagation through time (BPTT) to train it [1,2]. This involves unrolling the network across time steps and computing gradients by summing contributions from each time step. For each weight, we accumulate partial derivatives of the loss with respect to the membrane potential and spike history at each time step:

    LW=tLW(t)\frac{\partial \mathcal{L}}{\partial W} = \sum_t \frac{\partial \mathcal{L}}{\partial W}(t)

    This is implemented in the code as gradient accumulators:

    # Gradient accumulators for BPTT: ∂L/∂W = ∑_t ∂L/∂W(t)
    grad_weights_input_to_hidden = np.zeros_like(weights_input_to_hidden)
    grad_weights_hidden_to_output = np.zeros_like(weights_hidden_to_output)
    

    And gradients are accumulated in the backward pass:

    # Accumulate gradients for the weights
    grad_weights_hidden_to_output += spike_history_hidden[t].T @ dV_output
    grad_weights_input_to_hidden += I_history_hidden[t].T @ dI_hidden
    

    Where:

    As shown in the code, we accumulate gradients by iterating through each time step in reverse order using a backward loop (for t in reversed(range(num_time_steps))). Figure 4 illustrates the L2 norms of the gradients for both hidden and output layers throughout training, providing insight into the training dynamics and stability.

    Figure 4: L2 norms for the hidden layer and output gradients recorded every 50 iterations

    Temporal Dependencies in SNNs

    A challenge in SNNs is that the network state at time tt depends on the history of spikes and membrane potentials from previous time steps. Due to the leaky integrate-and-fire dynamics:

    V(t+1)=αV(t)+(1α)I(t)V(t + 1) = \alpha V(t) + (1-\alpha) I(t)

    the current membrane potential depends on all previous inputs and spikes. This creates a complex temporal dependency chain that must be correctly backpropagated during training.

    During BPTT, we treat the SNN as a recurrent neural network unrolled in time, where:

    1. Forward pass: Propagate activity sequentially through time steps 11 to TT.
    2. Backward pass: Propagate gradients backwards from time step TT to 11.

    For each time step tt, gradients flow through two paths:

    This is implemented in the code by:

    # Store history during forward pass
    spike_history_hidden.append(spikes_hidden)
    grad_surrogate_hidden.append(d_sigma_hidden)
    I_history_hidden.append(encoded_images[:, t, :])
    grad_surrogate_output.append(d_sigma_output)
    
    # In backward pass, recurrent dependencies are handled with dV_next terms
    dV_output = dS_output * grad_surrogate_output[t] + dV_next_output * membrane_decay
    dV_hidden = dS_hidden * grad_surrogate_hidden[t] + dV_next_hidden * membrane_decay
    
    # Update recurrent gradients for next time step (going backward)
    dV_next_hidden = dV_hidden * membrane_decay  # ∂L/∂V_h(t-1)
    dV_next_output = dV_output * membrane_decay  # ∂L/∂V_o(t-1)
    

    The recurrent dependency significantly affects how gradients flow through time during backpropagation. When a neuron spikes at time tt, it influences:

    1. The immediate forward connections to the next layer at time tt
    2. Its own future membrane potential at time t+1,t+2,...t+1, t+2, ...

    For each backward step through time, gradients must be propagated through both these paths:

    This creates a multiplicative effect where gradients from later time steps flow through a chain of dependencies. Since V(t+1)V(t)=α\frac{\partial V(t+1)}{\partial V(t)} = \alpha and we apply this recursively, gradients from time step t+kt+k to time step tt are scaled by αk\alpha^k.

    The implementation uses the membrane_decay factor (α in the equations) to propagate these temporal dependencies. This technique helps manage potential gradient attenuation over long time horizons for small α\alpha (high leakage) or instability for values of α\alpha close to 1 (low leakage).

    Handling Discontinuities with Surrogate Gradients

    The spike function is discontinuous (step function), making the network non-differentiable. The surrogate gradient approach replaces this non-differentiable function with a differentiable approximation (sigmoid) during the backward pass only:

    SpikeVσ(V)\frac{\partial \text{Spike}}{\partial V} \approx \sigma'(V)

    This is implemented in the lif_step function where we have:

    # Spike if membrane potential exceeds threshold
    spikes = membrane_potential > membrane_threshold  # Non-differentiable step function
    
    # Surrogate gradient (sigmoid derivative): 
    # σ'(V) = β*exp(-β(V-Vthresh))/(1+exp(-β(V-Vthresh)))²
    exp_term = np.exp(-surrogate_grad_steepness * (membrane_potential - membrane_threshold))
    grad_surrogate = surrogate_grad_steepness * exp_term / (1.0 + exp_term) ** 2
    

    Where:

    These surrogate gradient values are stored during the forward pass and then used during the backward pass:

    # Store surrogate gradients during forward pass
    grad_surrogate_hidden.append(d_sigma_hidden)
    grad_surrogate_output.append(d_sigma_output)
    
    # Use them in backward pass
    dV_output = dS_output * grad_surrogate_output[t] + dV_next_output * membrane_decay
    dV_hidden = dS_hidden * grad_surrogate_hidden[t] + dV_next_hidden * membrane_decay
    

    Where:

    This allows gradient flow while preserving the discrete spiking behavior in the forward pass.

    Figure 5: Weight distributions before and after training. Top left: Input layer weights before training. Top right: Input layer weights after training. Bottom left: Hidden layer weights before training. Bottom right: Hidden layer weights after training.

    Implementation of BPTT

    For a network with NN neurons simulated for TT time steps, we need to store:

    The implementation handles these requirements by:

    1. Storing activity history during the forward pass:
    # Initialize history arrays for BPTT
    spike_history_hidden = []   # Store h_j(t) for all t
    grad_surrogate_hidden = []  # Store σ'(V_h^j(t)) for all t
    I_history_hidden = []       # Store x_i(t) for all t
    grad_surrogate_output = []  # Store σ'(V_o^k(t)) for all t
    
    # FORWARD PASS: Propagate activity through time steps 1 to T
    for t in range(num_time_steps):
        # Input → Hidden: I_h = x(t) · W_ih
        I_hidden = encoded_images[:, t, :] @ weights_input_to_hidden
        
        # Process hidden layer with LIF neurons
        # h(t), V_h(t+1) = LIF(I_h(t), V_h(t))
        spikes_hidden, V_hidden, d_sigma_hidden, _ = lif_step(I_hidden, V_hidden, membrane_decay,
                                            surrogate_grad_steepness, membrane_threshold)
        
        # Hidden → Output: I_o = h(t) · W_ho
        I_output = spikes_hidden @ weights_hidden_to_output
        
        # Process output layer with LIF neurons
        # o(t), V_o(t+1) = LIF(I_o(t), V_o(t))
        s_output, V_output, d_sigma_output, _ = lif_step(I_output, V_output, membrane_decay,
                                            surrogate_grad_steepness, membrane_threshold)
        
        # Accumulate output spikes across time (rate coding for classification)
        spike_accumulator += s_output
    
        # Store history for BPTT
        spike_history_hidden.append(spikes_hidden)           # h_j(t)
        grad_surrogate_hidden.append(d_sigma_hidden)         # σ'(V_h^j(t))
        I_history_hidden.append(encoded_images[:, t, :])     # x_i(t)
        grad_surrogate_output.append(d_sigma_output)         # σ'(V_o^k(t))
    
    1. Backpropagating gradients through time in reverse order:
    # Initialize recurrent gradient terms for time dependency
    dV_next_hidden = np.zeros_like(V_hidden)  # ∂L/∂V_h(t+1)
    dV_next_output = np.zeros_like(V_output)  # ∂L/∂V_o(t+1)
    
    # Iterate backward through time (BPTT)
    for t in reversed(range(num_time_steps)):
        # Output layer backpropagation
        # 1. ∂L/∂o_k(t) = p_k - y_k (from grad_logits)
        dS_output = grad_logits  # ∂L/∂o_k(t)
        
        # 2. ∂L/∂V_o^k(t) = ∂L/∂o_k(t) * σ'(V_o^k(t)) + ∂L/∂V_o^k(t+1) * α
        # First term: direct influence on current spike
        # Second term: influence on future membrane potentials through decay
        dV_output = dS_output * grad_surrogate_output[t] + dV_next_output * membrane_decay
    
        # 3. ∂L/∂W_ho^kj = ∑_t ∂L/∂V_o^k(t) * h_j(t)
        grad_weights_hidden_to_output += spike_history_hidden[t].T @ dV_output
    
        # Hidden layer backpropagation
        # 4. ∂L/∂h_j(t) = ∑_k ∂L/∂V_o^k(t) * W_ho^kj
        dS_hidden = dV_output @ weights_hidden_to_output.T
        
        # 5. ∂L/∂V_h^j(t) = ∂L/∂h_j(t) * σ'(V_h^j(t)) + ∂L/∂V_h^j(t+1) * α
        dV_hidden = dS_hidden * grad_surrogate_hidden[t] + dV_next_hidden * membrane_decay
        
        # 6. ∂L/∂I_h^j(t) = ∂L/∂V_h^j(t) * (1-α)
        dI_hidden = dV_hidden * (1.0 - membrane_decay)
        
        # 7. ∂L/∂W_ih^ji = ∑_t ∂L/∂I_h^j(t) * x_i(t)
        grad_weights_input_to_hidden += I_history_hidden[t].T @ dI_hidden
    
        # Store recurrent gradients for next time step
        # These capture how current membrane potentials affect future time steps
        dV_next_hidden = dV_hidden * membrane_decay  # ∂L/∂V_h(t-1)
        dV_next_output = dV_output * membrane_decay  # ∂L/∂V_o(t-1)
    

    This implementation directly matches the mathematical derivation in the appendix, with each step clearly documented in the code comments.

    Conclusion

    This simulation was a fun project to training a spiking neural network on image classification tasks using biologically inspired encoding, LIF dynamics, and surrogate gradient descent. NumPy was chosen over higher level frameworks not for performance, but to understand the underlying mechanics of SNNs. The model achieved over 90% accuracy on the MNIST dataset successfully, and future implementations can generalize the model to other datasets like FashionMNIST, which is already supported as an option in the code.

    References

    1. Emre O. Neftci et. al, Surrogate Gradient Learning in Spiking Neural Networks
    2. Jason K. Eshraghian, Training Spiking Neural Networks with snntorch
    3. Derivative of Cross-Entropy Loss w.r.t. Output Spike Count

    Appendix

    Gradient w.r.t. Hidden→Output Weights

    We want

    LWhokj=tLok(t)  ok(t)Vok(t)  Vok(t)Whokj.\frac{\partial \mathcal{L}}{\partial W_{ho}^{kj}} =\sum_{t} \frac{\partial \mathcal{L}}{\partial o_k(t)} \;\frac{\partial o_k(t)}{\partial V_o^k(t)} \;\frac{\partial V_o^k(t)}{\partial W_{ho}^{kj}}.
    1. Error at the spike output [3]

      Lok(t)=y^kyk.\frac{\partial \mathcal{L}}{\partial o_k(t)} = \hat y_k - y_k.

      This is implemented as:

      # Gradient of loss w.r.t. logits: ∂L/∂o_k(t) = p_k - y_k
      grad_logits = (probabilities - one_hot_encoded)
      dS_output = grad_logits  # ∂L/∂o_k(t)
      
    2. Surrogate spike derivative

      ok(t)Vok(t)σ(Vok(t)).\frac{\partial o_k(t)}{\partial V_o^k(t)} \approx \sigma'\bigl(V_o^k(t)\bigr).

      Implemented as:

      # Surrogate gradient (sigmoid derivative)
      exp_term = np.exp(-surrogate_grad_steepness * (membrane_potential - membrane_threshold))
      grad_surrogate = surrogate_grad_steepness * exp_term / (1.0 + exp_term) ** 2
      
    3. Voltage w.r.t. weight

      Vok(t)Whokj=hj(t).\frac{\partial V_o^k(t)}{\partial W_{ho}^{kj}} = h_j(t).

      since Vok(t)=jWhokjhj(t) V_o^k(t) = \sum_{j} W_{ho}^{kj}\,h_j(t)

    Putting it together:

    LWhokj=t(y^kyk)  σ(Vok(t))  hj(t).\boxed{ \frac{\partial \mathcal{L}}{\partial W_{ho}^{kj}} = \sum_{t} \bigl(\hat y_k - y_k\bigr) \;\sigma'\bigl(V_o^k(t)\bigr)\; h_j(t) }.

    This is implemented in the code as:

    # 1. ∂L/∂o_k(t) = p_k - y_k (from grad_logits)
    dS_output = grad_logits  # ∂L/∂o_k(t)
    
    # 2. ∂L/∂V_o^k(t) = ∂L/∂o_k(t) * σ'(V_o^k(t)) + ∂L/∂V_o^k(t+1) * α
    dV_output = dS_output * grad_surrogate_output[t] + dV_next_output * membrane_decay
    
    # 3. ∂L/∂W_ho^kj = ∑_t ∂L/∂V_o^k(t) * h_j(t)
    grad_weights_hidden_to_output += spike_history_hidden[t].T @ dV_output
    

    Where:

    Gradient w.r.t. Input→Hidden Weights

    We want

    LWihji=tLhj(t)  hj(t)Vhj(t)  Vhj(t)Wihji.\frac{\partial \mathcal{L}}{\partial W_{ih}^{j i}} =\sum_{t} \frac{\partial \mathcal{L}}{\partial h_j(t)} \;\frac{\partial h_j(t)}{\partial V_h^j(t)} \;\frac{\partial V_h^j(t)}{\partial W_{ih}^{j i}}.
    1. Backpropagated error into hidden spike

      Lhj(t)=kLok(t)  ok(t)Vok(t)  Vok(t)hj(t)=k(y^kyk)σ(Vok(t))Whokj.\frac{\partial \mathcal{L}}{\partial h_j(t)} = \sum_{k} \frac{\partial \mathcal{L}}{\partial o_k(t)} \;\frac{\partial o_k(t)}{\partial V_o^k(t)} \;\frac{\partial V_o^k(t)}{\partial h_j(t)} = \sum_{k} (\hat y_k - y_k)\, \sigma'\bigl(V_o^k(t)\bigr)\, W_{ho}^{kj}.

      Implemented as:

      # 4. ∂L/∂h_j(t) = ∑_k ∂L/∂V_o^k(t) * W_ho^kj
      dS_hidden = dV_output @ weights_hidden_to_output.T
      
    2. Surrogate at hidden layer

      hj(t)Vhj(t)σ(Vhj(t)).\frac{\partial h_j(t)}{\partial V_h^j(t)} \approx \sigma'\bigl(V_h^j(t)\bigr).

      Implemented with the same surrogate gradient function as the output layer:

      # 5. ∂L/∂V_h^j(t) = ∂L/∂h_j(t) * σ'(V_h^j(t)) + ∂L/∂V_h^j(t+1) * α
      dV_hidden = dS_hidden * grad_surrogate_hidden[t] + dV_next_hidden * membrane_decay
      
    3. Voltage w.r.t. weight and input current

      Vhj(t)=iWihjixi(t)Vhj(t)Wihji=xi(t).V_h^j(t) = \sum_{i} W_{ih}^{j i}\,x_i(t) \quad\Longrightarrow\quad \frac{\partial V_h^j(t)}{\partial W_{ih}^{j i}} = x_i(t).

      Implemented as:

      # 6. ∂L/∂I_h^j(t) = ∂L/∂V_h^j(t) * (1-α)
      dI_hidden = dV_hidden * (1.0 - membrane_decay)
      
      # 7. ∂L/∂W_ih^ji = ∑_t ∂L/∂I_h^j(t) * x_i(t)
      grad_weights_input_to_hidden += I_history_hidden[t].T @ dI_hidden
      

    Together:

    LWihji=t[k(y^kyk)σ(Vok(t))Whokj]  σ(Vhj(t))  xi(t).\boxed{ \frac{\partial \mathcal{L}}{\partial W_{ih}^{j i}} = \sum_{t} \Bigl[\sum_{k} (\hat y_k - y_k)\, \sigma'\bigl(V_o^k(t)\bigr)\, W_{ho}^{kj} \Bigr] \;\sigma'\bigl(V_h^j(t)\bigr)\; x_i(t) }.

    The code implements this equation with the addition of temporal dependencies:

    # Hidden layer backpropagation
    # 4. ∂L/∂h_j(t) = ∑_k ∂L/∂V_o^k(t) * W_ho^kj
    dS_hidden = dV_output @ weights_hidden_to_output.T
    
    # 5. ∂L/∂V_h^j(t) = ∂L/∂h_j(t) * σ'(V_h^j(t)) + ∂L/∂V_h^j(t+1) * α
    dV_hidden = dS_hidden * grad_surrogate_hidden[t] + dV_next_hidden * membrane_decay
    
    # 6. ∂L/∂I_h^j(t) = ∂L/∂V_h^j(t) * (1-α)
    dI_hidden = dV_hidden * (1.0 - membrane_decay)
    
    # 7. ∂L/∂W_ih^ji = ∑_t ∂L/∂I_h^j(t) * x_i(t)
    grad_weights_input_to_hidden += I_history_hidden[t].T @ dI_hidden
    

    Where:

    The effect of these weight updates is visualized in Figure 5, showing how the weight distributions change from their initial random state to more structured patterns after training.