遍历sklearn决策树

时间:2020-04-14 07:56:20

标签: python scikit-learn decision-tree

如何对sklearn决策树进行广度优先搜索遍历?

在我的代码中,我尝试了sklearn.tree_库,并使用了诸如tree_.feature和tree_.threshold之类的各种功能来理解树的结构。但是如果我想做bfs,这些功能会遍历树的dfs吗?

假设

$ kubectl exec echo-5955898b58-gvgh9 -- ls /usr/share/nginx/html
50x.html  index.htm  index.html  teste.html

这是我的分类器,生成的决策树是

enter image description here

然后我使用以下函数遍历了树

clf1 = DecisionTreeClassifier( max_depth = 2 )
clf1 = clf1.fit(x_train, y_train)

产生的输出是

def encoding(clf, features):
l1 = list()
l2 = list()

for i in range(len(clf.tree_.feature)):
    if(clf.tree_.feature[i]>=0):
        l1.append( features[clf.tree_.feature[i]])
        l2.append(clf.tree_.threshold[i])
    else:
        l1.append(None)
        print(np.max(clf.tree_.value))
        l2.append(np.argmax(clf.tree_.value[i]))

l = [l1 , l2]

return np.array(l)

其中第一个数组是节点的特征,或者如果它被点头,则被标记为无,第二个数组是特征节点的阈值,对于类节点,它是类,但这是树的dfs遍历,我想做bfs遍历我该怎么办?

由于我是堆栈溢出的新手,请提出如何改进问题描述以及应该添加哪些其他信息(以进一步解释我的问题)的建议。

X_train(样本) X_train

y_train(样本) y_train

1 个答案:

答案 0 :(得分:0)

这应该做到:

from collections import deque

tree = clf.tree_

stack = deque()
stack.append(0)  # push tree root to stack

while stack:
    current_node = stack.popleft()

    # do whatever you want with current node
    # ...

    left_child = tree.children_left[current_node]
    if left_child >= 0:
        stack.append(left_child)

    right_child = tree.children_right[current_node]
    if right_child >= 0:
        stack.append(right_child)

这使用deque来保留要处理的下一堆节点。由于我们从左侧删除了元素,然后在右侧添加了元素,因此这应该表示广度优先遍历。


为实际使用,建议您将其变成发电机:

from collections import deque

def breadth_first_traversal(tree):
    stack = deque()
    stack.append(0)

    while stack:
        current_node = stack.popleft()

        yield current_node

        left_child = tree.children_left[current_node]
        if left_child >= 0:
            stack.append(left_child)

        right_child = tree.children_right[current_node]
        if right_child >= 0:
            stack.append(right_child)

然后,您只需要对原始功能进行最少的更改:

def encoding(clf, features):
    l1 = list()
    l2 = list()

    for i in breadth_first_traversal(clf.tree_):
        if(clf.tree_.feature[i]>=0):
            l1.append( features[clf.tree_.feature[i]])
            l2.append(clf.tree_.threshold[i])
        else:
            l1.append(None)
            print(np.max(clf.tree_.value))
            l2.append(np.argmax(clf.tree_.value[i]))

    l = [l1 , l2]

    return np.array(l)