web-dev-qa-db-ja.com

pytorchのリシェイプとビューの違いは何ですか?

Numpyでは、ndarray.reshape()を使用して配列を再形成します。

Pytorchでは、同じ目的でtorch.view(...)を使用しますが、同時にtorch.reshape(...)も存在することに気付きました。

だから、それらの違いは何ですか、いつどちらを使用する必要がありますか?

37
Lifu Huang

torch.viewは長い間存在しています。新しい形状のテンソルを返します。返されたテンソルは、元のテンソルと基礎データを共有します。 こちらのドキュメント をご覧ください。

一方、torch.reshapeバージョン0.4で最近導入された のようです。 document によると、このメソッドは

入力と同じデータと要素数で、指定された形状を持つテンソルを返します。可能な場合、返されるテンソルは入力のビューになります。それ以外の場合は、コピーになります。連続した入力および互換性のあるストライドを持つ入力は、コピーせずに再形成できますが、コピーと表示の動作に依存しないでください。

torch.reshapeは元のテンソルのコピーまたはビューを返す可能性があることを意味します。ビューまたはコピーを返すことを期待することはできません。開発者によると:

コピーが必要な場合は、同じストレージが必要な場合はclone()を使用し、view()を使用します。 reshape()のセマンティクスは、ストレージを共有する場合としない場合があり、事前に知らないことです。

別の違いは、reshape()は連続テンソルと非連続テンソルの両方で操作できますが、view()は連続テンソルでのみ操作できます。 contiguousの意味については here も参照してください。

39
jdhao

torch.viewtorch.reshapeは両方ともテンソルの形状を変更するために使用されますが、それらの違いは次のとおりです。

  1. 名前が示すように、torch.viewは単に元のテンソルのviewを作成するだけです。新しいテンソルは、always元のテンソルとデータを共有します。つまり、元のテンソルを変更すると、再構成されたテンソルも変更され、その逆も同様です。
>>> z = torch.zeros(3, 2)
>>> x = z.view(2, 3)
>>> z.fill_(1)
>>> x
tensor([[1., 1., 1.],
        [1., 1., 1.]])
  1. 新しいテンソルが元のデータと常にデータを共有することを保証するために、torch.viewは2つのテンソルの形状に連続性制約を課します[ docs ]。多くの場合、これは懸念事項ではありませんが、2つのテンソルの形状に互換性がある場合でも、torch.viewはエラーをスローすることがあります。これは有名な反例です。
>>> z = torch.zeros(3, 2)
>>> y = z.t()
>>> y.size()
torch.Size([2, 3])
>>> y.view(6)
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
RuntimeError: invalid argument 2: view size is not compatible with input tensor's
size and stride (at least one dimension spans across two contiguous subspaces).
Call .contiguous() before .view().
  1. torch.reshapeは連続性制約を課しませんが、データ共有を保証しません。新しいテンソルは、元のテンソルのビューである場合もあれば、まったく新しいテンソルである場合もあります。
>>> z = torch.zeros(3, 2)
>>> y = z.reshape(6)
>>> x = z.t().reshape(6)
>>> z.fill_(1)
tensor([[1., 1.],
        [1., 1.],
        [1., 1.]])
>>> y
tensor([1., 1., 1., 1., 1., 1.])
>>> x
tensor([0., 0., 0., 0., 0., 0.])

TL; DR:
テンソルの形状を変更するだけの場合は、torch.reshapeを使用します。メモリ使用量も心配で、2つのテンソルが同じデータを共有するようにしたい場合は、torch.viewを使用します。

8
nikhilweee

Tensor.reshape()はより堅牢です。 Tensor.view()はテンソルt where t.is_contiguous()==Trueでのみ機能しますが、どのテンソルでも機能します。

非連続および連続について説明することは別のタイムストーリーですが、t.contiguous()を呼び出すことで、tをテンソルに常に連続させることができ、view()をエラーなしで呼び出すことができます。

0
prosti