web-dev-qa-db-ja.com

テンソルフローでは、テンソルに保存されている一連の入力を反復処理する方法は?

可変長多変量シーケンス分類問題でRNNを試しています。

シーケンスの出力(つまり、シーケンスからの最終入力が供給された後のRNNセルの出力)を取得するために、次の関数を定義しました

def get_sequence_output(x_sequence, initial_hidden_state):
    previous_hidden_state = initial_hidden_state
    for x_single in x_sequence:
        hidden_state = gru_unit(previous_hidden_state, x_single)
        previous_hidden_state = hidden_state
    final_hidden_state = hidden_state
    return final_hidden_state

ここに x_sequenceは形状のテンソルです(?, ?, 10)最初はどこ?バッチサイズと秒用ですか?はシーケンスの長さであり、各入力要素の長さは10です。gru関数は、前の非表示状態と現在の入力を受け取り、次の非表示状態(標準のゲート付き回帰ユニット)を吐き出します。

エラーが発生しました:'Tensor' object is not iterable. Tensorを順番に繰り返す(一度に1つの要素を読み取る)にはどうすればよいですか?

私の目的は、シーケンスからのすべての入力にgru関数を適用し、最終的な非表示状態を取得することです。

9
exAres

最初の次元をリストに変換するunpack関数を使用して、テンソルをリストに変換できます。同様のことを行う分割関数もあります。作業中のRNNモデルでアンスタックを使用しています。

y = tf.unstack(tf.transpose(y, (1, 0, 2)))

この場合、yは形状(BATCH_SIZE、TIME_STEPS、128)で始まります。転置してタイムステップを外側の次元にし、タイムステップごとに1つずつテンソルのリストにアンパックします。これで、yリストのすべての要素が形状(BATCH_SIZE、128)の場合、RNNにフィードできます。

7
chasep255

TF> = 1.0では、tf.packtf.unpackの名前がそれぞれtf.stacktf.unstackに変更されます。

10
wuhy08