Làm thế nào để trích xuất các quy tắc quyết định từ cây quyết định tìm hiểu scikit?


156

Tôi có thể trích xuất các quy tắc quyết định cơ bản (hoặc 'đường dẫn quyết định') từ cây được đào tạo trong cây quyết định dưới dạng danh sách văn bản không?

Cái gì đó như:

if A>0.4 then if B<0.2 then if C>0.8 then class='X'

Cảm ơn bạn đã giúp đỡ.



Bạn đã bao giờ tìm thấy một câu trả lời cho vấn đề này? Tôi phải xuất các quy tắc cây quyết định theo định dạng bước dữ liệu SAS gần như chính xác như bạn đã liệt kê.
Zelazny7

1
Bạn có thể sử dụng gói sklearn-porter để xuất và chuyển các cây quyết định (cũng là rừng ngẫu nhiên và cây được tăng cường) sang C, Java, JavaScript và các cây khác.
Darius

Bạn có thể kiểm tra liên kết này- kdnuggets.com/2017/05/ từ
agrawal

Câu trả lời:


138

Tôi tin rằng câu trả lời này đúng hơn các câu trả lời khác ở đây:

from sklearn.tree import _tree

def tree_to_code(tree, feature_names):
    tree_ = tree.tree_
    feature_name = [
        feature_names[i] if i != _tree.TREE_UNDEFINED else "undefined!"
        for i in tree_.feature
    ]
    print "def tree({}):".format(", ".join(feature_names))

    def recurse(node, depth):
        indent = "  " * depth
        if tree_.feature[node] != _tree.TREE_UNDEFINED:
            name = feature_name[node]
            threshold = tree_.threshold[node]
            print "{}if {} <= {}:".format(indent, name, threshold)
            recurse(tree_.children_left[node], depth + 1)
            print "{}else:  # if {} > {}".format(indent, name, threshold)
            recurse(tree_.children_right[node], depth + 1)
        else:
            print "{}return {}".format(indent, tree_.value[node])

    recurse(0, 1)

Điều này in ra một hàm Python hợp lệ. Đây là một đầu ra ví dụ cho một cây đang cố gắng trả về đầu vào của nó, một số trong khoảng từ 0 đến 10.

def tree(f0):
  if f0 <= 6.0:
    if f0 <= 1.5:
      return [[ 0.]]
    else:  # if f0 > 1.5
      if f0 <= 4.5:
        if f0 <= 3.5:
          return [[ 3.]]
        else:  # if f0 > 3.5
          return [[ 4.]]
      else:  # if f0 > 4.5
        return [[ 5.]]
  else:  # if f0 > 6.0
    if f0 <= 8.5:
      if f0 <= 7.5:
        return [[ 7.]]
      else:  # if f0 > 7.5
        return [[ 8.]]
    else:  # if f0 > 8.5
      return [[ 9.]]

Dưới đây là một số vấp ngã mà tôi thấy trong các câu trả lời khác:

  1. Sử dụng tree_.threshold == -2để quyết định xem một nút có phải là một chiếc lá không là một ý tưởng tốt. Điều gì xảy ra nếu đó là nút quyết định thực sự với ngưỡng -2? Thay vào đó, bạn nên nhìn vào tree.featurehoặc tree.children_*.
  2. Dòng features = [feature_names[i] for i in tree_.feature]gặp sự cố với phiên bản sklearn của tôi, vì một số giá trị tree.tree_.featurelà -2 (đặc biệt cho các nút lá).
  3. Không cần phải có nhiều câu lệnh if trong hàm đệ quy, chỉ cần một câu lệnh là ổn.

1
Mã này hoạt động rất tốt cho tôi. Tuy nhiên, tôi có hơn 500 Feature_names nên mã đầu ra gần như không thể hiểu được. Có cách nào để tôi chỉ nhập tính năng_names mà tôi tò mò vào chức năng không?
dùng3768495

1
Tôi đồng ý với những nhận xét trước. IIUC, print "{}return {}".format(indent, tree_.value[node])nên được thay đổi thành print "{}return {}".format(indent, np.argmax(tree_.value[node][0]))hàm để trả về chỉ mục lớp.
súpault

1
@paulkernfeld À đúng rồi, tôi thấy rằng bạn có thể lặp lại RandomForestClassifier.estimators_, nhưng tôi không thể tìm ra cách kết hợp các kết quả của người ước tính.
Nathan Lloyd

6
Tôi không thể làm việc này trong python 3, các bit _tree dường như chưa từng hoạt động và TREE_UNDEFINED không được xác định. Liên kết này đã giúp tôi. Mặc dù mã xuất khẩu không thể chạy trực tiếp bằng python, nhưng nó rất dễ dịch và dễ dịch sang các ngôn ngữ khác: web.archive.org/web/20171005203850/http://www.kdnuggets.com/
Josiah

