Karar kurallarını scikit-learn karar ağacından nasıl çıkarabilirim?


157

Temel karar kurallarını (veya 'karar yollarını') bir karar ağacındaki eğitimli bir ağaçtan metin listesi olarak alabilir miyim?

Gibi bir şey:

if A>0.4 then if B<0.2 then if C>0.8 then class='X'

Yardımınız için teşekkürler.



Bu soruna hiç cevap buldunuz mu? Karar ağacı kurallarını neredeyse tam olarak listelediğiniz gibi bir SAS veri adım biçiminde dışa aktarmak zorundayım.
Zelazny7

1
Sklearn-porter paketini karar ağaçlarını (rastgele orman ve güçlendirilmiş ağaçlar) C, Java, JavaScript ve diğerlerine aktarmak ve aktarmak için kullanabilirsiniz .
Darius

Bu bağlantıyı kontrol edebilirsiniz- kdnuggets.com/2017/05/…
yogesh agrawal

Yanıtlar:


139

Bu cevabın buradaki diğer cevaplardan daha doğru olduğuna inanıyorum:

from sklearn.tree import _tree

def tree_to_code(tree, feature_names):
    tree_ = tree.tree_
    feature_name = [
        feature_names[i] if i != _tree.TREE_UNDEFINED else "undefined!"
        for i in tree_.feature
    ]
    print "def tree({}):".format(", ".join(feature_names))

    def recurse(node, depth):
        indent = "  " * depth
        if tree_.feature[node] != _tree.TREE_UNDEFINED:
            name = feature_name[node]
            threshold = tree_.threshold[node]
            print "{}if {} <= {}:".format(indent, name, threshold)
            recurse(tree_.children_left[node], depth + 1)
            print "{}else:  # if {} > {}".format(indent, name, threshold)
            recurse(tree_.children_right[node], depth + 1)
        else:
            print "{}return {}".format(indent, tree_.value[node])

    recurse(0, 1)

Bu geçerli bir Python işlevi yazdırır. İşte girdisini döndürmeye çalışan bir ağaç için örnek çıktı, 0 ile 10 arasında bir sayı.

def tree(f0):
  if f0 <= 6.0:
    if f0 <= 1.5:
      return [[ 0.]]
    else:  # if f0 > 1.5
      if f0 <= 4.5:
        if f0 <= 3.5:
          return [[ 3.]]
        else:  # if f0 > 3.5
          return [[ 4.]]
      else:  # if f0 > 4.5
        return [[ 5.]]
  else:  # if f0 > 6.0
    if f0 <= 8.5:
      if f0 <= 7.5:
        return [[ 7.]]
      else:  # if f0 > 7.5
        return [[ 8.]]
    else:  # if f0 > 8.5
      return [[ 9.]]

İşte diğer cevaplarda gördüğüm bazı engeller:

  1. Kullanılması tree_.threshold == -2Bir düğüm bir yaprak olup olmadığına karar vermek iyi bir fikir değildir. Ya -2 eşiği olan gerçek bir karar düğümü ise? Bunun yerine, tree.featureveya 'ye bakmalısınız tree.children_*.
  2. Satır features = [feature_names[i] for i in tree_.feature], sklearn sürümümle kilitleniyor, çünkü bazı değerleri tree.tree_.feature-2 (özellikle yaprak düğümleri için).
  3. Özyinelemeli işlevde birden fazla if ifadesine gerek yoktur, sadece biri iyidir.

1
Bu kod benim için harika çalışıyor. Ancak, çıkış kodu bir insanın anlaması neredeyse imkansız böylece 500 + özellik_adı var. Yalnızca işleve merak ettiğim özellik_adlarını girmeme izin vermenin bir yolu var mı?
user3768495

1
Önceki yoruma katılıyorum. Sınıf indeksini döndürmek için fonksiyonun IIUC print "{}return {}".format(indent, tree_.value[node])olarak değiştirilmesi gerekir print "{}return {}".format(indent, np.argmax(tree_.value[node][0])).
soupault

1
@paulkernfeld Ah evet, görebildiğinizi görüyorum RandomForestClassifier.estimators_ , ancak tahmincilerin sonuçlarını nasıl birleştireceğimizi .
Nathan Lloyd

6
Python 3'te bu çalışmayı alamadım, _tree bitleri hiç çalışacak gibi görünmüyor ve TREE_UNDEFINED tanımlanmadı. Bu bağlantı bana yardımcı oldu. Dışa aktarılan kod python'da
Josiah

1
@Josiah, python3 içinde çalışması için basılı ifadelere () ekleyin. eg print "bla"=>print("bla")
Nir

48

Sklearn tarafından oluşturulan karar ağaçlarından kuralları çıkarmak için kendi işlevimi yarattım:

