Pytorchで手動で定義されたパラメーターでGRU/LSTMを埋めようとしています。
ドキュメントに定義されている形状を持つパラメーターのnumpy配列があります( https://pytorch.org/docs/stable/nn.html#torch.nn.GR )。
うまくいくようですが、返された値が正しいかどうかわかりません。
これは、numpyパラメータでGRU/LSTMを埋める正しい方法ですか?
gru = nn.GRU(input_size, hidden_size, num_layers,
bias=True, batch_first=False, dropout=dropout, bidirectional=bidirectional)
def set_nn_wih(layer, parameter_name, w, l0=True):
param = getattr(layer, parameter_name)
if l0:
for i in range(3*hidden_size):
param.data[i] = w[i*input_size:(i+1)*input_size]
else:
for i in range(3*hidden_size):
param.data[i] = w[i*num_directions*hidden_size:(i+1)*num_directions*hidden_size]
def set_nn_whh(layer, parameter_name, w):
param = getattr(layer, parameter_name)
for i in range(3*hidden_size):
param.data[i] = w[i*hidden_size:(i+1)*hidden_size]
l0=True
for i in range(num_directions):
for j in range(num_layers):
if j == 0:
wih = w0[i, :, :3*input_size]
whh = w0[i, :, 3*input_size:] # check
l0=True
else:
wih = w[j-1, i, :, :num_directions*3*hidden_size]
whh = w[j-1, i, :, num_directions*3*hidden_size:]
l0=False
if i == 0:
set_nn_wih(
gru, "weight_ih_l{}".format(j), torch.from_numpy(wih.flatten()),l0)
set_nn_whh(
gru, "weight_hh_l{}".format(j), torch.from_numpy(whh.flatten()))
else:
set_nn_wih(
gru, "weight_ih_l{}_reverse".format(j), torch.from_numpy(wih.flatten()),l0)
set_nn_whh(
gru, "weight_hh_l{}_reverse".format(j), torch.from_numpy(whh.flatten()))
y, hn = gru(x_t, h_t)
numpy配列は次のように定義されます:
rng = np.random.RandomState(313)
w0 = rng.randn(num_directions, hidden_size, 3*(input_size +
hidden_size)).astype(np.float32)
w = rng.randn(max(1, num_layers-1), num_directions, hidden_size,
3*(num_directions*hidden_size + hidden_size)).astype(np.float32)
それは良い質問であり、あなたはすでにまともな答えを与えています。しかし、それは車輪を再発明します-非常にエレガントなPytorchの内部ルーチンがあり、これを同じように労力をかけずに行うことができます-そして、どのネットワークにも適用できます。
ここでの中心概念は、PyTorchの_state_dict
_です。状態ディクショナリには、_nn.Modules
_とそのサブモジュールなどの関係によって与えられるツリー構造で編成されたparameters
が効果的に含まれています。
コードで_state_dict
_を使用して値をテンソルにロードする場合は、次の行を試してください(dict
には有効な_state_dict
_が含まれます)。
_`model.load_state_dict(dict, strict=False)`
_
ここで_strict=False
_は、一部のパラメーター値のみをロードする場合に重要です。
state_dict
_の紹介を含む状態ディクがGRUを検索する方法の例を次に示します(状態ディク全体を印刷できるように_input_size = hidden_size = 2
_を選択しました)。
_rnn = torch.nn.GRU(2, 2, 1)
rnn.state_dict()
# Out[10]:
# OrderedDict([('weight_ih_l0', tensor([[-0.0023, -0.0460],
# [ 0.3373, 0.0070],
# [ 0.0745, -0.5345],
# [ 0.5347, -0.2373],
# [-0.2217, -0.2824],
# [-0.2983, 0.4771]])),
# ('weight_hh_l0', tensor([[-0.2837, -0.0571],
# [-0.1820, 0.6963],
# [ 0.4978, -0.6342],
# [ 0.0366, 0.2156],
# [ 0.5009, 0.4382],
# [-0.7012, -0.5157]])),
# ('bias_ih_l0',
# tensor([-0.2158, -0.6643, -0.3505, -0.0959, -0.5332, -0.6209])),
# ('bias_hh_l0',
# tensor([-0.1845, 0.4075, -0.1721, -0.4893, -0.2427, 0.3973]))])
_
したがって、_state_dict
_ネットワークのすべてのパラメーター。 「ネストされた」_nn.Modules
_がある場合、パラメーター名で表されるツリーを取得します。
_class MLP(torch.nn.Module):
def __init__(self):
torch.nn.Module.__init__(self)
self.lin_a = torch.nn.Linear(2, 2)
self.lin_b = torch.nn.Linear(2, 2)
mlp = MLP()
mlp.state_dict()
# Out[23]:
# OrderedDict([('lin_a.weight', tensor([[-0.2914, 0.0791],
# [-0.1167, 0.6591]])),
# ('lin_a.bias', tensor([-0.2745, -0.1614])),
# ('lin_b.weight', tensor([[-0.4634, -0.2649],
# [ 0.4552, 0.3812]])),
# ('lin_b.bias', tensor([ 0.0273, -0.1283]))])
class NestedMLP(torch.nn.Module):
def __init__(self):
torch.nn.Module.__init__(self)
self.mlp_a = MLP()
self.mlp_b = MLP()
n_mlp = NestedMLP()
n_mlp.state_dict()
# Out[26]:
# OrderedDict([('mlp_a.lin_a.weight', tensor([[ 0.2543, 0.3412],
# [-0.1984, -0.3235]])),
# ('mlp_a.lin_a.bias', tensor([ 0.2480, -0.0631])),
# ('mlp_a.lin_b.weight', tensor([[-0.4575, -0.6072],
# [-0.0100, 0.5887]])),
# ('mlp_a.lin_b.bias', tensor([-0.3116, 0.5603])),
# ('mlp_b.lin_a.weight', tensor([[ 0.3722, 0.6940],
# [-0.5120, 0.5414]])),
# ('mlp_b.lin_a.bias', tensor([0.3604, 0.0316])),
# ('mlp_b.lin_b.weight', tensor([[-0.5571, 0.0830],
# [ 0.5230, -0.1020]])),
# ('mlp_b.lin_b.bias', tensor([ 0.2156, -0.2930]))])
_
だから-状態辞書を抽出せずに、それを変更したい場合-そしてそれによってネットワークのパラメータをどうするか? nn.Module.load_state_dict(state_dict, strict=True)
を使用( ドキュメントへのリンク )このメソッドを使用すると、state_dict全体を任意の値で同じ種類のインスタンス化されたモデルにロードできますキー(パラメータ名)が正しく、値(パラメータ)が正しい形状の_torch.tensors
_である限り。 strict
kwargがTrue
(デフォルト)に設定されている場合、ロードする辞書は、パラメーターの値を除き、元の状態の辞書と正確に一致する必要があります。つまり、各パラメーターに1つの新しい値が必要です。
上記のGRUの例では、各_'weight_ih_l0', 'weight_hh_l0', 'bias_ih_l0', 'bias_hh_l0'
_に対して正しいサイズのテンソル(および正しいデバイスbtw)が必要です。時々some値のみをロードしたいので(あなたがしたいと思うように)、strict
kwargをFalse
に設定できます。その後、部分的な状態の辞書のみを読み込むことができます、例えば_'weight_ih_l0'
_のパラメーター値のみを含むもの。
実用的なアドバイスとして、値をロードするモデルを作成し、状態の辞書(または少なくともキーとそれぞれのテンソルサイズのリスト)を出力します。
_print([k, v.shape for k, v in model.state_dict().items()])
_
これにより、変更するパラメーターの正確な名前がわかります。次に、それぞれのパラメーター名とテンソルを使用して状態辞書を作成し、ロードします。
_from dollections import OrderedDict
new_state_dict = OrderedDict({'tensor_name_retrieved_from_original_dict': new_tensor_value})
model.load_state_dict(new_state_dict, strict=False)
_