web-dev-qa-db-ja.com

scikit learnを使用して構築された決定木を探索する方法

私はを使用して決定木を構築しています

clf = tree.DecisionTreeClassifier()
clf = clf.fit(X_train, Y_train)

これはすべて正常に動作します。ただし、決定木を探索するにはどうすればよいですか?

たとえば、X_trainのどのエントリが特定のリーフに表示されるかを確認するにはどうすればよいですか?

13
eleanora

予測方法を使用する必要があります。

ツリーをトレーニングした後、X値をフィードして出力を予測します。

_from sklearn.datasets import load_iris
from sklearn.tree import DecisionTreeClassifier
clf = DecisionTreeClassifier(random_state=0)
iris = load_iris()
tree = clf.fit(iris.data, iris.target)
tree.predict(iris.data) 
_

出力:

_>>> tree.predict(iris.data)
array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
       1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
       1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
       2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
       2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2])
_

ツリー構造の詳細を取得するには、tree_.__getstate__()を使用できます

「ASCIIアート」画像に変換されたツリー構造

_              0  
        _____________
        1           2
               ______________
               3            12
            _______      _______
            4     7      13   16
           ___   ______        _____
           5 6   8    9        14 15
                      _____
                      10 11
_

配列としてのツリー構造。

_In [38]: tree.tree_.__getstate__()['nodes']
Out[38]: 
array([(1, 2, 3, 0.800000011920929, 0.6666666666666667, 150, 150.0),
       (-1, -1, -2, -2.0, 0.0, 50, 50.0),
       (3, 12, 3, 1.75, 0.5, 100, 100.0),
       (4, 7, 2, 4.949999809265137, 0.16803840877914955, 54, 54.0),
       (5, 6, 3, 1.6500000953674316, 0.04079861111111116, 48, 48.0),
       (-1, -1, -2, -2.0, 0.0, 47, 47.0), 
       (-1, -1, -2, -2.0, 0.0, 1, 1.0),
       (8, 9, 3, 1.5499999523162842, 0.4444444444444444, 6, 6.0),
       (-1, -1, -2, -2.0, 0.0, 3, 3.0),
       (10, 11, 2, 5.449999809265137, 0.4444444444444444, 3, 3.0),
       (-1, -1, -2, -2.0, 0.0, 2, 2.0), 
       (-1, -1, -2, -2.0, 0.0, 1, 1.0),
       (13, 16, 2, 4.850000381469727, 0.042533081285444196, 46, 46.0),
       (14, 15, 1, 3.0999999046325684, 0.4444444444444444, 3, 3.0),
       (-1, -1, -2, -2.0, 0.0, 2, 2.0), 
       (-1, -1, -2, -2.0, 0.0, 1, 1.0),
       (-1, -1, -2, -2.0, 0.0, 43, 43.0)], 
      dtype=[('left_child', '<i8'), ('right_child', '<i8'), 
             ('feature', '<i8'), ('threshold', '<f8'), 
             ('impurity', '<f8'), ('n_node_samples', '<i8'), 
             ('weighted_n_node_samples', '<f8')])
_

どこ:

  • 最初のノード[0]はルートノードです。
  • 内部ノードでは、left_childとright_childが正の値を持ち、現在のノードより大きいノードを参照しています。
  • 葉は、左右の子ノードに対して-1の値を持ちます。
  • ノード1、5、6、8、10、11、14、15、16はリーフです。
  • ノード構造は、深さ優先検索アルゴリズムを使用して構築されます。
  • 機能フィールドは、このサンプルのパスを決定するためにノードで使用されたiris.data機能を示します。
  • しきい値は、特徴に基づいて方向を評価するために使用される値を示します。
  • 葉で不純物が0に達する...葉に達するとすべてのサンプルが同じクラスにあるため。
  • n_node_samplesは、各リーフに到達するサンプルの数を示します。

この情報を使用して、スクリプトの分類ルールとしきい値に従うことにより、各サンプルXをリーフまで簡単に追跡できます。さらに、n_node_samplesを使用すると、各ノードが正しい数のサンプルを取得することを確認する単体テストを実行できます。次に、tree.predictの出力を使用して、各リーフを関連するクラスにマップできます。

13
PabTorre

注:これは答えではなく、考えられる解決策のヒントにすぎません。

最近、私のプロジェクトで同様の問題が発生しました。私の目標は、いくつかの特定のサンプルに対応する一連の決定を抽出することです。意思決定チェーンの最後のステップを記録するだけでよいので、あなたの問題は私のサブセットだと思います。

これまでのところ、実行可能な唯一の解決策はPythonでカスタムpredictメソッドを作成するであり、途中で決定を追跡することです。その理由は、scikit-learnが提供するpredictメソッドは(私が知る限り)これをそのままでは実行できないためです。さらに悪いことに、これはC実装のラッパーであり、カスタマイズが非常に困難です。

