web-dev-qa-db-ja.com

配列を使用したnp.add.atインデックス作成

私はcs231nに取り組んでおり、このインデックス作成がどのように機能するかを理解するのに苦労しています。とすれば

_x = [[0,4,1], [3,2,4]]
dW = np.zeros(5,6)
dout = [[[  1.19034710e-01  -4.65005990e-01   8.93743168e-01  -9.78047129e-01
            -8.88672957e-01  -4.66605091e-01]
         [ -1.38617461e-03  -2.64569728e-01  -3.83712733e-01  -2.61360826e-01
            8.07072009e-01  -5.47607277e-01]
         [ -3.97087458e-01  -4.25187949e-02   2.57931759e-01   7.49565950e-01
           1.37707667e+00   1.77392240e+00]]

       [[ -1.20692745e+00  -8.28111550e-01   6.53041092e-01  -2.31247762e+00
         -1.72370321e+00   2.44308033e+00]
        [ -1.45191870e+00  -3.49328154e-01   6.15445782e-01  -2.84190582e-01
           4.85997687e-02   4.81590106e-01]
        [ -1.14828583e+00  -9.69055406e-01  -1.00773809e+00   3.63553835e-01
          -1.28078363e+00  -2.54448436e+00]]]
_

彼らが行う操作は

np.add.at(dW, x, dout)

xは2次元配列です。ここでインデックス作成はどのように機能しますか? _np.ufunc.at_のドキュメントを確認しましたが、1次元配列と定数を使用した簡単な例があります。

_np.add.at(a, [0, 1, 2, 2], 1)
_
7
MoneyBall
In [226]: x = [[0,4,1], [3,2,4]]
     ...: dW = np.zeros((5,6),int)

In [227]: np.add.at(dW,x,1)
In [228]: dW
Out[228]: 
array([[0, 0, 0, 1, 0, 0],
       [0, 0, 0, 0, 1, 0],
       [0, 0, 0, 0, 0, 0],
       [0, 0, 0, 0, 0, 0],
       [0, 0, 1, 0, 0, 0]])

このxでは重複するエントリがないため、add.at+=インデックスを使用するのと同じです。同様に、変更された値は次のように読み取ることができます。

In [229]: dW[x[0], x[1]]
Out[229]: array([1, 1, 1])

インデックスは、ブロードキャストを含め、どちらの方法でも同じように機能します。

In [234]: dW[...]=0
In [235]: np.add.at(dW,[[[1],[2]],[2,4,4]],1)
In [236]: dW
Out[236]: 
array([[0, 0, 0, 0, 0, 0],
       [0, 0, 1, 0, 2, 0],
       [0, 0, 1, 0, 2, 0],
       [0, 0, 0, 0, 0, 0],
       [0, 0, 0, 0, 0, 0]])

可能な値

インデックスに関して、値はbroadcastableである必要があります。

In [112]: np.add.at(dW,[[[1],[2]],[2,4,4]],np.ones((2,3)))
...
In [114]: np.add.at(dW,[[[1],[2]],[2,4,4]],np.ones((2,3)).ravel())
...
ValueError: array is not broadcastable to correct shape
In [115]: np.add.at(dW,[[[1],[2]],[2,4,4]],[1,2,3])

In [117]: np.add.at(dW,[[[1],[2]],[2,4,4]],[[1],[2]])

In [118]: dW
Out[118]: 
array([[ 0,  0,  0,  0,  0,  0],
       [ 0,  0,  3,  0,  9,  0],
       [ 0,  0,  4,  0, 11,  0],
       [ 0,  0,  0,  0,  0,  0],
       [ 0,  0,  0,  0,  0,  0]])

この場合、インデックスは(2,3)形状を定義するため、(2,3)、(3、)、(2,1)、およびスカラー値が機能します。 (6、)はしません。

この場合、add.atは(2,3)配列をdWの(2,2)サブ配列にマッピングしています。

9
hpaulj

最近、私もこのコード行を理解するのに苦労しています。私が得たものがあなたを助けることができることを願っています、私が間違っているなら私を訂正してください。

このコード行の3つの配列は次のとおりです。

x , whose shape is (N,T)
dW,  ---(V,D)
dout ---(N,T,D)

次に、何が起こるかを理解したいラインコードに到達します

np.add.at(dW, x, dout)

あなたが思考手順を知りたくない場合。上記のコードは次と同等です:

for row in range(N):
   for col in range(T):
      dW[ x[row,col]  , :] += dout[row,col, :]

これは思考手順です:

このドキュメントを参照する

https://docs.scipy.org/doc/numpy-1.13.0/reference/generated/numpy.ufunc.at.html

Xがインデックス配列であることがわかっています。したがって、重要なのはdW [x]を理解することです。これは、別のarray(x)を使用してarray(dW)にインデックスを付けるという概念です。この概念に精通していない場合は、このリンクをチェックしてください

https://docs.scipy.org/doc/numpy-1.13.0/user/basics.indexing.html

一般的に、インデックス配列を使用すると返されるのは、インデックス配列と同じ形状の配列ですが、インデックスが作成される配列の型と値が含まれます。

dW [x]は、形状が(N、T、D)、(N、T)部分がxから、(D)がdW(V、D)からなる配列を提供します。ここで、xのすべての要素が[0、v)の範囲内にあることに注意してください。

具体的な例としていくつかの数字を取り上げましょう

x:    np.array([[0,0],[0,0]]) ---- (2,2) N=2, T=2
dW:   np.array([[0,0],[2,2]]) ---- (2,2) V=2, D=2
dout: np.arange(1,9).reshape(2,2,2)  ----(2,2,2) N=2, T=2, D=2

dW[x] should be [ [[0 0] #this comes from the dW's firt row
                  [0 0]]

                  [[0 0]
                   [0 0]] ]

dW [x] add doutは、elemnetアイテムを追加することを意味します(ここでは、これはいくつかのトリックで、後で説明します)

np.add.at(dW, x, dout) gives 
 [ [16 20]
   [ 2  2] ]

どうして?手順は次のとおりです。

DWの最初の行である[0,0]に[1,2]を追加します。

なぜ最初の行? x [0,0] = 0であり、dWの最初の行を示しているため、dW [0] = dW [0、:] =最初の行です。

次に、dW [0,0]の最初の行に[3,4]を追加します。 [3,4] = dout [0,1 、:]。 [0,0]も、dW(x [0,1] = 0)から取得されますが、それでもdW [0]の最初の行です。

次に、dWの最初の行に[5,6]を追加します。

次に、dWの最初の行に[7,8]を追加します。

したがって、結果は[1 + 3 + 5 + 7、2 + 4 + 6 + 8] = [16,20]になります。 dWの2行目には触れないからです。 dWの2行目は変更されません。

秘訣は、Origin行を1回だけカウントし、バッファーがないと考えることができ、すべてのステップが元の場所で再生されることです。

1
kaishen