2018年2月28日水曜日

数式で書き下すリカレントニューラルネットワーク(RNN)

昨今の人工知能ブームのおかげで、 TensorFlow  Deeplearning4j など、気軽にディープラーニングが試せるライブラリがどんどん出てきました。
もちろん、ライブラリを使ってガシガシ実験をし、ガシガシ効果を測定するのは大切なことです。しかし一方で、使っているモデルの理論を理解しないままただライブラリで提供されている手法を試している、ただパラメータチューニングをしている、といったケースも多くなってしまっているのではないでしょうか。
MICIN では、そんな曖昧な理解のまま人工知能技術を使うことがないよう、きちんと数式ベースでモデルの中身を理解することを推奨しています。 今回は、慢性疾患の予防・疾病マネジメントでも大事となってくる時系列データを扱うためのモデルリカレントニューラルネットワークRecurrent Neural Networks: RNN)の中身を3ステップに分けて追っていこうと思います。
 
Step 1
RNNは、「過去の隠れ層の状態も入力に含める」という点が他の一般的なニューラルネットワークと異なっており、これにより時系列データの学習が効率よく行えるモデルになっています。まずは最も単純化したRNNのグラフィカルモデルを描くと、下図のようになります。
時刻 tt において、入力層 x(t)x(t)、隠れ層 s(t)s(t)、出力層 y(t)y(t) と表されますが、隠れ層への入力は、 直前の時刻 t1t1 における隠れ層のニューロンの状態を表す s(t1)s(t1) x(t)x(t)に加わっています。また、 U,V,WU,V,W はそれぞれの層の間の重み行列です。
よって、まずはネットワークの順伝播を式で表すと、下記が得られます。
s(t)=f(Ux(t)+Ws(t1)+b)(1)(1)s(t)=f(Ux(t)+Ws(t1)+b)y(t)=g(Vs(t)+c)(2)(2)y(t)=g(Vs(t)+c)
ここで、 f,gf,g は活性化関数を、b,cb,c はバイアスベクトルを表しています。後の式の整理のため、下記を定義しておきます。
u(t):=Ux(t)+Ws(t1)+b(3)(3)u(t):=Ux(t)+Ws(t1)+bv(t):=Vs(t)+c(4)(4)v(t):=Vs(t)+c
これにより、
s(t)=f(u(t))(5)(5)s(t)=f(u(t))y(t)=g(v(t))(6)(6)y(t)=g(v(t))
と表すことができます。
さて、順伝播が終わったら、次は逆伝播(バックプロパゲーション)を行うことでモデルの学習が進むわけですが、これは誤差関数(評価関数)を設定し、その値を最小化することを考えます。よく用いられる誤差関数のひとつとして、下記の式で与えられる2乗誤差関数があります。
E:=12n|tnyn|2(7)(7)E:=12n|tnyn|2
ここで、 nn はトレーニングデータのインデックスを表し、 tt は教師データのラベルを表します。すなわち2乗誤差関数は、モデルの出力値(予測値)と実際の値との誤差を最小にしたいという、とても直観的な関数になっていることが分かるかと思います。
この誤差関数が最小になるよう、モデルのパラメータを更新していけばいいわけですが、そのためには誤差関数に対する各パラメータの勾配を求める必要があります。RNNのモデルパラメータは V,U,WV,U,W (およびバイアス)なので、これを求めると
EV=∂Ev(t)∂v(t)∂V=∂Ev(t)s(t)=s(t)eo(t)T(8)(8)EV=Ev(t)v(t)V=Ev(t)s(t)=s(t)eo(t)TEU=∂Eu(t)∂u(t)∂U=∂Eu(t)x(t)=x(t)eh(t)T(9)(9)EU=Eu(t)u(t)U=Eu(t)x(t)=x(t)eh(t)TEW=∂Eu(t)∂u(t)∂W=∂Eu(t)s(t−1)=s(t−1)eh(t)T(10)(10)EW=Eu(t)u(t)W=Eu(t)s(t1)=s(t1)eh(t)T
となります。ただし、ここで新しく定義した下記を用いています。
eo(t):=Ev(t)(11)(11)eo(t):=Ev(t)eh(t):=Eu(t)(12)(12)eh(t):=Eu(t)
この eo(t)eo(t) および eh(t)eh(t) はいわゆるバックプロパゲーションにおける 誤差 と呼ばれる項になります。
ちなみに、勾配の式(8), (9), (10)の最後に転置の TT がありますが、厳密にはこれは途中の偏微分式でも付けるべきものです。ここでは式を直観的にわかってもらうために、あえて TT を途中書いていません。以下も途中の式では書かない場合があるので、予めご了承ください。)
この誤差を求めればバックプロパゲーションができるわけですが、これは下記の式変形により求まります。
eo(t)=Ey(t)y(t)v(t)=(y(t)t(t))g(v(t))(13)eo(t)=Ey(t)y(t)v(t)(13)=(y(t)t(t))g(v(t))
eh(t)=Ey(t)y(t)v(t)v(t)s(t)s(t)u(t)=eo(t)TVf(u(t))(14)eh(t)=Ey(t)y(t)v(t)v(t)s(t)s(t)u(t)(14)=eo(t)TVf(u(t))
ここで、  はベクトルの要素積を表します。
よって、式(8), (9), (10), (13), (14)より、下記の式を用いてパラメータの学習ができることになります。
V(t+1)=V(t)ηs(t)eo(t)T(15)(15)V(t+1)=V(t)ηs(t)eo(t)TU(t+1)=U(t)ηx(t)eh(t)T(16)(16)U(t+1)=U(t)ηx(t)eh(t)TW(t+1)=W(t)ηs(t1)eh(t)T(17)(17)W(t+1)=W(t)ηs(t1)eh(t)T
ただし、ηη は学習率です。これがRNNのバックプロパゲーションの基本形となりますが、実はこれではまだ十分ではありません。
Step 2
Step 1 の段階で RNN のバックプロパゲーションを数式で表現することはでき、理論上はこれで学習ができるのですが、式(13)を見ると、かなり計算が煩雑になることが容易に想像がつきます。この煩雑さを避けるために、別の誤差関数を用いて評価することができないかを考えます。
一般的に、活性化関数 gg には softmax関数が用いられますが、これと下記の式で定義される負の交差エントロピー誤差関数を2乗誤差関数の代わりに用いることで、式をかなりすっきりさせることができます。
E:=nktnklogynk(18)(18)E:=nktnklogynk
ここで、 nn は式(7)は同様トレーニングデータのインデックスを表し、kk は出力層のユニット数を表します。(ここはベクトル表記にするとかえって式の見た目が煩雑になるためこのように書きました)
すると、式(13)は下のように書き換えることができ、キレイに式がまとまります。
eo(t)=y(t)t(t)(19)(19)eo(t)=y(t)t(t)
Step 3
さて、前述したとおり、式(15), (16), (17) では学習が十分とは言えません。これらの式は、直近の過去である t1t1 の隠れ状態にしかバックプロパゲーションしていないためです。より長い時間を遡って誤差を反映させ、より長い時間情報を学習するには、下図のように、RNNを時間軸で展開して考える必要があります。
そして、このより長い過去にまで誤差を反映させる手法が BPTTBack Propagation Through Time)という、時間を遡ってバックプロパゲーションするアルゴリズムであり、 RNNがディープラーニングの手法とされているのは、このように何層にも遡って学習をするためです。
ではまず、誤差を過去に反映させるために、1ステップだけ遡って、
eh(t1)=Eu(t1)(20)(20)eh(t1)=Eu(t1)
を求めることを考えてみましょう。これは、下記の式変形により求まります。
eh(t1)=Eu(t)u(t)u(t1)=eh(t)u(t)s(t1)s(t1)u(t1)=eh(t)Wf(u(t1))(21)eh(t1)=Eu(t)u(t)u(t1)=eh(t)u(t)s(t1)s(t1)u(t1)(21)=eh(t)Wf(u(t1))
これで 誤差 eh(t1)eh(t1)  eh(t)eh(t) の式で表すことができたので、式(21) を一般化して、
eh(tτ1)=eh(tτ)Wf(u(tτ1))(22)(22)eh(tτ1)=eh(tτ)Wf(u(tτ1))
が得られます。この ττ がすなわち、過去どれくらいまで遡って学習するかを決めるパラメータです。理想的には ττ は大きければ大きいほどいいのですが、学習コストが莫大になってしまうので、 τ=3τ=3 などが一般的によく用いられます。
以上より、最終的にRNNのパラメータ更新式は下記のようにまとまります。ここで、 VV は過去は関係ないので、式(15)と式(23)は同じ形になっていることに注意してください。
V(t+1)=V(t)ηs(t)eo(t)T(23)(23)V(t+1)=V(t)ηs(t)eo(t)TU(t+1)=U(t)ητz=0x(tz)eh(tz)T(24)(24)U(t+1)=U(t)ηz=0τx(tz)eh(tz)TW(t+1)=W(t)ητz=0s(tz1)eh(tz)T(25)(25)W(t+1)=W(t)ηz=0τs(tz1)eh(tz)T
いかがだったでしょうか。RNNはディープラーニングのモデルの中でもかなり複雑だと思いますが、こうして式で追ってみると、他の手法と計算の仕方自体はあまり変わらないと言えるのではないでしょうか。
数式ベースでモデルを理解することで、他のモデルへの理解も深まると思います。今後も様々なモデルの理論や実装を記事にしていくつもりですが、ぜひ自身でもチャレンジしてみてください。 そして、ぜひ MICIN で一緒に人工知能の開発を進めましょう!
 
 

0 件のコメント:

コメントを投稿