Chainer に LSTM というクラスがあります。名前が示すとおり、RNN(再 帰型ニューラルネット)の LSTM(長・ 短期記憶)を実現するためのクラスで、自然言語処理の公式サンプル ptb などで使用されています。
ソースを追ってみます。
2つの LSTM クラス
LSTM と名前のつくクラスは Chainer に 2 つ存在します。1 つは Chain のサブクラス、もう 1 つは Function のサブクラスです。ここでは便宜上、前者を LSTM (C)、後者を LSTM (F) と書いて区別します。
両クラスは下図の関係にあり、LSTM (C) が LSTM (F) を生成します。順伝播や誤差逆伝播の実際の計算は LSTM (F) で行います。
LSTM (C) の初期状態
LSTM (C) は内部に Linear インスタンスを 2 つ持っており、変数名はそれぞれ "lateral" と "upward" となっています。
(Linear クラスについて詳しくはこちら)
lateral と upward は LSTM (C) の初期化のタイミングで生成されます。
ソースは次のようになっています。
lateral と upward の役割
LSTM は時系列処理で使われます。時系列のデータを扱う場合、データの流れていく方向は下図のように 2 種類あります。1 つは前レイヤーから次レイヤーへの流れ(縦方向)、もう 1 つは時間に沿った t-1 から t+1 への流れ(横方向)です。
データが流れていく際に重みやバイアスを反映しますが、前レイヤーから受け取るデータには upward が、t-1 から受け取るデータには lateral がそれぞれ反映を行います。
LSTM (F) の生成
モデル生成直後、 LSTM (F) のインスタンスはまだ存在していません。Linear (F) のインスタンスも同様に存在していません。では、いつ生成されるかというと、学習ループが始まり順伝播が動き出してからです。
生成後の LSTM (F) と Linear (F) は下図の関係にあり、おおまかに言うと、Linear (F) で重み・バイアスを反映し、 LSTM (F) が LSTM 独自 の処理を行います。
時系列
RNN は時系列データを扱うため、時刻 t の結果を 時刻 t+1 に引き継ぎます。t と t+1 を図示すると
データが横につながっていきます。
ソースで確認すると
self.c はメモリセル値、self.h は出力値です。両方とも次回の__call__で使用されるので、t+1、t+2 ・・・ へとデータは引き継がれていきます。
LSTM (F) の内部
LSTM (F) の内部をのぞいてみます。
順伝播を図解
下図に描いたのは順伝播の処理ですが、メモリセルを中心に 3 つのゲート(入力ゲート、出力ゲート、忘却ゲート) が働いています。図では i ゲート、o ゲート、f ゲートと記載しています。
a、i、f、o の各値は、入力データや t-1 の出力から upward と lateral によってつくられます。
実際にデータを当てはめてみます。極端な例ですが、入力サイズが 2 、出力サイズが 3 の場合、下図のようになります(バッチサイズは 1)。
ptb の場合
Chainer の公式サンプル ptb の場合、出力は 650、バッチサイズは 20 なので、メモリセルのサイズはこうなります。
計算
LSTM (F) の役割は 順伝播と誤差逆伝播の計算です。該当のソースを確認しておきたいと思います。
順伝播
forward メソッドです。
のぞき穴(peephole)は実装されていないようです。
誤差逆伝播
誤差逆伝播は、順伝播で生成されたインスタンスを逆方向にたどりながら計算します。詳しくはこちら
LSTM (F) の backward メソッドは、t+1 と次レイヤー から δ(デルタ)を受け取り、δ を計算します。
Variable の加算
LSTM に関わる Function として、ここまでに LSTM (F) と Linear (F) の 2 つを取り上げましたが、実はもう 1 つ Function が関わっています。Variable に関する Function です。
Variable には __add__ や __sub__の特殊メソッドが定義されており、加算時は Add というクラスが生成される仕組みになっています。この Add が Function を継承しています。
(上の図にはこっそり書き込んであります。lateral と upward の出力を加算する場所)
Variable の加算が動くタイミングはここ、LSTM (C) の__call__メソッドです。
1 | lstm_in += self.lateral(self.h) |
このタイミングで Add クラスの forward が動き、誤差逆伝播で backward が動きます。
0 件のコメント:
コメントを投稿