Tensorflow v1.3のデータセットAPI で遊んでいます。それは素晴らしい。 ここ の説明に従って、データセットを関数でマップすることができます。追加の引数を持つ関数を渡す方法を知りたいです。たとえば、arg1
:
def _parse_function(example_proto, arg1):
features = {"image": tf.FixedLenFeature((), tf.string, default_value=""),
"label": tf.FixedLenFeature((), tf.int32, default_value=0)}
parsed_features = tf.parse_single_example(example_proto, features)
return parsed_features["image"], parsed_features["label"]
もちろん、
dataset = dataset.map(_parse_function)
arg1
を渡す方法がないため、機能しません。
以下は、ラムダ式を使用して、引数を渡したい関数をラップする例です。
_import tensorflow as tf
def fun(x, arg):
return x * arg
my_arg = tf.constant(2, dtype=tf.int64)
ds = tf.data.Dataset.range(5)
ds = ds.map(lambda x: fun(x, my_arg))
_
上記では、map
に提供される関数のシグネチャは、データセットの内容と一致する必要があります。したがって、ラムダ式を記述してそれに一致させる必要があります。データセットに含まれている要素は1つだけなので、ここでは簡単です。0〜4の範囲の要素を含むx
です。
必要に応じて、ds = ds.map(lambda x: my_other_fun(x, arg1, arg2, arg3)
など、データセットの外部から任意の数の外部引数を渡すことができます。
上記が機能することを確認するために、マッピングによって実際に各データセット要素が2倍になることがわかります。
_iterator = ds.make_initializable_iterator()
next_x = iterator.get_next()
with tf.Session() as sess:
sess.run(iterator.initializer)
while True:
try:
print(sess.run(next_x))
except tf.errors.OutOfRangeError:
break
_
出力:
_0
2
4
6
8
_
別の解決策は、クラスラッパーを使用することです。次のコードでは、パラメーターshapeを解析関数に渡しました。
class MyDataSets:
def __init__(self, shape):
self.shape = shape
def parse_sample(self.sample):
features = { ... }
f = tf.parse_example([example], features=features)
image_raw = tf.decode_raw(f['image_raw'], tf.uint8)
image = image.reshape(image_raw, self.shape)
label = tf.cast(f['label'], tf.int32)
return image, label
def init(self):
ds = tf.data.TFRecordDataSets(...)
ds = ds.map(self.parse_sample)
...
return ds.make_initializable_iterator()
代わりにPartial
関数を使用してパラメーターをラップすることもできます。
def _parse_function(arg1, example_proto):
features = {"image": tf.FixedLenFeature((), tf.string, default_value=""),
"label": tf.FixedLenFeature((), tf.int32, default_value=0)}
parsed_features = tf.parse_single_example(example_proto, features)
return parsed_features["image"], parsed_features["label"]
関数のパラメーターの順序は、偏りに合うように変更されます。その後、次のようなパラメーター値で関数をラップできます。
from functools import partial
arg1 = ...
dataset = dataset.map(partial(_parse_function, arg1))