web-dev-qa-db-ja.com

Pytorch Tensorが特定の値のインデックスを取得する方法

pythonリストでは、list.index(somevalue)を使用できます。pytorchでこれを行うにはどうすればよいですか?
例えば:

    a=[1,2,3]
    print(a.index(2))

次に、1が出力されます。 pytorchテンソルは、pythonリストに変換せずにこれを行うにはどうすればよいですか?

15
Han Bing

list.index()からpytorch関数への直接の変換はないと思います。ただし、_tensor==number_を使用してnonzero()関数を使用すると、同様の結果を得ることができます。例えば:

_t = torch.Tensor([1, 2, 3])
print ((t == 2).nonzero())
_

このコードは

1

[サイズ1x1のTorch.LongTensor]

24
Manuel Lagunas

次のようにnumpyに変換することで実行できます

import torch
x = torch.range(1,4)
print(x)
===> tensor([ 1.,  2.,  3.,  4.]) 
nx = x.numpy()
np.where(nx == 3)[0][0]
===> 2
2
vlad

浮動小数点テンソルの場合、これを使用してテンソル内の要素のインデックスを取得します。

print((torch.abs((torch.max(your_tensor).item()-your_tensor))<0.0001).nonzero())

ここで、フロートテンソルのmax_valueのインデックスを取得します。このような値を入力して、テンソルの任意の要素のインデックスを取得することもできます。

print((torch.abs((YOUR_VALUE-your_tensor))<0.0001).nonzero())
0
Giang Nguyễn