.fit_generator()
または.fit()
を使用して画像分類子をトレーニングし、辞書をclass_weight=
に引数として渡します。
TF1.xでエラーが発生したことはありませんが、2.1ではトレーニングを開始すると次の出力が表示されます。
WARNING:tensorflow:sample_weight modes were coerced from
...
to
['...']
...
から['...']
に強制変換するとはどういう意味ですか?
tensorflow
のリポジトリに関するこの警告のソースは here であり、コメントは次のとおりです。
Sample_weight_modesをターゲット構造に強制変換しようとしました。これは、Modelが内部表現の出力を平坦化するという事実に暗黙的に依存しています。
これは偽のメッセージのようです。 TensorFlow 2.1にアップグレードした後も同じ警告メッセージが表示されますが、クラスの重みまたはサンプルの重みをまったく使用していません。私はこのようなタプルを返すジェネレータを使用します:
return inputs, targets
そして今、私はそれを次のように変更して警告を消しました:
return inputs, targets, [None]
これが関連するかどうかはわかりませんが、私のモデルでは3つの入力を使用しているため、inputs
変数は実際には3つのnumpy配列のリストです。 targets
は単一のnumpy配列です。
いずれにせよ、それは単なる警告です。トレーニングはどちらの方法でも問題なく機能します。
このバグはTensorFlow 2.2で修正されたようで、すばらしいです。ただし、上記の修正はTF 2.2では失敗します。これは、サンプルの重みの形状を取得しようとするためで、明らかにAttributeError: 'NoneType' object has no attribute 'shape'
で失敗します。そのため、2.2にアップグレードするときに上記の修正を元に戻します。
これはテンソルフローのバグであり、デフォルトのパラメータsample_weight_mode=None
を使用してmodel.compile()
を呼び出し、次にsample_weight
またはclass_weight
を指定してmodel.fit()
を呼び出すと発生すると考えられます。
Tensorflowリポジトリから:
fit()
は最終的に_process_training_inputs()
を呼び出します_process_training_inputs()
setssample_weight_modes = [None]
に基づいてmodel.sample_weight_mode = None
を作成し、sample_weight_modes = [None]
でDataAdapter
を作成しますDataAdapter
は、 初期化 中にsample_weight_modes = [None]
を使用してbroadcast_sample_weight_modes()
を呼び出しますbroadcast_sample_weight_modes()
期待しているようですsample_weight_modes = None
を受け取りますが、[None]
を受け取ります[None]
がsample_weight
/class_weight
とは異なる構造であると断言し、sample_weight
/class_weight
の構造に適合させることによってNone
に上書きし、警告を出力しますただし、DataAdapter
のsample_weight_modes
がNone
に設定されているため、これはfit()
には影響しません。
Tensorflow documentation は、sample_weight
はnumpy-arrayでなければならないことを述べていることに注意してください。代わりにfit()
をsample_weight.tolist()
で呼び出すと警告は表示されませんが、_process_numpy_inputs()
が呼び出されたときにsample_weight
がNone
に暗黙的に上書きされます preprocessing し、1より大きい長さの入力を受け取ります。
私はあなたの要旨を取り、TFAの代わりにTensorflow 2.0をインストールしましたが、そのような警告なしで機能しました。
これが完全なコードの Gist です。 Tensorflowをインストールするためのコードを以下に示します。
!pip install tensorflow==2.0
成功した実行のスクリーンショットを以下に示します。
更新:このバグはTensorflow Version 2.2.
で修正されています