import pandas as pd
import numpy as np
from sklearn.tree import DecisionTreeClassifier

# dummy data:
df = pd.DataFrame({'col1':[0,1,2,3],'col2':[3,4,5,6],'dv':[0,1,0,1]})

# create decision tree
dt = DecisionTreeClassifier(max_depth=5, min_samples_leaf=1)
dt.fit(df.ix[:,:2], df.dv)

Bu işlev önce düğümlerle başlar (alt dizilerde -1 ile tanımlanır) ve daha sonra ebeveynleri özyinelemeli olarak bulur. Buna düğümün 'soy' diyorum. Yol boyunca if / then / else SAS mantığı için oluşturmam gereken değerleri yakalarım:

def get_lineage(tree, feature_names):
     left      = tree.tree_.children_left
     right     = tree.tree_.children_right
     threshold = tree.tree_.threshold
     features  = [feature_names[i] for i in tree.tree_.feature]

     # get ids of child nodes
     idx = np.argwhere(left == -1)[:,0]     

     def recurse(left, right, child, lineage=None):          
          if lineage is None:
               lineage = [child]
          if child in left:
               parent = np.where(left == child)[0].item()
               split = 'l'
          else:
               parent = np.where(right == child)[0].item()
               split = 'r'

          lineage.append((parent, split, threshold[parent], features[parent]))

          if parent == 0:
               lineage.reverse()
               return lineage
          else:
               return recurse(left, right, parent, lineage)

     for child in idx:
          for node in recurse(left, right, child):
               print node

Aşağıdaki tuples setleri, if / then / else ifadeleri oluşturmak için ihtiyacım olan her şeyi içerir. doSAS'ta blokları kullanmayı sevmiyorum , bu yüzden bir düğümün tüm yolunu tanımlayan mantık oluşturuyorum. Tüplerden sonraki tek tamsayı, bir yoldaki terminal düğümünün kimliğidir. Önceki tuples'lar bu düğümü oluşturmak için birleşir.

In [1]: get_lineage(dt, df.columns)
(0, 'l', 0.5, 'col1')
1
(0, 'r', 0.5, 'col1')
(2, 'l', 4.5, 'col2')
3
(0, 'r', 0.5, 'col1')
(2, 'r', 4.5, 'col2')
(4, 'l', 2.5, 'col1')
5
(0, 'r', 0.5, 'col1')
(2, 'r', 4.5, 'col2')
(4, 'r', 2.5, 'col1')
6

Örnek ağacın GraphViz çıktısı


Bu tür bir ağaç doğru çünkü col1 tekrar geliyor biri col1 <= 0.50000 ve bir col1 <= 2.5000 evet ise, bu kütüphanede kullanılan herhangi bir özyineleme
dilek mi

sağ dalın kayıtları var (0.5, 2.5]. Ağaçlar özyinelemeli bölümleme ile yapılır. Bir değişkenin birden çok kez seçilmesini engelleyen hiçbir şey yoktur.
Zelazny7

tamam özyinelemeli kısmı açıklayabilir miyim xactly olur neden kodumu kullandım ve benzer sonuç görülür
jayant singh

38

Bazı sözde kod yazdırmak için Zelazny7 tarafından gönderilen kodu değiştirdim :

def get_code(tree, feature_names):
        left      = tree.tree_.children_left
        right     = tree.tree_.children_right
        threshold = tree.tree_.threshold
        features  = [feature_names[i] for i in tree.tree_.feature]
        value = tree.tree_.value

        def recurse(left, right, threshold, features, node):
                if (threshold[node] != -2):
                        print "if ( " + features[node] + " <= " + str(threshold[node]) + " ) {"
                        if left[node] != -1:
                                recurse (left, right, threshold, features,left[node])
                        print "} else {"
                        if right[node] != -1:
                                recurse (left, right, threshold, features,right[node])
                        print "}"
                else:
                        print "return " + str(value[node])

        recurse(left, right, threshold, features, 0)

get_code(dt, df.columns)aynı örneği ararsanız aşağıdakileri elde edersiniz:

if ( col1 <= 0.5 ) {
return [[ 1.  0.]]
} else {
if ( col2 <= 4.5 ) {
return [[ 0.  1.]]
} else {
if ( col1 <= 2.5 ) {
return [[ 1.  0.]]
} else {
return [[ 0.  1.]]
}
}
}

1
Yukarıdaki ifadede dönüş ifadesindeki tam olarak [[1. 0.]] ne anlama geldiğini söyleyebilir misiniz? Ben bir Python adamı değilim, ama aynı tür şeyler üzerinde çalışıyorum. Bu yüzden, benim için daha kolay olması için bazı detayları kanıtlarsanız benim için iyi olacak.
Subhradip Bose

1
@ user3156186 Bu, '0' sınıfında bir nesne ve '1' sınıfında sıfır nesne olduğu anlamına gelir
Daniele

1
@Daniele, sınıfların nasıl sıralandığını biliyor musunuz? Alfasayısal tahmin ediyorum, ama hiçbir yerde onay bulamadım.
IanS

Teşekkürler! Eşik değerinin aslında -2 olduğu uç senaryo senaryosu (threshold[node] != -2)için ( left[node] != -1)(alt düğümlerin kimliklerini almak için aşağıdaki yönteme benzer)
tlingf

@Daniele, nasıl "get_code" "bir değer döndürmek ve" baskı "değil, bir işlev yapmak için herhangi bir fikir, çünkü başka bir işleve göndermek gerekir?
RoyaumeIX

17

Scikit öğren export_text, kuralları bir ağaçtan çıkarmak için 0.21 (Mayıs 2019) sürümünde adlandırılan lezzetli yeni bir yöntem tanıttı . Buradaki belgeler . Artık özel bir işlev oluşturmak gerekli değildir.

Modelinize uyduktan sonra, sadece iki satır kod gerekir. İlk olarak, içe aktarın export_text:

from sklearn.tree.export import export_text

İkinci olarak, kurallarınızı içerecek bir nesne oluşturun. Kuralların daha okunabilir görünmesini sağlamak için, feature_namesargümanı kullanın ve özellik adlarınızın bir listesini iletin. Örneğin, modeliniz çağrılırsa modelve özellikleriniz adlı bir veri çerçevesinde adlandırılırsa X_train, şu adlı bir nesne oluşturabilirsiniz tree_rules:

tree_rules = export_text(model, feature_names=list(X_train))

Sonra yazdırın veya kaydedin tree_rules. Çıktınız şöyle görünecek:

|--- Age <= 0.63
|   |--- EstimatedSalary <= 0.61
|   |   |--- Age <= -0.16
|   |   |   |--- class: 0
|   |   |--- Age >  -0.16
|   |   |   |--- EstimatedSalary <= -0.06
|   |   |   |   |--- class: 0
|   |   |   |--- EstimatedSalary >  -0.06
|   |   |   |   |--- EstimatedSalary <= 0.40
|   |   |   |   |   |--- EstimatedSalary <= 0.03
|   |   |   |   |   |   |--- class: 1

14

Yeni Orada DecisionTreeClassifieryöntem decision_pathde, 0.18.0 sürümü. Geliştiriciler, kapsamlı (iyi belgelenmiş) bir izlenecek yol sağlar .

İzlenecek yoldaki ağaç yapısını yazdıran ilk kod bölümü iyi görünüyor. Ancak, bir örnek sorgulamak için ikinci bölümdeki kodu değiştirdim. Değişikliklerim ile belirtildi# <--

Düzenle Aşağıdaki # <--kodda işaretlenen değişiklikler , # 8653 ve # 10951 numaralı çekme isteklerinde hatalar belirtildikten sonra, izlenecek bağlantıda güncellenmiştir . Şimdi takip etmek çok daha kolay.

sample_id = 0
node_index = node_indicator.indices[node_indicator.indptr[sample_id]:
                                    node_indicator.indptr[sample_id + 1]]

print('Rules used to predict sample %s: ' % sample_id)
for node_id in node_index:

    if leave_id[sample_id] == node_id:  # <-- changed != to ==
        #continue # <-- comment out
        print("leaf node {} reached, no decision here".format(leave_id[sample_id])) # <--

    else: # < -- added else to iterate through decision nodes
        if (X_test[sample_id, feature[node_id]] <= threshold[node_id]):
            threshold_sign = "<="
        else:
            threshold_sign = ">"

        print("decision id node %s : (X[%s, %s] (= %s) %s %s)"
              % (node_id,
                 sample_id,
                 feature[node_id],
                 X_test[sample_id, feature[node_id]], # <-- changed i to sample_id
                 threshold_sign,
                 threshold[node_id]))

Rules used to predict sample 0: 
decision id node 0 : (X[0, 3] (= 2.4) > 0.800000011921)
decision id node 2 : (X[0, 2] (= 5.1) > 4.94999980927)
leaf node 4 reached, no decision here

sample_idDiğer örneklerin karar yollarını görmek için simgesini değiştirin . Geliştiricilere bu değişiklikleri sormadım, örnek üzerinde çalışırken daha sezgisel görünüyordu.


sen arkadaşım bir efsanesin! Belirli bir örnek için karar ağacının nasıl çizileceği hakkında bir fikriniz var mı? çok yardım için teşekkür ederiz

