web-dev-qa-db-ja.com

TensorFlowの新しいバージョンのtf.nn.rnnに相当するものは何ですか?

TensorFlowのバージョン0.8で、以下を使用してRNNネットワークを作成していました。

_from tensorflow.python.ops import rnn

# Define a lstm cell with tensorflow
lstm_cell = rnn_cell.BasicLSTMCell(n_hidden, forget_bias=1.0)

# Get lstm cell output
outputs, states = rnn.rnn(cell=lstm_cell, inputs=x, dtype=tf.float32)
_

rnn.rnn()は使用できなくなり、_tf.contrib_に移動されたようです。 BasicLSTMCellからRNNネットワークを作成するための正確なコードは何ですか?

または、LSTMがスタックされている場合は、

_lstm_cell = tf.contrib.rnn.BasicLSTMCell(hidden_size, forget_bias=0.0)
stacked_lstm = tf.contrib.rnn.MultiRNNCell([lstm_cell] * num_layers)
outputs, new_state =  tf.nn.rnn(stacked_lstm, inputs, initial_state=_initial_state)
_

新しいバージョンのTensorFlowでの_tf.nn.rnn_の置き換えは何ですか?

10
Saeed

tf.nn.rnntf.nn.static_rnn と同等です。

注: TensorFlowのバージョン1.2 の前は、namespacetf.nn.static_rnnは存在しませんでしたが、 tf.contrib.rnn.static_rnn (これは、tf.nn.static_rnnエイリアスになりました)。

13
ruoho ruotsi
2