Kerasでは、model.fit
の出力を次のように履歴に戻すことができます。
history = model.fit(X_train, y_train,
batch_size=batch_size,
nb_Epoch=nb_Epoch,
validation_data=(X_test, y_test))
今、さらに使用するために履歴をファイルに保存する方法(たとえば、エポックに対してaccまたはlossのプロットを描く)?
私が使用しているのは次のとおりです。
with open('/trainHistoryDict', 'wb') as file_pi:
pickle.dump(history.history, file_pi)
この方法で、後で損失または精度をプロットしたい場合に備えて、履歴を辞書として保存します。
model
履歴は、次のようにファイルに保存できます
import json
hist = model.fit(X_train, y_train, epochs=5, batch_size=batch_size,validation_split=0.1)
with open('file.json', 'w') as f:
json.dump(hist.history, f)
history
オブジェクトにはhistory
フィールドがあり、すべてのトレーニングエポックにまたがるさまざまなトレーニングメトリックを保持する辞書です。例えばhistory.history['loss'][99]
は、トレーニングの100エポックでモデルの損失を返します。これを保存するには、この辞書をpickle
するか、この辞書の異なるリストを適切なファイルに保存します。
history.history
はdict
であるため、それをpandas
DataFrame
オブジェクトに変換することもできます。このオブジェクトは、必要に応じて保存できます。
ステップバイステップ:
import pandas as pd
# assuming you stored your model.fit results in a 'history' variable:
history = model.fit(x_train, y_train, epochs=10)
# convert the history.history dict to a pandas DataFrame:
hist_df = pd.DataFrame(history.history)
# save to json:
hist_json_file = 'history.json'
with open(hist_json_file, mode='w') as f:
hist_df.to_json(f)
# or save to csv:
hist_csv_file = 'history.csv'
with open(hist_csv_file, mode='w') as f:
hist_df.to_csv(f)
私は、kerasのリスト内の値がjson直列化可能でないという問題に遭遇しました。したがって、この2つの便利な関数は、使用目的のために作成しました。
import json,codecs
import numpy as np
def saveHist(path,history):
new_hist = {}
for key in list(history.history.keys()):
if type(history.history[key]) == np.ndarray:
new_hist[key] == history.history[key].tolist()
Elif type(history.history[key]) == list:
if type(history.history[key][0]) == np.float64:
new_hist[key] = list(map(float, history.history[key]))
print(new_hist)
with codecs.open(path, 'w', encoding='utf-8') as f:
json.dump(new_hist, f, separators=(',', ':'), sort_keys=True, indent=4)
def loadHist(path):
with codecs.open(path, 'r', encoding='utf-8') as f:
n = json.loads(f.read())
return n
saveHistは、jsonファイルを保存する場所へのパスと、keras fit
またはfit_generator
メソッドから返される履歴オブジェクトを取得する必要があります。
これを行うには多くの方法があると確信していますが、いじって、自分のバージョンを思いつきました。
まず、カスタムコールバックにより、すべてのエポックの終わりに履歴を取得して更新できます。そこには、モデルを保存するためのコールバックもあります。これらはどちらも便利です。クラッシュしたりシャットダウンしたりすると、最後に完了したエポックでトレーニングを受けることができるからです。
class LossHistory(Callback):
# https://stackoverflow.com/a/53653154/852795
def on_Epoch_end(self, Epoch, logs = None):
new_history = {}
for k, v in logs.items(): # compile new history from logs
new_history[k] = [v] # convert values into lists
current_history = loadHist(history_filename) # load history from current training
current_history = appendHist(current_history, new_history) # append the logs
saveHist(history_filename, current_history) # save history from current training
model_checkpoint = ModelCheckpoint(model_filename, verbose = 0, period = 1)
history_checkpoint = LossHistory()
callbacks_list = [model_checkpoint, history_checkpoint]
次に、彼らが言うことを正確に行うための「ヘルパー」関数があります。これらはすべてLossHistory()
コールバックから呼び出されます。
# https://stackoverflow.com/a/54092401/852795
import json, codecs
def saveHist(path, history):
with codecs.open(path, 'w', encoding='utf-8') as f:
json.dump(history, f, separators=(',', ':'), sort_keys=True, indent=4)
def loadHist(path):
n = {} # set history to empty
if os.path.exists(path): # reload history if it exists
with codecs.open(path, 'r', encoding='utf-8') as f:
n = json.loads(f.read())
return n
def appendHist(h1, h2):
if h1 == {}:
return h2
else:
dest = {}
for key, value in h1.items():
dest[key] = value + h2[key]
return dest
その後、必要なのは、history_filename
をdata/model-history.json
のようなものに設定し、model_filesname
をdata/model.h5
のようなものに設定することだけです。トレーニングの終了時に履歴を台無しにしないようにするための最後の微調整の1つは、停止して開始し、コールバックを維持することを前提としていることです。
new_history = model.fit(X_train, y_train,
batch_size = batch_size,
nb_Epoch = nb_Epoch,
validation_data=(X_test, y_test),
callbacks=callbacks_list)
history = appendHist(history, new_history.history)
必要なときはいつでも、history = loadHist(history_filename)
は履歴を戻します。
ファンキーさはjsonとリストに由来しますが、繰り返して変換せずに機能させることはできませんでした。とにかく、私はこれが何日もクランクをつけてきたので、これが機能することを知っています。 https://stackoverflow.com/a/44674337/852795 のpickled.dump
の回答の方が良いかもしれませんが、それが何であるかはわかりません。ここで何かを見逃した場合、または機能させることができない場合は、お知らせください。