Beklenti Maksimizasyonu tekniğinin sezgisel açıklaması nedir? [kapalı]


109

Beklenti Maksimizasyonu (EM), verileri sınıflandırmak için bir tür olasılıklı yöntemdir. Bir sınıflandırıcı değilse, yanılıyorsam lütfen düzeltin.

Bu EM tekniğinin sezgisel açıklaması nedir? Ne expectationburada ve varlık nedir maximized?


12
Beklenti maksimizasyon algoritması nedir? , Nature Biotechnology 26 , 897–899 (2008), algoritmanın nasıl çalıştığını gösteren güzel bir resme sahiptir.
chl

@chl In kısım b ait güzel resim , nasıl Z (vs. yani 0.45xA, 0.55xB) üzerindeki olasılık dağılımının değerlerini elde ettiniz?
Noob Saibot


3
@Chl'nin bahsettiği resmin bağlantısı güncellendi .
n1k31t4

Yanıtlar:


120

Not: Bu cevabın arkasındaki kod burada bulunabilir .


Kırmızı ve mavi olmak üzere iki farklı gruptan örneklenmiş bazı verilerimiz olduğunu varsayalım:

görüntü açıklamasını buraya girin

Burada hangi veri noktasının kırmızı veya mavi gruba ait olduğunu görebiliriz. Bu, her bir grubu karakterize eden parametreleri bulmayı kolaylaştırır. Örneğin, kırmızı grubun ortalaması 3 civarındadır, mavi grubun ortalaması 7 civarındadır (ve istersek tam anlamını bulabiliriz).

Bu, genel olarak maksimum olasılık tahmini olarak bilinir . Bazı veriler göz önüne alındığında, bu verileri en iyi açıklayan bir parametrenin (veya parametrelerin) değerini hesaplarız.

Şimdi hayal edemez hangi gruptan örneklenmiş hangi değer görüyoruz. Bize her şey mor görünüyor:

görüntü açıklamasını buraya girin

Burada iki değer grubu olduğu bilgisine sahibiz , ancak belirli bir değerin hangi gruba ait olduğunu bilmiyoruz.

Bu verilere en iyi uyan kırmızı grup ve mavi grup için ortalamaları hala tahmin edebilir miyiz?

Evet, sık sık yapabiliriz! Beklenti Maksimizasyonu , bunu yapmamız için bize bir yol sunar. Algoritmanın arkasındaki çok genel fikir şudur:

  1. Her parametrenin ne olabileceğine dair bir ilk tahminle başlayın.
  2. Her parametrenin veri noktasını oluşturma olasılığını hesaplayın .
  3. Her veri noktası için, bir parametre tarafından üretilme olasılığına bağlı olarak daha fazla kırmızı mı yoksa daha fazla mavi mi olduğunu gösteren ağırlıkları hesaplayın. Ağırlıkları verilerle birleştirin ( beklenti ).
  4. Ağırlık ayarlı verileri ( maksimizasyon ) kullanarak parametreler için daha iyi bir tahmin hesaplayın .
  5. Parametre tahmini yakınsayıncaya kadar 2'den 4'e kadar olan adımları tekrarlayın (işlem farklı bir tahmin üretmeyi durdurur).

Bu adımların daha fazla açıklamaya ihtiyacı vardır, bu nedenle yukarıda açıklanan sorunu gözden geçireceğim.

Örnek: ortalama ve standart sapmayı tahmin etme

Bu örnekte Python kullanacağım, ancak bu dile aşina değilseniz kodun anlaşılması oldukça kolay olacaktır.

Yukarıdaki görüntüdeki gibi dağıtılmış değerlere sahip kırmızı ve mavi olmak üzere iki grubumuz olduğunu varsayalım. Spesifik olarak, her grup aşağıdaki parametrelerle normal bir dağılımdan alınan bir değer içerir :

import numpy as np
from scipy import stats

np.random.seed(110) # for reproducible results

# set parameters
red_mean = 3
red_std = 0.8

blue_mean = 7
blue_std = 2

# draw 20 samples from normal distributions with red/blue parameters
red = np.random.normal(red_mean, red_std, size=20)
blue = np.random.normal(blue_mean, blue_std, size=20)

both_colours = np.sort(np.concatenate((red, blue))) # for later use...

İşte yine bu kırmızı ve mavi grupların bir görüntüsü (sizi yukarı kaydırmaktan kurtarmak için):

görüntü açıklamasını buraya girin

Her noktanın rengini (yani hangi gruba ait olduğunu) görebildiğimizde, her grup için ortalama ve standart sapmayı tahmin etmek çok kolaydır. Kırmızı ve mavi değerleri NumPy'deki yerleşik işlevlere aktarıyoruz. Örneğin:

>>> np.mean(red)
2.802
>>> np.std(red)
0.871
>>> np.mean(blue)
6.932
>>> np.std(blue)
2.195

Ama ya eğer edemez noktalarının renkleri görmek? Yani kırmızı veya mavi yerine her nokta mor renkle boyanmıştır.

