-
Notifications
You must be signed in to change notification settings - Fork 1
/
decision_tree.py
executable file
·83 lines (67 loc) · 3.36 KB
/
decision_tree.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
from sklearn.model_selection import GridSearchCV
from sklearn import tree
from sklearn.tree._tree import TREE_LEAF
def learn_local_decision_tree(Z, Yb, weights, class_values, multi_label=False, one_vs_rest=False, cv=5,
prune_tree=False):
dt = tree.DecisionTreeClassifier()
if prune_tree:
try:
param_list = {'min_samples_split': [0.002, 0.01, 0.05, 0.1, 0.2],
'min_samples_leaf': [0.001, 0.01, 0.05, 0.1, 0.2],
'max_depth': [None, 2, 4, 6, 8, 10, 12, 16]
}
if not multi_label or (multi_label and one_vs_rest):
if len(class_values) == 2 or (multi_label and one_vs_rest):
scoring = 'f1'
else:
scoring = 'f1_macro'
# scoring = 'f1_micro'
else:
scoring = 'f1_samples'
# print(datetime.datetime.now())
dt_search = GridSearchCV(dt, param_grid=param_list, scoring=scoring, cv=cv, n_jobs=-1, iid=False)
dt_search.fit(Z, Yb, sample_weight=weights)
# print(datetime.datetime.now())
dt = dt_search.best_estimator_
prune_duplicate_leaves(dt)
except ValueError:
dt.fit(Z, Yb, sample_weight=weights)
else:
dt.fit(Z, Yb, sample_weight=weights)
return dt
def show_tree(clf, feature_names=None, class_names=None):
# tree.plot_tree(clf, feature_names=feature_names, class_names=class_names, filled=True, max_depth=5)
dot_data = tree.export_graphviz(clf, out_file='tree.dot', feature_names=feature_names, class_names=class_names, filled=True)
'''
graph = graphviz.Source(dot_data, format='png')
graph
graph.render('decision_tree_graphivz')
'''
return
def is_leaf(inner_tree, index):
# Check whether node is leaf node
return (inner_tree.children_left[index] == TREE_LEAF and
inner_tree.children_right[index] == TREE_LEAF)
def prune_index(inner_tree, decisions, index=0):
"""
Start pruning from the bottom - if we start from the top, we might miss
nodes that become leaves during pruning.
Do not use this directly - use prune_duplicate_leaves instead.
"""
if not is_leaf(inner_tree, inner_tree.children_left[index]):
prune_index(inner_tree, decisions, inner_tree.children_left[index])
if not is_leaf(inner_tree, inner_tree.children_right[index]):
prune_index(inner_tree, decisions, inner_tree.children_right[index])
# Prune children if both children are leaves now and make the same decision:
if (is_leaf(inner_tree, inner_tree.children_left[index]) and
is_leaf(inner_tree, inner_tree.children_right[index]) and
(decisions[index] == decisions[inner_tree.children_left[index]]) and
(decisions[index] == decisions[inner_tree.children_right[index]])):
# turn node into a leaf by "unlinking" its children
inner_tree.children_left[index] = TREE_LEAF
inner_tree.children_right[index] = TREE_LEAF
# print("Pruned {}".format(index))
def prune_duplicate_leaves(dt):
# Remove leaves if both
decisions = dt.tree_.value.argmax(axis=2).flatten().tolist() # Decision for each node
prune_index(dt.tree_, decisions)