2つのベクトルのドット積(つまり、1-dテンソル)を計算し、テンソルフローでスカラー値を返す簡単な方法があるかどうか疑問に思っていました。
2つのベクトルX =(x1、...、xn)およびY =(y1、...、yn)が与えられた場合、ドット積はdot(X、Y)= x1 * y1 + ... + xn * ynです。
最初にベクトルXとYを2次元テンソルにブロードキャストし、次にtf.matmulを使用することでこれを実現できることを知っています。しかし、結果は行列であり、私はスカラーを求めています。
ベクトルに固有のtf.matmulのような演算子はありますか?
2つのテンソル(ベクトルは1Dテンソル)間のドット積を計算する最も簡単な方法の1つは、 tf.tensordot
を使用することです
a = tf.placeholder(tf.float32, shape=(5))
b = tf.placeholder(tf.float32, shape=(5))
dot_a_b = tf.tensordot(a, b, 1)
with tf.Session() as sess:
print(dot_a_b.eval(feed_dict={a: [1, 2, 3, 4, 5], b: [6, 7, 8, 9, 10]}))
# results: 130.0
tf.reduce_sum(tf.multiply(x, y))
に加えて、tf.matmul(x, tf.reshape(y, [-1, 1]))
も実行できます。
tf.matmulとtf.transposeを使用できます
tf.matmul(x,tf.transpose(y))
または
tf.matmul(tf.transpose(x),y)
xおよびyの次元に応じて
import tensorflow as tf
x = tf.Variable([1, -2, 3], tf.float32, name='x')
y = tf.Variable([-1, 2, -3], tf.float32, name='y')
dot_product = tf.reduce_sum(tf.multiply(x, y))
sess = tf.InteractiveSession()
init_op = tf.global_variables_initializer()
sess.run(init_op)
dot_product.eval()
Out[46]: -14
ここで、xとyは両方ともベクトルです。要素ごとの積を求め、tf.reduce_sumを使用して、結果のベクトルの要素を合計できます。このソリューションは読みやすく、再形成する必要はありません。
興味深いことに、 docs に組み込みのドット積演算子があるようには見えません。
中間ステップを簡単に確認できることに注意してください。
In [48]: tf.multiply(x, y).eval()
Out[48]: array([-1, -4, -9], dtype=int32)
おそらく、新しいドキュメントでは、ドット積の最初の引数または2番目の引数のいずれかの転置オプションをtrueに設定することができます。
tf.matmul(a, b, transpose_a=False, transpose_b=False, adjoint_a=False, adjoint_b=False, a_is_sparse=False, b_is_sparse=False, name=None)
リーディング:
tf.matmul(a, b, transpose_a=True, transpose_b=False)
tf.matmul(a, b, transpose_a=False, transpose_b=True)
Tf.mul(x、y)の後にtf.reduce_sum()を実行できます
2つの列ベクトルがあると仮定しましょう
u = tf.constant([[2.], [3.]])
v = tf.constant([[5.], [7.]])
1x1マトリックスが必要な場合は、使用できます
tf.einsum('ij,ik->jk',x,y)
スカラーに興味がある場合は、使用できます
tf.einsum('ij,ik->',x,y)
ab = tf.reduce_sum(a*b)
次のような簡単な例を示します。
import tensorflow as tf
a = tf.constant([1,2,3])
b = tf.constant([2,3,4])
print(a.get_shape())
print(b.get_shape())
c = a*b
ab = tf.reduce_sum(c)
with tf.Session() as sess:
print(c.eval())
print(ab.eval())
# output
# (3,)
# (3,)
# [2 6 12]
# 20