2018年2月28日水曜日

数式で書き下す長短期記憶(LSTM)

RNNでは時間を遡ってバックプロパゲーションをすることにより、時系列に含まれる情報を隠れ層に保持しておくことができるとされていますが、一方で、遡る時間 ττ は理論上は無限にできても、実用上は τ=3τ=3 くらいが限度という問題がありました。(これをtruncated BPTTと呼びました。)
この問題に対処すべく考案されたのが 長短期記憶Long Short Term Memory: LSTM)です。LSTMは、隠れ層の各ニューロンに対して、
  • CEC (Constant Error Carousel)
  • 入力ゲート
  • 忘却ゲート
  • 出力ゲート
の4つを足すことで、 LSTMブロック (LSTM block) あるいは LSTM記憶ブロック (LSTM memory block) と呼ばれる回路のような仕組みを構成し、時系列データの学習をより効率的に行えるようにしたモデルです。 グラフィカルモデルで表すと、下図のようになります。
この点線で囲まれた部分がLSTMブロックになります。 f,gf,g は活性化関数を表し、x(t),h(t)x(t),h(t) はそれぞれ入力値、LSTMの出力値(=隠れ層の値)を、 i(t),f(t),o(t)i(t),f(t),o(t) はそれぞれ入力ゲート、忘却ゲート、出力ゲートの値を表しています。 もはや隠れ層の各ユニットはひとつのニューロンではなく、回路で構成されていることが分かります。

さて、見た目上はとても複雑な作りをしているLSTMですが、考えるべきことは他のモデルと変わりありません。順伝播・逆伝播をきちんと数式で書けさえすれば、実装に落とし込むことができます。上図に沿って数式で表してみましょう。まずは順伝播ですが、これは下のように書くことができます。入力層からの値と、各ゲートの値およびセル(CEC)の値を考える必要があることに注意してください。
a(t)=f(Wcx(t)+Uch(t1)+bc)i(t)=σ(Wix(t)+Uih(t1)+bi)f(t)=σ(Wfx(t)+Ufh(t1)+bf)o(t)=σ(Wox(t)+Uoh(t1)+bo)(1)(2)(3)(4)(1)a(t)=f(Wcx(t)+Uch(t−1)+bc)(2)i(t)=σ(Wix(t)+Uih(t−1)+bi)(3)f(t)=σ(Wfx(t)+Ufh(t−1)+bf)(4)o(t)=σ(Wox(t)+Uoh(t−1)+bo)
ただし、W,UW,U は重み行列を、bbはバイアスベクトルを表し、σσ はゲートにおける活性化関数を表しています(シグモイド関数でなくても問題ありません)。
ここで、式の整理のために下記の定義をしておきます。
a(t):=f(^a(t))i(t):=σ(^i(t))f(t):=σ(^f(t))o(t):=σ(^o(t))(5)(6)(7)(8)(5)a(t):=f(a^(t))(6)i(t):=σ(i^(t))(7)f(t):=σ(f^(t))(8)o(t):=σ(o^(t))
すると、式(1)~(4) は下記の式でまとまります。
s(t):=⎛⎜
⎜⎝^a(t)^i(t)^f(t)^o(t)⎞⎟
⎟⎠=⎛⎜
⎜⎝WcUcbcWiUibiWfUfbfWoUobo⎞⎟
⎟⎠⎛⎜⎝x(t)h(t1)1⎞⎟⎠(9)(9)s(t):=(a^(t)i^(t)f^(t)o^(t))=(WcUcbcWiUibiWfUfbfWoUobo)(x(t)h(t−1)1)
また、誤差を記憶しておくのに重要なセルの値は、下式で表せます。
c(t)=i(t)a(t)+f(t)c(t1)(10)(10)c(t)=i(t)a(t)+f(t)c(t1)
ただし、⊙⊙ はベクトルの要素積を表します。これらにより、隠れ層の出力 h(t)h(t) は下式となります。
h(t)=o(t)g(c(t))(11)(11)h(t)=o(t)g(c(t))
以上が順伝播となります。深層学習のモデルとしては、更にここに出力層を加えて学習を行うことになります。

続いて、逆伝播を考えていきます。他のディープラーニングのモデルと同様、誤差関数を EE としたときの各モデルパラメータ(重み W,UW,U および バイアス bb)に対する勾配を求めていくことになりますが、LSTMブロックにはセルおよびゲートがあるので、これらの誤差を求める必要があります。
出力層における誤差はすでに伝播してきているため、隠れ層の出力における誤差
eh(t):=∂Eh(t)(12)(12)eh(t):=∂E∂h(t)
は既知であることに注意してください。
まず、出力ゲートの誤差は式(11)より
eo(t):=∂Eo(t)=∂Eh(t)∂h(t)∂o(t)=eh(t)g(c(t))(13)(13)eo(t):=∂E∂o(t)=∂E∂h(t)∂h(t)∂o(t)=eh(t)g(c(t))
となります。セルの誤差は少し注意が必要で、勾配は
Ec(t)=∂Eh(t)∂h(t)∂c(t)=eh(t)o(t)g(c(t))(14)(14)∂E∂c(t)=∂E∂h(t)∂h(t)∂c(t)=eh(t)o(t)g′(c(t))
となりますが、セルの誤差はその場にとどまり続けるので、
ec(t)+=∂Ec(t)(15)(15)ec(t)+=∂E∂c(t)
と、式で表現しようとした場合は ++ がつくことが他の誤差項とは異なります。 また、t1t−1までの誤差も考慮する必要があるので、下記も求めておきます。
Ec(t−1)=∂Ec(t)∂c(t)∂c(t−1)=ec(t)f(t)(16)(16)∂E∂c(t−1)=∂E∂c(t)∂c(t)∂c(t−1)=ec(t)f(t)
ここで、式(10)を式変形に用いました。

