ニューラルネットワーク

深層学習における各層の逆伝播(バックプロパゲーション)公式。

表記規約
本ページの公式は分母レイアウト(denominator layout)に基づく。詳細はレイアウト規約を参照。

ニューラルネットワークの微分

深層学習における各層の逆伝播(バックプロパゲーション)公式を導出する。損失 $L$ の各パラメータに対する勾配を求め、勾配降下法でパラメータを更新する。

活性化関数と損失関数

ニューラルネットワークの活性化関数・損失関数の微分公式。第6章の証明番号を保持。

6.2 softmaxのJacobian

公式:$\displaystyle\dfrac{\partial\, \mathrm{softmax}(\boldsymbol{x})}{\partial \boldsymbol{x}} = \mathrm{diag}(\boldsymbol{p}) - \boldsymbol{p}\boldsymbol{p}^\top$
($\boldsymbol{p} = \mathrm{softmax}(\boldsymbol{x})$)
条件:$\boldsymbol{x} \in \mathbb{R}^N$、$p_i = e^{x_i} / \displaystyle\sum_{k} e^{x_k} > 0$、$\displaystyle\sum_i p_i = 1$
証明

softmax関数の $i$ 番目の出力は次のように定義される。

\begin{equation} p_i = \dfrac{e^{x_i}}{S}, \qquad S = \displaystyle\sum_{k=0}^{N-1} e^{x_k} \label{eq:6-2-1} \end{equation}

場合1:$i = j$ のとき。商の微分法を適用する。

\begin{equation} \dfrac{\partial p_i}{\partial x_i} = \dfrac{e^{x_i} \cdot S - e^{x_i} \cdot e^{x_i}}{S^2} = \dfrac{e^{x_i}}{S} - \dfrac{e^{x_i}}{S} \cdot \dfrac{e^{x_i}}{S} = p_i - p_i^2 = p_i(1 - p_i) \label{eq:6-2-2} \end{equation}

場合2:$i \neq j$ のとき。分子 $e^{x_i}$ は $x_j$ に依存しないので、微分されるのは分母のみである。

\begin{equation} \dfrac{\partial p_i}{\partial x_j} = \dfrac{0 \cdot S - e^{x_i} \cdot e^{x_j}}{S^2} = -\dfrac{e^{x_i}}{S} \cdot \dfrac{e^{x_j}}{S} = -p_i \, p_j \label{eq:6-2-3} \end{equation}

両方の場合を Kroneckerのデルタ $\delta_{ij}$ を用いてまとめると次のようになる。

\begin{equation} \dfrac{\partial p_i}{\partial x_j} = p_i(\delta_{ij} - p_j) \label{eq:6-2-4} \end{equation}

これを行列形式で書くと、Jacobi行列の $(i, j)$ 成分が $p_i \delta_{ij} - p_i p_j$ であるから次のようになる。

\begin{equation} \dfrac{\partial \boldsymbol{p}}{\partial \boldsymbol{x}} = \mathrm{diag}(\boldsymbol{p}) - \boldsymbol{p}\boldsymbol{p}^\top \label{eq:6-2-5} \end{equation}

補足:softmaxの出力は確率分布をなすため $\displaystyle\sum_i p_i = 1$ が成り立つ。Jacobi行列の各列の和は 0 になる。これは softmax 出力の総和が常に 1 であるという制約の微分的表現である。

6.3 シグモイド関数の微分

公式:$\displaystyle\dfrac{\partial \sigma(\boldsymbol{x})}{\partial \boldsymbol{x}} = \mathrm{diag}(\sigma(\boldsymbol{x}) \odot (1 - \sigma(\boldsymbol{x})))$
($\sigma(x) = 1 / (1 + e^{-x})$)
条件:$\boldsymbol{x} \in \mathbb{R}^N$、$\sigma$ は要素ごとに適用
証明

$\sigma$ は要素ごとに適用されるので、Jacobi行列は対角行列になる。成分 $\sigma(x_i)$ の微分を計算すればよい。

\begin{equation} \sigma(x) = \dfrac{1}{1 + e^{-x}} = (1 + e^{-x})^{-1} \label{eq:6-3-1} \end{equation}

合成関数の微分法(連鎖律)を適用する。$u = 1 + e^{-x}$ とおくと、$\sigma = u^{-1}$ なので次のようになる。

\begin{equation} \dfrac{d\sigma}{dx} = -u^{-2} \cdot \dfrac{du}{dx} = -(1 + e^{-x})^{-2} \cdot (-e^{-x}) = \dfrac{e^{-x}}{(1 + e^{-x})^2} \label{eq:6-3-2} \end{equation}

これを $\sigma(x)$ を用いて書き直す。$\sigma(x) = 1/(1+e^{-x})$ より次の関係が得られる。

\begin{equation} 1 - \sigma(x) = 1 - \dfrac{1}{1+e^{-x}} = \dfrac{e^{-x}}{1+e^{-x}} \label{eq:6-3-3} \end{equation}

したがって次のようになる。

\begin{equation} \dfrac{d\sigma}{dx} = \dfrac{1}{1+e^{-x}} \cdot \dfrac{e^{-x}}{1+e^{-x}} = \sigma(x)(1 - \sigma(x)) \label{eq:6-3-4} \end{equation}

ベクトルに対して要素ごとに適用すると、Jacobi行列は対角行列になる。

\begin{equation} \dfrac{\partial \sigma(\boldsymbol{x})}{\partial \boldsymbol{x}} = \mathrm{diag}(\sigma(\boldsymbol{x}) \odot (1 - \sigma(\boldsymbol{x}))) \label{eq:6-3-5} \end{equation}

補足:$\sigma'(x) = \sigma(x)(1-\sigma(x))$ は $x = 0$ で最大値 $1/4$ をとり、$|x| \to \infty$ で指数的に 0 に近づく。これが深層学習における勾配消失問題の一因となる。

6.4 tanh関数の微分

公式:$\displaystyle\dfrac{\partial \tanh(\boldsymbol{x})}{\partial \boldsymbol{x}} = \mathrm{diag}(1 - \tanh^2(\boldsymbol{x}))$
条件:$\boldsymbol{x} \in \mathbb{R}^N$、$\tanh$ は要素ごとに適用
証明

$\tanh$ の定義とシグモイドとの関係を確認する。

\begin{equation} \tanh(x) = \dfrac{e^x - e^{-x}}{e^x + e^{-x}} \label{eq:6-4-1} \end{equation}

商の微分法を適用する。$f = e^x - e^{-x}$、$g = e^x + e^{-x}$ とおくと次のようになる。

\begin{equation} \dfrac{d}{dx}\tanh(x) = \dfrac{f'g - fg'}{g^2} = \dfrac{(e^x + e^{-x})(e^x + e^{-x}) - (e^x - e^{-x})(e^x - e^{-x})}{(e^x + e^{-x})^2} \label{eq:6-4-2} \end{equation}

分子を展開する。

\begin{equation} (e^x + e^{-x})^2 - (e^x - e^{-x})^2 = 4 e^x e^{-x} = 4 \label{eq:6-4-3} \end{equation}

ここで恒等式 $(a+b)^2 - (a-b)^2 = 4ab$ を用いた。$a = e^x$, $b = e^{-x}$ として $ab = 1$ である。したがって次のようになる。

\begin{equation} \dfrac{d}{dx}\tanh(x) = \dfrac{4}{(e^x + e^{-x})^2} = 1 - \tanh^2(x) \label{eq:6-4-4} \end{equation}

最後の等号は $\tanh^2(x) = (e^x - e^{-x})^2/(e^x + e^{-x})^2$ を代入して確認できる。ベクトルに対して要素ごとに適用すると次のようになる。

\begin{equation} \dfrac{\partial \tanh(\boldsymbol{x})}{\partial \boldsymbol{x}} = \mathrm{diag}(1 - \tanh^2(\boldsymbol{x})) \label{eq:6-4-5} \end{equation}

補足:$\tanh(x) = 2\sigma(2x) - 1$ の関係を用いて、シグモイドの微分(6.3)から導くこともできる。$\tanh'(0) = 1$ であり、シグモイドより勾配消失が緩やかである。

6.5 ReLU関数の微分

公式:$\displaystyle\dfrac{\partial\, \mathrm{ReLU}(\boldsymbol{x})}{\partial \boldsymbol{x}} = \mathrm{diag}(\mathbf{1}_{x_i > 0})$
($\mathrm{ReLU}(x) = \max(0, x)$)
条件:$\boldsymbol{x} \in \mathbb{R}^N$、$x = 0$ では劣勾配 $[0, 1]$ から慣例として 0 を選択
証明

ReLU は要素ごとに適用される区分線形関数である。各成分について場合分けする。

\begin{equation} \mathrm{ReLU}(x) = \max(0, x) = \begin{cases} x & (x > 0) \\ 0 & (x \leq 0) \end{cases} \label{eq:6-5-1} \end{equation}

$x > 0$ のとき $\mathrm{ReLU}(x) = x$ なので微分は 1 である。$x < 0$ のとき $\mathrm{ReLU}(x) = 0$ なので微分は 0 である。

\begin{equation} \dfrac{d}{dx}\mathrm{ReLU}(x) = \begin{cases} 1 & (x > 0) \\ 0 & (x < 0) \end{cases} = \mathbf{1}_{x > 0} \label{eq:6-5-2} \end{equation}

$x = 0$ では ReLU は微分不可能であるが、劣勾配(subdifferential)$[0, 1]$ が存在する。深層学習の実装では慣例的に $\mathrm{ReLU}'(0) = 0$ とする。

ベクトルに対して要素ごとに適用すると次のようになる。

\begin{equation} \dfrac{\partial\, \mathrm{ReLU}(\boldsymbol{x})}{\partial \boldsymbol{x}} = \mathrm{diag}(\mathbf{1}_{x_i > 0}) \label{eq:6-5-3} \end{equation}

6.6 Leaky ReLU関数の微分

公式:$\displaystyle\dfrac{\partial\, \mathrm{LeakyReLU}(\boldsymbol{x})}{\partial \boldsymbol{x}} = \mathrm{diag}(\mathbf{1}_{x_i > 0} + \alpha \cdot \mathbf{1}_{x_i \leq 0})$
($\mathrm{LeakyReLU}(x) = \max(\alpha x, x)$、$0 < \alpha < 1$)
条件:$\boldsymbol{x} \in \mathbb{R}^N$、$\alpha \in (0, 1)$ は定数(典型値 $\alpha = 0.01$)
証明

Leaky ReLU は ReLU の変種であり、負の領域でも小さな勾配 $\alpha$ を持つ。

\begin{equation} \mathrm{LeakyReLU}(x) = \begin{cases} x & (x > 0) \\ \alpha x & (x \leq 0) \end{cases} \label{eq:6-6-1} \end{equation}

各領域で微分すると次のようになる。

\begin{equation} \dfrac{d}{dx}\mathrm{LeakyReLU}(x) = \begin{cases} 1 & (x > 0) \\ \alpha & (x \leq 0) \end{cases} = \mathbf{1}_{x > 0} + \alpha \cdot \mathbf{1}_{x \leq 0} \label{eq:6-6-2} \end{equation}

$x = 0$ では左微分 $\alpha$ と右微分 $1$ が異なるため厳密には微分不可能であるが、実装では慣例的に $\alpha$ を用いる($0 < \alpha < 1$)。

ベクトルに対して要素ごとに適用すると次のようになる。

\begin{equation} \dfrac{\partial\, \mathrm{LeakyReLU}(\boldsymbol{x})}{\partial \boldsymbol{x}} = \mathrm{diag}(\mathbf{1}_{x_i > 0} + \alpha \cdot \mathbf{1}_{x_i \leq 0}) \label{eq:6-6-3} \end{equation}

補足:$\alpha = 0$ のとき ReLU(6.5)に帰着する。$\alpha$ を学習可能パラメータとする変種は PReLU(Parametric ReLU)と呼ばれる。

6.7 クロスエントロピー損失の微分(softmax + CE)

公式:$\displaystyle\dfrac{\partial}{\partial \boldsymbol{x}} \bigl(-\boldsymbol{y}^\top \log \boldsymbol{p}\bigr) = \boldsymbol{p} - \boldsymbol{y}$
($\boldsymbol{p} = \mathrm{softmax}(\boldsymbol{x})$)
条件:$\boldsymbol{x} \in \mathbb{R}^N$、$\boldsymbol{y} \in \mathbb{R}^N$ は one-hot ベクトル($y_c = 1$, 他は 0)または確率分布($\displaystyle\sum_i y_i = 1$, $y_i \geq 0$)
証明

損失関数を成分で書き下す。

\begin{equation} L = -\boldsymbol{y}^\top \log \boldsymbol{p} = -\displaystyle\sum_{i=0}^{N-1} y_i \log p_i \label{eq:6-7-1} \end{equation}

$L$ を $x_j$ で偏微分する。連鎖律を適用すると次のようになる。

\begin{equation} \dfrac{\partial L}{\partial x_j} = -\displaystyle\sum_{i=0}^{N-1} y_i \cdot \dfrac{1}{p_i} \cdot \dfrac{\partial p_i}{\partial x_j} \label{eq:6-7-2} \end{equation}

softmaxのJacobian(6.2、式\eqref{eq:6-2-4})を代入する。$\partial p_i / \partial x_j = p_i(\delta_{ij} - p_j)$ であるから次のようになる。