Kırmızı ve mavi gruplar için ortalama ve standart sapma parametrelerini denemek ve kurtarmak için Beklenti Maksimizasyonunu kullanabiliriz.

İlk adımımız ( yukarıdaki 1. adım ), her grubun ortalama ve standart sapması için parametre değerlerini tahmin etmektir. Akıllıca tahmin etmemize gerek yok; İstediğimiz sayıları seçebiliriz:

# estimates for the mean
red_mean_guess = 1.1
blue_mean_guess = 9

# estimates for the standard deviation
red_std_guess = 2
blue_std_guess = 1.7

Bu parametre tahminleri, şuna benzeyen çan eğrileri üretir:

görüntü açıklamasını buraya girin

Bunlar kötü tahminlerdir. Her iki araç da (dikey noktalı çizgiler), örneğin, mantıklı nokta grupları için herhangi bir tür "orta" dan çok uzak görünür. Bu tahminleri iyileştirmek istiyoruz.

Bir sonraki adım ( adım 2 ), her bir veri noktasının mevcut parametre tahminlerinin altında görünme olasılığını hesaplamaktır:

likelihood_of_red = stats.norm(red_mean_guess, red_std_guess).pdf(both_colours)
likelihood_of_blue = stats.norm(blue_mean_guess, blue_std_guess).pdf(both_colours)

Burada, kırmızı ve mavi için ortalama ve standart sapmadaki mevcut tahminlerimizi kullanarak her veri noktasını normal dağılım için olasılık yoğunluğu fonksiyonuna koyduk . Bu bize, örneğin, mevcut tahminlerimizle 1.761'deki veri noktasının mavi (0.00003) yerine kırmızı (0.189) olma olasılığının çok daha yüksek olduğunu söylüyor.

Her veri noktası için, bu iki olabilirlik değerini ağırlıklara dönüştürebiliriz ( 3. adım ), böylece toplamları aşağıdaki gibi 1 olur:

likelihood_total = likelihood_of_red + likelihood_of_blue

red_weight = likelihood_of_red / likelihood_total
blue_weight = likelihood_of_blue / likelihood_total

Mevcut tahminlerimiz ve yeni hesaplanmış ağırlıklarımızla, artık kırmızı ve mavi grupların ortalama ve standart sapması için yeni tahminler hesaplayabiliriz ( 4. adım ).

Tüm veri noktalarını kullanarak ortalama ve standart sapmayı iki kez hesaplıyoruz , ancak farklı ağırlıklarla: bir kez kırmızı ağırlıklar ve bir kez mavi ağırlıklar için.

Sezginin temel noktası, bir veri noktasındaki bir rengin ağırlığı ne kadar büyükse, veri noktasının o rengin parametreleri için sonraki tahminleri o kadar fazla etkilemesidir. Bu, parametreleri doğru yönde "çekme" etkisine sahiptir.

def estimate_mean(data, weight):
    """
    For each data point, multiply the point by the probability it
    was drawn from the colour's distribution (its "weight").

    Divide by the total weight: essentially, we're finding where 
    the weight is centred among our data points.
    """
    return np.sum(data * weight) / np.sum(weight)

def estimate_std(data, weight, mean):
    """
    For each data point, multiply the point's squared difference
    from a mean value by the probability it was drawn from
    that distribution (its "weight").

    Divide by the total weight: essentially, we're finding where 
    the weight is centred among the values for the difference of
    each data point from the mean.

    This is the estimate of the variance, take the positive square
    root to find the standard deviation.
    """
    variance = np.sum(weight * (data - mean)**2) / np.sum(weight)
    return np.sqrt(variance)

# new estimates for standard deviation
blue_std_guess = estimate_std(both_colours, blue_weight, blue_mean_guess)
red_std_guess = estimate_std(both_colours, red_weight, red_mean_guess)

# new estimates for mean
red_mean_guess = estimate_mean(both_colours, red_weight)
blue_mean_guess = estimate_mean(both_colours, blue_weight)

Parametreler için yeni tahminlerimiz var. Bunları tekrar iyileştirmek için 2. adıma geri dönüp işlemi tekrar edebiliriz. Bunu tahminler yakınsayıncaya kadar veya bazı yineleme yapıldıktan sonra yapıyoruz ( adım 5 ).

Verilerimiz için, bu sürecin ilk beş yinelemesi şuna benzer (son yinelemeler daha güçlü görünüme sahiptir):

görüntü açıklamasını buraya girin

Araçların zaten bazı değerlerde birleştiğini ve eğrilerin şekillerinin (standart sapma tarafından yönetilen) daha kararlı hale geldiğini görüyoruz.

20 yineleme boyunca devam edersek, aşağıdakilerle sonuçlanırız:

görüntü açıklamasını buraya girin

EM süreci, gerçek değerlere çok yakın olduğu ortaya çıkan aşağıdaki değerlere yakınsamıştır (renkleri görebiliyoruz - gizli değişkenler yok):

          | EM guess | Actual |  Delta
