私はPyTorchで訓練されたモデルを保存するための代替方法を探していました。これまでのところ、2つの選択肢があります。
私はこの 議論 に遭遇しました。そこでアプローチ2はアプローチ1よりも推奨されます。
私の質問は、なぜ2番目のアプローチが好ましいのかということです。それは、 torch.nn モジュールがそれらの2つの機能を持っているからであり、それらを使用することが推奨されますか?
このページ をgithubレポジトリで見つけました。ここにコンテンツを貼り付けるだけです。
モデルのシリアル化と復元には、主に2つの方法があります。
最初の(推奨)はモデルパラメータのみを保存してロードします。
torch.save(the_model.state_dict(), PATH)
じゃあ後で:
the_model = TheModelClass(*args, **kwargs)
the_model.load_state_dict(torch.load(PATH))
2番目はモデル全体を保存してロードします。
torch.save(the_model, PATH)
じゃあ後で:
the_model = torch.load(PATH)
ただし、この場合、シリアル化されたデータは特定のクラスと使用される正確なディレクトリ構造にバインドされているため、他のプロジェクトで使用したり、重大なリファクタリングをしたりすると、さまざまな方法で破損します。
それはあなたがやりたいことによります。
ケース#1:モデルを保存して推論に使用する:モデルを保存して復元し、モデルを評価モードに変更します。 。これは、デフォルトでBatchNorm
とDropout
レイヤーがデフォルトで構築モードであるために行われます。
torch.save(model.state_dict(), filepath)
#Later to restore:
model.load_state_dict(torch.load(filepath))
model.eval()
ケース#2:後でトレーニングを再開するためにモデルを保存する:保存しようとしているモデルのトレーニングを継続する必要がある場合は、より多く保存する必要があります。モデルだけです。また、オプティマイザの状態、エポック、スコアなどを保存する必要があります。これは、次のようにして行います。
state = {
'Epoch': Epoch,
'state_dict': model.state_dict(),
'optimizer': optimizer.state_dict(),
...
}
torch.save(state, filepath)
トレーニングを再開するには、次のようにします。state = torch.load(filepath)
、そして、各オブジェクトの状態を復元するには、次のようにします。
model.load_state_dict(state['state_dict'])
optimizer.load_state_dict(state['optimizer'])
トレーニングを再開しているので、ロード時に状態を復元したらNOTmodel.eval()
を呼び出さないでください。
ケース#3:自分のコードにアクセスできない他の人が使うモデル:Tensorflowでは、アーキテクチャとファイルの両方を定義する.pb
ファイルを作成できます。モデルの重みこれは特にTensorflow serve
を使うときにとても便利です。 Pytorchでこれを行うのと同じ方法は次のようになります。
torch.save(model, filepath)
# Then later:
model = torch.load(filepath)
この方法はまだ完全な証拠ではありませんし、pytorchはまだ多くの変更を受けているので、私はそれをお勧めしません。
pickle Pythonライブラリは、Pythonオブジェクトをシリアライズおよびデシリアライズするためのバイナリプロトコルを実装しています。
import torch
(またはPyTorchを使用するとき)はあなたに代わってimport pickle
を生成します。pickle.dump()
とpickle.load()
を直接呼び出す必要はありません。これらはオブジェクトを保存してロードするためのメソッドです。
実際、torch.save()
とtorch.load()
はpickle.dump()
とpickle.load()
をラップします。
もう1つの回答であるstate_dict
には、さらにいくつかのメモが必要です。
PyTorchの内部にはどんなstate_dict
がありますか?実際には2つのstate_dict
があります。
PyTorchモデルはtorch.nn.Module
が学習可能なパラメータを取得するためのmodel.parameters()
呼び出しを持っています(wとb)。これらの学習可能なパラメータは、いったんランダムに設定されると、学習するにつれて徐々に更新されます。学習可能なパラメータは最初のstate_dict
です。
2番目のstate_dict
はオプティマイザの状態辞書です。オプティマイザもモデルの一部です。あなたは、オプティマイザが私たちの学習可能なパラメータを改善するために使われていることを思い出してください。しかし、オプティマイザstate_dict
は固定されています。そこで学ぶことは何もありません。
state_dict
オブジェクトはPythonの辞書なので、保存、更新、変更、復元が簡単にでき、PyTorchモデルとオプティマイザに非常に多くのモジュール性を追加します。
これを説明するために、超簡単なモデルを作成しましょう。
import torch
import torch.optim as optim
model = torch.nn.Linear(5, 2)
# Initialize optimizer
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)
print("Model's state_dict:")
for param_tensor in model.state_dict():
print(param_tensor, "\t", model.state_dict()[param_tensor].size())
print("Model weight:")
print(model.weight)
print("Model bias:")
print(model.bias)
print("---")
print("Optimizer's state_dict:")
for var_name in optimizer.state_dict():
print(var_name, "\t", optimizer.state_dict()[var_name])
このコードは以下を出力します。
Model's state_dict:
weight torch.Size([2, 5])
bias torch.Size([2])
Model weight:
Parameter containing:
tensor([[ 0.1328, 0.1360, 0.1553, -0.1838, -0.0316],
[ 0.0479, 0.1760, 0.1712, 0.2244, 0.1408]], requires_grad=True)
Model bias:
Parameter containing:
tensor([ 0.4112, -0.0733], requires_grad=True)
---
Optimizer's state_dict:
state {}
param_groups [{'lr': 0.001, 'momentum': 0.9, 'dampening': 0, 'weight_decay': 0, 'nesterov': False, 'params': [140695321443856, 140695321443928]}]
これは最小限のモデルです。あなたはシーケンシャルのスタックを追加しようとするかもしれません
model = torch.nn.Sequential(
torch.nn.Linear(D_in, H),
torch.nn.Conv2d(A, B, C)
torch.nn.Linear(H, D_out),
)
学習可能なパラメータを持つレイヤ(たたみ込みレイヤ、線形レイヤなど)と登録バッファ(batchnormレイヤ)だけがモデルのstate_dict
にエントリを持つことに注意してください。
学習できないことは、オプティマイザオブジェクトstate_dict
に属しています。これには、オプティマイザの状態と使用されているハイパーパラメータに関する情報が含まれています。
ストーリーの残りの部分は同じです。予測のための推論段階(これは訓練後にモデルを使用する段階である)。学習したパラメータに基づいて予測を行います。そのため、推論のために、パラメータmodel.state_dict()
を保存する必要があります。
torch.save(model.state_dict(), filepath)
そして後で使うためにmodel.load_state_dict(torch.load(filepath))model.eval()
注:最後の行model.eval()
を忘れないでください。これはモデルを読み込んだ後に重要になります。
またtorch.save(model.parameters(), filepath)
を保存しようとしないでください。 model.parameters()
は単なるジェネレータオブジェクトです。
反対に、torch.save(model, filepath)
はモデルオブジェクト自体を保存しますが、モデルはオプティマイザのstate_dict
を持っていないことを覚えておいてください。オプティマイザの状態辞書を保存するために@Jadiel de Armasによる他の優れた答えをチェックしてください。
一般的なPyTorchの慣例は、.ptか.pthファイル拡張子を使ってモデルを保存することです。
モデル全体を保存/読み込み保存:
path = "username/directory/lstmmodelgpu.pth"
torch.save(trainer, path)
負荷:
model = torch.load(PATH)
model.eval()