web-dev-qa-db-ja.com

tensorflowデータセットシャッフルしてからバッチまたはバッチしてからシャッフル

最近、テンソルフローを学び始めました。

違いがあるかどうかわからない

x = np.array([[1],[2],[3],[4],[5]])
dataset = tf.data.Dataset.from_tensor_slices(x)
ds.shuffle(buffer_size=4)
ds.batch(4)

そして

x = np.array([[1],[2],[3],[4],[5]])
dataset = tf.data.Dataset.from_tensor_slices(x)
ds.batch(4)
ds.shuffle(buffer_size=4)

また、なぜ使えないのかわかりません

dataset = dataset.shuffle_batch(buffer_size=2,batch_size=BATCH_SIZE)

エラーが発生するため

dataset = dataset.shuffle_batch(buffer_size=2,batch_size=BATCH_SIZE)
AttributeError: 'TensorSliceDataset' object has no attribute 'shuffle_batch'

ありがとうございました!

7
Lim Kaizhuo

TL; DR:はい、違いがあります。ほとんどの場合、 Dataset.shuffle()beforeDataset.batch())を呼び出します。tf.data.Dataset クラスにはshuffle_batch()メソッドはありません。データセットをシャッフルしてバッチ処理するには、2つのメソッドを別々に呼び出す必要があります。


tf.data.Datasetの変換は、呼び出されるのと同じ順序で適用されます。 Dataset.batch()は、入力の連続する要素を、出力の単一のバッチ要素に結合します。次の2つのデータセットを検討することで、操作の順序の影響を確認できます。

tf.enable_eager_execution()  # To simplify the example code.

# Batch before shuffle.
dataset = tf.data.Dataset.from_tensor_slices([0, 0, 0, 1, 1, 1, 2, 2, 2])
dataset = dataset.batch(3)
dataset = dataset.shuffle(9)

for elem in dataset:
  print(elem)

# Prints:
# tf.Tensor([1 1 1], shape=(3,), dtype=int32)
# tf.Tensor([2 2 2], shape=(3,), dtype=int32)
# tf.Tensor([0 0 0], shape=(3,), dtype=int32)

# Shuffle before batch.
dataset = tf.data.Dataset.from_tensor_slices([0, 0, 0, 1, 1, 1, 2, 2, 2])
dataset = dataset.shuffle(9)
dataset = dataset.batch(3)

for elem in dataset:
  print(elem)

# Prints:
# tf.Tensor([2 0 2], shape=(3,), dtype=int32)
# tf.Tensor([2 1 0], shape=(3,), dtype=int32)
# tf.Tensor([0 1 1], shape=(3,), dtype=int32)

最初のバージョン(シャッフル前のバッチ)では、各バッチの要素は入力からの3つの連続した要素です。一方、2番目のバージョン(バッチの前にシャッフル)では、入力からランダムにサンプリングされます。通常、ミニバッチ(のいくつかの変形)によるトレーニング 確率的勾配降下法 の場合、各バッチの要素は、入力全体から可能な限り均一にサンプリングする必要があります。そうしないと、ネットワークが入力データに含まれていた構造に適合しすぎて、結果のネットワークがそれほど高い精度を達成できない可能性があります。

7
mrry

@mrryに完全に同意しますが、バッチ処理を実行したい場合が1つありますbeforeシャッフル。 RNNにフィードされるテキストデータを処理しているとします。ここでは、各文が1つのシーケンスとして扱われ、1つのバッチに複数のシーケンスが含まれます。文の長さは可変であるため、バッチで文を均一な長さにpadする必要があります。これを行う効率的な方法は、同じ長さの文をグループ化するバッチ処理を行ってから、シャッフルすることです。そうしないと、<pad>トークンでいっぱいのバッチになってしまう可能性があります。

3
R. Zhu