SVD・Fisher情報・強化学習・NLP

高度な機械学習手法におけるベクトル微分/行列微分の応用。

表記規約
本ページの公式は分母レイアウト(denominator layout)に基づく。詳細はレイアウト規約を参照。

18.1 特異値分解(SVD)と行列微分

特異値分解を含む損失関数の最適化において行列微分が必要となる。低ランク近似、行列補完、主成分分析などで使用。

18.1 低ランク近似の勾配

公式:$\displaystyle\frac{\partial}{\partial \boldsymbol{X}} \|\boldsymbol{X} - \boldsymbol{U}\boldsymbol{V}^\top\|_F^2 = 2(\boldsymbol{X} - \boldsymbol{U}\boldsymbol{V}^\top)$
条件:$\boldsymbol{X} \in \mathbb{R}^{m \times n}$、$\boldsymbol{U} \in \mathbb{R}^{m \times k}$、$\boldsymbol{V} \in \mathbb{R}^{n \times k}$($k$ はランク)
証明

$\boldsymbol{A} = \boldsymbol{U}\boldsymbol{V}^\top$ とおくと、12.6 より直接得られる。

\begin{equation}\frac{\partial}{\partial \boldsymbol{X}} \|\boldsymbol{X} - \boldsymbol{A}\|_F^2 = 2(\boldsymbol{X} - \boldsymbol{A}) = 2(\boldsymbol{X} - \boldsymbol{U}\boldsymbol{V}^\top) \label{eq:18-1-1}\end{equation}

補足:行列補完問題では、観測された要素についてのみこの勾配を計算する。Netflix問題などの推薦システムで使用。

18.2 因子行列 $\boldsymbol{U}$ の勾配

公式:$\displaystyle\frac{\partial}{\partial \boldsymbol{U}} \|\boldsymbol{X} - \boldsymbol{U}\boldsymbol{V}^\top\|_F^2 = -2(\boldsymbol{X} - \boldsymbol{U}\boldsymbol{V}^\top)\boldsymbol{V}$
条件:$\boldsymbol{X}$ は定数行列
証明

$\boldsymbol{E} = \boldsymbol{X} - \boldsymbol{U}\boldsymbol{V}^\top$ とおく。損失関数は $L = \|\boldsymbol{E}\|_F^2 = \text{tr}(\boldsymbol{E}^\top\boldsymbol{E})$ である。

\begin{equation}L = \text{tr}((\boldsymbol{X} - \boldsymbol{U}\boldsymbol{V}^\top)^\top(\boldsymbol{X} - \boldsymbol{U}\boldsymbol{V}^\top)) \label{eq:18-2-1}\end{equation}

$\eqref{eq:18-2-1}$ を展開する。

\begin{equation}L = \text{tr}(\boldsymbol{X}^\top\boldsymbol{X}) - 2\text{tr}(\boldsymbol{X}^\top\boldsymbol{U}\boldsymbol{V}^\top) + \text{tr}(\boldsymbol{V}\boldsymbol{U}^\top\boldsymbol{U}\boldsymbol{V}^\top) \label{eq:18-2-2}\end{equation}

$\boldsymbol{U}$ に関する項のみを考える。第2項は $\text{tr}(\boldsymbol{V}^\top\boldsymbol{X}^\top\boldsymbol{U})$(トレースの巡回性)。

3.1 より $\displaystyle\frac{\partial}{\partial \boldsymbol{U}}\text{tr}(\boldsymbol{V}^\top\boldsymbol{X}^\top\boldsymbol{U}) = \boldsymbol{X}\boldsymbol{V}$。

第3項は 3.5 の形式で、$\displaystyle\frac{\partial}{\partial \boldsymbol{U}}\text{tr}(\boldsymbol{V}\boldsymbol{U}^\top\boldsymbol{U}\boldsymbol{V}^\top) = 2\boldsymbol{U}\boldsymbol{V}^\top\boldsymbol{V}$。

\begin{equation}\frac{\partial L}{\partial \boldsymbol{U}} = -2\boldsymbol{X}\boldsymbol{V} + 2\boldsymbol{U}\boldsymbol{V}^\top\boldsymbol{V} = -2(\boldsymbol{X} - \boldsymbol{U}\boldsymbol{V}^\top)\boldsymbol{V} \label{eq:18-2-3}\end{equation}

