最近、テンソルフローを学び始めました。
違いがあるかどうかわからない
x = np.array([[1],[2],[3],[4],[5]])
dataset = tf.data.Dataset.from_tensor_slices(x)
ds.shuffle(buffer_size=4)
ds.batch(4)
そして
x = np.array([[1],[2],[3],[4],[5]])
dataset = tf.data.Dataset.from_tensor_slices(x)
ds.batch(4)
ds.shuffle(buffer_size=4)
また、なぜ使えないのかわかりません
dataset = dataset.shuffle_batch(buffer_size=2,batch_size=BATCH_SIZE)
エラーが発生するため
dataset = dataset.shuffle_batch(buffer_size=2,batch_size=BATCH_SIZE)
AttributeError: 'TensorSliceDataset' object has no attribute 'shuffle_batch'
ありがとうございました!
TL; DR:はい、違いがあります。ほとんどの場合、 Dataset.shuffle()
beforeDataset.batch()
)を呼び出します。 。 tf.data.Dataset
クラスにはshuffle_batch()
メソッドはありません。データセットをシャッフルしてバッチ処理するには、2つのメソッドを別々に呼び出す必要があります。
tf.data.Dataset
の変換は、呼び出されるのと同じ順序で適用されます。 Dataset.batch()
は、入力の連続する要素を、出力の単一のバッチ要素に結合します。次の2つのデータセットを検討することで、操作の順序の影響を確認できます。
tf.enable_eager_execution() # To simplify the example code.
# Batch before shuffle.
dataset = tf.data.Dataset.from_tensor_slices([0, 0, 0, 1, 1, 1, 2, 2, 2])
dataset = dataset.batch(3)
dataset = dataset.shuffle(9)
for elem in dataset:
print(elem)
# Prints:
# tf.Tensor([1 1 1], shape=(3,), dtype=int32)
# tf.Tensor([2 2 2], shape=(3,), dtype=int32)
# tf.Tensor([0 0 0], shape=(3,), dtype=int32)
# Shuffle before batch.
dataset = tf.data.Dataset.from_tensor_slices([0, 0, 0, 1, 1, 1, 2, 2, 2])
dataset = dataset.shuffle(9)
dataset = dataset.batch(3)
for elem in dataset:
print(elem)
# Prints:
# tf.Tensor([2 0 2], shape=(3,), dtype=int32)
# tf.Tensor([2 1 0], shape=(3,), dtype=int32)
# tf.Tensor([0 1 1], shape=(3,), dtype=int32)
最初のバージョン(シャッフル前のバッチ)では、各バッチの要素は入力からの3つの連続した要素です。一方、2番目のバージョン(バッチの前にシャッフル)では、入力からランダムにサンプリングされます。通常、ミニバッチ(のいくつかの変形)によるトレーニング 確率的勾配降下法 の場合、各バッチの要素は、入力全体から可能な限り均一にサンプリングする必要があります。そうしないと、ネットワークが入力データに含まれていた構造に適合しすぎて、結果のネットワークがそれほど高い精度を達成できない可能性があります。
@mrryに完全に同意しますが、バッチ処理を実行したい場合が1つありますbeforeシャッフル。 RNNにフィードされるテキストデータを処理しているとします。ここでは、各文が1つのシーケンスとして扱われ、1つのバッチに複数のシーケンスが含まれます。文の長さは可変であるため、バッチで文を均一な長さにpadする必要があります。これを行う効率的な方法は、同じ長さの文をグループ化するバッチ処理を行ってから、シャッフルすることです。そうしないと、<pad>
トークンでいっぱいのバッチになってしまう可能性があります。