1
@Josiah, thêm () vào các câu lệnh in để làm cho nó hoạt động trong python3. ví dụ print "bla"=>print("bla")
Nir

48

Tôi đã tạo chức năng của riêng mình để trích xuất các quy tắc từ các cây quyết định được tạo bởi sklearn:

import pandas as pd
import numpy as np
from sklearn.tree import DecisionTreeClassifier

# dummy data:
df = pd.DataFrame({'col1':[0,1,2,3],'col2':[3,4,5,6],'dv':[0,1,0,1]})

# create decision tree
dt = DecisionTreeClassifier(max_depth=5, min_samples_leaf=1)
dt.fit(df.ix[:,:2], df.dv)

Hàm này trước tiên bắt đầu bằng các nút (được xác định bởi -1 trong mảng con) và sau đó tìm đệ quy cha mẹ. Tôi gọi đây là 'dòng dõi' của nút. Trên đường đi, tôi lấy các giá trị tôi cần để tạo logic logic if / then / other:

def get_lineage(tree, feature_names):
     left      = tree.tree_.children_left
     right     = tree.tree_.children_right
     threshold = tree.tree_.threshold
     features  = [feature_names[i] for i in tree.tree_.feature]

     # get ids of child nodes
     idx = np.argwhere(left == -1)[:,0]     

     def recurse(left, right, child, lineage=None):          
          if lineage is None:
               lineage = [child]
          if child in left:
               parent = np.where(left == child)[0].item()
               split = 'l'
          else:
               parent = np.where(right == child)[0].item()
               split = 'r'

          lineage.append((parent, split, threshold[parent], features[parent]))

          if parent == 0:
               lineage.reverse()
               return lineage
          else:
               return recurse(left, right, parent, lineage)

     for child in idx:
          for node in recurse(left, right, child):
               print node

Các bộ dữ liệu bên dưới chứa mọi thứ tôi cần để tạo ra các câu lệnh if / then / other. Tôi không thích sử dụng docác khối trong SAS, đó là lý do tại sao tôi tạo logic mô tả toàn bộ đường dẫn của một nút. Số nguyên đơn sau các bộ dữ liệu là ID của nút đầu cuối trong một đường dẫn. Tất cả các bộ trước đó kết hợp để tạo nút đó.

In [1]: get_lineage(dt, df.columns)
(0, 'l', 0.5, 'col1')
1
(0, 'r', 0.5, 'col1')
(2, 'l', 4.5, 'col2')
3
(0, 'r', 0.5, 'col1')
(2, 'r', 4.5, 'col2')
(4, 'l', 2.5, 'col1')
5
(0, 'r', 0.5, 'col1')
(2, 'r', 4.5, 'col2')
(4, 'r', 2.5, 'col1')
6

Đầu ra GraphViz của cây ví dụ


loại cây này có đúng không bởi vì col1 đang đến một lần nữa là col1 <= 0,50000 và một col1 <= 2.5000 nếu có, đây có phải là bất kỳ loại đệ quy nào được sử dụng trong thư viện không
jayant singh

nhánh bên phải sẽ có hồ sơ giữa (0.5, 2.5]. Các cây được thực hiện với phân vùng đệ quy. Không có gì ngăn cản một biến được chọn nhiều lần.
Zelazny7

được rồi, bạn có thể giải thích phần đệ quy những gì xảy ra một cách chính xác vì tôi đã sử dụng nó trong mã của mình và kết quả tương tự được nhìn thấy
jayant singh

38

Tôi đã sửa đổi mã được gửi bởi Zelazny7 để in một số mã giả:

def get_code(tree, feature_names):
        left      = tree.tree_.children_left
        right     = tree.tree_.children_right
        threshold = tree.tree_.threshold
        features  = [feature_names[i] for i in tree.tree_.feature]
        value = tree.tree_.value

        def recurse(left, right, threshold, features, node):
                if (threshold[node] != -2):
                        print "if ( " + features[node] + " <= " + str(threshold[node]) + " ) {"
                        if left[node] != -1:
                                recurse (left, right, threshold, features,left[node])
                        print "} else {"
                        if right[node] != -1:
                                recurse (left, right, threshold, features,right[node])
                        print "}"
                else:
                        print "return " + str(value[node])

        recurse(left, right, threshold, features, 0)

nếu bạn gọi get_code(dt, df.columns)cùng một ví dụ, bạn sẽ nhận được:

if ( col1 <= 0.5 ) {
return [[ 1.  0.]]
} else {
if ( col2 <= 4.5 ) {
return [[ 0.  1.]]
} else {
if ( col1 <= 2.5 ) {
return [[ 1.  0.]]
} else {
return [[ 0.  1.]]
}
}
}

