web-dev-qa-db-ja.com

Tensorflow LSTMのc_stateおよびm_stateとは何ですか?

Tensorflow r0.12のtf.nn.rnn_cell.LSTMCellに関するドキュメントでは、これをinitとして説明しています。

tf.nn.rnn_cell.LSTMCell.__call__(inputs, state, scope=None)

ここで、stateは次のとおりです。

state:state_is_TupleがFalseの場合、これは状態Tensor、2-D、バッチx state_sizeでなければなりません。 state_is_TupleがTrueの場合、これは列サイズがc_stateとm_stateの状態テンソルのタプルで、両方とも2次元でなければなりません。

何アールc_stateおよびm_stateそして、それらはどのようにLSTMに適合しますか?ドキュメント内のどこにもそれらへの参照が見つかりません。

ここにドキュメントのそのページへのリンクがあります。

16
Haziq Nordin

私は同じ質問につまずきました、ここに私がそれを理解する方法があります!最小限のLSTMの例:

import tensorflow as tf

sample_input = tf.constant([[1,2,3]],dtype=tf.float32)

LSTM_CELL_SIZE = 2

lstm_cell = tf.nn.rnn_cell.BasicLSTMCell(LSTM_CELL_SIZE, state_is_Tuple=True)
state = (tf.zeros([1,LSTM_CELL_SIZE]),)*2

output, state_new = lstm_cell(sample_input, state)

init_op = tf.global_variables_initializer()

sess = tf.Session()
sess.run(init_op)
print sess.run(output)

state_is_Tuple=Trueであるため、stateをこのcellに渡すときは、Tuple形式にする必要があります。 c_statem_stateはおそらく「メモリ状態」と「セル状態」ですが、これらの用語はドキュメントでのみ言及されているため、正直なところわかりません。 LSTMに関するコードと論文では、hcの文字が「出力値」と「セル状態」を示すために一般的に使用されています。 http://colah.github.io/posts/2015-08-Understanding-LSTMs/ これらのテンソルはセルの内部状態の組み合わせを表し、一緒に渡す必要があります。それを行う古い方法は単純にそれらを連結することでしたが、新しい方法はタプルを使用することです。

古い方法:

lstm_cell = tf.nn.rnn_cell.BasicLSTMCell(LSTM_CELL_SIZE, state_is_Tuple=False)
state = tf.zeros([1,LSTM_CELL_SIZE*2])

output, state_new = lstm_cell(sample_input, state)

新しい方法:

lstm_cell = tf.nn.rnn_cell.BasicLSTMCell(LSTM_CELL_SIZE, state_is_Tuple=True)
state = (tf.zeros([1,LSTM_CELL_SIZE]),)*2

output, state_new = lstm_cell(sample_input, state)

したがって、基本的に私たちが行ったことは、stateが長さ4の1テンソルから、長さ2の2つのテンソルに変更されることです。内容は同じままでした。 [0,0,0,0]([0,0],[0,0])になります。 (これにより高速化されるはずです)

14
avloss

ドキュメントが不明確であることに同意します。 _tf.nn.rnn_cell.LSTMCell.__call___ を見ると明確になります(TensorFlow 1.0.0からコードを取得しました):

