Derivative of Cross-Entropy Loss w.r.t. Output Spike Count
Table of Contents
The cross entropy loss function is commonly used in classification tasks,
especially in the context of neural networks.
The loss function itself is used along with its gradient to update the
weights of the network during training.
Here we derive the derivative of the cross-entropy loss with respect to the
output spike count for a spiking neural network.
First we do it manually, and then we use Sympy to verify our results.
Manual Derivation
To derivate the derivate of the cross-entropy loss with respect to the output spike count, we
assume the following:
- We have spike count ouput si for each output neuron i computed over some time window.
- We apply a softmax activation function to the output spike count si to get the predicted class probabilities y^i.
- The true label is one hot encoded yi.
The cross-entropy loss is defined as:
L=−i=1∑Cyilog(y^i)
The softmax function is defined as:
y^i=∑j=1Cesjesi=Zesi
Where Z=∑j=1Cesj, the partition function.
Now we plug the partition function into the cross entropy loss:
L=−i=1∑Cyilog(y^i)
L=−i=1∑Cyilog(Zesi)
L=−i=1∑Cyi(si−log(Z))
L=−i=1∑Cyisi+log(Z)⋅i=1∑Cyi
L=−i=1∑Cyisi+log(Z)
since ∑i=1Cyi=1.
Now we can compute the derivative of the loss with respect to the output spike count si:
∂sk∂L=−∂sk∂(i∑yisi)+∂sk∂log(Z)
First term:
∂sk∂(i∑yisi)=yk
Second term:
∂sk∂log(Z)=Z1⋅∂sk∂Z=Z1⋅∂sk∂(j=1∑Cesj)=Z1⋅esk=Zesk=y^k
Combine both:
∂sk∂L=−yk+y^k
This means that the derivative of the cross-entropy loss with respect to the output spike count is given by:
∂sk∂L=y^k−yk
This is our final result.
Derivation with Sympy
Heres what the code does:
- Defines symbols for logits
s = [s0, s1, s2]
and true class one-hot vector y = [y0, y1, y2]
.
- Computes the softmax predictions
y_hat[i] = exp(si) / Z
.
- Defines the cross-entropy loss
L = -sum(y[i] * log(y_hat[i]))
.
- Substitutes
Z = sum(exp(si))
into L
to get a purely s
-dependent expression.
- Differentiates the loss w.r.t. each logit
s_i
, getting dL/ds_i
.
- Simplifies the result and compares it against the known identity:
∂sk∂L=yk^−yk
- Substitutes the one-ho- vector constraint
sum(y) = 1
so Sympy can simplify.
- Checks that
grad_k - (y_hat_k - y_k)
simplifies to zero (i.e. confirms the identity).
- checks_with_constraint returns
[0,0,0]
showing the condition is true.