私は、シーケンシャルモデルでスキップ接続に頭を包み込もうとしています。機能的なAPIを使用すると、次のように簡単に実行できます(簡単な例、100%構文的に正確ではないかもしれませんが、アイデアが得られるはずです)。
x1 = self.conv1(inp)
x = self.conv2(x)
x = self.conv3(x)
x = self.conv4(x)
x = self.deconv4(x)
x = self.deconv3(x)
x = self.deconv2(x)
x = torch.cat((x, x1), 1))
x = self.deconv1(x)
私は現在、シーケンシャルモデルを使用して、同様のことをしようとしています。最初のconvレイヤーのアクティベーションを最後のconvTransposeまでもたらすスキップ接続を作成します。 here 実装されたU-netアーキテクチャを見てきましたが、少し混乱します。次のようになります。
upconv = nn.ConvTranspose2d(inner_nc * 2, outer_nc,
kernel_size=4, stride=2,
padding=1, bias=use_bias)
down = [downrelu, downconv, downnorm]
up = [uprelu, upconv, upnorm]
if use_dropout:
model = down + [submodule] + up + [nn.Dropout(0.5)]
else:
model = down + [submodule] + up
これは、シーケンシャルモデルにレイヤーを適切に追加するだけではありませんか? down
convに続いてsubmodule
(再帰的に内側の層を追加)が続き、upconv層であるup
に連結されます。おそらくSequential
APIがどのように機能するかについて重要なことを見逃していますが、U-NETから切り取られたコードはどのように実際にスキップを実装していますか?
観察は正しいが、UnetSkipConnectionBlock.forward()
(UnetSkipConnectionBlock
は共有したU-Netブロックを定義するModule
である)の定義を見逃している可能性があります。 :
(from _pytorch-CycleGAN-and-pix2pix/models/networks.py#L259
_ )
_# Defines the submodule with skip connection.
# X -------------------identity---------------------- X
# |-- downsampling -- |submodule| -- upsampling --|
class UnetSkipConnectionBlock(nn.Module):
# ...
def forward(self, x):
if self.outermost:
return self.model(x)
else:
return torch.cat([x, self.model(x)], 1)
_
最後の行はキーです(すべての内部ブロックに適用されます)。スキップレイヤーは、入力x
と(再帰)ブロック出力self.model(x)
を、_self.model
_で指定した操作のリストと連結することで簡単に行われます。書いたFunctional
コード。