tf.estimator.Estimator
でトレーニングを管理したいのですが、 tf.data
APIと一緒に使用するのに問題があります。
私はこのようなものを持っています:
def model_fn(features, labels, params, mode):
# Defines model's ops.
# Initializes with tf.train.Scaffold.
# Returns an tf.estimator.EstimatorSpec.
def input_fn():
dataset = tf.data.TextLineDataset("test.txt")
# map, shuffle, padded_batch, etc.
iterator = dataset.make_initializable_iterator()
return iterator.get_next()
estimator = tf.estimator.Estimator(model_fn)
estimator.train(input_fn)
ユースケースにmake_one_shot_iterator
を使用できないため、私の問題は、input_fn
にmodel_fn
内で初期化する必要のあるイテレータが含まれていることです(ここでは、 tf.train.Scaffold
ローカル操作を初期化します)。
また、input_fn = iterator.get_next
だけを使用することはできないことを理解しました。そうしないと、他の操作が同じグラフに追加されません。
イテレータを初期化するための推奨される方法は何ですか?
TensorFlow 1.5以降、input_fn
にtf.data.Dataset
を返すことができます。例:
def input_fn():
dataset = tf.data.TextLineDataset("test.txt")
# map, shuffle, padded_batch, etc.
return dataset
c294fcfd を参照してください。
以前のバージョンでは、イテレーターの初期化子をtf.GraphKeys.TABLE_INITIALIZERS
コレクションに追加し、デフォルトの初期化子に依存することができます。
tf.add_to_collection(tf.GraphKeys.TABLE_INITIALIZERS, iterator.initializer)