補足:最適性条件 $\nabla_{\boldsymbol{U}} L = 0$ より $\boldsymbol{U} = \boldsymbol{X}\boldsymbol{V}(\boldsymbol{V}^\top\boldsymbol{V})^{-1}$。$\boldsymbol{V}$ が正規直交なら $\boldsymbol{U} = \boldsymbol{X}\boldsymbol{V}$。

18.3 因子行列 $\boldsymbol{V}$ の勾配

公式:$\displaystyle\frac{\partial}{\partial \boldsymbol{V}} \|\boldsymbol{X} - \boldsymbol{U}\boldsymbol{V}^\top\|_F^2 = -2(\boldsymbol{X} - \boldsymbol{U}\boldsymbol{V}^\top)^\top\boldsymbol{U}$
条件:$\boldsymbol{X}$ は定数行列
証明

18.2 と同様の手順で導出。$\boldsymbol{V}$ に関する微分を計算する。

\begin{equation}\frac{\partial L}{\partial \boldsymbol{V}} = -2\boldsymbol{X}^\top\boldsymbol{U} + 2\boldsymbol{V}\boldsymbol{U}^\top\boldsymbol{U} = -2(\boldsymbol{X}^\top - \boldsymbol{V}\boldsymbol{U}^\top)\boldsymbol{U} \label{eq:18-3-1}\end{equation}

$(\boldsymbol{X}^\top - \boldsymbol{V}\boldsymbol{U}^\top) = (\boldsymbol{X} - \boldsymbol{U}\boldsymbol{V}^\top)^\top$ より公式を得る。

補足:交互最小二乗法(ALS)では $\boldsymbol{U}$ と $\boldsymbol{V}$ を交互に更新する。各ステップは凸最適化問題となり、収束が保証される。

18.4 核ノルムの劣勾配

公式:$\partial \|\boldsymbol{X}\|_* = \{\boldsymbol{U}\boldsymbol{V}^\top + \boldsymbol{W} : \boldsymbol{U}^\top\boldsymbol{W} = \boldsymbol{0}, \boldsymbol{W}\boldsymbol{V} = \boldsymbol{0}, \|\boldsymbol{W}\|_2 \leq 1\}$
条件:$\boldsymbol{X} = \boldsymbol{U}\boldsymbol{\Sigma}\boldsymbol{V}^\top$ がSVD、$\|\cdot\|_*$ は核ノルム(特異値の和)、$\|\cdot\|_2$ はスペクトルノルム
証明

核ノルムの定義は $\|\boldsymbol{X}\|_* = \sum_{i} \sigma_i = \text{tr}(\boldsymbol{\Sigma})$ である。

$\boldsymbol{X}$ がフルランクのとき、劣勾配は一意で $\boldsymbol{U}\boldsymbol{V}^\top$ となる。

ランク落ちの場合、零空間への射影成分 $\boldsymbol{W}$ の自由度がある。

補足:核ノルム最小化は行列補完問題の凸緩和として使われる。近接勾配法で解く際にこの劣勾配を使用。

18.5 Fisher情報行列

パラメータ推定の精度を表すFisher情報行列は、対数尤度の2階微分から計算される。自然勾配法、信頼領域法で使用。

18.5 Fisher情報行列の定義

公式:$\boldsymbol{F}(\boldsymbol{\theta}) = \mathbb{E}\left[\nabla_{\boldsymbol{\theta}} \log p(\boldsymbol{x}|\boldsymbol{\theta}) \nabla_{\boldsymbol{\theta}} \log p(\boldsymbol{x}|\boldsymbol{\theta})^\top\right]$
条件:$p(\boldsymbol{x}|\boldsymbol{\theta})$ は確率密度関数、$\boldsymbol{\theta}$ はパラメータベクトル
証明

対数尤度のスコア関数を $\boldsymbol{s}(\boldsymbol{\theta}) = \nabla_{\boldsymbol{\theta}} \log p(\boldsymbol{x}|\boldsymbol{\theta})$ と定義する。

