最近、TensorflowのデータセットAPIを調べていますが、分散計算用のメソッドdataset.shard()
があります。
これは、Tensorflowのドキュメントに記載されている内容です。
Creates a Dataset that includes only 1/num_shards of this dataset.
d = tf.data.TFRecordDataset(FLAGS.input_file)
d = d.shard(FLAGS.num_workers, FLAGS.worker_index)
d = d.repeat(FLAGS.num_epochs)
d = d.shuffle(FLAGS.shuffle_buffer_size)
d = d.map(parser_fn, num_parallel_calls=FLAGS.num_map_threads)
このメソッドは、元のデータセットの一部を返すと言われています。 2人の労働者がいる場合、次のことを行う必要があります。
d_0 = d.shard(FLAGS.num_workers, worker_0)
d_1 = d.shard(FLAGS.num_workers, worker_1)
......
iterator_0 = d_0.make_initializable_iterator()
iterator_1 = d_1.make_initializable_iterator()
for worker_id in workers:
with tf.device(worker_id):
if worker_id == 0:
data = iterator_0.get_next()
else:
data = iterator_1.get_next()
......
ドキュメントには後続の呼び出しの方法が指定されていないため、ここでは少し混乱しています。
ありがとう!
それがどのように機能するかをよりよく理解するために、最初に Distributed TensorFlowのチュートリアル を見る必要があります。
複数のワーカーがあり、それぞれが同じコードを実行しますが、わずかな違いがあります。各ワーカーのFLAGS.worker_index
は異なります。
tf.data.Dataset.shard
を使用する場合、このワーカーインデックスを指定すると、データはワーカー間で均等に分割されます。
これは3人の労働者の例です。
dataset = tf.data.Dataset.range(6)
dataset = dataset.shard(FLAGS.num_workers, FLAGS.worker_index)
iterator = dataset.make_one_shot_iterator()
res = iterator.get_next()
# Suppose you have 3 workers in total
with tf.Session() as sess:
for i in range(2):
print(sess.run(res))
出力があります:
0, 3
ワーカー01, 4
ワーカー12, 5
ワーカー2