セルの誤差(勾配)が求まったので、残りのパラメータの誤差も求めることができます。忘却ゲートの誤差は、
ef(t):=∂Ef(t)=∂Ec(t)∂c(t)∂f(t)=ec(t)c(t1)(17)(17)ef(t):=∂E∂f(t)=∂E∂c(t)∂c(t)∂f(t)=ec(t)c(t1)
また、入力ゲートの誤差は、
ei(t):=∂Ei(t)=∂Ec(t)∂c(t)∂i(t)=ec(t)a(t)(18)(18)ei(t):=∂E∂i(t)=∂E∂c(t)∂c(t)∂i(t)=ec(t)a(t)
で与えられます。また、LSTMブロックの入力に対する誤差は、
ea(t):=∂Ea(t)=∂Ec(t)∂c(t)∂a(t)=ec(t)i(t)(19)(19)ea(t):=∂E∂a(t)=∂E∂c(t)∂c(t)∂a(t)=ec(t)i(t)
となります。 よって、活性前の項の誤差を求めると、式(5)~(8)より、
e^a(t):=∂E^a(t)=∂Ea(t)∂a(t)∂^a(t)=ea(t)f(^a(t))(20)(20)ea^(t):=∂E∂a^(t)=∂E∂a(t)∂a(t)∂a^(t)=ea(t)f′(a^(t))

e^i(t):=∂E^i(t)=∂Ei(t)∂i(t)∂^i(t)=ei(t)σ(^i(t))(21)(21)ei^(t):=∂E∂i^(t)=∂E∂i(t)∂i(t)∂i^(t)=ei(t)σ′(i^(t))

e^f(t):=∂E^f(t)=∂Ef(t)∂f(t)∂^f(t)=ef(t)σ(^f(t))(22)(22)ef^(t):=∂E∂f^(t)=∂E∂f(t)∂f(t)∂f^(t)=ef(t)σ′(f^(t))

e^o(t):=∂E^o(t)=∂Eo(t)∂o(t)∂^o(t)=eo(t)σ(^o(t))(23)(23)eo^(t):=∂E∂o^(t)=∂E∂o(t)∂o(t)∂o^(t)=eo(t)σ′(o^(t))
が求まり、式(20)~(23)をまとめると、式(9)より、
es(t):=∂Es(t)=⎛⎜
⎜⎝e^a(t)e^i(t)e^f(t)e^o(t)⎞⎟
⎟⎠(24)(24)es(t):=∂E∂s(t)=(ea^(t)ei^(t)ef^(t)eo^(t))
となります。
ここで、式(9)の右辺を更に変形すべく、下記で定義される W,z(t)W,z(t) を用いると、
W:=⎛⎜
⎜⎝WcUcbcWiUibiWfUfbfWoUobo⎞⎟
⎟⎠(25)(25)W:=(WcUcbcWiUibiWfUfbfWoUobo)
z(t):=⎛⎜⎝x(t)h(t1)1⎞⎟⎠(26)(26)z(t):=(x(t)h(t−1)1)
(9)
s(t)=Wz(t)(27)(27)s(t)=Wz(t)
と、一般的な線形活性の式の形となるため、
ez(t):=∂Ez(t)=WTes(t)(28)(28)ez(t):=∂E∂z(t)=WTes(t)
および
EW(t)=es(t)z(t)T(29)(29)∂E∂W(t)=es(t)z(t)T
が得られます。よって、まず
eh(t1):=∂Eh(t1)(30)(30)eh(t−1):=∂E∂h(t−1)
は式(28)のベクトル成分を比較することで求まることが分かります。
また、式(29)より、入力 x(t)x(t)  t=1,...,Tt=1,...,T だとすると、
EW=Tt=1EW(t)(31)(31)∂E∂W=∑t=1T∂E∂W(t)
となり、確率的勾配降下法を用いることでモデルのパラメータを更新することができるようになります。
以上がLSTMの理論(数式)となります。グラフィカルモデルはとても複雑ですが、実は式で表してみると単純な線形代数および偏微分の組み合わせに過ぎないことがわかります。 画像認識に関しては畳み込みニューラルネットワーク(CNN)が用いられるケースが多いですが、時系列のデータを扱うにはRNNLSTMなども積極的に試してみると色々と面白い結果が得られるのではないでしょうか。

医療に対してAI技術を活用している事例に関しては、画像診断の領域が特に活発に見えます。これはおそらく画像認識とディープラーニング(特にCNN)の相性がいいことが主たる理由だと思いますが、弊社では、画像だけでなく、自然言語処理はもちろん、服薬遵守に関する行動といった時系列データについてもデータマイニングを行い、医療に役立つ新たな知見が得られないかアプローチしています。 ミスが許されない医療分野だからこそ、ディープラーニングをはじめとした人工知能のモデルも、ライブラリをただ使うだけではなく、理論をきちんと理解して開発に取り組むことを弊社では心がけています。今後も人工知能やアプリケーションに関する記事をアップデートしていく予定なので、ぜひご期待ください!
 
 

0 件のコメント:

コメントを投稿