Keras自定义损失函数,使用隐藏层输出作为目标的一部分

时间:2018-08-21 03:42:45

标签: python tensorflow keras autoencoder

我正在尝试在Keras中实现一种自动编码器,该编码器不仅可以最大程度地减少重构错误,而且其构造功能还应该最大化我定义的度量。我真的不知道该怎么做。

这是我到目前为止的摘要:

corrupt_data = self._corrupt(self.data, 0.1)

# define encoder-decoder network structure
# create input layer
input_layer = Input(shape=(corrupt_data.shape[1], ))
encoded = Dense(self.encoding_dim, activation = "relu")(input_layer)
decoded = Dense(self.data.shape[1], activation="sigmoid")(encoded)

# create autoencoder
dae = Model(input_layer, decoded)

# define custom multitask loss with wlm measure
def multitask_loss(y_true, y_pred):
    # extract learned features from hidden layer
    learned_fea = Model(input_layer, encoded).predict(self.data)
    # additional measure I want to optimize from an external function
    wlm_measure = wlm.measure(learned_fea, self.labels)
    cross_entropy = losses.binary_crossentropy(y_true, y_pred)
    return wlm_measure + cross_entropy

# create optimizer
dae.compile(optimizer=self.optimizer, loss=multitask_loss)

dae.fit(corrupt_data, self.data, 
                epochs=self.epochs, batch_size=20, shuffle=True, 
                callbacks=[tensorboard])

# separately create an encoder model
encoder = Model(input_layer, encoded)

当前这无法正常工作...当我查看训练历史记录时,该模型似乎忽略了额外的度量,仅根据交叉熵损失进行训练。另外,如果我更改损失函数以仅考虑wlm量度,则会得到错误“ numpy.float64”对象没有属性“ get_shape”(我不知道将wlm函数的返回类型更改为张量是否有帮助)。

我认为有些地方可能出了问题。我不知道我是否在自定义损失函数中正确提取隐藏层的输出。另外我也不知道wlm.measure函数是否正确输出-是否应输出numpy.float32或float32类型的一维张量。

基本上,传统的损失函数仅关心输出层的预测标签和真实标签。就我而言,我还需要考虑隐藏层的输出(激活),这在Keras中实现起来并不那么简单。

感谢您的帮助!

1 个答案:

答案 0 :(得分:1)

您不想在自定义损失函数中定义learned_fea Model。相反,您可以预先定义一个具有两个输出的模型:解码器的输出(重建)和endder的输出(特征表示):

multi_output_model = Model(inputs=input_layer, outputs=[decoded, encoded])

现在,您可以编写仅适用于编码器输出的自定义损失函数:

def custom_loss(y_true, y_pred):
    return wlm.measure(y_pred, y_true)

在编译模型时,您传递损失函数列表(如果命名张量,则传递字典):

model.compile(loss=['binary_crossentropy', custom_loss], optimizer=...)

并通过传递输出列表来拟合模型:

model.fit(X=X, y=[data_to_be_reconstructed,labels_for_wlm_measure])
相关问题