web-dev-qa-db-ja.com

Pytorchで「スクイーズ解除」は何をしますか?

私はドキュメントを見ています、そしてここにそれらの例があります。この例がそれらの説明にどのように対応するか理解できません:「指定された位置に挿入されたサイズ1の次元を持つ新しいテンソルを返します。」

>>> x = torch.tensor([1, 2, 3, 4])
>>> torch.unsqueeze(x, 0)
tensor([[ 1,  2,  3,  4]])
>>> torch.unsqueeze(x, 1)
tensor([[ 1],
        [ 2],
        [ 3],
        [ 4]])
19
user6549804

配列の前と後の形状を見ると、その前は_(4,)_であり、その後は_(1, 4)_(2番目のパラメーターが_0_の場合)および_(4, 1)_であることがわかります。 (2番目のパラメーターが_1_の場合)。そのため、2番目のパラメータの値に応じて、_1_が軸_0_または_1_の配列の形で挿入されました。

これは、サイズ_1_(シングルトン)の軸を削除する np.squeeze() (MATLABから借用した命名法)の反対です。

20
norok2

これはPyTorchのレガシーであるため、PyTorchの参照がここで言及されていない理由はわかりません。

torch.squeezeenter image description here

torch.unsqueezeenter image description here

9
prosti

寸法を追加する位置を示します。 torch.unsqueezeは、テンソルに次元を追加します。たとえば、形状のテンソル(3)があるとします。0の位置に次元を追加すると、形状は(1,3)になります。つまり、1行3列です。 1の位置に追加すると、(3,1)になります。これは、3行と1列を意味します。形状(2,2)の2Dテンソルがある場合、0の位置に余分な次元を追加します。これにより、テンソルの形状が(1,2,2)になります。これは、1つのチャネル、2つの行、2を意味します。列。 1の位置に追加すると、形状は(2,1,2)になるため、2つのチャネル、1つの行と2つの列があります。 2の位置に追加すると、テンソルの形状は(2,2,1)になります。これは、2つのチャネル、2つの行、1つの列を意味します。

5
blueboy21