モデル入力にtf.placeholderを使用し、tf.Session()。runのfeed_dictパラメーターを使用してデータをフィードする既存のTensorFlowモデルがあります。以前は、データセット全体がメモリに読み込まれ、この方法で渡されていました。
もっと大きなデータセットを使用し、tf.data APIのパフォーマンスの向上を活用したいと思います。 tf.data.TextLineDatasetとそれからのワンショットイテレータを定義しましたが、データをトレーニングしてモデルに取り込む方法を見つけるのに苦労しています。
最初は、feed_dictをプレースホルダーからiterator.get_next()への辞書として定義しようとしましたが、フィードの値をtf.Tensorオブジェクトにすることはできないというエラーが表示されました。さらに掘り下げると、これは、iterator.get_next()によって返されたオブジェクトが、feed_dictにフィードするものとは異なり、すでにグラフの一部であるためであり、とにかくfeed_dictを使用しようとしてはならないことを理解するようになりましたパフォーマンス上の理由。
そこで、入力tf.placeholderを取り除き、モデルを定義するクラスのコンストラクターへのパラメーターに置き換えました。トレーニングコードでモデルを構築するとき、iterator.get_next()の出力をそのパラメーターに渡します。これは、モデルの定義とデータセット/トレーニング手順の間の分離を壊すため、すでに少し不格好に見えます。そして、モデルの入力を表す(信じている)Tensorは、iterator.get_next()のTensorと同じグラフからのものでなければならないというエラーが表示されます。
私はこのアプローチで正しい軌道に乗っており、グラフとセッションのセットアップ方法、またはそのようなことで何か間違ったことをしていますか? (データセットとモデルは両方ともセッション外で初期化され、作成しようとする前にエラーが発生します。)
または、私はこれで完全にオフになっており、Estimator APIを使用して入力関数ですべてを定義するなど、別のことをする必要がありますか?
最小限の例を示すコードを次に示します。
import tensorflow as tf
import numpy as np
class Network:
def __init__(self, x_in, input_size):
self.input_size = input_size
# self.x_in = tf.placeholder(dtype=tf.float32, shape=(None, self.input_size)) # Original
self.x_in = x_in
self.output_size = 3
tf.reset_default_graph() # This turned out to be the problem
self.layer = tf.layers.dense(self.x_in, self.output_size, activation=tf.nn.relu)
self.loss = tf.reduce_sum(tf.square(self.layer - tf.constant(0, dtype=tf.float32, shape=[self.output_size])))
data_array = np.random.standard_normal([4, 10]).astype(np.float32)
dataset = tf.data.Dataset.from_tensor_slices(data_array).batch(2)
model = Network(x_in=dataset.make_one_shot_iterator().get_next(), input_size=dataset.output_shapes[-1])
与えられた元のコードのモデルのコンストラクターのtf.reset_default_graph()
行が原因でした。削除して修正しました。
頭を動かすのにも少し時間がかかりました。あなたは正しい軌道に乗っています。データセット定義全体はグラフの一部にすぎません。通常、Modelクラスとは異なるクラスとして作成し、データセットをModelクラスに渡します。コマンドラインで読み込むデータセットクラスを指定し、そのクラスを動的に読み込むことで、データセットとグラフをモジュール的に分離します。
データセット内のすべてのテンソルに名前を付けることができます(また、そうする必要があります)。これにより、必要なさまざまな変換を介してデータを渡すときに、物事を理解しやすくなります。
iterator.get_next()
からサンプルを取得して表示する簡単なテストケースを書くことができます。sess.run(next_element_tensor)
のようなものがあり、feed_dict
あなたが正しく述べたように。
頭を悩ませたら、おそらくデータセット入力パイプラインが好きになるでしょう。コードを適切にモジュール化することを強制し、単体テストが容易な構造に強制します。
開発者ガイドを必ず読んでください。そこにはたくさんの例があります:
https://www.tensorflow.org/programmers_guide/datasets
もう1つ注意する点は、このパイプラインを使用してトレインを操作し、データセットをテストするのがどれほど簡単かということです。テストデータセットでは実行しないトレーニングデータセットでデータの増強を実行することが多いため、これは重要です。from_string_handle
はそれを可能にし、上記のガイドで明確に説明されています。