\begin{equation}\boldsymbol{F}(\boldsymbol{\theta}) = \mathbb{E}[\boldsymbol{s}(\boldsymbol{\theta})\boldsymbol{s}(\boldsymbol{\theta})^\top] = \text{Cov}(\boldsymbol{s}(\boldsymbol{\theta})) \label{eq:18-5-1}\end{equation}

$\mathbb{E}[\boldsymbol{s}(\boldsymbol{\theta})] = \boldsymbol{0}$ であることを示す。

\begin{equation}\mathbb{E}[\boldsymbol{s}(\boldsymbol{\theta})] = \int \frac{\nabla_{\boldsymbol{\theta}} p(\boldsymbol{x}|\boldsymbol{\theta})}{p(\boldsymbol{x}|\boldsymbol{\theta})} p(\boldsymbol{x}|\boldsymbol{\theta}) d\boldsymbol{x} = \nabla_{\boldsymbol{\theta}} \int p(\boldsymbol{x}|\boldsymbol{\theta}) d\boldsymbol{x} = \nabla_{\boldsymbol{\theta}} 1 = \boldsymbol{0} \label{eq:18-5-2}\end{equation}

補足:Fisher情報行列は正半定値である。対角成分 $F_{ii}$ はパラメータ $\theta_i$ に関する情報量を表す。

18.6 Fisher情報とHessianの関係

公式:$\boldsymbol{F}(\boldsymbol{\theta}) = -\mathbb{E}\left[\nabla_{\boldsymbol{\theta}}^2 \log p(\boldsymbol{x}|\boldsymbol{\theta})\right]$
条件:正則条件(微分と積分の交換可能性)が成立
証明

$\nabla_{\boldsymbol{\theta}} \log p = \displaystyle\frac{\nabla_{\boldsymbol{\theta}} p}{p}$ を再度微分する。

\begin{equation}\nabla_{\boldsymbol{\theta}}^2 \log p = \frac{\nabla_{\boldsymbol{\theta}}^2 p}{p} - \frac{\nabla_{\boldsymbol{\theta}} p (\nabla_{\boldsymbol{\theta}} p)^\top}{p^2} \label{eq:18-6-1}\end{equation}

期待値を取ると、第1項は $\nabla_{\boldsymbol{\theta}}^2 \int p \, d\boldsymbol{x} = \boldsymbol{0}$。

\begin{equation}\mathbb{E}[\nabla_{\boldsymbol{\theta}}^2 \log p] = -\mathbb{E}\left[\frac{\nabla_{\boldsymbol{\theta}} p (\nabla_{\boldsymbol{\theta}} p)^\top}{p^2}\right] = -\boldsymbol{F}(\boldsymbol{\theta}) \label{eq:18-6-2}\end{equation}

補足:この関係により、Fisher情報行列は対数尤度のHessianの期待値の符号反転として計算できる。ニュートン法との接続点となる。

18.7 自然勾配

公式:$\tilde{\nabla}_{\boldsymbol{\theta}} L = \boldsymbol{F}(\boldsymbol{\theta})^{-1} \nabla_{\boldsymbol{\theta}} L$
条件:$L$ は損失関数、$\boldsymbol{F}(\boldsymbol{\theta})$ は正則
証明

パラメータ空間がRiemann多様体であると考える。Fisher情報行列は計量テンソルとなる。

ユークリッド勾配 $\nabla_{\boldsymbol{\theta}} L$ を計量 $\boldsymbol{F}$ で変換して自然勾配を得る。

\begin{equation}\tilde{\nabla}_{\boldsymbol{\theta}} L = \boldsymbol{F}(\boldsymbol{\theta})^{-1} \nabla_{\boldsymbol{\theta}} L \label{eq:18-7-1}\end{equation}

これはKL距離を制約とした最急降下方向に対応する。

補足:自然勾配法はパラメータの再パラメータ化に対して不変である。TRPO、PPOなどの強化学習アルゴリズムの理論的基盤。

18.8 Cramér-Raoの下界

公式:$\text{Var}(\hat{\theta}_i) \geq [\boldsymbol{F}(\boldsymbol{\theta})^{-1}]_{ii}$
条件:$\hat{\boldsymbol{\theta}}$ は不偏推定量
証明

