テンソルフローモデルのパラメーターの総数をカウントする関数呼び出しや別の方法はありますか?
つまり、トレーニング可能な変数のN次元ベクトルにはN個のパラメーターがあり、NxM
マトリックスにはN*M
パラメーターなどがあります。したがって、基本的には、すべてのテンソルフローセッションでトレーニング可能な変数。
tf.trainable_variables()
のすべての変数の形状をループします。
total_parameters = 0
for variable in tf.trainable_variables():
# shape is an array of tf.Dimension
shape = variable.get_shape()
print(shape)
print(len(shape))
variable_parameters = 1
for dim in shape:
print(dim)
variable_parameters *= dim.value
print(variable_parameters)
total_parameters += variable_parameters
print(total_parameters)
更新:この回答により、Tensorflowの動的/静的な形状を明確にする記事を書きました: https://pgaleone.eu/tensorflow/2018/07/28/understanding-tensorflow-tensors-shape-static- dynamic /
私はさらに短いバージョン、numpyを使用して1行のソリューションを持っています:
np.sum([np.prod(v.get_shape().as_list()) for v in tf.trainable_variables()])
与えられた答えが実際に実行されるかどうかはわかりません(動作するためには、dimオブジェクトをintに変換する必要があることがわかりました)。機能するものを次に示します。コピーして関数を貼り付けて呼び出すことができます(いくつかのコメントを追加しました)。
def count_number_trainable_params():
'''
Counts the number of trainable variables.
'''
tot_nb_params = 0
for trainable_variable in tf.trainable_variables():
shape = trainable_variable.get_shape() # e.g [D,F] or [W,H,C]
current_nb_params = get_nb_params_shape(shape)
tot_nb_params = tot_nb_params + current_nb_params
return tot_nb_params
def get_nb_params_shape(shape):
'''
Computes the total number of params for a given shap.
Works for any number of shapes etc [D,F] or [W,H,C] computes D*F and W*H*C.
'''
nb_params = 1
for dim in shape:
nb_params = nb_params*int(dim)
return nb_params
パラメータの数を自分で計算する場合は、2つの既存の答えが役立ちます。 「TensorFlowモデルのプロファイルを作成する簡単な方法はありますか?」という行に沿った質問であれば、 tfprof を調べることを強くお勧めします。パラメーターの数の計算など、モデルのプロファイルを作成します。
私は同等の短い実装をスローします:
def count_params():
"print number of trainable variables"
size = lambda v: reduce(lambda x, y: x*y, v.get_shape().as_list())
n = sum(size(v) for v in tf.trainable_variables())
print "Model size: %dK" % (n/1000,)
Numpyを避けたい場合(多くのプロジェクトでは省略できます):
all_trainable_vars = tf.reduce_sum([tf.reduce_prod(v.shape) for v in tf.trainable_variables()])
これは、Julius Kunzeによる以前の回答のTF翻訳です。
TF操作と同様に、以下を評価するにはセッションの実行が必要です。
print(sess.run(all_trainable_vars))
モデルがKeras
モデル、具体的にはtensorflow.python.keras.engine.training.Model
である場合、model.count_params()
を使用できます。
ドキュメントはここにあります: https://www.tensorflow.org/api_docs/python/tf/keras/backend/count_params