BERT 模型如何选择标签排序?

时间:2021-04-21 06:15:39

标签: pytorch bert-language-model huggingface-transformers logits

我正在为分类任务训练 BertForSequenceClassification。我的数据集由“包含不利影响”(1)和“不包含不利影响”(0)组成。数据集包含所有 1,然后是 0(数据未打乱)。对于训练,我已经洗牌了我的数据并获得了 logits。据我所知,logits 是 softmax 之前的概率分布。一个示例 logit 是 [-4.673831, 4.7095485]。第一个值是否对应于标签 1(包含 AE),因为它首先出现在数据集中,还是标签 0。感谢任何帮助。

1 个答案:

答案 0 :(得分:1)

第一个值对应于标签 0,第二个值对应于标签 1。BertForSequenceClassification 的作用是将池化器的输出馈送到线性层(在我将在本答案中忽略的 dropout 之后)。我们来看下面的例子:

from torch import nn
from transformers import BertModel, BertTokenizer
t = BertTokenizer.from_pretrained('bert-base-uncased')
m = BertModel.from_pretrained('bert-base-uncased')
i = t.encode_plus('This is an example.', return_tensors='pt')
o = m(**i)
print(o.pooler_output.shape)

输出:

torch.Size([1, 768])

pooled_output 是形状为 [batch_size,hidden_​​size] 的张量,表示输入序列的上下文化(即应用了注意力)[CLS] 标记。该张量被馈送到线性层以计算序列的 logits

classificationLayer = nn.Linear(768,2)
logits = classificationLayer(o.pooler_output)

当我们对这些 logits 进行归一化时,我们可以看到线性层预测我们的输入应该属于标签 1:

print(nn.functional.softmax(logits,dim=-1))

输出(会有所不同,因为线性层是随机初始化的):

tensor([[0.1679, 0.8321]], grad_fn=<SoftmaxBackward>)

线性层应用线性变换:y=xA^T+b,您已经可以看到线性层不知道您的标签。它“仅”具有大小为 [2,768] 的权重矩阵以生成大小为 [1,2] 的对数(即:第一行对应于第一个值,第二行对应于第二个值):

import torch:

logitsOwnCalculation = torch.matmul(o.pooler_output,  classificationLayer.weight.transpose(0,1))+classificationLayer.bias
print(nn.functional.softmax(logitsOwnCalculation,dim=-1))

输出:

tensor([[0.1679, 0.8321]], grad_fn=<SoftmaxBackward>)

BertForSequenceClassification 模型通过应用 CrossEntropyLoss 进行学习。当某个类(在您的例子中为标签)的 logits 仅与期望值略有不同时,此损失函数会产生较小的损失。这意味着 CrossEntropyLoss 是让您的模型学习第一个 logit 在输入 does not contain adverse effect 时应该高或在 contains adverse effect 时小的那个。您可以使用以下内容检查我们的示例:

loss_fct = nn.CrossEntropyLoss()
label0 = torch.tensor([0]) #does not contain adverse effect
label1 = torch.tensor([1]) #contains adverse effect
print(loss_fct(logits, label0))
print(loss_fct(logits, label1))

输出:

tensor(1.7845, grad_fn=<NllLossBackward>)
tensor(0.1838, grad_fn=<NllLossBackward>)
相关问题