Proofs Chapter 14: Matrix Chain Rule (Basic Formulas)
Proofs Chapter 14: Matrix Chain Rule
In this chapter we prove the matrix chain rule. The chain rule is the mathematical foundation of the backpropagation algorithm in deep learning, governing the propagation of gradients through the layers of a multilayer neural network. We derive the transformation from the component form of the chain rule to the trace form, and clarify the correspondence with automatic differentiation (forward mode and reverse mode). The results of this chapter provide the theoretical basis for efficient gradient computation based on computational graphs.
Prerequisites: Chapter 3 (Derivative of a vector with respect to a vector), Chapter 4 (Basic formulas of matrix differentiation). Related chapters: Chapter 6 (Hadamard product and activation functions), Chapter 11 (Matrix powers and composite functions).
14. Matrix Chain Rule
Unless otherwise stated, the formulas in this chapter hold under the following conditions:
- All formulas are based on the denominator layout
- Each entry of the intermediate variable matrix $\boldsymbol{U}$ is differentiable with respect to the entries of $\boldsymbol{X}$
- The trace form of the chain rule applies to the composition of scalar functions
Suppose the matrix $\boldsymbol{U} = f(\boldsymbol{X})$ is a function of the matrix $\boldsymbol{X}$, and there is a scalar function $g(\boldsymbol{U})$. We derive a formula for differentiating the composite function $g(f(\boldsymbol{X}))$ with respect to $\boldsymbol{X}$.
14.1 Matrix Chain Rule (Component Form)
Proof
Consider a scalar function $g(\boldsymbol{U})$, where $\boldsymbol{U} = f(\boldsymbol{X})$ is a function of the matrix $\boldsymbol{X}$.
The function $g$ produces a scalar value through all entries $U_{kl}$ ($k = 0, \ldots, M-1$, $l = 0, \ldots, N-1$) of the intermediate variable matrix $\boldsymbol{U}$. Writing $g$ explicitly as a function of the entries of $\boldsymbol{U}$,
\begin{equation}g = g(U_{00}, U_{01}, \ldots, U_{M-1,N-1}) \label{eq:14-1-1}\end{equation}
Consider perturbing the entry $X_{ij}$ of $\boldsymbol{X}$ by a small amount $\Delta X_{ij}$. This perturbation causes each entry $U_{kl}$ of the intermediate variable $\boldsymbol{U}$ to change as well.
\begin{equation}\Delta U_{kl} = \dfrac{\partial U_{kl}}{\partial X_{ij}} \Delta X_{ij} + O((\Delta X_{ij})^2) \label{eq:14-1-2}\end{equation}
We compute the change $\Delta g$ in $g$. Since $g$ depends on all entries of $\boldsymbol{U}$, we apply the total differential of a multivariate function.
\begin{equation}\Delta g = \displaystyle\sum_{k=0}^{M-1} \displaystyle\sum_{l=0}^{N-1} \dfrac{\partial g}{\partial U_{kl}} \Delta U_{kl} + O((\Delta U_{kl})^2) \label{eq:14-1-3}\end{equation}
Substituting $\eqref{eq:14-1-2}$ into $\eqref{eq:14-1-3}$,
\begin{equation}\Delta g = \displaystyle\sum_{k=0}^{M-1} \displaystyle\sum_{l=0}^{N-1} \dfrac{\partial g}{\partial U_{kl}} \dfrac{\partial U_{kl}}{\partial X_{ij}} \Delta X_{ij} + O((\Delta X_{ij})^2) \label{eq:14-1-4}\end{equation}
Dividing $\eqref{eq:14-1-4}$ by $\Delta X_{ij}$ and taking the limit as $\Delta X_{ij} \to 0$,
\begin{equation}\dfrac{\partial g}{\partial X_{ij}} = \lim_{\Delta X_{ij} \to 0} \dfrac{\Delta g}{\Delta X_{ij}} = \displaystyle\sum_{k=0}^{M-1} \displaystyle\sum_{l=0}^{N-1} \dfrac{\partial g}{\partial U_{kl}} \dfrac{\partial U_{kl}}{\partial X_{ij}} \label{eq:14-1-5}\end{equation}
Equation $\eqref{eq:14-1-5}$ is the component form of the matrix chain rule.
14.2 Matrix Chain Rule (Trace Form)
Proof
From $\eqref{eq:14-1-5}$ in 14.1, the component form of the chain rule is
\begin{equation}\dfrac{\partial g}{\partial X_{ij}} = \displaystyle\sum_{k=0}^{M-1} \displaystyle\sum_{l=0}^{N-1} \dfrac{\partial g}{\partial U_{kl}} \dfrac{\partial U_{kl}}{\partial X_{ij}} \label{eq:14-2-1}\end{equation}
Define two matrices: the gradient matrix $\boldsymbol{G}$ and the derivative matrix $\boldsymbol{H}^{ij}$.
\begin{equation}\boldsymbol{G} = \dfrac{\partial g}{\partial \boldsymbol{U}} \in \mathbb{R}^{M \times N}, \quad G_{kl} = \dfrac{\partial g}{\partial U_{kl}} \label{eq:14-2-2}\end{equation}
\begin{equation}\boldsymbol{H}^{ij} = \dfrac{\partial \boldsymbol{U}}{\partial X_{ij}} \in \mathbb{R}^{M \times N}, \quad H^{ij}_{kl} = \dfrac{\partial U_{kl}}{\partial X_{ij}} \label{eq:14-2-3}\end{equation}
Rewriting $\eqref{eq:14-2-1}$ using $\eqref{eq:14-2-2}$ and $\eqref{eq:14-2-3}$,
\begin{equation}\dfrac{\partial g}{\partial X_{ij}} = \displaystyle\sum_{k=0}^{M-1} \displaystyle\sum_{l=0}^{N-1} G_{kl} H^{ij}_{kl} \label{eq:14-2-4}\end{equation}
The right-hand side of $\eqref{eq:14-2-4}$ is the sum of the element-wise products of the two matrices $\boldsymbol{G}$ and $\boldsymbol{H}^{ij}$ (the Frobenius inner product).
We show that the Frobenius inner product $\langle \boldsymbol{G}, \boldsymbol{H}^{ij} \rangle_F$ can be expressed using the trace. First, we compute the $(p, q)$ entry of the matrix product $\boldsymbol{G}^\top \boldsymbol{H}^{ij}$. By the definition of the transpose $(\boldsymbol{G}^\top)_{pk} = G_{kp}$ and the definition of matrix multiplication,
\begin{equation}(\boldsymbol{G}^\top \boldsymbol{H}^{ij})_{pq} = \displaystyle\sum_{k=0}^{M-1} (\boldsymbol{G}^\top)_{pk} H^{ij}_{kq} = \displaystyle\sum_{k=0}^{M-1} G_{kp} H^{ij}_{kq} \label{eq:14-2-5}\end{equation}
We compute the trace of the matrix $\boldsymbol{G}^\top \boldsymbol{H}^{ij}$. Since $\boldsymbol{G} \in \mathbb{R}^{M \times N}$ and $\boldsymbol{H}^{ij} \in \mathbb{R}^{M \times N}$, we have $\boldsymbol{G}^\top \in \mathbb{R}^{N \times M}$, so the matrix product $\boldsymbol{G}^\top \boldsymbol{H}^{ij} \in \mathbb{R}^{N \times N}$. By the definition of the trace (the sum of the diagonal entries),
\begin{equation}\text{tr}(\boldsymbol{G}^\top \boldsymbol{H}^{ij}) = \displaystyle\sum_{l=0}^{N-1} (\boldsymbol{G}^\top \boldsymbol{H}^{ij})_{ll} \label{eq:14-2-6a}\end{equation}
Setting $p = q = l$ in $\eqref{eq:14-2-5}$,
\begin{equation}(\boldsymbol{G}^\top \boldsymbol{H}^{ij})_{ll} = \displaystyle\sum_{k=0}^{M-1} G_{kl} H^{ij}_{kl} \label{eq:14-2-6b}\end{equation}
Substituting $\eqref{eq:14-2-6b}$ into $\eqref{eq:14-2-6a}$,
\begin{equation}\text{tr}(\boldsymbol{G}^\top \boldsymbol{H}^{ij}) = \displaystyle\sum_{l=0}^{N-1} \displaystyle\sum_{k=0}^{M-1} G_{kl} H^{ij}_{kl} \label{eq:14-2-6}\end{equation}
Comparing $\eqref{eq:14-2-6}$ with $\eqref{eq:14-2-4}$, we see that the two expressions are equal.
\begin{equation}\dfrac{\partial g}{\partial X_{ij}} = \text{tr}(\boldsymbol{G}^\top \boldsymbol{H}^{ij}) = \text{tr}\left[\left(\dfrac{\partial g}{\partial \boldsymbol{U}}\right)^\top \dfrac{\partial \boldsymbol{U}}{\partial X_{ij}}\right] \label{eq:14-2-7}\end{equation}
Equation $\eqref{eq:14-2-7}$ is the trace form of the matrix chain rule.
References
- Petersen, K. B., & Pedersen, M. S. (2012). The Matrix Cookbook. Technical University of Denmark.
- Magnus, J. R., & Neudecker, H. (1999). Matrix Differential Calculus with Applications in Statistics and Econometrics (Revised ed.). Wiley.
- Matrix calculus - Wikipedia