Keras vgg传输学习保存/加载自定义完全连接层

时间:2017-08-28 21:23:58

标签: keras

在转学习中,可以仅在训练期间保存完全连接的层,并在以后加载它以继续训练?如果可能的话,怎么做?

1 个答案:

答案 0 :(得分:0)

这可能不是最佳解决方案,但我认为这可以解决您的问题。

import pickle

#set a name for all fully connected layers.
model.add(Dense(...,name='fc1'))
model.add(Dense(...,name='fc2'))
model.add(Dense(...,name='fc3'))


layers_to_save = ['fc1','fc2','fc3'] # add here any layer you want to save

# Save weights to a dictionary
weights_dict = dict([(layer.name, layer.get_weights()) for layer in model.layers if layer.name in layers_to_save]) 

with open('filename.pickle', 'wb') as handle:
    pickle.dump(weights_dict, handle, protocol=pickle.HIGHEST_PROTOCOL)




# Load weights 
with open('filename.pickle', 'rb') as handle:
    weights_dict = pickle.load(handle)

for name in layers_to_save:
    model.get_layer(name).set_weights(weights_dict[name])

您还可以查看此keras blog post的后半部分以了解其他方法。