Keras - 如何获得非标准化的logits而不是概率

时间:2017-10-31 13:15:42

标签: nlp keras

我正在Keras中创建一个模型,并希望计算自己的指标(困惑)。这需要使用非标准化的概率/对数。但是,keras模型仅返回softmax概率:

model = Sequential()
model.add(embedding_layer)
model.add(LSTM(n_hidden, return_sequences=False))
model.add(Dropout(dropout_keep_prob))
model.add(Dense(vocab_size))
model.add(Activation('softmax'))
optimizer = RMSprop(lr=self.lr)

model.compile(optimizer=optimizer, 
loss='sparse_categorical_crossentropy')

Keras常见问题解答提供了获取中间层here输出的解决方案。另一种解决方案是here。但是,这些答案将中间输出存储在不同的模型中,这不是我需要的。 我想使用自定义指标的logits。自定义指标应包含在model.compile()函数中,以便在培训期间进行评估和显示。所以我不需要在不同模型中分隔Dense图层的输出,而是作为原始模型的一部分。

简而言之,我的问题是:

  • 使用def custom_metric(y_true, y_pred)定义here概述的自定义指标时,y_pred是否包含对数或规范化概率?

  • 如果它包含标准化概率,我如何得到非标准化概率,即Dense层输出的对数?

3 个答案:

答案 0 :(得分:3)

我想我找到了解决方案

首先,我将激活层更改为线性,以便我收到@loannis Nasios概述的logits。

其次,为了仍然将sparse_categorical_crossentropy作为一个损失函数,我定义了自己的损失函数,将from_logits参数设置为true。

model.add(embedding_layer)
model.add(LSTM(n_hidden, return_sequences=False))
model.add(Dropout(dropout_keep_prob))
model.add(Dense(vocab_size))
model.add(Activation('linear'))
optimizer = RMSprop(lr=self.lr)


def my_sparse_categorical_crossentropy(y_true, y_pred):
    return K.sparse_categorical_crossentropy(y_true, y_pred, from_logits=True)

model.compile(optimizer=optimizer,loss=my_sparse_categorical_crossentropy)

答案 1 :(得分:1)

尝试将上次激活从softmax更改为线性

model = Sequential()
model.add(embedding_layer)
model.add(LSTM(n_hidden, return_sequences=False))
model.add(Dropout(dropout_keep_prob))
model.add(Dense(vocab_size))
model.add(Activation('linear'))
optimizer = RMSprop(lr=self.lr)

model.compile(optimizer=optimizer, loss='sparse_categorical_crossentropy')

答案 2 :(得分:-1)

您可以为训练制作模型,为预测制作另一个模型。

对于培训,您可以使用功能API模型并简单地使用现有模型的一部分,将激活放在一边:

model = yourExistingModelWithSoftmax 
modelForTraining = Model(model.input,model.layers[-2].output)

#use your loss function in this model:
modelForTraining.compile(optimizer=optimizer,loss=my_sparse_categorical_crossentropy, metrics=[my_custom_metric])

由于您将一个模型作为另一个模型的一部分,因此它们将共享相同的权重。

  • 如果您想训练,请使用modelForTraining.fit()
  • 如果要预测概率,请使用model.predict()
相关问题