不偏推定量の定義より $\mathbb{E}[\hat{\boldsymbol{\theta}}] = \boldsymbol{\theta}$。両辺を $\boldsymbol{\theta}$ で微分する。

\begin{equation}\nabla_{\boldsymbol{\theta}} \mathbb{E}[\hat{\boldsymbol{\theta}}] = \boldsymbol{I} \label{eq:18-8-1}\end{equation}

Cauchy-Schwarzの不等式を適用して下界を得る。

\begin{equation}\text{Cov}(\hat{\boldsymbol{\theta}}) \succeq \boldsymbol{F}(\boldsymbol{\theta})^{-1} \label{eq:18-8-2}\end{equation}

補足:この下界を達成する推定量を有効推定量という。最尤推定量は漸近的に有効である。

18.9 強化学習と方策勾配

強化学習における方策勾配法では、期待累積報酬の勾配を推定する。行列微分は方策ネットワークの勾配計算に使用。

18.9 方策勾配定理

公式:$\nabla_{\boldsymbol{\theta}} J(\boldsymbol{\theta}) = \mathbb{E}_{\pi_{\boldsymbol{\theta}}}\left[\nabla_{\boldsymbol{\theta}} \log \pi_{\boldsymbol{\theta}}(a|s) \cdot Q^{\pi_{\boldsymbol{\theta}}}(s, a)\right]$
条件:$J(\boldsymbol{\theta})$ は期待累積報酬、$\pi_{\boldsymbol{\theta}}(a|s)$ は方策、$Q^{\pi}(s, a)$ は行動価値関数
証明

期待累積報酬の定義は次の通り。

\begin{equation}J(\boldsymbol{\theta}) = \mathbb{E}_{\pi_{\boldsymbol{\theta}}}\left[\sum_{t=0}^{\infty} \gamma^t r_t\right] \label{eq:18-9-1}\end{equation}

状態分布 $d^{\pi}(s)$ を用いて書き直す。

\begin{equation}J(\boldsymbol{\theta}) = \sum_s d^{\pi}(s) \sum_a \pi_{\boldsymbol{\theta}}(a|s) Q^{\pi}(s, a) \label{eq:18-9-2}\end{equation}

$\boldsymbol{\theta}$ で微分する。$\nabla_{\boldsymbol{\theta}} \pi = \pi \nabla_{\boldsymbol{\theta}} \log \pi$ を使う。

\begin{equation}\nabla_{\boldsymbol{\theta}} J = \sum_s d^{\pi}(s) \sum_a \pi_{\boldsymbol{\theta}}(a|s) \nabla_{\boldsymbol{\theta}} \log \pi_{\boldsymbol{\theta}}(a|s) Q^{\pi}(s, a) \label{eq:18-9-3}\end{equation}

補足:REINFORCE アルゴリズムはこの勾配をモンテカルロサンプリングで推定する。分散低減のためベースラインを使用。

18.10 ベースラインによる分散低減

公式:$\nabla_{\boldsymbol{\theta}} J(\boldsymbol{\theta}) = \mathbb{E}_{\pi_{\boldsymbol{\theta}}}\left[\nabla_{\boldsymbol{\theta}} \log \pi_{\boldsymbol{\theta}}(a|s) \cdot (Q^{\pi}(s, a) - b(s))\right]$
条件:$b(s)$ は状態 $s$ のみに依存するベースライン関数
証明

ベースライン項の期待値がゼロになることを示す。

\begin{equation}\mathbb{E}_{a \sim \pi}\left[\nabla_{\boldsymbol{\theta}} \log \pi_{\boldsymbol{\theta}}(a|s) \cdot b(s)\right] = b(s) \sum_a \nabla_{\boldsymbol{\theta}} \pi_{\boldsymbol{\theta}}(a|s) \label{eq:18-10-1}\end{equation}

$\sum_a \pi_{\boldsymbol{\theta}}(a|s) = 1$ より $\sum_a \nabla_{\boldsymbol{\theta}} \pi_{\boldsymbol{\theta}}(a|s) = 0$。

よって $\eqref{eq:18-10-1}$ はゼロとなり、ベースラインは勾配のバイアスを変えない。

