私は https://github.com/tensorflow/tensorflow/blob/r1.3/tensorflow/examples/learn/wide_n_deep_tutorial.py のチュートリアルコードを使用し、コードを作成するまで問題なく動作します単にそれを評価する代わりに予測。次のような予測用の別の関数を作成しようとしました(パラメーターyを削除するだけです)。
def input_fn_predict(data_file, num_epochs, shuffle):
"""Input builder function."""
df_data = pd.read_csv(
tf.gfile.Open(data_file),
names=CSV_COLUMNS,
skipinitialspace=True,
engine="python",
skiprows=1)
# remove NaN elements
df_data = df_data.dropna(how="any", axis=0)
labels = df_data["income_bracket"].apply(lambda x: ">50K" in x).astype(int)
return tf.estimator.inputs.pandas_input_fn( #removed paramter y
x=df_data,
batch_size=100,
num_epochs=num_epochs,
shuffle=shuffle,
num_threads=5)
そしてそれをこのように呼ぶには:
predictions = m.predict(
input_fn=input_fn_predict(test_file_name, num_epochs=1, shuffle=True)
)
for i, p in enumerate(predictions):
print(i, p)
{'確率':array([0.78595656、0.21404342]、dtype = float32)、 'logits':array([-1.3007226]、dtype = float32)、 'classes':array(['0']、dtype = object) 、 'class_ids':array([0])、 'logistic':array([0.21404341]、dtype = float32)}
どうやって読むの?
新しいラベルを予測するには、データの順序を維持する必要があるため、shuffle=False
を設定する必要があります。
以下は、予測を実行するためのコードです(テストしました)。入力ファイルはテストデータ(csv形式)に似ていますが、ラベル列はありません。
def predict_input_fn(data_file):
global CSV_COLUMNS
CSV_COLUMNS = CSV_COLUMNS[:-1]
df_data = pd.read_csv(
tf.gfile.Open(data_file),
names=CSV_COLUMNS,
skipinitialspace=True,
engine='python',
skiprows=1
)
# remove NaN elements
df_data = df_data.dropna(how='any', axis=0)
return tf.estimator.inputs.pandas_input_fn(
x=df_data,
num_epochs=1,
shuffle=False
)
それを呼び出すには:
predict_file_name = 'tutorials/data/adult.predict'
results = m.predict(
input_fn=predict_input_fn(predict_file_name)
)
for result in results:
print 'result: {}'.format(result)
1つのサンプルの予測結果は次のとおりです。
{
'probabilities': array([0.78595656, 0.21404342], dtype = float32),
'logits': array([-1.3007226], dtype = float32),
'classes': array(['0'], dtype = object),
'class_ids': array([0]),
'logistic': array([0.21404341], dtype = float32)
}
各フィールドの意味は