1
Bạn có thể cho biết, chính xác [[1. 0.]] trong câu lệnh return có nghĩa là gì ở đầu ra trên. Tôi không phải là một người Python, nhưng làm việc trên cùng một thứ. Vì vậy, nó sẽ tốt cho tôi nếu bạn vui lòng chứng minh một số chi tiết để nó sẽ dễ dàng hơn cho tôi.
Subhradip Bose

1
@ user3156186 Điều đó có nghĩa là có một đối tượng trong lớp '0' và không có đối tượng trong lớp '1'
Daniele

1
@Daniele, bạn có biết các lớp học được sắp xếp như thế nào không? Tôi đoán chữ và số, nhưng tôi không tìm thấy xác nhận ở bất cứ đâu.
IanS

Cảm ơn! Đối với kịch bản trường hợp cạnh trong đó giá trị ngưỡng thực sự là -2, chúng ta có thể cần phải thay đổi (threshold[node] != -2)thành ( left[node] != -1)(tương tự như phương pháp bên dưới để lấy id của các nút con)
tlingf

@Daniele, có ý tưởng nào để làm cho hàm của bạn "get_code" "trả lại" một giá trị và không "in" nó không, vì tôi cần gửi nó đến một hàm khác?
RoyaumeIX

17

Scikit learn đã giới thiệu một phương pháp mới tuyệt vời được gọi là export_texttrong phiên bản 0.21 (tháng 5 năm 2019) để trích xuất các quy tắc từ một cái cây. Tài liệu ở đây . Không còn cần thiết để tạo một chức năng tùy chỉnh.

Khi bạn đã phù hợp với mô hình của mình, bạn chỉ cần hai dòng mã. Đầu tiên, nhập khẩu export_text:

from sklearn.tree.export import export_text

Thứ hai, tạo một đối tượng sẽ chứa các quy tắc của bạn. Để làm cho các quy tắc trông dễ đọc hơn, hãy sử dụng feature_namesđối số và chuyển danh sách các tên tính năng của bạn. Ví dụ: nếu mô hình của bạn được gọi modelvà các tính năng của bạn được đặt tên trong khung dữ liệu được gọi X_train, bạn có thể tạo một đối tượng được gọi là tree_rules:

tree_rules = export_text(model, feature_names=list(X_train))

Sau đó chỉ cần in hoặc lưu tree_rules. Đầu ra của bạn sẽ trông như thế này:

|--- Age <= 0.63
|   |--- EstimatedSalary <= 0.61
|   |   |--- Age <= -0.16
|   |   |   |--- class: 0
|   |   |--- Age >  -0.16
|   |   |   |--- EstimatedSalary <= -0.06
|   |   |   |   |--- class: 0
|   |   |   |--- EstimatedSalary >  -0.06
|   |   |   |   |--- EstimatedSalary <= 0.40
|   |   |   |   |   |--- EstimatedSalary <= 0.03
|   |   |   |   |   |   |--- class: 1

14

Có một DecisionTreeClassifierphương thức mới decision_path, trong bản phát hành 0.18.0 . Các nhà phát triển cung cấp một (cũng như các tài liệu) rộng hương .

Phần đầu tiên của mã trong hướng dẫn in cấu trúc cây có vẻ ổn. Tuy nhiên, tôi đã sửa đổi mã trong phần thứ hai để thẩm vấn một mẫu. Những thay đổi của tôi được biểu thị bằng# <--

Chỉnh sửa Các thay đổi được đánh dấu # <--trong mã bên dưới đã được cập nhật trong liên kết hướng dẫn sau khi các lỗi được chỉ ra trong các yêu cầu kéo # 8653# 10951 . Bây giờ dễ dàng hơn nhiều để làm theo.

sample_id = 0
node_index = node_indicator.indices[node_indicator.indptr[sample_id]:
                                    node_indicator.indptr[sample_id + 1]]

print('Rules used to predict sample %s: ' % sample_id)
for node_id in node_index:

    if leave_id[sample_id] == node_id:  # <-- changed != to ==
        #continue # <-- comment out
        print("leaf node {} reached, no decision here".format(leave_id[sample_id])) # <--

    else: # < -- added else to iterate through decision nodes
        if (X_test[sample_id, feature[node_id]] <= threshold[node_id]):
            threshold_sign = "<="
        else:
            threshold_sign = ">"

        print("decision id node %s : (X[%s, %s] (= %s) %s %s)"
              % (node_id,
                 sample_id,
                 feature[node_id],
                 X_test[sample_id, feature[node_id]], # <-- changed i to sample_id
                 threshold_sign,
                 threshold[node_id]))

Rules used to predict sample 0: 
decision id node 0 : (X[0, 3] (= 2.4) > 0.800000011921)
decision id node 2 : (X[0, 2] (= 5.1) > 4.94999980927)
leaf node 4 reached, no decision here

Thay đổi sample_idđể xem các đường dẫn quyết định cho các mẫu khác. Tôi đã không hỏi các nhà phát triển về những thay đổi này, có vẻ trực quan hơn khi làm việc qua ví dụ.


