関数torch.nn.functional.softmax
は、input
とdim
の2つのパラメーターを取ります。そのドキュメントによると、softmax操作は、指定されたinput
に沿ってdim
のすべてのスライスに適用され、要素が(0, 1)
の範囲にあり、合計が1。
入力を次のようにします。
input = torch.randn((3, 4, 5, 6))
その配列のすべてのエントリが1になるように、次のものが必要だとします。
sum = torch.sum(input, dim = 3) # sum's size is (3, 4, 5, 1)
Softmaxはどのように適用すればよいですか?
softmax(input, dim = 0) # Way Number 0
softmax(input, dim = 1) # Way Number 1
softmax(input, dim = 2) # Way Number 2
softmax(input, dim = 3) # Way Number 3
私の直感は、それが最後のものであることを教えてくれますが、私にはわかりません。英語は私の第一言語ではなく、そのためにalong
の使用が混乱しているように見えました。
「に沿って」が何を意味するのかはあまり明確ではないので、物事を明確にすることができる例を使用します。サイズのテンソル(s1、s2、s3、s4)があり、これを実現したいとします
私が理解できる最も簡単な方法は、形状のテンソルが与えられていると言うことです(s1, s2, s3, s4)
そして、あなたが言及したように、あなたは最後の軸に沿ったすべてのエントリの合計を1にしたいです。
sum = torch.sum(input, dim = 3) # input is of shape (s1, s2, s3, s4)
次に、softmaxを次のように呼び出す必要があります。
softmax(input, dim = 3)
簡単に理解するために、形状の4次元テンソル(s1, s2, s3, s4)
2Dテンソルまたは形状の行列(s1*s2*s3, s4)
。マトリックスの各行(axis = 0)または列(axis = 1)に合計1の値を含める場合、次のように2dテンソルでsoftmax
関数を呼び出すことができます。
softmax(input, dim = 0) # normalizes values along axis 0
softmax(input, dim = 1) # normalizes values along axis 1
Stevenが answer で言及した例を見ることができます。
私はあなたの質問が何を意味するのか100%は確信していませんが、あなたの混乱はあなたがdim
パラメータの意味を理解していないということだけだと思います。だから私はそれを説明し、例を提供します。
私たちが持っている場合:
_m0 = nn.Softmax(dim=0)
_
つまり、_m0
_は、受け取るテンソルの0番目の座標に沿って要素を正規化します。正式には、サイズが_(d0,d1)
_のテンソルb
が与えられた場合、次のことが当てはまります。
_sum^{d0}_{i0=1} b[i0,i1] = 1, forall i1 \in {0,...,d1}
_
pytorchの例でこれを簡単に確認できます。
_>>> b = torch.arange(0,4,1.0).view(-1,2)
>>> b
tensor([[0., 1.],
[2., 3.]])
>>> m0 = nn.Softmax(dim=0)
>>> b0 = m0(b)
>>> b0
tensor([[0.1192, 0.1192],
[0.8808, 0.8808]])
_
_dim=0
_は列_i0 \in {0,1}
_を選択し、その要素(つまり行)を合計すると_i1
_を通過すること(つまり行を通過すること)を意味するため、1を取得する必要があります。
_>>> b0[:,0].sum()
tensor(1.0000)
>>> b0[:,1].sum()
tensor(1.0000)
_
予想通り。
torch.sum(b0,dim=0)
で「行を合計する」ことにより、すべての行の合計が1になることに注意してください。
_>>> torch.sum(b0,0)
tensor([1.0000, 1.0000])
_
より複雑な例を作成して、それが本当に明確であることを確認できます。
_a = torch.arange(0,24,1.0).view(-1,3,4)
>>> a
tensor([[[ 0., 1., 2., 3.],
[ 4., 5., 6., 7.],
[ 8., 9., 10., 11.]],
[[12., 13., 14., 15.],
[16., 17., 18., 19.],
[20., 21., 22., 23.]]])
>>> a0 = m0(a)
>>> a0[:,0,0].sum()
tensor(1.0000)
>>> a0[:,1,0].sum()
tensor(1.0000)
>>> a0[:,2,0].sum()
tensor(1.0000)
>>> a0[:,1,0].sum()
tensor(1.0000)
>>> a0[:,1,1].sum()
tensor(1.0000)
>>> a0[:,2,3].sum()
tensor(1.0000)
_
したがって、最初の値から最後の値までの最初の座標に沿ってすべての要素を合計すると1になります。したがって、すべてが最初の次元(または最初の座標_i0
_)に沿って正規化されます。
_>>> torch.sum(a0,0)
tensor([[1.0000, 1.0000, 1.0000, 1.0000],
[1.0000, 1.0000, 1.0000, 1.0000],
[1.0000, 1.0000, 1.0000, 1.0000]])
_
また、次元0に沿って、その次元に沿って座標を変化させ、各要素を考慮することを意味します。最初の座標が取ることができる値を通過するforループを持つようなもの、つまり.
_for i0 in range(0,d0):
a[i0,b,c,d]
_
2次元の例を考えてみましょう
x = [[1,2],
[3,4]]
最終結果が欲しいですか
y = [[0.27,0.73],
[0.27,0.73]]
または
y = [[0.12,0.12],
[0.88,0.88]]
最初のオプションの場合はdim = 1にします。2番目のオプションの場合はdim = 0にします。
2番目の例では列またはゼロ次元が正規化されているため、ゼロ次元に沿って正規化されていることに注意してください。
2018-07-10更新:ゼロ次元がpytorchの列を参照することを反映します。