補足:最適ベースラインは $b^*(s) = \displaystyle\frac{\mathbb{E}[\|\nabla \log \pi\|^2 Q]}{\mathbb{E}[\|\nabla \log \pi\|^2]}$。実用上は価値関数 $V(s)$ がよく使われる。

18.11 Actor-Criticの勾配

公式:$\nabla_{\boldsymbol{\theta}} J = \mathbb{E}\left[\nabla_{\boldsymbol{\theta}} \log \pi_{\boldsymbol{\theta}}(a|s) \cdot A^{\pi}(s, a)\right]$
条件:$A^{\pi}(s, a) = Q^{\pi}(s, a) - V^{\pi}(s)$ はアドバンテージ関数
証明

18.10 でベースライン $b(s) = V^{\pi}(s)$ を使用すると得られる。

\begin{equation}\nabla_{\boldsymbol{\theta}} J = \mathbb{E}\left[\nabla_{\boldsymbol{\theta}} \log \pi_{\boldsymbol{\theta}}(a|s) \cdot (Q^{\pi}(s, a) - V^{\pi}(s))\right] \label{eq:18-11-1}\end{equation}

$A^{\pi}(s, a) = Q^{\pi}(s, a) - V^{\pi}(s)$ をアドバンテージ関数という。

補足:Actor は方策 $\pi_{\boldsymbol{\theta}}$、Critic は価値関数 $V_{\boldsymbol{w}}$ を学習する。A2C、A3C、PPO などで使用。

18.12 PPOのクリップ目的関数

公式:$L^{\text{CLIP}}(\boldsymbol{\theta}) = \mathbb{E}\left[\min\left(r_t(\boldsymbol{\theta}) A_t, \text{clip}(r_t(\boldsymbol{\theta}), 1-\epsilon, 1+\epsilon) A_t\right)\right]$
条件:$r_t(\boldsymbol{\theta}) = \displaystyle\frac{\pi_{\boldsymbol{\theta}}(a_t|s_t)}{\pi_{\boldsymbol{\theta}_{\text{old}}}(a_t|s_t)}$ は確率比、$\epsilon \approx 0.2$
証明

重要度サンプリングにより、古い方策のサンプルで新しい方策の勾配を推定する。

\begin{equation}\nabla_{\boldsymbol{\theta}} J = \mathbb{E}_{\pi_{\text{old}}}\left[r_t(\boldsymbol{\theta}) \nabla_{\boldsymbol{\theta}} \log \pi_{\boldsymbol{\theta}}(a_t|s_t) A_t\right] \label{eq:18-12-1}\end{equation}

クリップ関数により $r_t$ が $[1-\epsilon, 1+\epsilon]$ 外に出ると勾配が停止し、方策の急激な変化を防ぐ。

補足:PPO(Proximal Policy Optimization)は実装が簡単で安定した学習が可能。OpenAI Fiveなどで成功。

18.13 自然言語処理とAttention

TransformerのAttention機構における行列微分。ソフトマックス、マルチヘッド注意の勾配計算。

18.13 Attention重みの勾配

公式:$\displaystyle\frac{\partial}{\partial \boldsymbol{Q}} \text{Attention}(\boldsymbol{Q}, \boldsymbol{K}, \boldsymbol{V}) = \displaystyle\frac{1}{\sqrt{d_k}} (\boldsymbol{I} - \boldsymbol{A})\text{diag}(\boldsymbol{A}) \cdot \displaystyle\frac{\partial L}{\partial \boldsymbol{O}} \cdot \boldsymbol{V}^\top \cdot \boldsymbol{K}$
条件:$\boldsymbol{A} = \text{softmax}(\boldsymbol{Q}\boldsymbol{K}^\top / \sqrt{d_k})$、$\boldsymbol{O} = \boldsymbol{A}\boldsymbol{V}$ は出力
証明

Scaled Dot-Product Attention の定義を確認する。

\begin{equation}\boldsymbol{O} = \text{softmax}\left(\frac{\boldsymbol{Q}\boldsymbol{K}^\top}{\sqrt{d_k}}\right)\boldsymbol{V} = \boldsymbol{A}\boldsymbol{V} \label{eq:18-13-1}\end{equation}

連鎖律を適用し、ソフトマックスのJacobian(6.2)を使用する。

