データ分析部の島田です。今回はDeepLearningがなぜうまく学習出来ているのか、についてサーベイしてみました(簡単なコード付きです)。
  記事アウトライン
    §   用語の解説
  §   事前知識:NNがうまく学習できなかった理由
  §   DeepLearningがうまく学習できる理由
  §   参考コード
  §   まとめ
  
  
  用語の解説
    以降のsectionで出てくる単語についてまとめておきます。
  Hessian(ヘッセ行列)
    (多変数スカラー値)関数の二階偏導関数全体が作る正方行列で対称行列。
  Hessianの固有値の符号をみることにより極小点や凸性の判定を行えます。
  Hessianの固有値が全て正であれば、凸関数になるので大域解が求まることを保証できたりします。
  次に示す鞍点では、Hessianの固有値が正負の両者を含んでいます。
  鞍点
    最適化対象の関数が鞍のような形になっている部分。
  鞍点はどの方向に対しても平坦に近くなっており、そのため勾配がゼロに近くなり学習が進まなくります。
  鞍点などの停留点に到達して学習が停滞している様はプラトーと呼ばれています。
  Alec Radfordさんが作成した下のアニメーションは、DeepLearningで使用されている最適化アルゴリズムにおける鞍点での動きを示しています。
  モーメンタムがないと鞍点に捕まってしまって抜け出せていない様子もわかります。
  
  出典:http://imgur.com/a/Hqolp
  モーメンタム(モーメンタム法)
    SGDを最適化の進行方向に加速させて、振動を抑制し学習の停滞を少なくする手法です。
  物理でいうところの慣性です。
  一時点前の進行方向に対してある程度の速度成分を維持していることになります。
  これにより鞍点から脱する事が出来ます。
  事前知識
    まず、DeepLearningがうまく学習出来ている理由について見る前に、うまく学習出来ていなかった理由について言及します。
  ニューラルネットワークの層を深くするとうまく学習出来なかった理由としては以下があげられます。
  §   勾配消失
  §   過学習
  §   鞍点の増加
  従来問題となっていた、勾配消失1や過学習を解決したのがReluやDropoutのような手法でした。
  正則化という観点では、FC層ではなくConv.層を使うことで共有されるパラメタが増え(パラメタ数自体は減る)正則化の効果がある2という見方もあります。
  次のsubsectionで紹介する得居さんの資料に詳しく説明がありますので、ご興味のある方はそちらを参考にして頂きたく。
  最近では、計算量や使用メモリ削減のためにFC層の使用をさけて全てConv.層で構築されたネットワーク(FCN)も出てきています。
  上の問題がクリアされた事により、層を深くし、パラメタ数を多くする事が可能になったのですが、次は高次元になると鞍点が増加しやすくなりうまく学習が出来なくなってしまうのではないか、という懸念が生まれてきます。
  鞍点の増加については次のsubsectionで説明します。
  鞍点の増加
    最適化から見たディープラーニングの考え方,  得居, OR学会誌'15には以下のような記述があります。
  ヘッセ行列の負の固有値の個数を指数というが,停留点は指数によって分類できる.そこで[11]  では,高次元のガウス確率場から生成された関数は,ほとんどの停留点が鞍点であることは,直感的に次のように理解できる.目的関数がランダムであり,停留点のヘッセ行列の要素が平均 0  の同じ分布に従うならば,その固有値の分布はウィグナーの半円分布に従う.これは 0 を平均・中央値とする半楕円形をした分布で,特に正負の値が半々の確率で生成される.このとき,停留点の指数が 0  となる確率は,次元を高くするにつれて指数的に減少する.
  つまり、次元数が高く(パラメタ数が多く)なればなるほど鞍点が発生する確率が高くなります。
  鞍点は局所解よりも性能が悪いので、鞍点を素早く抜けられる事も収束性能に関わってきます。
  鞍点を抜ける原始的な方法としてはモーメンタムがあります。
  今回は実験条件をある程度固定したいので、学習係数係数を一定としてmomentumSGDを使用しています。
  上の図で見て頂いたようにmomentumSGDより賢い最適化関数がありますので、通常はそちらを使用されるのが良いと思います。
  DeepLearningがうまく学習できる理由
    理論的解析
    loss関数がだいたい凸なため大域解に到達する事が可能、というのが理論的な解析で明らかになってきているようです。
  Deep Learning without Poor Local Minima,Kenji Kawaguchi, NIPS'16などの最新の理論研究では、いくつかの仮定が付いてはいますが局所解が大域解になるという事が証明されています。
  まだ完全に一般化されているわけではないようですが、こういった理論解析が進んできているようです。
  Why does Deep Learning work?というblogにはspin glassの観点からみた場合の話がされています。
  エネルギー関数の形状の図などがあり視覚的にもわかりやすいと思います。
  実験的解析
    今回は実験的に学習過程を解析しているQualitatively characterizing neural network optimization problems, Goodfellow+, ICLR'15を例に実験してみました。
  論文中では2次元のデータで多くの情報を語ってくれるのはloss curveである、という記載がありloss curveによる可視化が行われています。
  この論文がどのようなことを示したかと言うと、学習の始点と結果的な終点を線形で中割りしてその経路に沿ってloss curveを見るとなめらかな下り坂だったということです。
  "学習の始点と結果的な終点を線形で中割り"というのは、具体的な操作として、パラメタの初期値と学習の終点でのパラメタを線形的に補間することに対応します。
  パラメタを線形的に変化させている部分を式で書くと となります。
となります。
  ここで、 は初期状態のパラメタ、
