web-dev-qa-db-ja.com

scikit learn-決定木の特徴の重要度の計算

Sci-kit Learnのディシジョンツリーで機能の重要性がどのように計算されるかを理解しようとしています。この質問は以前に尋ねられましたが、アルゴリズムが提供する結果を再現することはできません。

例えば:

from StringIO import StringIO

from sklearn.datasets import load_iris
from sklearn.tree import DecisionTreeClassifier
from sklearn.tree.export import export_graphviz
from sklearn.feature_selection import mutual_info_classif

X = [[1,0,0], [0,0,0], [0,0,1], [0,1,0]]

y = [1,0,1,1]

clf = DecisionTreeClassifier()
clf.fit(X, y)

feat_importance = clf.tree_.compute_feature_importances(normalize=False)
print("feat importance = " + str(feat_importance))

out = StringIO()
out = export_graphviz(clf, out_file='test/tree.dot')

機能の重要性の結果:

feat importance = [0.25       0.08333333 0.04166667]

そして、次の決定木を与えます:

decision tree

さて、これは answer 同様の質問に対する重要度が次のように計算されることを示唆しています

formula_a

ここで、Gはノード不純物、この場合はジニ不純物です。私が理解した限りでは、これは不純物の削減です。ただし、機能1の場合、これは次のようになります。

formula_b

この answer は、重要度がノードに到達する確率(そのノードに到達するサンプルの割合で概算される)で重み付けされることを示唆しています。繰り返しますが、機能1の場合、これは次のようになります。

formula_c

両方の式は間違った結果を提供します。機能の重要度はどのように正しく計算されますか?

11
Characeae

機能の重要性は実装に依存するため、scikit-learnのドキュメントを参照する必要があります。

機能の重要性。高いほど、機能はより重要になります。フィーチャの重要度は、そのフィーチャによってもたらされる基準の(正規化された)合計削減として計算されます。ジニの重要性としても知られています

その削減または加重情報ゲインは、次のように定義されます。

重み付き不純物減少方程式は次のとおりです。

N_t / N * (impurity - N_t_R / N_t * right_impurity - N_t_L / N_t * left_impurity)

ここで、Nはサンプルの合計数、N_tは現在のノードのサンプル数、N_t_Lは左の子のサンプル数、N_t_Rは右の子のサンプル数です。

http://scikit-learn.org/stable/modules/generated/sklearn.tree.DecisionTreeClassifier.html#sklearn.tree.DecisionTreeClassifier

各機能はケースで1回使用されるため、機能情報は上記の式と等しくなければなりません。

X [2]の場合:

feature_importance = (4 / 4) * (0.375 - (0.75 * 0.444)) = 0.042

X [1]の場合:

feature_importance = (3 / 4) * (0.444 - (2/3 * 0.5)) = 0.083

X [0]の場合:

feature_importance = (2 / 4) * (0.5) = 0.25

15
Seljuk Gülcan

ツリーのさまざまなブランチで単一の機能を使用することができます。その場合、機能の重要性は、不純物の削減における全体的な貢献です。

_feature_importance += number_of_samples_at_parent_where_feature_is_used\*impurity_at_parent-left_child_samples\*impurity_left-right_child_samples\*impurity_right
_

不純物はジニ/エントロピー値です

_normalized_importance = feature_importance/number_of_samples_root_node(total num of samples)
_

上記の例:

_feature_2_importance = 0.375*4-0.444*3-0*1 = 0.16799 , 
normalized = 0.16799/4(total_num_of_samples) = 0.04199
_

_feature_2_が他のブランチで使用された場合、そのような各親ノードでの重要度を計算し、値を合計します。

グラフに表示される切り捨てられた値を使用しているため、計算される機能の重要度とライブラリによって返される重要度に違いがあります。

代わりに、使用される機能、しきい値、不純物、各ノードでのサンプル数などの調査に使用できる分類子の 'tree_'属性を使用して、必要なすべてのデータにアクセスできます。

例:_clf.tree_.feature_は、使用される機能のリストを提供します。負の値は、リーフノードであることを示します。

同様に、_clf.tree_.children_left/right_は、左右の子の_clf.tree_.feature_にインデックスを与えます

上記を使用してツリーをトラバースし、_clf.tree_.impurity & clf.tree_.weighted_n_node_samples_で同じインデックスを使用して、各ノードとその子でgini /エントロピー値とサンプル数を取得します。

_def dt_feature_importance(model,normalize=True):

    left_c = model.tree_.children_left
    right_c = model.tree_.children_right

    impurity = model.tree_.impurity    
    node_samples = model.tree_.weighted_n_node_samples 

    # Initialize the feature importance, those not used remain zero
    feature_importance = np.zeros((model.tree_.n_features,))

    for idx,node in enumerate(model.tree_.feature):
        if node >= 0:
            # Accumulate the feature importance over all the nodes where it's used
            feature_importance[node]+=impurity[idx]*node_samples[idx]- \
                                   impurity[left_c[idx]]*node_samples[left_c[idx]]-\
                                   impurity[right_c[idx]]*node_samples[right_c[idx]]

    # Number of samples at the root node
    feature_importance/=node_samples[0]

    if normalize:
        normalizer = feature_importance.sum()
        if normalizer > 0:
            feature_importance/=normalizer

    return feature_importance
_

この関数は、clf.tree_.compute_feature_importances(normalize=...)によって返される値とまったく同じ値を返します

重要度に基づいて機能をソートするには

_features = clf.tree_.feature[clf.tree_.feature>=0] # Feature number should not be negative, indicates a leaf node
sorted(Zip(features,dt_feature_importance(clf,False)[features]),key=lambda x:x[1],reverse=True)
_
2
bhasuru