决策树-查找遍历树时常量预测如何变化

时间:2018-11-26 20:04:21

标签: python machine-learning scikit-learn

假设我具有以下DecisionTreeClassifier模型:

from sklearn.tree import DecisionTreeClassifier
from sklearn.datasets import load_breast_cancer

bunch = load_breast_cancer()

X, y = bunch.data, bunch.target

model = DecisionTreeClassifier(random_state=100)
model.fit(X, y)

我想遍历此树中的每个节点(叶子和决策),并确定遍历树时预测值如何变化。基本上,我想对给定的样本讲出如何确定最终预测(.predict返回的结果)。因此,也许最终可以对样本进行1预测,但是遍历四个节点,并且在每个节点上,其“常数”(在scikit文档中使用的语言)预测从10到{{ 1}}至0

现在还不清楚我如何从1获取该信息,该信息被描述为:

model.tree_.value

在这种情况下,看起来像这样:

 |  value : array of double, shape [node_count, n_outputs, max_n_classes]
 |      Contains the constant prediction value of each node.

有人知道我该怎么做吗?上面43个节点中的每个节点的类预测是否只是每个列表的argmax?那么1,1,1,1,1,1,1,0,0,...,是从上到下的?

1 个答案:

答案 0 :(得分:1)

一种解决方案是直接走到树上的决策路径。 您可以调整this solution,使其像子句一样打印整个决策树。 这是一个快速的改编来解释一个实例:

def tree_path(instance, values, left, right, threshold, features, node, depth):
    spacer = '    ' * depth
    if (threshold[node] != _tree.TREE_UNDEFINED):
        if instance[features[node]] <= threshold[node]:
            path = f'{spacer}{features[node]} ({round(instance[features[node]], 2)}) <= {round(threshold[node], 2)}'
            next_node = left[node]
        else:
            path = f'{spacer}{features[node]} ({round(instance[features[node]], 2)}) > {round(threshold[node], 2)}'
            next_node = right[node]
        return path + '\n' + tree_path(instance, values, left, right, threshold, features, next_node, depth+1)
    else:
        target = values[node]
        for i, v in zip(np.nonzero(target)[1],
                        target[np.nonzero(target)]):
            target_count = int(v)
            return spacer + "==> " + str(round(target[0][0], 2)) + \
                   " ( " + str(target_count) + " examples )"

def get_path_code(tree, feature_names, instance):
    left      = tree.tree_.children_left
    right     = tree.tree_.children_right
    threshold = tree.tree_.threshold
    features  = [feature_names[i] for i in tree.tree_.feature]
    values = tree.tree_.value
    return tree_path(instance, values, left, right, threshold, features, 0, 0)

# print the decision path of the first intance of a panda dataframe df
print(get_path_code(tree, df.columns, df.iloc[0]))