使用Python中的Keras进行神经网络中的特征重要性图表

时间:2017-07-27 21:47:01

标签: python neural-network keras

我正在使用python(3.6)anaconda(64位)spyder(3.1.2)。我已经使用keras(2.0.6)设置了一个神经网络模型用于回归问题(一个响应,10个变量)。我想知道如何生成像这样的特征重要性图表:

feature importance chart

def base_model():
    model = Sequential()
    model.add(Dense(200, input_dim=10, kernel_initializer='normal', activation='relu'))
    model.add(Dense(1, kernel_initializer='normal'))
    model.compile(loss='mean_squared_error', optimizer = 'adam')
    return model

clf = KerasRegressor(build_fn=base_model, epochs=100, batch_size=5,verbose=0)
clf.fit(X_train,Y_train)

3 个答案:

答案 0 :(得分:7)

我最近正在寻找这个问题的答案,发现一些对我所做的事情有用的东西,并认为分享会有所帮助。我最终使用了permutation importance中的eli5 package模块。它最容易与scikit学习模型一起使用。幸运的是,Keras提供了wrapper for sequential models。如下面的代码所示,使用起来非常简单。

from keras.wrappers.scikit_learn import KerasClassifier, KerasRegressor
import eli5
from eli5.sklearn import PermutationImportance

def base_model():
    model = Sequential()        
    ...
    return model

X = ...
y = ...

my_model = KerasRegressor(build_fn=base_model, **sk_params)    
my_model.fit(X,y)

perm = PermutationImportance(my_model, random_state=1).fit(X,y)
eli5.show_weights(perm, feature_names = X.columns.tolist())

答案 1 :(得分:5)

目前,Keras没有提供任何功能来提取功能重要性。

您可以查看上一个问题: Keras: Any way to get variable importance?

或相关的GoogleGroup:Feature importance

Spoiler:在GoogleGroup中,有人宣布了一个解决此问题的开源项目..

答案 2 :(得分:4)

这是一个相对较旧的帖子,具有相对较旧的答案,因此,我想提出另一个建议,使用SHAP来确定Keras模型的特征重要性。 SHAP同时支持2d和3d数组,而eli5目前仅支持2d数组(因此,如果模型使用需要3d输入的图层,例如LSTMGRUeli5不起作用)。

这是link的示例,说明SHAP如何绘制Keras模型的功能重要性,但万一它被破坏了,下面提供一些示例代码和图表以及(摘自所述链接):


import shap

# load your data here, e.g. X and y
# create and fit your model here

# load JS visualization code to notebook
shap.initjs()

# explain the model's predictions using SHAP
# (same syntax works for LightGBM, CatBoost, scikit-learn and spark models)
explainer = shap.TreeExplainer(model)
shap_values = explainer.shap_values(X)

# visualize the first prediction's explanation (use matplotlib=True to avoid Javascript)
shap.force_plot(explainer.expected_value, shap_values[0,:], X.iloc[0,:])

shap.summary_plot(shap_values, X, plot_type="bar")

enter image description here