私はアンバランスなデータセットを扱っており、気になるサンプル(ポジティブなもの)はまれなので、カスタマイズは私の問題には問題ありません。したがって、最初にsklearn predictを使用してそれらを除外し、次にカスタマイズを使用して意思決定チェーンを取得できます。

ただし、データセットが大きい場合、これは機能しない可能性があります。ツリーを解析してPythonで予測すると、Python速度で実行が遅くなり、(簡単に)スケーリングされません。C実装のカスタマイズにフォールバックする必要がある場合があるためです。

5
zaxliu

ドリュー博士の投稿を少し変更しました。
次のコードは、データフレームとフィッティング後の決定木が与えられると、次を返します。

  • rules_list:ルールのリスト
  • values_path:エントリのリスト(パスを通過する各クラスのエントリ)

    import numpy as np  
    import pandas as pd  
    from sklearn.tree import DecisionTreeClassifier 
    
    def get_rules(dtc, df):
        rules_list = []
        values_path = []
        values = dtc.tree_.value
    
        def RevTraverseTree(tree, node, rules, pathValues):
            '''
            Traverase an skl decision tree from a node (presumably a leaf node)
            up to the top, building the decision rules. The rules should be
            input as an empty list, which will be modified in place. The result
            is a nested list of tuples: (feature, direction (left=-1), threshold).  
            The "tree" is a nested list of simplified tree attributes:
            [split feature, split threshold, left node, right node]
            '''
            # now find the node as either a left or right child of something
            # first try to find it as a left node            
    
            try:
                prevnode = tree[2].index(node)           
                leftright = '<='
                pathValues.append(values[prevnode])
            except ValueError:
                # failed, so find it as a right node - if this also causes an exception, something's really f'd up
                prevnode = tree[3].index(node)
                leftright = '>'
                pathValues.append(values[prevnode])
    
            # now let's get the rule that caused prevnode to -> node
            p1 = df.columns[tree[0][prevnode]]    
            p2 = tree[1][prevnode]    
            rules.append(str(p1) + ' ' + leftright + ' ' + str(p2))
    
            # if we've not yet reached the top, go up the tree one more step
            if prevnode != 0:
                RevTraverseTree(tree, prevnode, rules, pathValues)
    
        # get the nodes which are leaves
        leaves = dtc.tree_.children_left == -1
        leaves = np.arange(0,dtc.tree_.node_count)[leaves]
    
        # build a simpler tree as a nested list: [split feature, split threshold, left node, right node]
        thistree = [dtc.tree_.feature.tolist()]
        thistree.append(dtc.tree_.threshold.tolist())
        thistree.append(dtc.tree_.children_left.tolist())
        thistree.append(dtc.tree_.children_right.tolist())
    
        # get the decision rules for each leaf node & apply them
        for (ind,nod) in enumerate(leaves):
    
            # get the decision rules
            rules = []
            pathValues = []
            RevTraverseTree(thistree, nod, rules, pathValues)
    
            pathValues.insert(0, values[nod])      
            pathValues = list(reversed(pathValues))
    
            rules = list(reversed(rules))
    
            rules_list.append(rules)
            values_path.append(pathValues)
    
        return (rules_list, values_path)
    

それは例に従います:

df = pd.read_csv('df.csv')

X = df[df.columns[:-1]]
y = df['classification']

X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2)

dtc = DecisionTreeClassifier(max_depth=2)
dtc.fit(X_train, y_train)

適合した決定木は次のツリーを生成しました: 幅2の決定木

この時点で、関数を呼び出すだけです。

get_rules(dtc, df)

これは関数が返すものです:

rules = [  
    ['first <= 63.5', 'first <= 43.5'],  
    ['first <= 63.5', 'first > 43.5'],  
    ['first > 63.5', 'second <= 19.700000762939453'],  
    ['first > 63.5', 'second > 19.700000762939453']
]

values = [
    [array([[ 1568.,  1569.]]), array([[ 636.,  241.]]), array([[ 284.,  57.]])],
    [array([[ 1568.,  1569.]]), array([[ 636.,  241.]]), array([[ 352.,  184.]])],
    [array([[ 1568.,  1569.]]), array([[  932.,  1328.]]), array([[ 645.,  620.]])],
    [array([[ 1568.,  1569.]]), array([[  932.,  1328.]]), array([[ 287.,  708.]])]
]

明らかに、値には、パスごとにリーフ値もあります。

3
Federico Ibba

以下のコードは、上位10の機能のプロットを生成するはずです。

import numpy as np
import matplotlib.pyplot as plt

importances = clf.feature_importances_
std = np.std(clf.feature_importances_,axis=0)
indices = np.argsort(importances)[::-1]

# Print the feature ranking
print("Feature ranking:")

for f in range(10):
    print("%d. feature %d (%f)" % (f + 1, indices[f], importances[indices[f]]))

