pythonリストでは、list.index(somevalue)
を使用できます。pytorchでこれを行うにはどうすればよいですか?
例えば:
a=[1,2,3]
print(a.index(2))
次に、1
が出力されます。 pytorchテンソルは、pythonリストに変換せずにこれを行うにはどうすればよいですか?
list.index()
からpytorch関数への直接の変換はないと思います。ただし、_tensor==number
_を使用してnonzero()
関数を使用すると、同様の結果を得ることができます。例えば:
_t = torch.Tensor([1, 2, 3])
print ((t == 2).nonzero())
_
このコードは
1
[サイズ1x1のTorch.LongTensor]
次のようにnumpyに変換することで実行できます
import torch
x = torch.range(1,4)
print(x)
===> tensor([ 1., 2., 3., 4.])
nx = x.numpy()
np.where(nx == 3)[0][0]
===> 2
浮動小数点テンソルの場合、これを使用してテンソル内の要素のインデックスを取得します。
print((torch.abs((torch.max(your_tensor).item()-your_tensor))<0.0001).nonzero())
ここで、フロートテンソルのmax_valueのインデックスを取得します。このような値を入力して、テンソルの任意の要素のインデックスを取得することもできます。
print((torch.abs((YOUR_VALUE-your_tensor))<0.0001).nonzero())