web-dev-qa-db-ja.com

TensorFlow:dataset.train.next_batchはどのように定義されていますか?

私はTensorFlowを学び、例を研究しようとしています: https://github.com/aymericdamien/TensorFlow-Examples/blob/master/notebooks/3_NeuralNetworks/autoencoder.ipynb

次に、以下のコードにいくつか質問があります。

for Epoch in range(training_epochs):
    # Loop over all batches
    for i in range(total_batch):
        batch_xs, batch_ys = mnist.train.next_batch(batch_size)
        # Run optimization op (backprop) and cost op (to get loss value)
        _, c = sess.run([optimizer, cost], feed_dict={X: batch_xs})
    # Display logs per Epoch step
    if Epoch % display_step == 0:
        print("Epoch:", '%04d' % (Epoch+1),
              "cost=", "{:.9f}".format(c))

Mnistは単なるデータセットであるため、mnist.train.next_batch意味? dataset.train.next_batch定義済み?

ありがとう!

13
Edamame

mnistオブジェクトは、tf.contrib.learnモジュールで定義されている read_data_sets()関数 から返されます。 mnist.train.next_batch(batch_size)メソッドは here で実装され、2つの配列のタプルを返します。最初の配列はbatch_size MNISTイメージのバッチを表し、2番目は配列のバッチを表しますbatch-sizeこれらの画像に対応するラベル。

画像はサイズ[batch_size, 784](MNIST画像に784ピクセルがあるため)の2次元NumPy配列として返され、ラベルはサイズ[batch_size]の1次元NumPy配列として返されます。 (read_data_sets()one_hot=Falseで呼び出された場合)またはサイズ[batch_size, 10]の2-D NumPy配列(read_data_sets()one_hot=Trueで呼び出された場合)。

25
mrry