私はデータ入力パイプラインを持っています:
tf.Tensor
_にキャストできないタイプの入力データポイント(dictsとwhatnot)私はこれを_tf.data
_パイプラインに適合させようとしましたが、複数のデータポイントの前処理を並行して実行することに行き詰まっています。これまでのところ、これを試しました:
Dataset.from_generator(gen)
を使用して、ジェネレーターで前処理を行います。これは機能しますが、prefetch
と偽のmap
呼び出しの配置に関係なく、各データポイントを順次処理します。並行してプリフェッチすることは不可能ですか?tf.py_function
_にカプセル化して、データセット上で並列にmap
できるようにしましたが、py_function
_の実行は(単一プロセス)pythonインタプリタに渡されるので、python =あまり役に立たないGILinterleave
を使用していくつかのトリックを実行できることを確認しましたが、最初の2つのアイデアから問題のないものは見つかりませんでした。ここで何か不足していますか?グラフで実行できるように前処理を変更する必要がありますか、それともマルチプロセスする方法はありますか?
これを行う以前の方法はkeras.Sequenceを使用することでしたが、うまくいきましたが、_tf.data
_ APIへのアップグレードを推進している人が多すぎます。 (地獄、tf 2.2でkeras.Sequenceを試してみても_WARNING:tensorflow:multiprocessing can interact badly with TensorFlow, causing nondeterministic deadlocks. For high performance data pipelines tf.data is recommended.
_)が生成されます
注:tf 2.2rc3を使用しています
入力パイプラインのbatch()
の前にmap()
を追加してみてください。
これは通常、小さなマップ関数のマップ関数呼び出しのオーバーヘッドを減らすことを目的としています。ここを参照してください https://www.tensorflow.org/guide/data_performance#vectorizing_mapping
ただし、これを使用して、マップへの入力のバッチを_py_function
_取得し、python multiprocessing
を使用して高速化することもできます。
これにより、GILの制限を回避して、tf.data.map()
の_num_parallel_calls
_を_py_function
_マップ関数で使用できなくすることができます。