Burada ortalama ve standart sapmayı tahmin etmek için kullanılan Beklenti Maksimizasyonu (EM) örneği. Kod Python'dadır, ancak dili bilmeseniz bile takip etmesi kolay olmalıdır.
EM için motivasyon
Aşağıda gösterilen kırmızı ve mavi noktalar, her biri belirli bir ortalama ve standart sapmaya sahip iki farklı normal dağılımdan alınmıştır:
Kırmızı dağılım için "gerçek" ortalama ve standart sapma parametrelerinin makul tahminlerini hesaplamak için, kırmızı noktalara çok kolay bir şekilde bakabilir ve her birinin konumunu kaydedebilir ve sonra bilinen formülleri kullanabiliriz (ve benzer şekilde mavi grup için). .
Şimdi, iki puan grubunun olduğunu bildiğimiz durumu düşünün, ancak hangi noktanın hangi gruba ait olduğunu göremeyiz. Başka bir deyişle, renkler gizlenir:
Noktaları iki gruba nasıl ayıracağınız hiç belli değil. Artık kırmızı dağılımın veya mavi dağılımın parametrelerine ilişkin konumlara bakıp tahminleri hesaplayamıyoruz.
EM'nin sorunu çözmek için kullanılabileceği yer burasıdır.
Parametreleri tahmin etmek için EM kullanma
Yukarıda gösterilen noktaları oluşturmak için kullanılan kod. Noktaların çizildiği normal dağılımların gerçek araçlarını ve standart sapmalarını görebilirsiniz. Değişkenler red
ve blue
her bir noktanın sırasıyla kırmızı ve mavi gruplardaki pozisyonlarını tutun:
import numpy as np
from scipy import stats
np.random.seed(110) # for reproducible random results
# set parameters
red_mean = 3
red_std = 0.8
blue_mean = 7
blue_std = 2
# draw 20 samples from normal distributions with red/blue parameters
red = np.random.normal(red_mean, red_std, size=20)
blue = np.random.normal(blue_mean, blue_std, size=20)
both_colours = np.sort(np.concatenate((red, blue)))
Biz ise olabilir her noktanın renk görmek, denemek ve kütüphane işlevleri kullanarak ortalamaları ve standart sapmaları kurtarmak olacaktır:
>>> np.mean(red)
2.802
>>> np.std(red)
0.871
>>> np.mean(blue)
6.932
>>> np.std(blue)
2.195
Fakat renkler bizden gizlendiğinden, EM işlemine başlayacağız ...
İlk önce, sadece her grubun parametreleri için değerleri tahmin ediyoruz ( adım 1 ). Bu tahminlerin iyi olması gerekmiyor:
# estimates for the mean
red_mean_guess = 1.1
blue_mean_guess = 9
# estimates for the standard deviation
red_std_guess = 2
blue_std_guess = 1.7
Oldukça kötü tahminler - araçlar, bir grup grubun herhangi bir "ortasından" uzun bir yol gibi görünüyor.
EM ile devam etmek ve bu tahminleri geliştirmek için, ortalama ve standart sapma için bu tahminler altında görünen her veri noktasının (gizli renginden bağımsız olarak) olasılığını hesaplıyoruz ( adım 2 ).
Değişken both_colours
her veri noktasını tutar. İşlev stats.norm
, verilen parametrelerle normal bir dağılım altındaki noktanın olasılığını hesaplar:
likelihood_of_red = stats.norm(red_mean_guess, red_std_guess).pdf(both_colours)
likelihood_of_blue = stats.norm(blue_mean_guess, blue_std_guess).pdf(both_colours)
Bu bize, örneğin mevcut tahminlerimizde 1.761'deki veri noktasının kırmızı (0.189) 'dan mavi (0.00003)' ten çok daha muhtemel olduğunu söylemektedir.
Bu iki olasılık değerini ağırlıklara dönüştürebiliriz ( 3. adım ), böylece 1'i toplarlar:
likelihood_total = likelihood_of_red + likelihood_of_blue
red_weight = likelihood_of_red / likelihood_total
blue_weight = likelihood_of_blue / likelihood_total
Mevcut tahminlerimiz ve yeni hesaplanan ağırlıklarımızla, şimdi parametreler için yeni, muhtemelen daha iyi tahminler hesaplayabiliriz ( adım 4 ). Ortalama için bir fonksiyona ve standart sapma için bir fonksiyona ihtiyacımız var:
def estimate_mean(data, weight):
return np.sum(data * weight) / np.sum(weight)
def estimate_std(data, weight, mean):
variance = np.sum(weight * (data - mean)**2) / np.sum(weight)
return np.sqrt(variance)
Bunlar, normal fonksiyonlara verinin ortalama ve standart sapmasına çok benziyor. Fark, weight
her veri noktasına ağırlık atayan bir parametrenin kullanılmasıdır .
Bu ağırlıklandırma EM'nin anahtarıdır. Bir veri noktasındaki bir rengin ağırlığı arttıkça, veri noktası o rengin parametreleri için sonraki tahminleri o kadar fazla etkiler. Sonuç olarak, bu her parametreyi doğru yönde çekme etkisine sahiptir.
Yeni tahminler bu fonksiyonlarla hesaplanır:
# new estimates for standard deviation
blue_std_guess = estimate_std(both_colours, blue_weight, blue_mean_guess)
red_std_guess = estimate_std(both_colours, red_weight, red_mean_guess)
# new estimates for mean
red_mean_guess = estimate_mean(both_colours, red_weight)
blue_mean_guess = estimate_mean(both_colours, blue_weight)
EM süreci daha sonra bu yeni tahminlerle 2. adımdan itibaren tekrarlanır. Belirli bir yineleme sayısı için adımları tekrarlayabiliriz (örneğin, 20) ya da yakınsak parametreleri görene kadar.
Beş tekrardan sonra, ilk kötü tahminlerimizin iyileşmeye başladığını görüyoruz:
20 yinelemeden sonra, EM süreci az çok yakınlaşmıştır:
Karşılaştırma için, renk bilgilerinin gizlenmediği hesaplanan değerlerle karşılaştırıldığında EM sürecinin sonuçları:
| EM guess | Actual
----------+----------+--------
Red mean | 2.910 | 2.802
Red std | 0.854 | 0.871
Blue mean | 6.838 | 6.932
Blue std | 2.227 | 2.195
Not: Bu yanıt yığın taşması benim cevap uyarlanmıştır burada .