如何提取Keras图层权重作为可训练参数?

时间:2019-07-12 06:18:09

标签: python tensorflow keras

我正在训练类似GAN的模型,但不完全相同。我正在将Keras与TensorFlow后端一起使用。

我有两个Keras模型GD。我想将G中目标图层的 weights参数输出为模型D的输入,并将D.predict(G.weights)的结果用作G的损失函数的一部分,即D是不可训练的,但是参数G.weights是可训练的。要以这种方式进一步训练G.weights

我尝试使用

def custom_loss(ytrue, ypred):
    ### Something to do with ytrue and ypred
    weight = self.G.get_layer('target').get_weights()
    loss += self.D.predict(weight)
    return loss

,但是显然它不起作用,因为weight只是一个小数组,并且是不可训练的。

有没有办法获得仍然可以在Keras中训练的模型的权重?我是Keras的新手,对TensorFlow知之甚少。我将非常感谢有人可以提供帮助!

1 个答案:

答案 0 :(得分:1)

正如您提到的,layer.get_weights()将返回矩阵的当前权重。您想要馈送进行预测的是计算图中表示此类权重的节点。您可以改用layer.trainable_weights,这将返回两个tf.Variable,您可以将其馈送到另一个图层/模型。

请注意,一个变量用于单元到单元的连接,另一个变量用于偏置。如果您想从中获得平坦的张量,可以执行以下操作:

from keras import backend as K

...
ww, bias = self.G.get_layer('target').trainable_weights
flattened_weights = Flatten()(K.concat([ww, K.reshape(bias, (5, 1))], axis=1))
相关问题