私は現在TensorFlowを学んでいますが、このコード内で混乱に遭遇しました:
dataset = dataset.shuffle(buffer_size = 10 * batch_size)
dataset = dataset.repeat(num_epochs).batch(batch_size)
return dataset.make_one_shot_iterator().get_next()
最初にデータセットがすべてのデータを保持することを知っていますが、shuffle()、repeat()、batch()はデータセットに何をしますか?例とともに説明をお願いします
あなたがデータセットを持っていると想像してください:_[1, 2, 3, 4, 5, 6]
_、そして:
ds.shuffle()の仕組み
dataset.shuffle(buffer_size=3)
は、ランダムエントリを選択するためにサイズ3のバッファを割り当てます。このバッファはソースデータセットに接続されます。次のようにイメージできます。
_Random buffer
|
| Source dataset where all other elements live
| |
↓ ↓
[1,2,3] <= [4,5,6]
_
エントリ_2
_がランダムバッファから取得されたと仮定しましょう。空き領域は、ソースバッファの次の要素、つまり_4
_によって埋められます。
_2 <= [1,3,4] <= [5,6]
_
何もなくなるまで読み続けます。
_1 <= [3,4,5] <= [6]
5 <= [3,4,6] <= []
3 <= [4,6] <= []
6 <= [4] <= []
4 <= [] <= []
_
ds.repeat()の仕組み
データセットからすべてのエントリが読み取られ、次の要素を読み取ろうとすると、データセットはエラーをスローします。そこでds.repeat()
が登場します。データセットを再初期化し、再び次のようにします。
_[1,2,3] <= [4,5,6]
_
ds.batch()が生成するもの
ds.batch()
は最初に_batch_size
_エントリを取得し、それらからバッチを作成します。したがって、サンプルデータセットのバッチサイズ3は、2つのバッチレコードを生成します。
_[2,1,5]
[3,6,4]
_
バッチの前にds.repeat()
があるため、データの生成は続行されます。ただし、ds.random()
により、要素の順序は異なります。考慮すべきことは、ランダムバッファのサイズのために、_6
_が最初のバッチに存在しないことです。
Tf.Datasetの次のメソッド:
repeat( count=0 )
このメソッドは、データセットをcount
回繰り返します。shuffle( buffer_size, seed=None, reshuffle_each_iteration=None)
メソッドは、データセット内のサンプルをシャッフルします。 _buffer_size
_は、ランダム化されて_tf.Dataset
_として返されるサンプルの数です。batch(batch_size,drop_remainder=False)
_batch_size
_として指定されたバッチサイズでデータセットのバッチを作成します。これはバッチの長さでもあります。