web-dev-qa-db-ja.com

Tensorflowの.tfrecordsファイルからレコードの総数を取得する

.tfrecordsファイルからレコードの総数を取得することは可能ですか?これに関連して、一般に、モデルのトレーニング中に経過したエポックの数をどのように追跡しますか? batch_sizenum_of_epochsを指定することは可能ですが、current Epoch、エポックごとのバッチ数などの値を取得するのが簡単かどうかはわかりません。トレーニングの進行状況をより詳細に制御できます。現在、.tfrecordsファイルにあるレコードの数とミニバッチのサイズを事前に知っているので、ダーティハックを使用してこれを計算しています。ヘルプを感謝します。

24
HuckleberryFinn

レコード数を数えるには、 tf.python_io.tf_record_iterator

c = 0
for fn in tf_records_filenames:
  for record in tf.python_io.tf_record_iterator(fn):
     c += 1

モデルのトレーニングを追跡するには、 tensorboard が便利です。

30
drpng

いいえ、不可能です。 TFRecord は、内部に保存されているデータに関するメタデータを保存しません。このファイル

(バイナリ)文字列のシーケンスを表します。この形式はランダムアクセスではないため、大量のデータのストリーミングには適していますが、高速シャーディングまたはその他の非シーケンシャルアクセスが必要な場合には適していません。

必要に応じて、このメタデータを手動で保存するか、 record_iterator を使用して番号を取得できます(所有するすべてのレコードを反復処理する必要があります。

sum(1 for _ in tf.python_io.tf_record_iterator(file_name))

現在のエポックを知りたい場合は、テンソルボードから、またはループから数値を出力することにより、これを行うことができます。

17
Salvador Dali

tf_record_iterator の非推奨警告に従って、積極的な実行を使用してレコードをカウントすることもできます。

#!/usr/bin/env python
from __future__ import absolute_import, division, print_function, unicode_literals

import tensorflow as tf
import sys

assert len(sys.argv) == 2, \
    "USAGE: {} <file_glob>".format(sys.argv[0])

tf.enable_eager_execution()

input_pattern = sys.argv[1]

# This is where we get the count of records
records_n = sum(1 for record in tf.data.TFRecordDataset(tf.gfile.Glob(input_pattern)))

print("records_n = {}".format(records_n))
2
Russell