イテレータを使用してネットワークをトレーニングするモデルがあります。現在Googleが推奨している新しいデータセットAPIパイプラインモデルに従います。
Tfrecordファイルを読み取り、データをネットワークにフィードし、適切にトレーニングしました。すべてが順調に進んでいます。トレーニングの最後にモデルを保存して、後で推論を実行できるようにします。コードの簡略版は次のとおりです。
_""" Training and saving """
training_dataset = tf.contrib.data.TFRecordDataset(training_record)
training_dataset = training_dataset.map(ds._path_records_parser)
training_dataset = training_dataset.batch(BATCH_SIZE)
with tf.name_scope("iterators"):
training_iterator = Iterator.from_structure(training_dataset.output_types, training_dataset.output_shapes)
next_training_element = training_iterator.get_next()
training_init_op = training_iterator.make_initializer(training_dataset)
def train(num_epochs):
# compute for the number of epochs
for e in range(1, num_epochs+1):
session.run(training_init_op) #initializing iterator here
while True:
try:
images, labels = session.run(next_training_element)
session.run(optimizer, feed_dict={x: images, y_true: labels})
except tf.errors.OutOfRangeError:
saver_name = './saved_models/ucf-model'
print("Finished Training Epoch {}".format(e))
break
""" Restoring """
# restoring the saved model and its variables
session = tf.Session()
saver = tf.train.import_meta_graph(r'saved_models\ucf-model.meta')
saver.restore(session, tf.train.latest_checkpoint('.\saved_models'))
graph = tf.get_default_graph()
# restoring relevant tensors/ops
accuracy = graph.get_tensor_by_name("accuracy/Mean:0") #the tensor that when evaluated returns the mean accuracy of the batch
testing_iterator = graph.get_operation_by_name("iterators/Iterator") #my iterator used in testing.
next_testing_element = graph.get_operation_by_name("iterators/IteratorGetNext") #the GetNext operator for my iterator
# loading my testing set tfrecords
testing_dataset = tf.contrib.data.TFRecordDataset(testing_record_path)
testing_dataset = testing_dataset.map(ds._path_records_parser, num_threads=4, output_buffer_size=BATCH_SIZE*20)
testing_dataset = testing_dataset.batch(BATCH_SIZE)
testing_init_op = testing_iterator.make_initializer(testing_dataset) #to initialize the dataset
with tf.Session() as session:
session.run(testing_init_op)
while True:
try:
images, labels = session.run(next_testing_element)
accuracy = session.run(accuracy, feed_dict={x: test_images, y_true: test_labels}) #error here, x, y_true not defined
except tf.errors.OutOfRangeError:
break
_
私の問題は主にモデルを復元するときです。テストデータをネットワークにフィードする方法は?
testing_iterator = graph.get_operation_by_name("iterators/Iterator")
、next_testing_element = graph.get_operation_by_name("iterators/IteratorGetNext")
を使用してイテレータを復元すると、次のエラーが発生します:GetNext() failed because the iterator has not been initialized. Ensure that you have run the initializer operation for this iterator before getting the next element.
testing_init_op = testing_iterator.make_initializer(testing_dataset))
を使用してデータセットを初期化しようとしました。このエラーが発生しました:_AttributeError: 'Operation' object has no attribute 'make_initializer'
_もう1つの問題は、イテレーターが使用されているため、イテレーターがデータをグラフに直接フィードするため、training_modelでプレースホルダーを使用する必要がないことです。しかし、このように、データを「精度」操作にフィードするときに、最後から3行目のfeed_dictキーを復元するにはどうすればよいですか?
編集:誰かがイテレーターとネットワーク入力の間にプレースホルダーを追加する方法を提案できる場合は、プレースホルダーにデータを供給し、イテレーターを完全に無視しながら、「精度」テンソルを評価することでグラフを実行してみることができます。
保存されたメタグラフを復元する場合、名前を使用して初期化操作を復元し、それを再度使用して、推論のために入力パイプラインを初期化できます。
つまり、グラフを作成するときに、次のことができます。
dataset_init_op = iterator.make_initializer(dataset, name='dataset_init')
次に、次の手順を実行してこの操作を復元します。
dataset_init_op = graph.get_operation_by_name('dataset_init')
これは、復元の前後でランダムに初期化されたモデルの結果を比較する自己完結型のコードスニペットです。
np.random.seed(42)
data = np.random.random([4, 4])
X = tf.placeholder(dtype=tf.float32, shape=[4, 4], name='X')
dataset = tf.data.Dataset.from_tensor_slices(X)
iterator = tf.data.Iterator.from_structure(dataset.output_types, dataset.output_shapes)
dataset_next_op = iterator.get_next()
# name the operation
dataset_init_op = iterator.make_initializer(dataset, name='dataset_init')
w = np.random.random([1, 4])
W = tf.Variable(w, name='W', dtype=tf.float32)
output = tf.multiply(W, dataset_next_op, name='output')
sess = tf.Session()
saver = tf.train.Saver()
sess.run(tf.global_variables_initializer())
sess.run(dataset_init_op, feed_dict={X:data})
while True:
try:
print(sess.run(output))
except tf.errors.OutOfRangeError:
saver.save(sess, 'tmp/', global_step=1002)
break
そして、次のように推論のために同じモデルを復元できます。
np.random.seed(42)
data = np.random.random([4, 4])
tf.reset_default_graph()
sess = tf.Session()
saver = tf.train.import_meta_graph('tmp/-1002.meta')
ckpt = tf.train.get_checkpoint_state(os.path.dirname('tmp/checkpoint'))
saver.restore(sess, ckpt.model_checkpoint_path)
graph = tf.get_default_graph()
# Restore the init operation
dataset_init_op = graph.get_operation_by_name('dataset_init')
X = graph.get_tensor_by_name('X:0')
output = graph.get_tensor_by_name('output:0')
sess.run(dataset_init_op, feed_dict={X:data})
while True:
try:
print(sess.run(output))
except tf.errors.OutOfRangeError:
break
この目的のために正確に設計された tf.contrib.data.make_saveable_from_iterator
を使用することをお勧めします。冗長性がはるかに低く、既存のコード、特にイテレータの定義方法を変更する必要はありません。
手順5の完了後にすべてを保存する場合の実例。どのシードが使用されているのかわからないことに注意してください。
import tensorflow as tf
iterator = (
tf.data.Dataset.range(100)
.shuffle(10)
.make_one_shot_iterator())
batch = iterator.get_next(name='batch')
saveable_obj = tf.contrib.data.make_saveable_from_iterator(iterator)
tf.add_to_collection(tf.GraphKeys.SAVEABLE_OBJECTS, saveable_obj)
saver = tf.train.Saver()
with tf.Session() as sess:
tf.global_variables_initializer().run()
for step in range(10):
print('{}: {}'.format(step, sess.run(batch)))
if step == 5:
saver.save(sess, './foo', global_step=step)
# 0: 1
# 1: 6
# 2: 7
# 3: 3
# 4: 8
# 5: 10
# 6: 12
# 7: 14
# 8: 5
# 9: 17
その後、手順6から再開すると、同じ出力が得られます。
import tensorflow as tf
saver = tf.train.import_meta_graph('./foo-5.meta')
with tf.Session() as sess:
saver.restore(sess, './foo-5')
for step in range(6, 10):
print('{}: {}'.format(step, sess.run('batch:0')))
# 6: 12
# 7: 14
# 8: 5
# 9: 17
イテレータの初期化に関連する問題を解決できませんでしたが、 map メソッドを使用してデータセットを前処理し、Pythonラップされた操作で定義された変換を適用するため) py_func を使用すると、保存/復元のためにシリアル化できません。とにかく復元する場合は、データセットを初期化する必要があります。
したがって、残っている問題は、グラフを復元するときにデータをグラフにフィードする方法です。イテレータ出力とネットワーク入力の間にtf.identityノードを配置しました。復元時に、データをIDノードにフィードします。後で発見したより良い解決策は、 この回答 で説明されているように、placeholder_with_default()
を使用することです。
CheckpointInputPipelineHook CheckpointInputPipelineHook を確認することをお勧めします。これは、tf.Estimatorを使用してさらにトレーニングするためにイテレーター状態を保存することを実装しています。