Tôi có thể trích xuất các quy tắc quyết định cơ bản (hoặc 'đường dẫn quyết định') từ cây được đào tạo trong cây quyết định dưới dạng danh sách văn bản không?
Cái gì đó như:
if A>0.4 then if B<0.2 then if C>0.8 then class='X'
Cảm ơn bạn đã giúp đỡ.
Tôi có thể trích xuất các quy tắc quyết định cơ bản (hoặc 'đường dẫn quyết định') từ cây được đào tạo trong cây quyết định dưới dạng danh sách văn bản không?
Cái gì đó như:
if A>0.4 then if B<0.2 then if C>0.8 then class='X'
Cảm ơn bạn đã giúp đỡ.
Câu trả lời:
Tôi tin rằng câu trả lời này đúng hơn các câu trả lời khác ở đây:
from sklearn.tree import _tree
def tree_to_code(tree, feature_names):
tree_ = tree.tree_
feature_name = [
feature_names[i] if i != _tree.TREE_UNDEFINED else "undefined!"
for i in tree_.feature
]
print "def tree({}):".format(", ".join(feature_names))
def recurse(node, depth):
indent = " " * depth
if tree_.feature[node] != _tree.TREE_UNDEFINED:
name = feature_name[node]
threshold = tree_.threshold[node]
print "{}if {} <= {}:".format(indent, name, threshold)
recurse(tree_.children_left[node], depth + 1)
print "{}else: # if {} > {}".format(indent, name, threshold)
recurse(tree_.children_right[node], depth + 1)
else:
print "{}return {}".format(indent, tree_.value[node])
recurse(0, 1)
Điều này in ra một hàm Python hợp lệ. Đây là một đầu ra ví dụ cho một cây đang cố gắng trả về đầu vào của nó, một số trong khoảng từ 0 đến 10.
def tree(f0):
if f0 <= 6.0:
if f0 <= 1.5:
return [[ 0.]]
else: # if f0 > 1.5
if f0 <= 4.5:
if f0 <= 3.5:
return [[ 3.]]
else: # if f0 > 3.5
return [[ 4.]]
else: # if f0 > 4.5
return [[ 5.]]
else: # if f0 > 6.0
if f0 <= 8.5:
if f0 <= 7.5:
return [[ 7.]]
else: # if f0 > 7.5
return [[ 8.]]
else: # if f0 > 8.5
return [[ 9.]]
Dưới đây là một số vấp ngã mà tôi thấy trong các câu trả lời khác:
tree_.threshold == -2để quyết định xem một nút có phải là một chiếc lá không là một ý tưởng tốt. Điều gì xảy ra nếu đó là nút quyết định thực sự với ngưỡng -2? Thay vào đó, bạn nên nhìn vào tree.featurehoặc tree.children_*.features = [feature_names[i] for i in tree_.feature]gặp sự cố với phiên bản sklearn của tôi, vì một số giá trị tree.tree_.featurelà -2 (đặc biệt cho các nút lá).print "{}return {}".format(indent, tree_.value[node])nên được thay đổi thành print "{}return {}".format(indent, np.argmax(tree_.value[node][0]))hàm để trả về chỉ mục lớp.
RandomForestClassifier.estimators_, nhưng tôi không thể tìm ra cách kết hợp các kết quả của người ước tính.
print "bla"=>print("bla")
Tôi đã tạo chức năng của riêng mình để trích xuất các quy tắc từ các cây quyết định được tạo bởi sklearn:
import pandas as pd
import numpy as np
from sklearn.tree import DecisionTreeClassifier
# dummy data:
df = pd.DataFrame({'col1':[0,1,2,3],'col2':[3,4,5,6],'dv':[0,1,0,1]})
# create decision tree
dt = DecisionTreeClassifier(max_depth=5, min_samples_leaf=1)
dt.fit(df.ix[:,:2], df.dv)
Hàm này trước tiên bắt đầu bằng các nút (được xác định bởi -1 trong mảng con) và sau đó tìm đệ quy cha mẹ. Tôi gọi đây là 'dòng dõi' của nút. Trên đường đi, tôi lấy các giá trị tôi cần để tạo logic logic if / then / other:
def get_lineage(tree, feature_names):
left = tree.tree_.children_left
right = tree.tree_.children_right
threshold = tree.tree_.threshold
features = [feature_names[i] for i in tree.tree_.feature]
# get ids of child nodes
idx = np.argwhere(left == -1)[:,0]
def recurse(left, right, child, lineage=None):
if lineage is None:
lineage = [child]
if child in left:
parent = np.where(left == child)[0].item()
split = 'l'
else:
parent = np.where(right == child)[0].item()
split = 'r'
lineage.append((parent, split, threshold[parent], features[parent]))
if parent == 0:
lineage.reverse()
return lineage
else:
return recurse(left, right, parent, lineage)
for child in idx:
for node in recurse(left, right, child):
print node
Các bộ dữ liệu bên dưới chứa mọi thứ tôi cần để tạo ra các câu lệnh if / then / other. Tôi không thích sử dụng docác khối trong SAS, đó là lý do tại sao tôi tạo logic mô tả toàn bộ đường dẫn của một nút. Số nguyên đơn sau các bộ dữ liệu là ID của nút đầu cuối trong một đường dẫn. Tất cả các bộ trước đó kết hợp để tạo nút đó.
In [1]: get_lineage(dt, df.columns)
(0, 'l', 0.5, 'col1')
1
(0, 'r', 0.5, 'col1')
(2, 'l', 4.5, 'col2')
3
(0, 'r', 0.5, 'col1')
(2, 'r', 4.5, 'col2')
(4, 'l', 2.5, 'col1')
5
(0, 'r', 0.5, 'col1')
(2, 'r', 4.5, 'col2')
(4, 'r', 2.5, 'col1')
6

