現在、トレーニング済みのTensorFlowモデルをProtoBufファイルとしてエクスポートして、AndroidのTensorFlow C++ APIで使用しようとしています。したがって、私は _freeze_graph.py
_ スクリプトを使用しています。
_tf.train.write_graph
_を使用してモデルをエクスポートしました:
tf.train.write_graph(graph_def, FLAGS.save_path, out_name, as_text=True)
_tf.train.Saver
_で保存されたチェックポイントを使用しています。
スクリプトの冒頭で説明したように、_freeze_graph.py
_を呼び出します。コンパイルした後、実行します
_bazel-bin/tensorflow/python/tools/freeze_graph \
--input_graph=<path_to_protobuf_file> \
--input_checkpoint=<model_name>.ckpt-10000 \
--output_graph=<output_protobuf_file_path> \
--output_node_names=dropout/mul_1
_
これにより、次のエラーメッセージが表示されます。
_TypeError: Cannot interpret feed_dict key as Tensor: The name 'save/Const:0' refers to a Tensor which does not exist. The operation, 'save/Const', does not exist in the graph.
_
エラーが示すように、エクスポートしたモデルにテンソル_save/Const:0
_がありません。ただし、_freeze_graph.py
_のコードは、フラグ_filename_tensor_name
_によってこのテンソル名を指定できることを示しています。残念ながら、このテンソルがどうあるべきか、モデルに正しく設定する方法についての情報は見つかりません。
エクスポートされたProtoBufモデルで_save/Const:0
_テンソルを生成する方法、またはフラグ_filename_tensor_name
_を正しく設定する方法を誰かに教えてもらえますか?
_--filename_tensor_name
_フラグは、モデルの _tf.train.Saver
_ を構築するときに作成されるプレースホルダーテンソルの名前を指定するために使用されます。*
元のプログラムでは、_saver.saver_def.filename_tensor_name
_の値を出力して、このフラグに渡す必要がある値を取得できます。 _saver.saver_def.restore_op_name
_の値を出力して、_--restore_op_name
_フラグの値を取得することもできます(デフォルトではグラフが正しくないと思われるため)。
または、 _tf.train.SaverDef
_ protocol buffer には、これらのフラグの関連情報を再構築するために必要なすべての情報が含まれています。必要に応じて、_saver.saver_def
_をファイルに書き込み、そのファイルの名前を_--input_saver
_フラグとして_freeze_graph.py
_に渡すことができます。
* _tf.train.Saver
_のデフォルトの名前スコープは_"save/"
_であり、プレースホルダーは 実際にはtf.constant()
であり、その名前はデフォルトで_"Const:0"
_にデフォルト設定されます。フラグがデフォルトで_"save/Const:0"
_になる理由。
次のようにコードを配置するとエラーが発生することに気づきました。
sess = tf.Session()
tf.train.write_graph(sess.graph_def, '', '/tmp/train.pbtxt')
init = tf.initialize_all_variables()
saver = tf.train.Saver()
sess.run(init)
私がこのようにコードレイアウトを変更した後、それはうまくいきました:
# Add ops to save and restore all the variables.
saver = tf.train.Saver()
init = tf.initialize_all_variables()
sess = tf.Session()
tf.train.write_graph(sess.graph_def, '', '/tmp/train.pbtxt')
sess.run(init)
なぜかはよくわかりません。 @mrryもう少し説明してもらえますか?
@ Drag0の回答のフォローアップと、新しいコードレイアウトがエラーを修正した理由。
saver = tf.train.Saver()
を呼び出すとき、_'save/Const:0'
_などのtf.train.Saver()
に関連するさまざまな変数をデフォルトのグラフに追加します。
最初のコード配置では、グラフはtf.train.Saver()
変数なしで以前に保存されています。 2番目のコード配置では、後で保存されるので、操作_save/Const
_がグラフに存在します。
私はこれらが削除されたのを見ることができたので、それは最新のfreeze_graph.pyで問題になるべきではありません:
_del restore_op_name, filename_tensor_name # Unused by updated loading code.
_ source:freeze_graph.py
以前のバージョンでは、restore_opを使用してモデルを復元していました
sess.run([restore_op_name], {filename_tensor_name: input_checkpoint})
したがって、以前のバージョンでは、セーバーオペレーションをインスタンス化する前に.pbファイルにグラフを書き込んでいると、問題が発生します。例えば。:
_tf.train.write_graph(sess.graph_def, "./logs", "test2.pb", False)
saver = tf.train.Saver()
saver.save(sess, "./logs/hello_ck.ckpt", meta_graph_suffix='meta', write_meta_graph=True)
_
これは、グラフにモデルの復元のための保存/復元操作がないためです。それを解決するには、.ckptファイルを保存してからグラフを作成します
_saver = tf.train.Saver()
saver.save(sess, "./logs/hello_ck.ckpt", meta_graph_suffix='meta', write_meta_graph=True)
tf.train.write_graph(sess.graph_def, "./logs", "test2.pb", False)
_
@mrry、何か間違ったことを解釈した場合はご案内ください。最近、tensorflowコードに飛び込み始めました。