SVD・Fisher情報・強化学習・NLP
高度な機械学習手法におけるベクトル微分/行列微分の応用。
本ページの公式は分母レイアウト(denominator layout)に基づく。詳細はレイアウト規約を参照。
18.1 特異値分解(SVD)と行列微分
特異値分解を含む損失関数の最適化において行列微分が必要となる。低ランク近似、行列補完、主成分分析などで使用。
18.1 低ランク近似の勾配
証明
$\boldsymbol{A} = \boldsymbol{U}\boldsymbol{V}^\top$ とおくと、12.6 より直接得られる。
\begin{equation}\frac{\partial}{\partial \boldsymbol{X}} \|\boldsymbol{X} - \boldsymbol{A}\|_F^2 = 2(\boldsymbol{X} - \boldsymbol{A}) = 2(\boldsymbol{X} - \boldsymbol{U}\boldsymbol{V}^\top) \label{eq:18-1-1}\end{equation}
18.2 因子行列 $\boldsymbol{U}$ の勾配
証明
$\boldsymbol{E} = \boldsymbol{X} - \boldsymbol{U}\boldsymbol{V}^\top$ とおく。損失関数は $L = \|\boldsymbol{E}\|_F^2 = \text{tr}(\boldsymbol{E}^\top\boldsymbol{E})$ である。
\begin{equation}L = \text{tr}((\boldsymbol{X} - \boldsymbol{U}\boldsymbol{V}^\top)^\top(\boldsymbol{X} - \boldsymbol{U}\boldsymbol{V}^\top)) \label{eq:18-2-1}\end{equation}
$\eqref{eq:18-2-1}$ を展開する。
\begin{equation}L = \text{tr}(\boldsymbol{X}^\top\boldsymbol{X}) - 2\text{tr}(\boldsymbol{X}^\top\boldsymbol{U}\boldsymbol{V}^\top) + \text{tr}(\boldsymbol{V}\boldsymbol{U}^\top\boldsymbol{U}\boldsymbol{V}^\top) \label{eq:18-2-2}\end{equation}
$\boldsymbol{U}$ に関する項のみを考える。第2項は $\text{tr}(\boldsymbol{V}^\top\boldsymbol{X}^\top\boldsymbol{U})$(トレースの巡回性)。
3.1 より $\displaystyle\frac{\partial}{\partial \boldsymbol{U}}\text{tr}(\boldsymbol{V}^\top\boldsymbol{X}^\top\boldsymbol{U}) = \boldsymbol{X}\boldsymbol{V}$。
第3項は 3.5 の形式で、$\displaystyle\frac{\partial}{\partial \boldsymbol{U}}\text{tr}(\boldsymbol{V}\boldsymbol{U}^\top\boldsymbol{U}\boldsymbol{V}^\top) = 2\boldsymbol{U}\boldsymbol{V}^\top\boldsymbol{V}$。
\begin{equation}\frac{\partial L}{\partial \boldsymbol{U}} = -2\boldsymbol{X}\boldsymbol{V} + 2\boldsymbol{U}\boldsymbol{V}^\top\boldsymbol{V} = -2(\boldsymbol{X} - \boldsymbol{U}\boldsymbol{V}^\top)\boldsymbol{V} \label{eq:18-2-3}\end{equation}
18.3 因子行列 $\boldsymbol{V}$ の勾配
証明
18.2 と同様の手順で導出。$\boldsymbol{V}$ に関する微分を計算する。
\begin{equation}\frac{\partial L}{\partial \boldsymbol{V}} = -2\boldsymbol{X}^\top\boldsymbol{U} + 2\boldsymbol{V}\boldsymbol{U}^\top\boldsymbol{U} = -2(\boldsymbol{X}^\top - \boldsymbol{V}\boldsymbol{U}^\top)\boldsymbol{U} \label{eq:18-3-1}\end{equation}
$(\boldsymbol{X}^\top - \boldsymbol{V}\boldsymbol{U}^\top) = (\boldsymbol{X} - \boldsymbol{U}\boldsymbol{V}^\top)^\top$ より公式を得る。
18.4 核ノルムの劣勾配
証明
核ノルムの定義は $\|\boldsymbol{X}\|_* = \sum_{i} \sigma_i = \text{tr}(\boldsymbol{\Sigma})$ である。
$\boldsymbol{X}$ がフルランクのとき、劣勾配は一意で $\boldsymbol{U}\boldsymbol{V}^\top$ となる。
ランク落ちの場合、零空間への射影成分 $\boldsymbol{W}$ の自由度がある。
18.5 Fisher情報行列
パラメータ推定の精度を表すFisher情報行列は、対数尤度の2階微分から計算される。自然勾配法、信頼領域法で使用。
18.5 Fisher情報行列の定義
証明
対数尤度のスコア関数を $\boldsymbol{s}(\boldsymbol{\theta}) = \nabla_{\boldsymbol{\theta}} \log p(\boldsymbol{x}|\boldsymbol{\theta})$ と定義する。
\begin{equation}\boldsymbol{F}(\boldsymbol{\theta}) = \mathbb{E}[\boldsymbol{s}(\boldsymbol{\theta})\boldsymbol{s}(\boldsymbol{\theta})^\top] = \text{Cov}(\boldsymbol{s}(\boldsymbol{\theta})) \label{eq:18-5-1}\end{equation}
$\mathbb{E}[\boldsymbol{s}(\boldsymbol{\theta})] = \boldsymbol{0}$ であることを示す。
\begin{equation}\mathbb{E}[\boldsymbol{s}(\boldsymbol{\theta})] = \int \frac{\nabla_{\boldsymbol{\theta}} p(\boldsymbol{x}|\boldsymbol{\theta})}{p(\boldsymbol{x}|\boldsymbol{\theta})} p(\boldsymbol{x}|\boldsymbol{\theta}) d\boldsymbol{x} = \nabla_{\boldsymbol{\theta}} \int p(\boldsymbol{x}|\boldsymbol{\theta}) d\boldsymbol{x} = \nabla_{\boldsymbol{\theta}} 1 = \boldsymbol{0} \label{eq:18-5-2}\end{equation}
18.6 Fisher情報とHessianの関係
証明
$\nabla_{\boldsymbol{\theta}} \log p = \displaystyle\frac{\nabla_{\boldsymbol{\theta}} p}{p}$ を再度微分する。
\begin{equation}\nabla_{\boldsymbol{\theta}}^2 \log p = \frac{\nabla_{\boldsymbol{\theta}}^2 p}{p} - \frac{\nabla_{\boldsymbol{\theta}} p (\nabla_{\boldsymbol{\theta}} p)^\top}{p^2} \label{eq:18-6-1}\end{equation}
期待値を取ると、第1項は $\nabla_{\boldsymbol{\theta}}^2 \int p \, d\boldsymbol{x} = \boldsymbol{0}$。
\begin{equation}\mathbb{E}[\nabla_{\boldsymbol{\theta}}^2 \log p] = -\mathbb{E}\left[\frac{\nabla_{\boldsymbol{\theta}} p (\nabla_{\boldsymbol{\theta}} p)^\top}{p^2}\right] = -\boldsymbol{F}(\boldsymbol{\theta}) \label{eq:18-6-2}\end{equation}
18.7 自然勾配
証明
パラメータ空間がRiemann多様体であると考える。Fisher情報行列は計量テンソルとなる。
ユークリッド勾配 $\nabla_{\boldsymbol{\theta}} L$ を計量 $\boldsymbol{F}$ で変換して自然勾配を得る。
\begin{equation}\tilde{\nabla}_{\boldsymbol{\theta}} L = \boldsymbol{F}(\boldsymbol{\theta})^{-1} \nabla_{\boldsymbol{\theta}} L \label{eq:18-7-1}\end{equation}
これはKL距離を制約とした最急降下方向に対応する。
18.8 Cramér-Raoの下界
証明
不偏推定量の定義より $\mathbb{E}[\hat{\boldsymbol{\theta}}] = \boldsymbol{\theta}$。両辺を $\boldsymbol{\theta}$ で微分する。
\begin{equation}\nabla_{\boldsymbol{\theta}} \mathbb{E}[\hat{\boldsymbol{\theta}}] = \boldsymbol{I} \label{eq:18-8-1}\end{equation}
Cauchy-Schwarzの不等式を適用して下界を得る。
\begin{equation}\text{Cov}(\hat{\boldsymbol{\theta}}) \succeq \boldsymbol{F}(\boldsymbol{\theta})^{-1} \label{eq:18-8-2}\end{equation}
18.9 強化学習と方策勾配
強化学習における方策勾配法では、期待累積報酬の勾配を推定する。行列微分は方策ネットワークの勾配計算に使用。
18.9 方策勾配定理
証明
期待累積報酬の定義は次の通り。
\begin{equation}J(\boldsymbol{\theta}) = \mathbb{E}_{\pi_{\boldsymbol{\theta}}}\left[\sum_{t=0}^{\infty} \gamma^t r_t\right] \label{eq:18-9-1}\end{equation}
状態分布 $d^{\pi}(s)$ を用いて書き直す。
\begin{equation}J(\boldsymbol{\theta}) = \sum_s d^{\pi}(s) \sum_a \pi_{\boldsymbol{\theta}}(a|s) Q^{\pi}(s, a) \label{eq:18-9-2}\end{equation}
$\boldsymbol{\theta}$ で微分する。$\nabla_{\boldsymbol{\theta}} \pi = \pi \nabla_{\boldsymbol{\theta}} \log \pi$ を使う。
\begin{equation}\nabla_{\boldsymbol{\theta}} J = \sum_s d^{\pi}(s) \sum_a \pi_{\boldsymbol{\theta}}(a|s) \nabla_{\boldsymbol{\theta}} \log \pi_{\boldsymbol{\theta}}(a|s) Q^{\pi}(s, a) \label{eq:18-9-3}\end{equation}
18.10 ベースラインによる分散低減
証明
ベースライン項の期待値がゼロになることを示す。
\begin{equation}\mathbb{E}_{a \sim \pi}\left[\nabla_{\boldsymbol{\theta}} \log \pi_{\boldsymbol{\theta}}(a|s) \cdot b(s)\right] = b(s) \sum_a \nabla_{\boldsymbol{\theta}} \pi_{\boldsymbol{\theta}}(a|s) \label{eq:18-10-1}\end{equation}
$\sum_a \pi_{\boldsymbol{\theta}}(a|s) = 1$ より $\sum_a \nabla_{\boldsymbol{\theta}} \pi_{\boldsymbol{\theta}}(a|s) = 0$。
よって $\eqref{eq:18-10-1}$ はゼロとなり、ベースラインは勾配のバイアスを変えない。
18.11 Actor-Criticの勾配
証明
18.10 でベースライン $b(s) = V^{\pi}(s)$ を使用すると得られる。
\begin{equation}\nabla_{\boldsymbol{\theta}} J = \mathbb{E}\left[\nabla_{\boldsymbol{\theta}} \log \pi_{\boldsymbol{\theta}}(a|s) \cdot (Q^{\pi}(s, a) - V^{\pi}(s))\right] \label{eq:18-11-1}\end{equation}
$A^{\pi}(s, a) = Q^{\pi}(s, a) - V^{\pi}(s)$ をアドバンテージ関数という。
18.12 PPOのクリップ目的関数
証明
重要度サンプリングにより、古い方策のサンプルで新しい方策の勾配を推定する。
\begin{equation}\nabla_{\boldsymbol{\theta}} J = \mathbb{E}_{\pi_{\text{old}}}\left[r_t(\boldsymbol{\theta}) \nabla_{\boldsymbol{\theta}} \log \pi_{\boldsymbol{\theta}}(a_t|s_t) A_t\right] \label{eq:18-12-1}\end{equation}
クリップ関数により $r_t$ が $[1-\epsilon, 1+\epsilon]$ 外に出ると勾配が停止し、方策の急激な変化を防ぐ。
18.13 自然言語処理とAttention
TransformerのAttention機構における行列微分。ソフトマックス、マルチヘッド注意の勾配計算。
18.13 Attention重みの勾配
証明
Scaled Dot-Product Attention の定義を確認する。
\begin{equation}\boldsymbol{O} = \text{softmax}\left(\frac{\boldsymbol{Q}\boldsymbol{K}^\top}{\sqrt{d_k}}\right)\boldsymbol{V} = \boldsymbol{A}\boldsymbol{V} \label{eq:18-13-1}\end{equation}
連鎖律を適用し、ソフトマックスのJacobian(6.2)を使用する。
\begin{equation}\frac{\partial L}{\partial \boldsymbol{Q}} = \frac{\partial L}{\partial \boldsymbol{O}} \cdot \frac{\partial \boldsymbol{O}}{\partial \boldsymbol{A}} \cdot \frac{\partial \boldsymbol{A}}{\partial \boldsymbol{S}} \cdot \frac{\partial \boldsymbol{S}}{\partial \boldsymbol{Q}} \label{eq:18-13-2}\end{equation}
ここで $\boldsymbol{S} = \boldsymbol{Q}\boldsymbol{K}^\top / \sqrt{d_k}$ はスコア行列。
18.14 マルチヘッドAttentionの勾配
証明
マルチヘッドAttentionの出力は各ヘッドの連結と出力射影からなる。
\begin{equation}\text{MultiHead}(\boldsymbol{Q}, \boldsymbol{K}, \boldsymbol{V}) = \text{Concat}(\text{head}_1, \ldots, \text{head}_h)\boldsymbol{W}^O \label{eq:18-14-1}\end{equation}
各ヘッドは $\text{head}_i = \text{Attention}(\boldsymbol{X}\boldsymbol{W}_i^Q, \boldsymbol{X}\boldsymbol{W}_i^K, \boldsymbol{X}\boldsymbol{W}_i^V)$。
$\boldsymbol{W}_i^Q$ に対する勾配は 3.1 の形式で得られる。
18.15 Layer Normalizationの勾配
証明
Layer Normalizationは特徴次元に沿って正規化する。
\begin{equation}\text{LayerNorm}(\boldsymbol{x}) = \gamma \odot \hat{\boldsymbol{x}} + \beta \label{eq:18-15-1}\end{equation}
$\mu$ と $\sigma$ が $\boldsymbol{x}$ の全成分に依存するため、勾配計算は複雑になる。
17.5 と類似の手順で導出。
18.16 埋め込み層の勾配
証明
埋め込み層はone-hotベクトル $\boldsymbol{o}_i$ との行列積である。
\begin{equation}\boldsymbol{e} = \boldsymbol{E}^\top \boldsymbol{o}_i = \boldsymbol{E}_{[i,:]} \label{eq:18-16-1}\end{equation}
勾配は選択された行にのみ伝播する。
\begin{equation}\frac{\partial L}{\partial \boldsymbol{E}_{[j,:]}} = \begin{cases} \frac{\partial L}{\partial \boldsymbol{e}} & \text{if } j = i \\ \boldsymbol{0} & \text{otherwise} \end{cases} \label{eq:18-16-2}\end{equation}
18.17 交差エントロピー損失の勾配(ロジット)
証明
$\boldsymbol{p} = \text{softmax}(\boldsymbol{z})$ とおく。交差エントロピーは $L = -\sum_i y_i \log p_i$。
連鎖律を適用する。
\begin{equation}\frac{\partial L}{\partial z_j} = \sum_i \frac{\partial L}{\partial p_i} \frac{\partial p_i}{\partial z_j} \label{eq:18-17-1}\end{equation}
$\displaystyle\frac{\partial L}{\partial p_i} = -\displaystyle\frac{y_i}{p_i}$ と softmax の Jacobian を代入して計算すると $p_j - y_j$ を得る。
18.18 その他の高度なトピック
行列微分の追加の応用例。
18.18 ガウス過程の対数周辺尤度勾配
証明
ガウス過程の対数周辺尤度は次の通り。
\begin{equation}\log p(\boldsymbol{y}|\boldsymbol{X}, \theta) = -\frac{1}{2}\boldsymbol{y}^\top\boldsymbol{K}^{-1}\boldsymbol{y} - \frac{1}{2}\log|\boldsymbol{K}| - \frac{n}{2}\log(2\pi) \label{eq:18-18-1}\end{equation}
18.19 独立成分分析(ICA)の勾配
証明
ICAの目的関数は非ガウス性の最大化。FastICAアルゴリズムの固定点反復を導出する。
$\boldsymbol{g}(y) = \tanh(y)$ または $\boldsymbol{g}(y) = y \exp(-y^2/2)$ がよく使われる。
18.20 自然言語処理の勾配公式
単語埋め込み(Word2Vec、GloVe)や対照学習(InfoNCE)で使用される目的関数の勾配。
18.20 Skip-gram(負例サンプリング)の勾配
証明
負例サンプリング付きSkip-gramの損失関数は次のように定義される。
\begin{equation}L = -\log\sigma(\boldsymbol{w}_c^\top \boldsymbol{w}_o) - \sum_{k=1}^{K}\log\sigma(-\boldsymbol{w}_c^\top \boldsymbol{w}_k) \label{eq:18-20-1}\end{equation}
$\sigma'(x) = \sigma(x)(1-\sigma(x))$ と $\frac{d}{dx}\log\sigma(x) = 1 - \sigma(x)$ を用いて $\boldsymbol{w}_c$ で微分する。
18.21 GloVeの勾配
証明
GloVeの損失関数は共起行列の対数を予測するように設計されている。
\begin{equation}J = \sum_{i,j} f(X_{ij})(\boldsymbol{w}_i^\top \tilde{\boldsymbol{w}}_j + b_i + \tilde{b}_j - \log X_{ij})^2 \label{eq:18-21-1}\end{equation}
$\boldsymbol{w}_i$ に関する二次形式として微分し、勾配を得る。
18.22 InfoNCE損失関数の勾配
証明
クエリ $\boldsymbol{q}$ に関する勾配:
\begin{equation}\frac{\partial \mathcal{L}_{\text{NCE}}}{\partial \boldsymbol{q}} = \frac{1}{\tau}\left(-\boldsymbol{k}^+ + \sum_{i} p_i \boldsymbol{k}_i\right) \label{eq:18-22-1}\end{equation}
ここで $p_i = \text{softmax}(\boldsymbol{q}^\top \boldsymbol{k}_i / \tau)$ である。
正例キー $\boldsymbol{k}^+$ に関する勾配:
\begin{equation}\frac{\partial \mathcal{L}_{\text{NCE}}}{\partial \boldsymbol{k}^+} = \frac{1}{\tau}(p_+ - 1)\boldsymbol{q} \label{eq:18-22-2}\end{equation}