bạn của tôi là một huyền thoại! bất kỳ ý tưởng làm thế nào để vẽ cây quyết định cho mẫu cụ thể đó? nhiều sự giúp đỡ được đánh giá cao

1
Cảm ơn Victor, có lẽ tốt nhất nên hỏi đây là một câu hỏi riêng vì các yêu cầu về âm mưu có thể cụ thể theo nhu cầu của người dùng. Bạn có thể sẽ nhận được phản hồi tốt nếu bạn cung cấp ý tưởng về những gì bạn muốn đầu ra trông như thế nào.
Kevin

này kevin, tôi đã tạo câu hỏi stackoverflow.com/questions/48888893/ từ

bạn sẽ thật tử tế khi xem qua: stackoverflow.com/questions/52654280/
Kẻ

Bạn có thể vui lòng giải thích phần được gọi là node_index, không nhận phần đó. nó làm gì?
Anindya Sankar Dey

12
from StringIO import StringIO
out = StringIO()
out = tree.export_graphviz(clf, out_file=out)
print out.getvalue()

Bạn có thể nhìn thấy một cây sơ đồ. Sau đó, clf.tree_.featureclf.tree_.valuelà mảng các tính năng phân tách các nút và mảng các giá trị nút tương ứng. Bạn có thể tham khảo thêm chi tiết từ nguồn github này .


1
Vâng, tôi biết làm thế nào để vẽ cây - nhưng tôi cần phiên bản nhiều văn bản hơn - các quy tắc. một cái gì đó như: cam.biolab.si/docs/latest/reference/rst/iêng
Dror

4

Chỉ vì mọi người rất hữu ích, tôi sẽ thêm một sửa đổi cho các giải pháp đẹp của Zelazny7 và Daniele. Cái này dành cho python 2.7, với các tab để dễ đọc hơn:

def get_code(tree, feature_names, tabdepth=0):
    left      = tree.tree_.children_left
    right     = tree.tree_.children_right
    threshold = tree.tree_.threshold
    features  = [feature_names[i] for i in tree.tree_.feature]
    value = tree.tree_.value

    def recurse(left, right, threshold, features, node, tabdepth=0):
            if (threshold[node] != -2):
                    print '\t' * tabdepth,
                    print "if ( " + features[node] + " <= " + str(threshold[node]) + " ) {"
                    if left[node] != -1:
                            recurse (left, right, threshold, features,left[node], tabdepth+1)
                    print '\t' * tabdepth,
                    print "} else {"
                    if right[node] != -1:
                            recurse (left, right, threshold, features,right[node], tabdepth+1)
                    print '\t' * tabdepth,
                    print "}"
            else:
                    print '\t' * tabdepth,
                    print "return " + str(value[node])

    recurse(left, right, threshold, features, 0)

3

Các mã dưới đây là cách tiếp cận của tôi theo anaconda python 2.7 cộng với tên gói "pydot-ng" để tạo tệp PDF với các quy tắc quyết định. Tôi hi vọng nó hữu ích.

from sklearn import tree

clf = tree.DecisionTreeClassifier(max_leaf_nodes=n)
clf_ = clf.fit(X, data_y)

feature_names = X.columns
class_name = clf_.classes_.astype(int).astype(str)

def output_pdf(clf_, name):
    from sklearn import tree
    from sklearn.externals.six import StringIO
    import pydot_ng as pydot
    dot_data = StringIO()
    tree.export_graphviz(clf_, out_file=dot_data,
                         feature_names=feature_names,
                         class_names=class_name,
                         filled=True, rounded=True,
                         special_characters=True,
                          node_ids=1,)
    graph = pydot.graph_from_dot_data(dot_data.getvalue())
    graph.write_pdf("%s.pdf"%name)

output_pdf(clf_, name='filename%s'%n)

một chương trình biểu đồ cây ở đây


3

Tôi đã trải qua điều này, nhưng tôi cần các quy tắc được viết theo định dạng này

if A>0.4 then if B<0.2 then if C>0.8 then class='X' 

Vì vậy, tôi đã điều chỉnh câu trả lời của @paulkernfeld (cảm ơn) để bạn có thể tùy chỉnh theo nhu cầu của mình

def tree_to_code(tree, feature_names, Y):
    tree_ = tree.tree_
    feature_name = [
        feature_names[i] if i != _tree.TREE_UNDEFINED else "undefined!"
        for i in tree_.feature
    ]
    pathto=dict()

    global k
    k = 0
    def recurse(node, depth, parent):
        global k
        indent = "  " * depth

        if tree_.feature[node] != _tree.TREE_UNDEFINED:
            name = feature_name[node]
            threshold = tree_.threshold[node]
            s= "{} <= {} ".format( name, threshold, node )
            if node == 0:
                pathto[node]=s
            else:
                pathto[node]=pathto[parent]+' & ' +s

            recurse(tree_.children_left[node], depth + 1, node)
            s="{} > {}".format( name, threshold)
            if node == 0:
                pathto[node]=s
            else:
                pathto[node]=pathto[parent]+' & ' +s
            recurse(tree_.children_right[node], depth + 1, node)
        else:
            k=k+1
            print(k,')',pathto[parent], tree_.value[node])
    recurse(0, 1, 0)

