Derivative of Cross-Entropy Loss w.r.t. Output Spike Count

Table of Contents

The cross entropy loss function is commonly used in classification tasks, especially in the context of neural networks.

The loss function itself is used along with its gradient to update the weights of the network during training.

Here we derive the derivative of the cross-entropy loss with respect to the output spike count for a spiking neural network.

First we do it manually, and then we use Sympy to verify our results.

Manual Derivation

To derivate the derivate of the cross-entropy loss with respect to the output spike count, we assume the following:

The cross-entropy loss is defined as:

L=i=1Cyilog(y^i)L = -\sum_{i=1}^{C} y_i \log(\hat{y}_i)

The softmax function is defined as:

y^i=esij=1Cesj=esiZ\hat{y}_i = \frac{e^{s_i}}{\sum_{j=1}^{C} e^{s_j}} = \frac{e^{s_i}}{Z}

Where Z=j=1CesjZ = \sum_{j=1}^C e^{s_j}, the partition function.

Now we plug the partition function into the cross entropy loss:

L=i=1Cyilog(y^i)L = -\sum_{i=1}^{C} y_i \log(\hat{y}_i) L=i=1Cyilog(esiZ)L = -\sum_{i=1}^{C} y_i \log\left(\frac{e^{s_i}}{Z}\right) L=i=1Cyi(silog(Z))L = -\sum_{i=1}^{C} y_i \left(s_i - \log(Z)\right) L=i=1Cyisi+log(Z)i=1CyiL = -\sum_{i=1}^{C} y_i s_i + \log(Z) \cdot \sum_{i=1}^{C} y_i L=i=1Cyisi+log(Z)L = -\sum_{i=1}^{C} y_i s_i + \log(Z)

since i=1Cyi=1\sum_{i=1}^{C} y_i = 1.

Now we can compute the derivative of the loss with respect to the output spike count sis_i:

Lsk=sk(iyisi)+sklog(Z)\frac{\partial L}{\partial s_k} = - \frac{\partial }{ \partial s_k} (\sum_i y_i s_i) + \frac{\partial }{ \partial s_k} \log(Z)

First term:

sk(iyisi)=yk\frac{\partial }{ \partial s_k} (\sum_i y_i s_i) = y_k

Second term:

sklog(Z)=1ZZsk=1Zsk(j=1Cesj)=1Zesk=eskZ=y^k\frac{\partial }{ \partial s_k} \log(Z) = \frac{1}{Z} \cdot \frac{\partial Z}{\partial s_k}= \frac{1}{Z} \cdot \frac{\partial }{ \partial s_k} \left(\sum_{j=1}^{C} e^{s_j}\right) = \frac{1}{Z} \cdot e^{s_k} = \frac{e^{s_k}}{Z} = \hat{y}_k

Combine both:

Lsk=yk+y^k\frac{\partial L}{\partial s_k} = -y_k + \hat{y}_k

This means that the derivative of the cross-entropy loss with respect to the output spike count is given by:

Lsk=y^kyk\frac{\partial L}{\partial s_k} = \hat{y}_k - y_k

This is our final result.

Derivation with Sympy

Notebook
Loading...

Heres what the code does: