Scikit predict_proba çıktı yorumu


12

Python'da scikit-learn kütüphanesi ile çalışıyorum. Aşağıdaki kodda, olasılığı tahmin ediyorum, ancak çıktıyı nasıl okuyacağımı bilmiyorum.

Verileri test etme

from sklearn.ensemble import RandomForestClassifier as RF
from sklearn import cross_validation

X = np.array([[5,5,5,5],[10,10,10,10],[1,1,1,1],[6,6,6,6],[13,13,13,13],[2,2,2,2]])
y = np.array([0,1,1,0,1,2])

Veri kümesini böl

X_train, X_test, y_train, y_test = cross_validation.train_test_split(X, y, test_size=0.5, random_state=0) 

Olasılığı hesapla

clf = RF()
clf.fit(X_train,y_train)
pred_pro = clf.predict_proba(X_test)
print pred_pro

Çıktı

[[ 1.  0.]
 [ 1.  0.]
 [ 0.  1.]]

X_test listesi 3 dizi içerir (6 örneğim ve test_size = 0,5), bu nedenle çıktıda 3 var.

Ama ben 3 değer (0,1,2) tahmin ediyorum, bu yüzden neden her dizide sadece 2 eleman alıyorum?

Çıktıyı nasıl okumalıyım?

Ayrıca, y'deki farklı değerlerin sayısını değiştirdiğimde, çıktıdaki sütun sayısının her zaman y-1'in farklı sayımı olduğunu fark ettim.


CrossValidated'a hoş geldiniz. Cevabımı aşağıda gördün mü? Sorunuzu çözdüyse, devam edin ve doğru cevap olarak işaretleyin. Aksi takdirde, neyin eksik olduğunu bana bildirin, sizin için temizlemeye çalışacağım.
Ben

Yanıtlar:


5

Bir göz atın y_train. Öyle array([0, 0, 1]). Bu, split'inizin y = 2 olduğu örneği almadığı anlamına gelir. Dolayısıyla, modelinizin y = 2 sınıfının var olduğu hakkında hiçbir fikri yoktur.

Anlamlı bir şey döndürmek için daha fazla örneğe ihtiyacınız var.

Çıktının nasıl yorumlanacağını anlamak için dokümanlara da göz atın .


1
Doğru. Ayarladıysanız y = np.array([0,2,1,0,1,2])ve random_state=2şimdi 3 çıkış sütunu göreceksiniz
tdc

Cevap sorumu çözdü. Çok teşekkür ederim. Sütunlar hangi sırada? Her zaman artan mı?
HonzaB

Koş clf.classes_. Sütunlar bu sırayla olacaktır.
Ben

Sadece böyle: clf.fit(X_train,y_train).classes_?
HonzaB

1
Bence bu işe yarayacaktır ama clf.classes_ clf.fit(X_train,y_train)
Ben
Sitemizi kullandığınızda şunları okuyup anladığınızı kabul etmiş olursunuz: Çerez Politikası ve Gizlilik Politikası.
Licensed under cc by-sa 3.0 with attribution required.