Basit bir lojistik regresyon modeli MNIST'de nasıl bir% 92 sınıflandırma doğruluğu elde eder?


64

MNIST veri setindeki tüm görüntüler ortalanmış olsa da, benzer bir ölçekte ve dönme olmadan yüz yüze olsalar bile, lineer bir modelin bu kadar yüksek bir sınıflandırma doğruluğunu nasıl elde ettiğini gösteren çok önemli bir el yazısı varyasyonu var.

Görselleştirebildiğim kadarıyla, önemli el yazısı varyasyonu dikkate alındığında, rakamlar 784 boyutlu bir uzayda doğrusal olarak birbirinden ayrılmamalı, yani farklı basamakları ayıran küçük bir kompleks (çok karmaşık olmasa da) doğrusal olmayan bir sınır olmalıdır. pozitif ve negatif sınıfların herhangi bir doğrusal sınıflandırıcı ile ayrılamayacağı , iyi belirtilmiş örneğine benzer şekilde . Çok sınıflı lojistik regresyonun tamamen lineer özelliklerle (polinom özelliği olmadan) nasıl bu kadar yüksek bir doğruluk sağladığını bana şaşırttı.XOR

Örnek olarak, görüntüdeki herhangi bir piksel verildiğinde, ve basamakların farklı el yazısı varyasyonları bu pikseli aydınlatıp açmamayı sağlayabilir. Bu nedenle, öğrenilen ağırlık dizisi ile, her piksel bir gibi bir rakam bakmak yapabilirsiniz aynı zamanda iyi tanımlanmış . Sadece bir piksel değerlerinin birleşimiyle bir rakamın mi mü olduğunu söylemek mümkün olmalıdır . Bu, rakam çiftlerinin çoğu için geçerlidir. Öyleyse, kararını bağımsız olarak tüm piksel değerlerine (hiçbir pikseller arası bağımlılığı düşünmeden) bağımsız bir şekilde temel alan lojistik regresyon, böyle yüksek doğruluklara ulaşabiliyor.232323

Bir yerlerde yanıldığımı veya resimlerdeki çeşitliliği fazla tahmin ettiğimi biliyorum. Bununla birlikte, birisinin rakamların 'neredeyse' doğrusal olarak nasıl ayrılabilir olduğu konusunda bir sezgiye yardımcı olması harika olurdu.


Sparsity ile ders kitabı İstatistiksel Öğrenme göz at: Kement ve genellemeler 3.3.1 Örnek: Handwritten Rakamlar web.stanford.edu/~hastie/StatLearnSparsity_files/SLS.pdf
Adrian

Merak ettim: cezalandırılmış bir lineer model (yani, glmnet) gibi bir şey bu problemi ne kadar iyi yapıyor? Hatırladığımı hatırlatırsam, bildirdiğin şey örneklenmemiş örnek dışı doğruluktur.
Cliff AB,

Yanıtlar:


82

tl; dr Her ne kadar bu bir sınıflandırma veri seti olsa da , girdilerden tahminlere kadar kolayca doğrudan bir harita bulabilen çok kolay bir görevdir .


Cevap:

Bu çok ilginç bir sorudur ve lojistik regresyonun basitliği sayesinde cevabı gerçekten öğrenebilirsiniz.

Lojistik regresyonun yaptığı, her bir görüntü için girişi kabul etmekte ve öngörüsünü oluşturmak için bunları ağırlıklarla çarpmaktadır. İlginç olan, girdi ve çıktı arasındaki doğrudan eşlemeden dolayı (yani gizli katman yok), her ağırlığın değerinin , her sınıfın olasılığını hesaplarken girdiden her birinin ne kadar dikkate alındığına karşılık gelmesidir. Şimdi, her bir sınıf için ağırlıklar alarak ve bunları (yani görüntü çözünürlüğü) olarak yeniden şekillendirerek , her bir sınıfın hesaplanması için hangi piksellerin en önemli olduğunu söyleyebiliriz .78478428×28

Yine bunların ağırlık olduğuna dikkat edin .

Şimdi yukarıdaki resme bakın ve ilk iki haneye odaklanın (yani sıfır ve bir). Mavi ağırlıklar, bu pikselin yoğunluğunun o sınıfa çok katkıda bulunduğunu ve kırmızı değerlerin olumsuz katkıda bulunduğunu gösterir.

Şimdi düşünün, bir kişi nasıl ? Aralarında boş olan dairesel bir şekil çizer. Bu tam olarak ağırlığın aldığı şeydi. Aslında eğer biri görüntünün ortasını çizerse, sıfır olarak negatif sayılır . Bu nedenle sıfırları tanımak için bazı karmaşık filtrelere ve üst düzey özelliklere ihtiyacınız yoktur. Çizilen piksel konumlarına bakabilir ve buna göre karar verebilirsiniz.0