3

Đây là một cách để dịch toàn bộ cây thành một biểu thức python duy nhất (không nhất thiết phải quá dễ đọc của con người) bằng thư viện SKompiler :

from skompiler import skompile
skompile(dtree.predict).to('python/code')

3

Điều này được xây dựng dựa trên câu trả lời của @paulkernfeld. Nếu bạn có một khung dữ liệu X với các tính năng của bạn và một khung dữ liệu đích y với các giá trị lại của bạn và bạn muốn có một ý tưởng về giá trị y kết thúc ở nút nào (và cả ant để vẽ biểu đồ cho phù hợp), bạn có thể thực hiện như sau:

    def tree_to_code(tree, feature_names):
        from sklearn.tree import _tree
        codelines = []
        codelines.append('def get_cat(X_tmp):\n')
        codelines.append('   catout = []\n')
        codelines.append('   for codelines in range(0,X_tmp.shape[0]):\n')
        codelines.append('      Xin = X_tmp.iloc[codelines]\n')
        tree_ = tree.tree_
        feature_name = [
            feature_names[i] if i != _tree.TREE_UNDEFINED else "undefined!"
            for i in tree_.feature
        ]
        #print "def tree({}):".format(", ".join(feature_names))

        def recurse(node, depth):
            indent = "      " * depth
            if tree_.feature[node] != _tree.TREE_UNDEFINED:
                name = feature_name[node]
                threshold = tree_.threshold[node]
                codelines.append ('{}if Xin["{}"] <= {}:\n'.format(indent, name, threshold))
                recurse(tree_.children_left[node], depth + 1)
                codelines.append( '{}else:  # if Xin["{}"] > {}\n'.format(indent, name, threshold))
                recurse(tree_.children_right[node], depth + 1)
            else:
                codelines.append( '{}mycat = {}\n'.format(indent, node))

        recurse(0, 1)
        codelines.append('      catout.append(mycat)\n')
        codelines.append('   return pd.DataFrame(catout,index=X_tmp.index,columns=["category"])\n')
        codelines.append('node_ids = get_cat(X)\n')
        return codelines
    mycode = tree_to_code(clf,X.columns.values)

    # now execute the function and obtain the dataframe with all nodes
    exec(''.join(mycode))
    node_ids = [int(x[0]) for x in node_ids.values]
    node_ids2 = pd.DataFrame(node_ids)

    print('make plot')
    import matplotlib.cm as cm
    colors = cm.rainbow(np.linspace(0, 1, 1+max( list(set(node_ids)))))
    #plt.figure(figsize=cm2inch(24, 21))
    for i in list(set(node_ids)):
        plt.plot(y[node_ids2.values==i],'o',color=colors[i], label=str(i))  
    mytitle = ['y colored by node']
    plt.title(mytitle ,fontsize=14)
    plt.xlabel('my xlabel')
    plt.ylabel(tagname)
    plt.xticks(rotation=70)       
    plt.legend(loc='upper center', bbox_to_anchor=(0.5, 1.00), shadow=True, ncol=9)
    plt.tight_layout()
    plt.show()
    plt.close 

không phải là phiên bản thanh lịch nhất nhưng nó thực hiện công việc ...


1
Đây là cách tiếp cận tốt khi bạn muốn trả về các dòng mã thay vì chỉ in chúng.
Hajar Homayouni

3

Đây là mã bạn cần

Tôi đã sửa đổi mã thích hàng đầu để thụt vào một máy tính xách tay jupyter python 3 một cách chính xác

import numpy as np
from sklearn.tree import _tree

def tree_to_code(tree, feature_names):
    tree_ = tree.tree_
    feature_name = [feature_names[i] 
                    if i != _tree.TREE_UNDEFINED else "undefined!" 
                    for i in tree_.feature]
    print("def tree({}):".format(", ".join(feature_names)))

    def recurse(node, depth):
        indent = "    " * depth
        if tree_.feature[node] != _tree.TREE_UNDEFINED:
            name = feature_name[node]
            threshold = tree_.threshold[node]
            print("{}if {} <= {}:".format(indent, name, threshold))
            recurse(tree_.children_left[node], depth + 1)
            print("{}else:  # if {} > {}".format(indent, name, threshold))
            recurse(tree_.children_right[node], depth + 1)
        else:
            print("{}return {}".format(indent, np.argmax(tree_.value[node])))

    recurse(0, 1)

