在sklearn DecisionTreeClassifier中修剪不必要的叶子

时间:2018-07-18 08:28:14

标签: python scikit-learn decision-tree pruning

我使用sklearn.tree.DecisionTreeClassifier来构建决策树。使用最佳参数设置,我得到的树上有多余的叶子(请参见下面的示例图片-我不需要概率,因此标有红色的叶子节点是不必要的分割)

Tree

是否存在用于修剪这些不必要节点的第三方库?还是代码片段?我可以写一个,但我无法想象我是第一个遇到这个问题的人...

要复制的代码:

from sklearn.tree import DecisionTreeClassifier
from sklearn import datasets
iris = datasets.load_iris()
X = iris.data
y = iris.target
mdl = DecisionTreeClassifier(max_leaf_nodes=8)
mdl.fit(X,y)

PS:我尝试了多个关键字搜索,但是却一无所获-sklearn中是否真的没有通常的后修剪功能?

PPS:为应对可能的重复:尽管the suggested question可能在我自己编写修剪算法时为我提供帮助,但它回答了一个不同的问题-我想摆脱那些不会改变最终决定的问题,而另一个问题是要有一个最小的分割节点阈值。

PPPS:显示的树是显示我的问题的示例。我知道以下事实:创建树的参数设置不理想。我并不是要优化这种特定的树,我需要进行修剪后的处理以除去可能需要类概率的叶子,而如果只对最可能的类感兴趣的叶子则无济于事。

3 个答案:

答案 0 :(得分:5)

使用ncfirth的链接,我能够在那里修改代码,使其适合我的问题:

from sklearn.tree._tree import TREE_LEAF

def is_leaf(inner_tree, index):
    # Check whether node is leaf node
    return (inner_tree.children_left[index] == TREE_LEAF and 
            inner_tree.children_right[index] == TREE_LEAF)

def prune_index(inner_tree, decisions, index=0):
    # Start pruning from the bottom - if we start from the top, we might miss
    # nodes that become leaves during pruning.
    # Do not use this directly - use prune_duplicate_leaves instead.
    if not is_leaf(inner_tree, inner_tree.children_left[index]):
        prune_index(inner_tree, decisions, inner_tree.children_left[index])
    if not is_leaf(inner_tree, inner_tree.children_right[index]):
        prune_index(inner_tree, decisions, inner_tree.children_right[index])

    # Prune children if both children are leaves now and make the same decision:     
    if (is_leaf(inner_tree, inner_tree.children_left[index]) and
        is_leaf(inner_tree, inner_tree.children_right[index]) and
        (decisions[index] == decisions[inner_tree.children_left[index]]) and 
        (decisions[index] == decisions[inner_tree.children_right[index]])):
        # turn node into a leaf by "unlinking" its children
        inner_tree.children_left[index] = TREE_LEAF
        inner_tree.children_right[index] = TREE_LEAF
        ##print("Pruned {}".format(index))

def prune_duplicate_leaves(mdl):
    # Remove leaves if both 
    decisions = mdl.tree_.value.argmax(axis=2).flatten().tolist() # Decision for each node
    prune_index(mdl.tree_, decisions)

在DecisionTreeClassifier clf上使用它:

prune_duplicate_leaves(clf)

编辑:修复了更复杂树木的错误

答案 1 :(得分:0)

DecisionTreeClassifier(max_leaf_nodes=8)指定最多(最多)8片叶子,因此,除非树构建器有其他原因停止它,否则它将达到最大数量。

在所示示例中,与其他3个叶子(> 50)相比,这8个叶子中的5个具有非常少量的样本(<= 3),这可能是过度拟合的迹象。 无需在训练后修剪树,可以指定min_samples_leafmin_samples_split来更好地指导训练,这样可以摆脱有问题的叶子。例如,对于至少5%的样本,请使用值0.05

答案 2 :(得分:0)

我在这里贴的代码有问题,所以我修改了它,不得不添加一小部分(它处理双方相同但仍然存在比较的情况):

from sklearn.tree._tree import TREE_LEAF, TREE_UNDEFINED

def is_leaf(inner_tree, index):
    # Check whether node is leaf node
    return (inner_tree.children_left[index] == TREE_LEAF and 
            inner_tree.children_right[index] == TREE_LEAF)

def prune_index(inner_tree, decisions, index=0):
    # Start pruning from the bottom - if we start from the top, we might miss
    # nodes that become leaves during pruning.
    # Do not use this directly - use prune_duplicate_leaves instead.
    if not is_leaf(inner_tree, inner_tree.children_left[index]):
        prune_index(inner_tree, decisions, inner_tree.children_left[index])
    if not is_leaf(inner_tree, inner_tree.children_right[index]):
        prune_index(inner_tree, decisions, inner_tree.children_right[index])

    # Prune children if both children are leaves now and make the same decision:     
    if (is_leaf(inner_tree, inner_tree.children_left[index]) and
        is_leaf(inner_tree, inner_tree.children_right[index]) and
        (decisions[index] == decisions[inner_tree.children_left[index]]) and 
        (decisions[index] == decisions[inner_tree.children_right[index]])):
        # turn node into a leaf by "unlinking" its children
        inner_tree.children_left[index] = TREE_LEAF
        inner_tree.children_right[index] = TREE_LEAF
        inner_tree.feature[index] = TREE_UNDEFINED
        ##print("Pruned {}".format(index))

def prune_duplicate_leaves(mdl):
    # Remove leaves if both 
    decisions = mdl.tree_.value.argmax(axis=2).flatten().tolist() # Decision for each node
    prune_index(mdl.tree_, decisions)
相关问题