----------+----------+--------+-------
Red mean  |    2.910 |  2.802 |  0.108
Red std   |    0.854 |  0.871 | -0.017
Blue mean |    6.838 |  6.932 | -0.094
Blue std  |    2.227 |  2.195 |  0.032

Yukarıdaki kodda, standart sapma için yeni tahminin, önceki yinelemenin ortalama tahmini kullanılarak hesaplandığını fark etmiş olabilirsiniz. Nihayetinde, ilk önce ortalama için yeni bir değer hesaplayıp hesaplamamız önemli değildir, çünkü sadece bazı merkezi noktalar etrafında değerlerin (ağırlıklı) varyansını buluyoruz. Yine de parametreler için tahminlerin yakınsadığını göreceğiz.


Ya bunun geldiği normal dağılımların sayısını bile bilmiyorsak? Burada k = 2 dağılımlarına bir örnek aldınız, ayrıca k ve k parametre setlerini de tahmin edebilir miyiz?
stackit

1
@stackit: Bu durumda EM sürecinin bir parçası olarak en olası k değerini hesaplamanın basit bir genel yolu olduğundan emin değilim. Ana sorun, EM'yi bulmak istediğimiz her parametre için tahminlerle başlatmamız gerektiğidir ve bu, başlamadan önce k'yi bilmemiz / tahmin etmemiz gerektiğini gerektirir. Ancak burada bir gruba ait noktaların oranını EM vasıtasıyla tahmin etmek mümkündür. Belki k'yi fazla tahmin edersek, iki grup dışındaki tüm grupların oranı sıfıra yakın düşer. Bunu denemedim, bu yüzden pratikte ne kadar işe yarayacağını bilmiyorum.
Alex Riley

1
@AlexRiley Yeni ortalama ve standart sapma tahminlerini hesaplamak için formüllerden biraz daha bahsedebilir misiniz?
Lemon

2
@AlexRiley Açıklama için teşekkürler. Yeni standart sapma tahminleri neden ortalamanın eski tahmini kullanılarak hesaplanıyor? Ya ortalamanın yeni tahminleri önce bulunursa?
GoodDeeds

1
@Lemon GoodDeeds Kaushal - sorularınıza geç yanıt verdiğim için özür dilerim. Ortaya koyduğunuz noktaları ele almak için cevabı düzenlemeye çalıştım. Ayrıca, bu cevapta kullanılan tüm kodu burada bir not defterinde erişilebilir hale getirdim (ayrıca değindiğim bazı noktaların daha ayrıntılı açıklamalarını da içeriyor).
Alex Riley

36

EM, modelinizdeki bazı değişkenler gözlemlenmediğinde (ör. Gizli değişkenleriniz olduğunda) olasılık fonksiyonunu maksimize etmek için bir algoritmadır.

Bir işlevi maksimize etmeye çalışıyorsak, bir işlevi maksimize etmek için neden sadece mevcut makineyi kullanmıyoruz diye sorabilirsiniz. Bunu türevleri alıp sıfıra koyarak maksimize etmeye çalışırsanız, çoğu durumda birinci dereceden koşulların bir çözümü olmadığını görürsünüz. Model parametrelerinizi çözmek için gözlenmeyen verilerinizin dağılımını bilmeniz gereken bir tavuk ve yumurta sorunu vardır; ancak gözlemlenmemiş verilerinizin dağılımı, model parametrelerinizin bir fonksiyonudur.

EM, gözlemlenmemiş veriler için bir dağılımı yinelemeli olarak tahmin ederek, ardından gerçek olasılık fonksiyonunda daha düşük bir sınır olan bir şeyi maksimize ederek ve yakınsamaya kadar tekrarlayarak model parametrelerini tahmin ederek bu sorunu aşmaya çalışır:

EM algoritması

Model parametrelerinizin değerleri için tahminle başlayın

E-adımı: Eksik değerleri olan her veri noktası için, model parametreleriyle ilgili mevcut tahmininiz ve gözlemlenen verilere göre eksik verilerin dağılımını çözmek için model denkleminizi kullanın (her bir eksik için bir dağıtım için çözdüğünüzü unutmayın. değer, beklenen değer için değil). Artık her bir eksik değer için bir dağılımımız olduğuna göre, gözlenmeyen değişkenlere göre olasılık fonksiyonunun beklentisini hesaplayabiliriz . Model parametresi için tahminimiz doğruysa, bu beklenen olasılık, gözlemlenen verilerimizin gerçek olasılığı olacaktır; parametreler doğru değilse, bu sadece bir alt sınır olacaktır.

M adımı: Artık, içinde gözlenmeyen değişkenler olmayan beklenen bir olasılık fonksiyonuna sahip olduğumuza göre, model parametrelerinizin yeni bir tahminini elde etmek için, tamamen gözlemlenen durumda yapacağınız gibi fonksiyonu maksimize edin.

Yakınsamaya kadar tekrarlayın.


