web-dev-qa-db-ja.com

kerasでtrain_on_batch()の使用は何ですか?

Train_on_batch()とfit()の違いは何ですか? train_on_batch()を使用する必要がある場合はどうなりますか?

21
Dhairya Verma

fitはKerasで一般的に利用可能なAPI関数ではないため、train_on_batchtrain(およびfit_generatorのようなバリエーション)と比較するつもりだと思います。

この質問の場合、それは 主著者からの簡単な回答

Fit_generatorを使用すると、検証データにもジェネレーターを使用できます。一般に、fit_generatorを使用することをお勧めしますが、train_on_batchを使用しても問題なく機能します。これらのメソッドは、さまざまなユースケースでの利便性のためだけに存在し、「正しい」メソッドはありません。

train_on_batchを使用すると、固定バッチサイズに関係なく、指定したサンプルのコレクションに基づいて明示的に重みを更新できます。これは、サンプルの明示的なコレクションでトレーニングする場合に使用します。このアプローチを使用して、従来のトレーニングセットの複数のバッチで独自の反復を維持できますが、fitまたはfit_generatorでバッチを反復できるようにする方が簡単です。

train_on_batchを使用するのが適切な場合の1つは、単一の新しいサンプルバッチで事前に訓練されたモデルを更新する場合です。モデルのトレーニングと展開をすでに完了しており、しばらくしてから、以前は使用したことがなかった新しいトレーニングサンプルのセットを受け取ったとします。 train_on_batchを使用して、これらのサンプルでのみ既存のモデルを直接更新できます。他の方法でもこれを行うことができますが、この場合はtrain_on_batchを使用することがかなり明示的です。

このような特別なケース(別のトレーニングバッチ間で独自のカーソルを維持する何らかの教育的理由がある場合、または特別なバッチでのセミオンライントレーニング更新の場合)を除き、常に常にfitを使用することをお勧めします(メモリに収まるデータ用)またはfit_generator(ジェネレーターとしてデータのバッチをストリーミングするため)。

26
ely

train_on_batch()は、たとえば、ステートフルLSTMを使用し、model.reset_states()への呼び出しを制御する必要がある場合など、LSTMの状態をより細かく制御できます。複数シリーズのデータ​​があり、各シリーズの後に状態をリセットする必要がある場合がありますが、train_on_batch()で実行できますが、.fit()を使用すると、ネットワークはすべてのシリーズのデータ​​でトレーニングされます状態をリセットします。正しいか間違っているかはありません。使用しているデータ、およびネットワークの動作方法によって異なります。

10
BigBadMe

Train_on_batchは、大きなデータセットを使用し、tfrecordsに書き込むための簡単にシリアル化可能なデータ(上位のnumpy配列など)を持たない場合、fit and fit generatorよりもパフォーマンスが向上します。

この場合、配列全体をnumpyファイルとして保存し、セット全体がメモリに収まらない場合に、それらの小さなサブセット(traina.npy、trainb.npyなど)をメモリにロードできます。その後、tf.data.Dataset.from_tensor_slicesを使用してから、サブデータセットでtrain_on_batchを使用し、別のデータセットをロードして、バッチで再度トレーニングを呼び出すなど、今ではセット全体でトレーニングを行い、正確にどのくらい、何を制御できるデータセットのがモデルをトレーニングします。その後、データセットから取得する単純なループと関数を使用して、独自のエポック、バッチサイズなどを定義できます。

1
Adam Collins