Tensorflow多输入损耗函数

时间:2017-08-02 09:48:17

标签: function tensorflow loss

我正在尝试在Tensorflow中实现CNN(与VGG非常相似的架构),然后在第一个完全连接的层之后分成两个分支。它遵循本文:https://arxiv.org/abs/1612.01697

网络的两个分支中的每一个都输出一组32个数字。我想写一个关节损失函数,它将需要3个输入:

  1. 分支1(y)的预测
  2. 分支2(alpha)的预测
  3. 标签Y(地面实况)(q)
  4. 并计算加权损失,如下图所示:

    Loss function definition

    q_hat = tf.divide(tf.reduce_sum(tf.multiply(alpha, y),0), tf.reduce_sum(alpha,0))
    loss  = tf.abs(tf.subtract(q_hat, q))
    

    我理解我需要使用tf函数来实现这个损失函数。实施上述功能后,网络正在进行培训,但一旦经过培训,就无法输出预期结果。

    有没有人试过在一个关节损失函数中组合网络的两个分支的输出?这是TensorFlow支持的吗?也许我在某处犯了错误?任何帮助将不胜感激。如果您希望我添加更多详细信息,请与我们联系。

1 个答案:

答案 0 :(得分:1)

从TensorFlow的角度来看,“常规”CNN图和“分支”图之间绝对没有区别。对于TensorFlow,它只是一个需要执行的图形。所以,TensorFlow当然支持这一点。 “将两个分支合并为关节损失”也没什么特别之处。事实上,损失取决于两个分支是“好的”。这意味着当你要求TensorFlow计算损失时,它必须通过两个分支进行正向传递,这就是你想要的。

我注意到的一件事是你的损失代码与图像不同。您的代码似乎执行此操作https://ibb.co/kbEH95

相关问题