web-dev-qa-db-ja.com

Pythonで.pbファイルからTensorflowモデルを復元する方法は?

私はpython DNNにロードしたいtensorflow .pbファイルを持っています。グラフを復元して予測を取得します。作成した.pbファイルが通常のSaver.save()モデルと同様の予測。

私の基本的な問題は、上記の.pbファイルを使用してAndroid=

私の.pbファイル作成コード:

frozen_graph = tf.graph_util.convert_variables_to_constants(
        session,
        session.graph_def,
        ['outputLayer/Softmax']
    )
with open('frozen_model.pb', 'wb') as f:
  f.write(frozen_graph.SerializeToString())

だから私は2つの大きな懸念があります:

  1. 上記の.pbファイルをpython Tensorflowモデルにロードするにはどうすればよいですか?
  2. pythonとAndroid?
13
vizsatiz

次のコードはモデルを読み取り、グラフ内のノードの名前を出力します。

import tensorflow as tf
from tensorflow.python.platform import gfile
GRAPH_PB_PATH = './frozen_model.pb'
with tf.Session() as sess:
   print("load graph")
   with gfile.FastGFile(GRAPH_PB_PATH,'rb') as f:
       graph_def = tf.GraphDef()
   graph_def.ParseFromString(f.read())
   sess.graph.as_default()
   tf.import_graph_def(graph_def, name='')
   graph_nodes=[n for n in graph_def.node]
   names = []
   for t in graph_nodes:
      names.append(t.name)
   print(names)

グラフを適切にフリーズしているため、異なる結果が得られるのは、基本的にモデルに重みが保存されないためです。 freeze_graph.pylink )を使用して、適切に保存されたグラフを取得できます。

19
sahu