web-dev-qa-db-ja.com

TensorFlowの `conv2d_transpose()`操作は何をしますか?

conv2d_transpose()操作のドキュメントでは、その機能について明確に説明されていません。

Conv2dの転置。

この操作は Deconvolutional Networks の後に「デコンボリューション」と呼ばれることもありますが、実際には実際のデコンボリューションではなくconv2dの転置(勾配)です。

私はドキュメントが指している論文を読みましたが、助けにはなりませんでした。

この操作は何をするものであり、それを使用する理由の例は何ですか?

34
MiniQuark

これは、畳み込み転置がどのように機能するかをオンラインで見た最も良い説明です here です。

簡単な説明をします。これは、分数ストライドによる畳み込みを適用します。つまり、入力値(ゼロ)の間隔を空けて、フィルターサイズよりも潜在的に小さい領域にフィルターを適用します。

なぜそれを使いたいのかについて。バイリニア補間やその他の固定形式のアップサンプリングとは対照的に、学習した重みを持つ一種のアップサンプリングとして使用できます。

32
Steven

「勾配」の観点から見た別の視点、つまりTensorFlowのドキュメントにconv2d_transpose()が「実際にはconv2dの転置(gradient)」と書かれている理由実際のデコンボリューションではなく」。 conv2d_transposeで行われる実際の計算の詳細については、19ページから この記事 をお勧めします。

4つの関連機能

tf.nnには、2次元畳み込みのための密接に関連する4つのやや混乱する関数があります。

  • tf.nn.conv2d
  • tf.nn.conv2d_backprop_filter
  • tf.nn.conv2d_backprop_input
  • tf.nn.conv2d_transpose

1つの文の要約:これらはすべて2d畳み込みです。それらの違いは、入力引数の順序、入力回転または転置、ストライド(小数ストライドサイズを含む)、パディングなどにあります。tf.nn.conv2dを使用すると、入力を変換し、 conv2d引数。

問題設定

  • 前方および後方計算:
# forward
out = conv2d(x, w)

# backward, given d_out
=> find d_x?
=> find d_w?

前方計算では、フィルターxを使用して入力画像wの畳み込みを計算し、結果はoutになります。後方計算では、d_out(勾配w.r.t)が与えられていると仮定します。 out。私たちの目標は、グラデーションw.r.tであるd_xd_wを見つけることです。それぞれxおよびw

議論を簡単にするために、次のことを想定しています。

  • すべてのストライドサイズは1になります
  • in_channelsout_channelsはすべて1です
  • VALIDパディングを使用する
  • 奇数のフィルターサイズ、これは非対称形状の問題を回避します

簡潔な答え

概念的には、上記の仮定により、次の関係があります。

out = conv2d(x, w, padding='VALID')
d_x = conv2d(d_out, rot180(w), padding='FULL')
d_w = conv2d(x, d_out, padding='VALID')

rot180は180度回転した2Dマトリックス(左右反転と上下反転)であり、FULLは「入力と部分的に重複する場所にフィルターを適用する」ことを意味します( theano docs を参照) =)。 これは上記の仮定でのみ有効ですが、conv2d引数を変更して一般化することができます。

重要なポイント:

  • 入力勾配d_xは、出力勾配d_outと重みwの畳み込みであり、いくつかの修正が加えられています。
  • 重み勾配d_wは、入力xと出力勾配d_outの畳み込みであり、いくつかの修正が加えられています。

ロングアンサー

ここで、上記の4つの関数を使用してd_xおよびd_wが与えられたd_outを計算する方法の実際の動作コード例を示します。これは、conv2dconv2d_backprop_filterconv2d_backprop_input、およびconv2d_transposeの相互関係を示しています。 ここで完全なスクリプトを見つけてください

4つの異なる方法でd_xを計算します:

# Method 1: TF's autodiff
d_x = tf.gradients(f, x)[0]

# Method 2: manually using conv2d
d_x_manual = tf.nn.conv2d(input=tf_pad_to_full_conv2d(d_out, w_size),
                          filter=tf_rot180(w),
                          strides=strides,
                          padding='VALID')

# Method 3: conv2d_backprop_input
d_x_backprop_input = tf.nn.conv2d_backprop_input(input_sizes=x_shape,
                                                 filter=w,
                                                 out_backprop=d_out,
                                                 strides=strides,
                                                 padding='VALID')

# Method 4: conv2d_transpose
d_x_transpose = tf.nn.conv2d_transpose(value=d_out,
                                       filter=w,
                                       output_shape=x_shape,
                                       strides=strides,
                                       padding='VALID')