1
Teşekkürler Victor, çizim gereksinimleri bir kullanıcının ihtiyaçlarına özel olabileceğinden, bunu ayrı bir soru olarak sormak muhtemelen en iyisidir. Çıktının nasıl görünmesini istediğinize dair bir fikir verirseniz muhtemelen iyi bir yanıt alırsınız.
Kevin

hey kevin, şu soruyu oluşturdum stackoverflow.com/questions/48888893/…

Şuna bir göz atmak çok nazik olur mu? stackoverflow.com/questions/52654280/…
Alexander Chervov

Lütfen node_index adlı kısmı açıklayabilir misiniz? bu ne işe yarıyor?
Anindya Sankar Dey

12
from StringIO import StringIO
out = StringIO()
out = tree.export_graphviz(clf, out_file=out)
print out.getvalue()

Bir digraf Ağacı görebilirsiniz. Sonra, clf.tree_.featureve clf.tree_.valuesırasıyla düğüm bölme özelliği ve düğüm değerleri dizisi vardır. Bu github kaynağından daha fazla ayrıntıya başvurabilirsiniz .


1
Evet, ağacın nasıl çizileceğini biliyorum - ama daha metinsel versiyona ihtiyacım var - kurallar. şuna
Dror Hilman

4

Herkes çok yardımcı olduğu için Zelazny7 ve Daniele'nin güzel çözümlerine bir değişiklik ekleyeceğim. Bu, daha okunabilir hale getirmek için sekmeleri olan python 2.7 içindir:

def get_code(tree, feature_names, tabdepth=0):
    left      = tree.tree_.children_left
    right     = tree.tree_.children_right
    threshold = tree.tree_.threshold
    features  = [feature_names[i] for i in tree.tree_.feature]
    value = tree.tree_.value

    def recurse(left, right, threshold, features, node, tabdepth=0):
            if (threshold[node] != -2):
                    print '\t' * tabdepth,
                    print "if ( " + features[node] + " <= " + str(threshold[node]) + " ) {"
                    if left[node] != -1:
                            recurse (left, right, threshold, features,left[node], tabdepth+1)
                    print '\t' * tabdepth,
                    print "} else {"
                    if right[node] != -1:
                            recurse (left, right, threshold, features,right[node], tabdepth+1)
                    print '\t' * tabdepth,
                    print "}"
            else:
                    print '\t' * tabdepth,
                    print "return " + str(value[node])

    recurse(left, right, threshold, features, 0)

3

Aşağıdaki kodlar, anaconda python 2.7 ve paket kuralları "pydot-ng" altında karar kurallarıyla PDF dosyası yapmak için benim yaklaşımımdır. Umarım faydalıdır.

from sklearn import tree

clf = tree.DecisionTreeClassifier(max_leaf_nodes=n)
clf_ = clf.fit(X, data_y)

feature_names = X.columns
class_name = clf_.classes_.astype(int).astype(str)

def output_pdf(clf_, name):
    from sklearn import tree
    from sklearn.externals.six import StringIO
    import pydot_ng as pydot
    dot_data = StringIO()
    tree.export_graphviz(clf_, out_file=dot_data,
                         feature_names=feature_names,
                         class_names=class_name,
                         filled=True, rounded=True,
                         special_characters=True,
                          node_ids=1,)
    graph = pydot.graph_from_dot_data(dot_data.getvalue())
    graph.write_pdf("%s.pdf"%name)

output_pdf(clf_, name='filename%s'%n)

burada bir ağaç grafik gösterisi


3

Bunu yaptım ama kuralların bu formatta yazılması gerekiyordu

if A>0.4 then if B<0.2 then if C>0.8 then class='X' 

Böylece @paulkernfeld'in cevabını uyarladım (teşekkürler) ihtiyaçlarınıza göre özelleştirebilirsiniz

def tree_to_code(tree, feature_names, Y):
    tree_ = tree.tree_
    feature_name = [
        feature_names[i] if i != _tree.TREE_UNDEFINED else "undefined!"
        for i in tree_.feature
    ]
    pathto=dict()

    global k
    k = 0
    def recurse(node, depth, parent):
        global k
        indent = "  " * depth

        if tree_.feature[node] != _tree.TREE_UNDEFINED:
            name = feature_name[node]
            threshold = tree_.threshold[node]
            s= "{} <= {} ".format( name, threshold, node )
            if node == 0:
                pathto[node]=s
            else:
                pathto[node]=pathto[parent]+' & ' +s

            recurse(tree_.children_left[node], depth + 1, node)
            s="{} > {}".format( name, threshold)
            if node == 0:
                pathto[node]=s
            else:
                pathto[node]=pathto[parent]+' & ' +s
            recurse(tree_.children_right[node], depth + 1, node)
        else:
            k=k+1
            print(k,')',pathto[parent], tree_.value[node])
    recurse(0, 1, 0)

