打印DecisionTree的功能名称而不是列号

时间:2018-02-20 13:31:45

标签: python machine-learning scikit-learn

如何在决策树输出中打印列名称,例如要素1,要素2,要素3,要素4,要素5或要素6,而不是-2。 答案应该是可扩展的,例如,如果某人有500列或更多列。

import numpy as np
import pandas as pd
from sklearn.datasets import make_classification
from sklearn.ensemble import RandomForestClassifier

X, y = make_classification(n_samples=1000,
                           n_features=6,
                           n_informative=3,
                           n_classes=2,
                           random_state=0,
                           shuffle=False)

# Creating a dataFrame
df = pd.DataFrame({'Feature 1':X[:,0],
                                  'Feature 2':X[:,1],
                                  'Feature 3':X[:,2],
                                  'Feature 4':X[:,3],
                                  'Feature 5':X[:,4],
                                  'Feature 6':X[:,5],
                                  'Class':y})


y_train = df['Class']
X_train = df.drop('Class',axis = 1)


from sklearn.tree import _tree
# Using those arrays, we can parse the tree structure:

n_nodes = dt.tree_.node_count
children_left = dt.tree_.children_left
children_right = dt.tree_.children_right
feature = dt.tree_.feature
threshold = dt.tree_.threshold

new_X = np.array(X_train)

node_indicator = dt.decision_path(new_X)

# Similarly, we can also have the leaves ids reached by each sample.

leave_id = dt.apply(new_X)

# Now, it's possible to get the tests that were used to predict a sample or
# a group of samples. First, let's make it for the sample.

sample_id = 703
node_index = node_indicator.indices[node_indicator.indptr[sample_id]:
                                    node_indicator.indptr[sample_id + 1]]

print('Rules used to predict sample %s: ' % sample_id)
for node_id in node_index:
    if leave_id[sample_id] != node_id:
        continue

    if (new_X[sample_id, feature[node_id]] <= threshold[node_id]):
        threshold_sign = "<="
    else:
        threshold_sign = ">"

    print("decision id node %s : (For sample number [%s, %s] (= %s) %s %s)"
          % (node_id,
             sample_id,
             feature[node_id],
             new_X[sample_id, feature[node_id]],
             threshold_sign,
             threshold[node_id]))
  

用于预测样本703的规则:   决策ID节点11 :(样本号[703,-2](= -0.210092480919)> -2.0)

1 个答案:

答案 0 :(得分:0)

只需将feature[node_id]替换为df.columns[feature[node_id]],就像这样:

print("decision id node %s : (For sample number [%s, %s] (= %s) %s %s)"
      % (node_id,
         sample_id,
         X_train.columns[feature[node_id]],
         new_X[sample_id, feature[node_id]],
         threshold_sign,
         threshold[node_id]))
相关问题