Keras自定义损失不训练

时间:2019-01-08 19:43:48

标签: python tensorflow keras

我有一个网络使用自定义损失功能来掩盖具有相同标签的输入,在训练准确性/损失不超过1%时。我不确定是什么原因造成的,我已经检查了输入数据,没有问题。

# Custom loss to take full batch (of size beam) and apply a mask to calculate the true loss within the beam
beam_size = 10

def create_mask(y, yhat):
    idxs = list(permutations(range(beam_size), r=2))
    perms_y = tf.squeeze(tf.gather(y, idxs))
    perms_yhat = tf.squeeze(tf.gather(yhat, idxs))
    mask  = tf.where(tf.not_equal(perms_y[:,0], perms_y[:,1]))
    mask = tf.reduce_sum(mask, 1)
    uneq = tf.boolean_mask(perms_y, mask, axis=0)
    yhat_uneq = tf.boolean_mask(perms_yhat, mask, axis=0)
    return uneq, yhat_uneq

def mask_acc(y, yhat):
    uneq, yhat_uneq = create_mask(y, yhat)
    uneq = tf.argmax(uneq,1)
    yhat_uneq = tf.argmax(yhat_uneq, 1)
    # argmax and compare
    return tf.cond(tf.greater(tf.size(yhat_uneq), 1), lambda: tf.reduce_sum(tf.cast(tf.equal(uneq, yhat_uneq), tf.float32)), lambda: 0.)

def mask_loss(y, yhat):
    # Cosider weighted loss
    uneq, yhat_uneq = create_mask(y, yhat)
    #uneq = tf.argmax(uneq,1)
    #create all permutations and zero out matches with mask
    total_loss = tf.reduce_mean(tf.losses.softmax_cross_entropy(onehot_labels=tf.cast(uneq, tf.int32), logits=yhat_uneq))
    #d = tf.Print(yhat_uneq, [yhat_uneq], summarize=-1)
    return total_loss

x = Input(shape=(72,300))
aux_input = Input(shape=(72, 3))
probs = Input(shape=(1,))
#dim_red_1 = Dense(100)(x)
dim_red_2 = Dense(25, activation='tanh')(x)
cat = concatenate([dim_red_2, aux_input])
encoded = LSTM(5)(cat)
output = Lambda(lambda x: K.sum(x, axis=1))(encoded)
#cat2 = concatenate([encoded, probs])
#output = Dense(1, activation='linear')(cat2)

sgd = optimizers.SGD(lr=0.01, decay=1e-6, momentum=.8, nesterov=True)
lstm_model = Model(inputs=[x, aux_input, probs], outputs=output)
lstm_model.compile(optimizer=sgd, loss=mask_loss, metrics=[mask_acc])

奇怪的是,将输出激活设置为softmax可以大大提高准确度,但是tf.losses.softmax_cross_entropy期望未标准化的logit不确定为什么会这样。

0 个答案:

没有答案