は初期状態のパラメタ、 は学習を打ち切った時点でのパラメタです。
は学習を打ち切った時点でのパラメタです。
  これは学習が障壁なく進んでいるとすればパラメタを線形に変化させるとlossも線形的に変化するはず、という仮定のもとlossを観測していることになります。
  通常はlossを使ってweightを更新していくので、通常のtrainingとは逆の操作をしている事になります。
  論文中の実験結果を見ると、ほとんどの場合大きな障壁なくlossがスムーズに減少している様子が伺えます。
  つまり多数の局所解や鞍点により学習が極端に難しいと長く考えられていたが、実験してみると学習の経路以外も意外とスムーズな下りの関数であったということです。
  今回は論文内容の追試も兼ねて、以下の2パターンについて実験しました(論文中ではもっと多くのバリエーションで実験が行われています)。
| 層数 | 隠れ層UNIT数 | 図中のLABEL | 
| 4 | 100 | 4nn_hlunit=100 | 
| 3 | 1000 | 3nn_hlunit=1000 | 
  通常のtraining時のloss-epoch curveは以下になります。
  どちらの構造でも綺麗に収束しています。
  
  線形的に補間したパラメタの値を用いてlossを計算したloss-alpha curveが以下になります。
  論文中ではパラメタの初期値と収束後のパラメタを線形的に変化させているのですが、今回は1epoch時点のパラメタと200epoch時点のパラメタを使用し、alphaを[0, 1]で変化させてlossを計算しています。
  そのため、スタート地点でのlossの値が小さ目に出ています。
  下の図を見ると、4nn_hlunit=100では一本の滑らかなスロープを障壁なしに下っていっているような様子がわかります。
  最適化対象の関数が凸に近い形状になっていると判断できます。
  一方、3nn_hlunit=1000では線形的にlossが落ちておらず、alpha=[0.0, 0.3]付近では最適化対象の関数が非凸な形状になっている事が予想されます。
  上の図では停滞している様子は見られなかったため、非凸な形状の部分を上手く避けて(通り抜けて)学習が進んだと考えられます。
  
  参考コード
    今回はMNISTデータセットを使用し、chainerのMNIST   exampleのコードを変更したもので実験を行っています。
  コードとしては主に、
  §   MNISTデータでの訓練用script
  §   学習済みモデルのパラメタから学習過程を眺めるためのscript
  の2つを作成しています。
  他、MNIST用のNNモデルのコードを分離して2つのscriptで使いまわしています。
  こちらも以下に貼り付けてあります。
  実行環境はPython 3.5.1、Chainer 1.13.0です。
  MNIST用の4層NNモデル
    一般的な全結合のNNです。
  Dropoutのハイパパラメタはdefault値の0.5のまま使用しています。
  lossの計算をしていないのは、chainerのclassifierでwrapしているためです。
  classifierではこちらで指定しなければ、softmax_cross_entropyを計算してくれるようになっています。
| 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 | 
 
 
     
 
 
 
 
 
 
 
 
   
 
   
 
 
 | 
  MNISTデータでの訓練用script
    PFNさんが用意してくださっているtrain_mnist.pyをほぼそのまま使用しています。
  今回はtrainerではなくmodelのsnapshotが欲しかったので、snapshot_objectでmodelのobjectを保存するようにしています。
  SGDのmomentumはdefaultの0.9としています。
| 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 | 
 
 
 
 
 
     
 
 
 
 
 
   
 
 
 
   
 
 
   
 
   
 
 
   
 
 
   
 
 
   
   
 
 
 
 
 
 
 
 
 
 
   
 
 
   
 
 
 
 
 
 
 
 
 
 
 | 
  学習済みモデルのパラメタから学習過程を眺めるためのscript
    classifierを使用しているので、weightを取得する際に一度predictorを経由する必要があります。
  この点、少しわかりわかりづらくなっていると思います。
| 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 | 
 
 
 
 
 
 
 
 
 
 
     
 
   
 
     
 
 
 
 
 
 
 
 
 
 
     
 
 
 
 
 
 
 
 
 
     
 
   
 
 
 
 
   
   
 
 
 
 
 
   
 
 
 
 
 
 
   
 
 
 
 
 
 
 
   
 
 
   
 
 
     
 
 
   
 
 
 
   
 
 
 
 
 
 
 
 
 
 
 | 
  まとめ
    無駄にパラメタ数を増やしてみると完全にlinearではなくloss関数が歪んでいることがわかりました。
  loss関数が非凸であっても学習に対する大きな障壁がなくスムーズに学習が進んでいる様子も確認出来ました。
  今回紹介した内容とは異なりますが、DeepLearningのモデルが画像のどこをみて答えを出したか などの研究3も進んできているので、学習過程などDeepLearningの気持ちを理解する研究が更に進めば、お客さんに対する説明力も更に高まると考えています。
  今後もDeepLearningの可視化や理論的/実験的解析の研究に注目していきたいと思います。
  1.     低層のパラメタ勾配がほぼゼロになってしまう現象。活性化関数の微分の積を繰り返す事によって発生する事が多い。重みの初期値に依存する場合もありうる。 
  2.     パラメタを共有すると、少ないパラメタ(表現力の低いモデル)で必要な情報を抽出しようとするため正則化効果が見込まれる。 
  3.     "Why Should I Trust You?"   Explaining the Predictions of Any Classifier, Ribeiro+, '16やOBJECT   DETECTORS EMERGE IN DEEP SCENE CNNS, Zhou+, '15 
0 件のコメント:
コメントを投稿