私はを使用して決定木を構築しています
clf = tree.DecisionTreeClassifier()
clf = clf.fit(X_train, Y_train)
これはすべて正常に動作します。ただし、決定木を探索するにはどうすればよいですか?
たとえば、X_trainのどのエントリが特定のリーフに表示されるかを確認するにはどうすればよいですか?
予測方法を使用する必要があります。
ツリーをトレーニングした後、X値をフィードして出力を予測します。
_from sklearn.datasets import load_iris
from sklearn.tree import DecisionTreeClassifier
clf = DecisionTreeClassifier(random_state=0)
iris = load_iris()
tree = clf.fit(iris.data, iris.target)
tree.predict(iris.data)
_
出力:
_>>> tree.predict(iris.data)
array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2])
_
ツリー構造の詳細を取得するには、tree_.__getstate__()
を使用できます
「ASCIIアート」画像に変換されたツリー構造
_ 0
_____________
1 2
______________
3 12
_______ _______
4 7 13 16
___ ______ _____
5 6 8 9 14 15
_____
10 11
_
配列としてのツリー構造。
_In [38]: tree.tree_.__getstate__()['nodes']
Out[38]:
array([(1, 2, 3, 0.800000011920929, 0.6666666666666667, 150, 150.0),
(-1, -1, -2, -2.0, 0.0, 50, 50.0),
(3, 12, 3, 1.75, 0.5, 100, 100.0),
(4, 7, 2, 4.949999809265137, 0.16803840877914955, 54, 54.0),
(5, 6, 3, 1.6500000953674316, 0.04079861111111116, 48, 48.0),
(-1, -1, -2, -2.0, 0.0, 47, 47.0),
(-1, -1, -2, -2.0, 0.0, 1, 1.0),
(8, 9, 3, 1.5499999523162842, 0.4444444444444444, 6, 6.0),
(-1, -1, -2, -2.0, 0.0, 3, 3.0),
(10, 11, 2, 5.449999809265137, 0.4444444444444444, 3, 3.0),
(-1, -1, -2, -2.0, 0.0, 2, 2.0),
(-1, -1, -2, -2.0, 0.0, 1, 1.0),
(13, 16, 2, 4.850000381469727, 0.042533081285444196, 46, 46.0),
(14, 15, 1, 3.0999999046325684, 0.4444444444444444, 3, 3.0),
(-1, -1, -2, -2.0, 0.0, 2, 2.0),
(-1, -1, -2, -2.0, 0.0, 1, 1.0),
(-1, -1, -2, -2.0, 0.0, 43, 43.0)],
dtype=[('left_child', '<i8'), ('right_child', '<i8'),
('feature', '<i8'), ('threshold', '<f8'),
('impurity', '<f8'), ('n_node_samples', '<i8'),
('weighted_n_node_samples', '<f8')])
_
どこ:
この情報を使用して、スクリプトの分類ルールとしきい値に従うことにより、各サンプルXをリーフまで簡単に追跡できます。さらに、n_node_samplesを使用すると、各ノードが正しい数のサンプルを取得することを確認する単体テストを実行できます。次に、tree.predictの出力を使用して、各リーフを関連するクラスにマップできます。
注:これは答えではなく、考えられる解決策のヒントにすぎません。
最近、私のプロジェクトで同様の問題が発生しました。私の目標は、いくつかの特定のサンプルに対応する一連の決定を抽出することです。意思決定チェーンの最後のステップを記録するだけでよいので、あなたの問題は私のサブセットだと思います。
これまでのところ、実行可能な唯一の解決策はPythonでカスタムpredict
メソッドを作成するであり、途中で決定を追跡することです。その理由は、scikit-learnが提供するpredict
メソッドは(私が知る限り)これをそのままでは実行できないためです。さらに悪いことに、これはC実装のラッパーであり、カスタマイズが非常に困難です。
私はアンバランスなデータセットを扱っており、気になるサンプル(ポジティブなもの)はまれなので、カスタマイズは私の問題には問題ありません。したがって、最初にsklearn predict
を使用してそれらを除外し、次にカスタマイズを使用して意思決定チェーンを取得できます。
ただし、データセットが大きい場合、これは機能しない可能性があります。ツリーを解析してPythonで予測すると、Python速度で実行が遅くなり、(簡単に)スケーリングされません。C実装のカスタマイズにフォールバックする必要がある場合があるためです。
ドリュー博士の投稿を少し変更しました。
次のコードは、データフレームとフィッティング後の決定木が与えられると、次を返します。
values_path:エントリのリスト(パスを通過する各クラスのエントリ)
import numpy as np
import pandas as pd
from sklearn.tree import DecisionTreeClassifier
def get_rules(dtc, df):
rules_list = []
values_path = []
values = dtc.tree_.value
def RevTraverseTree(tree, node, rules, pathValues):
'''
Traverase an skl decision tree from a node (presumably a leaf node)
up to the top, building the decision rules. The rules should be
input as an empty list, which will be modified in place. The result
is a nested list of tuples: (feature, direction (left=-1), threshold).
The "tree" is a nested list of simplified tree attributes:
[split feature, split threshold, left node, right node]
'''
# now find the node as either a left or right child of something
# first try to find it as a left node
try:
prevnode = tree[2].index(node)
leftright = '<='
pathValues.append(values[prevnode])
except ValueError:
# failed, so find it as a right node - if this also causes an exception, something's really f'd up
prevnode = tree[3].index(node)
leftright = '>'
pathValues.append(values[prevnode])
# now let's get the rule that caused prevnode to -> node
p1 = df.columns[tree[0][prevnode]]
p2 = tree[1][prevnode]
rules.append(str(p1) + ' ' + leftright + ' ' + str(p2))
# if we've not yet reached the top, go up the tree one more step
if prevnode != 0:
RevTraverseTree(tree, prevnode, rules, pathValues)
# get the nodes which are leaves
leaves = dtc.tree_.children_left == -1
leaves = np.arange(0,dtc.tree_.node_count)[leaves]
# build a simpler tree as a nested list: [split feature, split threshold, left node, right node]
thistree = [dtc.tree_.feature.tolist()]
thistree.append(dtc.tree_.threshold.tolist())
thistree.append(dtc.tree_.children_left.tolist())
thistree.append(dtc.tree_.children_right.tolist())
# get the decision rules for each leaf node & apply them
for (ind,nod) in enumerate(leaves):
# get the decision rules
rules = []
pathValues = []
RevTraverseTree(thistree, nod, rules, pathValues)
pathValues.insert(0, values[nod])
pathValues = list(reversed(pathValues))
rules = list(reversed(rules))
rules_list.append(rules)
values_path.append(pathValues)
return (rules_list, values_path)
それは例に従います:
df = pd.read_csv('df.csv')
X = df[df.columns[:-1]]
y = df['classification']
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2)
dtc = DecisionTreeClassifier(max_depth=2)
dtc.fit(X_train, y_train)
適合した決定木は次のツリーを生成しました: 幅2の決定木
この時点で、関数を呼び出すだけです。
get_rules(dtc, df)
これは関数が返すものです:
rules = [
['first <= 63.5', 'first <= 43.5'],
['first <= 63.5', 'first > 43.5'],
['first > 63.5', 'second <= 19.700000762939453'],
['first > 63.5', 'second > 19.700000762939453']
]
values = [
[array([[ 1568., 1569.]]), array([[ 636., 241.]]), array([[ 284., 57.]])],
[array([[ 1568., 1569.]]), array([[ 636., 241.]]), array([[ 352., 184.]])],
[array([[ 1568., 1569.]]), array([[ 932., 1328.]]), array([[ 645., 620.]])],
[array([[ 1568., 1569.]]), array([[ 932., 1328.]]), array([[ 287., 708.]])]
]
明らかに、値には、パスごとにリーフ値もあります。
以下のコードは、上位10の機能のプロットを生成するはずです。
import numpy as np
import matplotlib.pyplot as plt
importances = clf.feature_importances_
std = np.std(clf.feature_importances_,axis=0)
indices = np.argsort(importances)[::-1]
# Print the feature ranking
print("Feature ranking:")
for f in range(10):
print("%d. feature %d (%f)" % (f + 1, indices[f], importances[indices[f]]))
# Plot the feature importances of the forest
plt.figure()
plt.title("Feature importances")
plt.bar(range(10), importances[indices],
color="r", yerr=std[indices], align="center")
plt.xticks(range(10), indices)
plt.xlim([-1, 10])
plt.show()
here から取得し、 DecisionTreeClassifier に合うように少し変更しました。
これは、ツリーを探索するのに正確には役立ちませんが、ツリーについては教えてくれます。
このコードはあなたが望むことを正確に行います。ここで、n
はX_train
の観測数です。最後に、(n、number_of_leaves)サイズの配列leaf_observations
は、各リーフの観測値を取得するためにX_train
にインデックスを付けるためのブール値を各列に保持します。 leaf_observations
の各列は、リーフのノードIDを持つleaves
の要素に対応します。
# get the nodes which are leaves
leaves = clf.tree_.children_left == -1
leaves = np.arange(0,clf.tree_.node_count)[leaves]
# loop through each leaf and figure out the data in it
leaf_observations = np.zeros((n,len(leaves)),dtype=bool)
# build a simpler tree as a nested list: [split feature, split threshold, left node, right node]
thistree = [clf.tree_.feature.tolist()]
thistree.append(clf.tree_.threshold.tolist())
thistree.append(clf.tree_.children_left.tolist())
thistree.append(clf.tree_.children_right.tolist())
# get the decision rules for each leaf node & apply them
for (ind,nod) in enumerate(leaves):
# get the decision rules in numeric list form
rules = []
RevTraverseTree(thistree, nod, rules)
# convert & apply to the data by sequentially &ing the rules
thisnode = np.ones(n,dtype=bool)
for rule in rules:
if rule[1] == 1:
thisnode = np.logical_and(thisnode,X_train[:,rule[0]] > rule[2])
else:
thisnode = np.logical_and(thisnode,X_train[:,rule[0]] <= rule[2])
# get the observations that obey all the rules - they are the ones in this leaf node
leaf_observations[:,ind] = thisnode
これには、ここで定義されているヘルパー関数が必要です。ヘルパー関数は、指定されたノードから開始してツリーを再帰的に走査し、決定規則を構築します。
def RevTraverseTree(tree, node, rules):
'''
Traverase an skl decision tree from a node (presumably a leaf node)
up to the top, building the decision rules. The rules should be
input as an empty list, which will be modified in place. The result
is a nested list of tuples: (feature, direction (left=-1), threshold).
The "tree" is a nested list of simplified tree attributes:
[split feature, split threshold, left node, right node]
'''
# now find the node as either a left or right child of something
# first try to find it as a left node
try:
prevnode = tree[2].index(node)
leftright = -1
except ValueError:
# failed, so find it as a right node - if this also causes an exception, something's really f'd up
prevnode = tree[3].index(node)
leftright = 1
# now let's get the rule that caused prevnode to -> node
rules.append((tree[0][prevnode],leftright,tree[1][prevnode]))
# if we've not yet reached the top, go up the tree one more step
if prevnode != 0:
RevTraverseTree(tree, prevnode, rules)
簡単なオプションは、訓練された決定木の適用方法を使用することだと思います。ツリーをトレーニングし、traindataを適用して、返されたインデックスからルックアップテーブルを作成します。
import numpy as np
from sklearn.tree import DecisionTreeClassifier
from sklearn.datasets import load_iris
iris = load_iris()
clf = DecisionTreeClassifier()
clf = clf.fit(iris.data, iris.target)
# apply training data to decision tree
leaf_indices = clf.apply(iris.data)
lookup = {}
# build lookup table
for i, leaf_index in enumerate(leaf_indices):
try:
lookup[leaf_index].append(iris.data[i])
except KeyError:
lookup[leaf_index] = []
lookup[leaf_index].append(iris.data[i])
# test
unkown_sample = [[4., 3.1, 6.1, 1.2]]
index = clf.apply(unkown_sample)
print(lookup[index[0]])
DecisionTreeをgraphvizの.dotファイル[1]にダンプしてから、graph_tool [2]でロードしてみましたか。
import numpy as np
from sklearn.tree import DecisionTreeClassifier
from sklearn.datasets import load_iris
from graph_tool.all import *
iris = load_iris()
clf = DecisionTreeClassifier()
clf = clf.fit(iris.data, iris.target)
tree.export_graphviz(clf,out_file='tree.dot')
#load graph with graph_tool and explore structure as you please
g = load_graph('tree.dot')
for v in g.vertices():
for e in v.out_edges():
print(e)
for w in v.out_neighbours():
print(w)
[1] http://scikit-learn.org/stable/modules/generated/sklearn.tree.export_graphviz.html