pytorch中的多标签分类

时间:2018-10-17 13:16:40

标签: pytorch

我有一个多标签分类问题。我有11个课程,大约4k个示例。每个示例可以具有1到4-5的标签。目前,我正在为每个类别分别使用log_loss训练分类器。如您所料,训练11个分类器需要花费很多时间,我想尝试另一种方法,只训练1个分类器。这个想法是,该分类器的最后一层将有11个节点,并将按类输出实数,然后将其通过S型转换为Proba。我要优化的损失是所有类上log_loss的平均值。

不幸的是,我是某种带有pytorch的菜鸟,即使通过阅读损失的源代码,我也无法弄清楚是否已经存在的损失之一确实满足我的要求,或者我是否应该创建新的损失,如果是这样,我真的不知道该怎么办。

具体来说,我想为批次中的每个元素提供一个大小为11的向量(其中每个标签均包含一个实数(越接近无穷大,则该类的预测值越接近1),并且1个大小为11的向量(每个真实标签均包含1),并且能够计算所有11个标签的均值log_loss,并基于该损失优化分类器。

任何帮助将不胜感激:)

1 个答案:

答案 0 :(得分:5)

您正在寻找torch.nn.BCELoss。这是示例代码:

"use strict";