\begin{equation} \dfrac{\partial L}{\partial x_j} = -\displaystyle\sum_{i} y_i \cdot \dfrac{1}{p_i} \cdot p_i(\delta_{ij} - p_j) = -\displaystyle\sum_{i} y_i (\delta_{ij} - p_j) \label{eq:6-7-3} \end{equation}

和を展開する。$\displaystyle\sum_i y_i \delta_{ij} = y_j$ であるから次のようになる。

\begin{equation} \dfrac{\partial L}{\partial x_j} = -y_j + p_j \displaystyle\sum_{i} y_i \label{eq:6-7-4} \end{equation}

$\boldsymbol{y}$ が確率分布のとき $\displaystyle\sum_i y_i = 1$ であるので次のようになる。

\begin{equation} \dfrac{\partial L}{\partial x_j} = p_j - y_j \label{eq:6-7-5} \end{equation}

ベクトル形式でまとめると次のようになる。

\begin{equation} \dfrac{\partial L}{\partial \boldsymbol{x}} = \boldsymbol{p} - \boldsymbol{y} \label{eq:6-7-6} \end{equation}

補足:この簡潔な結果は softmax と交差エントロピーの組み合わせに特有であり、数値的にも安定である。実装では $\log \mathrm{softmax}$ を直接計算する log-sum-exp トリックが用いられる。

6.8 二値クロスエントロピー損失の微分(sigmoid + BCE)

公式:$\displaystyle\dfrac{\partial}{\partial x} \mathrm{BCE}(y, \sigma(x)) = \sigma(x) - y$
($\mathrm{BCE} = -y\log\sigma(x) - (1-y)\log(1-\sigma(x))$)
条件:$x \in \mathbb{R}$(スカラ)、$y \in \{0, 1\}$ はラベル
証明

損失関数を書き下す。$p = \sigma(x)$ とおく。

\begin{equation} L = -y \log p - (1 - y)\log(1 - p) \label{eq:6-8-1} \end{equation}

連鎖律を適用する。$dp/dx = \sigma(x)(1 - \sigma(x)) = p(1-p)$(6.3)を用いる。

\begin{equation} \dfrac{dL}{dx} = \left(-\dfrac{y}{p} + \dfrac{1 - y}{1 - p}\right) \dfrac{dp}{dx} = \left(-\dfrac{y}{p} + \dfrac{1 - y}{1 - p}\right) p(1 - p) \label{eq:6-8-2} \end{equation}

各項を整理する。

\begin{equation} \dfrac{dL}{dx} = -y(1 - p) + (1 - y)p = -y + yp + p - yp = p - y \label{eq:6-8-3} \end{equation}

したがって次のようになる。

\begin{equation} \dfrac{dL}{dx} = \sigma(x) - y \label{eq:6-8-4} \end{equation}

補足:多クラスの場合(6.7)と同じ $\boldsymbol{p} - \boldsymbol{y}$ の形になる。二値分類は $N = 2$ の softmax の特殊ケースとみなせる。

6.9 GELU関数の微分

公式:$\displaystyle\dfrac{\partial\, \mathrm{GELU}(\boldsymbol{x})}{\partial \boldsymbol{x}} = \mathrm{diag}(\Phi(\boldsymbol{x}) + \boldsymbol{x} \odot \phi(\boldsymbol{x}))$
($\mathrm{GELU}(x) = x \cdot \Phi(x)$、$\Phi$: 標準正規CDF、$\phi$: 標準正規PDF)
条件:$\boldsymbol{x} \in \mathbb{R}^N$、$\Phi(x) = \dfrac{1}{2}\bigl[1 + \mathrm{erf}(x/\sqrt{2})\bigr]$、$\phi(x) = \dfrac{1}{\sqrt{2\pi}} e^{-x^2/2}$
証明

GELU(Gaussian Error Linear Unit)の定義を確認する。

\begin{equation} \mathrm{GELU}(x) = x \, \Phi(x) \label{eq:6-9-1} \end{equation}

積の微分法を適用する。$\Phi'(x) = \phi(x)$(CDFの微分はPDF)であるから次のようになる。

\begin{equation} \dfrac{d}{dx}\mathrm{GELU}(x) = \Phi(x) + x \, \phi(x) \label{eq:6-9-2} \end{equation}

ベクトルに対して要素ごとに適用すると次のようになる。

\begin{equation} \dfrac{\partial\, \mathrm{GELU}(\boldsymbol{x})}{\partial \boldsymbol{x}} = \mathrm{diag}(\Phi(\boldsymbol{x}) + \boldsymbol{x} \odot \phi(\boldsymbol{x})) \label{eq:6-9-3} \end{equation}

補足:GELUは BERT、GPT 等の Transformer モデルで広く用いられる。$x \to +\infty$ で $\mathrm{GELU}(x) \to x$(恒等写像)、$x \to -\infty$ で $\mathrm{GELU}(x) \to 0$ という漸近的挙動を持ち、ReLU を滑らかにした関数とみなせる。

6.10 Swish (SiLU) 関数の微分

公式:$\displaystyle\dfrac{\partial\, \mathrm{Swish}(\boldsymbol{x})}{\partial \boldsymbol{x}} = \mathrm{diag}(\sigma(\boldsymbol{x}) + \boldsymbol{x} \odot \sigma(\boldsymbol{x}) \odot (1 - \sigma(\boldsymbol{x})))$
($\mathrm{Swish}(x) = x \cdot \sigma(x)$)
条件:$\boldsymbol{x} \in \mathbb{R}^N$、$\sigma$ はシグモイド関数
証明

Swish(SiLU: Sigmoid Linear Unit)の定義を確認する。

\begin{equation} \mathrm{Swish}(x) = x \, \sigma(x) \label{eq:6-10-1} \end{equation}

積の微分法を適用する。$\sigma'(x) = \sigma(x)(1 - \sigma(x))$(6.3)であるから次のようになる。

\begin{equation} \dfrac{d}{dx}\mathrm{Swish}(x) = \sigma(x) + x \, \sigma'(x) = \sigma(x) + x \, \sigma(x)(1 - \sigma(x)) \label{eq:6-10-2} \end{equation}

$\sigma(x)$ でくくると、次のようにも書ける。

\begin{equation} \dfrac{d}{dx}\mathrm{Swish}(x) = \sigma(x)\bigl[1 + x(1 - \sigma(x))\bigr] \label{eq:6-10-3} \end{equation}

ベクトルに対して要素ごとに適用すると次のようになる。

\begin{equation} \dfrac{\partial\, \mathrm{Swish}(\boldsymbol{x})}{\partial \boldsymbol{x}} = \mathrm{diag}(\sigma(\boldsymbol{x}) + \boldsymbol{x} \odot \sigma(\boldsymbol{x}) \odot (1 - \sigma(\boldsymbol{x}))) \label{eq:6-10-4} \end{equation}

補足:Swish は GELU と類似の形状を持つ。GELU が $x \cdot \Phi(x)$ であるのに対し、Swish は $x \cdot \sigma(x)$ であり、$\sigma(x)$ が $\Phi(1.702 x)$ の近似であることから両者は近い関数となる。EfficientNet 等で採用された。

全結合層・バッチ正規化・Attention

ニューラルネットワークの基本構成要素の逆伝播公式。

17.1 全結合層の重み勾配

公式:$\displaystyle\dfrac{\partial L}{\partial \boldsymbol{W}} = \boldsymbol{X}^\top \displaystyle\dfrac{\partial L}{\partial \boldsymbol{Y}}$
条件:順伝播 $\boldsymbol{Y} = \boldsymbol{X}\boldsymbol{W} + \boldsymbol{1}_N \boldsymbol{b}^\top$、$\boldsymbol{X} \in \mathbb{R}^{N \times D_{\text{in}}}$、$\boldsymbol{W} \in \mathbb{R}^{D_{\text{in}} \times D_{\text{out}}}$
証明

全結合層の順伝播は行列積で表される。

\begin{equation}\boldsymbol{Y} = \boldsymbol{X}\boldsymbol{W} + \boldsymbol{1}_N \boldsymbol{b}^\top \label{eq:17-1-1}\end{equation}

$\eqref{eq:17-1-1}$ を成分で書くと、出力の $(n, j)$ 成分は次のようになる。

\begin{equation}Y_{nj} = \displaystyle\sum_{k=0}^{D_{\text{in}}-1} X_{nk} W_{kj} + b_j \label{eq:17-1-2}\end{equation}

$\eqref{eq:17-1-2}$ において、$Y_{nj}$ が $W_{ij}$ に依存するのは $k = i$ のときのみである。したがって偏微分を計算する。

\begin{equation}\dfrac{\partial Y_{nj}}{\partial W_{ij}} = \dfrac{\partial}{\partial W_{ij}} \left( \displaystyle\sum_{k} X_{nk} W_{kj} + b_j \right) = X_{ni} \label{eq:17-1-3}\end{equation}

損失 $L$ は出力 $\boldsymbol{Y}$ を通じて $\boldsymbol{W}$ に依存する。連鎖律(1.26)を適用する。

\begin{equation}\dfrac{\partial L}{\partial W_{ij}} = \displaystyle\sum_{n,k} \dfrac{\partial L}{\partial Y_{nk}} \dfrac{\partial Y_{nk}}{\partial W_{ij}} \label{eq:17-1-4}\end{equation}

$\eqref{eq:17-1-3}$ より、$\displaystyle\dfrac{\partial Y_{nk}}{\partial W_{ij}}$ は $k = j$ のときのみ非零である。したがって $\eqref{eq:17-1-4}$ は次のように簡約される。

\begin{equation}\dfrac{\partial L}{\partial W_{ij}} = \displaystyle\sum_n \dfrac{\partial L}{\partial Y_{nj}} \cdot X_{ni} \label{eq:17-1-5}\end{equation}

$\eqref{eq:17-1-5}$ を行列積の成分として解釈する。$(\boldsymbol{X}^\top)_{in} = X_{ni}$ であるから

\begin{equation}\dfrac{\partial L}{\partial W_{ij}} = \displaystyle\sum_n (\boldsymbol{X}^\top)_{in} \left(\dfrac{\partial L}{\partial \boldsymbol{Y}}\right)_{nj} = \left( \boldsymbol{X}^\top \dfrac{\partial L}{\partial \boldsymbol{Y}} \right)_{ij} \label{eq:17-1-6}\end{equation}

$\eqref{eq:17-1-6}$ はすべての $(i, j)$ について成り立つので、行列形式で最終結果を得る。

\begin{equation}\dfrac{\partial L}{\partial \boldsymbol{W}} = \boldsymbol{X}^\top \dfrac{\partial L}{\partial \boldsymbol{Y}} \label{eq:17-1-7}\end{equation}

補足:$\eqref{eq:17-1-7}$ は入力 $\boldsymbol{X}$ と出力勾配 $\displaystyle\dfrac{\partial L}{\partial \boldsymbol{Y}}$ の行列積である。バッチサイズ $N$ 個のサンプルに対する勾配が自動的に累積される。

17.2 全結合層のバイアス勾配

公式:$\displaystyle\dfrac{\partial L}{\partial \boldsymbol{b}} = \displaystyle\sum_{n=0}^{N-1} \left(\displaystyle\dfrac{\partial L}{\partial \boldsymbol{Y}}\right)_{n,:}^\top$
条件:順伝播 $\boldsymbol{Y} = \boldsymbol{X}\boldsymbol{W} + \boldsymbol{1}_N \boldsymbol{b}^\top$、$\boldsymbol{b} \in \mathbb{R}^{D_{\text{out}}}$
証明

17.1 の $\eqref{eq:17-1-2}$ より、出力の成分表示は次のようになる。

\begin{equation}Y_{nj} = \displaystyle\sum_{k} X_{nk} W_{kj} + b_j \label{eq:17-2-1}\end{equation}

$\eqref{eq:17-2-1}$ において、$Y_{nj}$ が $b_i$ に依存するのは $j = i$ のときのみである。偏微分を計算する。

\begin{equation}\dfrac{\partial Y_{nj}}{\partial b_i} = \dfrac{\partial}{\partial b_i} \left( \displaystyle\sum_{k} X_{nk} W_{kj} + b_j \right) = \delta_{ij} \label{eq:17-2-2}\end{equation}

ここで $\delta_{ij}$ はクロネッカーのデルタであり、$j = i$ のとき 1、それ以外は 0 である。

損失 $L$ は出力 $\boldsymbol{Y}$ を通じて $\boldsymbol{b}$ に依存する。連鎖律(1.26)を適用する。

\begin{equation}\dfrac{\partial L}{\partial b_i} = \displaystyle\sum_{n,j} \dfrac{\partial L}{\partial Y_{nj}} \dfrac{\partial Y_{nj}}{\partial b_i} \label{eq:17-2-3}\end{equation}

$\eqref{eq:17-2-2}$ を $\eqref{eq:17-2-3}$ に代入する。$\delta_{ij}$ により $j = i$ の項のみ残る。

\begin{equation}\dfrac{\partial L}{\partial b_i} = \displaystyle\sum_{n,j} \dfrac{\partial L}{\partial Y_{nj}} \delta_{ij} = \displaystyle\sum_n \dfrac{\partial L}{\partial Y_{ni}} \label{eq:17-2-4}\end{equation}

$\eqref{eq:17-2-4}$ はバッチ方向($n = 0, 1, \ldots, N-1$)の和である。ベクトル形式で書くと、$\left(\displaystyle\dfrac{\partial L}{\partial \boldsymbol{Y}}\right)_{n,:}$ は勾配行列の第 $n$ 行(行ベクトル)を表す。