3

İşte tüm ağacı SKompiler kütüphanesini kullanarak tek bir (mutlaka insan tarafından okunamayan) bir python ifadesine çevirmenin bir yolu :

from skompiler import skompile
skompile(dtree.predict).to('python/code')

3

Bu @paulkernfeld'in cevabı üzerine kuruludur. Özelliklerinizle birlikte bir veri kareniz X ve resonlarınızla birlikte bir hedef veri kareniz varsa ve hangi y değerinin hangi düğümde sona erdiğini (ve buna göre çizmek için karınca) bir fikir edinmek istiyorsanız aşağıdakileri yapabilirsiniz:

    def tree_to_code(tree, feature_names):
        from sklearn.tree import _tree
        codelines = []
        codelines.append('def get_cat(X_tmp):\n')
        codelines.append('   catout = []\n')
        codelines.append('   for codelines in range(0,X_tmp.shape[0]):\n')
        codelines.append('      Xin = X_tmp.iloc[codelines]\n')
        tree_ = tree.tree_
        feature_name = [
            feature_names[i] if i != _tree.TREE_UNDEFINED else "undefined!"
            for i in tree_.feature
        ]
        #print "def tree({}):".format(", ".join(feature_names))

        def recurse(node, depth):
            indent = "      " * depth
            if tree_.feature[node] != _tree.TREE_UNDEFINED:
                name = feature_name[node]
                threshold = tree_.threshold[node]
                codelines.append ('{}if Xin["{}"] <= {}:\n'.format(indent, name, threshold))
                recurse(tree_.children_left[node], depth + 1)
                codelines.append( '{}else:  # if Xin["{}"] > {}\n'.format(indent, name, threshold))
                recurse(tree_.children_right[node], depth + 1)
            else:
                codelines.append( '{}mycat = {}\n'.format(indent, node))

        recurse(0, 1)
        codelines.append('      catout.append(mycat)\n')
        codelines.append('   return pd.DataFrame(catout,index=X_tmp.index,columns=["category"])\n')
        codelines.append('node_ids = get_cat(X)\n')
        return codelines
    mycode = tree_to_code(clf,X.columns.values)

    # now execute the function and obtain the dataframe with all nodes
    exec(''.join(mycode))
    node_ids = [int(x[0]) for x in node_ids.values]
    node_ids2 = pd.DataFrame(node_ids)

    print('make plot')
    import matplotlib.cm as cm
    colors = cm.rainbow(np.linspace(0, 1, 1+max( list(set(node_ids)))))
    #plt.figure(figsize=cm2inch(24, 21))
    for i in list(set(node_ids)):
        plt.plot(y[node_ids2.values==i],'o',color=colors[i], label=str(i))  
    mytitle = ['y colored by node']
    plt.title(mytitle ,fontsize=14)
    plt.xlabel('my xlabel')
    plt.ylabel(tagname)
    plt.xticks(rotation=70)       
    plt.legend(loc='upper center', bbox_to_anchor=(0.5, 1.00), shadow=True, ncol=9)
    plt.tight_layout()
    plt.show()
    plt.close 

en şık versiyonu değil ama işi yapıyor ...


1
Kod satırlarını yazdırmak yerine döndürmek istediğinizde bu iyi bir yaklaşımdır.
Hajar Homayouni

3

Bu ihtiyacınız olan kod

Bir jupyter notebook python 3 girintili en sevilen kodu değiştirdim

import numpy as np
from sklearn.tree import _tree

def tree_to_code(tree, feature_names):
    tree_ = tree.tree_
    feature_name = [feature_names[i] 
                    if i != _tree.TREE_UNDEFINED else "undefined!" 
                    for i in tree_.feature]
    print("def tree({}):".format(", ".join(feature_names)))

    def recurse(node, depth):
        indent = "    " * depth
        if tree_.feature[node] != _tree.TREE_UNDEFINED:
            name = feature_name[node]
            threshold = tree_.threshold[node]
            print("{}if {} <= {}:".format(indent, name, threshold))
            recurse(tree_.children_left[node], depth + 1)
            print("{}else:  # if {} > {}".format(indent, name, threshold))
            recurse(tree_.children_right[node], depth + 1)
        else:
            print("{}return {}".format(indent, np.argmax(tree_.value[node])))

    recurse(0, 1)

2

İşte bir fonksiyon, python 3 altında bir scikit-öğren karar ağacının kurallarını yazdırma ve yapıyı daha okunabilir hale getirmek için koşullu bloklar için ofsetler:

