web-dev-qa-db-ja.com

(モジュールだけでなく)pytorchカスタム関数が必要になるのはいつですか?

Pytorch初心者はこちら!次のカスタムモジュールについて考えてみます。

_class Testme(nn.Module):
    def __init__(self):
        super(Testme, self).__init__()

def forward(self, x):
    return x / t_.max(x).expand_as(x)
_

ドキュメントを理解している限り、これはカスタムFunctionとして実装することもできると思います。 Functionのサブクラスにはbackward()メソッドが必要ですが、Moduleには必要ありません。同様に、Linear Moduleのドキュメントの例では、Linear Functionに依存しています。

_class Linear(nn.Module):
    def __init__(self, input_features, output_features, bias=True):
        ...    
    def forward(self, input):
        return Linear()(input, self.weight, self.bias)
_

質問:ModuleFunctionの関係がわかりません。上記の最初のリスト(モジュールTestme)では、関連する関数が必要ですか?そうでない場合は、モジュールをサブクラス化することで、このなしbackwardメソッドを実装できます。それでは、なぜFunctionは常にbackwardメソッドを必要とするのでしょうか。 ?

おそらく、Functionsは、既存のトーチ関数で構成されていない関数のみを対象としていますか?別の言い方をすれば、モジュールのFunctionメソッドが以前に定義されたトーチ関数から完全に構成されている場合、モジュールは関連するforwardを必要としないでしょうか?

13
forgotmysocks

この情報は、公式のPyTorchDocumentaionから収集および要約されています。

_torch.autograd.Function_本当にPyTorchのautogradパッケージの中心にあります。 PyTorchで作成するグラフ、およびPyTorchのVariablesで実行する操作は、Functionに基づいています。すべての関数には、__init__(), forward()メソッドとbackward()メソッドが必要です(詳細はこちらをご覧ください: http://pytorch.org/docs/notes/extending.html )。これにより、PyTorchは結果を計算し、Variablesの勾配を計算できます。

対照的に、nn.Module()は、モデルやさまざまなレイヤーなどを整理するのに非常に便利です。たとえば、モデル内のすべてのトレーニング可能なパラメーターを.parameters()に整理し、別のレイヤーを追加できます。 forward()メソッドであるため、後方メソッドを定義する場所ではありませんnot 、すでにFunction()を定義しているbackward()のサブクラスを使用することになっています。したがって、forward()で演算の順序を指定した場合、PyTorchは勾配を逆伝播する方法をすでに知っています。

さて、いつ何を使うべきですか?

PyTorchに実装されている既存の関数の単なる合成である操作(上記のような)がある場合、Function()にサブクラスを自分で追加しても意味がありません。操作を積み重ねて動的グラフを作成できるからです。ただし、これらの操作をまとめることは賢明なアイデアです。操作にトレーニング可能なパラメーター(ニューラルネットワークの線形層など)が含まれる場合は、nn.Module()をサブクラス化し、forwardメソッドで操作をまとめる必要があります。これにより、_torch.optim_などを使用するためのパラメーター(上記で概説)に簡単にアクセスできます。トレーニング可能なパラメーターがない場合でも、おそらくそれらをまとめますが、標準のPython関数、使用する各操作のインスタンス化を処理するだけで十分です。

新しいカスタム操作(たとえば、複雑なサンプリング手順を伴う新しい確率的レイヤー)がある場合は、Function()をサブクラス化し、__init__(), forward()backward()を定義してPyTorchに通知する必要があります。この操作を使用する場合の結果の計算方法と勾配の計算方法。その後、関数のインスタンス化を処理して操作を使用する機能バージョンを作成するか、操作にトレーニング可能なパラメーターがある場合はモジュールを作成する必要があります。繰り返しになりますが、これについて詳しくは上記のリンクをご覧ください。

13
mexmex