web-dev-qa-db-ja.com

Theanoでテンソル共有変数のサブセットを割り当てる/更新するにはどうすればよいですか?

theanoで関数をコンパイルする場合、updates=[(X, new_value)]を指定することで共有変数(たとえばX)を更新できます。現在、共有変数のサブセットのみを更新しようとしています。

from theano import tensor as T
from theano import function
import numpy

X = T.shared(numpy.array([0,1,2,3,4]))
Y = T.vector()
f = function([Y], updates=[(X[2:4], Y)] # error occur:
                                        # 'update target must 
                                        # be a SharedVariable'

コードは「更新ターゲットはSharedVariableでなければなりません」というエラーを発生させます。これは、更新ターゲットを非共有変数にすることはできないことを意味していると思います。では、共有変数のサブセットを更新するだけの関数をコンパイルする方法はありますか?

25
gaga.zhn

set_subtensor または inc_subtensor を使用します:

from theano import tensor as T
from theano import function, shared
import numpy

X = shared(numpy.array([0,1,2,3,4]))
Y = T.vector()
X_update = (X, T.set_subtensor(X[2:4], Y))
f = function([Y], updates=[X_update])
f([100,10])
print X.get_value() # [0 1 100 10 4]

Theano FAQにこれに関するページがあります: http://deeplearning.net/software/theano/tutorial/faq_tutorial.html

32
dpfried

このコードはあなたの問題を解決するはずです:

from theano import tensor as T
from theano import function, shared
import numpy

X = shared(numpy.array([0,1,2,3,4], dtype='int'))
Y = T.lvector()
X_update = (X, X[2:4]+Y)
f = function(inputs=[Y], updates=[X_update])
f([100,10])
print X.get_value()
# output: [102 13]

そしてここに 公式チュートリアルの共有変数についての紹介 があります。

ご不明な点がございましたら、お問い合わせください。

0
Framester