创建自定义损失时,有没有一种方法可以访问输入?

时间:2019-10-22 15:11:53

标签: python tensorflow neural-network

我有一个非常复杂的CNN网络,其中有 3个不同的输入 6个输出。我需要创建一个损失函数,其中输入之一会影响损失计算。

我尝试过custom loss function tutorial,但是当我使用模型输入作为参数并在计算中使用它时,我得到:

  

NotImplementedError:无法将符号张量(add_5:0)转换为   numpy数组。

调试之后,我意识到所提到的输入永远不会获得实际值,它保持张量的形状(无,无,无),而expected_y和predicted_y则获得真实的数据形状(10、46、46、19)。

自定义损失函数的代码( model 参数是fit_generator中使用的模型):

def custom_loss(model):
    def loss(y, gt):
        # W is the background weight
        W = model.inputs[-1][:, :, :, -1]
        print(y.shape)
        print(W.shape)
        return K.sum(K.square(y - gt) * W) / c.BATCH_SIZE / 2

    return loss

我可以访问流经模型的数据,而不仅仅是占位符吗?

1 个答案:

答案 0 :(得分:0)

我通过在所有输出中添加同一层(我需要的层)来解决此问题,然后自定义损失从该层获取信息,并将其应用于问题中所述的自定义MSE的默认计算。

仍然无法回答如何在损失中访问网络输入的问题,我对任何可能的见解都感兴趣。