Pytorch初心者はこちら!次のカスタムモジュールについて考えてみます。
_class Testme(nn.Module):
def __init__(self):
super(Testme, self).__init__()
def forward(self, x):
return x / t_.max(x).expand_as(x)
_
ドキュメントを理解している限り、これはカスタムFunction
として実装することもできると思います。 Function
のサブクラスにはbackward()
メソッドが必要ですが、Module
には必要ありません。同様に、Linear Module
のドキュメントの例では、Linear Function
に依存しています。
_class Linear(nn.Module):
def __init__(self, input_features, output_features, bias=True):
...
def forward(self, input):
return Linear()(input, self.weight, self.bias)
_
質問:Module
とFunction
の関係がわかりません。上記の最初のリスト(モジュールTestme
)では、関連する関数が必要ですか?そうでない場合は、モジュールをサブクラス化することで、このなしbackward
メソッドを実装できます。それでは、なぜFunction
は常にbackward
メソッドを必要とするのでしょうか。 ?
おそらく、Function
sは、既存のトーチ関数で構成されていない関数のみを対象としていますか?別の言い方をすれば、モジュールのFunction
メソッドが以前に定義されたトーチ関数から完全に構成されている場合、モジュールは関連するforward
を必要としないでしょうか?
この情報は、公式のPyTorchDocumentaionから収集および要約されています。
_torch.autograd.Function
_本当にPyTorchのautogradパッケージの中心にあります。 PyTorchで作成するグラフ、およびPyTorchのVariables
で実行する操作は、Function
に基づいています。すべての関数には、__init__(), forward()
メソッドとbackward()
メソッドが必要です(詳細はこちらをご覧ください: http://pytorch.org/docs/notes/extending.html )。これにより、PyTorchは結果を計算し、Variables
の勾配を計算できます。
対照的に、nn.Module()
は、モデルやさまざまなレイヤーなどを整理するのに非常に便利です。たとえば、モデル内のすべてのトレーニング可能なパラメーターを.parameters()
に整理し、別のレイヤーを追加できます。 forward()
メソッドであるため、後方メソッドを定義する場所ではありませんnot 、すでにFunction()
を定義しているbackward()
のサブクラスを使用することになっています。したがって、forward()
で演算の順序を指定した場合、PyTorchは勾配を逆伝播する方法をすでに知っています。
さて、いつ何を使うべきですか?
PyTorchに実装されている既存の関数の単なる合成である操作(上記のような)がある場合、Function()にサブクラスを自分で追加しても意味がありません。操作を積み重ねて動的グラフを作成できるからです。ただし、これらの操作をまとめることは賢明なアイデアです。操作にトレーニング可能なパラメーター(ニューラルネットワークの線形層など)が含まれる場合は、nn.Module()
をサブクラス化し、forwardメソッドで操作をまとめる必要があります。これにより、_torch.optim
_などを使用するためのパラメーター(上記で概説)に簡単にアクセスできます。トレーニング可能なパラメーターがない場合でも、おそらくそれらをまとめますが、標準のPython関数、使用する各操作のインスタンス化を処理するだけで十分です。
新しいカスタム操作(たとえば、複雑なサンプリング手順を伴う新しい確率的レイヤー)がある場合は、Function()
をサブクラス化し、__init__(), forward()
とbackward()
を定義してPyTorchに通知する必要があります。この操作を使用する場合の結果の計算方法と勾配の計算方法。その後、関数のインスタンス化を処理して操作を使用する機能バージョンを作成するか、操作にトレーニング可能なパラメーターがある場合はモジュールを作成する必要があります。繰り返しになりますが、これについて詳しくは上記のリンクをご覧ください。