在决策树中查找最大值

时间:2016-02-24 10:08:11

标签: r machine-learning classification decision-tree party

我用R中的Party包创建了决策树。 我试图获得具有最大值的路线/分支。

它可以是来自箱图Picture 1

的平均值

它可以是来自二叉树Picture 2 http://www.rdatamining.com/_/rsrc/1421496897574/examples/decision-tree/iris_ctree_simple.jpg

的概率值

1 个答案:

答案 0 :(得分:0)

实际上这可以很容易地完成,虽然最大值的定义对于回归树是明确的,但分类并不十分清楚树,因为在每个节点中不同的级别可以拥有它自己的最大值

无论哪种方式,这里都是一个非常简单的辅助函数,可以为每种类型的树返回预测

GetPredicts <- function(ct){
      f <- function(ct, i) nodes(ct, i)[[1]]$prediction
      Terminals <- unique(where(ct))
      Predictions <- sapply(Terminals, f, ct = ct)
      if(is.matrix(Predictions)){
        colnames(Predictions) <- Terminals
        return(Predictions)
       } else {
        return(setNames(Predictions, Terminals))
       }
}

幸运的是,您已经从?ctree的示例中获取了树木,因此我们可以测试它们(下次请提供您自己使用的代码)

回归树(你的第一个树)

## load the package and create the tree
library(party)
airq <- subset(airquality, !is.na(Ozone))
airct <- ctree(Ozone ~ ., data = airq, 
               controls = ctree_control(maxsurrogate = 3))
plot(airct)

现在,测试功能

res <- GetPredicts(airct)
res
#        5        3        6        9        8 
# 18.47917 55.60000 31.14286 48.71429 81.63333 

所以我们得到了每个终端节点的预测。您可以从这里轻松地继续which.max(res)(我会留待您决定)

分类树(您的第二棵树)

irisct <- ctree(Species ~ .,data = iris)
plot(irisct, type = "simple")

运行功能

res <- GetPredicts(irisct)
res
#      2          5   6          7
# [1,] 1 0.00000000 0.0 0.00000000
# [2,] 0 0.97826087 0.5 0.02173913
# [3,] 0 0.02173913 0.5 0.97826087

现在,输出有点难以阅读,因为每个类都有自己的概率。您可以使用

使其更具可读性
row.names(res) <- levels(iris$Species)
res
#            2          5   6          7
# setosa     1 0.00000000 0.0 0.00000000
# versicolor 0 0.97826087 0.5 0.02173913
# virginica  0 0.02173913 0.5 0.97826087

您可以执行以下操作以获得总体最大值

which(res == max(res), arr.ind = TRUE)
#        row col
# setosa   1   1

对于列/行最大值,您可以执行

matrixStats::colMaxs(res)
# [1] 1.0000000 0.9782609 0.5000000 0.9782609
matrixStats::rowMaxs(res)
# [1] 1.0000000 0.9782609 0.9782609

但是,我会再次请你决定如何从这里开始。