# Plot the feature importances of the forest
plt.figure()
plt.title("Feature importances")
plt.bar(range(10), importances[indices],
       color="r", yerr=std[indices], align="center")
plt.xticks(range(10), indices)
plt.xlim([-1, 10])
plt.show()

here から取得し、 DecisionTreeClassifier に合うように少し変更しました。

これは、ツリーを探索するのに正確には役立ちませんが、ツリーについては教えてくれます。

3
Charlie Haley

このコードはあなたが望むことを正確に行います。ここで、nX_trainの観測数です。最後に、(n、number_of_leaves)サイズの配列leaf_observationsは、各リーフの観測値を取得するためにX_trainにインデックスを付けるためのブール値を各列に保持します。 leaf_observationsの各列は、リーフのノードIDを持つleavesの要素に対応します。

# get the nodes which are leaves
leaves = clf.tree_.children_left == -1
leaves = np.arange(0,clf.tree_.node_count)[leaves]

# loop through each leaf and figure out the data in it
leaf_observations = np.zeros((n,len(leaves)),dtype=bool)
# build a simpler tree as a nested list: [split feature, split threshold, left node, right node]
thistree = [clf.tree_.feature.tolist()]
thistree.append(clf.tree_.threshold.tolist())
thistree.append(clf.tree_.children_left.tolist())
thistree.append(clf.tree_.children_right.tolist())
# get the decision rules for each leaf node & apply them
for (ind,nod) in enumerate(leaves):
    # get the decision rules in numeric list form
    rules = []
    RevTraverseTree(thistree, nod, rules)
    # convert & apply to the data by sequentially &ing the rules
    thisnode = np.ones(n,dtype=bool)
    for rule in rules:
        if rule[1] == 1:
            thisnode = np.logical_and(thisnode,X_train[:,rule[0]] > rule[2])
        else:
            thisnode = np.logical_and(thisnode,X_train[:,rule[0]] <= rule[2])
    # get the observations that obey all the rules - they are the ones in this leaf node
    leaf_observations[:,ind] = thisnode

これには、ここで定義されているヘルパー関数が必要です。ヘルパー関数は、指定されたノードから開始してツリーを再帰的に走査し、決定規則を構築します。

def RevTraverseTree(tree, node, rules):
    '''
    Traverase an skl decision tree from a node (presumably a leaf node)
    up to the top, building the decision rules. The rules should be
    input as an empty list, which will be modified in place. The result
    is a nested list of tuples: (feature, direction (left=-1), threshold).  
    The "tree" is a nested list of simplified tree attributes:
    [split feature, split threshold, left node, right node]
    '''
    # now find the node as either a left or right child of something
    # first try to find it as a left node
    try:
        prevnode = tree[2].index(node)
        leftright = -1
    except ValueError:
        # failed, so find it as a right node - if this also causes an exception, something's really f'd up
        prevnode = tree[3].index(node)
        leftright = 1
    # now let's get the rule that caused prevnode to -> node
    rules.append((tree[0][prevnode],leftright,tree[1][prevnode]))
    # if we've not yet reached the top, go up the tree one more step
    if prevnode != 0:
        RevTraverseTree(tree, prevnode, rules)
3
Dr. Drew

簡単なオプションは、訓練された決定木の適用方法を使用することだと思います。ツリーをトレーニングし、traindataを適用して、返されたインデックスからルックアップテーブルを作成します。

import numpy as np
from sklearn.tree import DecisionTreeClassifier
from sklearn.datasets import load_iris

iris = load_iris()
clf = DecisionTreeClassifier()
clf = clf.fit(iris.data, iris.target)

# apply training data to decision tree
leaf_indices = clf.apply(iris.data)
lookup = {}

# build lookup table
for i, leaf_index in enumerate(leaf_indices):
    try:
        lookup[leaf_index].append(iris.data[i])
    except KeyError:
        lookup[leaf_index] = []
        lookup[leaf_index].append(iris.data[i])

# test
unkown_sample = [[4., 3.1, 6.1, 1.2]]
index = clf.apply(unkown_sample)
print(lookup[index[0]])
1
maltesar

DecisionTreeをgraphvizの.dotファイル[1]にダンプしてから、graph_tool [2]でロードしてみましたか。

import numpy as np
from sklearn.tree import DecisionTreeClassifier
from sklearn.datasets import load_iris
from graph_tool.all import *

iris = load_iris()
clf = DecisionTreeClassifier()
clf = clf.fit(iris.data, iris.target)

tree.export_graphviz(clf,out_file='tree.dot')

#load graph with graph_tool and explore structure as you please
g = load_graph('tree.dot')

for v in g.vertices():
   for e in v.out_edges():
       print(e)
   for w in v.out_neighbours():
       print(w)

[1] http://scikit-learn.org/stable/modules/generated/sklearn.tree.export_graphviz.html

[2] https://graph-tool.skewed.de/

0
roj4s