web-dev-qa-db-ja.com

Tensorflowにテンソルの次元をドロップします

形状が(50, 100, 1, 512)のテンソルがあり、新しいテンソルの形状が(50, 100, 512)になるように、形状を変更するか、3次元を削除したいと思います。

tf.slicetf.squeezeで試しました。

a = tf.slice(a, [50, 100, 1, 512], [50, 100, 1, 512])
b = tf.squeeze(a)

abの形状を印刷しようとすると、すべてが機能しているように見えますが、モデルのトレーニングを開始すると、このエラーが発生しました

tensorflow.python.framework.errors_impl.InvalidArgumentError: Expected size[0] in [0, 0], but got 50
     [[Node: Slice = Slice[Index=DT_INT32, T=DT_FLOAT, _device="/job:localhost/replica:0/task:0/device:CPU:0"](MaxPool_2, Slice/begin, Slice/size)]]

sliceに問題はありますか。どうすれば修正できますか。ありがとう

4
lamhoangtung

一般的にtf.squeezeは寸法を削除します。

a = tf.constant([[[1,2,3],[3,4,5]]])

上記のテンソル形状は[1,2,3]。スクイーズ操作後、

b = tf.squeeze(a)

さて、テンソルの形は[2,3]

3

それを行うには複数の方法があります。 Tensorflowはインデックス作成のサポートを開始しました。試してみてください

_a = a[:,:,0,:]_

[〜#〜]または[〜#〜]

_a = a[:,:,-1,:]_

[〜#〜]または[〜#〜]

a = tf.reshape(a,[50,100,512])

3
mnis

この場合、tf.sliceを間違って使用していますが、次のようになります。

a = tf.slice(a, [0, 0, 0, 0], [50, 100, 1, 512])
b = tf.squeeze(a)

tf.slicedocumentation を見ると、その理由がわかります。

1
lamhoangtung