5
E-adımınızı anlamıyorum. Sorunun bir kısmı, bu şeyleri öğrenirken aynı terminolojiyi kullanan insanları bulamıyorum. Öyleyse model denklem derken neyi kastediyorsunuz? Olasılık dağılımını çözerek ne demek istediğini bilmiyorum?
user678392

27

Beklenti Maksimizasyonu algoritmasını anlamak için basit bir tarif:

1- Do ve Batzoglou'nun hazırladığı EM eğitim belgesini okuyun .

2- Kafanızda soru işaretleri olabilir, bu matematik yığını değişim sayfasındaki açıklamalara bir göz atın .

3- Python'da yazdığım, 1. maddenin EM eğitim kağıdındaki örneği açıklayan bu koda bakın:

Uyarı: Python geliştiricisi olmadığım için kod dağınık / yetersiz olabilir. Ama işi yapıyor.

import numpy as np
import math

#### E-M Coin Toss Example as given in the EM tutorial paper by Do and Batzoglou* #### 

def get_mn_log_likelihood(obs,probs):
    """ Return the (log)likelihood of obs, given the probs"""
    # Multinomial Distribution Log PMF
    # ln (pdf)      =             multinomial coeff            *   product of probabilities
    # ln[f(x|n, p)] = [ln(n!) - (ln(x1!)+ln(x2!)+...+ln(xk!))] + [x1*ln(p1)+x2*ln(p2)+...+xk*ln(pk)]     

    multinomial_coeff_denom= 0
    prod_probs = 0
    for x in range(0,len(obs)): # loop through state counts in each observation
        multinomial_coeff_denom = multinomial_coeff_denom + math.log(math.factorial(obs[x]))
        prod_probs = prod_probs + obs[x]*math.log(probs[x])

    multinomial_coeff = math.log(math.factorial(sum(obs))) -  multinomial_coeff_denom
    likelihood = multinomial_coeff + prod_probs
    return likelihood

# 1st:  Coin B, {HTTTHHTHTH}, 5H,5T
# 2nd:  Coin A, {HHHHTHHHHH}, 9H,1T
# 3rd:  Coin A, {HTHHHHHTHH}, 8H,2T
# 4th:  Coin B, {HTHTTTHHTT}, 4H,6T
# 5th:  Coin A, {THHHTHHHTH}, 7H,3T
# so, from MLE: pA(heads) = 0.80 and pB(heads)=0.45

# represent the experiments
head_counts = np.array([5,9,8,4,7])
tail_counts = 10-head_counts
experiments = zip(head_counts,tail_counts)

# initialise the pA(heads) and pB(heads)
pA_heads = np.zeros(100); pA_heads[0] = 0.60
pB_heads = np.zeros(100); pB_heads[0] = 0.50

# E-M begins!
delta = 0.001  
j = 0 # iteration counter
improvement = float('inf')
while (improvement>delta):
    expectation_A = np.zeros((5,2), dtype=float) 
    expectation_B = np.zeros((5,2), dtype=float)
    for i in range(0,len(experiments)):
        e = experiments[i] # i'th experiment
        ll_A = get_mn_log_likelihood(e,np.array([pA_heads[j],1-pA_heads[j]])) # loglikelihood of e given coin A
        ll_B = get_mn_log_likelihood(e,np.array([pB_heads[j],1-pB_heads[j]])) # loglikelihood of e given coin B

        weightA = math.exp(ll_A) / ( math.exp(ll_A) + math.exp(ll_B) ) # corresponding weight of A proportional to likelihood of A 
        weightB = math.exp(ll_B) / ( math.exp(ll_A) + math.exp(ll_B) ) # corresponding weight of B proportional to likelihood of B                            

        expectation_A[i] = np.dot(weightA, e) 
        expectation_B[i] = np.dot(weightB, e)

    pA_heads[j+1] = sum(expectation_A)[0] / sum(sum(expectation_A)); 
    pB_heads[j+1] = sum(expectation_B)[0] / sum(sum(expectation_B)); 

    improvement = max( abs(np.array([pA_heads[j+1],pB_heads[j+1]]) - np.array([pA_heads[j],pB_heads[j]]) ))
    j = j+1

Programınızın hem A hem de B ile 0.66 sonuçlanacağını buldum, ayrıca scala kullanarak da uyguluyorum, sonucun 0.66 olduğunu da buldum, kontrol etmeye yardımcı olabilir misiniz?
zjffdu

Bir elektronik tablo kullanarak, 0.66 sonuçlarınızı yalnızca ilk tahminlerim eşitse bulurum. Aksi takdirde, eğiticinin çıktısını yeniden oluşturabilirim.
soakley

@zjffdu, EM size 0.66 döndürmeden önce kaç tane yineleme yapıyor? Eşit değerlerle başlatırsanız, yerel bir maksimumda takılıp kalıyor olabilir ve yineleme sayısının son derece düşük olduğunu göreceksiniz (çünkü iyileşme yoktur).
Zhubarb


16

Teknik olarak "EM" terimi biraz daha az belirtilmiştir, ancak genel EM ilkesinin bir örneği olan Gauss Karışımı Modelleme küme analizi tekniğine atıfta bulunduğunuzu varsayıyorum .