def print_decision_tree(tree, feature_names=None, offset_unit='    '):
    '''Plots textual representation of rules of a decision tree
    tree: scikit-learn representation of tree
    feature_names: list of feature names. They are set to f1,f2,f3,... if not specified
    offset_unit: a string of offset of the conditional block'''

    left      = tree.tree_.children_left
    right     = tree.tree_.children_right
    threshold = tree.tree_.threshold
    value = tree.tree_.value
    if feature_names is None:
        features  = ['f%d'%i for i in tree.tree_.feature]
    else:
        features  = [feature_names[i] for i in tree.tree_.feature]        

    def recurse(left, right, threshold, features, node, depth=0):
            offset = offset_unit*depth
            if (threshold[node] != -2):
                    print(offset+"if ( " + features[node] + " <= " + str(threshold[node]) + " ) {")
                    if left[node] != -1:
                            recurse (left, right, threshold, features,left[node],depth+1)
                    print(offset+"} else {")
                    if right[node] != -1:
                            recurse (left, right, threshold, features,right[node],depth+1)
                    print(offset+"}")
            else:
                    print(offset+"return " + str(value[node]))

    recurse(left, right, threshold, features, 0,0)

2

Hangi sınıfa ait olduğunu ayırt ederek veya hatta çıktı değerinden söz ederek daha bilgilendirici de yapabilirsiniz.

def print_decision_tree(tree, feature_names, offset_unit='    '):    
left      = tree.tree_.children_left
right     = tree.tree_.children_right
threshold = tree.tree_.threshold
value = tree.tree_.value
if feature_names is None:
    features  = ['f%d'%i for i in tree.tree_.feature]
else:
    features  = [feature_names[i] for i in tree.tree_.feature]        

def recurse(left, right, threshold, features, node, depth=0):
        offset = offset_unit*depth
        if (threshold[node] != -2):
                print(offset+"if ( " + features[node] + " <= " + str(threshold[node]) + " ) {")
                if left[node] != -1:
                        recurse (left, right, threshold, features,left[node],depth+1)
                print(offset+"} else {")
                if right[node] != -1:
                        recurse (left, right, threshold, features,right[node],depth+1)
                print(offset+"}")
        else:
                #print(offset,value[node]) 

                #To remove values from node
                temp=str(value[node])
                mid=len(temp)//2
                tempx=[]
                tempy=[]
                cnt=0
                for i in temp:
                    if cnt<=mid:
                        tempx.append(i)
                        cnt+=1
                    else:
                        tempy.append(i)
                        cnt+=1
                val_yes=[]
                val_no=[]
                res=[]
                for j in tempx:
                    if j=="[" or j=="]" or j=="." or j==" ":
                        res.append(j)
                    else:
                        val_no.append(j)
                for j in tempy:
                    if j=="[" or j=="]" or j=="." or j==" ":
                        res.append(j)
                    else:
                        val_yes.append(j)
                val_yes = int("".join(map(str, val_yes)))
                val_no = int("".join(map(str, val_no)))

                if val_yes>val_no:
                    print(offset,'\033[1m',"YES")
                    print('\033[0m')
                elif val_no>val_yes:
                    print(offset,'\033[1m',"NO")
                    print('\033[0m')
                else:
                    print(offset,'\033[1m',"Tie")
                    print('\033[0m')

recurse(left, right, threshold, features, 0,0)

resim açıklamasını buraya girin


2

İşte karar kurallarını doğrudan sql içinde kullanılabilecek bir formda çıkarmak için yaklaşımım, böylece veriler düğüm tarafından gruplandırılabilir. (Önceki posterlerin yaklaşımlarına dayanarak.)

Sonuç, CASEsql deyimine kopyalanabilen sonraki maddeler olacaktır , örn.

SELECT COALESCE(*CASE WHEN <conditions> THEN > <NodeA>*, > *CASE WHEN <conditions> THEN <NodeB>*, > ....)NodeName,* > FROM <table or view>


import numpy as np

import pickle
feature_names=.............
features  = [feature_names[i] for i in range(len(feature_names))]
clf= pickle.loads(trained_model)
impurity=clf.tree_.impurity
importances = clf.feature_importances_
SqlOut=""

#global Conts
global ContsNode
global Path
#Conts=[]#
ContsNode=[]
Path=[]
global Results
Results=[]

