Tensorflow API は、事前にトレーニングされたモデルをほとんど提供していないため、任意のデータセットでモデルをトレーニングできました。
1つのテンソルフローセッションで複数のグラフを初期化して使用する方法を知りたいです。トレーニング済みの2つのモデルを2つのグラフにインポートし、それらをオブジェクト検出に使用したいのですが、1つのセッションで複数のグラフを実行しようとすると失望します。
1つのセッションで複数のグラフを操作する特定の方法はありますか?.
別の問題は、2つの異なるグラフに対して2つの異なるセッションを作成し、それらを操作しようとしても、最初のインスタンス化されたセッションの2番目のセッションで同様の結果が得られることです。
各Session
には、単一のGraph
のみを含めることができます。そうは言っても、具体的に何をしようとしているかに応じて、いくつかのオプションがあります。
最初のオプションは、2つの個別のセッションを作成し、各セッションに1つのグラフを読み込むことです。これについては、 こちらのドキュメント で説明しています。そのアプローチの各セッションで予想外に同様の結果が得られるとおっしゃいましたが、詳細がなければ、具体的に問題が何であるかを把握するのは困難です。各セッションに同じグラフがロードされたのか、各セッションを個別に実行しようとしたときに同じセッションが2回実行されているのではないかと思われますが、詳細なしではわかりません。
2番目のオプションは、両方のグラフをメインセッショングラフのサブグラフとしてロードすることです。グラフ内に2つのスコープを作成し、そのスコープ内で読み込むグラフごとにグラフを作成できます。次に、それらの間に関連性がないため、それらを独立したグラフとして扱うことができます。通常のグラフグローバル関数を実行する場合、それらの関数が適用されるスコープを指定する必要があります。たとえば、オプティマイザーを使用してサブグラフの1つで更新を実行する場合、 this answer に示されているようなものを使用して、そのサブグラフのスコープのトレーニング可能な変数のみを取得する必要があります。
TensorFlowグラフ内で何らかの形で相互作用できるように2つのグラフを明示的に必要としない限り、サブグラフが必要とする余分なフープをジャンプする必要がないように、最初のアプローチをお勧めします(フィルタリングする必要があるなど)任意の時点での作業範囲、および2つの間で共有されるグラフグローバルの可能性)。
1つのセッションのグラフ引数は、Noneまたはグラフのインスタンスでなければなりません。
ソースコード は次のとおりです。
class BaseSession(SessionInterface):
"""A class for interacting with a TensorFlow computation.
The BaseSession enables incremental graph building with inline
execution of Operations and evaluation of Tensors.
"""
def __init__(self, target='', graph=None, config=None):
"""Constructs a new TensorFlow session.
Args:
target: (Optional) The TensorFlow execution engine to connect to.
graph: (Optional) The graph to be used. If this argument is None,
the default graph will be used.
config: (Optional) ConfigProto proto used to configure the session.
Raises:
tf.errors.OpError: Or one of its subclasses if an error occurs while
creating the TensorFlow session.
TypeError: If one of the arguments has the wrong type.
"""
if graph is None:
self._graph = ops.get_default_graph()
else:
if not isinstance(graph, ops.Graph):
raise TypeError('graph must be a tf.Graph, but got %s' % type(graph))
また、以下のスニペットから、リストにできないことがわかります。
if graph is None:
self._graph = ops.get_default_graph()
else:
if not isinstance(graph, ops.Graph):
raise TypeError('graph must be a tf.Graph, but got %s' % type(graph))
そして ops.Graph (find by help(ops.Graph))オブジェクトから、複数のグラフにできないことがわかります。
more については、シーションとグラフについて:
If no `graph` argument is specified when constructing the session, the default graph will be launched in the session. If you are using more than one graph (created with `tf.Graph()` in the same process, you will have to use different sessions for each graph, but each graph can be used in multiple sessions. In this case, it is often clearer to pass the graph to be launched explicitly to the session constructor.