\begin{equation}\dfrac{\partial L}{\partial \boldsymbol{b}} = \displaystyle\sum_{n=0}^{N-1} \left(\dfrac{\partial L}{\partial \boldsymbol{Y}}\right)_{n,:}^\top \label{eq:17-2-5}\end{equation}

補足:$\eqref{eq:17-2-5}$ は、バッチ内の全サンプルからの勾配寄与を累積していることを示す。NumPy では np.sum(dY, axis=0) で計算できる。

17.3 全結合層の入力勾配

公式:$\displaystyle\dfrac{\partial L}{\partial \boldsymbol{X}} = \displaystyle\dfrac{\partial L}{\partial \boldsymbol{Y}} \boldsymbol{W}^\top$
条件:順伝播 $\boldsymbol{Y} = \boldsymbol{X}\boldsymbol{W}$、$\boldsymbol{X} \in \mathbb{R}^{N \times D_{\text{in}}}$
証明

全結合層の順伝播を成分で書く(バイアス項は入力勾配に影響しないため省略)。

\begin{equation}Y_{nj} = \displaystyle\sum_{k=0}^{D_{\text{in}}-1} X_{nk} W_{kj} \label{eq:17-3-1}\end{equation}

$\eqref{eq:17-3-1}$ において、$Y_{nj}$ が $X_{mi}$ に依存するのは $n = m$ かつ $k = i$ のときのみである。偏微分を計算する。

\begin{equation}\dfrac{\partial Y_{nj}}{\partial X_{mi}} = \delta_{nm} W_{ij} \label{eq:17-3-2}\end{equation}

ここで $\delta_{nm}$ はクロネッカーのデルタであり、$n = m$ のとき 1、それ以外は 0 である。

損失 $L$ は出力 $\boldsymbol{Y}$ を通じて $\boldsymbol{X}$ に依存する。連鎖律(1.26)を適用する。

\begin{equation}\dfrac{\partial L}{\partial X_{mi}} = \displaystyle\sum_{n,j} \dfrac{\partial L}{\partial Y_{nj}} \dfrac{\partial Y_{nj}}{\partial X_{mi}} \label{eq:17-3-3}\end{equation}

$\eqref{eq:17-3-2}$ を $\eqref{eq:17-3-3}$ に代入する。$\delta_{nm}$ により $n = m$ の項のみ残る。

\begin{equation}\dfrac{\partial L}{\partial X_{mi}} = \displaystyle\sum_{n,j} \dfrac{\partial L}{\partial Y_{nj}} \delta_{nm} W_{ij} = \displaystyle\sum_j \dfrac{\partial L}{\partial Y_{mj}} W_{ij} \label{eq:17-3-4}\end{equation}

$\eqref{eq:17-3-4}$ を行列積の成分として解釈する。$(\boldsymbol{W}^\top)_{ji} = W_{ij}$ であるから

\begin{equation}\dfrac{\partial L}{\partial X_{mi}} = \displaystyle\sum_j \left(\dfrac{\partial L}{\partial \boldsymbol{Y}}\right)_{mj} (\boldsymbol{W}^\top)_{ji} = \left( \dfrac{\partial L}{\partial \boldsymbol{Y}} \boldsymbol{W}^\top \right)_{mi} \label{eq:17-3-5}\end{equation}

$\eqref{eq:17-3-5}$ はすべての $(m, i)$ について成り立つので、行列形式で最終結果を得る。

\begin{equation}\dfrac{\partial L}{\partial \boldsymbol{X}} = \dfrac{\partial L}{\partial \boldsymbol{Y}} \boldsymbol{W}^\top \label{eq:17-3-6}\end{equation}

補足:$\eqref{eq:17-3-6}$ は逆伝播の核心である。出力勾配 $\displaystyle\dfrac{\partial L}{\partial \boldsymbol{Y}}$ を重み行列の転置 $\boldsymbol{W}^\top$ で変換し、前の層に伝播させる。多層ネットワークでは、各層でこの操作を繰り返す。

17.4 バッチ正規化のスケール勾配

公式:$\displaystyle\dfrac{\partial L}{\partial \gamma} = \displaystyle\sum_n \displaystyle\dfrac{\partial L}{\partial y_n} \hat{x}_n$
条件:$y_n = \gamma \hat{x}_n + \beta$、$\hat{x}_n = \displaystyle\dfrac{x_n - \mu}{\sqrt{\sigma^2 + \epsilon}}$
証明

バッチ正規化の出力は、正規化された値 $\hat{x}_n$ をスケール $\gamma$ とシフト $\beta$ で変換したものである。

\begin{equation}y_n = \gamma \hat{x}_n + \beta \label{eq:17-4-1}\end{equation}

$\eqref{eq:17-4-1}$ において、$\gamma$ は全サンプルで共有されるパラメータである。$y_n$ を $\gamma$ で偏微分する。

\begin{equation}\dfrac{\partial y_n}{\partial \gamma} = \dfrac{\partial}{\partial \gamma} (\gamma \hat{x}_n + \beta) = \hat{x}_n \label{eq:17-4-2}\end{equation}

損失 $L$ は出力 $y_n$($n = 0, 1, \ldots, N-1$)を通じて $\gamma$ に依存する。連鎖律(1.26)を適用する。

\begin{equation}\dfrac{\partial L}{\partial \gamma} = \displaystyle\sum_{n=0}^{N-1} \dfrac{\partial L}{\partial y_n} \dfrac{\partial y_n}{\partial \gamma} \label{eq:17-4-3}\end{equation}

$\eqref{eq:17-4-2}$ を $\eqref{eq:17-4-3}$ に代入して最終結果を得る。

\begin{equation}\dfrac{\partial L}{\partial \gamma} = \displaystyle\sum_{n=0}^{N-1} \dfrac{\partial L}{\partial y_n} \hat{x}_n \label{eq:17-4-4}\end{equation}

補足:$\eqref{eq:17-4-4}$ は、出力勾配 $\displaystyle\dfrac{\partial L}{\partial y_n}$ と正規化された入力 $\hat{x}_n$ の内積(バッチ方向の和)である。

17.5 バッチ正規化のシフト勾配

公式:$\displaystyle\dfrac{\partial L}{\partial \beta} = \displaystyle\sum_n \displaystyle\dfrac{\partial L}{\partial y_n}$
条件:$y_n = \gamma \hat{x}_n + \beta$
証明

17.4 の $\eqref{eq:17-4-1}$ より、バッチ正規化の出力は次のように表される。

\begin{equation}y_n = \gamma \hat{x}_n + \beta \label{eq:17-5-1}\end{equation}

$\eqref{eq:17-5-1}$ において、$\beta$ は全サンプルで共有されるパラメータである。$y_n$ を $\beta$ で偏微分する。

\begin{equation}\dfrac{\partial y_n}{\partial \beta} = \dfrac{\partial}{\partial \beta} (\gamma \hat{x}_n + \beta) = 1 \label{eq:17-5-2}\end{equation}

損失 $L$ は出力 $y_n$($n = 0, 1, \ldots, N-1$)を通じて $\beta$ に依存する。連鎖律(1.26)を適用する。

\begin{equation}\dfrac{\partial L}{\partial \beta} = \displaystyle\sum_{n=0}^{N-1} \dfrac{\partial L}{\partial y_n} \dfrac{\partial y_n}{\partial \beta} \label{eq:17-5-3}\end{equation}

$\eqref{eq:17-5-2}$ を $\eqref{eq:17-5-3}$ に代入して最終結果を得る。

\begin{equation}\dfrac{\partial L}{\partial \beta} = \displaystyle\sum_{n=0}^{N-1} \dfrac{\partial L}{\partial y_n} \cdot 1 = \displaystyle\sum_{n=0}^{N-1} \dfrac{\partial L}{\partial y_n} \label{eq:17-5-4}\end{equation}

補足:$\eqref{eq:17-5-4}$ は出力勾配のバッチ方向の単純な総和である。17.2 のバイアス勾配と同様の構造を持つ。

17.6 バッチ正規化の入力勾配

公式:$\displaystyle\dfrac{\partial L}{\partial x_i} = \displaystyle\dfrac{\gamma}{\sqrt{\sigma^2 + \epsilon}} \left( \displaystyle\dfrac{\partial L}{\partial y_i} - \displaystyle\dfrac{1}{N}\displaystyle\sum_j \displaystyle\dfrac{\partial L}{\partial y_j} - \displaystyle\dfrac{\hat{x}_i}{N}\displaystyle\sum_j \displaystyle\dfrac{\partial L}{\partial y_j}\hat{x}_j \right)$
条件:$\mu = \displaystyle\dfrac{1}{N}\displaystyle\sum_n x_n$、$\sigma^2 = \displaystyle\dfrac{1}{N}\displaystyle\sum_n (x_n - \mu)^2$、$\hat{x}_n = \displaystyle\dfrac{x_n - \mu}{\sqrt{\sigma^2 + \epsilon}}$
証明

バッチ正規化では、入力 $x_i$ がバッチ統計量 $\mu$, $\sigma^2$ および正規化された値 $\hat{x}_i$ のすべてに影響する。計算グラフは次のようになる。

\begin{equation}\mu = \dfrac{1}{N}\displaystyle\sum_{n=0}^{N-1} x_n \label{eq:17-6-1}\end{equation}

\begin{equation}\sigma^2 = \dfrac{1}{N}\displaystyle\sum_{n=0}^{N-1} (x_n - \mu)^2 \label{eq:17-6-2}\end{equation}

\begin{equation}\hat{x}_n = \dfrac{x_n - \mu}{\sqrt{\sigma^2 + \epsilon}} \label{eq:17-6-3}\end{equation}

\begin{equation}y_n = \gamma \hat{x}_n + \beta \label{eq:17-6-4}\end{equation}

【ステップ1:$\hat{x}_i$ に関する勾配】

$\eqref{eq:17-6-4}$ より、$y_i$ を $\hat{x}_i$ で偏微分する。

\begin{equation}\dfrac{\partial y_i}{\partial \hat{x}_i} = \gamma \label{eq:17-6-5}\end{equation}

連鎖律(1.26)より

\begin{equation}\dfrac{\partial L}{\partial \hat{x}_i} = \dfrac{\partial L}{\partial y_i} \dfrac{\partial y_i}{\partial \hat{x}_i} = \dfrac{\partial L}{\partial y_i} \gamma \label{eq:17-6-6}\end{equation}

【ステップ2:$\sigma^2$ に関する勾配】

$\eqref{eq:17-6-3}$ より、$\hat{x}_n$ を $\sigma^2$ で偏微分する。$\sqrt{\sigma^2 + \epsilon} = (\sigma^2 + \epsilon)^{1/2}$ なので

\begin{equation}\dfrac{\partial \hat{x}_n}{\partial \sigma^2} = (x_n - \mu) \cdot \left(-\dfrac{1}{2}\right)(\sigma^2 + \epsilon)^{-3/2} \label{eq:17-6-7}\end{equation}

すべての $n$ について連鎖律(1.26)を適用し、総和を取る。

\begin{equation}\dfrac{\partial L}{\partial \sigma^2} = \displaystyle\sum_{n=0}^{N-1} \dfrac{\partial L}{\partial \hat{x}_n} \dfrac{\partial \hat{x}_n}{\partial \sigma^2} = \displaystyle\sum_{n=0}^{N-1} \dfrac{\partial L}{\partial \hat{x}_n} (x_n - \mu) \cdot \left(-\dfrac{1}{2}\right)(\sigma^2 + \epsilon)^{-3/2} \label{eq:17-6-8}\end{equation}

【ステップ3:$\mu$ に関する勾配】

$\mu$ は $\hat{x}_n$ に直接影響するとともに、$\sigma^2$ を経由しても影響する。まず $\hat{x}_n$ への直接の影響を計算する。$\eqref{eq:17-6-3}$ より

\begin{equation}\dfrac{\partial \hat{x}_n}{\partial \mu} = \dfrac{-1}{\sqrt{\sigma^2 + \epsilon}} \label{eq:17-6-9}\end{equation}

次に $\sigma^2$ を経由する影響を計算する。$\eqref{eq:17-6-2}$ より

\begin{equation}\dfrac{\partial \sigma^2}{\partial \mu} = \dfrac{1}{N}\displaystyle\sum_{n=0}^{N-1} 2(x_n - \mu)(-1) = \dfrac{-2}{N}\displaystyle\sum_{n=0}^{N-1}(x_n - \mu) \label{eq:17-6-10}\end{equation}

$\displaystyle\sum_n (x_n - \mu) = \displaystyle\sum_n x_n - N\mu = N\mu - N\mu = 0$ であるから、$\eqref{eq:17-6-10}$ は 0 になる。

\begin{equation}\dfrac{\partial \sigma^2}{\partial \mu} = 0 \label{eq:17-6-11}\end{equation}

したがって、$\mu$ に関する勾配は直接経路のみとなる。

\begin{equation}\dfrac{\partial L}{\partial \mu} = \displaystyle\sum_{n=0}^{N-1} \dfrac{\partial L}{\partial \hat{x}_n} \dfrac{\partial \hat{x}_n}{\partial \mu} = \displaystyle\sum_{n=0}^{N-1} \dfrac{\partial L}{\partial \hat{x}_n} \cdot \dfrac{-1}{\sqrt{\sigma^2 + \epsilon}} \label{eq:17-6-12}\end{equation}