\begin{equation}\frac{\partial L}{\partial \boldsymbol{Q}} = \frac{\partial L}{\partial \boldsymbol{O}} \cdot \frac{\partial \boldsymbol{O}}{\partial \boldsymbol{A}} \cdot \frac{\partial \boldsymbol{A}}{\partial \boldsymbol{S}} \cdot \frac{\partial \boldsymbol{S}}{\partial \boldsymbol{Q}} \label{eq:18-13-2}\end{equation}

ここで $\boldsymbol{S} = \boldsymbol{Q}\boldsymbol{K}^\top / \sqrt{d_k}$ はスコア行列。

補足:実装ではFlash Attentionなどのメモリ効率的な手法が使われ、$\boldsymbol{A}$ を明示的に保存せずに勾配を計算する。

18.14 マルチヘッドAttentionの勾配

公式:$\displaystyle\frac{\partial L}{\partial \boldsymbol{W}_i^Q} = \boldsymbol{X}^\top \displaystyle\frac{\partial L}{\partial \boldsymbol{Q}_i}$
条件:$\boldsymbol{Q}_i = \boldsymbol{X}\boldsymbol{W}_i^Q$ は第 $i$ ヘッドのQuery
証明

マルチヘッドAttentionの出力は各ヘッドの連結と出力射影からなる。

\begin{equation}\text{MultiHead}(\boldsymbol{Q}, \boldsymbol{K}, \boldsymbol{V}) = \text{Concat}(\text{head}_1, \ldots, \text{head}_h)\boldsymbol{W}^O \label{eq:18-14-1}\end{equation}

各ヘッドは $\text{head}_i = \text{Attention}(\boldsymbol{X}\boldsymbol{W}_i^Q, \boldsymbol{X}\boldsymbol{W}_i^K, \boldsymbol{X}\boldsymbol{W}_i^V)$。

$\boldsymbol{W}_i^Q$ に対する勾配は 3.1 の形式で得られる。

補足:典型的なTransformerでは $h = 8$ または $h = 12$ ヘッド。$d_k = d_v = d_{\text{model}} / h$。

18.15 Layer Normalizationの勾配

公式:$\displaystyle\frac{\partial L}{\partial \boldsymbol{x}} = \displaystyle\frac{\gamma}{\sigma}\left(\displaystyle\frac{\partial L}{\partial \hat{\boldsymbol{x}}} - \displaystyle\frac{1}{d}\mathbf{1}\mathbf{1}^\top\displaystyle\frac{\partial L}{\partial \hat{\boldsymbol{x}}} - \displaystyle\frac{\hat{\boldsymbol{x}}}{d}\hat{\boldsymbol{x}}^\top\displaystyle\frac{\partial L}{\partial \hat{\boldsymbol{x}}}\right)$
条件:$\hat{\boldsymbol{x}} = (\boldsymbol{x} - \mu\mathbf{1})/\sigma$、$\mu = \displaystyle\frac{1}{d}\sum_i x_i$、$\sigma = \sqrt{\displaystyle\frac{1}{d}\sum_i (x_i - \mu)^2 + \epsilon}$
証明

Layer Normalizationは特徴次元に沿って正規化する。

\begin{equation}\text{LayerNorm}(\boldsymbol{x}) = \gamma \odot \hat{\boldsymbol{x}} + \beta \label{eq:18-15-1}\end{equation}

$\mu$ と $\sigma$ が $\boldsymbol{x}$ の全成分に依存するため、勾配計算は複雑になる。

17.5 と類似の手順で導出。

補足:TransformerではBatch Normalizationの代わりにLayer Normalizationを使用。シーケンス長が可変でもバッチ統計に依存しない。

18.16 埋め込み層の勾配

公式:$\displaystyle\frac{\partial L}{\partial \boldsymbol{E}_{[i,:]}} = \displaystyle\frac{\partial L}{\partial \boldsymbol{e}}$ ($i$ 番目のトークンが選択されたとき)
条件:$\boldsymbol{E} \in \mathbb{R}^{V \times d}$ は埋め込み行列、$V$ は語彙サイズ、$d$ は埋め込み次元
証明

