web-dev-qa-db-ja.com

特定の値でtf.data.Datasetをフィルタリングするにはどうすればよいですか?

TFRecordsを読み取ってデータセットを作成し、値をマッピングして特定の値のデータセットをフィルター処理したいのですが、結果はテンソルを含むディクテーションであるため、テンソルの実際の値を取得したり、チェックしたりできません。 tf.cond()/tf.equal。どうやってやるの?

def mapping_func(serialized_example):
    feature = { 'label': tf.FixedLenFeature([1], tf.string) }
    features = tf.parse_single_example(serialized_example, features=feature)
    return features

def filter_func(features):
    # this doesn't work
    #result = features['label'] == 'some_label_value'
    # neither this
    result = tf.reshape(tf.equal(features['label'], 'some_label_value'), [])
    return result

def main():
    file_names = ["/var/data/file1.tfrecord", "/var/data/file2.tfrecord"]
    dataset = tf.contrib.data.TFRecordDataset(file_names)
    dataset = dataset.map(mapping_func)
    dataset = dataset.shuffle(buffer_size=10000)
    dataset = dataset.filter(filter_func)
    dataset = dataset.repeat()
    iterator = dataset.make_one_shot_iterator()
    sample = iterator.get_next()
11
tsveti_iko

私は自分の質問に答えています。問題が見つかりました!

私がする必要があるのは、次のようなラベルtf.unstack()です。

_label = tf.unstack(features['label'])
label = label[0]
_

tf.equal()に渡す前に:

_result = tf.reshape(tf.equal(label, 'some_label_value'), [])
_

問題は、ラベルが文字列tf.FixedLenFeature([1], tf.string)の1つの要素を持つ配列として定義されていたため、最初の単一の要素を取得するために、(リストを作成する)それをアンパックする必要がありました。インデックス0の要素を取得し、間違っている場合は修正してください。

4
tsveti_iko

そもそも、ラベルを1次元配列にする必要はないと思います。

と:

feature = {'label': tf.FixedLenFeature((), tf.string)}

filter_funcでラベルをアンスタックする必要はありません

0