如何在keras中将组ID传递给自定义丢失函数

时间:2018-03-21 11:32:21

标签: tensorflow keras

我的样本有组ID。有没有办法将组ID传递给我的keras模型并根据组ID计算损失?

1 个答案:

答案 0 :(得分:0)

您可以将组标签作为虚拟列插入目标向量 y 中,并将其从损失中排除。以下代码计算每组的均方误差并返回最大值。

def worst_case_group_mse(y_true, y_pred):
   """calculate mean squared error for each group separately and return worst value

    Args:
        y_true, y_pred (tf.Tensor): 
            last column corresponds to group index, 
            mean squared error calculated over all other columns

    Returns:
        tf.Tensor: maximum grouped mean squared error
    """
    groups = tf.cast(y_true[:,-1], tf.int32)
    y_true, y_pred = y_true[:,:-1], y_pred[:,:-1]

    square = tf.math.square(y_pred - y_true)
    unique, idx, count = tf.unique_with_counts(groups)
    group_losses = tf.math.unsorted_segment_mean(square, idx, tf.size(unique))
    group_losses = tf.math.reduce_mean(group_losses, axis=1)
    return tf.math.reduce_max(group_losses)