埋め込み層はone-hotベクトル $\boldsymbol{o}_i$ との行列積である。

\begin{equation}\boldsymbol{e} = \boldsymbol{E}^\top \boldsymbol{o}_i = \boldsymbol{E}_{[i,:]} \label{eq:18-16-1}\end{equation}

勾配は選択された行にのみ伝播する。

\begin{equation}\frac{\partial L}{\partial \boldsymbol{E}_{[j,:]}} = \begin{cases} \frac{\partial L}{\partial \boldsymbol{e}} & \text{if } j = i \\ \boldsymbol{0} & \text{otherwise} \end{cases} \label{eq:18-16-2}\end{equation}

補足:同一トークンがバッチ内で複数回出現する場合、勾配は加算される。スパース勾配更新により効率化可能。

18.17 交差エントロピー損失の勾配(ロジット)

公式:$\displaystyle\frac{\partial L}{\partial \boldsymbol{z}} = \text{softmax}(\boldsymbol{z}) - \boldsymbol{y}$
条件:$L = -\sum_i y_i \log(\text{softmax}(\boldsymbol{z})_i)$、$\boldsymbol{y}$ はone-hotラベル
証明

$\boldsymbol{p} = \text{softmax}(\boldsymbol{z})$ とおく。交差エントロピーは $L = -\sum_i y_i \log p_i$。

連鎖律を適用する。

\begin{equation}\frac{\partial L}{\partial z_j} = \sum_i \frac{\partial L}{\partial p_i} \frac{\partial p_i}{\partial z_j} \label{eq:18-17-1}\end{equation}

$\displaystyle\frac{\partial L}{\partial p_i} = -\displaystyle\frac{y_i}{p_i}$ と softmax の Jacobian を代入して計算すると $p_j - y_j$ を得る。

補足:この勾配は非常にシンプルで、予測確率とターゲットの差となる。言語モデルの次トークン予測で使用。

18.18 その他の高度なトピック

行列微分の追加の応用例。

18.18 ガウス過程の対数周辺尤度勾配

公式:$\displaystyle\frac{\partial}{\partial \theta} \log p(\boldsymbol{y}|\boldsymbol{X}, \theta) = \displaystyle\frac{1}{2}\text{tr}\left((\boldsymbol{\alpha}\boldsymbol{\alpha}^\top - \boldsymbol{K}^{-1})\displaystyle\frac{\partial \boldsymbol{K}}{\partial \theta}\right)$
条件:$\boldsymbol{\alpha} = \boldsymbol{K}^{-1}\boldsymbol{y}$、$\boldsymbol{K}$ はカーネル行列、$\theta$ はハイパーパラメータ
証明

ガウス過程の対数周辺尤度は次の通り。

\begin{equation}\log p(\boldsymbol{y}|\boldsymbol{X}, \theta) = -\frac{1}{2}\boldsymbol{y}^\top\boldsymbol{K}^{-1}\boldsymbol{y} - \frac{1}{2}\log|\boldsymbol{K}| - \frac{n}{2}\log(2\pi) \label{eq:18-18-1}\end{equation}

7.28.1 を使って微分する。

補足:ガウス過程回帰のハイパーパラメータ最適化に使用。計算量は $O(n^3)$ だが、低ランク近似で削減可能。

18.19 独立成分分析(ICA)の勾配

公式:$\nabla_{\boldsymbol{W}} J = \left(\boldsymbol{I} - \mathbb{E}[\boldsymbol{g}(\boldsymbol{y})\boldsymbol{y}^\top]\right)\boldsymbol{W}$
条件:$\boldsymbol{y} = \boldsymbol{W}\boldsymbol{x}$ は分離信号、$\boldsymbol{g}$ は非線形関数
証明

ICAの目的関数は非ガウス性の最大化。FastICAアルゴリズムの固定点反復を導出する。

$\boldsymbol{g}(y) = \tanh(y)$ または $\boldsymbol{g}(y) = y \exp(-y^2/2)$ がよく使われる。

補足:ブラインド信号分離、脳波解析などで使用。白色化を前処理として適用することが多い。

18.20 自然言語処理の勾配公式

単語埋め込み(Word2Vec、GloVe)や対照学習(InfoNCE)で使用される目的関数の勾配。