(0.5, 2.5]. Các cây được thực hiện với phân vùng đệ quy. Không có gì ngăn cản một biến được chọn nhiều lần.
Tôi đã sửa đổi mã được gửi bởi Zelazny7 để in một số mã giả:
def get_code(tree, feature_names):
left = tree.tree_.children_left
right = tree.tree_.children_right
threshold = tree.tree_.threshold
features = [feature_names[i] for i in tree.tree_.feature]
value = tree.tree_.value
def recurse(left, right, threshold, features, node):
if (threshold[node] != -2):
print "if ( " + features[node] + " <= " + str(threshold[node]) + " ) {"
if left[node] != -1:
recurse (left, right, threshold, features,left[node])
print "} else {"
if right[node] != -1:
recurse (left, right, threshold, features,right[node])
print "}"
else:
print "return " + str(value[node])
recurse(left, right, threshold, features, 0)
nếu bạn gọi get_code(dt, df.columns)cùng một ví dụ, bạn sẽ nhận được:
if ( col1 <= 0.5 ) {
return [[ 1. 0.]]
} else {
if ( col2 <= 4.5 ) {
return [[ 0. 1.]]
} else {
if ( col1 <= 2.5 ) {
return [[ 1. 0.]]
} else {
return [[ 0. 1.]]
}
}
}
(threshold[node] != -2)thành ( left[node] != -1)(tương tự như phương pháp bên dưới để lấy id của các nút con)
Scikit learn đã giới thiệu một phương pháp mới tuyệt vời được gọi là export_texttrong phiên bản 0.21 (tháng 5 năm 2019) để trích xuất các quy tắc từ một cái cây. Tài liệu ở đây . Không còn cần thiết để tạo một chức năng tùy chỉnh.
Khi bạn đã phù hợp với mô hình của mình, bạn chỉ cần hai dòng mã. Đầu tiên, nhập khẩu export_text:
from sklearn.tree.export import export_text
Thứ hai, tạo một đối tượng sẽ chứa các quy tắc của bạn. Để làm cho các quy tắc trông dễ đọc hơn, hãy sử dụng feature_namesđối số và chuyển danh sách các tên tính năng của bạn. Ví dụ: nếu mô hình của bạn được gọi modelvà các tính năng của bạn được đặt tên trong khung dữ liệu được gọi X_train, bạn có thể tạo một đối tượng được gọi là tree_rules:
tree_rules = export_text(model, feature_names=list(X_train))
Sau đó chỉ cần in hoặc lưu tree_rules. Đầu ra của bạn sẽ trông như thế này:
|--- Age <= 0.63
| |--- EstimatedSalary <= 0.61
| | |--- Age <= -0.16
| | | |--- class: 0
| | |--- Age > -0.16
| | | |--- EstimatedSalary <= -0.06
| | | | |--- class: 0
| | | |--- EstimatedSalary > -0.06
| | | | |--- EstimatedSalary <= 0.40
| | | | | |--- EstimatedSalary <= 0.03
| | | | | | |--- class: 1
Có một DecisionTreeClassifierphương thức mới decision_path, trong bản phát hành 0.18.0 . Các nhà phát triển cung cấp một (cũng như các tài liệu) rộng hương .
Phần đầu tiên của mã trong hướng dẫn in cấu trúc cây có vẻ ổn. Tuy nhiên, tôi đã sửa đổi mã trong phần thứ hai để thẩm vấn một mẫu. Những thay đổi của tôi được biểu thị bằng# <--
Chỉnh sửa Các thay đổi được đánh dấu # <--trong mã bên dưới đã được cập nhật trong liên kết hướng dẫn sau khi các lỗi được chỉ ra trong các yêu cầu kéo # 8653 và # 10951 . Bây giờ dễ dàng hơn nhiều để làm theo.
sample_id = 0
node_index = node_indicator.indices[node_indicator.indptr[sample_id]:
node_indicator.indptr[sample_id + 1]]
print('Rules used to predict sample %s: ' % sample_id)
for node_id in node_index:
if leave_id[sample_id] == node_id: # <-- changed != to ==
#continue # <-- comment out
print("leaf node {} reached, no decision here".format(leave_id[sample_id])) # <--
else: # < -- added else to iterate through decision nodes
if (X_test[sample_id, feature[node_id]] <= threshold[node_id]):
threshold_sign = "<="
else:
threshold_sign = ">"
print("decision id node %s : (X[%s, %s] (= %s) %s %s)"
% (node_id,
sample_id,
feature[node_id],
X_test[sample_id, feature[node_id]], # <-- changed i to sample_id
threshold_sign,
threshold[node_id]))
Rules used to predict sample 0:
decision id node 0 : (X[0, 3] (= 2.4) > 0.800000011921)
decision id node 2 : (X[0, 2] (= 5.1) > 4.94999980927)
leaf node 4 reached, no decision here
Thay đổi sample_idđể xem các đường dẫn quyết định cho các mẫu khác. Tôi đã không hỏi các nhà phát triển về những thay đổi này, có vẻ trực quan hơn khi làm việc qua ví dụ.
from StringIO import StringIO
out = StringIO()
out = tree.export_graphviz(clf, out_file=out)
print out.getvalue()
Bạn có thể nhìn thấy một cây sơ đồ. Sau đó, clf.tree_.featurevà clf.tree_.valuelà mảng các tính năng phân tách các nút và mảng các giá trị nút tương ứng. Bạn có thể tham khảo thêm chi tiết từ nguồn github này .
Chỉ vì mọi người rất hữu ích, tôi sẽ thêm một sửa đổi cho các giải pháp đẹp của Zelazny7 và Daniele. Cái này dành cho python 2.7, với các tab để dễ đọc hơn:
def get_code(tree, feature_names, tabdepth=0):
left = tree.tree_.children_left
right = tree.tree_.children_right
threshold = tree.tree_.threshold
features = [feature_names[i] for i in tree.tree_.feature]
value = tree.tree_.value
def recurse(left, right, threshold, features, node, tabdepth=0):
if (threshold[node] != -2):
print '\t' * tabdepth,
print "if ( " + features[node] + " <= " + str(threshold[node]) + " ) {"
if left[node] != -1:
recurse (left, right, threshold, features,left[node], tabdepth+1)
print '\t' * tabdepth,
print "} else {"
if right[node] != -1:
recurse (left, right, threshold, features,right[node], tabdepth+1)
print '\t' * tabdepth,
print "}"
else:
print '\t' * tabdepth,
print "return " + str(value[node])
recurse(left, right, threshold, features, 0)
Các mã dưới đây là cách tiếp cận của tôi theo anaconda python 2.7 cộng với tên gói "pydot-ng" để tạo tệp PDF với các quy tắc quyết định. Tôi hi vọng nó hữu ích.
from sklearn import tree
clf = tree.DecisionTreeClassifier(max_leaf_nodes=n)
clf_ = clf.fit(X, data_y)
feature_names = X.columns
class_name = clf_.classes_.astype(int).astype(str)
def output_pdf(clf_, name):
from sklearn import tree
from sklearn.externals.six import StringIO
import pydot_ng as pydot
dot_data = StringIO()
tree.export_graphviz(clf_, out_file=dot_data,
feature_names=feature_names,
class_names=class_name,
filled=True, rounded=True,
special_characters=True,
node_ids=1,)
graph = pydot.graph_from_dot_data(dot_data.getvalue())
graph.write_pdf("%s.pdf"%name)
output_pdf(clf_, name='filename%s'%n)
Tôi đã trải qua điều này, nhưng tôi cần các quy tắc được viết theo định dạng này
if A>0.4 then if B<0.2 then if C>0.8 then class='X'
Vì vậy, tôi đã điều chỉnh câu trả lời của @paulkernfeld (cảm ơn) để bạn có thể tùy chỉnh theo nhu cầu của mình
def tree_to_code(tree, feature_names, Y):
tree_ = tree.tree_
feature_name = [
feature_names[i] if i != _tree.TREE_UNDEFINED else "undefined!"
for i in tree_.feature
]
pathto=dict()
global k
k = 0
def recurse(node, depth, parent):
global k
indent = " " * depth
if tree_.feature[node] != _tree.TREE_UNDEFINED:
name = feature_name[node]
threshold = tree_.threshold[node]
s= "{} <= {} ".format( name, threshold, node )
if node == 0:
pathto[node]=s
else:
pathto[node]=pathto[parent]+' & ' +s
recurse(tree_.children_left[node], depth + 1, node)
s="{} > {}".format( name, threshold)
if node == 0:
pathto[node]=s
else:
pathto[node]=pathto[parent]+' & ' +s
recurse(tree_.children_right[node], depth + 1, node)
else:
k=k+1
print(k,')',pathto[parent], tree_.value[node])
recurse(0, 1, 0)
Điều này được xây dựng dựa trên câu trả lời của @paulkernfeld. Nếu bạn có một khung dữ liệu X với các tính năng của bạn và một khung dữ liệu đích y với các giá trị lại của bạn và bạn muốn có một ý tưởng về giá trị y kết thúc ở nút nào (và cả ant để vẽ biểu đồ cho phù hợp), bạn có thể thực hiện như sau:
def tree_to_code(tree, feature_names):
from sklearn.tree import _tree
codelines = []
codelines.append('def get_cat(X_tmp):\n')
codelines.append(' catout = []\n')
codelines.append(' for codelines in range(0,X_tmp.shape[0]):\n')
codelines.append(' Xin = X_tmp.iloc[codelines]\n')
tree_ = tree.tree_
feature_name = [
feature_names[i] if i != _tree.TREE_UNDEFINED else "undefined!"
for i in tree_.feature
]
#print "def tree({}):".format(", ".join(feature_names))
def recurse(node, depth):
indent = " " * depth
if tree_.feature[node] != _tree.TREE_UNDEFINED:
name = feature_name[node]
threshold = tree_.threshold[node]
codelines.append ('{}if Xin["{}"] <= {}:\n'.format(indent, name, threshold))
recurse(tree_.children_left[node], depth + 1)
codelines.append( '{}else: # if Xin["{}"] > {}\n'.format(indent, name, threshold))
recurse(tree_.children_right[node], depth + 1)
else:
codelines.append( '{}mycat = {}\n'.format(indent, node))
recurse(0, 1)
codelines.append(' catout.append(mycat)\n')
codelines.append(' return pd.DataFrame(catout,index=X_tmp.index,columns=["category"])\n')
codelines.append('node_ids = get_cat(X)\n')
return codelines
mycode = tree_to_code(clf,X.columns.values)
# now execute the function and obtain the dataframe with all nodes
exec(''.join(mycode))
node_ids = [int(x[0]) for x in node_ids.values]
node_ids2 = pd.DataFrame(node_ids)
print('make plot')
import matplotlib.cm as cm
colors = cm.rainbow(np.linspace(0, 1, 1+max( list(set(node_ids)))))
#plt.figure(figsize=cm2inch(24, 21))
for i in list(set(node_ids)):
plt.plot(y[node_ids2.values==i],'o',color=colors[i], label=str(i))
mytitle = ['y colored by node']
plt.title(mytitle ,fontsize=14)
plt.xlabel('my xlabel')
plt.ylabel(tagname)
plt.xticks(rotation=70)
plt.legend(loc='upper center', bbox_to_anchor=(0.5, 1.00), shadow=True, ncol=9)
plt.tight_layout()
plt.show()
plt.close
không phải là phiên bản thanh lịch nhất nhưng nó thực hiện công việc ...
Tôi đã sửa đổi mã thích hàng đầu để thụt vào một máy tính xách tay jupyter python 3 một cách chính xác
import numpy as np
from sklearn.tree import _tree
def tree_to_code(tree, feature_names):
tree_ = tree.tree_
feature_name = [feature_names[i]
if i != _tree.TREE_UNDEFINED else "undefined!"
for i in tree_.feature]
print("def tree({}):".format(", ".join(feature_names)))
def recurse(node, depth):
indent = " " * depth
if tree_.feature[node] != _tree.TREE_UNDEFINED:
name = feature_name[node]
threshold = tree_.threshold[node]
print("{}if {} <= {}:".format(indent, name, threshold))
recurse(tree_.children_left[node], depth + 1)
print("{}else: # if {} > {}".format(indent, name, threshold))
recurse(tree_.children_right[node], depth + 1)
else:
print("{}return {}".format(indent, np.argmax(tree_.value[node])))
recurse(0, 1)
Đây là một chức năng, in các quy tắc của cây quyết định tìm hiểu scikit theo python 3 và với phần bù cho các khối có điều kiện để làm cho cấu trúc dễ đọc hơn:
def print_decision_tree(tree, feature_names=None, offset_unit=' '):
'''Plots textual representation of rules of a decision tree
tree: scikit-learn representation of tree
feature_names: list of feature names. They are set to f1,f2,f3,... if not specified
offset_unit: a string of offset of the conditional block'''
left = tree.tree_.children_left
right = tree.tree_.children_right
threshold = tree.tree_.threshold
value = tree.tree_.value
if feature_names is None:
features = ['f%d'%i for i in tree.tree_.feature]
else:
features = [feature_names[i] for i in tree.tree_.feature]
def recurse(left, right, threshold, features, node, depth=0):
offset = offset_unit*depth
if (threshold[node] != -2):
print(offset+"if ( " + features[node] + " <= " + str(threshold[node]) + " ) {")
if left[node] != -1:
recurse (left, right, threshold, features,left[node],depth+1)
print(offset+"} else {")
if right[node] != -1:
recurse (left, right, threshold, features,right[node],depth+1)
print(offset+"}")
else:
print(offset+"return " + str(value[node]))
recurse(left, right, threshold, features, 0,0)
Bạn cũng có thể làm cho nó nhiều thông tin hơn bằng cách phân biệt nó thuộc lớp nào hoặc thậm chí bằng cách đề cập đến giá trị đầu ra của nó.
def print_decision_tree(tree, feature_names, offset_unit=' '):
left = tree.tree_.children_left
right = tree.tree_.children_right
threshold = tree.tree_.threshold
value = tree.tree_.value
if feature_names is None:
features = ['f%d'%i for i in tree.tree_.feature]
else:
features = [feature_names[i] for i in tree.tree_.feature]
def recurse(left, right, threshold, features, node, depth=0):
offset = offset_unit*depth
if (threshold[node] != -2):
print(offset+"if ( " + features[node] + " <= " + str(threshold[node]) + " ) {")
if left[node] != -1:
recurse (left, right, threshold, features,left[node],depth+1)
print(offset+"} else {")
if right[node] != -1:
recurse (left, right, threshold, features,right[node],depth+1)
print(offset+"}")
else:
#print(offset,value[node])
#To remove values from node
temp=str(value[node])
mid=len(temp)//2
tempx=[]
tempy=[]
cnt=0
for i in temp:
if cnt<=mid:
tempx.append(i)
cnt+=1
else:
tempy.append(i)
cnt+=1
val_yes=[]
val_no=[]
res=[]
for j in tempx:
if j=="[" or j=="]" or j=="." or j==" ":
res.append(j)
else:
val_no.append(j)
for j in tempy:
if j=="[" or j=="]" or j=="." or j==" ":
res.append(j)
else:
val_yes.append(j)
val_yes = int("".join(map(str, val_yes)))
val_no = int("".join(map(str, val_no)))
if val_yes>val_no:
print(offset,'\033[1m',"YES")
print('\033[0m')
elif val_no>val_yes:
print(offset,'\033[1m',"NO")
print('\033[0m')
else:
print(offset,'\033[1m',"Tie")
print('\033[0m')
recurse(left, right, threshold, features, 0,0)
Dưới đây là cách tiếp cận của tôi để trích xuất các quy tắc quyết định trong một hình thức có thể được sử dụng trực tiếp trong sql, vì vậy dữ liệu có thể được nhóm theo nút. (Dựa trên cách tiếp cận của các áp phích trước.)
Kết quả sẽ là các CASEmệnh đề tiếp theo có thể được sao chép vào câu lệnh sql, ví dụ.
SELECT COALESCE(*CASE WHEN <conditions> THEN > <NodeA>*, > *CASE WHEN
<conditions> THEN <NodeB>*, > ....)NodeName,* > FROM <table or view>
import numpy as np
import pickle
feature_names=.............
features = [feature_names[i] for i in range(len(feature_names))]
clf= pickle.loads(trained_model)
impurity=clf.tree_.impurity
importances = clf.feature_importances_
SqlOut=""
#global Conts
global ContsNode
global Path
#Conts=[]#
ContsNode=[]
Path=[]
global Results
Results=[]
def print_decision_tree(tree, feature_names, offset_unit='' ''):
left = tree.tree_.children_left
right = tree.tree_.children_right
threshold = tree.tree_.threshold
value = tree.tree_.value
if feature_names is None:
features = [''f%d''%i for i in tree.tree_.feature]
else:
features = [feature_names[i] for i in tree.tree_.feature]
def recurse(left, right, threshold, features, node, depth=0,ParentNode=0,IsElse=0):
global Conts
global ContsNode
global Path
global Results
global LeftParents
LeftParents=[]
global RightParents
RightParents=[]
for i in range(len(left)): # This is just to tell you how to create a list.
LeftParents.append(-1)
RightParents.append(-1)
ContsNode.append("")
Path.append("")
for i in range(len(left)): # i is node
if (left[i]==-1 and right[i]==-1):
if LeftParents[i]>=0:
if Path[LeftParents[i]]>" ":
Path[i]=Path[LeftParents[i]]+" AND " +ContsNode[LeftParents[i]]
else:
Path[i]=ContsNode[LeftParents[i]]
if RightParents[i]>=0:
if Path[RightParents[i]]>" ":
Path[i]=Path[RightParents[i]]+" AND not " +ContsNode[RightParents[i]]
else:
Path[i]=" not " +ContsNode[RightParents[i]]
Results.append(" case when " +Path[i]+" then ''" +"{:4d}".format(i)+ " "+"{:2.2f}".format(impurity[i])+" "+Path[i][0:180]+"''")
else:
if LeftParents[i]>=0:
if Path[LeftParents[i]]>" ":
Path[i]=Path[LeftParents[i]]+" AND " +ContsNode[LeftParents[i]]
else:
Path[i]=ContsNode[LeftParents[i]]
if RightParents[i]>=0:
if Path[RightParents[i]]>" ":
Path[i]=Path[RightParents[i]]+" AND not " +ContsNode[RightParents[i]]
else:
Path[i]=" not "+ContsNode[RightParents[i]]
if (left[i]!=-1):
LeftParents[left[i]]=i
if (right[i]!=-1):
RightParents[right[i]]=i
ContsNode[i]= "( "+ features[i] + " <= " + str(threshold[i]) + " ) "
recurse(left, right, threshold, features, 0,0,0,0)
print_decision_tree(clf,features)
SqlOut=""
for i in range(len(Results)):
SqlOut=SqlOut+Results[i]+ " end,"+chr(13)+chr(10)
Bây giờ bạn có thể sử dụng export bản.
from sklearn.tree import export_text
r = export_text(loan_tree, feature_names=(list(X_train.columns)))
print(r)
Một ví dụ hoàn chỉnh từ [sklearn] [1]
from sklearn.datasets import load_iris
from sklearn.tree import DecisionTreeClassifier
from sklearn.tree import export_text
iris = load_iris()
X = iris['data']
y = iris['target']
decision_tree = DecisionTreeClassifier(random_state=0, max_depth=2)
decision_tree = decision_tree.fit(X, y)
r = export_text(decision_tree, feature_names=iris['feature_names'])
print(r)
Mã Zelazny7 đã sửa đổi để tìm nạp SQL từ cây quyết định.
# SQL from decision tree
def get_lineage(tree, feature_names):
left = tree.tree_.children_left
right = tree.tree_.children_right
threshold = tree.tree_.threshold
features = [feature_names[i] for i in tree.tree_.feature]
le='<='
g ='>'
# get ids of child nodes
idx = np.argwhere(left == -1)[:,0]
def recurse(left, right, child, lineage=None):
if lineage is None:
lineage = [child]
if child in left:
parent = np.where(left == child)[0].item()
split = 'l'
else:
parent = np.where(right == child)[0].item()
split = 'r'
lineage.append((parent, split, threshold[parent], features[parent]))
if parent == 0:
lineage.reverse()
return lineage
else:
return recurse(left, right, parent, lineage)
print 'case '
for j,child in enumerate(idx):
clause=' when '
for node in recurse(left, right, child):
if len(str(node))<3:
continue
i=node
if i[1]=='l': sign=le
else: sign=g
clause=clause+i[3]+sign+str(i[2])+' and '
clause=clause[:-4]+' then '+str(j)
print clause
print 'else 99 end as clusters'
Rõ ràng từ lâu, ai đó đã quyết định thử thêm chức năng sau vào các hàm xuất cây chính thức của scikit (về cơ bản chỉ hỗ trợ export_graphviz)
def export_dict(tree, feature_names=None, max_depth=None) :
"""Export a decision tree in dict format.
Đây là cam kết đầy đủ của anh ấy:
Không chắc chắn chính xác những gì đã xảy ra với bình luận này. Nhưng bạn cũng có thể thử sử dụng chức năng đó.
Tôi nghĩ rằng điều này đảm bảo một yêu cầu tài liệu nghiêm túc cho những người giỏi về scikit-learn để viết tài liệu đúng về sklearn.tree.TreeAPI, cấu trúc cây bên dưới DecisionTreeClassifierlộ ra như thuộc tính của nó tree_.
Chỉ cần sử dụng chức năng từ sklearn.tree như thế này
from sklearn.tree import export_graphviz
export_graphviz(tree,
out_file = "tree.dot",
feature_names = tree.columns) //or just ["petal length", "petal width"]
Và sau đó tìm trong thư mục dự án của bạn để tìm tập tin tree.dot , sao chép TẤT CẢ nội dung và dán vào đây http://www.webgraphviz.com/ và tạo biểu đồ của bạn :)
Cảm ơn giải pháp tuyệt vời của @paulkerfeld. Ngày đầu của giải pháp của mình, cho tất cả những ai muốn có một phiên bản serialized cây, chỉ cần sử dụng tree.threshold, tree.children_left, tree.children_right, tree.featurevà tree.value. Vì các lá không có sự phân chia và do đó không có tên tính năng và trẻ em, giữ chỗ của chúng trong tree.featurevà tree.children_***đang _tree.TREE_UNDEFINEDvà _tree.TREE_LEAF. Mỗi phân chia được chỉ định một chỉ số duy nhất bởi depth first search.
Lưu ý rằng tree.valuehình dạng[n, 1, 1]
Đây là một hàm tạo mã Python từ cây quyết định bằng cách chuyển đổi đầu ra của export_text:
import string
from sklearn.tree import export_text
def export_py_code(tree, feature_names, max_depth=100, spacing=4):
if spacing < 2:
raise ValueError('spacing must be > 1')
# Clean up feature names (for correctness)
nums = string.digits
alnums = string.ascii_letters + nums
clean = lambda s: ''.join(c if c in alnums else '_' for c in s)
features = [clean(x) for x in feature_names]
features = ['_'+x if x[0] in nums else x for x in features if x]
if len(set(features)) != len(feature_names):
raise ValueError('invalid feature names')
# First: export tree to text
res = export_text(tree, feature_names=features,
max_depth=max_depth,
decimals=6,
spacing=spacing-1)
# Second: generate Python code from the text
skip, dash = ' '*spacing, '-'*(spacing-1)
code = 'def decision_tree({}):\n'.format(', '.join(features))
for line in repr(tree).split('\n'):
code += skip + "# " + line + '\n'
for line in res.split('\n'):
line = line.rstrip().replace('|',' ')
if '<' in line or '>' in line:
line, val = line.rsplit(maxsplit=1)
line = line.replace(' ' + dash, 'if')
line = '{} {:g}:'.format(line, float(val))
else:
line = line.replace(' {} class:'.format(dash), 'return')
code += skip + line + '\n'
return code
Sử dụng mẫu:
res = export_py_code(tree, feature_names=names, spacing=4)
print (res)
Đầu ra mẫu:
def decision_tree(f1, f2, f3):
# DecisionTreeClassifier(class_weight=None, criterion='gini', max_depth=3,
# max_features=None, max_leaf_nodes=None,
# min_impurity_decrease=0.0, min_impurity_split=None,
# min_samples_leaf=1, min_samples_split=2,
# min_weight_fraction_leaf=0.0, presort=False,
# random_state=42, splitter='best')
if f1 <= 12.5:
if f2 <= 17.5:
if f1 <= 10.5:
return 2
if f1 > 10.5:
return 3
if f2 > 17.5:
if f2 <= 22.5:
return 1
if f2 > 22.5:
return 1
if f1 > 12.5:
if f1 <= 17.5:
if f3 <= 23.5:
return 2
if f3 > 23.5:
return 3
if f1 > 17.5:
if f1 <= 25:
return 1
if f1 > 25:
return 2
Ví dụ trên được tạo ra với names = ['f'+str(j+1) for j in range(NUM_FEATURES)].
Một tính năng tiện dụng là nó có thể tạo kích thước tệp nhỏ hơn với khoảng cách giảm. Chỉ cần đặt spacing=2.