从保存的自动编码器中提取编码器和解码器

时间:2021-04-28 09:19:59

标签: python keras autoencoder

我为我的项目使用的大量自动编码器保存了模型。它们是使用 autoencoder.save(outdir + "autoencoder_"+params) 函数保存的。

有什么方法可以让我提取每个已保存模型的编码器和解码器组件,或者我是否需要重新运行脚本并添加 encoder = Model(input, bottleneck)decoder = Model(bottleneck, output) 行并保存它们模型?

这是我试图检索的自动编码器结构:

autoencoder.summary()

Model: "model_3"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
input_1 (InputLayer)         [(None, 3593, 4)]         0         
_________________________________________________________________
flatten (Flatten)            (None, 14372)             0         
_________________________________________________________________
dense (Dense)                (None, 1797)              25828281  
_________________________________________________________________
dense_1 (Dense)              (None, 719)               1292762   
_________________________________________________________________
dense_2 (Dense)              (None, 180)               129600    
_________________________________________________________________
dense_3 (Dense)              (None, 719)               130139    
_________________________________________________________________
dense_4 (Dense)              (None, 1797)              1293840   
_________________________________________________________________
dense_5 (Dense)              (None, 14372)             25840856  
_________________________________________________________________
reshape (Reshape)            multiple                  0         
=================================================================
Total params: 54,515,478
Trainable params: 54,515,478
Non-trainable params: 0
_________________________________________________________________

1 个答案:

答案 0 :(得分:1)

您可以将权重转移到两个不同的神经网络模型。您只需要确定瓶颈层的索引,您可以通过运行 model.summary()

轻松了解该索引

这是一个可以帮助您复制模型的片段

bottleneck_index = # this you need to identify
encoder_model = tf.keras.Sequential()
for layer in ae_model.layers[:bottleneck_index]:
    layer_config = layer.get_config()  # to get all layer's parameters (units, activation, etc...)
    copied_layer = type(layer).from_config(layer_config) # to initialize the same layer class with same parameters
    copied_layer.build(layer.input_shape)  # build the layer to initialize the weights.
    copied_layer.set_weights(layer.get_weights())  # transfer the trainable parameters
    encoder_model.add(copied_layer)  # add it to the encoder's model

对解码器执行相同操作,其中 ae_model.layers[bottleneck_index:]

当然,如果当前层的单位小于连续层,您甚至可以通过检查当前层的单位来识别瓶颈指标。

相关问题