Aslında EM küme analizi bir sınıflandırıcı değildir . Bazı insanların kümelemeyi "denetimsiz sınıflandırma" olarak gördüğünü biliyorum, ama aslında kümeleme analizi oldukça farklı bir şey.

Temel fark ve insanların küme analiziyle her zaman sahip oldukları büyük yanlış anlama sınıflandırması şudur: küme analizinde "doğru çözüm" yoktur . Bu bir bilgi keşif yöntemidir, aslında yeni bir şey bulmak içindir ! Bu, değerlendirmeyi çok zor hale getirir. Genellikle referans olarak bilinen bir sınıflandırma kullanılarak değerlendirilir, ancak bu her zaman uygun değildir: sahip olduğunuz sınıflandırma verilerde ne olduğunu yansıtabilir veya yansıtmayabilir.

Size bir örnek vereyim: Cinsiyet verileri de dahil olmak üzere geniş bir müşteri veri kümeniz var. Bu veri kümesini "erkek" ve "dişi" olarak bölen bir yöntem, onu mevcut sınıflarla karşılaştırdığınızda idealdir. Yeni kullanıcılar için artık cinsiyetlerini tahmin edebildiğinizden, bu iyi bir "tahmin" tarzıdır. "Bilgi keşfi" şeklinde düşündüğünüzde, bu aslında kötü, çünkü verilerde yeni bir yapı keşfetmek istediniz . Örneğin, verileri yaşlı insanlara ve çocuklara ayıran bir yöntem, ancak erkek / kadın sınıfına göre alabileceği kadar kötü puan alacaktır . Bununla birlikte, bu mükemmel bir kümeleme sonucu olur (yaş belirtilmemişse).

Şimdi EM'ye geri dönün. Esasen, verilerinizin çok değişkenli normal dağılımlardan oluştuğunu varsayar (bunun özellikle küme sayısını sabitlediğinizde çok güçlü bir varsayım olduğuna dikkat edin !). Daha sonra modeli ve modele nesne atamasını dönüşümlü olarak geliştirerek bunun için yerel bir optimal model bulmaya çalışır .

Bir sınıflandırma bağlamında en iyi sonuçlar için, sınıf sayısından daha büyük küme sayısını seçin veya hatta kümelemeyi yalnızca tek sınıflara uygulayın (sınıfta bir yapı olup olmadığını öğrenmek için!).

"Arabaları", "bisikletleri" ve "kamyonları" birbirinden ayırmak için bir sınıflandırıcı eğitmek istediğinizi varsayalım. Verilerin tam olarak 3 normal dağılımdan oluştuğunu varsaymanın çok az faydası vardır. Bununla birlikte, birden fazla araba türü (ve kamyon ve bisiklet) olduğunu varsayabilirsiniz . Dolayısıyla, bu üç sınıf için bir sınıflandırıcı eğitmek yerine, arabaları, kamyonları ve bisikletleri her biri 10 kümeye (veya belki 10 araba, 3 kamyon ve 3 bisiklet, her neyse) ayırırsınız, ardından bu 30 sınıfı ayırmak için bir sınıflandırıcı eğitirsiniz ve sonra sınıf sonucunu orijinal sınıflara geri birleştirin. Ayrıca, Trikes gibi sınıflandırılması özellikle zor olan bir küme olduğunu da keşfedebilirsiniz. Onlar biraz araba ve biraz da bisiklet. Veya kamyonlardan çok büyük arabalara benzeyen teslimat kamyonları.


EM nasıl yetersiz tanımlanır?
sam boosalis

Birden fazla versiyonu var. Teknik olarak, Lloyd tarzı k anlamına gelen "EM" de diyebilirsiniz. Hangi modeli kullandığınızı belirtmeniz gerekiyor .
ÇIKTI - Anony-Mousse

2

Diğer cevaplar iyidir, başka bir bakış açısı sağlamaya ve sorunun sezgisel kısmını ele almaya çalışacağım.

EM (Beklenti-Maksimizasyon) algoritması , dualite kullanan bir yinelemeli algoritmalar sınıfının bir varyantıdır.

Alıntı (vurgu benim):

Matematikte, genel olarak konuşursak, bir dualite, kavramları, teoremleri veya matematiksel yapıları diğer kavramlara, teoremlere veya yapılara bire bir biçimde, genellikle (ancak her zaman değil) bir evrişim işlemi aracılığıyla çevirir: eğer A, B'dir, sonra B'nin ikilisi, A'dır. bazen sabit noktaları vardır , böylece A'nın duali A'nın kendisidir.

Genellikle , çift bir B nesnesi A bir muhafaza bir şekilde A ile ilgilidir simetri veya uyumluluk . Örneğin AB = const