【ステップ4:$x_i$ に関する勾配】

$x_i$ は $\hat{x}_i$、$\mu$、$\sigma^2$ のすべてに影響する。それぞれの経路からの寄与を計算する。

$\hat{x}_i$ への直接の影響:$\eqref{eq:17-6-3}$ より

\begin{equation}\dfrac{\partial \hat{x}_i}{\partial x_i} = \dfrac{1}{\sqrt{\sigma^2 + \epsilon}} \label{eq:17-6-13}\end{equation}

$\mu$ への影響:$\eqref{eq:17-6-1}$ より

\begin{equation}\dfrac{\partial \mu}{\partial x_i} = \dfrac{1}{N} \label{eq:17-6-14}\end{equation}

$\sigma^2$ への影響:$\eqref{eq:17-6-2}$ より

\begin{equation}\dfrac{\partial \sigma^2}{\partial x_i} = \dfrac{1}{N} \cdot 2(x_i - \mu) = \dfrac{2(x_i - \mu)}{N} \label{eq:17-6-15}\end{equation}

連鎖律(1.26)ですべての経路を合計する。

\begin{equation}\dfrac{\partial L}{\partial x_i} = \dfrac{\partial L}{\partial \hat{x}_i} \dfrac{\partial \hat{x}_i}{\partial x_i} + \dfrac{\partial L}{\partial \mu} \dfrac{\partial \mu}{\partial x_i} + \dfrac{\partial L}{\partial \sigma^2} \dfrac{\partial \sigma^2}{\partial x_i} \label{eq:17-6-16}\end{equation}

$\eqref{eq:17-6-6}$、$\eqref{eq:17-6-12}$、$\eqref{eq:17-6-8}$、$\eqref{eq:17-6-13}$、$\eqref{eq:17-6-14}$、$\eqref{eq:17-6-15}$ を $\eqref{eq:17-6-16}$ に代入する。

\begin{equation}\dfrac{\partial L}{\partial x_i} = \dfrac{\partial L}{\partial \hat{x}_i} \cdot \dfrac{1}{\sqrt{\sigma^2 + \epsilon}} + \dfrac{\partial L}{\partial \mu} \cdot \dfrac{1}{N} + \dfrac{\partial L}{\partial \sigma^2} \cdot \dfrac{2(x_i - \mu)}{N} \label{eq:17-6-17}\end{equation}

$\eqref{eq:17-6-17}$ を整理する。$\eqref{eq:17-6-6}$ より $\displaystyle\dfrac{\partial L}{\partial \hat{x}_i} = \gamma \displaystyle\dfrac{\partial L}{\partial y_i}$ である。また $\hat{x}_i = \displaystyle\dfrac{x_i - \mu}{\sqrt{\sigma^2 + \epsilon}}$ より $(x_i - \mu) = \hat{x}_i \sqrt{\sigma^2 + \epsilon}$ である。これらを代入し、計算を整理すると最終結果を得る。

\begin{equation}\dfrac{\partial L}{\partial x_i} = \dfrac{\gamma}{\sqrt{\sigma^2 + \epsilon}} \left( \dfrac{\partial L}{\partial y_i} - \dfrac{1}{N}\displaystyle\sum_{j=0}^{N-1} \dfrac{\partial L}{\partial y_j} - \dfrac{\hat{x}_i}{N}\displaystyle\sum_{j=0}^{N-1} \dfrac{\partial L}{\partial y_j}\hat{x}_j \right) \label{eq:17-6-18}\end{equation}

補足:$\eqref{eq:17-6-18}$ の3つの項は、それぞれ直接経路、平均経路、分散経路からの寄与を表す。この複雑な勾配は、正規化された値 $\hat{x}_i$ だけでなく、バッチ統計量 $\mu$, $\sigma^2$ を通じた間接的な依存を考慮している。

17.7 レイヤー正規化の入力勾配

公式:$\displaystyle\dfrac{\partial L}{\partial x_i} = \displaystyle\dfrac{\gamma}{\sqrt{\sigma^2 + \epsilon}} \left( \displaystyle\dfrac{\partial L}{\partial y_i} - \displaystyle\dfrac{1}{D}\displaystyle\sum_j \displaystyle\dfrac{\partial L}{\partial y_j} - \displaystyle\dfrac{\hat{x}_i}{D}\displaystyle\sum_j \displaystyle\dfrac{\partial L}{\partial y_j}\hat{x}_j \right)$
条件:$\mu$, $\sigma^2$ はサンプル内の統計量(特徴方向の平均・分散)、$D$ は特徴次元
証明

レイヤー正規化では、統計量をバッチ方向ではなく特徴方向で計算する。1つのサンプル内で次の計算を行う。

\begin{equation}\mu = \dfrac{1}{D}\displaystyle\sum_{d=0}^{D-1} x_d \label{eq:17-7-1}\end{equation}

\begin{equation}\sigma^2 = \dfrac{1}{D}\displaystyle\sum_{d=0}^{D-1} (x_d - \mu)^2 \label{eq:17-7-2}\end{equation}

\begin{equation}\hat{x}_d = \dfrac{x_d - \mu}{\sqrt{\sigma^2 + \epsilon}} \label{eq:17-7-3}\end{equation}

\begin{equation}y_d = \gamma_d \hat{x}_d + \beta_d \label{eq:17-7-4}\end{equation}

導出過程は 17.6 と同じパターンである。バッチ正規化との違いは、統計量の計算方向のみである。

17.6 の $\eqref{eq:17-6-1}$ から $\eqref{eq:17-6-18}$ の導出において、バッチサイズ $N$ を特徴次元 $D$ に置き換える。各ステップの計算は同一であり、最終結果として次を得る。

\begin{equation}\dfrac{\partial L}{\partial x_i} = \dfrac{\gamma_i}{\sqrt{\sigma^2 + \epsilon}} \left( \dfrac{\partial L}{\partial y_i} - \dfrac{1}{D}\displaystyle\sum_{j=0}^{D-1} \dfrac{\partial L}{\partial y_j} - \dfrac{\hat{x}_i}{D}\displaystyle\sum_{j=0}^{D-1} \dfrac{\partial L}{\partial y_j}\hat{x}_j \right) \label{eq:17-7-5}\end{equation}

補足:Layer Normalization は Transformer で標準的に使われる。$\eqref{eq:17-7-5}$ においてスケールパラメータ $\gamma_i$ は特徴次元ごとに異なる値を持つ点がバッチ正規化と異なる。バッチサイズに依存しないため、可変長シーケンスの処理に適している。

17.8 Attention の Value 勾配

公式:$\displaystyle\dfrac{\partial L}{\partial \boldsymbol{V}} = \boldsymbol{A}^\top \displaystyle\dfrac{\partial L}{\partial \boldsymbol{O}}$
条件:$\boldsymbol{O} = \boldsymbol{A}\boldsymbol{V}$、$\boldsymbol{A} = \text{softmax}(\boldsymbol{Q}\boldsymbol{K}^\top / \sqrt{d_k}) \in \mathbb{R}^{n \times n}$、$\boldsymbol{V} \in \mathbb{R}^{n \times d_v}$
証明

Attention の出力は、Attention 重み $\boldsymbol{A}$ と Value $\boldsymbol{V}$ の行列積で定義される。

\begin{equation}\boldsymbol{O} = \boldsymbol{A}\boldsymbol{V} \label{eq:17-8-1}\end{equation}

$\eqref{eq:17-8-1}$ を成分で書くと、出力の $(i, j)$ 成分は次のようになる。

\begin{equation}O_{ij} = \displaystyle\sum_{k=0}^{n-1} A_{ik} V_{kj} \label{eq:17-8-2}\end{equation}

$\eqref{eq:17-8-2}$ において、$O_{ij}$ が $V_{pq}$ に依存するのは $k = p$ かつ $j = q$ のときのみである。偏微分を計算する。

\begin{equation}\dfrac{\partial O_{ij}}{\partial V_{pq}} = A_{ip} \delta_{jq} \label{eq:17-8-3}\end{equation}

ここで $\delta_{jq}$ はクロネッカーのデルタであり、$j = q$ のとき 1、それ以外は 0 である。

損失 $L$ は出力 $\boldsymbol{O}$ を通じて $\boldsymbol{V}$ に依存する。連鎖律(1.26)を適用する。

\begin{equation}\dfrac{\partial L}{\partial V_{pq}} = \displaystyle\sum_{i,j} \dfrac{\partial L}{\partial O_{ij}} \dfrac{\partial O_{ij}}{\partial V_{pq}} \label{eq:17-8-4}\end{equation}

$\eqref{eq:17-8-3}$ を $\eqref{eq:17-8-4}$ に代入する。$\delta_{jq}$ により $j = q$ の項のみ残る。

\begin{equation}\dfrac{\partial L}{\partial V_{pq}} = \displaystyle\sum_{i,j} \dfrac{\partial L}{\partial O_{ij}} A_{ip} \delta_{jq} = \displaystyle\sum_i A_{ip} \dfrac{\partial L}{\partial O_{iq}} \label{eq:17-8-5}\end{equation}

$\eqref{eq:17-8-5}$ を行列積の成分として解釈する。$(\boldsymbol{A}^\top)_{pi} = A_{ip}$ であるから

\begin{equation}\dfrac{\partial L}{\partial V_{pq}} = \displaystyle\sum_i (\boldsymbol{A}^\top)_{pi} \left(\dfrac{\partial L}{\partial \boldsymbol{O}}\right)_{iq} = \left( \boldsymbol{A}^\top \dfrac{\partial L}{\partial \boldsymbol{O}} \right)_{pq} \label{eq:17-8-6}\end{equation}

$\eqref{eq:17-8-6}$ はすべての $(p, q)$ について成り立つので、行列形式で最終結果を得る。

\begin{equation}\dfrac{\partial L}{\partial \boldsymbol{V}} = \boldsymbol{A}^\top \dfrac{\partial L}{\partial \boldsymbol{O}} \label{eq:17-8-7}\end{equation}

補足:$\eqref{eq:17-8-7}$ は 17.1 の重み勾配と同様の構造を持つ。Attention 重み $\boldsymbol{A}$ の転置が入力 $\boldsymbol{X}$ の役割を果たす。

17.9 Attention 重みの勾配

公式:$\displaystyle\dfrac{\partial L}{\partial \boldsymbol{A}} = \displaystyle\dfrac{\partial L}{\partial \boldsymbol{O}} \boldsymbol{V}^\top$
条件:$\boldsymbol{O} = \boldsymbol{A}\boldsymbol{V}$
証明

17.8 の $\eqref{eq:17-8-2}$ より、Attention 出力の成分表示は次のようになる。

\begin{equation}O_{ij} = \displaystyle\sum_{k=0}^{n-1} A_{ik} V_{kj} \label{eq:17-9-1}\end{equation}

$\eqref{eq:17-9-1}$ において、$O_{ij}$ が $A_{pq}$ に依存するのは $i = p$ かつ $k = q$ のときのみである。偏微分を計算する。

\begin{equation}\dfrac{\partial O_{ij}}{\partial A_{pq}} = \delta_{ip} V_{qj} \label{eq:17-9-2}\end{equation}

ここで $\delta_{ip}$ はクロネッカーのデルタである。

損失 $L$ は出力 $\boldsymbol{O}$ を通じて $\boldsymbol{A}$ に依存する。連鎖律(1.26)を適用する。

\begin{equation}\dfrac{\partial L}{\partial A_{pq}} = \displaystyle\sum_{i,j} \dfrac{\partial L}{\partial O_{ij}} \dfrac{\partial O_{ij}}{\partial A_{pq}} \label{eq:17-9-3}\end{equation}

$\eqref{eq:17-9-2}$ を $\eqref{eq:17-9-3}$ に代入する。$\delta_{ip}$ により $i = p$ の項のみ残る。

\begin{equation}\dfrac{\partial L}{\partial A_{pq}} = \displaystyle\sum_{i,j} \dfrac{\partial L}{\partial O_{ij}} \delta_{ip} V_{qj} = \displaystyle\sum_j \dfrac{\partial L}{\partial O_{pj}} V_{qj} \label{eq:17-9-4}\end{equation}

$\eqref{eq:17-9-4}$ を行列積の成分として解釈する。$(\boldsymbol{V}^\top)_{qj} = V_{qj}$ であるから

\begin{equation}\dfrac{\partial L}{\partial A_{pq}} = \displaystyle\sum_j \left(\dfrac{\partial L}{\partial \boldsymbol{O}}\right)_{pj} (\boldsymbol{V}^\top)_{qj} = \left( \dfrac{\partial L}{\partial \boldsymbol{O}} \boldsymbol{V}^\top \right)_{pq} \label{eq:17-9-5}\end{equation}

$\eqref{eq:17-9-5}$ はすべての $(p, q)$ について成り立つので、行列形式で最終結果を得る。

\begin{equation}\dfrac{\partial L}{\partial \boldsymbol{A}} = \dfrac{\partial L}{\partial \boldsymbol{O}} \boldsymbol{V}^\top \label{eq:17-9-6}\end{equation}

補足:$\eqref{eq:17-9-6}$ は 17.3 の入力勾配と同様の構造を持つ。この勾配は softmax を通じて Query と Key に逆伝播される。

17.10 softmax 前のスコア勾配