_def __call__(self, inputs, state, scope=None):
    """Run one step of LSTM.

    Args:
      inputs: input Tensor, 2D, batch x num_units.
      state: if `state_is_Tuple` is False, this must be a state Tensor,
        `2-D, batch x state_size`.  If `state_is_Tuple` is True, this must be a
        Tuple of state Tensors, both `2-D`, with column sizes `c_state` and
        `m_state`.
      scope: VariableScope for the created subgraph; defaults to "lstm_cell".

    Returns:
      A Tuple containing:

      - A `2-D, [batch x output_dim]`, Tensor representing the output of the
        LSTM after reading `inputs` when previous state was `state`.
        Here output_dim is:
           num_proj if num_proj was set,
           num_units otherwise.
      - Tensor(s) representing the new state of LSTM after reading `inputs` when
        the previous state was `state`.  Same type and shape(s) as `state`.

    Raises:
      ValueError: If input size cannot be inferred from inputs via
        static shape inference.
    """
    num_proj = self._num_units if self._num_proj is None else self._num_proj

    if self._state_is_Tuple:
      (c_prev, m_prev) = state
    else:
      c_prev = array_ops.slice(state, [0, 0], [-1, self._num_units])
      m_prev = array_ops.slice(state, [0, self._num_units], [-1, num_proj])

    dtype = inputs.dtype
    input_size = inputs.get_shape().with_rank(2)[1]
    if input_size.value is None:
      raise ValueError("Could not infer input size from inputs.get_shape()[-1]")
    with vs.variable_scope(scope or "lstm_cell",
                           initializer=self._initializer) as unit_scope:
      if self._num_unit_shards is not None:
        unit_scope.set_partitioner(
            partitioned_variables.fixed_size_partitioner(
                self._num_unit_shards))
      # i = input_gate, j = new_input, f = forget_gate, o = output_gate
      lstm_matrix = _linear([inputs, m_prev], 4 * self._num_units, bias=True,
                            scope=scope)
      i, j, f, o = array_ops.split(
          value=lstm_matrix, num_or_size_splits=4, axis=1)

      # Diagonal connections
      if self._use_peepholes:
        with vs.variable_scope(unit_scope) as projection_scope:
          if self._num_unit_shards is not None:
            projection_scope.set_partitioner(None)
          w_f_diag = vs.get_variable(
              "w_f_diag", shape=[self._num_units], dtype=dtype)
          w_i_diag = vs.get_variable(
              "w_i_diag", shape=[self._num_units], dtype=dtype)
          w_o_diag = vs.get_variable(
              "w_o_diag", shape=[self._num_units], dtype=dtype)

      if self._use_peepholes:
        c = (sigmoid(f + self._forget_bias + w_f_diag * c_prev) * c_prev +
             sigmoid(i + w_i_diag * c_prev) * self._activation(j))
      else:
        c = (sigmoid(f + self._forget_bias) * c_prev + sigmoid(i) *
             self._activation(j))

      if self._cell_clip is not None:
        # pylint: disable=invalid-unary-operand-type
        c = clip_ops.clip_by_value(c, -self._cell_clip, self._cell_clip)
        # pylint: enable=invalid-unary-operand-type

      if self._use_peepholes:
        m = sigmoid(o + w_o_diag * c) * self._activation(c)
      else:
        m = sigmoid(o) * self._activation(c)

      if self._num_proj is not None:
        with vs.variable_scope("projection") as proj_scope:
          if self._num_proj_shards is not None:
            proj_scope.set_partitioner(
                partitioned_variables.fixed_size_partitioner(
                    self._num_proj_shards))
          m = _linear(m, self._num_proj, bias=False, scope=scope)

        if self._proj_clip is not None:
          # pylint: disable=invalid-unary-operand-type
          m = clip_ops.clip_by_value(m, -self._proj_clip, self._proj_clip)
          # pylint: enable=invalid-unary-operand-type

    new_state = (LSTMStateTuple(c, m) if self._state_is_Tuple else
                 array_ops.concat([c, m], 1))
    return m, new_state
_

キー行は次のとおりです。

_c = (sigmoid(f + self._forget_bias) * c_prev + sigmoid(i) *
         self._activation(j))
_

そして

_m = sigmoid(o) * self._activation(c)
_

そして

_new_state = (LSTMStateTuple(c, m) 
_

cmを計算するコードをLSTM方程式(以下を参照)で比較すると、それぞれセル状態(通常cで示される)と非表示状態(通常hで示される)に対応することがわかります:

enter image description here

new_state = (LSTMStateTuple(c, m)は、返された状態Tupleの最初の要素がc(セル状態aka _c_state_)であり、返された状態Tupleの2番目の要素がm(隠された状態aka _m_state_ )。

19

コードからのこの抜粋は役立つかもしれません

def __call__(self, inputs, state, scope=None):
  """Long short-term memory cell (LSTM)."""
  with vs.variable_scope(scope or type(self).__name__):  # "BasicLSTMCell"
    # Parameters of gates are concatenated into one multiply for efficiency.
    if self._state_is_Tuple:
      c, h = state
    else:
      c, h = array_ops.split(1, 2, state)
    concat = _linear([inputs, h], 4 * self._num_units, True)

    # i = input_gate, j = new_input, f = forget_gate, o = output_gate
    i, j, f, o = array_ops.split(1, 4, concat)

    new_c = (c * sigmoid(f + self._forget_bias) + sigmoid(i) *
             self._activation(j))
    new_h = self._activation(new_c) * sigmoid(o)

    if self._state_is_Tuple:
      new_state = LSTMStateTuple(new_c, new_h)
    else:
      new_state = array_ops.concat(1, [new_c, new_h])
    return new_h, new_state
2

https://github.com/tensorflow/tensorflow/blob/r1.2/tensorflow/python/ops/rnn_cell_impl.py

ライン#308-314

クラスLSTMStateTuple(_LSTMStateTuple): "" "state_sizezero_state、および出力状態にLSTMセルで使用されるタプル。2つの要素を格納します:(c, h)、この順序。state_is_Tuple=Trueの場合にのみ使用されます。

0
Z Chen