在R中使用randomForest包,如何从分类模型中获取概率?

时间:2014-09-07 22:57:44

标签: r machine-learning random-forest predict

TL; DR:

我可以在原始 randomForest 电话中标记,以避免重新运行 predict 函数来获取预测的分类概率,而不仅仅是可能的类别?

详细说明:

我正在使用randomForest包。

我有一个类似的模型:

model <- randomForest(x=out.data[train.rows, feature.cols],
                      y=out.data[train.rows, response.col],
                      xtest=out.data[test.rows, feature.cols],
                      ytest=out.data[test.rows, response.col],
                      importance= TRUE)

其中out.data是数据框,feature.cols是数字和分类要素的混合,而response.colTRUE / FALSE二进制变量,我强迫进入factor,以便randomForest模型将其正确地视为分类。

一切运行良好,变量model正确地返回给我。但是,我似乎找不到要传递给randomForest函数的标记或参数,以便使用 概率 model返回给我TRUEFALSE。相反,我得到的只是预测值。也就是说,如果我查看model$predicted,我会看到类似的内容:

FALSE
FALSE
TRUE
TRUE
FALSE
.
.
.

相反,我希望看到类似的内容:

   FALSE  TRUE
1  0.84   0.16
2  0.66   0.34
3  0.11   0.89
4  0.17   0.83
5  0.92   0.08
.   .      .
.   .      .
.   .      .

我可以得到上述内容,但为了做到这一点,我需要做类似的事情:

tmp <- predict(model, out.data[test.rows, feature.cols], "prob")

[test.rows捕获模型测试期间使用的行号。这里没有显示详细信息,但很简单,因为测试行ID输出到model。]

然后一切正常。 问题 是模型很大并且需要很长时间才能运行,甚至预测本身也需要一段时间。由于预测 应该 完全没必要(我只是想计算测试数据集上的ROC曲线,应该已经计算过的数据集),我是希望跳过这一步。 我可以在原始 randomForest 电话中标记,以避免重新运行 predict 功能?

1 个答案:

答案 0 :(得分:25)

model$predicted NOT predict()返回的内容相同。如果您想要TRUEFALSE类的概率,那么您必须运行predict(),或者传递x,y,xtest,ytest

randomForest(x,y,xtest=x,ytest=y), 

其中x=out.data[, feature.cols], y=out.data[, response.col]

model$predicted根据每个记录在model$votes中具有较大值的类返回类。 votes,正如@joran所指出的那样是来自随机森林的OOB(袋外)'投票'的比例,只有当在OOB样本中选择记录时才进行投票。另一方面,predict()基于所有树的投票返回每个类的真实概率。

使用randomForest(x,y,xtest=x,ytest=y)功能与传递公式或仅randomForest(x,y)时的功能略有不同,如上面给出的示例所示。 randomForest(x,y,xtest=x,ytest=y)将返回每个班级的概率,这可能听起来有点奇怪,但它在model$test$votes下找到,而model$test$predicted下的预测班级,它只是选择基于哪个班级class在model$test$votes中的值较大。此外,使用randomForest(x,y,xtest=x,ytest=y)时,model$predictedmodel$votes的定义与上述相同。

最后,请注意,如果使用randomForest(x,y,xtest=x,ytest=y),则为了使用predict()函数,keep.forest标志应设置为TRUE。

model=randomForest(x,y,xtest=x,ytest=y,keep.forest=TRUE). 
prob=predict(model,x,type="prob")

prob 等同于model$test$votes,因为测试数据输入均为x

相关问题