我如何获得与预测类相反的概率?

时间:2018-12-28 00:02:58

标签: r r-caret

我正在使用扫帚扩充,希望我可以检索课堂概率和预测:

# Statistical Modeling
## dummy vars
training_data <- mtcars
dummy <- caret::dummyVars(~ ., data = training_data, fullRank = T, sep = ".")
training_data <- predict(dummy, mtcars) %>% as.data.frame()
clean_names <- names(training_data) %>% str_replace_all(" |`", "")
names(training_data) <- clean_names

## make target a factor
target <- training_data$mpg
target <- ifelse(target < 20, 0,1) %>% as.factor() %>% make.names()

## custom evaluation metric function
my_summary  <- function(data, lev = NULL, model = NULL){
  a1 <- defaultSummary(data, lev, model)
  b1 <- twoClassSummary(data, lev, model)
  c1 <- prSummary(data, lev, model)
  out <- c(a1, b1, c1)
  out}

## tuning & parameters
set.seed(123)
train_control <- trainControl(
  method = "cv",
  number = 3,
  savePredictions = TRUE,
  verboseIter = TRUE,
  classProbs = TRUE,
  summaryFunction = my_summary
)

linear_model = train(
  x = select(training_data, -mpg), 
  y = target,
  trControl = train_control,
  method = "glm", # logistic regression
  family = "binomial",
  metric = "AUC" # prAUC since using prSummary
)

library(broom)
linear_augment <- augment(linear_model$finalModel)

现在,如果我看一下broom :: augment创建的新数据框,我会发现新功能.fitted有一些负值:

> glimpse(linear_augment)
Observations: 32
Variables: 19
$ .rownames  <chr> "Mazda.RX4", "Mazda.RX4.Wag", "Datsun.710", "Hornet.4.Drive", "Hornet.Sportabout", "Valiant", "Duster.360", "Merc.240D", "Merc.230", ...
$ .outcome   <fct> X1, X1, X1, X1, X0, X0, X0, X1, X1, X0, X0, X0, X0, X0, X0, X0, X0, X1, X1, X1, X1, X0, X0, X0, X0, X1, X1, X1, X0, X0, X0, X1
$ cyl        <dbl> 6, 6, 4, 6, 8, 6, 8, 4, 4, 6, 6, 8, 8, 8, 8, 8, 8, 4, 4, 4, 4, 8, 8, 8, 8, 4, 4, 4, 8, 6, 8, 4
$ disp       <dbl> 160.0, 160.0, 108.0, 258.0, 360.0, 225.0, 360.0, 146.7, 140.8, 167.6, 167.6, 275.8, 275.8, 275.8, 472.0, 460.0, 440.0, 78.7, 75.7, 71...
$ hp         <dbl> 110, 110, 93, 110, 175, 105, 245, 62, 95, 123, 123, 180, 180, 180, 205, 215, 230, 66, 52, 65, 97, 150, 150, 245, 175, 66, 91, 113, 26...
$ drat       <dbl> 3.90, 3.90, 3.85, 3.08, 3.15, 2.76, 3.21, 3.69, 3.92, 3.92, 3.92, 3.07, 3.07, 3.07, 2.93, 3.00, 3.23, 4.08, 4.93, 4.22, 3.70, 2.76, 3...
$ wt         <dbl> 2.620, 2.875, 2.320, 3.215, 3.440, 3.460, 3.570, 3.190, 3.150, 3.440, 3.440, 4.070, 3.730, 3.780, 5.250, 5.424, 5.345, 2.200, 1.615, ...
$ qsec       <dbl> 16.46, 17.02, 18.61, 19.44, 17.02, 20.22, 15.84, 20.00, 22.90, 18.30, 18.90, 17.40, 17.60, 18.00, 17.98, 17.82, 17.42, 19.47, 18.52, ...
$ vs         <dbl> 0, 0, 1, 1, 0, 1, 0, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 1
$ am         <dbl> 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1
$ gear       <dbl> 4, 4, 4, 3, 3, 3, 3, 4, 4, 4, 4, 3, 3, 3, 3, 3, 3, 4, 4, 4, 3, 3, 3, 3, 3, 4, 5, 5, 5, 5, 5, 4
$ carb       <dbl> 4, 4, 1, 1, 2, 1, 4, 2, 2, 4, 4, 3, 3, 3, 4, 4, 4, 1, 2, 1, 1, 2, 2, 4, 2, 1, 2, 2, 4, 6, 8, 2
$ .fitted    <dbl> 40.86100, 24.31240, 42.67493, 22.89140, -25.45002, -23.47658, -24.80498, 26.11860, 25.57239, -28.22688, -24.39062, -119.66717, -91.54...
$ .se.fit    <dbl> 136811.47, 115039.17, 425411.90, 56691.38, 102820.13, 75999.04, 147489.21, 283467.63, 214587.09, 137360.84, 118556.72, 281060.10, 206...
$ .resid     <dbl> 2.107342e-08, 7.432678e-06, 2.107342e-08, 1.512555e-05, -4.208358e-06, -1.128859e-05, -5.810077e-06, 3.012542e-06, 3.958608e-06, -1.0...
$ .hat       <dbl> 4.156093e-06, 9.965104e-01, 4.018459e-05, 9.982376e-01, 2.741217e-01, 9.996313e-01, 9.916032e-01, 9.973865e-01, 9.875200e-01, 3.31231...
$ .sigma     <dbl> 5.991067e-06, NaN, 5.991067e-06, NaN, 5.888378e-06, NaN, NaN, NaN, NaN, 5.986310e-06, NaN, 5.991067e-06, 5.991067e-06, 5.991067e-06, ...
$ .cooksd    <dbl> 8.389507e-23, 2.054949e-07, 8.112262e-22, 3.342117e-06, 4.188107e-13, 4.259059e-05, 2.157994e-08, 6.023639e-08, 4.516228e-09, 1.77508...
$ .std.resid <dbl> 2.107347e-08, 1.258224e-04, 2.107385e-08, 3.602950e-04, -4.939475e-06, -5.878873e-04, -6.340513e-05, 5.892791e-05, 3.543529e-05, -1.0...

如果我在控制台中输入linear_model$pred,则可能返回与每个类相关的概率,但是由于我使用了k折,所以顺序错误。我想有一种“正确的方法”可以从插入符号中提取概率,因为我在火车控制函数中设置了参数classProbs = T。

我找到了此页面:https://rdrr.io/cran/caret/man/predict.train.html

哪个说我可以使用extractProb()提取概率,但是这会导致错误消息:

(我不确定如何正确调用extractProb())

> extractProb(linear_model)
Error: $ operator is invalid for atomic vectors
> extractProb(linear_model$finalModel)
Error: $ operator is invalid for atomic vectors
> extractProb(linear_model$finalModel, testX = linear_model$trainingData)
Error: $ operator is invalid for atomic vectors

如何获取X0和X1的类概率向量?

1 个答案:

答案 0 :(得分:0)

如果您需要使用带有完整训练数据和最终模型的班级概率:

使用extractProb,请注意,它以a list of objects of the class train作为输入:

extractProb(models = list(linear_model))

OR

对训练数据(如M_M所述)使用predict函数

predict(linear_model, type = "prob")

在进行k折叠交叉验证时,您势必会获得k个(即,示例中k = 3)折叠中的重新采样数据。您已经启用了classProbs = TRUE。按rowIndex进行排序/排列,以获得k折的交叉验证概率:

linear_model$pred %>% dplyr::arrange(rowIndex)
相关问题