web-dev-qa-db-ja.com

シーケンシャルコンテナによるPyTorchビルドのレイヤーをフラット化

PyTorchのシーケンシャルコンテナでcnnを構築しようとしていますが、問題はレイヤーをフラット化する方法がわからないことです。

main = nn.Sequential()
self._conv_block(main, 'conv_0', 3, 6, 5)
main.add_module('max_pool_0_2_2', nn.MaxPool2d(2,2))
self._conv_block(main, 'conv_1', 6, 16, 3)
main.add_module('max_pool_1_2_2', nn.MaxPool2d(2,2)) 
main.add_module('flatten', make_it_flatten)

「make_it_flatten」には何を入れるべきですか?メインを平らにしようとしましたが、機能しません。メインが存在しません。ビューを呼び出します。

main = main.view(-1, 16*3*3)
8
StereoMatching

これはあなたが探しているものと正確に一致しないかもしれませんが、入力をフラット化する独自の_nn.Module_を作成し、それをnn.Sequential()オブジェクトに追加することができます。

_class Flatten(nn.Module):
    def forward(self, x):
        return x.view(x.size()[0], -1)
_

x.size()[0]はバッチdimを選択し、_-1_は要素の数に合うように残りのすべてのdimを計算し、それによってテンソル/変数を平坦化します。

そしてそれを_nn.Sequential_で使用する:

_main = nn.Sequential()
self._conv_block(main, 'conv_0', 3, 6, 5)
main.add_module('max_pool_0_2_2', nn.MaxPool2d(2,2))
self._conv_block(main, 'conv_1', 6, 16, 3)
main.add_module('max_pool_1_2_2', nn.MaxPool2d(2,2)) 
main.add_module('flatten', Flatten())
_
17
cleros

レイヤーをフラット化する最も速い方法は、新しいモジュールを作成せず、main.add_module('flatten', Flatten())を介してそのモジュールをメインに追加することです。

_class Flatten(nn.Module):
    def forward(self, input):
        return input.view(input.size(0), -1)
_

代わりに、モデルのforward内の単純なout = inp.reshape(inp.size(0), -1)は、 here で示したように高速です。

2
prosti