公式:$\displaystyle\dfrac{\partial L}{\partial \boldsymbol{S}} = \boldsymbol{A} \odot \left( \displaystyle\dfrac{\partial L}{\partial \boldsymbol{A}} - \text{rowsum}\left(\displaystyle\dfrac{\partial L}{\partial \boldsymbol{A}} \odot \boldsymbol{A}\right) \boldsymbol{1}^\top \right)$
条件:$\boldsymbol{A} = \text{softmax}(\boldsymbol{S})$(行ごとに softmax)、$\boldsymbol{S} = \boldsymbol{Q}\boldsymbol{K}^\top / \sqrt{d_k}$
証明

Attention 重み $\boldsymbol{A}$ はスコア $\boldsymbol{S}$ の行ごとの softmax で計算される。

\begin{equation}\boldsymbol{A} = \text{softmax}(\boldsymbol{S}) \label{eq:17-10-1}\end{equation}

softmax の定義より、行 $i$ の成分 $A_{ij}$ は次のように表される。

\begin{equation}A_{ij} = \dfrac{e^{S_{ij}}}{\displaystyle\sum_{l=0}^{n-1} e^{S_{il}}} \label{eq:17-10-2}\end{equation}

softmax のヤコビ行列を計算する。$A_{ij}$ を $S_{ik}$ で偏微分する。$j = k$ の場合と $j \neq k$ の場合で分けて計算する。

$j = k$ の場合:商の微分法則(1.28)を適用する。

\begin{equation}\dfrac{\partial A_{ij}}{\partial S_{ij}} = \dfrac{e^{S_{ij}} \displaystyle\sum_l e^{S_{il}} - e^{S_{ij}} e^{S_{ij}}}{(\displaystyle\sum_l e^{S_{il}})^2} = A_{ij} - A_{ij}^2 = A_{ij}(1 - A_{ij}) \label{eq:17-10-3}\end{equation}

$j \neq k$ の場合:

\begin{equation}\dfrac{\partial A_{ij}}{\partial S_{ik}} = \dfrac{0 \cdot \displaystyle\sum_l e^{S_{il}} - e^{S_{ij}} e^{S_{ik}}}{(\displaystyle\sum_l e^{S_{il}})^2} = -A_{ij} A_{ik} \label{eq:17-10-4}\end{equation}

$\eqref{eq:17-10-3}$ と $\eqref{eq:17-10-4}$ を統一して書くと、クロネッカーのデルタを用いて次のように表される。

\begin{equation}\dfrac{\partial A_{ij}}{\partial S_{ik}} = A_{ij}(\delta_{jk} - A_{ik}) \label{eq:17-10-5}\end{equation}

損失 $L$ は $\boldsymbol{A}$ を通じて $\boldsymbol{S}$ に依存する。行 $i$ について連鎖律(1.26)を適用する。

\begin{equation}\dfrac{\partial L}{\partial S_{ik}} = \displaystyle\sum_{j=0}^{n-1} \dfrac{\partial L}{\partial A_{ij}} \dfrac{\partial A_{ij}}{\partial S_{ik}} \label{eq:17-10-6}\end{equation}

$\eqref{eq:17-10-5}$ を $\eqref{eq:17-10-6}$ に代入する。

\begin{equation}\dfrac{\partial L}{\partial S_{ik}} = \displaystyle\sum_j \dfrac{\partial L}{\partial A_{ij}} A_{ij}(\delta_{jk} - A_{ik}) \label{eq:17-10-7}\end{equation}

$\eqref{eq:17-10-7}$ を展開する。$\delta_{jk}$ により $j = k$ の項のみ残る第1項と、全体の和となる第2項に分ける。

\begin{equation}\dfrac{\partial L}{\partial S_{ik}} = \dfrac{\partial L}{\partial A_{ik}} A_{ik} - A_{ik} \displaystyle\sum_j \dfrac{\partial L}{\partial A_{ij}} A_{ij} \label{eq:17-10-8}\end{equation}

$\eqref{eq:17-10-8}$ を $A_{ik}$ でくくる。

\begin{equation}\dfrac{\partial L}{\partial S_{ik}} = A_{ik} \left( \dfrac{\partial L}{\partial A_{ik}} - \displaystyle\sum_j \dfrac{\partial L}{\partial A_{ij}} A_{ij} \right) \label{eq:17-10-9}\end{equation}

$\eqref{eq:17-10-9}$ を行列形式で書く。$\displaystyle\sum_j \displaystyle\dfrac{\partial L}{\partial A_{ij}} A_{ij}$ は $\displaystyle\dfrac{\partial L}{\partial \boldsymbol{A}} \odot \boldsymbol{A}$ の行 $i$ の和である。これを $\text{rowsum}(\cdot)$ で表す。

\begin{equation}\dfrac{\partial L}{\partial \boldsymbol{S}} = \boldsymbol{A} \odot \left( \dfrac{\partial L}{\partial \boldsymbol{A}} - \text{rowsum}\left(\dfrac{\partial L}{\partial \boldsymbol{A}} \odot \boldsymbol{A}\right) \boldsymbol{1}^\top \right) \label{eq:17-10-10}\end{equation}

補足:$\eqref{eq:17-10-10}$ において $\boldsymbol{1}^\top$ は行ベクトルであり、$\text{rowsum}(\cdot) \boldsymbol{1}^\top$ は各行の和を全列に複製する操作である。この公式は FlashAttention などの効率的な実装で使用される。

17.11 Query の勾配

公式:$\displaystyle\dfrac{\partial L}{\partial \boldsymbol{Q}} = \displaystyle\dfrac{1}{\sqrt{d_k}} \displaystyle\dfrac{\partial L}{\partial \boldsymbol{S}} \boldsymbol{K}$
条件:$\boldsymbol{S} = \boldsymbol{Q}\boldsymbol{K}^\top / \sqrt{d_k}$
証明

Attention スコア $\boldsymbol{S}$ は Query と Key の内積をスケーリングしたものである。

\begin{equation}\boldsymbol{S} = \dfrac{1}{\sqrt{d_k}} \boldsymbol{Q}\boldsymbol{K}^\top \label{eq:17-11-1}\end{equation}

$\eqref{eq:17-11-1}$ を成分で書くと、スコアの $(i, j)$ 成分は次のようになる。

\begin{equation}S_{ij} = \dfrac{1}{\sqrt{d_k}} \displaystyle\sum_{k=0}^{d_k-1} Q_{ik} K_{jk} \label{eq:17-11-2}\end{equation}

$\eqref{eq:17-11-2}$ において、$S_{ij}$ が $Q_{pq}$ に依存するのは $i = p$ かつ $k = q$ のときのみである。偏微分を計算する。

\begin{equation}\dfrac{\partial S_{ij}}{\partial Q_{pq}} = \dfrac{1}{\sqrt{d_k}} \delta_{ip} K_{jq} \label{eq:17-11-3}\end{equation}

損失 $L$ はスコア $\boldsymbol{S}$ を通じて $\boldsymbol{Q}$ に依存する。連鎖律(1.26)を適用する。

\begin{equation}\dfrac{\partial L}{\partial Q_{pq}} = \displaystyle\sum_{i,j} \dfrac{\partial L}{\partial S_{ij}} \dfrac{\partial S_{ij}}{\partial Q_{pq}} \label{eq:17-11-4}\end{equation}

$\eqref{eq:17-11-3}$ を $\eqref{eq:17-11-4}$ に代入する。$\delta_{ip}$ により $i = p$ の項のみ残る。

\begin{equation}\dfrac{\partial L}{\partial Q_{pq}} = \dfrac{1}{\sqrt{d_k}} \displaystyle\sum_{i,j} \dfrac{\partial L}{\partial S_{ij}} \delta_{ip} K_{jq} = \dfrac{1}{\sqrt{d_k}} \displaystyle\sum_j \dfrac{\partial L}{\partial S_{pj}} K_{jq} \label{eq:17-11-5}\end{equation}

$\eqref{eq:17-11-5}$ を行列積の成分として解釈する。

\begin{equation}\dfrac{\partial L}{\partial Q_{pq}} = \dfrac{1}{\sqrt{d_k}} \displaystyle\sum_j \left(\dfrac{\partial L}{\partial \boldsymbol{S}}\right)_{pj} K_{jq} = \dfrac{1}{\sqrt{d_k}} \left( \dfrac{\partial L}{\partial \boldsymbol{S}} \boldsymbol{K} \right)_{pq} \label{eq:17-11-6}\end{equation}

$\eqref{eq:17-11-6}$ はすべての $(p, q)$ について成り立つので、行列形式で最終結果を得る。

\begin{equation}\dfrac{\partial L}{\partial \boldsymbol{Q}} = \dfrac{1}{\sqrt{d_k}} \dfrac{\partial L}{\partial \boldsymbol{S}} \boldsymbol{K} \label{eq:17-11-7}\end{equation}

補足:$\eqref{eq:17-11-7}$ において、スケーリング因子 $\displaystyle\dfrac{1}{\sqrt{d_k}}$ は順伝播と同様に逆伝播にも現れる。

17.12 Key の勾配

公式:$\displaystyle\dfrac{\partial L}{\partial \boldsymbol{K}} = \displaystyle\dfrac{1}{\sqrt{d_k}} \left(\displaystyle\dfrac{\partial L}{\partial \boldsymbol{S}}\right)^\top \boldsymbol{Q}$
条件:$\boldsymbol{S} = \boldsymbol{Q}\boldsymbol{K}^\top / \sqrt{d_k}$
証明

17.11 の $\eqref{eq:17-11-2}$ より、スコアの成分表示は次のようになる。

\begin{equation}S_{ij} = \dfrac{1}{\sqrt{d_k}} \displaystyle\sum_{k=0}^{d_k-1} Q_{ik} K_{jk} \label{eq:17-12-1}\end{equation}

$\eqref{eq:17-12-1}$ において、$S_{ij}$ が $K_{pq}$ に依存するのは $j = p$ かつ $k = q$ のときのみである。偏微分を計算する。

\begin{equation}\dfrac{\partial S_{ij}}{\partial K_{pq}} = \dfrac{1}{\sqrt{d_k}} Q_{iq} \delta_{jp} \label{eq:17-12-2}\end{equation}

損失 $L$ はスコア $\boldsymbol{S}$ を通じて $\boldsymbol{K}$ に依存する。連鎖律(1.26)を適用する。

\begin{equation}\dfrac{\partial L}{\partial K_{pq}} = \displaystyle\sum_{i,j} \dfrac{\partial L}{\partial S_{ij}} \dfrac{\partial S_{ij}}{\partial K_{pq}} \label{eq:17-12-3}\end{equation}

$\eqref{eq:17-12-2}$ を $\eqref{eq:17-12-3}$ に代入する。$\delta_{jp}$ により $j = p$ の項のみ残る。

\begin{equation}\dfrac{\partial L}{\partial K_{pq}} = \dfrac{1}{\sqrt{d_k}} \displaystyle\sum_{i,j} \dfrac{\partial L}{\partial S_{ij}} Q_{iq} \delta_{jp} = \dfrac{1}{\sqrt{d_k}} \displaystyle\sum_i \dfrac{\partial L}{\partial S_{ip}} Q_{iq} \label{eq:17-12-4}\end{equation}

$\eqref{eq:17-12-4}$ を行列積の成分として解釈する。$\left(\left(\displaystyle\dfrac{\partial L}{\partial \boldsymbol{S}}\right)^\top\right)_{pi} = \left(\displaystyle\dfrac{\partial L}{\partial \boldsymbol{S}}\right)_{ip}$ であるから

\begin{equation}\dfrac{\partial L}{\partial K_{pq}} = \dfrac{1}{\sqrt{d_k}} \displaystyle\sum_i \left(\left(\dfrac{\partial L}{\partial \boldsymbol{S}}\right)^\top\right)_{pi} Q_{iq} = \dfrac{1}{\sqrt{d_k}} \left( \left(\dfrac{\partial L}{\partial \boldsymbol{S}}\right)^\top \boldsymbol{Q} \right)_{pq} \label{eq:17-12-5}\end{equation}

$\eqref{eq:17-12-5}$ はすべての $(p, q)$ について成り立つので、行列形式で最終結果を得る。

\begin{equation}\dfrac{\partial L}{\partial \boldsymbol{K}} = \dfrac{1}{\sqrt{d_k}} \left(\dfrac{\partial L}{\partial \boldsymbol{S}}\right)^\top \boldsymbol{Q} \label{eq:17-12-6}\end{equation}

補足:$\eqref{eq:17-12-6}$ と $\eqref{eq:17-11-7}$ を比較すると、Query と Key の勾配は対称的な構造を持つことがわかる。これは $\boldsymbol{S} = \boldsymbol{Q}\boldsymbol{K}^\top$ において Query と Key が転置の関係にあることを反映している。

17.13 畳み込み層のフィルタ勾配

公式:$\displaystyle\dfrac{\partial L}{\partial \boldsymbol{F}} = \boldsymbol{X} \star \displaystyle\dfrac{\partial L}{\partial \boldsymbol{Y}}$
条件:$\boldsymbol{Y} = \boldsymbol{X} * \boldsymbol{F}$(畳み込み)、$\star$ は相互相関(cross-correlation)
証明

2D 畳み込みを考える。入力 $\boldsymbol{X}$ とフィルタ $\boldsymbol{F}$ の畳み込みで出力 $\boldsymbol{Y}$ が計算される。