3つの異なる方法でd_wを計算します:

# Method 1: TF's autodiff
d_w = tf.gradients(f, w)[0]

# Method 2: manually using conv2d
d_w_manual = tf_NHWC_to_HWIO(tf.nn.conv2d(input=x,
                                          filter=tf_NHWC_to_HWIO(d_out),
                                          strides=strides,
                                          padding='VALID'))

# Method 3: conv2d_backprop_filter
d_w_backprop_filter = tf.nn.conv2d_backprop_filter(input=x,
                                                   filter_sizes=w_shape,
                                                   out_backprop=d_out,
                                                   strides=strides,
                                                   padding='VALID')

tf_rot180tf_pad_to_full_conv2dtf_NHWC_to_HWIOの実装については、 フルスクリプト をご覧ください。スクリプトでは、さまざまなメソッドの最終出力値が同じであることを確認します。 numpy実装も利用可能です。

23
Yixing

conv2d_transpose()は、単に重みを転置し、それらを180度反転します。次に、標準のconv2d()を適用します。 「転置」は、実際には、重みテンソルの「列」の順序を変更することを意味します。以下の例を確認してください。

ここに、stride = 1およびpadding = 'SAME'のたたみ込みを使用する例があります。これは単純なケースですが、他のケースにも同じ推論を適用できます。

私たちが持っていると言う:

  • 入力:28x28x1のMNIST画像、形状= [28,28,1]
  • 畳み込み層:7x7の32個のフィルター、重み形状= [7、7、1、32]、名前= W_conv1

入力の畳み込みを実行すると、その活性化の形状は[1,28,28,32]になります。

 activations = sess.run(h_conv1,feed_dict={x:np.reshape(image,[1,784])})

どこ:

 W_conv1 = weight_variable([7, 7, 1, 32])
 b_conv1 = bias_variable([32])
 h_conv1 = conv2d(x, W_conv1, strides=[1, 1, 1, 1], padding='SAME') + b_conv1

「デコンボリューション」または「転置コンボリューション」を取得するには、この方法でコンボリューションのアクティベーションでconv2d_transpose()を使用できます。

  deconv = conv2d_transpose(activations,W_conv1, output_shape=[1,28,28,1],padding='SAME')

または、conv2d()を使用して、重みを転置および反転する必要があります。

  transposed_weights = tf.transpose(W_conv1, perm=[0, 1, 3, 2])

ここでは、「列」の順序を[0,1,2,3]から[0,1,3,2]に変更します。したがって、[7、7、1、32]からshape =のテンソルを取得します。 [7,7,32,1]。次に、重みを反転します。

  for i in range(n_filters):
      # Flip the weights by 180 degrees
      transposed_and_flipped_weights[:,:,i,0] =  sess.run(tf.reverse(transposed_weights[:,:,i,0], axis=[0, 1]))

次に、conv2d()を使用して畳み込みを計算できます。

  strides = [1,1,1,1]
  deconv = conv2d(activations,transposed_and_flipped_weights,strides=strides,padding='SAME')

そして、以前と同じ結果が得られます。また、次を使用してconv2d_backprop_input()を使用してもまったく同じ結果が得られます。

   deconv = conv2d_backprop_input([1,28,28,1],W_conv1,activations, strides=strides, padding='SAME')

結果は次のとおりです。

conv2d()、conv2d_tranposed()、conv2d_backprop_input()のテスト

結果が同じであることがわかります。より良い方法でそれを見るには、私のコードをチェックしてください:

https://github.com/simo23/conv2d_transpose

ここでは、標準のconv2d()を使用してconv2d_transpose()関数の出力を複製します。

12
simo23

Conv2d_transposeのアプリケーションの1つはアップスケーリングです。その仕組みを説明する例を次に示します。

a = np.array([[0, 0, 1.5],
              [0, 1, 0],
              [0, 0, 0]]).reshape(1,3,3,1)

filt = np.array([[1, 2],
                 [3, 4.0]]).reshape(2,2,1,1)

b = tf.nn.conv2d_transpose(a,
                           filt,
                           output_shape=[1,6,6,1],
                           strides=[1,2,2,1],
                           padding='SAME')

print(tf.squeeze(b))

tf.Tensor(
[[0.  0.  0.  0.  1.5 3. ]
 [0.  0.  0.  0.  4.5 6. ]
 [0.  0.  1.  2.  0.  0. ]
 [0.  0.  3.  4.  0.  0. ]
 [0.  0.  0.  0.  0.  0. ]
 [0.  0.  0.  0.  0.  0. ]], shape=(6, 6), dtype=float64)
1
pentadecagon