Keras:模型

时间:2017-07-12 00:40:45

标签: python keras

我在所有图层中设置trainable=False,通过Model API实现,但我想验证这是否有效。 model.count_params()会返回参数总数,但除了查看model.summary()的最后几行之外,还有哪些方法可以获取可训练参数的总数?

4 个答案:

答案 0 :(得分:15)

from keras import backend as K

trainable_count = int(
    np.sum([K.count_params(p) for p in set(model.trainable_weights)]))
non_trainable_count = int(
    np.sum([K.count_params(p) for p in set(model.non_trainable_weights)]))

print('Total params: {:,}'.format(trainable_count + non_trainable_count))
print('Trainable params: {:,}'.format(trainable_count))
print('Non-trainable params: {:,}'.format(non_trainable_count))

上述代码段可以在layer_utils.print_summary()定义的末尾发现,summary()正在调用。

答案 1 :(得分:0)

计数可训练参数的另一种方法是:

model.count_params()

答案 2 :(得分:0)

对于tensorflow.keras这对我有用。来自tensorflow github代码中layer_utils.py中的函数print_layer_summary_with_connections()

import numpy as np
from tensorflow.python.util import object_identity

def count_params(weights):
    return int(sum(np.prod(p.shape.as_list())
      for p in object_identity.ObjectIdentitySet(weights)))

if hasattr(model, '_collected_trainable_weights'):
    trainable_count = count_params(model._collected_trainable_weights)
else:
    trainable_count = count_params(model.trainable_weights)

print (trainable_count)

答案 3 :(得分:0)

对于 TensorFlow 2.0

import tensorflow.keras.backend as K

trainable_count = np.sum([K.count_params(w) for w in model.trainable_weights])
non_trainable_count = np.sum([K.count_params(w) for w in model.non_trainable_weights])

print('Total params: {:,}'.format(trainable_count + non_trainable_count))
print('Trainable params: {:,}'.format(trainable_count))
print('Non-trainable params: {:,}'.format(non_trainable_count))
相关问题