için aynı şey . Görüntünün ortasında her zaman düz bir dikey çizgi vardır. Diğer her şey olumsuz sayılır.1

Rakamların geri kalanı biraz daha karmaşık, ancak küçük hayallerde , , ve . Sayıların geri kalanı biraz daha zor, bu da lojistik regresyonun 90'lara ulaşmasını engelleyen şey.2378

Bu sayede lojistik regresyonun birçok görüntüyü doğru bir şekilde elde etme şansının yüksek olduğunu ve bu yüzden çok yüksek puan aldığını görebilirsiniz.


Yukarıdaki rakamın çoğaltılması için kod biraz tarihli, ama işte gidiyorsunuz:

import tensorflow as tf
import matplotlib.pyplot as plt
from tensorflow.examples.tutorials.mnist import input_data

# Load MNIST:
mnist = input_data.read_data_sets("MNIST_data/", one_hot=True)

# Create model
x = tf.placeholder(tf.float32, shape=(None, 784))
y = tf.placeholder(tf.float32, shape=(None, 10))

W = tf.Variable(tf.zeros((784,10)))
b = tf.Variable(tf.zeros((10)))
z = tf.matmul(x, W) + b

y_hat = tf.nn.softmax(z)
cross_entropy = tf.reduce_mean(-tf.reduce_sum(y * tf.log(y_hat), reduction_indices=[1]))
optimizer = tf.train.GradientDescentOptimizer(0.5).minimize(cross_entropy) # 

correct_pred = tf.equal(tf.argmax(y_hat, 1), tf.argmax(y, 1))
accuracy = tf.reduce_mean(tf.cast(correct_pred, tf.float32))

# Train model
batch_size = 64
with tf.Session() as sess:

    loss_tr, acc_tr, loss_ts, acc_ts = [], [], [], []

    sess.run(tf.global_variables_initializer()) 

    for step in range(1, 1001):

        x_batch, y_batch = mnist.train.next_batch(batch_size) 
        sess.run(optimizer, feed_dict={x: x_batch, y: y_batch})

        l_tr, a_tr = sess.run([cross_entropy, accuracy], feed_dict={x: x_batch, y: y_batch})
        l_ts, a_ts = sess.run([cross_entropy, accuracy], feed_dict={x: mnist.test.images, y: mnist.test.labels})
        loss_tr.append(l_tr)
        acc_tr.append(a_tr)
        loss_ts.append(l_ts)
        acc_ts.append(a_ts)

    weights = sess.run(W)      
    print('Test Accuracy =', sess.run(accuracy, feed_dict={x: mnist.test.images, y: mnist.test.labels})) 

# Plotting:
for i in range(10):
    plt.subplot(2, 5, i+1)
    weight = weights[:,i].reshape([28,28])
    plt.title(i)
    plt.imshow(weight, cmap='RdBu')  # as noted by @Eric Duminil, cmap='gray' makes the numbers stand out more
    frame1 = plt.gca()
    frame1.axes.get_xaxis().set_visible(False)
    frame1.axes.get_yaxis().set_visible(False)

9
Gösterim için teşekkürler. Bu ağırlık görüntüleri, doğruluğun bu kadar yüksek olmasından daha net olmasını sağlar. El yazısıyla yazılan bir rakam görüntüsünün, resmin gerçek etiketine karşılık gelen ağırlık görüntüsüyle nokta çarpımı, çoğu için diğer ağırlık etiketleriyle nokta ürününe kıyasla en yüksek gibi görünüyor (% 92'si bana çok benziyor) MNIST görüntülerin. Yine de, ve veya ve karışıklık matrisini inceledikten sonra birbirleriyle nadiren yanlış sınıflandırılmaları biraz şaşırtıcı . Neyse, olan bu. Veri asla yalan söylemez. :)2378
Nitish Agarwal

13
Elbette, MNIST örneklerinin, sınıflandırıcı onları görmeden önce ortalanmasına, ölçeklendirilmesine ve kontrast normalleştirilmesine yardımcı olur. "Sıfırın kenarı aslında kutunun ortasından geçerse ne olur?" Gibi soruları yanıtlamanız gerekmez. çünkü ön işlemci zaten sıfırın tamamını aynı görünmesi yönünde uzun bir yol kat etti.
Ocaklar

1
@EricDuminil Önerinizle birlikte bir komut dosyası ekledim. Giriş için çok teşekkürler! : D
Djib2011

1
@NitishAgarwal, Bu cevabın Sorunuzun Cevabı olduğunu düşünüyorsanız, bunu böyle işaretleyin.
sintax,

7
Bu tür işlemlerle ilgilenen ancak özellikle aşina olmayan biri için bu cevap, mekaniğin harika bir sezgisel örneğidir.
chrylis
Sitemizi kullandığınızda şunları okuyup anladığınızı kabul etmiş olursunuz: Çerez Politikası ve Gizlilik Politikası.
Licensed under cc by-sa 3.0 with attribution required.