Gradient of a Matrix Matrix multiplication

July 28, 2018 - 3 minute read -
Machine Learning Backpropagation Matrix Calculus

This is just matrix multiplication.

It’s good to understand how to derive gradients for your neural network. It gets a little hairy when you have matrix matrix multiplication, such as $WX + b$. When I was reviewing Backpropagation in CS231n, they handwaved over this derivation of the loss function $L$ in respect to the weights matrix $W$:


  • $X = (m, n)$ input matrix with $m$ features and $n$ samples
  • $W = (H, m)$ weight matrix with $H$ neurons
  • $D = WX$
    • $(H, n)$ matrix
  • $L = f(D)$
    • scalar value, $f$ is arbitrary loss function

Note that others may use $D = XW$ where $X$’s rows are samples and columns are feature dimensions. That’s ok, you can follow this math and switch the indices and find the result to be identical.

The canonical neuron is $Relu(D + b)$, but to make things simpler we’ll ignore the nonlinearity and bias and say the $L$ takes in $D$ instead of $Relu(D + b)$. We want to find the gradient of $L$ with respect to $W$ to do gradient descent.

We want to find $\frac{\partial L}{\partial W}$, so let’s start by looking at a specific weight $W_{dc}$. This way we can think more easily about the gradient of $L$ for a single weight and extrapolate for all weights $W$.

Let’s look more closely at the partial of $D_{ij}$ with respect to $W_{dc}$. We know that $\frac{\partial D_{ij}}{\partial W_{dc}} = 0$ if $i\neq d$ because $D_{ij}$ is the dot product of row $i$ of $W$ and column $j$ of $X$. This means the summation can be simplified by only looking at cases where $\frac{\partial D_{ij}}{\partial W_{dc}} \neq 0$, which is when $i = d$.

Finally, what is $\frac{\partial D_{dj}}{\partial W_{dc}}$?

So to put it all together, we have:

Now how can we simplify this? Well, one quick way is see that the sum over $j$ is doing a dot product on with row $d$ and column $c$ if we transpose $X_{cj}$ to $X^T_{jc}$.

Now we want this for all weights in $W$, which means we can generalize this to: