web-dev-qa-db-ja.com

TensorflowEstimator-warm_start_fromおよびmodel_dir

tf.estimatorwarm_start_fromで使用する場合およびmodel_dir、およびwarm_start_fromディレクトリとmodel_dirディレクトリの両方に有効なチェックポイントが含まれている場合、実際に復元されるチェックポイントはどれですか?

コンテキストを与えるために、私の推定コードは次のようになります

est = tf.estimator.Estimator(
    model_fn=model_fn,
    model_dir=model_dir,
    warm_start_from=warm_start_dir)

for Epoch in range(num_epochs):
    est.train(input_fn=train_input_fn)
    est.evaluate(input_fn=eval_input_fn)

(入力関数はワンショットイテレータを使用します。)

したがって、最初の反復でmodel_dirが空のときに、ウォームスタートチェックポイントをロードしますが、次のエポックでは、model_dirの最後の反復からの中間の微調整されたチェックポイントをロードします。しかし、少なくともログからは、warm_start_dirがまだロードされているように見えます。

次の反復のためにEstimatorをオーバーライドすることもできますが、何らかの方法でEstimatorに組み込むべきではないかと思います。

9
mtngld

同様の問題が発生しました。セッションの開始時に実行される初期化フックを提供し、tf.estimator.train_and_evaluateを使用することでこれを解決しました(ただし、このソリューション全体の功績は認められません。他の場所で別の目的のために似たようなもの):

class InitHook(tf.train.SessionRunHook):
    """initializes model from a checkpoint_path
    args:
        modelPath: full path to checkpoint
    """
    def __init__(self, checkpoint_dir):
        self.modelPath = checkpoint_dir
        self.initialized = False

    def begin(self):
        """
        Restore encoder parameters if a pre-trained encoder model is available and we haven't trained previously
        """
        if not self.initialized:
            log = logging.getLogger('tensorflow')
            checkpoint = tf.train.latest_checkpoint(self.modelPath)
            if checkpoint is None:
                log.info('No pre-trained model is available, training from scratch.')
            else:
                log.info('Pre-trained model {0} found in {1} - warmstarting.'.format(checkpoint, self.modelPath))
                tf.train.warm_start(checkpoint)
            self.initialized = True

次に、トレーニング用:

initHook = InitHook(checkpoint_dir = warm_start_dir)
trainSpec = tf.estimator.TrainSpec(
    input_fn = train_input_fn,
    max_steps = N_STEPS, 
    hooks = [initHook]
)
evalSpec = tf.estimator.EvalSpec(
    input_fn = eval_input_fn,
    steps = None,
    name = 'eval',
    throttle_secs = 3600
)
tf.estimator.train_and_evaluate(estimator, trainSpec, evalSpec)

これは最初に1回実行され、warm_start_dirから変数を初期化します。後で、推定器model_dirに新しいチェックポイントがある場合、そこからwarm_startingを続行します。

3
kamyonet