\begin{equation}\boldsymbol{Y} = \boldsymbol{X} * \boldsymbol{F} \label{eq:17-13-1}\end{equation}

$\eqref{eq:17-13-1}$ を成分で書くと、フィルタのサイズを $H \times W$ として次のようになる。

\begin{equation}Y_{ij} = \displaystyle\sum_{m=0}^{H-1} \displaystyle\sum_{n=0}^{W-1} X_{i+m, j+n} F_{mn} \label{eq:17-13-2}\end{equation}

$\eqref{eq:17-13-2}$ において、$Y_{ij}$ が $F_{pq}$ に依存するのは $m = p$ かつ $n = q$ のときのみである。偏微分を計算する。

\begin{equation}\dfrac{\partial Y_{ij}}{\partial F_{pq}} = X_{i+p, j+q} \label{eq:17-13-3}\end{equation}

損失 $L$ は出力 $\boldsymbol{Y}$ を通じて $\boldsymbol{F}$ に依存する。連鎖律(1.26)を適用する。

\begin{equation}\dfrac{\partial L}{\partial F_{pq}} = \displaystyle\sum_{i,j} \dfrac{\partial L}{\partial Y_{ij}} \dfrac{\partial Y_{ij}}{\partial F_{pq}} \label{eq:17-13-4}\end{equation}

$\eqref{eq:17-13-3}$ を $\eqref{eq:17-13-4}$ に代入する。

\begin{equation}\dfrac{\partial L}{\partial F_{pq}} = \displaystyle\sum_{i,j} \dfrac{\partial L}{\partial Y_{ij}} X_{i+p, j+q} \label{eq:17-13-5}\end{equation}

$\eqref{eq:17-13-5}$ は入力 $\boldsymbol{X}$ と出力勾配 $\displaystyle\dfrac{\partial L}{\partial \boldsymbol{Y}}$ の相互相関(cross-correlation)である。

\begin{equation}\dfrac{\partial L}{\partial \boldsymbol{F}} = \boldsymbol{X} \star \dfrac{\partial L}{\partial \boldsymbol{Y}} \label{eq:17-13-6}\end{equation}

補足:$\eqref{eq:17-13-6}$ において、相互相関 $\star$ は畳み込み $*$ と異なりフィルタの反転を行わない。深層学習フレームワークでは、順伝播で相互相関を使用することが多い。

17.14 畳み込み層の入力勾配

公式:$\displaystyle\dfrac{\partial L}{\partial \boldsymbol{X}} = \displaystyle\dfrac{\partial L}{\partial \boldsymbol{Y}} *_{\text{full}} \text{rot}_{180}(\boldsymbol{F})$
条件:$*_{\text{full}}$ は「full」畳み込み(ゼロパディングあり)、$\text{rot}_{180}$ は180度回転
証明

17.13 の $\eqref{eq:17-13-2}$ より、畳み込みの成分表示は次のようになる。

\begin{equation}Y_{ij} = \displaystyle\sum_{m=0}^{H-1} \displaystyle\sum_{n=0}^{W-1} X_{i+m, j+n} F_{mn} \label{eq:17-14-1}\end{equation}

$\eqref{eq:17-14-1}$ において、入力の要素 $X_{pq}$ が出力 $Y_{ij}$ に寄与する条件を考える。$X_{pq} = X_{i+m, j+n}$ となるのは $i = p - m$、$j = q - n$ のときである。

出力のインデックス $(i, j)$ が有効な範囲にあるためには、$0 \leq p - m$ かつ $0 \leq q - n$ である必要がある。つまり $X_{pq}$ は複数の出力要素に寄与する可能性がある。

偏微分を計算する。$X_{pq}$ が $Y_{ij}$ に寄与するとき

\begin{equation}\dfrac{\partial Y_{ij}}{\partial X_{pq}} = F_{p-i, q-j} \label{eq:17-14-2}\end{equation}

ただし、$(p-i, q-j)$ がフィルタの有効な範囲 $[0, H-1] \times [0, W-1]$ 内にある場合のみである。

損失 $L$ は出力 $\boldsymbol{Y}$ を通じて $\boldsymbol{X}$ に依存する。連鎖律(1.26)を適用する。

\begin{equation}\dfrac{\partial L}{\partial X_{pq}} = \displaystyle\sum_{i,j} \dfrac{\partial L}{\partial Y_{ij}} \dfrac{\partial Y_{ij}}{\partial X_{pq}} \label{eq:17-14-3}\end{equation}

$\eqref{eq:17-14-2}$ を $\eqref{eq:17-14-3}$ に代入する。和は $X_{pq}$ が寄与するすべての $(i, j)$ について取る。

\begin{equation}\dfrac{\partial L}{\partial X_{pq}} = \displaystyle\sum_{i,j: (p-i, q-j) \in [0, H-1] \times [0, W-1]} \dfrac{\partial L}{\partial Y_{ij}} F_{p-i, q-j} \label{eq:17-14-4}\end{equation}

$\eqref{eq:17-14-4}$ を解釈する。変数変換 $m' = p - i$, $n' = q - j$ を行うと、$i = p - m'$, $j = q - n'$ であり、和は $m' \in [0, H-1]$, $n' \in [0, W-1]$ について取られる。

\begin{equation}\dfrac{\partial L}{\partial X_{pq}} = \displaystyle\sum_{m'=0}^{H-1} \displaystyle\sum_{n'=0}^{W-1} \dfrac{\partial L}{\partial Y_{p-m', q-n'}} F_{m', n'} \label{eq:17-14-5}\end{equation}

$\eqref{eq:17-14-5}$ は出力勾配 $\displaystyle\dfrac{\partial L}{\partial \boldsymbol{Y}}$ と 180度回転したフィルタ $\text{rot}_{180}(\boldsymbol{F})$ の「full」畳み込みに相当する。

\begin{equation}\dfrac{\partial L}{\partial \boldsymbol{X}} = \dfrac{\partial L}{\partial \boldsymbol{Y}} *_{\text{full}} \text{rot}_{180}(\boldsymbol{F}) \label{eq:17-14-6}\end{equation}

補足:$\eqref{eq:17-14-6}$ においてフィルタの180度回転は $(\text{rot}_{180}(\boldsymbol{F}))_{mn} = F_{H-1-m, W-1-n}$ で定義される。「full」畳み込みは、出力が入力より大きくなるようにゼロパディングを行う畳み込みである。

17.15 Max Pooling の勾配

公式:$\displaystyle\dfrac{\partial L}{\partial X_{ij}} = \displaystyle\dfrac{\partial L}{\partial Y_k} \cdot \mathbf{1}_{X_{ij} = \max}$
条件:$Y_k = \max_{(i,j) \in \text{pool}_k} X_{ij}$
証明

Max pooling は各プーリング領域内の最大値を取る操作である。

\begin{equation}Y_k = \max_{(i,j) \in \text{pool}_k} X_{ij} \label{eq:17-15-1}\end{equation}

$\eqref{eq:17-15-1}$ において、最大値を達成する位置を $(i^*, j^*)$ とする。

\begin{equation}(i^*, j^*) = \arg\max_{(i,j) \in \text{pool}_k} X_{ij} \label{eq:17-15-2}\end{equation}

max 関数の微分を考える。$Y_k$ は最大値を取る要素 $X_{i^*, j^*}$ にのみ依存し、他の要素には依存しない(局所的には)。したがって偏微分は次のようになる。

\begin{equation}\dfrac{\partial Y_k}{\partial X_{ij}} = \begin{cases} 1 & \text{if } (i, j) = (i^*, j^*) \\ 0 & \text{otherwise} \end{cases} \label{eq:17-15-3}\end{equation}

$\eqref{eq:17-15-3}$ は指示関数 $\mathbf{1}_{X_{ij} = \max}$ を用いて次のように書ける。

\begin{equation}\dfrac{\partial Y_k}{\partial X_{ij}} = \mathbf{1}_{X_{ij} = \max\{X_{pq}: (p,q) \in \text{pool}_k\}} \label{eq:17-15-4}\end{equation}

損失 $L$ は出力 $Y_k$ を通じて $X_{ij}$ に依存する。連鎖律(1.26)を適用する。

\begin{equation}\dfrac{\partial L}{\partial X_{ij}} = \dfrac{\partial L}{\partial Y_k} \dfrac{\partial Y_k}{\partial X_{ij}} \label{eq:17-15-5}\end{equation}

$\eqref{eq:17-15-4}$ を $\eqref{eq:17-15-5}$ に代入して最終結果を得る。

\begin{equation}\dfrac{\partial L}{\partial X_{ij}} = \dfrac{\partial L}{\partial Y_k} \cdot \mathbf{1}_{X_{ij} = \max} \label{eq:17-15-6}\end{equation}

補足:$\eqref{eq:17-15-6}$ より、勾配は最大値位置にのみ伝播し、他の位置には伝播しない。複数の要素が同じ最大値を持つ場合、実装によって異なる(最初の要素のみ、または均等に分配)。

17.16 Average Pooling の勾配

公式:$\displaystyle\dfrac{\partial L}{\partial X_{ij}} = \displaystyle\dfrac{1}{|\text{pool}|} \displaystyle\dfrac{\partial L}{\partial Y_k}$
条件:$Y_k = \displaystyle\dfrac{1}{|\text{pool}_k|} \displaystyle\sum_{(i,j) \in \text{pool}_k} X_{ij}$、$(i,j) \in \text{pool}_k$
証明

Average pooling は各プーリング領域内の平均値を取る操作である。

\begin{equation}Y_k = \dfrac{1}{|\text{pool}_k|} \displaystyle\sum_{(i,j) \in \text{pool}_k} X_{ij} \label{eq:17-16-1}\end{equation}

ここで $|\text{pool}_k|$ はプーリング領域内の要素数である。

$\eqref{eq:17-16-1}$ において、$Y_k$ はプーリング領域内のすべての要素に均等に依存する。偏微分を計算する。

\begin{equation}\dfrac{\partial Y_k}{\partial X_{ij}} = \dfrac{1}{|\text{pool}_k|} \quad \text{for all } (i, j) \in \text{pool}_k \label{eq:17-16-2}\end{equation}

損失 $L$ は出力 $Y_k$ を通じて $X_{ij}$ に依存する。連鎖律(1.26)を適用する。

\begin{equation}\dfrac{\partial L}{\partial X_{ij}} = \dfrac{\partial L}{\partial Y_k} \dfrac{\partial Y_k}{\partial X_{ij}} \label{eq:17-16-3}\end{equation}

$\eqref{eq:17-16-2}$ を $\eqref{eq:17-16-3}$ に代入して最終結果を得る。

\begin{equation}\dfrac{\partial L}{\partial X_{ij}} = \dfrac{1}{|\text{pool}_k|} \dfrac{\partial L}{\partial Y_k} \label{eq:17-16-4}\end{equation}

補足:$\eqref{eq:17-16-4}$ より、Average pooling では勾配がプーリング領域内の全要素に均等に分配される。これは 17.15 の Max pooling と対照的であり、Max pooling は勝者総取り、Average pooling は均等分配である。

17.17 埋め込み層の勾配

公式:$\displaystyle\dfrac{\partial L}{\partial \boldsymbol{E}_{i,:}} = \displaystyle\sum_{n: \text{idx}_n = i} \displaystyle\dfrac{\partial L}{\partial \boldsymbol{o}_n}$
条件:$\boldsymbol{o}_n = \boldsymbol{E}_{\text{idx}_n, :}$(lookup操作)、$\boldsymbol{E} \in \mathbb{R}^{V \times d}$ は埋め込み行列
証明

埋め込み層は lookup テーブルとして機能する。入力インデックス $\text{idx}_n$ に対応する埋め込み行列の行を出力する。

\begin{equation}\boldsymbol{o}_n = \boldsymbol{E}_{\text{idx}_n, :} \label{eq:17-17-1}\end{equation}

ここで $\boldsymbol{E} \in \mathbb{R}^{V \times d}$ は語彙サイズ $V$、埋め込み次元 $d$ の埋め込み行列である。

$\eqref{eq:17-17-1}$ において、出力 $\boldsymbol{o}_n$ は埋め込み行列の $\text{idx}_n$ 行目のコピーである。偏微分を考える。

\begin{equation}\dfrac{\partial (\boldsymbol{o}_n)_j}{\partial E_{ik}} = \begin{cases} 1 & \text{if } i = \text{idx}_n \text{ and } j = k \\ 0 & \text{otherwise} \end{cases} \label{eq:17-17-2}\end{equation}

$\eqref{eq:17-17-2}$ はクロネッカーのデルタを用いて次のように書ける。

\begin{equation}\dfrac{\partial (\boldsymbol{o}_n)_j}{\partial E_{ik}} = \delta_{i, \text{idx}_n} \delta_{jk} \label{eq:17-17-3}\end{equation}

損失 $L$ は出力 $\boldsymbol{o}_n$($n = 0, 1, \ldots, N-1$)を通じて $\boldsymbol{E}$ に依存する。連鎖律(1.26)を適用する。

\begin{equation}\dfrac{\partial L}{\partial E_{ik}} = \displaystyle\sum_n \displaystyle\sum_j \dfrac{\partial L}{\partial (\boldsymbol{o}_n)_j} \dfrac{\partial (\boldsymbol{o}_n)_j}{\partial E_{ik}} \label{eq:17-17-4}\end{equation}

