Beklenti Maksimizasyonu (EM), verileri sınıflandırmak için bir tür olasılıklı yöntemdir. Bir sınıflandırıcı değilse, yanılıyorsam lütfen düzeltin.
Bu EM tekniğinin sezgisel açıklaması nedir? Ne expectation
burada ve varlık nedir maximized
?
Beklenti Maksimizasyonu (EM), verileri sınıflandırmak için bir tür olasılıklı yöntemdir. Bir sınıflandırıcı değilse, yanılıyorsam lütfen düzeltin.
Bu EM tekniğinin sezgisel açıklaması nedir? Ne expectation
burada ve varlık nedir maximized
?
Yanıtlar:
Not: Bu cevabın arkasındaki kod burada bulunabilir .
Kırmızı ve mavi olmak üzere iki farklı gruptan örneklenmiş bazı verilerimiz olduğunu varsayalım:
Burada hangi veri noktasının kırmızı veya mavi gruba ait olduğunu görebiliriz. Bu, her bir grubu karakterize eden parametreleri bulmayı kolaylaştırır. Örneğin, kırmızı grubun ortalaması 3 civarındadır, mavi grubun ortalaması 7 civarındadır (ve istersek tam anlamını bulabiliriz).
Bu, genel olarak maksimum olasılık tahmini olarak bilinir . Bazı veriler göz önüne alındığında, bu verileri en iyi açıklayan bir parametrenin (veya parametrelerin) değerini hesaplarız.
Şimdi hayal edemez hangi gruptan örneklenmiş hangi değer görüyoruz. Bize her şey mor görünüyor:
Burada iki değer grubu olduğu bilgisine sahibiz , ancak belirli bir değerin hangi gruba ait olduğunu bilmiyoruz.
Bu verilere en iyi uyan kırmızı grup ve mavi grup için ortalamaları hala tahmin edebilir miyiz?
Evet, sık sık yapabiliriz! Beklenti Maksimizasyonu , bunu yapmamız için bize bir yol sunar. Algoritmanın arkasındaki çok genel fikir şudur:
Bu adımların daha fazla açıklamaya ihtiyacı vardır, bu nedenle yukarıda açıklanan sorunu gözden geçireceğim.
Bu örnekte Python kullanacağım, ancak bu dile aşina değilseniz kodun anlaşılması oldukça kolay olacaktır.
Yukarıdaki görüntüdeki gibi dağıtılmış değerlere sahip kırmızı ve mavi olmak üzere iki grubumuz olduğunu varsayalım. Spesifik olarak, her grup aşağıdaki parametrelerle normal bir dağılımdan alınan bir değer içerir :
import numpy as np
from scipy import stats
np.random.seed(110) # for reproducible results
# set parameters
red_mean = 3
red_std = 0.8
blue_mean = 7
blue_std = 2
# draw 20 samples from normal distributions with red/blue parameters
red = np.random.normal(red_mean, red_std, size=20)
blue = np.random.normal(blue_mean, blue_std, size=20)
both_colours = np.sort(np.concatenate((red, blue))) # for later use...
İşte yine bu kırmızı ve mavi grupların bir görüntüsü (sizi yukarı kaydırmaktan kurtarmak için):
Her noktanın rengini (yani hangi gruba ait olduğunu) görebildiğimizde, her grup için ortalama ve standart sapmayı tahmin etmek çok kolaydır. Kırmızı ve mavi değerleri NumPy'deki yerleşik işlevlere aktarıyoruz. Örneğin:
>>> np.mean(red)
2.802
>>> np.std(red)
0.871
>>> np.mean(blue)
6.932
>>> np.std(blue)
2.195
Ama ya eğer edemez noktalarının renkleri görmek? Yani kırmızı veya mavi yerine her nokta mor renkle boyanmıştır.
Kırmızı ve mavi gruplar için ortalama ve standart sapma parametrelerini denemek ve kurtarmak için Beklenti Maksimizasyonunu kullanabiliriz.
İlk adımımız ( yukarıdaki 1. adım ), her grubun ortalama ve standart sapması için parametre değerlerini tahmin etmektir. Akıllıca tahmin etmemize gerek yok; İstediğimiz sayıları seçebiliriz:
# estimates for the mean
red_mean_guess = 1.1
blue_mean_guess = 9
# estimates for the standard deviation
red_std_guess = 2
blue_std_guess = 1.7
Bu parametre tahminleri, şuna benzeyen çan eğrileri üretir:
Bunlar kötü tahminlerdir. Her iki araç da (dikey noktalı çizgiler), örneğin, mantıklı nokta grupları için herhangi bir tür "orta" dan çok uzak görünür. Bu tahminleri iyileştirmek istiyoruz.
Bir sonraki adım ( adım 2 ), her bir veri noktasının mevcut parametre tahminlerinin altında görünme olasılığını hesaplamaktır:
likelihood_of_red = stats.norm(red_mean_guess, red_std_guess).pdf(both_colours)
likelihood_of_blue = stats.norm(blue_mean_guess, blue_std_guess).pdf(both_colours)
Burada, kırmızı ve mavi için ortalama ve standart sapmadaki mevcut tahminlerimizi kullanarak her veri noktasını normal dağılım için olasılık yoğunluğu fonksiyonuna koyduk . Bu bize, örneğin, mevcut tahminlerimizle 1.761'deki veri noktasının mavi (0.00003) yerine kırmızı (0.189) olma olasılığının çok daha yüksek olduğunu söylüyor.
Her veri noktası için, bu iki olabilirlik değerini ağırlıklara dönüştürebiliriz ( 3. adım ), böylece toplamları aşağıdaki gibi 1 olur:
likelihood_total = likelihood_of_red + likelihood_of_blue
red_weight = likelihood_of_red / likelihood_total
blue_weight = likelihood_of_blue / likelihood_total
Mevcut tahminlerimiz ve yeni hesaplanmış ağırlıklarımızla, artık kırmızı ve mavi grupların ortalama ve standart sapması için yeni tahminler hesaplayabiliriz ( 4. adım ).
Tüm veri noktalarını kullanarak ortalama ve standart sapmayı iki kez hesaplıyoruz , ancak farklı ağırlıklarla: bir kez kırmızı ağırlıklar ve bir kez mavi ağırlıklar için.
Sezginin temel noktası, bir veri noktasındaki bir rengin ağırlığı ne kadar büyükse, veri noktasının o rengin parametreleri için sonraki tahminleri o kadar fazla etkilemesidir. Bu, parametreleri doğru yönde "çekme" etkisine sahiptir.
def estimate_mean(data, weight):
"""
For each data point, multiply the point by the probability it
was drawn from the colour's distribution (its "weight").
Divide by the total weight: essentially, we're finding where
the weight is centred among our data points.
"""
return np.sum(data * weight) / np.sum(weight)
def estimate_std(data, weight, mean):
"""
For each data point, multiply the point's squared difference
from a mean value by the probability it was drawn from
that distribution (its "weight").
Divide by the total weight: essentially, we're finding where
the weight is centred among the values for the difference of
each data point from the mean.
This is the estimate of the variance, take the positive square
root to find the standard deviation.
"""
variance = np.sum(weight * (data - mean)**2) / np.sum(weight)
return np.sqrt(variance)
# new estimates for standard deviation
blue_std_guess = estimate_std(both_colours, blue_weight, blue_mean_guess)
red_std_guess = estimate_std(both_colours, red_weight, red_mean_guess)
# new estimates for mean
red_mean_guess = estimate_mean(both_colours, red_weight)
blue_mean_guess = estimate_mean(both_colours, blue_weight)
Parametreler için yeni tahminlerimiz var. Bunları tekrar iyileştirmek için 2. adıma geri dönüp işlemi tekrar edebiliriz. Bunu tahminler yakınsayıncaya kadar veya bazı yineleme yapıldıktan sonra yapıyoruz ( adım 5 ).
Verilerimiz için, bu sürecin ilk beş yinelemesi şuna benzer (son yinelemeler daha güçlü görünüme sahiptir):
Araçların zaten bazı değerlerde birleştiğini ve eğrilerin şekillerinin (standart sapma tarafından yönetilen) daha kararlı hale geldiğini görüyoruz.
20 yineleme boyunca devam edersek, aşağıdakilerle sonuçlanırız:
EM süreci, gerçek değerlere çok yakın olduğu ortaya çıkan aşağıdaki değerlere yakınsamıştır (renkleri görebiliyoruz - gizli değişkenler yok):
| EM guess | Actual | Delta
----------+----------+--------+-------
Red mean | 2.910 | 2.802 | 0.108
Red std | 0.854 | 0.871 | -0.017
Blue mean | 6.838 | 6.932 | -0.094
Blue std | 2.227 | 2.195 | 0.032
Yukarıdaki kodda, standart sapma için yeni tahminin, önceki yinelemenin ortalama tahmini kullanılarak hesaplandığını fark etmiş olabilirsiniz. Nihayetinde, ilk önce ortalama için yeni bir değer hesaplayıp hesaplamamız önemli değildir, çünkü sadece bazı merkezi noktalar etrafında değerlerin (ağırlıklı) varyansını buluyoruz. Yine de parametreler için tahminlerin yakınsadığını göreceğiz.
EM, modelinizdeki bazı değişkenler gözlemlenmediğinde (ör. Gizli değişkenleriniz olduğunda) olasılık fonksiyonunu maksimize etmek için bir algoritmadır.
Bir işlevi maksimize etmeye çalışıyorsak, bir işlevi maksimize etmek için neden sadece mevcut makineyi kullanmıyoruz diye sorabilirsiniz. Bunu türevleri alıp sıfıra koyarak maksimize etmeye çalışırsanız, çoğu durumda birinci dereceden koşulların bir çözümü olmadığını görürsünüz. Model parametrelerinizi çözmek için gözlenmeyen verilerinizin dağılımını bilmeniz gereken bir tavuk ve yumurta sorunu vardır; ancak gözlemlenmemiş verilerinizin dağılımı, model parametrelerinizin bir fonksiyonudur.
EM, gözlemlenmemiş veriler için bir dağılımı yinelemeli olarak tahmin ederek, ardından gerçek olasılık fonksiyonunda daha düşük bir sınır olan bir şeyi maksimize ederek ve yakınsamaya kadar tekrarlayarak model parametrelerini tahmin ederek bu sorunu aşmaya çalışır:
EM algoritması
Model parametrelerinizin değerleri için tahminle başlayın
E-adımı: Eksik değerleri olan her veri noktası için, model parametreleriyle ilgili mevcut tahmininiz ve gözlemlenen verilere göre eksik verilerin dağılımını çözmek için model denkleminizi kullanın (her bir eksik için bir dağıtım için çözdüğünüzü unutmayın. değer, beklenen değer için değil). Artık her bir eksik değer için bir dağılımımız olduğuna göre, gözlenmeyen değişkenlere göre olasılık fonksiyonunun beklentisini hesaplayabiliriz . Model parametresi için tahminimiz doğruysa, bu beklenen olasılık, gözlemlenen verilerimizin gerçek olasılığı olacaktır; parametreler doğru değilse, bu sadece bir alt sınır olacaktır.
M adımı: Artık, içinde gözlenmeyen değişkenler olmayan beklenen bir olasılık fonksiyonuna sahip olduğumuza göre, model parametrelerinizin yeni bir tahminini elde etmek için, tamamen gözlemlenen durumda yapacağınız gibi fonksiyonu maksimize edin.
Yakınsamaya kadar tekrarlayın.
Beklenti Maksimizasyonu algoritmasını anlamak için basit bir tarif:
1- Do ve Batzoglou'nun hazırladığı EM eğitim belgesini okuyun .
2- Kafanızda soru işaretleri olabilir, bu matematik yığını değişim sayfasındaki açıklamalara bir göz atın .
3- Python'da yazdığım, 1. maddenin EM eğitim kağıdındaki örneği açıklayan bu koda bakın:
Uyarı: Python geliştiricisi olmadığım için kod dağınık / yetersiz olabilir. Ama işi yapıyor.
import numpy as np
import math
#### E-M Coin Toss Example as given in the EM tutorial paper by Do and Batzoglou* ####
def get_mn_log_likelihood(obs,probs):
""" Return the (log)likelihood of obs, given the probs"""
# Multinomial Distribution Log PMF
# ln (pdf) = multinomial coeff * product of probabilities
# ln[f(x|n, p)] = [ln(n!) - (ln(x1!)+ln(x2!)+...+ln(xk!))] + [x1*ln(p1)+x2*ln(p2)+...+xk*ln(pk)]
multinomial_coeff_denom= 0
prod_probs = 0
for x in range(0,len(obs)): # loop through state counts in each observation
multinomial_coeff_denom = multinomial_coeff_denom + math.log(math.factorial(obs[x]))
prod_probs = prod_probs + obs[x]*math.log(probs[x])
multinomial_coeff = math.log(math.factorial(sum(obs))) - multinomial_coeff_denom
likelihood = multinomial_coeff + prod_probs
return likelihood
# 1st: Coin B, {HTTTHHTHTH}, 5H,5T
# 2nd: Coin A, {HHHHTHHHHH}, 9H,1T
# 3rd: Coin A, {HTHHHHHTHH}, 8H,2T
# 4th: Coin B, {HTHTTTHHTT}, 4H,6T
# 5th: Coin A, {THHHTHHHTH}, 7H,3T
# so, from MLE: pA(heads) = 0.80 and pB(heads)=0.45
# represent the experiments
head_counts = np.array([5,9,8,4,7])
tail_counts = 10-head_counts
experiments = zip(head_counts,tail_counts)
# initialise the pA(heads) and pB(heads)
pA_heads = np.zeros(100); pA_heads[0] = 0.60
pB_heads = np.zeros(100); pB_heads[0] = 0.50
# E-M begins!
delta = 0.001
j = 0 # iteration counter
improvement = float('inf')
while (improvement>delta):
expectation_A = np.zeros((5,2), dtype=float)
expectation_B = np.zeros((5,2), dtype=float)
for i in range(0,len(experiments)):
e = experiments[i] # i'th experiment
ll_A = get_mn_log_likelihood(e,np.array([pA_heads[j],1-pA_heads[j]])) # loglikelihood of e given coin A
ll_B = get_mn_log_likelihood(e,np.array([pB_heads[j],1-pB_heads[j]])) # loglikelihood of e given coin B
weightA = math.exp(ll_A) / ( math.exp(ll_A) + math.exp(ll_B) ) # corresponding weight of A proportional to likelihood of A
weightB = math.exp(ll_B) / ( math.exp(ll_A) + math.exp(ll_B) ) # corresponding weight of B proportional to likelihood of B
expectation_A[i] = np.dot(weightA, e)
expectation_B[i] = np.dot(weightB, e)
pA_heads[j+1] = sum(expectation_A)[0] / sum(sum(expectation_A));
pB_heads[j+1] = sum(expectation_B)[0] / sum(sum(expectation_B));
improvement = max( abs(np.array([pA_heads[j+1],pB_heads[j+1]]) - np.array([pA_heads[j],pB_heads[j]]) ))
j = j+1
Teknik olarak "EM" terimi biraz daha az belirtilmiştir, ancak genel EM ilkesinin bir örneği olan Gauss Karışımı Modelleme küme analizi tekniğine atıfta bulunduğunuzu varsayıyorum .
Aslında EM küme analizi bir sınıflandırıcı değildir . Bazı insanların kümelemeyi "denetimsiz sınıflandırma" olarak gördüğünü biliyorum, ama aslında kümeleme analizi oldukça farklı bir şey.
Temel fark ve insanların küme analiziyle her zaman sahip oldukları büyük yanlış anlama sınıflandırması şudur: küme analizinde "doğru çözüm" yoktur . Bu bir bilgi keşif yöntemidir, aslında yeni bir şey bulmak içindir ! Bu, değerlendirmeyi çok zor hale getirir. Genellikle referans olarak bilinen bir sınıflandırma kullanılarak değerlendirilir, ancak bu her zaman uygun değildir: sahip olduğunuz sınıflandırma verilerde ne olduğunu yansıtabilir veya yansıtmayabilir.
Size bir örnek vereyim: Cinsiyet verileri de dahil olmak üzere geniş bir müşteri veri kümeniz var. Bu veri kümesini "erkek" ve "dişi" olarak bölen bir yöntem, onu mevcut sınıflarla karşılaştırdığınızda idealdir. Yeni kullanıcılar için artık cinsiyetlerini tahmin edebildiğinizden, bu iyi bir "tahmin" tarzıdır. "Bilgi keşfi" şeklinde düşündüğünüzde, bu aslında kötü, çünkü verilerde yeni bir yapı keşfetmek istediniz . Örneğin, verileri yaşlı insanlara ve çocuklara ayıran bir yöntem, ancak erkek / kadın sınıfına göre alabileceği kadar kötü puan alacaktır . Bununla birlikte, bu mükemmel bir kümeleme sonucu olur (yaş belirtilmemişse).
Şimdi EM'ye geri dönün. Esasen, verilerinizin çok değişkenli normal dağılımlardan oluştuğunu varsayar (bunun özellikle küme sayısını sabitlediğinizde çok güçlü bir varsayım olduğuna dikkat edin !). Daha sonra modeli ve modele nesne atamasını dönüşümlü olarak geliştirerek bunun için yerel bir optimal model bulmaya çalışır .
Bir sınıflandırma bağlamında en iyi sonuçlar için, sınıf sayısından daha büyük küme sayısını seçin veya hatta kümelemeyi yalnızca tek sınıflara uygulayın (sınıfta bir yapı olup olmadığını öğrenmek için!).
"Arabaları", "bisikletleri" ve "kamyonları" birbirinden ayırmak için bir sınıflandırıcı eğitmek istediğinizi varsayalım. Verilerin tam olarak 3 normal dağılımdan oluştuğunu varsaymanın çok az faydası vardır. Bununla birlikte, birden fazla araba türü (ve kamyon ve bisiklet) olduğunu varsayabilirsiniz . Dolayısıyla, bu üç sınıf için bir sınıflandırıcı eğitmek yerine, arabaları, kamyonları ve bisikletleri her biri 10 kümeye (veya belki 10 araba, 3 kamyon ve 3 bisiklet, her neyse) ayırırsınız, ardından bu 30 sınıfı ayırmak için bir sınıflandırıcı eğitirsiniz ve sonra sınıf sonucunu orijinal sınıflara geri birleştirin. Ayrıca, Trikes gibi sınıflandırılması özellikle zor olan bir küme olduğunu da keşfedebilirsiniz. Onlar biraz araba ve biraz da bisiklet. Veya kamyonlardan çok büyük arabalara benzeyen teslimat kamyonları.
Diğer cevaplar iyidir, başka bir bakış açısı sağlamaya ve sorunun sezgisel kısmını ele almaya çalışacağım.
EM (Beklenti-Maksimizasyon) algoritması , dualite kullanan bir yinelemeli algoritmalar sınıfının bir varyantıdır.
Alıntı (vurgu benim):
Matematikte, genel olarak konuşursak, bir dualite, kavramları, teoremleri veya matematiksel yapıları diğer kavramlara, teoremlere veya yapılara bire bir biçimde, genellikle (ancak her zaman değil) bir evrişim işlemi aracılığıyla çevirir: eğer A, B'dir, sonra B'nin ikilisi, A'dır. bazen sabit noktaları vardır , böylece A'nın duali A'nın kendisidir.
Genellikle , çift bir B nesnesi A bir muhafaza bir şekilde A ile ilgilidir simetri veya uyumluluk . Örneğin AB = const
Dualite kullanan (önceki anlamda) yinelemeli algoritmalara örnekler şunlardır:
Benzer bir şekilde, EM algoritması iki ikili maksimizasyon adımı olarak da görülebilir :
.. [EM], parametrelerin ve gözlenmeyen değişkenler üzerindeki dağılımın birleşik bir fonksiyonunu maksimize ettiği görülmektedir. E-adımı, bu fonksiyonu, gözlemlenmemiş değişkenler üzerindeki dağılıma göre maksimize eder; parametrelere göre M-adımı ..
Dualite kullanan yinelemeli bir algoritmada, bir denge (veya sabit) yakınsama noktasının açık (veya örtük) varsayımı vardır (EM için bu Jensen'in eşitsizliği kullanılarak kanıtlanmıştır)
Dolayısıyla bu tür algoritmaların ana hatları şu şekildedir:
Not a (global) optimuma Böyle bir algoritma yakınsak, bunun bir yapılandırma bulduğunda her iki anlamda da iyi (her ikisi de, yani X alanı / parametreler ve y alan / parametreler). Ancak algoritma yerel bir optimum bulabilir , global optimum .
Bunun, algoritmanın ana hatlarının sezgisel açıklaması olduğunu söyleyebilirim
İstatistiksel argümanlar ve uygulamalar için, diğer cevaplar iyi açıklamalar vermiştir (bu cevaptaki referansları da kontrol edin)
Kabul edilen cevap, EM'yi açıklayan iyi bir iş çıkaran Chuong EM Paper'a atıfta bulunuyor . Ayrıca kağıdı daha detaylı anlatan bir youtube videosu da var .
Özetlemek gerekirse, senaryo şu şekildedir:
1st: {H,T,T,T,H,H,T,H,T,H} 5 Heads, 5 Tails; Did coin A or B generate me?
2nd: {H,H,H,H,T,H,H,H,H,H} 9 Heads, 1 Tails
3rd: {H,T,H,H,H,H,H,T,H,H} 8 Heads, 2 Tails
4th: {H,T,H,T,T,T,H,H,T,T} 4 Heads, 6 Tails
5th: {T,H,H,H,T,H,H,H,T,H} 7 Heads, 3 Tails
Two possible coins, A & B are used to generate these distributions.
A & B have an unknown parameter: their bias towards heads.
We don't know the biases, but we can simply start with a guess: A=60% heads, B=50% heads.
İlk denemenin sorusu durumunda, sezgisel olarak B'nin bunu ürettiğini düşünürüz çünkü kafa oranı B'nin önyargısına çok iyi uyuyor ... ama bu değer sadece bir tahmindi, bu yüzden emin olamayız.
Bunu aklımda tutarak, EM çözümünü şu şekilde düşünmeyi seviyorum:
Bu bir aşırı basitleştirme (veya hatta bazı seviyelerde temelde yanlış) olabilir, ancak umarım bu sezgisel bir seviyede yardımcı olur!
EM, gizli değişkenler Z olan bir Q modelinin olasılığını maksimize etmek için kullanılır.
Yinelemeli bir optimizasyon.
theta <- initial guess for hidden parameters
while not converged:
#e-step
Q(theta'|theta) = E[log L(theta|Z)]
#m-step
theta <- argmax_theta' Q(theta'|theta)
e-adım: verilen mevcut Z tahmini, beklenen mantıksallık fonksiyonunu hesapla
m-adımı: bu Q'yu maksimize eden teta'yı bulun
GMM Örneği:
e-step: mevcut gmm parametresi tahmini verildiğinde her veri noktası için etiket atamalarını tahmin edin
m-adımı: yeni etiket atamalarında yeni bir teta'yı maksimize edin
K-aracı aynı zamanda bir EM algoritmasıdır ve K-ortalamaları hakkında birçok açıklama animasyonu vardır.
Zhubarb yanıtında anılan Do ve Batzoglou aynı yazı kullanarak, bu sorun için EM uygulanan Java . Cevabına yapılan yorumlar, algoritmanın yerel bir optimumda takıldığını gösteriyor; bu, thetaA ve thetaB parametreleri aynı ise benim uygulamamda da ortaya çıkıyor.
Aşağıda, kodumun parametrelerin yakınsamasını gösteren standart çıktısı var.
thetaA = 0.71301, thetaB = 0.58134
thetaA = 0.74529, thetaB = 0.56926
thetaA = 0.76810, thetaB = 0.54954
thetaA = 0.78316, thetaB = 0.53462
thetaA = 0.79106, thetaB = 0.52628
thetaA = 0.79453, thetaB = 0.52239
thetaA = 0.79593, thetaB = 0.52073
thetaA = 0.79647, thetaB = 0.52005
thetaA = 0.79667, thetaB = 0.51977
thetaA = 0.79674, thetaB = 0.51966
thetaA = 0.79677, thetaB = 0.51961
thetaA = 0.79678, thetaB = 0.51960
thetaA = 0.79679, thetaB = 0.51959
Final result:
thetaA = 0.79678, thetaB = 0.51960
Aşağıda problemi çözmek için Java uygulamam (Do ve Batzoglou, 2008). Gerçekleştirmenin temel kısmı, parametreler birleşene kadar EM çalıştırma döngüsüdür.
private Parameters _parameters;
public Parameters run()
{
while (true)
{
expectation();
Parameters estimatedParameters = maximization();
if (_parameters.converged(estimatedParameters)) {
break;
}
_parameters = estimatedParameters;
}
return _parameters;
}
Kodun tamamı aşağıdadır.
import java.util.*;
/*****************************************************************************
This class encapsulates the parameters of the problem. For this problem posed
in the article by (Do and Batzoglou, 2008), the parameters are thetaA and
thetaB, the probability of a coin coming up heads for the two coins A and B,
respectively.
*****************************************************************************/
class Parameters
{
double _thetaA = 0.0; // Probability of heads for coin A.
double _thetaB = 0.0; // Probability of heads for coin B.
double _delta = 0.00001;
public Parameters(double thetaA, double thetaB)
{
_thetaA = thetaA;
_thetaB = thetaB;
}
/*************************************************************************
Returns true if this parameter is close enough to another parameter
(typically the estimated parameter coming from the maximization step).
*************************************************************************/
public boolean converged(Parameters other)
{
if (Math.abs(_thetaA - other._thetaA) < _delta &&
Math.abs(_thetaB - other._thetaB) < _delta)
{
return true;
}
return false;
}
public double getThetaA()
{
return _thetaA;
}
public double getThetaB()
{
return _thetaB;
}
public String toString()
{
return String.format("thetaA = %.5f, thetaB = %.5f", _thetaA, _thetaB);
}
}
/*****************************************************************************
This class encapsulates an observation, that is the number of heads
and tails in a trial. The observation can be either (1) one of the
experimental observations, or (2) an estimated observation resulting from
the expectation step.
*****************************************************************************/
class Observation
{
double _numHeads = 0;
double _numTails = 0;
public Observation(String s)
{
for (int i = 0; i < s.length(); i++)
{
char c = s.charAt(i);
if (c == 'H')
{
_numHeads++;
}
else if (c == 'T')
{
_numTails++;
}
else
{
throw new RuntimeException("Unknown character: " + c);
}
}
}
public Observation(double numHeads, double numTails)
{
_numHeads = numHeads;
_numTails = numTails;
}
public double getNumHeads()
{
return _numHeads;
}
public double getNumTails()
{
return _numTails;
}
public String toString()
{
return String.format("heads: %.1f, tails: %.1f", _numHeads, _numTails);
}
}
/*****************************************************************************
This class runs expectation-maximization for the problem posed by the article
from (Do and Batzoglou, 2008).
*****************************************************************************/
public class EM
{
// Current estimated parameters.
private Parameters _parameters;
// Observations from the trials. These observations are set once.
private final List<Observation> _observations;
// Estimated observations per coin. These observations are the output
// of the expectation step.
private List<Observation> _expectedObservationsForCoinA;
private List<Observation> _expectedObservationsForCoinB;
private static java.io.PrintStream o = System.out;
/*************************************************************************
Principal constructor.
@param observations The observations from the trial.
@param parameters The initial guessed parameters.
*************************************************************************/
public EM(List<Observation> observations, Parameters parameters)
{
_observations = observations;
_parameters = parameters;
}
/*************************************************************************
Run EM until parameters converge.
*************************************************************************/
public Parameters run()
{
while (true)
{
expectation();
Parameters estimatedParameters = maximization();
o.printf("%s\n", estimatedParameters);
if (_parameters.converged(estimatedParameters)) {
break;
}
_parameters = estimatedParameters;
}
return _parameters;
}
/*************************************************************************
Given the observations and current estimated parameters, compute new
estimated completions (distribution over the classes) and observations.
*************************************************************************/
private void expectation()
{
_expectedObservationsForCoinA = new ArrayList<Observation>();
_expectedObservationsForCoinB = new ArrayList<Observation>();
for (Observation observation : _observations)
{
int numHeads = (int)observation.getNumHeads();
int numTails = (int)observation.getNumTails();
double probabilityOfObservationForCoinA=
binomialProbability(10, numHeads, _parameters.getThetaA());
double probabilityOfObservationForCoinB=
binomialProbability(10, numHeads, _parameters.getThetaB());
double normalizer = probabilityOfObservationForCoinA +
probabilityOfObservationForCoinB;
// Compute the completions for coin A and B (i.e. the probability
// distribution of the two classes, summed to 1.0).
double completionCoinA = probabilityOfObservationForCoinA /
normalizer;
double completionCoinB = probabilityOfObservationForCoinB /
normalizer;
// Compute new expected observations for the two coins.
Observation expectedObservationForCoinA =
new Observation(numHeads * completionCoinA,
numTails * completionCoinA);
Observation expectedObservationForCoinB =
new Observation(numHeads * completionCoinB,
numTails * completionCoinB);
_expectedObservationsForCoinA.add(expectedObservationForCoinA);
_expectedObservationsForCoinB.add(expectedObservationForCoinB);
}
}
/*************************************************************************
Given new estimated observations, compute new estimated parameters.
*************************************************************************/
private Parameters maximization()
{
double sumCoinAHeads = 0.0;
double sumCoinATails = 0.0;
double sumCoinBHeads = 0.0;
double sumCoinBTails = 0.0;
for (Observation observation : _expectedObservationsForCoinA)
{
sumCoinAHeads += observation.getNumHeads();
sumCoinATails += observation.getNumTails();
}
for (Observation observation : _expectedObservationsForCoinB)
{
sumCoinBHeads += observation.getNumHeads();
sumCoinBTails += observation.getNumTails();
}
return new Parameters(sumCoinAHeads / (sumCoinAHeads + sumCoinATails),
sumCoinBHeads / (sumCoinBHeads + sumCoinBTails));
//o.printf("parameters: %s\n", _parameters);
}
/*************************************************************************
Since the coin-toss experiment posed in this article is a Bernoulli trial,
use a binomial probability Pr(X=k; n,p) = (n choose k) * p^k * (1-p)^(n-k).
*************************************************************************/
private static double binomialProbability(int n, int k, double p)
{
double q = 1.0 - p;
return nChooseK(n, k) * Math.pow(p, k) * Math.pow(q, n-k);
}
private static long nChooseK(int n, int k)
{
long numerator = 1;
for (int i = 0; i < k; i++)
{
numerator = numerator * n;
n--;
}
long denominator = factorial(k);
return (long)(numerator / denominator);
}
private static long factorial(int n)
{
long result = 1;
for (; n >0; n--)
{
result = result * n;
}
return result;
}
/*************************************************************************
Entry point into the program.
*************************************************************************/
public static void main(String argv[])
{
// Create the observations and initial parameter guess
// from the (Do and Batzoglou, 2008) article.
List<Observation> observations = new ArrayList<Observation>();
observations.add(new Observation("HTTTHHTHTH"));
observations.add(new Observation("HHHHTHHHHH"));
observations.add(new Observation("HTHHHHHTHH"));
observations.add(new Observation("HTHTTTHHTT"));
observations.add(new Observation("THHHTHHHTH"));
Parameters initialParameters = new Parameters(0.6, 0.5);
EM em = new EM(observations, initialParameters);
Parameters finalParameters = em.run();
o.printf("Final result:\n%s\n", finalParameters);
}
}