2017年3月14日火曜日

Chainerのソースを解析。学習ループは Trainer に

MNIST のサンプルの 大体 の流れを追ってみます。(詳しい処理はこちらで)

全体の流れはこうなっており

  1. ① モデルの生成
  2. ② オプティマイザの生成
  3. ③ 学習データのダウンロード
  4. ④ trainer、updater の生成
  5. ⑤ Extension の登録
  6. ⑥ 学習ループ

ソースはこんな感じです。

では見ていきます。

学習ループ

run

学習ループはTrainerrunメソッドの中にあります。

run の中でいろいろやっているようなので見ていきます。

run → Updater

上のソースで updater とあるのは Updater の サブクラスStandardUpdaterです。そのインスタンスの update メソッドが学習ループの中で呼び出されています。後でソースを確認しますが、update メソッド の中でOptimizer → 順伝播、誤差逆伝播が実行されます。

データの管理は Iterator(のサブクラス)が行っており、バッチサイズ毎にデータを切り出して渡してくれます。
では、その辺の流れを StandardUpdater のソースで確認します。

Trainer から update が呼び出され、update_core の中でバッチ 1 回分の学習を行っています。繰り返しを管理するのは Trainer 側です。

Trainer と Extension

ログ出力や進捗バー表示といった拡張機能がExtension(の サブクラス) として部品化されています。
Extension は

  • Trainer に登録され
  • Trainer から起動され

ます。

起動条件となる トリガは登録時に設定されます。

Extension の登録、起動部分のソースを見てみます。

extend メソッドで登録、run の学習ループの中で起動しています。

評価機能も Extension

テストデータによる評価も Extension 化されており、クラス名は Evaluator となっています。
他の Extension と同じく Trainer から起動され、起動後は順伝播の処理を行います。

ログ出力

学習 時

ログ情報は Trainer のインスタンス変数 observation(辞書型)に一旦書き込まれ、その内容が LogReport クラスによって集計されファイルに出力されます。

ログ情報の中身は損失(loss)精度(accuracy)です。この情報を observation へ実際に書き込むのは Classifier です。

評価 時

テストデータを使った評価の際も、同様に observation 経由でファイルに出力されます。
ただし、Evaluator の observation は いったん Trainer の observation に統合され、その後、まとめてファイルに出力されます。

まとめ

以上を図にしてみます。
煩雑になるので、Extension は一部だけにとどめます。

学習ループやデータ管理は、以前はアプリケーション側の実装になっていました。それが、抽象化されてフレームワーク内部に隠蔽 された格好です。

これで単純な実装ミスは減ると思います。でも、なんか難しくなった気がしないでもない、です。

0 件のコメント:

コメントを投稿