web-dev-qa-db-ja.com

Pytorchモデルの評価: `with torch.no_grad`対` model.eval() `

検証セットでのモデルのパフォーマンスを評価したい場合、以下を使用することをお勧めします。

  • _with torch.no_grad:_

または

  • model.eval()
22
Tom Hale

TL; DR:

both を使用します。彼らは異なることをし、異なるスコープを持っています。

  • _with torch.no_grad_-autogradのグラデーションの追跡を無効にします。
  • model.eval()は、呼び出されたモジュールのforward()動作を変更します
    • たとえば、ドロップアウトを無効にし、バッチ標準で全人口統計を使用します

_with torch.no_grad_

_torch.autograd.no_grad_のドキュメント はこう言っています:

[シック]勾配計算を無効にするコンテキストマネージャー。

勾配計算を無効にすることは、Tensor.backward()を呼び出さないことが確実な場合に、推論に役立ちます。それ以外の場合は_requires_grad=True_が必要な計算のメモリ消費量が削減されます。このモードでは、入力に_requires_grad=False_が含まれている場合でも、すべての計算の結果に_requires_grad=True_が含まれます。

model.eval()

_nn.Module.eval_のドキュメント はこう言っています:

モジュールを評価モードに設定します。

これは特定のモジュールでのみ効果があります。影響を受ける場合、トレーニング/評価モードでの動作の詳細については、特定のモジュールのドキュメントを参照してください。 DropoutBatchNormなど.


pytorchの作成者は、ドキュメントを更新して両方の使用法を提案する必要があると述べました

24
Tom Hale