web-dev-qa-db-ja.com

Tensorflowの中央値

TensorFlowでリストの中央値を計算するにはどうすればよいですか?お気に入り

node = tf.median(X)

Xはプレースホルダーです
numpyでは、np.medianを直接使用して中央値を取得できます。テンソルフローでnumpy操作を使用するにはどうすればよいですか?

9
Yingchao Xiong

編集:この回答は古くなっています。代わりにLucas VenezianPovoaのソリューションを使用してください。それはより簡単でより速いです。

以下を使用して、テンソルフロー内の中央値を計算できます。

def get_median(v):
    v = tf.reshape(v, [-1])
    mid = v.get_shape()[0]//2 + 1
    return tf.nn.top_k(v, mid).values[-1]

Xがすでにベクトルである場合は、再形成をスキップできます。

中央値が偶数サイズのベクトルの2つの中央要素の平均であることに関心がある場合は、代わりにこれを使用する必要があります。

def get_real_median(v):
    v = tf.reshape(v, [-1])
    l = v.get_shape()[0]
    mid = l//2 + 1
    val = tf.nn.top_k(v, mid).values
    if l % 2 == 1:
        return val[-1]
    else:
        return 0.5 * (val[-1] + val[-2])
4
BlueSun

tensorflowを使用して配列の中央値を計算するには、quantile関数を使用できます。これは、50%の分位数がmedianであるためです。

_import tensorflow as tf
import numpy as np 

np.random.seed(0)   
x = np.random.normal(3.0, .1, 100)

median = tf.contrib.distributions.percentile(x, 50.0)

tf.Session().run(median)
_

interpolationパラメータは結果をlowerhigher、またはnearestサンプル値に近似するため、このコードは_np.median_と同じ動作をしません。

同じ動作が必要な場合は、次を使用できます。

_median = tf.contrib.distributions.percentile(x, 50.0, interpolation='lower')
median += tf.contrib.distributions.percentile(x, 50.0, interpolation='higher')
median /= 2.
tf.Session().run(median)
_

それ以外に、上記のコードはnp.percentile(x, 50, interpolation='midpoint')と同等です。

BlueSunのソリューションを変更して、GPUではるかに高速にすることができます。

_def get_median(v):
    v = tf.reshape(v, [-1])
    m = v.get_shape()[0]//2
    return tf.reduce_min(tf.nn.top_k(v, m, sorted=False).values)
_

これは(私の経験では)tf.contrib.distributions.percentile(v, 50.0)を使用するのと同じくらい高速で、実際の要素の1つを返します。

3
noname

現在、TFには 中央値関数なし があります。 TFでnumpy操作を使用する唯一の方法は、グラフを実行した後です。

import tensorflow as tf
import numpy as np

a = tf.random_uniform(shape=(5, 5))

with tf.Session() as sess:
    np_matrix = sess.run(a)
    print np.median(np_matrix)
1
Salvador Dali