2

Đây là một chức năng, in các quy tắc của cây quyết định tìm hiểu scikit theo python 3 và với phần bù cho các khối có điều kiện để làm cho cấu trúc dễ đọc hơn:

def print_decision_tree(tree, feature_names=None, offset_unit='    '):
    '''Plots textual representation of rules of a decision tree
    tree: scikit-learn representation of tree
    feature_names: list of feature names. They are set to f1,f2,f3,... if not specified
    offset_unit: a string of offset of the conditional block'''

    left      = tree.tree_.children_left
    right     = tree.tree_.children_right
    threshold = tree.tree_.threshold
    value = tree.tree_.value
    if feature_names is None:
        features  = ['f%d'%i for i in tree.tree_.feature]
    else:
        features  = [feature_names[i] for i in tree.tree_.feature]        

    def recurse(left, right, threshold, features, node, depth=0):
            offset = offset_unit*depth
            if (threshold[node] != -2):
                    print(offset+"if ( " + features[node] + " <= " + str(threshold[node]) + " ) {")
                    if left[node] != -1:
                            recurse (left, right, threshold, features,left[node],depth+1)
                    print(offset+"} else {")
                    if right[node] != -1:
                            recurse (left, right, threshold, features,right[node],depth+1)
                    print(offset+"}")
            else:
                    print(offset+"return " + str(value[node]))

    recurse(left, right, threshold, features, 0,0)

2

Bạn cũng có thể làm cho nó nhiều thông tin hơn bằng cách phân biệt nó thuộc lớp nào hoặc thậm chí bằng cách đề cập đến giá trị đầu ra của nó.

def print_decision_tree(tree, feature_names, offset_unit='    '):    
left      = tree.tree_.children_left
right     = tree.tree_.children_right
threshold = tree.tree_.threshold
value = tree.tree_.value
if feature_names is None:
    features  = ['f%d'%i for i in tree.tree_.feature]
else:
    features  = [feature_names[i] for i in tree.tree_.feature]        

def recurse(left, right, threshold, features, node, depth=0):
        offset = offset_unit*depth
        if (threshold[node] != -2):
                print(offset+"if ( " + features[node] + " <= " + str(threshold[node]) + " ) {")
                if left[node] != -1:
                        recurse (left, right, threshold, features,left[node],depth+1)
                print(offset+"} else {")
                if right[node] != -1:
                        recurse (left, right, threshold, features,right[node],depth+1)
                print(offset+"}")
        else:
                #print(offset,value[node]) 

                #To remove values from node
                temp=str(value[node])
                mid=len(temp)//2
                tempx=[]
                tempy=[]
                cnt=0
                for i in temp:
                    if cnt<=mid:
                        tempx.append(i)
                        cnt+=1
                    else:
                        tempy.append(i)
                        cnt+=1
                val_yes=[]
                val_no=[]
                res=[]
                for j in tempx:
                    if j=="[" or j=="]" or j=="." or j==" ":
                        res.append(j)
                    else:
                        val_no.append(j)
                for j in tempy:
                    if j=="[" or j=="]" or j=="." or j==" ":
                        res.append(j)
                    else:
                        val_yes.append(j)
                val_yes = int("".join(map(str, val_yes)))
                val_no = int("".join(map(str, val_no)))

                if val_yes>val_no:
                    print(offset,'\033[1m',"YES")
                    print('\033[0m')
                elif val_no>val_yes:
                    print(offset,'\033[1m',"NO")
                    print('\033[0m')
                else:
                    print(offset,'\033[1m',"Tie")
                    print('\033[0m')

recurse(left, right, threshold, features, 0,0)

nhập mô tả hình ảnh ở đây


2

Dưới đây là cách tiếp cận của tôi để trích xuất các quy tắc quyết định trong một hình thức có thể được sử dụng trực tiếp trong sql, vì vậy dữ liệu có thể được nhóm theo nút. (Dựa trên cách tiếp cận của các áp phích trước.)

Kết quả sẽ là các CASEmệnh đề tiếp theo có thể được sao chép vào câu lệnh sql, ví dụ.

SELECT COALESCE(*CASE WHEN <conditions> THEN > <NodeA>*, > *CASE WHEN <conditions> THEN <NodeB>*, > ....)NodeName,* > FROM <table or view>


import numpy as np

import pickle
feature_names=.............
features  = [feature_names[i] for i in range(len(feature_names))]
clf= pickle.loads(trained_model)
impurity=clf.tree_.impurity
importances = clf.feature_importances_
SqlOut=""

#global Conts
global ContsNode
global Path
#Conts=[]#
ContsNode=[]
Path=[]
global Results
Results=[]

