Numpyでは、ndarray.reshape()
を使用して配列を再形成します。
Pytorchでは、同じ目的でtorch.view(...)
を使用しますが、同時にtorch.reshape(...)
も存在することに気付きました。
だから、それらの違いは何ですか、いつどちらを使用する必要がありますか?
torch.view
は長い間存在しています。新しい形状のテンソルを返します。返されたテンソルは、元のテンソルと基礎データを共有します。 こちらのドキュメント をご覧ください。
一方、torch.reshape
バージョン0.4で最近導入された のようです。 document によると、このメソッドは
入力と同じデータと要素数で、指定された形状を持つテンソルを返します。可能な場合、返されるテンソルは入力のビューになります。それ以外の場合は、コピーになります。連続した入力および互換性のあるストライドを持つ入力は、コピーせずに再形成できますが、コピーと表示の動作に依存しないでください。
torch.reshape
は元のテンソルのコピーまたはビューを返す可能性があることを意味します。ビューまたはコピーを返すことを期待することはできません。開発者によると:
コピーが必要な場合は、同じストレージが必要な場合はclone()を使用し、view()を使用します。 reshape()のセマンティクスは、ストレージを共有する場合としない場合があり、事前に知らないことです。
別の違いは、reshape()
は連続テンソルと非連続テンソルの両方で操作できますが、view()
は連続テンソルでのみ操作できます。 contiguous
の意味については here も参照してください。
torch.view
とtorch.reshape
は両方ともテンソルの形状を変更するために使用されますが、それらの違いは次のとおりです。
torch.view
は単に元のテンソルのviewを作成するだけです。新しいテンソルは、always元のテンソルとデータを共有します。つまり、元のテンソルを変更すると、再構成されたテンソルも変更され、その逆も同様です。>>> z = torch.zeros(3, 2)
>>> x = z.view(2, 3)
>>> z.fill_(1)
>>> x
tensor([[1., 1., 1.],
[1., 1., 1.]])
torch.view
は2つのテンソルの形状に連続性制約を課します[ docs ]。多くの場合、これは懸念事項ではありませんが、2つのテンソルの形状に互換性がある場合でも、torch.view
はエラーをスローすることがあります。これは有名な反例です。>>> z = torch.zeros(3, 2)
>>> y = z.t()
>>> y.size()
torch.Size([2, 3])
>>> y.view(6)
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
RuntimeError: invalid argument 2: view size is not compatible with input tensor's
size and stride (at least one dimension spans across two contiguous subspaces).
Call .contiguous() before .view().
torch.reshape
は連続性制約を課しませんが、データ共有を保証しません。新しいテンソルは、元のテンソルのビューである場合もあれば、まったく新しいテンソルである場合もあります。>>> z = torch.zeros(3, 2)
>>> y = z.reshape(6)
>>> x = z.t().reshape(6)
>>> z.fill_(1)
tensor([[1., 1.],
[1., 1.],
[1., 1.]])
>>> y
tensor([1., 1., 1., 1., 1., 1.])
>>> x
tensor([0., 0., 0., 0., 0., 0.])
TL; DR:
テンソルの形状を変更するだけの場合は、torch.reshape
を使用します。メモリ使用量も心配で、2つのテンソルが同じデータを共有するようにしたい場合は、torch.view
を使用します。
Tensor.reshape()
はより堅牢です。 Tensor.view()
はテンソルt
where t.is_contiguous()==True
でのみ機能しますが、どのテンソルでも機能します。
非連続および連続について説明することは別のタイムストーリーですが、t.contiguous()
を呼び出すことで、t
をテンソルに常に連続させることができ、view()
をエラーなしで呼び出すことができます。