私はPyTorchのドキュメントを読んでいて、彼らが書いている例を見つけました
gradients = torch.FloatTensor([0.1, 1.0, 0.0001])
y.backward(gradients)
print(x.grad)
ここで、xは初期変数で、そこからyが構築されました(3ベクトル)。問題は、勾配テンソルの0.1、1.0、0.0001引数は何ですか?ドキュメントはそれについてあまり明確ではありません。
PyTorchのWebサイトでもう見つけていない元のコード。
gradients = torch.FloatTensor([0.1, 1.0, 0.0001])
y.backward(gradients)
print(x.grad)
上記のコードの問題には、勾配の計算対象に基づいた機能はありません。これは、パラメーターの数(関数が取る引数)とパラメーターの次元がわからないことを意味します。
これを完全に理解するために、オリジナルに近いいくつかの例を作成しました。
例1:
a = torch.tensor([1.0, 2.0, 3.0], requires_grad = True)
b = torch.tensor([3.0, 4.0, 5.0], requires_grad = True)
c = torch.tensor([6.0, 7.0, 8.0], requires_grad = True)
y=3*a + 2*b*b + torch.log(c)
gradients = torch.FloatTensor([0.1, 1.0, 0.0001])
y.backward(gradients,retain_graph=True)
print(a.grad) # tensor([3.0000e-01, 3.0000e+00, 3.0000e-04])
print(b.grad) # tensor([1.2000e+00, 1.6000e+01, 2.0000e-03])
print(c.grad) # tensor([1.6667e-02, 1.4286e-01, 1.2500e-05])
ご覧のとおり、最初の例では関数はy=3*a + 2*b*b + torch.log(c)
であり、パラメーターは内部に3つの要素を持つテンソルであると想定しています。
しかし、別のオプションがあります:
例2:
import torch
a = torch.tensor(1.0, requires_grad = True)
b = torch.tensor(1.0, requires_grad = True)
c = torch.tensor(1.0, requires_grad = True)
y=3*a + 2*b*b + torch.log(c)
gradients = torch.FloatTensor([0.1, 1.0, 0.0001])
y.backward(gradients)
print(a.grad) # tensor(3.3003)
print(b.grad) # tensor(4.4004)
print(c.grad) # tensor(1.1001)
gradients = torch.FloatTensor([0.1, 1.0, 0.0001])
はアキュムレーターです。
次の例では、同じ結果が得られます。
例3:
a = torch.tensor(1.0, requires_grad = True)
b = torch.tensor(1.0, requires_grad = True)
c = torch.tensor(1.0, requires_grad = True)
y=3*a + 2*b*b + torch.log(c)
gradients = torch.FloatTensor([0.1])
y.backward(gradients,retain_graph=True)
gradients = torch.FloatTensor([1.0])
y.backward(gradients,retain_graph=True)
gradients = torch.FloatTensor([0.0001])
y.backward(gradients)
print(a.grad) # tensor(3.3003)
print(b.grad) # tensor(4.4004)
print(c.grad) # tensor(1.1001)
PyTorch autogradシステムの計算はJacobian製品と同等であると聞きます。
私たちがやったように、あなたが機能を持っている場合:
y=3*a + 2*b*b + torch.log(c)
ヤコビアンは[3, 4*b, 1/c]
になります。ただし、この Jacobian は、PyTorchが特定のポイントで勾配を計算する方法を実行しているわけではありません。
前の関数の場合、PyTorchはδy/δb
に対して、b=1
およびb=1+ε
に対してεが小さくなります。したがって、シンボリック数学のようなものは何も含まれていません。
y.backward()
でグラデーションを使用しない場合:
例4
a = torch.tensor(0.1, requires_grad = True)
b = torch.tensor(1.0, requires_grad = True)
c = torch.tensor(0.1, requires_grad = True)
y=3*a + 2*b*b + torch.log(c)
y.backward()
print(a.grad) # tensor(3.)
print(b.grad) # tensor(4.)
print(c.grad) # tensor(10.)
最初にa
、b
、c
テンソルを設定した方法に基づいて、ポイントで結果を簡単に取得できます。
a
、b
、c
の初期化方法に注意してください。
例5:
a = torch.empty(1, requires_grad = True, pin_memory=True)
b = torch.empty(1, requires_grad = True, pin_memory=True)
c = torch.empty(1, requires_grad = True, pin_memory=True)
y=3*a + 2*b*b + torch.log(c)
gradients = torch.FloatTensor([0.1, 1.0, 0.0001])
y.backward(gradients)
print(a.grad) # tensor([3.3003])
print(b.grad) # tensor([0.])
print(c.grad) # tensor([inf])
torch.empty()
を使用し、pin_memory=True
を使用しない場合、毎回異なる結果になる可能性があります。
また、ノートグラデーションはアキュムレータに似ているため、必要に応じてゼロにします。
例6:
a = torch.tensor(1.0, requires_grad = True)
b = torch.tensor(1.0, requires_grad = True)
c = torch.tensor(1.0, requires_grad = True)
y=3*a + 2*b*b + torch.log(c)
y.backward(retain_graph=True)
y.backward()
print(a.grad) # tensor(6.)
print(b.grad) # tensor(8.)
print(c.grad) # tensor(2.)
最後に、PyTorchが使用するいくつかの用語を述べたいだけです。
PyTorchは、勾配を計算するときに動的計算グラフを作成します。これは木のように見えます。
そのため、このツリーのleavesはinput tensorsとrootはoutput tensorです。
グラデーションは、グラフをルートからリーフまでトレースし、チェーンルールを使用して、すべてのグラデーションを途中で乗算することによって計算されます。
ニューラルネットワークでは、通常loss
を使用して、ネットワークが入力画像(または他のタスク)の分類をどの程度学習したかを評価します。 loss
項は通常、スカラー値です。ネットワークのパラメーターを更新するには、loss
wrtからパラメーターへの勾配を計算する必要があります。これは、実際には計算グラフでleaf node
です(ところで、これらのパラメーターはほとんどが重みですコンボリューション、リニアなどのさまざまなレイヤーのバイアス)。
チェーンルールに従って、loss
wrtのリーフノードへの勾配を計算するために、loss
wrtの導関数を中間変数で計算し、中間変数wrtのリーフ変数への勾配を計算できます。内積し、これらすべてを合計します。
gradient
の- backward()
メソッドのVariable
引数は、変数の各要素の加重和を計算するために使用されます リーフ変数 。これらの重みは、中間変数の各要素に対する最終loss
の派生物です。
これを理解するために、具体的かつ簡単な例を見てみましょう。
from torch.autograd import Variable
import torch
x = Variable(torch.FloatTensor([[1, 2, 3, 4]]), requires_grad=True)
z = 2*x
loss = z.sum(dim=1)
# do backward for first element of z
z.backward(torch.FloatTensor([[1, 0, 0, 0]]), retain_graph=True)
print(x.grad.data)
x.grad.data.zero_() #remove gradient in x.grad, or it will be accumulated
# do backward for second element of z
z.backward(torch.FloatTensor([[0, 1, 0, 0]]), retain_graph=True)
print(x.grad.data)
x.grad.data.zero_()
# do backward for all elements of z, with weight equal to the derivative of
# loss w.r.t z_1, z_2, z_3 and z_4
z.backward(torch.FloatTensor([[1, 1, 1, 1]]), retain_graph=True)
print(x.grad.data)
x.grad.data.zero_()
# or we can directly backprop using loss
loss.backward() # equivalent to loss.backward(torch.FloatTensor([1.0]))
print(x.grad.data)
上記の例では、最初のprint
の結果は
2 0 0 0
[サイズ1x4のTorch.FloatTensor]
これは、z_1 w.r.tのxへの導関数です。
2番目のprint
の結果は次のとおりです。
0 2 0 0
[サイズ1x4のTorch.FloatTensor]
これは、z_2 w.r.tのxへの導関数です。
ここで、[1、1、1、1]の重みを使用して、z w.r.tのxへの導関数を計算すると、結果は1*dz_1/dx + 1*dz_2/dx + 1*dz_3/dx + 1*dz_4/dx
になります。当然のことながら、3番目のprint
の出力は次のとおりです。
2 2 2 2
[サイズ1x4のTorch.FloatTensor]
重みベクトル[1、1、1、1]は、loss
w.r.tのz_1、z_2、z_3、およびz_4の導関数であることに注意してください。 loss
w.r.tからx
への導関数は、次のように計算されます。
d(loss)/dx = d(loss)/dz_1 * dz_1/dx + d(loss)/dz_2 * dz_2/dx + d(loss)/dz_3 * dz_3/dx + d(loss)/dz_4 * dz_4/dx
したがって、4番目のprint
の出力は、3番目のprint
と同じです。
2 2 2 2
[サイズ1x4のTorch.FloatTensor]
通常、計算グラフにはloss
という1つのスカラー出力があります。次に、loss
w.r.tの勾配を計算できます。 loss.backward()
による重み(w
)。 backward()
のデフォルト引数は1.0
です。
出力に複数の値がある場合(例:loss=[loss1, loss2, loss3]
)、損失w.r.tの勾配を計算できます。 loss.backward(torch.FloatTensor([1.0, 1.0, 1.0]))
による重み。
さらに、さまざまな損失に重みまたは重要度を追加する場合は、loss.backward(torch.FloatTensor([-0.1, 1.0, 0.0001]))
を使用できます。
これは、-0.1*d(loss1)/dw, d(loss2)/dw, 0.0001*d(loss3)/dw
を同時に計算することを意味します。
ここで、forward()の出力、つまりyは3つのベクトルです。
3つの値は、ネットワークの出力での勾配です。 yが最終出力の場合、通常は1.0に設定されますが、特にyがより大きなネットワークの一部である場合は、他の値も持つことができます。
例えばxが入力の場合、y = [y1、y2、y3]は最終出力zの計算に使用される中間出力です。
次に、
dz/dx = dz/dy1 * dy1/dx + dz/dy2 * dy2/dx + dz/dy3 * dy3/dx
したがって、ここでは、後方への3つの値は
[dz/dy1, dz/dy2, dz/dy3]
次に、backward()はdz/dxを計算します