web-dev-qa-db-ja.com

カスタムTensorFlow Kerasオプティマイザー

_tf.keras_ APIに準拠するカスタムオプティマイザークラスを記述したいとします(TensorFlowバージョン> = 2.0を使用)。これを行うための文書化された方法と実装で何が行われるかについて混乱しています。

_tf.keras.optimizers.Optimizer_のドキュメント states

_  ### Write a customized optimizer.
  If you intend to create your own optimization algorithm, simply inherit from
  this class and override the following methods:

    - resource_apply_dense (update variable given gradient tensor is dense)
    - resource_apply_sparse (update variable given gradient tensor is sparse)
    - create_slots (if your optimizer algorithm requires additional variables)
_

ただし、現在の_tf.keras.optimizers.Optimizer_実装は_resource_apply_dense_メソッドを定義していませんが、doesプライベートな外観の __resource_apply_dense_メソッドスタブを定義しています 。同様に、_resource_apply_sparse_または_create_slots_メソッドはありませんが、 __resource_apply_sparse_メソッドスタブ および __create_slots_メソッド呼び出し

公式の_tf.keras.optimizers.Optimizer_サブクラス(例として_tf.keras.optimizers.Adam_を使用)には、 __resource_apply_dense___resource_apply_sparse_があります 、および __create_slots_ メソッドがあり、先頭に下線がないと、そのようなメソッドはありません。

やや少ない公式の_tf.keras.optimizers.Optimizer_サブクラス(たとえば、TensorFlowアドオンの_tfa.optimizers.MovingAverage_: __resource_apply_dense___resource_apply_sparse___create_slots_ )。

私にとってもう1つの混乱点は、TensorFlow Addonsオプティマイザの一部alsoが_apply_gradients_メソッドをオーバーライドすることです(たとえば、 _tfa.optimizers.MovingAverage_ )、一方、_tf.keras.optimizers_オプティマイザーはそうしません。

さらに、_apply_gradients_メソッドの_tf.keras.optimizers.Optimizer_メソッド が__create_slots_ を呼び出していることに気付きましたが、基本の_tf.keras.optimizers.Optimizer_クラスには__create_slots_メソッド。したがって、サブクラスが__create_slots_をオーバーライドしない場合、オプティマイザーサブクラスで_apply_gradients_メソッドmustが定義されているようです。


ご質問

_tf.keras.optimizers.Optimizer_をサブクラス化する正しい方法は何ですか?具体的には

  1. 上部にリストされている_tf.keras.optimizers.Optimizer_のドキュメントは、それらが言及しているメソッドの先行アンダースコアバージョンをオーバーライドすることを単に意味しますか(たとえば、__resource_apply_dense_ではなく_resource_apply_dense_)。もしそうなら、これらのプライベートに見えるメソッドがTensorFlowの将来のバージョンでそれらの振る舞いを変更しないことについてのAPI保証はありますか?これらのメソッドのシグネチャは何ですか?
  2. _apply_gradients_メソッドに加えて、いつ__apply_resource_[dense|sparse]_をオーバーライドしますか?

編集。GitHubで未解決の問題: #36449

30
Artem Mavrin
  1. はい、これはドキュメントのエラーのようです。上記のアンダースコア名は、オーバーライドする正しい方法です。関連するのは、これらがすべて定義されているが、基本クラスには実装されていない非Kerasオプティマイザーです https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/training/optimizer.py
  def _create_slots(self, var_list):
    """Create all slots needed by the variables.
    Args:
      var_list: A list of `Variable` objects.
    """
    # No slots needed by default
    pass

  def _resource_apply_dense(self, grad, handle):
    """Add ops to apply dense gradients to the variable `handle`.
    Args:
      grad: a `Tensor` representing the gradient.
      handle: a `Tensor` of dtype `resource` which points to the variable
       to be updated.
    Returns:
      An `Operation` which updates the value of the variable.
    """
    raise NotImplementedError()

  def _resource_apply_sparse(self, grad, handle, indices):
    """Add ops to apply sparse gradients to the variable `handle`.
    Similar to `_apply_sparse`, the `indices` argument to this method has been
    de-duplicated. Optimizers which deal correctly with non-unique indices may
    instead override `_resource_apply_sparse_duplicate_indices` to avoid this
    overhead.
    Args:
      grad: a `Tensor` representing the gradient for the affected indices.
      handle: a `Tensor` of dtype `resource` which points to the variable
       to be updated.
      indices: a `Tensor` of integral type representing the indices for
       which the gradient is nonzero. Indices are unique.
    Returns:
      An `Operation` which updates the value of the variable.
    """
    raise NotImplementedError()
  1. apply_denseについて知りません。 1つには、オーバーライドすると、レプリカごとのDistributionStrategyが「危険」である可能性があることをコードが述べている
    # TODO(isaprykin): When using a DistributionStrategy, and when an
    # optimizer is created in each replica, it might be dangerous to
    # rely on some Optimizer methods.  When such methods are called on a
    # per-replica optimizer, an exception needs to be thrown.  We do
    # allow creation per-replica optimizers however, because the
    # compute_gradients()->apply_gradients() sequence is safe.
1
Tyler