自動微分と最適化
ベクトル微分/行列微分の実践的応用
はじめに
PyTorch、JAX、TensorFlowなどの深層学習フレームワークは「自動微分」という機能を提供している。 これにより、複雑なニューラルネットワークの勾配を手で計算することなく、 コードを書くだけで自動的に勾配が得られる。
自動微分は便利な道具であるが、その内部で何が行われているかを理解していないと、 予期しない挙動に遭遇したり、効率の悪いコードを書いてしまったりすることがある。 本ページでは、ベクトル微分/行列微分の公式集や テンソル微分入門で学んだ知識を踏まえ、 自動微分の仕組みと最適化への応用を解説する。
自動微分の基本
自動微分(Automatic Differentiation, AD)とは何かを明確にしておく。
数値微分・記号微分との違い
微分を計算する方法には、主に3つのアプローチがある。
| 方法 | 特徴 | 問題点 |
|---|---|---|
| 数値微分 | $\displaystyle\dfrac{f(x+h) - f(x)}{h}$ で近似 | 丸め誤差と打ち切り誤差のトレードオフ、多変数で計算量が大きい |
| 記号微分 | 数式を代数的に微分 | 式が膨張しやすい、条件分岐やループを含む関数に適用しにくい |
| 自動微分 | 関数を合成として捉え、連鎖律を自動適用 | 計算グラフの構築が必要 |
自動微分は、数値微分のような近似誤差がなく、記号微分のような式の膨張も起きない。 関数を基本演算の合成として捉え、各演算の局所的な微分を連鎖律で接続することで、 効率的かつ正確に微分を計算する。
自動微分の本質
自動微分の核心は、以下の考え方にある。
関数を合成の列として捉え、各局所微分を連鎖律で正確につなぎ合わせる。 Jacobi 行列や高階テンソル(全結合層の重みに関する 4 階 Jacobi など)そのものを明示的に構築するのではなく、 必要な「作用」(ベクトルとの積)だけを計算する。
ここで「作用」という言葉が重要である。 行列微分では Jacobi行列 $\boldsymbol{J}$ という行列が登場するが、 自動微分ではこの行列を陽に構築しない。 代わりに、「Jacobi行列をベクトルに掛けた結果」だけを計算する。
理論と実装の乖離:テンソル微分 vs ベクトル化
深層学習の計算を数学的に記述すると、しばしば高階テンソルの微分が登場する。 たとえば全結合層 $\boldsymbol{Y} = \boldsymbol{X}\boldsymbol{W}$($\boldsymbol{X}$: $m \times n$、$\boldsymbol{W}$: $n \times p$、$\boldsymbol{Y}$: $m \times p$)において、 $\boldsymbol{Y}$ の $\boldsymbol{W}$ に関する Jacobi 行列は4階テンソルになる。
しかし、実際のフレームワーク(PyTorch、JAX、TensorFlowなど)は、 このような高階テンソルを明示的に構築しない。 代わりに、テンソルの各要素をベクトルとみなし、ベクトル微分として計算を行う。
| 観点 | 理論(数学的記述) | 実装(フレームワーク内部) |
|---|---|---|
| 微分の表現 | 4階以上のテンソル(ヤコビアン) | ベクトル化による行列・ベクトル演算 |
| 勾配の形状 | $\displaystyle\dfrac{\partial L}{\partial \boldsymbol{W}}$ は $\boldsymbol{W}$ と同じ形状 | 同じ(reshape で戻す) |
| メモリ消費 | $O(n^4)$(明示的テンソル) | $O(n^2)$(要素ごとの逆伝播) |
なぜベクトル化するのか
- 計算効率:GPU/TPU は行列積(GEMM)に最適化されている。 テンソル縮約を直接実装するより、ベクトル化して行列積に帰着させる方が高速である。
- メモリ効率:4階テンソルのヤコビアンを明示的に保持すると $O(n^4)$ のメモリが必要になるが、 要素ごとの逆伝播なら $O(n^2)$ で済む。
- 実装の単純化:各演算の「ローカルな勾配ルール」を定義しておけば、 連鎖律で自動的に接続できる。テンソル全体のヤコビアンを陽に構築する必要がない。
具体例:全結合層の勾配
全結合層 $\boldsymbol{Y} = \boldsymbol{X}\boldsymbol{W}$ を考える。 理論上の勾配は4階テンソルだが、実装では以下のように行列積で計算される。 ただしここでは ML 実装の慣用に従い、損失 $L$ に対する勾配を $\mathtt{dA} := \partial L / \partial \boldsymbol{A}$ と略記する(数学の微分形式 $d\boldsymbol{A}$ とは別物)。 $\mathtt{@}$ は行列積、$\mathtt{.T}$ は転置(Python / NumPy / PyTorch の慣用)。
# 逆伝播で上流から dY = ∂L/∂Y(Y と同じ形状)が来たとする
dW = X.T @ dY # ∂L/∂W = Xᵀ (∂L/∂Y)
dX = dY @ W.T # ∂L/∂X = (∂L/∂Y) Wᵀ
この計算結果は、テンソル微分で導出した結果と一致する。 フレームワークは「テンソル微分の結果と同じ形状の勾配」を「ベクトル化+行列演算」で効率的に計算しているのである。
要するに、深層学習は数学的にはテンソル微分であるが、 実装はその構造を利用してショートカットしている。 理論を理解していれば、なぜそのショートカットが正しいのかがわかる。
VJPとJVP
自動微分には2つのモードがある。それぞれの意味と使い分けを理解することが重要である。
JVP(Jacobian-Vector Product)— 順方向モード
関数 $\boldsymbol{f}: \mathbb{R}^n \to \mathbb{R}^m$ と入力方向ベクトル $\boldsymbol{v} \in \mathbb{R}^n$ に対し、 JVP は $$ \text{JVP}(\boldsymbol{v}) = \boldsymbol{J} \boldsymbol{v} = D\boldsymbol{f}(\boldsymbol{x})[\boldsymbol{v}] $$ を計算する。ここで $\boldsymbol{J}$ は Jacobi行列である。
直感的には、入力方向 $\boldsymbol{v}$ に沿って関数がどのように変化するかを表している。 入力次元 $n$ が小さいときに効率的であり、1回のJVP計算で Jacobi行列の1列分の情報が得られる。
VJP(Vector-Jacobian Product)— 逆方向モード
出力側のcotangent(随伴)ベクトル $\boldsymbol{w} \in \mathbb{R}^m$ に対し、 VJP は $$ \text{VJP}(\boldsymbol{w}) = \boldsymbol{J}^\top \boldsymbol{w} = D\boldsymbol{f}(\boldsymbol{x})^*[\boldsymbol{w}] $$ を計算する。
直感的には、出力側からやってくる感度情報を、入力側に「逆伝播」させる操作である。 出力次元 $m$ が小さいときに効率的であり、1回のVJP計算で Jacobi行列の1行分の情報が得られる。
深層学習でVJPが使われる理由
深層学習の学習では、最終出力は損失関数という1次元のスカラーである。 一方、入力(パラメータ)は数百万〜数十億次元になることがある。
- JVP(順方向): 入力次元の回数だけ計算が必要 → 数百万回
- VJP(逆方向): 出力次元の回数だけ計算が必要 → 1回
このため、深層学習ではほぼ例外なくVJP(逆方向モード)が使われる。 これがバックプロパゲーション(誤差逆伝播法)の数学的正体である。
行列微分との対応
自動微分と行列微分の関係を明確にする。
分母レイアウトとの一致
テンソル微分入門で解説したように、 分母レイアウトでは連鎖律が縮約として自然に表現できる。 VJP はまさにこの縮約操作を行っており、分母レイアウトの考え方と完全に一致する。
| 行列微分の概念 | 自動微分での対応 |
|---|---|
| Jacobi行列 $\boldsymbol{J}$ | 明示的に構築しない |
| 勾配 $\nabla f$ | VJP$(1)$ の結果 |
| 連鎖律による縮約 | VJP の逐次適用 |
| 勾配の形状 = 入力の形状 | フレームワークの標準仕様 |
PyTorchの .backward() やJAXの grad() が返す勾配が
入力と同じ形状になるのは、分母レイアウトの自然な帰結である。
転置が「現れない」理由
行列微分の教科書では $\boldsymbol{J}^\top$ という転置がよく登場するが、 自動微分の実装では転置を明示的に行うことは少ない。 これは、VJP が「行列を作ってから転置して掛ける」のではなく、 最初から「転置後の掛け算に相当する操作」を直接計算しているためである。
行列微分を学ぶ際に転置が混乱の原因になることがあるが、 自動微分の観点から見ると、転置は「行列表現の副産物」であり本質ではない。 本質は「作用としての微分」であり、それをVJP/JVP という形で計算している。
最適化アルゴリズムとの接続
自動微分で得られた勾配は、最適化アルゴリズムの入力として使われる。 代表的なアルゴリズムを整理する。
最適化の一般形
多くの最適化アルゴリズムは、以下の形で統一的に表現できる。
$$ \boldsymbol{x}_{k+1} = \boldsymbol{x}_k - \alpha_k \boldsymbol{P}_k \boldsymbol{g}_k $$ここで $\boldsymbol{g}_k = \nabla f(\boldsymbol{x}_k)$ は勾配、 $\boldsymbol{P}_k$ は前処理行列(preconditioner)、 $\alpha_k$ はステップ長である。 アルゴリズムの違いは、$\boldsymbol{P}_k$ をどう設計するかの違いと見なせる。
主要なアルゴリズム
| アルゴリズム | 前処理 $\boldsymbol{P}_k$ | 必要な微分情報 | 特徴 |
|---|---|---|---|
| SGD | $\boldsymbol{I}$(単位行列) | 勾配のみ | シンプル、ノイズに強い |
| Momentum | 過去の勾配の指数移動平均 | 勾配のみ | 振動を抑制 |
| Adam | $\text{diag}(1/\sqrt{\boldsymbol{v}_k + \epsilon})$ | 勾配のみ | 適応的な学習率 |
| L-BFGS | Hessian逆行列の低ランク近似 | 勾配のみ(差分から構築) | 二次収束に近い |
| Newton法 | $\boldsymbol{H}^{-1}$ | 勾配 + Hessian | 二次収束、大規模では非実用的 |
Adamが広く使われる理由
Adam は、勾配の1次モーメント(平均)と2次モーメント(分散)を追跡し、 成分ごとに学習率を調整する。 これは対角Hessianの近似と解釈でき、二次情報を暗黙的に利用している。
必要なのは勾配だけであり、VJP 1回で計算できる。 かつ、パラメータごとのスケールの違いを自動的に吸収してくれる。 これが深層学習で Adam が標準的に使われる理由である。
Hessian-Vector Product(HVP)
Newton法を直接適用するには Hessian $\boldsymbol{H}$ の計算と逆行列が必要で、 大規模問題では現実的でない。 しかし、Hessian を「作用」として使うことは可能である。
$$ \text{HVP}(\boldsymbol{v}) = \boldsymbol{H} \boldsymbol{v} = \nabla(\nabla f \cdot \boldsymbol{v}) $$これは「勾配とベクトルの内積」の勾配として計算できる。 VJP と JVP を組み合わせることで、Hessian を明示的に構築せずに Hessian-ベクトル積だけを効率的に計算できる。 共役勾配法と組み合わせれば、近似的なNewton法が実現できる。
数値安定性
理論的に正しい数式でも、有限精度の浮動小数点演算で計算すると 予期しない結果になることがある。数値安定性の基本を理解しておくことは重要である。
浮動小数点演算の特性
- 相対誤差は一定: 大きい数も小さい数も、同程度の相対精度で表現される
- 桁落ち: 近い値どうしの引き算で有効桁数が減る
- オーバーフロー/アンダーフロー: 表現可能な範囲を超えると無限大やゼロになる
深層学習で起きやすい問題
勾配爆発・勾配消失
深いネットワークでは、連鎖律により勾配が多数の行列(またはその作用)の積になる。 各層での増幅/減衰が累積し、勾配が極端に大きく(爆発)または小さく(消失)なることがある。
条件数の問題
Hessian の条件数(最大固有値と最小固有値の比)が大きいと、 最適化が困難になる。これは損失関数の等高線が細長い楕円になっている状態に対応し、 勾配降下法が収束しにくくなる。
安定な計算の工夫
log-sum-exp トリック
$\log \displaystyle\sum_i e^{x_i}$ をそのまま計算すると、$x_i$ が大きいときにオーバーフローする。 以下のように書き換えると安定になる。
$$ \log \displaystyle\sum_i e^{x_i} = m + \log \displaystyle\sum_i e^{x_i - m} \quad \text{where} \quad m = \max_i x_i $$softmax + cross-entropy の融合
softmax と cross-entropy を別々に計算すると、中間結果で数値的な問題が生じやすい。
両者を融合した形で計算することで、安定性が向上する。
PyTorch の CrossEntropyLoss は内部でこの融合を行っている。
正規化の役割
Batch Normalization や Layer Normalization といった正規化層は、 単なる「学習を速くするテクニック」ではなく、 数値安定性と最適化の観点から本質的な役割を果たしている。
たとえば Batch Normalization は、ミニバッチ内での平均 $\mu$ と分散 $\sigma^2$ を用いて $$\hat{x} = \dfrac{x - \mu}{\sqrt{\sigma^2 + \epsilon}}$$ と入力を標準化した後、学習可能なスケール $\gamma$ とシフト $\beta$ を掛けて $y = \gamma \hat{x} + \beta$ を出力する。 以下では、この操作がなぜ最適化に効くのかを 3 つの観点から整理する。
なぜ正規化が必要か
1. 条件数の改善
入力のスケールが層ごとに大きく異なると、損失関数の等高線が歪み、 最適化が困難になる。正規化により各層の出力を一定のスケールに保つと、 Hessian の条件数が改善され、勾配降下法が効率的に動作するようになる。
2. 勾配のダイナミックレンジ制御
正規化により、各層を通過する勾配の大きさが適切な範囲に保たれ、 勾配爆発や勾配消失が起きにくくなる。 これは自動微分の連鎖律において、各項のスケールが制御されることを意味する。
3. 損失面の平滑化
正規化は損失関数の地形を平滑化する効果があることが知られている。 平滑な損失面では、勾配がより良い方向を指しやすく、 大きな学習率でも安定して学習できる。
Adam の $\epsilon$ の意味
Adam の更新式には、分母に $\epsilon$(通常 $10^{-8}$ 程度)が加えられている。
$$ \boldsymbol{x} \leftarrow \boldsymbol{x} - \alpha \dfrac{\boldsymbol{m}}{\sqrt{\boldsymbol{v}} + \epsilon} $$直感的には、勾配の2次モーメント $v_i$ が極端に小さいときに更新幅が爆発するのを防ぐ「安全装置」である。 より数学的に言えば、前処理行列 $\text{diag}(1/\sqrt{\boldsymbol{v}_k + \epsilon})$ の最大値に上限を設けることで、成分ごとの条件数の悪化を抑えている。 単なるゼロ除算回避ではなく、最適化の安定性そのものを担う構成要素である。
まとめ
- 自動微分は、連鎖律を計算グラフ上で自動適用する技術である
- 深層学習では、出力がスカラー(損失)であるため、VJP(逆方向モード)が効率的
- VJP は分母レイアウトの連鎖律と対応しており、勾配が入力と同じ形状になる
- 最適化アルゴリズムは、勾配に前処理をどう施すかの違いとして統一的に理解できる
- 数値安定性のために、計算式の工夫や正規化が重要である
- 正規化は、条件数の改善・勾配スケールの制御・損失面の平滑化を通じて最適化を支援する
- ベクトル微分/行列微分入門 - 基本概念と分野ごとの表記法
- ベクトル微分/行列微分の公式集 - 具体的な公式と証明
- テンソル微分入門 - 高階テンソルへの一般化
参考文献
- Griewank, A., & Walther, A. (2008). Evaluating Derivatives: Principles and Techniques of Algorithmic Differentiation (2nd ed.). SIAM. doi:10.1137/1.9780898717761
- Baydin, A. G., Pearlmutter, B. A., Radul, A. A., & Siskind, J. M. (2018). Automatic Differentiation in Machine Learning: a Survey. Journal of Machine Learning Research, 18(153), 1–43. arXiv:1502.05767
- Pearlmutter, B. A. (1994). Fast Exact Multiplication by the Hessian. Neural Computation, 6(1), 147–160. doi:10.1162/neco.1994.6.1.147
- Nocedal, J., & Wright, S. J. (2006). Numerical Optimization (2nd ed.). Springer. doi:10.1007/978-0-387-40065-5
- Higham, N. J. (2002). Accuracy and Stability of Numerical Algorithms (2nd ed.). SIAM. doi:10.1137/1.9780898718027