Keras与标签合并类预测

时间:2019-06-24 15:31:28

标签: python tensorflow keras

在训练我的网络时,我遇到了一个多标签分类问题,其中我将类标签转换为一种热编码。

训练模型并生成预测后-keras只需输出一个值数组,而无需指定类标签。

合并这些内容的最佳实践是什么,以便我的API可以将有意义的结果返回给使用者?

示例

y = pd.get_dummies(df_merged.eventId)
y

2CBC9h3uple1SXxEVy8W    GiiFxmfrUwBNMGgFuoHo    e06onPbpyCucAGXw01mM
12  1                   0                       0
13  1                   0                       0
14  1                   0                       0

prediction = model.predict(pred_test_input)
prediction
array([[0.5002058 , 0.49697363, 0.50251794]], dtype=float32)

所需结果: {results: { 2CBC9h3uple1SXxEVy8W: 0.5002058, ...}

编辑:根据评论添加模型-但这只是一个玩具模型。

model = Sequential()
model.add(
  Embedding(
    input_dim=embeddings_index.shape[0],
    output_dim=embeddings_index.shape[1],
    weights=[embeddings_index],
    input_length=MAX_SEQ_LENGTH,
    trainable=False,
  )
)
model.add(LSTM(300))
model.add(Dense(units=len(y.columns), activation='sigmoid'))
model.compile(loss='categorical_crossentropy', optimizer='sgd', metrics=['accuracy'])

编辑2-加y。

所以我的y的格式如下:

eventId
123
123
234
...

然后我使用y = pd.get_dummies(df_merged.eventId)将其转换为模型可以使用的东西,并希望将eventIds返回到预测中。

2 个答案:

答案 0 :(得分:2)

首先,如果您要进行多标签分类,则应该使用binary_crossentropy损失:

model.compile(loss='binary_crossentropy', optimizer='sgd', metrics=['accuracy'])

那么重要的一点是,keras的准确性没有考虑多标签分类,因此这将是一个误导性指标。每个类的精度/召回率都是更合适的指标。

要获得班级预测,您必须对每个班级的预测进行阈值设置,这是您必须调整的阈值(每个班级不必相同),例如:

class_names = y.columns.tolist()
pred_classes = {}
preds = model.predict(pred_test_input)

thresh = 0.5
for i in range(num_classes):
    if preds[i] > thresh:
        pred_classes[class_name[i]] = preds[i]

这将输出pred_classes词典,其类别超过阈值,并包括一个置信度得分。

答案 1 :(得分:0)

对于分类问题,我们倾向于以Softmax层结尾,该层的作用是为我们提供不同类别上的概率分布。

考虑将模型的体系结构更改为以下内容:

model = Sequential()
model.add(
  Embedding(
    input_dim=embeddings_index.shape[0],
    output_dim=embeddings_index.shape[1],
    weights=[embeddings_index],
    input_length=MAX_SEQ_LENGTH,
    trainable=False,
  )
)
model.add(LSTM(300))
model.add(Dense(units=len(y.columns), activation='sigmoid'))
model.add(Softmax(3))

然后我们可以通过采用其他人用argmax建议的最高值索引来获得预测。

相关问题