web-dev-qa-db-ja.com

tf.estimator input_fn内でtf.data初期化可能イテレーターを使用するにはどうすればよいですか?

tf.estimator.Estimator でトレーニングを管理したいのですが、 tf.data APIと一緒に使用するのに問題があります。

私はこのようなものを持っています:

def model_fn(features, labels, params, mode):
  # Defines model's ops.
  # Initializes with tf.train.Scaffold.
  # Returns an tf.estimator.EstimatorSpec.

def input_fn():
  dataset = tf.data.TextLineDataset("test.txt")
  # map, shuffle, padded_batch, etc.

  iterator = dataset.make_initializable_iterator()

  return iterator.get_next()

estimator = tf.estimator.Estimator(model_fn)
estimator.train(input_fn)

ユースケースにmake_one_shot_iteratorを使用できないため、私の問題は、input_fnmodel_fn内で初期化する必要のあるイテレータが含まれていることです(ここでは、 tf.train.Scaffold ローカル操作を初期化します)。

また、input_fn = iterator.get_nextだけを使用することはできないことを理解しました。そうしないと、他の操作が同じグラフに追加されません。

イテレータを初期化するための推奨される方法は何ですか?

10
guillaumekln

TensorFlow 1.5以降、input_fntf.data.Datasetを返すことができます。例:

def input_fn():
  dataset = tf.data.TextLineDataset("test.txt")
  # map, shuffle, padded_batch, etc.
  return dataset

c294fcfd を参照してください。


以前のバージョンでは、イテレーターの初期化子をtf.GraphKeys.TABLE_INITIALIZERSコレクションに追加し、デフォルトの初期化子に依存することができます。

tf.add_to_collection(tf.GraphKeys.TABLE_INITIALIZERS, iterator.initializer)
13
guillaumekln