def print_decision_tree(tree, feature_names, offset_unit=''    ''):    
    left      = tree.tree_.children_left
    right     = tree.tree_.children_right
    threshold = tree.tree_.threshold
    value = tree.tree_.value

    if feature_names is None:
        features  = [''f%d''%i for i in tree.tree_.feature]
    else:
        features  = [feature_names[i] for i in tree.tree_.feature]        

    def recurse(left, right, threshold, features, node, depth=0,ParentNode=0,IsElse=0):
        global Conts
        global ContsNode
        global Path
        global Results
        global LeftParents
        LeftParents=[]
        global RightParents
        RightParents=[]
        for i in range(len(left)): # This is just to tell you how to create a list.
            LeftParents.append(-1)
            RightParents.append(-1)
            ContsNode.append("")
            Path.append("")


        for i in range(len(left)): # i is node
            if (left[i]==-1 and right[i]==-1):      
                if LeftParents[i]>=0:
                    if Path[LeftParents[i]]>" ":
                        Path[i]=Path[LeftParents[i]]+" AND " +ContsNode[LeftParents[i]]                                 
                    else:
                        Path[i]=ContsNode[LeftParents[i]]                                   
                if RightParents[i]>=0:
                    if Path[RightParents[i]]>" ":
                        Path[i]=Path[RightParents[i]]+" AND not " +ContsNode[RightParents[i]]                                   
                    else:
                        Path[i]=" not " +ContsNode[RightParents[i]]                     
                Results.append(" case when  " +Path[i]+"  then ''" +"{:4d}".format(i)+ " "+"{:2.2f}".format(impurity[i])+" "+Path[i][0:180]+"''")

            else:       
                if LeftParents[i]>=0:
                    if Path[LeftParents[i]]>" ":
                        Path[i]=Path[LeftParents[i]]+" AND " +ContsNode[LeftParents[i]]                                 
                    else:
                        Path[i]=ContsNode[LeftParents[i]]                                   
                if RightParents[i]>=0:
                    if Path[RightParents[i]]>" ":
                        Path[i]=Path[RightParents[i]]+" AND not " +ContsNode[RightParents[i]]                                   
                    else:
                        Path[i]=" not "+ContsNode[RightParents[i]]                      
                if (left[i]!=-1):
                    LeftParents[left[i]]=i
                if (right[i]!=-1):
                    RightParents[right[i]]=i
                ContsNode[i]=   "( "+ features[i] + " <= " + str(threshold[i])   + " ) "

    recurse(left, right, threshold, features, 0,0,0,0)
print_decision_tree(clf,features)
SqlOut=""
for i in range(len(Results)): 
    SqlOut=SqlOut+Results[i]+ " end,"+chr(13)+chr(10)

1

Artık export_text kullanabilirsiniz.

from sklearn.tree import export_text

r = export_text(loan_tree, feature_names=(list(X_train.columns)))
print(r)

[Sklearn] [1] 'den tam bir örnek

from sklearn.datasets import load_iris
from sklearn.tree import DecisionTreeClassifier
from sklearn.tree import export_text
iris = load_iris()
X = iris['data']
y = iris['target']
decision_tree = DecisionTreeClassifier(random_state=0, max_depth=2)
decision_tree = decision_tree.fit(X, y)
r = export_text(decision_tree, feature_names=iris['feature_names'])
print(r)

0

Zelazny7'nin kodu, karar ağacından SQL getirecek şekilde değiştirildi.

# SQL from decision tree

def get_lineage(tree, feature_names):
     left      = tree.tree_.children_left
     right     = tree.tree_.children_right
     threshold = tree.tree_.threshold
     features  = [feature_names[i] for i in tree.tree_.feature]
     le='<='               
     g ='>'
     # get ids of child nodes
     idx = np.argwhere(left == -1)[:,0]     

     def recurse(left, right, child, lineage=None):          
          if lineage is None:
               lineage = [child]
          if child in left:
               parent = np.where(left == child)[0].item()
               split = 'l'
          else:
               parent = np.where(right == child)[0].item()
               split = 'r'
          lineage.append((parent, split, threshold[parent], features[parent]))
          if parent == 0:
               lineage.reverse()
               return lineage
          else:
               return recurse(left, right, parent, lineage)
     print 'case '
     for j,child in enumerate(idx):
        clause=' when '
        for node in recurse(left, right, child):
            if len(str(node))<3:
                continue
            i=node
            if i[1]=='l':  sign=le 
            else: sign=g
            clause=clause+i[3]+sign+str(i[2])+' and '
        clause=clause[:-4]+' then '+str(j)
        print clause
     print 'else 99 end as clusters'

0

Görünüşe göre uzun zaman önce birileri resmi scikit'in ağaç dışa aktarma işlevlerine aşağıdaki işlevi eklemeye karar verdi (temel olarak sadece export_graphviz'i destekliyor)

