web-dev-qa-db-ja.com

カスタムサブクラスモデルを保存できません

tf.keras.Model subclassing に触発されて、カスタムモデルが作成されました。
私はそれを訓練して成功することができますが、それを保存することはできません
私はtensorflow v1.10(またはv1.9)でpython3.6を使用しています

ここに最小限の完全なコード例:

_import tensorflow as tf
from tensorflow.keras.datasets import mnist


class Classifier(tf.keras.Model):
    def __init__(self):
        super().__init__(name="custom_model")

        self.batch_norm1 = tf.layers.BatchNormalization()
        self.conv1 = tf.layers.Conv2D(32, (7, 7))
        self.pool1 = tf.layers.MaxPooling2D((2, 2), (2, 2))

        self.batch_norm2 = tf.layers.BatchNormalization()
        self.conv2 = tf.layers.Conv2D(64, (5, 5))
        self.pool2 = tf.layers.MaxPooling2D((2, 2), (2, 2))

    def call(self, inputs, training=None, mask=None):
        x = self.batch_norm1(inputs)
        x = self.conv1(x)
        x = tf.nn.relu(x)
        x = self.pool1(x)

        x = self.batch_norm2(x)
        x = self.conv2(x)
        x = tf.nn.relu(x)
        x = self.pool2(x)

        return x


if __name__ == '__main__':
    (x_train, y_train), (x_test, y_test) = mnist.load_data()

    x_train = x_train.reshape(*x_train.shape, 1)[:1000]
    y_train = y_train.reshape(*y_train.shape, 1)[:1000]

    x_test = x_test.reshape(*x_test.shape, 1)
    y_test = y_test.reshape(*y_test.shape, 1)

    y_train = tf.keras.utils.to_categorical(y_train)
    y_test = tf.keras.utils.to_categorical(y_test)

    model = Classifier()

    inputs = tf.keras.Input((28, 28, 1))

    x = model(inputs)
    x = tf.keras.layers.Flatten()(x)
    x = tf.keras.layers.Dense(10, activation="sigmoid")(x)

    model = tf.keras.Model(inputs=inputs, outputs=x)
    model.compile(optimizer="adam", loss="binary_crossentropy", metrics=["accuracy"])
    model.fit(x_train, y_train, epochs=1, shuffle=True)

    model.save("./my_model")
_

エラーメッセージ:

_1000/1000 [==============================] - 1s 1ms/step - loss: 4.6037 - acc: 0.7025
Traceback (most recent call last):
  File "/home/user/Data/test/python/mnist/mnist_run.py", line 62, in <module>
    model.save("./my_model")
  File "/home/user/miniconda3/envs/ml3.6/lib/python3.6/site-packages/tensorflow/python/keras/engine/network.py", line 1278, in save
    save_model(self, filepath, overwrite, include_optimizer)
  File "/home/user/miniconda3/envs/ml3.6/lib/python3.6/site-packages/tensorflow/python/keras/engine/saving.py", line 101, in save_model
    'config': model.get_config()
  File "/home/user/miniconda3/envs/ml3.6/lib/python3.6/site-packages/tensorflow/python/keras/engine/network.py", line 1049, in get_config
    layer_config = layer.get_config()
  File "/home/user/miniconda3/envs/ml3.6/lib/python3.6/site-packages/tensorflow/python/keras/engine/network.py", line 1028, in get_config
    raise NotImplementedError
NotImplementedError

Process finished with exit code 1
_

エラー行を調べたところ、get_configメソッドがself._is_graph_networkをチェックしていることがわかりました

誰かがこの問題に対処しますか?

ありがとう!

更新1:
ケラス2.2.2(tf.kerasではない)
コメントが見つかりました(モデル保存用)
ファイル:keras/engine/network.py
関数:get_config

#サブクラス化されたネットワークはシリアル化できません
#(シリアライゼーションが
#サブクラス化されたネットワークの作成者)。

だから、明らかにそれはうまくいきません...
documentation でそれを指摘しないのはなぜでしょうか。

更新2:
kerasドキュメント にあります:

サブクラスモデルでは、モデルのトポロジはPython codeとして定義されます
(レイヤーの静的グラフとしてではなく)。つまり、モデルの
トポロジは検査またはシリアル化できません。その結果、次の
メソッドと属性は、サブクラス化されたモデルでは使用できません:

model.inputsおよびmodel.outputs。
model.to_yaml()およびmodel.to_json()
model.get_config()およびmodel.save()。

したがって、サブクラス化を使用してモデルを保存する方法はありません。
Model.save_weights()のみを使用できます

13
RedEyed

この回答はTensorflow 2.0向けです

TL; DR:

  1. カスタムサブクラスのケラスモデルにはmodel.save()を使用しないでください。
  2. 代わりにsave_weights()およびload_weights()を使用してください。

Tensorflowチームの助けを借りて、カスタムサブクラスKerasモデルを保存するベストプラクティスは、その重みを保存し、必要に応じてロードすることです。

Kerasカスタムサブクラスモデルを単純に保存できないのは、安全にシリアル化できないカスタムコードが含まれているためです。ただし、同じモデル構造とカスタムコードがある場合は、問題なくウェイトを保存/ロードできます。

Kerasの作者であるFrancois Cholletが書いた優れたチュートリアルがあり、Sequential/Functional/Keras/Custom Sub-Class ModelsをColabのTensorflow 2.0に保存/ロードする方法については here にあります。 Saving Subclassed Modelsセクションでは、次のように述べています:

シーケンシャルモデルと機能モデルは、レイヤーのDAGを表すデータ構造です。そのため、それらは安全にシリアライズおよびデシリアライズできます。

サブクラス化されたモデルは、データ構造ではなく、コードの一部であるという点で異なります。モデルのアーキテクチャは、callメソッドの本体を介して定義されます。これは、モデルのアーキテクチャを安全にシリアル化できないことを意味します。モデルをロードするには、モデルを作成したコード(モデルサブクラスのコード)にアクセスできる必要があります。または、このコードをバイトコードとしてシリアル化することもできます(たとえば酸洗いを使用)。

14
Huan

これは、次のリリースで 1.13プレリリースパッチノート に従って修正される予定です。

  • Keras&Python API:
    • サブクラス化されたKerasモデルはtf.contrib.saved_model.save_keras_modelを介して保存できるようになりました。

編集:これはノートが示唆するほど完成していないようです。 v1.13のその関数のドキュメント 状態:

モデルの制限:-シーケンシャルモデルと機能モデルは常に保存できます。 -サブクラス化されたモデルは、serving_only = Trueの場合にのみ保存できます。これは、トレーニングと評価のグラフをエクスポートするためにモデルをコピーする現在の実装によるものです。サブクラス化されたモデルのトポロジーを決定できないため、サブクラス化されたモデルを複製できません。サブクラス化されたモデルは、将来的に完全にエクスポート可能になります。

6
qwitwa

使用する model.predicttf.saved_model.save

0
Antoine Liutkus