web-dev-qa-db-ja.com

TensorFlowのDataset APIを使用してデータセットを数回反復する方法は?

データセットの値を複数回出力する方法は? (データセットはTensorFlowのデータセットAPIによって作成されます)

import tensorflow as tf

dataset = tf.contrib.data.Dataset.range(100)
iterator = dataset.make_one_shot_iterator()
next_element = iterator.get_next()
sess = tf.Session()
Epoch = 10

for i in range(Epoch):
   for j in range(100):
      value = sess.run(next_element)
      assert j == value
      print(j)

エラーメッセージ:

tensorflow.python.framework.errors_impl.OutOfRangeError: End of sequence
 [[Node: IteratorGetNext = IteratorGetNext[output_shapes=[[]], output_types=[DT_INT64], _device="/job:localhost/replica:0/task:0/cpu:0"](OneShotIterator)]]

これを機能させる方法は?

9
void

まず、 Data Set Guide を読むことをお勧めします。 DataSet APIのすべての詳細が説明されています。

あなたの質問は、データを数回繰り返すことです。そのための2つのソリューションを次に示します。

  1. すべてのエポックを一度に繰り返し、個々のエポックの終了に関する情報なし
import tensorflow as tf

Epoch   = 10
dataset = tf.data.Dataset.range(100)
dataset = dataset.repeat(Epoch)
iterator = dataset.make_one_shot_iterator()
next_element = iterator.get_next()
sess = tf.Session()

num_batch = 0
j = 0
while True:
    try:
        value = sess.run(next_element)
        assert j == value
        j += 1
        num_batch += 1
        if j > 99: # new Epoch
            j = 0
    except tf.errors.OutOfRangeError:
        break

print ("Num Batch: ", num_batch)
  1. 2番目のオプションは、各エポックの終了について通知します。検証損失の確認:
import tensorflow as tf

Epoch = 10
dataset = tf.data.Dataset.range(100)
iterator = dataset.make_initializable_iterator()
next_element = iterator.get_next()
sess = tf.Session()

num_batch = 0

for e in range(Epoch):
    print ("Epoch: ", e)
    j = 0
    sess.run(iterator.initializer)
    while True:
        try:
            value = sess.run(next_element)
            assert j == value
            j += 1
            num_batch += 1
        except tf.errors.OutOfRangeError:
            break

print ("Num Batch: ", num_batch)
22
melgor89

テンソルフローのバージョンが1.3以上の場合、高レベルAPI _tf.train.MonitoredTrainingSession_をお勧めします。このAPIによって作成されたsessは、sess.should_stop()で_tf.errors.OutOfRangeError_を自動的に検出できます。ほとんどのトレーニング状況では、データをシャッフルし、各ステップでバッチを取得する必要があります。これらを次のコードに追加しました。

_import tensorflow as tf

Epoch = 10
dataset = tf.data.Dataset.range(100)
dataset = dataset.shuffle(buffer_size=100) # comment this line if you don't want to shuffle data
dataset = dataset.batch(batch_size=32)     # batch_size=1 if you want to get only one element per step
dataset = dataset.repeat(Epoch)
iterator = dataset.make_one_shot_iterator()
next_element = iterator.get_next()

num_batch = 0
with tf.train.MonitoredTrainingSession() as sess:
    while not sess.should_stop():
        value = sess.run(next_element)
        num_batch += 1
        print("Num Batch: ", num_batch)
_
3
Tom

これを試して

while True:
  try:
    print(sess.run(value))
  except tf.errors.OutOfRangeError:
    break

データセットイテレータがデータの最後に到達するたびに、tf.errors.OutOfRangeErrorが発生します。それをexceptでキャッチし、データセットを最初から開始できます。

3
Grigor Carran