web-dev-qa-db-ja.com

PyTorchでzero_grad()を呼び出す必要があるのはなぜですか?

メソッドzero_grad()は、トレーニング中に呼び出す必要があります。しかし、 documentation はあまり役に立ちません

|  zero_grad(self)
|      Sets gradients of all model parameters to zero.

なぜこのメソッドを呼び出す必要があるのですか?

27
user1424739

PyTorch では、PyTorch グラデーションを累積する以降のバックワードパスでバックプロパゲーションを開始する前に、グラデーションをゼロに設定する必要があります。これは、RNNのトレーニング中に便利です。したがって、デフォルトのアクションは、すべてのloss.backward()呼び出しで 勾配を累積(つまり合計) することです。

このため、トレーニングループを開始するときは、理想的には zero out the gradients にして、パラメーターの更新を正しく行う必要があります。そうでない場合、勾配は、最小(または最大化目標の場合は最大)に向かう意図した方向以外の方向を指します。

以下に簡単な例を示します。

import torch
from torch.autograd import Variable
import torch.optim as optim

def linear_model(x, W, b):
    return torch.matmul(x, W) + b

data, targets = ...

W = Variable(torch.randn(4, 3), requires_grad=True)
b = Variable(torch.randn(3), requires_grad=True)

optimizer = optim.Adam([W, b])

for sample, target in Zip(data, targets):
    # clear out the gradients of all Variables 
    # in this optimizer (i.e. W, b)
    optimizer.zero_grad()
    output = linear_model(sample, W, b)
    loss = (output - target) ** 2
    loss.backward()
    optimizer.step()

あるいは、Vanilla gradient descentを実行している場合、次のようになります。

W = Variable(torch.randn(4, 3), requires_grad=True)
b = Variable(torch.randn(3), requires_grad=True)

for sample, target in Zip(data, targets):
    # clear out the gradients of Variables 
    # (i.e. W, b)
    W.grad.data.zero_()
    b.grad.data.zero_()

    output = linear_model(sample, W, b)
    loss = (output - target) ** 2
    loss.backward()

    W -= learning_rate * W.grad.data
    b -= learning_rate * b.grad.data

:勾配の累積(つまりsum)は .backward()は、lossテンソルで呼び出されます

43
kmario23