MNIST のサンプルの 大体 の流れを追ってみます。(詳しい処理はこちらで)
全体の流れはこうなっており
- ① モデルの生成
- ② オプティマイザの生成
- ③ 学習データのダウンロード
- ④ trainer、updater の生成
- ⑤ Extension の登録
- ⑥ 学習ループ
ソースはこんな感じです。
では見ていきます。
学習ループ
run
学習ループはTrainerのrunメソッドの中にあります。
run の中でいろいろやっているようなので見ていきます。
run → Updater
上のソースで updater とあるのは Updater の サブクラスStandardUpdaterです。そのインスタンスの update メソッドが学習ループの中で呼び出されています。後でソースを確認しますが、update メソッド の中でOptimizer → 順伝播、誤差逆伝播が実行されます。
データの管理は Iterator(のサブクラス)が行っており、バッチサイズ毎にデータを切り出して渡してくれます。
では、その辺の流れを StandardUpdater のソースで確認します。
Trainer から update が呼び出され、update_core の中でバッチ 1 回分の学習を行っています。繰り返しを管理するのは Trainer 側です。
Trainer と Extension
ログ出力や進捗バー表示といった拡張機能がExtension(の サブクラス) として部品化されています。
Extension は
- Trainer に登録され
- Trainer から起動され
ます。
起動条件となる トリガは登録時に設定されます。
Extension の登録、起動部分のソースを見てみます。
extend メソッドで登録、run の学習ループの中で起動しています。
評価機能も Extension
テストデータによる評価も Extension 化されており、クラス名は Evaluator となっています。
他の Extension と同じく Trainer から起動され、起動後は順伝播の処理を行います。
ログ出力
学習 時
ログ情報は Trainer のインスタンス変数 observation(辞書型)に一旦書き込まれ、その内容が LogReport クラスによって集計されファイルに出力されます。
ログ情報の中身は損失(loss)と精度(accuracy)です。この情報を observation へ実際に書き込むのは Classifier です。
評価 時
テストデータを使った評価の際も、同様に observation 経由でファイルに出力されます。
ただし、Evaluator の observation は いったん Trainer の observation に統合され、その後、まとめてファイルに出力されます。
まとめ
以上を図にしてみます。
煩雑になるので、Extension は一部だけにとどめます。
学習ループやデータ管理は、以前はアプリケーション側の実装になっていました。それが、抽象化されてフレームワーク内部に隠蔽 された格好です。
これで単純な実装ミスは減ると思います。でも、なんか難しくなった気がしないでもない、です。
0 件のコメント:
コメントを投稿