Tensorflowの正しいバッチ正規化機能とは何ですか?
Tensorflow 1.4で、バッチ正規化を行う2つの関数を見つけましたが、それらは同じように見えます。
どの機能を使用すればよいですか?どちらがより安定していますか?
リストに追加するだけで、テンソルフローでバッチノルムを実行する方法がいくつかあります。
tf.nn.batch_normalization
は低レベルの操作です。呼び出し元は、mean
およびvariance
テンソル自体を処理する責任があります。tf.nn.fused_batch_norm
は、前のものと同様の別の低レベルopです。違いは、4D入力テンソル用に最適化されていることです。これは、畳み込みニューラルネットワークの通常の場合です。tf.nn.batch_normalization
は、1より大きいランクのテンソルを受け入れます。tf.layers.batch_normalization
は、前のopsに対する高レベルのラッパーです。最大の違いは、実行中の平均テンソルと分散テンソルの作成と管理を行い、可能な場合は高速融合演算を呼び出すことです。通常、これはデフォルトの選択肢である必要があります。tf.contrib.layers.batch_norm
は、コアAPI(つまり、tf.layers
)に移行する前のバッチ標準の初期実装です。将来のリリースで削除される可能性があるため、使用は推奨されません。tf.nn.batch_norm_with_global_normalization
はもう1つの推奨されないopです。現在、tf.nn.batch_normalization
への呼び出しを委任していますが、将来的には削除される可能性があります。- 最後に、Kerasレイヤー
keras.layers.BatchNormalization
もあります。これは、テンソルフローバックエンドの場合にtf.nn.batch_normalization
を呼び出します。
doc 、tf.contrib
は、揮発性または実験的なコードを含む貢献モジュールです。 function
が完了すると、このモジュールから削除されます。履歴バージョンとの互換性を保つために、現在2つあります。
したがって、前者tf.layers.batch_normalization
がおすすめ。