web-dev-qa-db-ja.com

tf.estimatorで早期に停止する方法は?

TensorFlow 1.4でtf.estimatorを使用しており、tf.estimator.train_and_evaluateは素晴らしいですが、早期に停止する必要があります。それを追加する好ましい方法は何ですか?

このためにtf.train.SessionRunHookがどこかにあると思います。 ValidationMonitorが付いた古いcontribパッケージがあり、それが早期に停止したように見えましたが、1.4ではもう存在しないようです。または、将来的にtf.kerasの代わりにtf.estimator/tf.layers/tf.data(早期停止が非常に簡単です)に依存することをお勧めしますか?

19
Carl Thomé

良いニュースです! tf.estimatorがmasterで早期にサポートを停止しました。1.10になりそうです。

estimator = tf.estimator.Estimator(model_fn, model_dir)

os.makedirs(estimator.eval_dir())  # TODO This should not be expected IMO.

early_stopping = tf.contrib.estimator.stop_if_no_decrease_hook(
    estimator,
    metric_name='loss',
    max_steps_without_decrease=1000,
    min_steps=100)

tf.estimator.train_and_evaluate(
    estimator,
    train_spec=tf.estimator.TrainSpec(train_input_fn, hooks=[early_stopping]),
    eval_spec=tf.estimator.EvalSpec(eval_input_fn))
25
Carl Thomé

はいあります - tf.train.StopAtStepHook

このフック要求は、いくつかのステップが実行されるか、最後のステップに到達した後に停止します。 2つのオプションのうち1つのみを指定できます。

また、拡張して、ステップの結果に基づいて独自の停止戦略を実装することもできます。

class MyHook(session_run_hook.SessionRunHook):
  ...
  def after_run(self, run_context, run_values):
    if condition:
      run_context.request_stop()
2
Maxim

最初に、早期​​停止コールで使用できるように損失に名前を付ける必要があります。推定変数で損失変数の名前が「loss」の場合、行

copyloss = tf.identity(loss, name="loss")

そのすぐ下で動作します。

次に、このコードでフックを作成します。

class EarlyStopping(tf.train.SessionRunHook):
    def __init__(self,smoothing=.997,tolerance=.03):
        self.lowestloss=float("inf")
        self.currentsmoothedloss=-1
        self.tolerance=tolerance
        self.smoothing=smoothing
    def before_run(self, run_context):
        graph = ops.get_default_graph()
        #print(graph)
        self.lossop=graph.get_operation_by_name("loss")
        #print(self.lossop)
        #print(self.lossop.outputs)
        self.element = self.lossop.outputs[0]
        #print(self.element)
        return tf.train.SessionRunArgs([self.element])
    def after_run(self, run_context, run_values):
        loss=run_values.results[0]
        #print("loss "+str(loss))
        #print("running average "+str(self.currentsmoothedloss))
        #print("")
        if(self.currentsmoothedloss<0):
            self.currentsmoothedloss=loss*1.5
        self.currentsmoothedloss=self.currentsmoothedloss*self.smoothing+loss*(1-self.smoothing)
        if(self.currentsmoothedloss<self.lowestloss):
            self.lowestloss=self.currentsmoothedloss
        if(self.currentsmoothedloss>self.lowestloss+self.tolerance):
            run_context.request_stop()
            print("REQUESTED_STOP")
            raise ValueError('Model Stopping because loss is increasing from EarlyStopping hook')

これは、指数関数的に平滑化された損失検証をその最小値と比較し、許容値が高い場合、トレーニングを停止します。停止が早すぎる場合、許容値とスムージングを上げると、後で停止します。 1未満のスムージングを維持しないと、停止しません。

別の条件に基づいて停止する場合は、after_runのロジックを別のものに置き換えることができます。

次に、このフックを評価仕様に追加します。コードは次のようになります。

eval_spec=tf.estimator.EvalSpec(input_fn=lambda:eval_input_fn(batchsize),steps=100,hooks=[EarlyStopping()])#

重要な注意:関数run_context.request_stop()はtrain_and_evaluate呼び出しで壊れており、トレーニングを停止しません。そこで、トレーニングを停止するために値エラーを発生させました。したがって、train_and_evaluate呼び出しを次のようなtry catchブロックでラップする必要があります。

try:
    tf.estimator.train_and_evaluate(classifier,train_spec,eval_spec)
except ValueError as e:
    print("training stopped")

これを行わないと、トレーニングが停止したときにコードがエラーでクラッシュします。

2
user3806120

フックを使用しないもう1つのオプションは、tf.contrib.learn.Experimentを作成することです(これは、たとえ貢献している場合でも、新しいtf.estimator.Estimatorもサポートするようです)。

次に、(明らかに実験的な)メソッドcontinuous_train_and_evalを使用して、適切にカスタマイズされたcontinuous_eval_predicate_fnでトレーニングします。

テンソルフロー文書によれば、continuous_eval_predicate_fn

各反復後にevalを続行するかどうかを決定する述語関数。

最後の評価実行からのeval_resultsで呼び出されます。早期停止の場合、現在の最良の結果とカウンターを状態として保持し、早期停止の条件に達したときにFalseを返すカスタマイズされた関数を使用します。

追加された注:このアプローチは、tensorflow 1.7の非推奨メソッドを使用します(tf.contrib.learnのすべては、そのバージョン以降非推奨です: https://www.tensorflow.org/api_docs/python/tf/contrib/learn

1
skb