web-dev-qa-db-ja.com

TensorFlow Datasetでバッチ、リピート、シャッフルは何をしますか?

私は現在TensorFlowを学んでいますが、このコード内で混乱に遭遇しました:

dataset = dataset.shuffle(buffer_size = 10 * batch_size) 
dataset = dataset.repeat(num_epochs).batch(batch_size)
return dataset.make_one_shot_iterator().get_next()

最初にデータセットがすべてのデータを保持することを知っていますが、shuffle()、repeat()、batch()はデータセットに何をしますか?例とともに説明をお願いします

12
blue

あなたがデータセットを持っていると想像してください:_[1, 2, 3, 4, 5, 6]_、そして:

ds.shuffle()の仕組み

dataset.shuffle(buffer_size=3)は、ランダムエントリを選択するためにサイズ3のバッファを割り当てます。このバッファはソースデータセットに接続されます。次のようにイメージできます。

_Random buffer
   |
   |   Source dataset where all other elements live
   |         |
   ↓         ↓
[1,2,3] <= [4,5,6]
_

エントリ_2_がランダムバッファから取得されたと仮定しましょう。空き領域は、ソースバッファの次の要素、つまり_4_によって埋められます。

_2 <= [1,3,4] <= [5,6]
_

何もなくなるまで読み続けます。

_1 <= [3,4,5] <= [6]
5 <= [3,4,6] <= []
3 <= [4,6]   <= []
6 <= [4]      <= []
4 <= []      <= []
_

ds.repeat()の仕組み

データセットからすべてのエントリが読み取られ、次の要素を読み取ろうとすると、データセットはエラーをスローします。そこでds.repeat()が登場します。データセットを再初期化し、再び次のようにします。

_[1,2,3] <= [4,5,6]
_

ds.batch()が生成するもの

ds.batch()は最初に_batch_size_エントリを取得し、それらからバッチを作成します。したがって、サンプルデータセットのバッチサイズ3は、2つのバッチレコードを生成します。

_[2,1,5]
[3,6,4]
_

バッチの前にds.repeat()があるため、データの生成は続行されます。ただし、ds.random()により、要素の順序は異なります。考慮すべきことは、ランダムバッファのサイズのために、_6_が最初のバッチに存在しないことです。

15
Vlad-HC

Tf.Datasetの次のメソッド:

  1. repeat( count=0 )このメソッドは、データセットをcount回繰り返します。
  2. shuffle( buffer_size, seed=None, reshuffle_each_iteration=None)メソッドは、データセット内のサンプルをシャッフルします。 _buffer_size_は、ランダム化されて_tf.Dataset_として返されるサンプルの数です。
  3. batch(batch_size,drop_remainder=False) _batch_size_として指定されたバッチサイズでデータセットのバッチを作成します。これはバッチの長さでもあります。
0
user9477964