18.20 Skip-gram(負例サンプリング)の勾配

公式:$\dfrac{\partial L}{\partial \boldsymbol{w}_c} = (\sigma(\boldsymbol{w}_c^\top \boldsymbol{w}_o) - 1)\boldsymbol{w}_o + \sum_{k} \sigma(\boldsymbol{w}_c^\top \boldsymbol{w}_k)\boldsymbol{w}_k$
条件:$\boldsymbol{w}_c$: 中心語ベクトル、$\boldsymbol{w}_o$: 文脈語ベクトル、$\boldsymbol{w}_k$: 負例ベクトル、$\sigma$: シグモイド関数
証明

負例サンプリング付きSkip-gramの損失関数は次のように定義される。

\begin{equation}L = -\log\sigma(\boldsymbol{w}_c^\top \boldsymbol{w}_o) - \sum_{k=1}^{K}\log\sigma(-\boldsymbol{w}_c^\top \boldsymbol{w}_k) \label{eq:18-20-1}\end{equation}

$\sigma'(x) = \sigma(x)(1-\sigma(x))$ と $\frac{d}{dx}\log\sigma(x) = 1 - \sigma(x)$ を用いて $\boldsymbol{w}_c$ で微分する。

補足:Word2Vecの効率的な学習アルゴリズム。完全なsoftmaxの代わりに負例サンプリングを使用することで計算量を削減。

18.21 GloVeの勾配

公式:$\dfrac{\partial J}{\partial \boldsymbol{w}_i} = \sum_{j} f(X_{ij})(\boldsymbol{w}_i^\top \tilde{\boldsymbol{w}}_j + b_i + \tilde{b}_j - \log X_{ij})\tilde{\boldsymbol{w}}_j$
条件:$X_{ij}$: 共起行列、$f$: 重み関数、$\boldsymbol{w}_i, \tilde{\boldsymbol{w}}_j$: 単語ベクトル、$b_i, \tilde{b}_j$: バイアス
証明

GloVeの損失関数は共起行列の対数を予測するように設計されている。

\begin{equation}J = \sum_{i,j} f(X_{ij})(\boldsymbol{w}_i^\top \tilde{\boldsymbol{w}}_j + b_i + \tilde{b}_j - \log X_{ij})^2 \label{eq:18-21-1}\end{equation}

$\boldsymbol{w}_i$ に関する二次形式として微分し、勾配を得る。

補足:GloVeは大域的な共起統計量を活用する単語埋め込み手法。重み関数 $f(x) = \min(1, (x/x_{\max})^{0.75})$ が一般的。

18.22 InfoNCE損失関数の勾配

公式:$\mathcal{L}_{\text{NCE}} = -\log \dfrac{\exp(\boldsymbol{q}^\top \boldsymbol{k}^+ / \tau)}{\exp(\boldsymbol{q}^\top \boldsymbol{k}^+ / \tau) + \sum_{j=1}^{K} \exp(\boldsymbol{q}^\top \boldsymbol{k}^-_j / \tau)}$
条件:$\boldsymbol{q}$: クエリ、$\boldsymbol{k}^+$: 正例キー、$\boldsymbol{k}^-_j$: 負例キー、$\tau$: 温度パラメータ
証明

クエリ $\boldsymbol{q}$ に関する勾配:

\begin{equation}\frac{\partial \mathcal{L}_{\text{NCE}}}{\partial \boldsymbol{q}} = \frac{1}{\tau}\left(-\boldsymbol{k}^+ + \sum_{i} p_i \boldsymbol{k}_i\right) \label{eq:18-22-1}\end{equation}

ここで $p_i = \text{softmax}(\boldsymbol{q}^\top \boldsymbol{k}_i / \tau)$ である。

正例キー $\boldsymbol{k}^+$ に関する勾配:

\begin{equation}\frac{\partial \mathcal{L}_{\text{NCE}}}{\partial \boldsymbol{k}^+} = \frac{1}{\tau}(p_+ - 1)\boldsymbol{q} \label{eq:18-22-2}\end{equation}

補足:対照学習(SimCLR、CLIP、MoCo等)の基礎となる損失関数。温度パラメータ $\tau$ は分布の鋭さを制御する。