Keras自定义丢失函数(无法更新y_pred)

时间:2018-04-05 15:39:26

标签: tensorflow keras loss-function

我有这种损失功能:

def bingo_loss(y_true, y_pred):
    one = tf.ones([10])
    y2 = tf.subtract(one, y_pred)
    _, indices = tf.nn.top_k(y_pred, k = 3)
    loss = tf.scatter_update(y_pred, indices, y2)
    return loss

输出有10个节点。

损失选择前3个最大输出,其索引存储在indices

对于前3个值,损失为(1 - y_pred)。

否则,损失是(y_pred)。

这是通过scatter_update实现的。

但我有一个模糊的错误:

File "Keras-NN-training2.py", line 104, in bingo_loss
loss = tf.scatter_update(y_pred, indices, y2)
File ".....state_ops.py", line 352, in scatter_update
ref.handle, indices, ops.convert_to_tensor(updates, ref.dtype),
AttributeError: 'Tensor' object has no attribute 'handle'

我的朋友在另一台机器上跑步说他得到了

TypeError: 'ScatterUpdate' Op requires that input 'ref' be a mutable tensor (e.g.: a tf.Variable)

我已经放弃了这种方法 - 我宁愿编写一个新的层来实现我所需要的,而不仅仅是非常严格的损失函数。但知道解决这个问题的方法真好。谢谢!

0 个答案:

没有答案
相关问题