def export_dict(tree, feature_names=None, max_depth=None) :
    """Export a decision tree in dict format.

İşte tam taahhüdü:

https://github.com/scikit-learn/scikit-learn/blob/79bdc8f711d0af225ed6be9fdb708cea9f98a910/sklearn/tree/export.py

Bu yoruma ne olduğundan tam olarak emin değilim. Ancak bu işlevi kullanmayı da deneyebilirsiniz.

Sanırım bu, scikit-öğrenmenin iyi kişilerine , özniteliği olarak ortaya sklearn.tree.Treeçıkan temel ağaç yapısı olan API'yi düzgün bir şekilde belgelendirmek için ciddi bir belge talebi gerektirdiğini düşünüyorum .DecisionTreeClassifiertree_


0

Sadece sklearn.tree fonksiyonunu şu şekilde kullanın

from sklearn.tree import export_graphviz
    export_graphviz(tree,
                out_file = "tree.dot",
                feature_names = tree.columns) //or just ["petal length", "petal width"]

Ve sonra tree.dot dosyası için proje klasörünüze bakın, TÜM içeriği kopyalayın ve buraya yapıştırın http://www.webgraphviz.com/ ve grafiğinizi oluşturun :)


0

@Paulkerfeld'in harika çözümü için teşekkür ederiz. Onun çözümü üzerine, ağaçların tefrika sürümüne sahip olmak isteyenler, sadece kullanmak için tree.threshold, tree.children_left, tree.children_right, tree.featureve tree.value. Yaprakları yarıklara sahip olmuştur ve dolayısıyla hiçbir isim ve çocukları, onların yer tutucu özelliği olmadığından tree.featureve tree.children_***vardır _tree.TREE_UNDEFINEDve _tree.TREE_LEAF. Her bölüme benzersiz bir dizin atanır depth first search. Şeklin
olduğuna dikkat edintree.value[n, 1, 1]


0

İşte çıktısını dönüştürerek bir karar ağacından Python kodu üreten bir fonksiyon export_text:

import string
from sklearn.tree import export_text

def export_py_code(tree, feature_names, max_depth=100, spacing=4):
    if spacing < 2:
        raise ValueError('spacing must be > 1')

    # Clean up feature names (for correctness)
    nums = string.digits
    alnums = string.ascii_letters + nums
    clean = lambda s: ''.join(c if c in alnums else '_' for c in s)
    features = [clean(x) for x in feature_names]
    features = ['_'+x if x[0] in nums else x for x in features if x]
    if len(set(features)) != len(feature_names):
        raise ValueError('invalid feature names')

    # First: export tree to text
    res = export_text(tree, feature_names=features, 
                        max_depth=max_depth,
                        decimals=6,
                        spacing=spacing-1)

    # Second: generate Python code from the text
    skip, dash = ' '*spacing, '-'*(spacing-1)
    code = 'def decision_tree({}):\n'.format(', '.join(features))
    for line in repr(tree).split('\n'):
        code += skip + "# " + line + '\n'
    for line in res.split('\n'):
        line = line.rstrip().replace('|',' ')
        if '<' in line or '>' in line:
            line, val = line.rsplit(maxsplit=1)
            line = line.replace(' ' + dash, 'if')
            line = '{} {:g}:'.format(line, float(val))
        else:
            line = line.replace(' {} class:'.format(dash), 'return')
        code += skip + line + '\n'

    return code

Örnek kullanım:

res = export_py_code(tree, feature_names=names, spacing=4)
print (res)

Örnek çıktı:

def decision_tree(f1, f2, f3):
    # DecisionTreeClassifier(class_weight=None, criterion='gini', max_depth=3,
    #                        max_features=None, max_leaf_nodes=None,
    #                        min_impurity_decrease=0.0, min_impurity_split=None,
    #                        min_samples_leaf=1, min_samples_split=2,
    #                        min_weight_fraction_leaf=0.0, presort=False,
    #                        random_state=42, splitter='best')
    if f1 <= 12.5:
        if f2 <= 17.5:
            if f1 <= 10.5:
                return 2
            if f1 > 10.5:
                return 3
        if f2 > 17.5:
            if f2 <= 22.5:
                return 1
            if f2 > 22.5:
                return 1
    if f1 > 12.5:
        if f1 <= 17.5:
            if f3 <= 23.5:
                return 2
            if f3 > 23.5:
                return 3
        if f1 > 17.5:
            if f1 <= 25:
                return 1
            if f1 > 25:
                return 2

Yukarıdaki örnek, names = ['f'+str(j+1) for j in range(NUM_FEATURES)] .

Kullanışlı bir özellik, daha az boşluk ile daha küçük dosya boyutu üretebilmesidir. Sadece hazır spacing=2.

Sitemizi kullandığınızda şunları okuyup anladığınızı kabul etmiş olursunuz: Çerez Politikası ve Gizlilik Politikası.
Licensed under cc by-sa 3.0 with attribution required.