$\eqref{eq:17-17-3}$ を $\eqref{eq:17-17-4}$ に代入する。$\delta_{i, \text{idx}_n}$ により $\text{idx}_n = i$ の項のみ、$\delta_{jk}$ により $j = k$ の項のみ残る。

\begin{equation}\dfrac{\partial L}{\partial E_{ik}} = \displaystyle\sum_{n: \text{idx}_n = i} \dfrac{\partial L}{\partial (\boldsymbol{o}_n)_k} \label{eq:17-17-5}\end{equation}

$\eqref{eq:17-17-5}$ を行ベクトル形式で書くと最終結果を得る。

\begin{equation}\dfrac{\partial L}{\partial \boldsymbol{E}_{i,:}} = \displaystyle\sum_{n: \text{idx}_n = i} \dfrac{\partial L}{\partial \boldsymbol{o}_n} \label{eq:17-17-6}\end{equation}

補足:$\eqref{eq:17-17-6}$ より、埋め込み行列の各行の勾配は、その行が参照されたすべての位置からの勾配の和である。バッチ内で同じインデックスが複数回参照される場合、勾配は累積される。実装では sparse gradient として効率的に処理される。

17.18 L2正則化の勾配

公式:$\displaystyle\dfrac{\partial}{\partial \boldsymbol{W}} \displaystyle\dfrac{\lambda}{2}\|\boldsymbol{W}\|_F^2 = \lambda \boldsymbol{W}$
条件:$\|\boldsymbol{W}\|_F^2 = \displaystyle\sum_{ij} W_{ij}^2$(Frobeniusノルムの2乗)
証明

L2 正則化項は重み行列の Frobenius ノルムの2乗にスケーリング係数を掛けたものである。

\begin{equation}R = \dfrac{\lambda}{2}\|\boldsymbol{W}\|_F^2 = \dfrac{\lambda}{2} \displaystyle\sum_{i,j} W_{ij}^2 \label{eq:17-18-1}\end{equation}

ここで $\lambda > 0$ は正則化の強さを制御するハイパーパラメータである。

$\eqref{eq:17-18-1}$ を $W_{pq}$ で偏微分する。和の中で $W_{pq}$ を含む項は $(i, j) = (p, q)$ の項のみである。

\begin{equation}\dfrac{\partial R}{\partial W_{pq}} = \dfrac{\partial}{\partial W_{pq}} \left( \dfrac{\lambda}{2} \displaystyle\sum_{i,j} W_{ij}^2 \right) = \dfrac{\lambda}{2} \cdot 2W_{pq} = \lambda W_{pq} \label{eq:17-18-2}\end{equation}

$\eqref{eq:17-18-2}$ はすべての $(p, q)$ について成り立つので、行列形式で最終結果を得る。

\begin{equation}\dfrac{\partial}{\partial \boldsymbol{W}} \dfrac{\lambda}{2}\|\boldsymbol{W}\|_F^2 = \lambda \boldsymbol{W} \label{eq:17-18-3}\end{equation}

補足:$\eqref{eq:17-18-3}$ より、L2 正則化(Weight Decay)を加えた全損失の勾配は $\displaystyle\dfrac{\partial L_{\text{total}}}{\partial \boldsymbol{W}} = \displaystyle\dfrac{\partial L}{\partial \boldsymbol{W}} + \lambda \boldsymbol{W}$ となる。これは重みを原点方向に引き寄せる効果があり、過学習を抑制する。
より簡潔な導出は 12.10 を参照。

17.19 L1正則化の劣勾配

公式:$\displaystyle\dfrac{\partial}{\partial \boldsymbol{W}} \lambda\|\boldsymbol{W}\|_1 = \lambda \cdot \text{sign}(\boldsymbol{W})$
条件:$\|\boldsymbol{W}\|_1 = \displaystyle\sum_{ij} |W_{ij}|$、$\text{sign}(x) = \begin{cases} 1 & x > 0 \\ 0 & x = 0 \\ -1 & x < 0 \end{cases}$(ただし $x = 0$ では $[-1, 1]$ の任意の値が劣勾配)
証明

L1 正則化項は重み行列の L1 ノルムにスケーリング係数を掛けたものである。

\begin{equation}R = \lambda\|\boldsymbol{W}\|_1 = \lambda \displaystyle\sum_{i,j} |W_{ij}| \label{eq:17-19-1}\end{equation}

絶対値関数 $|x|$ の微分を考える。$x > 0$ と $x < 0$ の場合で分ける。

\begin{equation}\dfrac{d|x|}{dx} = \begin{cases} 1 & x > 0 \\ -1 & x < 0 \end{cases} \label{eq:17-19-2}\end{equation}

$x = 0$ では絶対値関数は微分不可能である。しかし、劣勾配(subgradient)の概念を用いると、$x = 0$ での劣勾配は区間 $[-1, 1]$ の任意の値を取ることができる。

\begin{equation}\partial |x| \Big|_{x=0} = [-1, 1] \label{eq:17-19-3}\end{equation}

$\eqref{eq:17-19-2}$ と $\eqref{eq:17-19-3}$ を統一して符号関数 $\text{sign}(x)$ を用いて書く。実装では $x = 0$ のとき $\text{sign}(0) = 0$ とすることが多い。

\begin{equation}\dfrac{d|x|}{dx} = \text{sign}(x) = \begin{cases} 1 & x > 0 \\ 0 & x = 0 \\ -1 & x < 0 \end{cases} \label{eq:17-19-4}\end{equation}

$\eqref{eq:17-19-1}$ を $W_{pq}$ で偏微分する。和の中で $|W_{pq}|$ を含む項は $(i, j) = (p, q)$ の項のみである。

\begin{equation}\dfrac{\partial R}{\partial W_{pq}} = \lambda \dfrac{\partial |W_{pq}|}{\partial W_{pq}} = \lambda \cdot \text{sign}(W_{pq}) \label{eq:17-19-5}\end{equation}

$\eqref{eq:17-19-5}$ はすべての $(p, q)$ について成り立つので、行列形式で最終結果を得る。

\begin{equation}\dfrac{\partial}{\partial \boldsymbol{W}} \lambda\|\boldsymbol{W}\|_1 = \lambda \cdot \text{sign}(\boldsymbol{W}) \label{eq:17-19-6}\end{equation}

補足:$\eqref{eq:17-19-6}$ より、L1 正則化はスパース性を促進する。$W_{ij}$ が 0 に近づくと勾配の符号が変わり、ちょうど 0 で「引っかかる」効果がある。これにより、多くの重みが厳密に 0 になりやすい。
より簡潔な導出は 12.11 を参照。

17.20 Gauss分布のKLダイバージェンス

公式:$D_{\text{KL}}(\mathcal{N}(\boldsymbol{\mu}, \text{diag}(\boldsymbol{\sigma}^2)) \| \mathcal{N}(\boldsymbol{0}, \boldsymbol{I})) = \displaystyle\dfrac{1}{2}\displaystyle\sum_i (\mu_i^2 + \sigma_i^2 - 1 - \log\sigma_i^2)$
条件:$\mathcal{N}(\boldsymbol{\mu}, \text{diag}(\boldsymbol{\sigma}^2))$ は対角共分散を持つ多変量正規分布
証明

KL ダイバージェンスの定義は次のようになる。

\begin{equation}D_{\text{KL}}(p \| q) = \mathbb{E}_p[\log p - \log q] = \displaystyle\int p(x) \log \dfrac{p(x)}{q(x)} dx \label{eq:17-20-1}\end{equation}

まず1次元の場合を考える。$p = \mathcal{N}(\mu, \sigma^2)$、$q = \mathcal{N}(0, 1)$ とする。

正規分布の対数を書き下す。

\begin{equation}\log p(x) = -\dfrac{1}{2}\log(2\pi\sigma^2) - \dfrac{(x-\mu)^2}{2\sigma^2} \label{eq:17-20-2}\end{equation}

\begin{equation}\log q(x) = -\dfrac{1}{2}\log(2\pi) - \dfrac{x^2}{2} \label{eq:17-20-3}\end{equation}

$\eqref{eq:17-20-2}$ と $\eqref{eq:17-20-3}$ の差を計算する。

\begin{equation}\log p(x) - \log q(x) = -\dfrac{1}{2}\log\sigma^2 - \dfrac{(x-\mu)^2}{2\sigma^2} + \dfrac{x^2}{2} \label{eq:17-20-4}\end{equation}

$\eqref{eq:17-20-4}$ の $p$ に関する期待値を計算する。期待値の性質を用いる。

\begin{equation}\mathbb{E}_p[(x-\mu)^2] = \sigma^2 \label{eq:17-20-5}\end{equation}

\begin{equation}\mathbb{E}_p[x^2] = \text{Var}(x) + (\mathbb{E}_p[x])^2 = \sigma^2 + \mu^2 \label{eq:17-20-6}\end{equation}

$\eqref{eq:17-20-5}$ と $\eqref{eq:17-20-6}$ を $\eqref{eq:17-20-4}$ の期待値に代入する。

\begin{equation}D_{\text{KL}} = \mathbb{E}_p[\log p - \log q] = -\dfrac{1}{2}\log\sigma^2 - \dfrac{\sigma^2}{2\sigma^2} + \dfrac{\sigma^2 + \mu^2}{2} \label{eq:17-20-7}\end{equation}

$\eqref{eq:17-20-7}$ を整理する。

\begin{equation}D_{\text{KL}} = -\dfrac{1}{2}\log\sigma^2 - \dfrac{1}{2} + \dfrac{\sigma^2 + \mu^2}{2} = \dfrac{1}{2}(\mu^2 + \sigma^2 - 1 - \log\sigma^2) \label{eq:17-20-8}\end{equation}

多次元で各次元が独立な場合、KL ダイバージェンスは各次元の和となる。

\begin{equation}D_{\text{KL}}(\mathcal{N}(\boldsymbol{\mu}, \text{diag}(\boldsymbol{\sigma}^2)) \| \mathcal{N}(\boldsymbol{0}, \boldsymbol{I})) = \dfrac{1}{2}\displaystyle\sum_{i=1}^{d} (\mu_i^2 + \sigma_i^2 - 1 - \log\sigma_i^2) \label{eq:17-20-9}\end{equation}

補足:$\eqref{eq:17-20-9}$ は VAE(変分オートエンコーダ)の損失関数の正則化項として使われる。潜在変数の分布を標準正規分布に近づける効果がある。

17.21 KLダイバージェンスの平均に関する勾配

公式:$\displaystyle\dfrac{\partial D_{\text{KL}}}{\partial \mu_i} = \mu_i$
条件:$D_{\text{KL}} = \displaystyle\dfrac{1}{2}\displaystyle\sum_i (\mu_i^2 + \sigma_i^2 - 1 - \log\sigma_i^2)$
証明

17.20 の $\eqref{eq:17-20-9}$ より、KL ダイバージェンスは次のように表される。

\begin{equation}D_{\text{KL}} = \dfrac{1}{2}\displaystyle\sum_{j=1}^{d} (\mu_j^2 + \sigma_j^2 - 1 - \log\sigma_j^2) \label{eq:17-21-1}\end{equation}

$\eqref{eq:17-21-1}$ において、$\mu_i$ に関する項は $\displaystyle\dfrac{1}{2}\mu_i^2$ のみである。他の次元 $j \neq i$ の項は $\mu_i$ に依存しない。

$\eqref{eq:17-21-1}$ を $\mu_i$ で偏微分する。

\begin{equation}\dfrac{\partial D_{\text{KL}}}{\partial \mu_i} = \dfrac{\partial}{\partial \mu_i} \dfrac{1}{2}\mu_i^2 = \dfrac{1}{2} \cdot 2\mu_i = \mu_i \label{eq:17-21-2}\end{equation}

補足:$\eqref{eq:17-21-2}$ より、$\mu_i = 0$ のとき勾配は 0 になり、KL ダイバージェンスは $\mu_i$ に関して最小となる。これは事前分布 $\mathcal{N}(0, 1)$ の平均 0 に一致する。

17.22 KLダイバージェンスの標準偏差に関する勾配

公式:$\displaystyle\dfrac{\partial D_{\text{KL}}}{\partial \sigma_i} = \sigma_i - \displaystyle\dfrac{1}{\sigma_i}$
条件:$D_{\text{KL}} = \displaystyle\dfrac{1}{2}\displaystyle\sum_i (\mu_i^2 + \sigma_i^2 - 1 - \log\sigma_i^2)$
証明

17.21 の $\eqref{eq:17-21-1}$ より、KL ダイバージェンスは次のように表される。

\begin{equation}D_{\text{KL}} = \dfrac{1}{2}\displaystyle\sum_{j=1}^{d} (\mu_j^2 + \sigma_j^2 - 1 - \log\sigma_j^2) \label{eq:17-22-1}\end{equation}

$\eqref{eq:17-22-1}$ において、$\sigma_i$ に関する項は $\displaystyle\dfrac{1}{2}(\sigma_i^2 - \log\sigma_i^2)$ である。他の次元 $j \neq i$ の項は $\sigma_i$ に依存しない。

$\log\sigma_i^2 = 2\log\sigma_i$ であることに注意して、$\sigma_i$ で偏微分する。

\begin{equation}\dfrac{\partial}{\partial \sigma_i} \sigma_i^2 = 2\sigma_i \label{eq:17-22-2}\end{equation}

\begin{equation}\dfrac{\partial}{\partial \sigma_i} \log\sigma_i^2 = \dfrac{\partial}{\partial \sigma_i} (2\log\sigma_i) = \dfrac{2}{\sigma_i} \label{eq:17-22-3}\end{equation}

