web-dev-qa-db-ja.com

PyTorchの連結テンソル

形状_[128, 4, 150, 150]_のdataというテンソルがあります。ここで、128はバッチサイズ、4はチャネル数、最後の2つの次元は高さと幅です。形状_[128, 1, 150, 150]_のfakeと呼ばれる別のテンソルがあります。

dataの2番目の次元から最後の_list/array_を削除したい;データの形状は_[128, 3, 150, 150]_になります。 fakeと連結して、連結の出力ディメンションを_[128, 4, 150, 150]_として指定します。

基本的に、言い換えれば、dataの最初の3次元をfakeと連結して、4次元テンソルを与えたいと思います。

私はPyTorchを使用していて、関数torch.cat()torch.stack()に出くわしました。

これが私が書いたサンプルコードです:

_fake_combined = []
        for j in range(batch_size):
            fake_combined.append(torch.stack((data[j][0].to(device), data[j][1].to(device), data[j][2].to(device), fake[j][0].to(device))))
fake_combined = torch.tensor(fake_combined, dtype=torch.float32)
fake_combined = fake_combined.to(device)
_

しかし、次の行でエラーが発生しています。

_fake_combined = torch.tensor(fake_combined, dtype=torch.float32)
_

エラーは次のとおりです。

_ValueError: only one element tensors can be converted to Python scalars
_

また、_fake_combined_の形状を印刷すると、出力は_[128,]_ではなく_[128, 4, 150, 150]_として取得されます。

そして、_fake_combined[0]_の形状を印刷すると、出力は_[4, 150, 150]_として取得されますが、これは予想どおりです。

だから私の質問は、なぜtorch.tensor()を使用してリストをテンソルに変換できないのかということです。私は何かが足りないのですか?私がやろうとしていることをするためのより良い方法はありますか?

どんな助けでも大歓迎です!ありがとう!

3
ntd

その特定のディメンションに割り当てることもできます。

orig = torch.randint(low=0, high=10, size=(2,3,2,2))
fake = torch.randint(low=111, high=119, size=(2,1,2,2))
orig[:,[2],:,:] = fake

元の前

tensor([[[[0, 1],
      [8, 0]],

     [[4, 9],
      [6, 1]],

     [[8, 2],
      [7, 6]]],


    [[[1, 1],
      [8, 5]],

     [[5, 0],
      [8, 6]],

     [[5, 5],
      [2, 8]]]])

tensor([[[[117, 115],
      [114, 111]]],


    [[[115, 115],
      [118, 115]]]])

オリジナルアフター

tensor([[[[  0,   1],
      [  8,   0]],

     [[  4,   9],
      [  6,   1]],

     [[117, 115],
      [114, 111]]],


    [[[  1,   1],
      [  8,   5]],

     [[  5,   0],
      [  8,   6]],

     [[115, 115],
      [118, 115]]]])

お役に立てれば! :)

2
PankajKabra

@ rollthedice32の答えは完全にうまく機能します。教育目的で、ここではtorch.catを使用しています

a = torch.Rand(128, 4, 150, 150)
b = torch.Rand(128, 1, 150, 150)

# Cut out last dimension
a = a[:, :3, :, :]
# Concatenate in 2nd dimension
result = torch.cat([a, b], dim=1)
print(result.shape)
# => torch.Size([128, 4, 150, 150])
6
Coolness