sklearn决策树的BFS遍历

时间:2020-04-20 13:13:48

标签: python-3.x scikit-learn decision-tree

如何对sklearn决策树进行广度优先搜索遍历?

在我的代码中,我尝试了sklearn.tree_库,并使用了诸如tree_.feature和tree_.threshold之类的各种功能来理解树的结构。但是如果我想做bfs,这些功能会遍历树的dfs吗?

假设

clf1 = DecisionTreeClassifier( max_depth = 2 )
clf1 = clf1.fit(x_train, y_train)

这是我的分类器,生成的决策树是

Decision tree

然后我使用以下函数遍历了树

def encoding(clf, features):
l1 = list()
l2 = list()

for i in range(len(clf.tree_.feature)):
    if(clf.tree_.feature[i]>=0):
        l1.append( features[clf.tree_.feature[i]])
        l2.append(clf.tree_.threshold[i])
    else:
        l1.append(None)
        print(np.max(clf.tree_.value))
        l2.append(np.argmax(clf.tree_.value[i]))

l = [l1 , l2]

return np.array(l)

产生的输出是

array([[['address','age',None,None,'age',None,None], [0.5、17.5、2、1、15.5、1、1]],dtype = object) 其中第一个数组是节点的特征,或者如果它离开节点,则它被标记为无,第二个数组是特征节点的阈值,对于类节点,它是类,但是这是树的dfs遍历,我想做bfs遍历,我应该怎么做? 上面的部分已经回答。

我想知道我们是否可以将树以完整的二叉树的形式存储到数组中,从而使第i个节点的子代存储在2i +1和2i +2索引处?

enter image description here

对于上面的树,生成的输出是 array([['address','age',None,None],[0.5,15.5,1,1]],dtype = object)

但所需的输出是

array([[['address',None,'age',None,None,None,None],[0.5,-1,15.5,-1,-1,1,1]],dtype = object)

如果在第一个数组中没有值,而在第二个数组中为-1,则表示该节点不存在。因此,这里是地址的正确子代的年龄为2 * 0 + 2 = 2 数组中的索引,分别在数组的2 * 2 +1 =第5个索引和2 * 2 + 2 =第6个索引中分别找到左和右年龄的孩子。

1 个答案:

答案 0 :(得分:0)

像这样吗?

def reformat_tree(clf):
    tree = clf.tree_

    feature_out = np.full((2 ** tree.max_depth), -1, dtype=tree.feature.dtype)
    threshold_out = np.zeros((2 ** tree.max_depth), dtype=tree.threshold.dtype)

    stack = []
    stack.append((0, 0))

    while stack:
        current_node, new_node = stack.pop()

        feature_out[new_node] = tree.feature[current_node]
        threshold_out[new_node] = tree.threshold[current_node]

        left_child = tree.children_left[current_node]
        if left_child >= 0:
            stack.append((left_child, 2 * current_node + 1))

        right_child = tree.children_right[current_node]
        if right_child >= 0:
            stack.append((right_child, 2 * current_node + 2))

    return feature_out, threshold_out

我无法在您的树上对其进行测试,因为您仍然没有办法重现它,但是它应该可以工作。

该函数以所需格式返回特征和阈值。特征值为-1是该节点不存在,如果该节点是叶则为-2

这可以通过遍历树并跟踪当前位置来实现。

相关问题