scikit-learn:迭代DecisionTreeClassifier的节点

时间:2014-09-04 09:33:45

标签: tree scikit-learn

 12   from sklearn.datasets import load_iris
 13   iris = load_iris()
 14   X    = iris.data
 15   y    = iris.target
 16 
 19   clf  = DecisionTreeClassifier()
 20   clf  = clf.fit(iris.data,iris.target)

如何迭代clf的节点。我无法在文档中找到它。

3 个答案:

答案 0 :(得分:1)

您想对clf的节点做什么?

有一个名为clf.tree_的变量,它包含实际的决策树信息。它在面向用户的文档中记录不足,但您可以read the code更好地了解它的作用。

不幸的是,实际节点数组似乎隐藏在Cython属性中,但您可以使用整数索引0...clf.tree_.node_count作为clf.tree_.feature[i]clf.tree_.threshold[i]等的索引(请参阅链接代码中的文档以获取更多信息)。如果要确定样本所在的节点,可以使用clf.tree_.apply(X)来获取节点的实际整数索引。

答案 1 :(得分:0)

现在,有一个示例说明如何做到这一点in the documentation

在那里,他们使用

迭代树
n_nodes = clf.tree_.node_count
children_left = clf.tree_.children_left
children_right = clf.tree_.children_right
feature = clf.tree_.feature
threshold = clf.tree_.threshold

node_depth = np.zeros(shape=n_nodes, dtype=np.int64)
is_leaves = np.zeros(shape=n_nodes, dtype=bool)
stack = [(0, 0)]  # start with the root node id (0) and its depth (0)
while len(stack) > 0:
    # `pop` ensures each node is only visited once
    node_id, depth = stack.pop()
    node_depth[node_id] = depth

    # If the left and right child of a node is not the same we have a split
    # node
    is_split_node = children_left[node_id] != children_right[node_id]
    # If a split node, append left and right children and depth to `stack`
    # so we can loop through them
    if is_split_node:
        stack.append((children_left[node_id], depth + 1))
        stack.append((children_right[node_id], depth + 1))
    else:
        is_leaves[node_id] = True

print("The binary tree structure has {n} nodes and has "
      "the following tree structure:\n".format(n=n_nodes))
for i in range(n_nodes):
    if is_leaves[i]:
        print("{space}node={node} is a leaf node.".format(
            space=node_depth[i] * "\t", node=i))
    else:
        print("{space}node={node} is a split node: "
              "go to node {left} if X[:, {feature}] <= {threshold} "
              "else to node {right}.".format(
                  space=node_depth[i] * "\t",
                  node=i,
                  left=children_left[i],
                  feature=feature[i],
                  threshold=threshold[i],
                  right=children_right[i]))

答案 2 :(得分:0)

有一个库 pydotplus,它可以更轻松地迭代决策树的节点(或边)。

以下是您如何从代码示例中的拟合分类器中迭代节点:

from sklearn import tree
import pydotplus

dot_data = tree.export_graphviz(clf,
                                feature_names=iris.feature_names,
                                out_file=None,
                                filled=True,
                                rounded=True)

graph = pydotplus.graph_from_dot_data(dot_data)
for node in graph.get_node_list(): # The iteration happens here!
    print(node.to_string())