web-dev-qa-db-ja.com

PyTorchテンソルをpython listに変換します

PyTorch Tensorをpythonリストに変換するにはどうすればよいですか?

私の現在の使用例は、サイズ[1, 2048, 1, 1]のテンソルを2048要素のリストに変換することです。

私のテンソルには浮動小数点値があります。 intおよびおそらく他のデータ型も考慮するソリューションはありますか?

19
Tom Hale

私は Tensor.tolist() を見つけました。これは次の使用例を示しています:

_>>> a = torch.randn(2, 2)
>>> a.tolist()
[[0.012766935862600803, 0.5415473580360413],
 [-0.08909505605697632, 0.7729271650314331]]
>>> a[0,0].tolist()
0.012766935862600803
_

したがって、質問に答えるには、a.squeeze().tolist()を使用して、サイズ_1_のすべての次元を削除します。

リストのリストが必要ない場合は、 .flatten() も検討してください。


.tolist()に出会う前に、私は以下を使用していました:

_list = [element.item() for element in tensor.flatten()]
_

これはテンソルを単一の次元に平坦化し、次に .item() を呼び出して各要素をPython数値に変換します。

30
Tom Hale