web-dev-qa-db-ja.com

TensorFlow:軸に沿ったテンソルの最大

私の質問は、2つの接続された部分にあります。

  1. テンソルの特定の軸に沿って最大値を計算するにはどうすればよいですか?たとえば、私が持っている場合

    x = tf.constant([[1,220,55],[4,3,-1]])
    

    みたいなものが欲しい

    x_max = tf.max(x, axis=1)
    print sess.run(x_max)
    
    output: [220,4]
    

    tf.argmaxおよびtf.maximumが、単一のテンソルの軸に沿って最大値を与えません。今のところ回避策があります:

    x_max = tf.slice(x, begin=[0,0], size=[-1,1])
    for a in range(1,2):
        x_max = tf.maximum(x_max , tf.slice(x, begin=[0,a], size=[-1,1]))
    

    しかし、それは最適ではありません。これを行うためのより良い方法はありますか?

  2. テンソルのargmaxのインデックスが与えられた場合、それらのインデックスを使用して別のテンソルにインデックスを付けるにはどうすればよいですか?上記のxの例を使用して、次のようなことを行うにはどうすればよいですか:

    ind_max = tf.argmax(x, dimension=1)    #output is [1,0]
    y = tf.constant([[1,2,3], [6,5,4])
    y_ = y[:, ind_max]                     #y_ should be [2,6]
    

    最後の行のように、スライスがTensorFlowにまだ存在しないことを知っています( #206 )。

    私の質問は次のとおりです:私の特定のケースに最適な回避策は何ですか(収集、選択などの他の方法を使用する可能性があります)?

    追加情報:xyが2次元テンソルのみになることを知っています!

28
aphdstudent

tf.reduce_max() 演算子は、まさにこの機能を提供します。デフォルトでは、指定されたテンソルのグローバルな最大値を計算しますが、NumPyのaxisと同じ意味を持つ_reduction_indices_のリストを指定できます。あなたの例を完了するには:

_x = tf.constant([[1, 220, 55], [4, 3, -1]])
x_max = tf.reduce_max(x, reduction_indices=[1])
print sess.run(x_max)  # ==> "array([220,   4], dtype=int32)"
_

tf.argmax() を使用してargmaxを計算する場合、 tf.reshape() 、次のようにargmaxインデックスをベクトルインデックスに変換し、 tf.gather() を使用して適切な値を抽出します。

_ind_max = tf.argmax(x, dimension=1)
y = tf.constant([[1, 2, 3], [6, 5, 4]])

flat_y = tf.reshape(y, [-1])  # Reshape to a vector.

# N.B. Handles 2-D case only.
flat_ind_max = ind_max + tf.cast(tf.range(tf.shape(y)[0]) * tf.shape(y)[1], tf.int64)

y_ = tf.gather(flat_y, flat_ind_max)

print sess.run(y_) # ==> "array([2, 6], dtype=int32)"
_
58
mrry

TensorFlow 1.10.0-dev20180626tf.reduce_max は、numpy.maxと同様の機能を提供するaxisおよびkeepdimsキーワード引数を受け入れます。

In [55]: x = tf.constant([[1,220,55],[4,3,-1]])

In [56]: tf.reduce_max(x, axis=1).eval() 
Out[56]: array([220,   4], dtype=int32)

入力テンソルと同じ次元の結果テンソルを使用するには、keepdims=Trueを使用します

In [57]: tf.reduce_max(x, axis=1, keepdims=True).eval()Out[57]: 
array([[220],
       [  4]], dtype=int32)

axis引数が明示的に指定されていない場合、テンソルレベルの最大要素が返されます(つまり、すべての軸が削減されます)。

In [58]: tf.reduce_max(x).eval()
Out[58]: 220
1
kmario23