Dualite kullanan (önceki anlamda) yinelemeli algoritmalara örnekler şunlardır:

  1. Greatest Common Divisor için Öklid algoritması ve türevleri
  2. Gram – Schmidt Vektör Temeli algoritması ve çeşitleri
  3. Aritmetik Ortalama - Geometrik Ortalama Eşitsizlik ve çeşitleri
  4. Beklenti-Maksimizasyon algoritması ve türevleri (ayrıca geometrik bir bilgi görünümü için buraya bakın )
  5. (.. diğer benzer algoritmalar ..)

Benzer bir şekilde, EM algoritması iki ikili maksimizasyon adımı olarak da görülebilir :

.. [EM], parametrelerin ve gözlenmeyen değişkenler üzerindeki dağılımın birleşik bir fonksiyonunu maksimize ettiği görülmektedir. E-adımı, bu fonksiyonu, gözlemlenmemiş değişkenler üzerindeki dağılıma göre maksimize eder; parametrelere göre M-adımı ..

Dualite kullanan yinelemeli bir algoritmada, bir denge (veya sabit) yakınsama noktasının açık (veya örtük) varsayımı vardır (EM için bu Jensen'in eşitsizliği kullanılarak kanıtlanmıştır)

Dolayısıyla bu tür algoritmaların ana hatları şu şekildedir:

  1. E benzeri adım: Verilen y'ye göre en iyi çözümü x bulun sabit tutulmasına .
  2. M benzeri adım (ikili): Sabit tutulan (önceki adımda hesaplandığı gibi) x'e göre en iyi y çözümünü bulun .
  3. Sonlandırma Kriteri / Yakınsama adımı: Yakınsama (veya belirtilen yineleme sayısına ulaşılana) kadar 1 ve 2 numaralı adımları güncellenmiş x , y değerleri ile tekrarlayın.

Not a (global) optimuma Böyle bir algoritma yakınsak, bunun bir yapılandırma bulduğunda her iki anlamda da iyi (her ikisi de, yani X alanı / parametreler ve y alan / parametreler). Ancak algoritma yerel bir optimum bulabilir , global optimum .

Bunun, algoritmanın ana hatlarının sezgisel açıklaması olduğunu söyleyebilirim

İstatistiksel argümanlar ve uygulamalar için, diğer cevaplar iyi açıklamalar vermiştir (bu cevaptaki referansları da kontrol edin)


2

Kabul edilen cevap, EM'yi açıklayan iyi bir iş çıkaran Chuong EM Paper'a atıfta bulunuyor . Ayrıca kağıdı daha detaylı anlatan bir youtube videosu da var .

Özetlemek gerekirse, senaryo şu şekildedir:

1st:  {H,T,T,T,H,H,T,H,T,H} 5 Heads, 5 Tails; Did coin A or B generate me?
2nd:  {H,H,H,H,T,H,H,H,H,H} 9 Heads, 1 Tails
3rd:  {H,T,H,H,H,H,H,T,H,H} 8 Heads, 2 Tails
4th:  {H,T,H,T,T,T,H,H,T,T} 4 Heads, 6 Tails
5th:  {T,H,H,H,T,H,H,H,T,H} 7 Heads, 3 Tails

Two possible coins, A & B are used to generate these distributions.
A & B have an unknown parameter: their bias towards heads.

We don't know the biases, but we can simply start with a guess: A=60% heads, B=50% heads.

İlk denemenin sorusu durumunda, sezgisel olarak B'nin bunu ürettiğini düşünürüz çünkü kafa oranı B'nin önyargısına çok iyi uyuyor ... ama bu değer sadece bir tahmindi, bu yüzden emin olamayız.

Bunu aklımda tutarak, EM çözümünü şu şekilde düşünmeyi seviyorum:

  • Her çevirme denemesi, en çok hangi parayı sevdiğini 'oylamaya' alır
    • Bu, her bir madalyonun dağıtımına ne kadar iyi uyduğuna bağlıdır
    • VEYA, madeni para açısından bakıldığında, bu denemeyi diğer madeni paraya göre görmenin yüksek beklentisi var ( günlük olasılıklarına göre ).
  • Her denemenin her bir madeni parayı ne kadar beğendiğine bağlı olarak, o madeni paranın parametresinin (önyargı) tahminini güncelleyebilir.
    • Deneme bir madeni parayı ne kadar çok seviyorsa, madeni paranın eğilimini kendi eğilimini yansıtacak şekilde günceller!
    • Esasen madeni paranın önyargıları, bu ağırlıklı güncellemeleri tüm denemelerde birleştirerek güncellenir; bu süreç ( maksimazasyon ), bir dizi denemede her bir madeni paranın önyargısı için en iyi tahminleri almaya çalışmayı ifade eder.

Bu bir aşırı basitleştirme (veya hatta bazı seviyelerde temelde yanlış) olabilir, ancak umarım bu sezgisel bir seviyede yardımcı olur!


1

EM, gizli değişkenler Z olan bir Q modelinin olasılığını maksimize etmek için kullanılır.

Yinelemeli bir optimizasyon.

theta <- initial guess for hidden parameters
while not converged:
    #e-step
    Q(theta'|theta) = E[log L(theta|Z)]
    #m-step
    theta <- argmax_theta' Q(theta'|theta)

e-adım: verilen mevcut Z tahmini, beklenen mantıksallık fonksiyonunu hesapla

m-adımı: bu Q'yu maksimize eden teta'yı bulun

GMM Örneği:

e-step: mevcut gmm parametresi tahmini verildiğinde her veri noktası için etiket atamalarını tahmin edin

m-adımı: yeni etiket atamalarında yeni bir teta'yı maksimize edin

K-aracı aynı zamanda bir EM algoritmasıdır ve K-ortalamaları hakkında birçok açıklama animasyonu vardır.


1

Zhubarb yanıtında anılan Do ve Batzoglou aynı yazı kullanarak, bu sorun için EM uygulanan Java . Cevabına yapılan yorumlar, algoritmanın yerel bir optimumda takıldığını gösteriyor; bu, thetaA ve thetaB parametreleri aynı ise benim uygulamamda da ortaya çıkıyor.

Aşağıda, kodumun parametrelerin yakınsamasını gösteren standart çıktısı var.

thetaA = 0.71301, thetaB = 0.58134
thetaA = 0.74529, thetaB = 0.56926
thetaA = 0.76810, thetaB = 0.54954
thetaA = 0.78316, thetaB = 0.53462
thetaA = 0.79106, thetaB = 0.52628
thetaA = 0.79453, thetaB = 0.52239
thetaA = 0.79593, thetaB = 0.52073
thetaA = 0.79647, thetaB = 0.52005
thetaA = 0.79667, thetaB = 0.51977
thetaA = 0.79674, thetaB = 0.51966
thetaA = 0.79677, thetaB = 0.51961
thetaA = 0.79678, thetaB = 0.51960
thetaA = 0.79679, thetaB = 0.51959
Final result:
thetaA = 0.79678, thetaB = 0.51960

Aşağıda problemi çözmek için Java uygulamam (Do ve Batzoglou, 2008). Gerçekleştirmenin temel kısmı, parametreler birleşene kadar EM çalıştırma döngüsüdür.

private Parameters _parameters;

public Parameters run()
{
    while (true)
    {
        expectation();

        Parameters estimatedParameters = maximization();

        if (_parameters.converged(estimatedParameters)) {
            break;
        }

        _parameters = estimatedParameters;
    }

    return _parameters;
}

Kodun tamamı aşağıdadır.

import java.util.*;

/*****************************************************************************
This class encapsulates the parameters of the problem. For this problem posed
in the article by (Do and Batzoglou, 2008), the parameters are thetaA and
thetaB, the probability of a coin coming up heads for the two coins A and B,
respectively.
*****************************************************************************/
class Parameters
{
    double _thetaA = 0.0; // Probability of heads for coin A.
    double _thetaB = 0.0; // Probability of heads for coin B.

    double _delta = 0.00001;

    public Parameters(double thetaA, double thetaB)
    {
        _thetaA = thetaA;
        _thetaB = thetaB;
    }

    /*************************************************************************
    Returns true if this parameter is close enough to another parameter
    (typically the estimated parameter coming from the maximization step).
    *************************************************************************/
    public boolean converged(Parameters other)
    {
        if (Math.abs(_thetaA - other._thetaA) < _delta &&
            Math.abs(_thetaB - other._thetaB) < _delta)
        {
            return true;
        }

        return false;
    }

    public double getThetaA()
    {
        return _thetaA;
    }

    public double getThetaB()
    {
        return _thetaB;
    }

    public String toString()
    {
        return String.format("thetaA = %.5f, thetaB = %.5f", _thetaA, _thetaB);
    }

}


/*****************************************************************************
This class encapsulates an observation, that is the number of heads
and tails in a trial. The observation can be either (1) one of the
experimental observations, or (2) an estimated observation resulting from
the expectation step.
*****************************************************************************/
class Observation
{
    double _numHeads = 0;
    double _numTails = 0;

    public Observation(String s)
    {
        for (int i = 0; i < s.length(); i++)
        {
            char c = s.charAt(i);

            if (c == 'H')
            {
                _numHeads++;
            }
            else if (c == 'T')
            {
                _numTails++;
            }
            else
            {
                throw new RuntimeException("Unknown character: " + c);
            }
        }
    }

    public Observation(double numHeads, double numTails)
    {
        _numHeads = numHeads;
        _numTails = numTails;
    }

    public double getNumHeads()
    {
        return _numHeads;
    }

    public double getNumTails()
    {
        return _numTails;
    }

    public String toString()
    {
        return String.format("heads: %.1f, tails: %.1f", _numHeads, _numTails);
    }

}

/*****************************************************************************
This class runs expectation-maximization for the problem posed by the article
from (Do and Batzoglou, 2008).
*****************************************************************************/
public class EM
{
    // Current estimated parameters.
    private Parameters _parameters;

    // Observations from the trials. These observations are set once.
    private final List<Observation> _observations;

    // Estimated observations per coin. These observations are the output
    // of the expectation step.
    private List<Observation> _expectedObservationsForCoinA;
    private List<Observation> _expectedObservationsForCoinB;

    private static java.io.PrintStream o = System.out;

    /*************************************************************************
    Principal constructor.
    @param observations The observations from the trial.
    @param parameters The initial guessed parameters.
    *************************************************************************/
    public EM(List<Observation> observations, Parameters parameters)
    {
        _observations = observations;
        _parameters = parameters;
    }

    /*************************************************************************
    Run EM until parameters converge.
    *************************************************************************/
    public Parameters run()
    {

        while (true)
        {
            expectation();

            Parameters estimatedParameters = maximization();

            o.printf("%s\n", estimatedParameters);

            if (_parameters.converged(estimatedParameters)) {
                break;
            }

            _parameters = estimatedParameters;
        }

        return _parameters;

    }

    /*************************************************************************
    Given the observations and current estimated parameters, compute new
    estimated completions (distribution over the classes) and observations.
    *************************************************************************/
    private void expectation()
    {

        _expectedObservationsForCoinA = new ArrayList<Observation>();
        _expectedObservationsForCoinB = new ArrayList<Observation>();

        for (Observation observation : _observations)
        {
            int numHeads = (int)observation.getNumHeads();
            int numTails = (int)observation.getNumTails();

            double probabilityOfObservationForCoinA=
                binomialProbability(10, numHeads, _parameters.getThetaA());

            double probabilityOfObservationForCoinB=
                binomialProbability(10, numHeads, _parameters.getThetaB());

            double normalizer = probabilityOfObservationForCoinA +
                                probabilityOfObservationForCoinB;

            // Compute the completions for coin A and B (i.e. the probability
            // distribution of the two classes, summed to 1.0).

            double completionCoinA = probabilityOfObservationForCoinA /
                                     normalizer;
            double completionCoinB = probabilityOfObservationForCoinB /
                                     normalizer;

            // Compute new expected observations for the two coins.

            Observation expectedObservationForCoinA =
                new Observation(numHeads * completionCoinA,
                                numTails * completionCoinA);

            Observation expectedObservationForCoinB =
                new Observation(numHeads * completionCoinB,
                                numTails * completionCoinB);

            _expectedObservationsForCoinA.add(expectedObservationForCoinA);
            _expectedObservationsForCoinB.add(expectedObservationForCoinB);
        }
    }

    /*************************************************************************
    Given new estimated observations, compute new estimated parameters.
    *************************************************************************/
    private Parameters maximization()
    {

        double sumCoinAHeads = 0.0;
        double sumCoinATails = 0.0;
        double sumCoinBHeads = 0.0;
        double sumCoinBTails = 0.0;

        for (Observation observation : _expectedObservationsForCoinA)
        {
            sumCoinAHeads += observation.getNumHeads();
            sumCoinATails += observation.getNumTails();
        }

        for (Observation observation : _expectedObservationsForCoinB)
        {
            sumCoinBHeads += observation.getNumHeads();
            sumCoinBTails += observation.getNumTails();
        }

        return new Parameters(sumCoinAHeads / (sumCoinAHeads + sumCoinATails),
                              sumCoinBHeads / (sumCoinBHeads + sumCoinBTails));

        //o.printf("parameters: %s\n", _parameters);

    }

    /*************************************************************************
    Since the coin-toss experiment posed in this article is a Bernoulli trial,
    use a binomial probability Pr(X=k; n,p) = (n choose k) * p^k * (1-p)^(n-k).
    *************************************************************************/
    private static double binomialProbability(int n, int k, double p)
    {
        double q = 1.0 - p;
        return nChooseK(n, k) * Math.pow(p, k) * Math.pow(q, n-k);
    }

    private static long nChooseK(int n, int k)
    {
        long numerator = 1;

        for (int i = 0; i < k; i++)
        {
            numerator = numerator * n;
            n--;
        }

        long denominator = factorial(k);

        return (long)(numerator / denominator);
    }

    private static long factorial(int n)
    {
        long result = 1;
        for (; n >0; n--)
        {
            result = result * n;
        }

        return result;
    }

    /*************************************************************************
    Entry point into the program.
    *************************************************************************/
    public static void main(String argv[])
    {
        // Create the observations and initial parameter guess
        // from the (Do and Batzoglou, 2008) article.

        List<Observation> observations = new ArrayList<Observation>();
        observations.add(new Observation("HTTTHHTHTH"));
        observations.add(new Observation("HHHHTHHHHH"));
        observations.add(new Observation("HTHHHHHTHH"));
        observations.add(new Observation("HTHTTTHHTT"));
        observations.add(new Observation("THHHTHHHTH"));

        Parameters initialParameters = new Parameters(0.6, 0.5);

        EM em = new EM(observations, initialParameters);

        Parameters finalParameters = em.run();

        o.printf("Final result:\n%s\n", finalParameters);
    }
}
Sitemizi kullandığınızda şunları okuyup anladığınızı kabul etmiş olursunuz: Çerez Politikası ve Gizlilik Politikası.
Licensed under cc by-sa 3.0 with attribution required.