bertfor序列分类模型评估

时间:2020-10-31 21:33:52

标签: python nlp pytorch bert-language-model huggingface-transformers

我正在与bertforsequenceclassification mdoel一起处理文本分类问题,我正在尝试评估我使用的代码模型:

from sklearn.metrics import classification_report
from sklearn.metrics import accuracy_score
pridicted=model.predict(x_test)

它给了我错误:ModuleAttributeError:'BertForSequenceClassification'对象没有属性'predict'

我尝试使用以下代码(代码的相关快照),但是除了训练损失之外,它没有打印出任何东西:

config = AutoConfig.from_pretrained(model_name,num_labels=num_labels, output_attentions=True,gradient_checkpointing=True,keep_check_point_max=0 )
tokenizer = AutoTokenizer.from_pretrained(model_name, do_lower_case=False, do_basic_tokenize=True, never_split=never_split_tokens)
model = BertForSequenceClassification.from_pretrained(model_name,config=config)

def compute_metrics(p): #p should be of type EvalPrediction
  preds = np.argmax(p.predictions, axis=1)
  assert len(preds) == len(p.label_ids)
  print(classification_report(p.label_ids,preds))
  print(confusion_matrix(p.label_ids,preds))

  f1_Positive = f1_score(p.label_ids,preds,pos_label=1,average='binary')
  f1_Negative = f1_score(p.label_ids,preds,pos_label=0,average='binary')
  macro_f1 = f1_score(p.label_ids,preds,average='macro')
  macro_precision = precision_score(p.label_ids,preds,average='macro')
  macro_recall = recall_score(p.label_ids,preds,average='macro')
  acc = accuracy_score(p.label_ids,preds)
  return {
      'f1_pos': f1_Positive,
      'f1_neg': f1_Negative,
      'macro_f1' : macro_f1, 
      'macro_precision': macro_precision,
      'macro_recall': macro_recall,
      'accuracy': acc
  }

trainer = Trainer(model=model,
                  args = training_args,
                  train_dataset = train_features,
                  eval_dataset = test_features,
                  compute_metrics = compute_metrics)
trainer.train()

有什么方法可以评估这个模型。

这是显示的内容:

Step    Training Loss
497     0.048466
994     0.020653
1491    0.013453
1988    0.005992
2485    0.004062

0 个答案:

没有答案