私は苦労してきた問題を抱えています。これは、tf.matmul()
とそのブロードキャストの欠如に関連しています。
https://github.com/tensorflow/tensorflow/issues/216 で同様の問題を認識していますが、tf.batch_matmul()
は私の場合の解決策のようには見えません。
入力データを4Dテンソルとしてエンコードする必要があります。X = tf.placeholder(tf.float32, shape=(None, None, None, 100))
最初の次元はバッチのサイズで、2番目の次元はバッチ内のエントリの数です。各エントリは、いくつかのオブジェクトの構成として想像できます(3次元)。最後に、各オブジェクトは100個のfloat値のベクトルで記述されます。
実際のサイズはバッチごとに変わる可能性があるため、2番目と3番目の次元にはNoneを使用したことに注意してください。ただし、簡単にするために、実際の数でテンソルを形成しましょう:X = tf.placeholder(tf.float32, shape=(5, 10, 4, 100))
これらは私の計算のステップです:
100個の浮動小数点値の各ベクトルの関数を計算します(例:線形関数)W = tf.Variable(tf.truncated_normal([100, 50], stddev=0.1))
Y = tf.matmul(X, W)
problem: tf.matmul()
のブロードキャストがなく、tf.batch_matmul()
を使用しても成功しないYの予想される形状:(5、10、4、50)
バッチの各エントリに平均プーリングを適用する(各エントリのオブジェクトに):Y_avg = tf.reduce_mean(Y, 2)
Y_avgの予想される形状:(5、10、50)
tf.matmul()
がブロードキャストをサポートするだろうと思っていました。次に、tf.batch_matmul()
を見つけましたが、それでも私の場合には当てはまらないようです(たとえば、Wは少なくとも3次元である必要があり、理由は明確ではありません)。
ところで、上記では単純な線形関数を使用しました(その重みはWに格納されています)。しかし、私のモデルでは、代わりに深いネットワークがあります。したがって、私が抱えているより一般的な問題は、テンソルの各スライスの関数を自動的に計算することです。これが、tf.matmul()
がブロードキャスト動作をすることを期待した理由です(もしそうなら、おそらくtf.batch_matmul()
は必要ないでしょう)。
あなたから学ぶことを楽しみにしています!アレッシオ
これは、X
を再形成して_[n, d]
_を形成することで実現できます。ここで、d
は、計算の1つの「インスタンス」の次元(この例では100)であり、n
は、多次元オブジェクト内のインスタンスの数です(_5*10*4=200
_あなたの例では)。形状を変更した後、_tf.matmul
_を使用して、目的の形状に形状を変更できます。最初の3次元が変化する可能性があるという事実は、少し注意が必要ですが、_tf.shape
_を使用して、実行時に実際の形状を決定できます。最後に、計算の2番目のステップを実行できます。これは、それぞれの次元で単純な_tf.reduce_mean
_である必要があります。全体として、次のようになります。
_X = tf.placeholder(tf.float32, shape=(None, None, None, 100))
W = tf.Variable(tf.truncated_normal([100, 50], stddev=0.1))
X_ = tf.reshape(X, [-1, 100])
Y_ = tf.matmul(X_, W)
X_shape = tf.gather(tf.shape(X), [0,1,2]) # Extract the first three dimensions
target_shape = tf.concat(0, [X_shape, [50]])
Y = tf.reshape(Y_, target_shape)
Y_avg = tf.reduce_mean(Y, 2)
_
リンクした GitHubの問題 の名前が変更されたタイトルが示すように、 tf.tensordot()
を使用する必要があります。これにより、Numpyの tensordot()
に沿って、2つのテンソル間の軸ペアの収縮が可能になります。あなたの場合:
X = tf.placeholder(tf.float32, shape=(5, 10, 4, 100))
W = tf.Variable(tf.truncated_normal([100, 50], stddev=0.1))
Y = tf.tensordot(X, W, [[3], [0]]) # gives shape=[5, 10, 4, 50]