def print_decision_tree(tree, feature_names, offset_unit=''    ''):    
    left      = tree.tree_.children_left
    right     = tree.tree_.children_right
    threshold = tree.tree_.threshold
    value = tree.tree_.value

    if feature_names is None:
        features  = [''f%d''%i for i in tree.tree_.feature]
    else:
        features  = [feature_names[i] for i in tree.tree_.feature]        

    def recurse(left, right, threshold, features, node, depth=0,ParentNode=0,IsElse=0):
        global Conts
        global ContsNode
        global Path
        global Results
        global LeftParents
        LeftParents=[]
        global RightParents
        RightParents=[]
        for i in range(len(left)): # This is just to tell you how to create a list.
            LeftParents.append(-1)
            RightParents.append(-1)
            ContsNode.append("")
            Path.append("")


        for i in range(len(left)): # i is node
            if (left[i]==-1 and right[i]==-1):      
                if LeftParents[i]>=0:
                    if Path[LeftParents[i]]>" ":
                        Path[i]=Path[LeftParents[i]]+" AND " +ContsNode[LeftParents[i]]                                 
                    else:
                        Path[i]=ContsNode[LeftParents[i]]                                   
                if RightParents[i]>=0:
                    if Path[RightParents[i]]>" ":
                        Path[i]=Path[RightParents[i]]+" AND not " +ContsNode[RightParents[i]]                                   
                    else:
                        Path[i]=" not " +ContsNode[RightParents[i]]                     
                Results.append(" case when  " +Path[i]+"  then ''" +"{:4d}".format(i)+ " "+"{:2.2f}".format(impurity[i])+" "+Path[i][0:180]+"''")

            else:       
                if LeftParents[i]>=0:
                    if Path[LeftParents[i]]>" ":
                        Path[i]=Path[LeftParents[i]]+" AND " +ContsNode[LeftParents[i]]                                 
                    else:
                        Path[i]=ContsNode[LeftParents[i]]                                   
                if RightParents[i]>=0:
                    if Path[RightParents[i]]>" ":
                        Path[i]=Path[RightParents[i]]+" AND not " +ContsNode[RightParents[i]]                                   
                    else:
                        Path[i]=" not "+ContsNode[RightParents[i]]                      
                if (left[i]!=-1):
                    LeftParents[left[i]]=i
                if (right[i]!=-1):
                    RightParents[right[i]]=i
                ContsNode[i]=   "( "+ features[i] + " <= " + str(threshold[i])   + " ) "

    recurse(left, right, threshold, features, 0,0,0,0)
print_decision_tree(clf,features)
SqlOut=""
for i in range(len(Results)): 
    SqlOut=SqlOut+Results[i]+ " end,"+chr(13)+chr(10)

1

Bây giờ bạn có thể sử dụng export bản.

from sklearn.tree import export_text

r = export_text(loan_tree, feature_names=(list(X_train.columns)))
print(r)

Một ví dụ hoàn chỉnh từ [sklearn] [1]

from sklearn.datasets import load_iris
from sklearn.tree import DecisionTreeClassifier
from sklearn.tree import export_text
iris = load_iris()
X = iris['data']
y = iris['target']
decision_tree = DecisionTreeClassifier(random_state=0, max_depth=2)
decision_tree = decision_tree.fit(X, y)
r = export_text(decision_tree, feature_names=iris['feature_names'])
print(r)

0

Mã Zelazny7 đã sửa đổi để tìm nạp SQL từ cây quyết định.

# SQL from decision tree

def get_lineage(tree, feature_names):
     left      = tree.tree_.children_left
     right     = tree.tree_.children_right
     threshold = tree.tree_.threshold
     features  = [feature_names[i] for i in tree.tree_.feature]
     le='<='               
     g ='>'
     # get ids of child nodes
     idx = np.argwhere(left == -1)[:,0]     

     def recurse(left, right, child, lineage=None):          
          if lineage is None:
               lineage = [child]
          if child in left:
               parent = np.where(left == child)[0].item()
               split = 'l'
          else:
               parent = np.where(right == child)[0].item()
               split = 'r'
          lineage.append((parent, split, threshold[parent], features[parent]))
          if parent == 0:
               lineage.reverse()
               return lineage
          else:
               return recurse(left, right, parent, lineage)
     print 'case '
     for j,child in enumerate(idx):
        clause=' when '
        for node in recurse(left, right, child):
            if len(str(node))<3:
                continue
            i=node
            if i[1]=='l':  sign=le 
            else: sign=g
            clause=clause+i[3]+sign+str(i[2])+' and '
        clause=clause[:-4]+' then '+str(j)
        print clause
     print 'else 99 end as clusters'

0

Rõ ràng từ lâu, ai đó đã quyết định thử thêm chức năng sau vào các hàm xuất cây chính thức của scikit (về cơ bản chỉ hỗ trợ export_graphviz)