$\eqref{eq:17-22-2}$ と $\eqref{eq:17-22-3}$ を用いて $D_{\text{KL}}$ を $\sigma_i$ で偏微分する。

\begin{equation}\dfrac{\partial D_{\text{KL}}}{\partial \sigma_i} = \dfrac{1}{2}\left(2\sigma_i - \dfrac{2}{\sigma_i}\right) = \sigma_i - \dfrac{1}{\sigma_i} \label{eq:17-22-4}\end{equation}

補足:$\eqref{eq:17-22-4}$ より、$\sigma_i = 1$ のとき勾配は $1 - 1 = 0$ になり、KL ダイバージェンスは $\sigma_i$ に関して最小となる。これは事前分布 $\mathcal{N}(0, 1)$ の標準偏差 1 に一致する。VAE では数値安定性のため $\log\sigma^2$ を直接出力することが多く、その場合の勾配は $\displaystyle\dfrac{1}{2}(e^{\log\sigma^2} - 1)$ となる。

12.5 正則化の微分

機械学習で頻出する正則化項の勾配計算。

12.10 L2正則化(Weight Decay)

公式:$\displaystyle\dfrac{\partial}{\partial \boldsymbol{W}} \dfrac{\lambda}{2}\|\boldsymbol{W}\|_F^2 = \lambda \boldsymbol{W}$
条件:$\boldsymbol{W} \in \mathbb{R}^{M \times N}$、$\lambda > 0$ は正則化パラメータ
証明

Frobeniusノルムの2乗は $\|\boldsymbol{W}\|_F^2 = \mathrm{tr}(\boldsymbol{W}^\top \boldsymbol{W}) = \displaystyle\sum_{i,j} W_{ij}^2$ である。

\begin{equation} \dfrac{\partial}{\partial W_{kl}} \dfrac{\lambda}{2} \displaystyle\sum_{i,j} W_{ij}^2 = \dfrac{\lambda}{2} \cdot 2 W_{kl} = \lambda W_{kl} \label{eq:12-10-1} \end{equation}

全成分をまとめると行列形式で次のようになる。

\begin{equation} \dfrac{\partial}{\partial \boldsymbol{W}} \dfrac{\lambda}{2}\|\boldsymbol{W}\|_F^2 = \lambda \boldsymbol{W} \label{eq:12-10-2} \end{equation}

補足:L2正則化は勾配降下の各ステップで重みを $1 - \eta\lambda$ 倍に縮小する効果があり、weight decay とも呼ばれる。

12.11 L1正則化(劣勾配)

公式:$\displaystyle\dfrac{\partial}{\partial \boldsymbol{W}} \lambda\|\boldsymbol{W}\|_1 = \lambda \cdot \mathrm{sign}(\boldsymbol{W})$
条件:$\|\boldsymbol{W}\|_1 = \displaystyle\sum_{i,j}|W_{ij}|$、$W_{ij} = 0$ では劣勾配 $[-1, 1]$
証明

$\|\boldsymbol{W}\|_1 = \displaystyle\sum_{i,j} |W_{ij}|$ であり、各成分 $|W_{kl}|$ は他の成分と独立に微分できる。

\begin{equation} \dfrac{\partial |W_{kl}|}{\partial W_{kl}} = \begin{cases} 1 & (W_{kl} > 0) \\ -1 & (W_{kl} < 0) \\ [-1, 1] & (W_{kl} = 0) \end{cases} = \mathrm{sign}(W_{kl}) \label{eq:12-11-1} \end{equation}

ここで $W_{kl} = 0$ の場合は絶対値関数が微分不可能であるが、劣微分(subdifferential)$\partial|W_{kl}| = [-1, 1]$ が存在する。実用上は $\mathrm{sign}(0) = 0$ と定める近似が用いられる。

全成分をまとめると次のようになる。

\begin{equation} \dfrac{\partial}{\partial \boldsymbol{W}} \lambda\|\boldsymbol{W}\|_1 = \lambda \cdot \mathrm{sign}(\boldsymbol{W}) \label{eq:12-11-2} \end{equation}

補足:L1正則化はスパース解を促進する。$\mathrm{sign}(0) = 0$ と近似する代わりに、近接勾配法(proximal gradient method)の soft thresholding 演算子を用いる方法が理論的にはより適切である。

12.12 LASSO勾配(L1正則化付き回帰)

公式:$\displaystyle\dfrac{\partial}{\partial \boldsymbol{\alpha}}\left(\dfrac{1}{2}\|\boldsymbol{x} - \boldsymbol{D}\boldsymbol{\alpha}\|^2 + \lambda\|\boldsymbol{\alpha}\|_1\right) = \boldsymbol{D}^\top(\boldsymbol{D}\boldsymbol{\alpha} - \boldsymbol{x}) + \lambda \cdot \mathrm{sign}(\boldsymbol{\alpha})$
条件:$\boldsymbol{x} \in \mathbb{R}^M$, $\boldsymbol{D} \in \mathbb{R}^{M \times N}$, $\boldsymbol{\alpha} \in \mathbb{R}^N$, $\lambda > 0$。$\alpha_i = 0$ の成分では劣勾配。
証明

目的関数を $L = L_{\text{data}} + L_{\text{reg}}$ と分解する。

\begin{equation} L_{\text{data}} = \dfrac{1}{2}\|\boldsymbol{x} - \boldsymbol{D}\boldsymbol{\alpha}\|^2, \qquad L_{\text{reg}} = \lambda\|\boldsymbol{\alpha}\|_1 \label{eq:12-12-1} \end{equation}

データ項の勾配は12.9 と同様に計算できる。$\boldsymbol{r} = \boldsymbol{D}\boldsymbol{\alpha} - \boldsymbol{x}$ とおくと次のようになる。

\begin{equation} \dfrac{\partial L_{\text{data}}}{\partial \boldsymbol{\alpha}} = \boldsymbol{D}^\top(\boldsymbol{D}\boldsymbol{\alpha} - \boldsymbol{x}) \label{eq:12-12-2} \end{equation}

正則化項の劣勾配は12.11 より次のようになる。

\begin{equation} \dfrac{\partial L_{\text{reg}}}{\partial \boldsymbol{\alpha}} = \lambda \cdot \mathrm{sign}(\boldsymbol{\alpha}) \label{eq:12-12-3} \end{equation}

$\eqref{eq:12-12-2}$ と $\eqref{eq:12-12-3}$ を合わせて公式を得る。

\begin{equation} \dfrac{\partial L}{\partial \boldsymbol{\alpha}} = \boldsymbol{D}^\top(\boldsymbol{D}\boldsymbol{\alpha} - \boldsymbol{x}) + \lambda \cdot \mathrm{sign}(\boldsymbol{\alpha}) \label{eq:12-12-4} \end{equation}

補足:LASSO(Least Absolute Shrinkage and Selection Operator)はスパースコーディング・圧縮センシングの基盤である。非微分点の扱いには ISTA(Iterative Shrinkage-Thresholding Algorithm)やその加速版 FISTA が用いられる。

12.6 応用:画像再構成の正則化

不良設定問題に対する正則化手法の勾配計算。医用画像再構成、信号処理で使用される。

12.13 Tikhonov正則化の勾配

公式:$\displaystyle\dfrac{\partial J}{\partial \boldsymbol{x}} = 2\boldsymbol{A}^\top(\boldsymbol{A}\boldsymbol{x} - \boldsymbol{y}) + 2\lambda\boldsymbol{L}^\top\boldsymbol{L}\boldsymbol{x}$
条件:$J(\boldsymbol{x}) = \|\boldsymbol{A}\boldsymbol{x} - \boldsymbol{y}\|^2 + \lambda\|\boldsymbol{L}\boldsymbol{x}\|^2$
証明

第1項 $\|\boldsymbol{A}\boldsymbol{x} - \boldsymbol{y}\|^2 = (\boldsymbol{A}\boldsymbol{x} - \boldsymbol{y})^\top(\boldsymbol{A}\boldsymbol{x} - \boldsymbol{y})$ の勾配を計算する。

\begin{equation} \dfrac{\partial}{\partial \boldsymbol{x}}\|\boldsymbol{A}\boldsymbol{x} - \boldsymbol{y}\|^2 = 2\boldsymbol{A}^\top(\boldsymbol{A}\boldsymbol{x} - \boldsymbol{y}) \label{eq:12-13-1} \end{equation}

第2項 $\lambda\|\boldsymbol{L}\boldsymbol{x}\|^2 = \lambda\boldsymbol{x}^\top\boldsymbol{L}^\top\boldsymbol{L}\boldsymbol{x}$ の勾配を計算する。

\begin{equation} \dfrac{\partial}{\partial \boldsymbol{x}}\lambda\|\boldsymbol{L}\boldsymbol{x}\|^2 = 2\lambda\boldsymbol{L}^\top\boldsymbol{L}\boldsymbol{x} \label{eq:12-13-2} \end{equation}

$\eqref{eq:12-13-1}$ と $\eqref{eq:12-13-2}$ を合わせて公式を得る。

補足:医用画像再構成では $\boldsymbol{A}$ はRadon変換(CT)またはFourier変換(MRI)、$\boldsymbol{L}$ は微分オペレータなど。

12.14 Tikhonov正則化解

公式:$\boldsymbol{x}^* = (\boldsymbol{A}^\top\boldsymbol{A} + \lambda\boldsymbol{L}^\top\boldsymbol{L})^{-1}\boldsymbol{A}^\top\boldsymbol{y}$
条件:$\boldsymbol{A}^\top\boldsymbol{A} + \lambda\boldsymbol{L}^\top\boldsymbol{L}$ が正則
証明

最適性条件 $\nabla J = 0$ より次式を得る。

\begin{equation} 2\boldsymbol{A}^\top(\boldsymbol{A}\boldsymbol{x} - \boldsymbol{y}) + 2\lambda\boldsymbol{L}^\top\boldsymbol{L}\boldsymbol{x} = \boldsymbol{0} \label{eq:12-14-1} \end{equation}

$\eqref{eq:12-14-1}$ を整理すると次式を得る。

\begin{equation} (\boldsymbol{A}^\top\boldsymbol{A} + \lambda\boldsymbol{L}^\top\boldsymbol{L})\boldsymbol{x} = \boldsymbol{A}^\top\boldsymbol{y} \label{eq:12-14-2} \end{equation}

$\lambda > 0$ のとき $\boldsymbol{A}^\top\boldsymbol{A} + \lambda\boldsymbol{L}^\top\boldsymbol{L}$ は正定値(または半正定値+正定値)で正則となり、解を得る。

\begin{equation} \boldsymbol{x}^* = (\boldsymbol{A}^\top\boldsymbol{A} + \lambda\boldsymbol{L}^\top\boldsymbol{L})^{-1}\boldsymbol{A}^\top\boldsymbol{y} \label{eq:12-14-3} \end{equation}

補足:$\boldsymbol{L} = \boldsymbol{I}$ のとき標準形Tikhonov正則化。$\boldsymbol{A}$ が不良条件でも $\lambda$ により解が安定化される。

12.15 全変動正則化の劣勾配

公式:$\displaystyle\dfrac{\partial}{\partial \boldsymbol{x}}\text{TV}(\boldsymbol{x}) = -\text{div}\left(\displaystyle\dfrac{\nabla \boldsymbol{x}}{|\nabla \boldsymbol{x}|}\right)$
条件:$\text{TV}(\boldsymbol{x}) = \displaystyle\int |\nabla \boldsymbol{x}| \, d\Omega$(連続形式)
証明

等方性全変動は次式で定義される。

\begin{equation} \text{TV}(\boldsymbol{x}) = \displaystyle\int_\Omega |\nabla \boldsymbol{x}| \, d\Omega = \displaystyle\int_\Omega \sqrt{|\partial_1 x|^2 + |\partial_2 x|^2} \, d\Omega \label{eq:12-15-1} \end{equation}

変分法により、$\boldsymbol{x}$ に摂動 $\boldsymbol{x} + \varepsilon\boldsymbol{\phi}$ を加え $\varepsilon$ で微分し $\varepsilon = 0$ で評価する。

\begin{equation} \dfrac{d}{d\varepsilon}\text{TV}(\boldsymbol{x} + \varepsilon\boldsymbol{\phi})\bigg|_{\varepsilon=0} = \displaystyle\int_\Omega \dfrac{\nabla \boldsymbol{x} \cdot \nabla \boldsymbol{\phi}}{|\nabla \boldsymbol{x}|} \, d\Omega \label{eq:12-15-2} \end{equation}

部分積分(Greenの定理)を適用し境界項が消えると次式を得る。

\begin{equation} = -\displaystyle\int_\Omega \text{div}\left(\dfrac{\nabla \boldsymbol{x}}{|\nabla \boldsymbol{x}|}\right) \boldsymbol{\phi} \, d\Omega \label{eq:12-15-3} \end{equation}

よって $\delta \text{TV} / \delta \boldsymbol{x} = -\text{div}(\nabla \boldsymbol{x} / |\nabla \boldsymbol{x}|)$。

補足:$|\nabla \boldsymbol{x}| = 0$ で非微分。実用上は $|\nabla \boldsymbol{x}|_\varepsilon = \sqrt{|\nabla \boldsymbol{x}|^2 + \varepsilon^2}$ で滑らかに近似する。