def export_dict(tree, feature_names=None, max_depth=None) :
    """Export a decision tree in dict format.

Đây là cam kết đầy đủ của anh ấy:

https://github.com/scikit-learn/scikit-learn/blob/79bdc8f711d0af225ed6be9fdb708cea9f98a910/sklearn/tree/export.py

Không chắc chắn chính xác những gì đã xảy ra với bình luận này. Nhưng bạn cũng có thể thử sử dụng chức năng đó.

Tôi nghĩ rằng điều này đảm bảo một yêu cầu tài liệu nghiêm túc cho những người giỏi về scikit-learn để viết tài liệu đúng về sklearn.tree.TreeAPI, cấu trúc cây bên dưới DecisionTreeClassifierlộ ra như thuộc tính của nó tree_.


0

Chỉ cần sử dụng chức năng từ sklearn.tree như thế này

from sklearn.tree import export_graphviz
    export_graphviz(tree,
                out_file = "tree.dot",
                feature_names = tree.columns) //or just ["petal length", "petal width"]

Và sau đó tìm trong thư mục dự án của bạn để tìm tập tin tree.dot , sao chép TẤT CẢ nội dung và dán vào đây http://www.webgraphviz.com/ và tạo biểu đồ của bạn :)


0

Cảm ơn giải pháp tuyệt vời của @paulkerfeld. Ngày đầu của giải pháp của mình, cho tất cả những ai muốn có một phiên bản serialized cây, chỉ cần sử dụng tree.threshold, tree.children_left, tree.children_right, tree.featuretree.value. Vì các lá không có sự phân chia và do đó không có tên tính năng và trẻ em, giữ chỗ của chúng trong tree.featuretree.children_***đang _tree.TREE_UNDEFINED_tree.TREE_LEAF. Mỗi phân chia được chỉ định một chỉ số duy nhất bởi depth first search.
Lưu ý rằng tree.valuehình dạng[n, 1, 1]


0

Đây là một hàm tạo mã Python từ cây quyết định bằng cách chuyển đổi đầu ra của export_text:

import string
from sklearn.tree import export_text

def export_py_code(tree, feature_names, max_depth=100, spacing=4):
    if spacing < 2:
        raise ValueError('spacing must be > 1')

    # Clean up feature names (for correctness)
    nums = string.digits
    alnums = string.ascii_letters + nums
    clean = lambda s: ''.join(c if c in alnums else '_' for c in s)
    features = [clean(x) for x in feature_names]
    features = ['_'+x if x[0] in nums else x for x in features if x]
    if len(set(features)) != len(feature_names):
        raise ValueError('invalid feature names')

    # First: export tree to text
    res = export_text(tree, feature_names=features, 
                        max_depth=max_depth,
                        decimals=6,
                        spacing=spacing-1)

    # Second: generate Python code from the text
    skip, dash = ' '*spacing, '-'*(spacing-1)
    code = 'def decision_tree({}):\n'.format(', '.join(features))
    for line in repr(tree).split('\n'):
        code += skip + "# " + line + '\n'
    for line in res.split('\n'):
        line = line.rstrip().replace('|',' ')
        if '<' in line or '>' in line:
            line, val = line.rsplit(maxsplit=1)
            line = line.replace(' ' + dash, 'if')
            line = '{} {:g}:'.format(line, float(val))
        else:
            line = line.replace(' {} class:'.format(dash), 'return')
        code += skip + line + '\n'

    return code

Sử dụng mẫu:

res = export_py_code(tree, feature_names=names, spacing=4)
print (res)

Đầu ra mẫu:

def decision_tree(f1, f2, f3):
    # DecisionTreeClassifier(class_weight=None, criterion='gini', max_depth=3,
    #                        max_features=None, max_leaf_nodes=None,
    #                        min_impurity_decrease=0.0, min_impurity_split=None,
    #                        min_samples_leaf=1, min_samples_split=2,
    #                        min_weight_fraction_leaf=0.0, presort=False,
    #                        random_state=42, splitter='best')
    if f1 <= 12.5:
        if f2 <= 17.5:
            if f1 <= 10.5:
                return 2
            if f1 > 10.5:
                return 3
        if f2 > 17.5:
            if f2 <= 22.5:
                return 1
            if f2 > 22.5:
                return 1
    if f1 > 12.5:
        if f1 <= 17.5:
            if f3 <= 23.5:
                return 2
            if f3 > 23.5:
                return 3
        if f1 > 17.5:
            if f1 <= 25:
                return 1
            if f1 > 25:
                return 2

Ví dụ trên được tạo ra với names = ['f'+str(j+1) for j in range(NUM_FEATURES)].

Một tính năng tiện dụng là nó có thể tạo kích thước tệp nhỏ hơn với khoảng cách giảm. Chỉ cần đặt spacing=2.

Khi sử dụng trang web của chúng tôi, bạn xác nhận rằng bạn đã đọc và hiểu Chính sách cookieChính sách bảo mật của chúng tôi.
Licensed under cc by-sa 3.0 with attribution required.