Dưới đây là một ví dụ về Tối đa hóa kỳ vọng (EM) được sử dụng để ước tính độ lệch trung bình và độ lệch chuẩn. Mã này bằng Python, nhưng nó rất dễ theo dõi ngay cả khi bạn không quen với ngôn ngữ này.
Động lực cho EM
Các điểm màu đỏ và màu xanh hiển thị bên dưới được rút ra từ hai phân phối bình thường khác nhau, mỗi điểm có một độ lệch chuẩn và trung bình cụ thể:
Để tính các xấp xỉ hợp lý của các tham số độ lệch chuẩn và trung bình "đúng" cho phân bố màu đỏ, chúng ta có thể dễ dàng nhìn vào các điểm đỏ và ghi lại vị trí của từng điểm, sau đó sử dụng các công thức quen thuộc (và tương tự cho nhóm màu xanh) .
Bây giờ hãy xem xét trường hợp chúng ta biết rằng có hai nhóm điểm, nhưng chúng ta không thể thấy điểm nào thuộc về nhóm nào. Nói cách khác, các màu được ẩn đi:
Không rõ ràng làm thế nào để chia điểm thành hai nhóm. Bây giờ chúng ta không thể chỉ nhìn vào các vị trí và tính toán các ước tính cho các tham số của phân phối màu đỏ hoặc phân phối màu xanh.
Đây là nơi EM có thể được sử dụng để giải quyết vấn đề.
Sử dụng EM để ước tính các tham số
Đây là mã được sử dụng để tạo các điểm được hiển thị ở trên. Bạn có thể thấy các phương tiện thực tế và độ lệch chuẩn của các bản phân phối bình thường mà các điểm được rút ra từ đó. Các biến red
và blue
giữ vị trí của từng điểm trong các nhóm màu đỏ và màu xanh tương ứng:
import numpy as np
from scipy import stats
np.random.seed(110) # for reproducible random results
# set parameters
red_mean = 3
red_std = 0.8
blue_mean = 7
blue_std = 2
# draw 20 samples from normal distributions with red/blue parameters
red = np.random.normal(red_mean, red_std, size=20)
blue = np.random.normal(blue_mean, blue_std, size=20)
both_colours = np.sort(np.concatenate((red, blue)))
Nếu chúng ta có thể thấy màu của từng điểm, chúng ta sẽ thử và phục hồi các phương tiện và độ lệch chuẩn bằng các hàm thư viện:
>>> np.mean(red)
2.802
>>> np.std(red)
0.871
>>> np.mean(blue)
6.932
>>> np.std(blue)
2.195
Nhưng vì màu sắc bị ẩn khỏi chúng ta, chúng ta sẽ bắt đầu quá trình EM ...
Đầu tiên, chúng tôi chỉ đoán các giá trị cho các tham số của mỗi nhóm ( bước 1 ). Những dự đoán này không phải là tốt:
# estimates for the mean
red_mean_guess = 1.1
blue_mean_guess = 9
# estimates for the standard deviation
red_std_guess = 2
blue_std_guess = 1.7
Dự đoán khá tệ - phương tiện trông giống như chúng là một chặng đường dài từ bất kỳ "giữa" nào của một nhóm điểm.
Để tiếp tục với EM và cải thiện những dự đoán này, chúng tôi tính toán khả năng của từng điểm dữ liệu (bất kể màu bí mật của nó) xuất hiện dưới những dự đoán này cho độ lệch trung bình và độ lệch chuẩn ( bước 2 ).
Biến both_colours
giữ từng điểm dữ liệu. Hàm stats.norm
tính toán xác suất của điểm dưới một phân phối bình thường với các tham số đã cho:
likelihood_of_red = stats.norm(red_mean_guess, red_std_guess).pdf(both_colours)
likelihood_of_blue = stats.norm(blue_mean_guess, blue_std_guess).pdf(both_colours)
Điều này cho chúng ta, ví dụ, với dự đoán hiện tại của chúng tôi, điểm dữ liệu tại 1.761 có nhiều khả năng là màu đỏ (0.189) hơn màu xanh lam (0,00003).
Chúng ta có thể biến hai giá trị khả năng này thành trọng số ( bước 3 ) để chúng tổng hợp thành 1 như sau:
likelihood_total = likelihood_of_red + likelihood_of_blue
red_weight = likelihood_of_red / likelihood_total
blue_weight = likelihood_of_blue / likelihood_total
Với các ước tính hiện tại của chúng tôi và các trọng số mới được tính toán của chúng tôi, giờ đây chúng tôi có thể tính toán các ước tính mới, có thể tốt hơn cho các tham số ( bước 4 ). Chúng ta cần một hàm cho giá trị trung bình và một hàm cho độ lệch chuẩn:
def estimate_mean(data, weight):
return np.sum(data * weight) / np.sum(weight)
def estimate_std(data, weight, mean):
variance = np.sum(weight * (data - mean)**2) / np.sum(weight)
return np.sqrt(variance)
Chúng trông rất giống với các hàm thông thường với độ lệch trung bình và độ lệch chuẩn của dữ liệu. Sự khác biệt là việc sử dụng weight
tham số gán trọng số cho từng điểm dữ liệu.
Trọng số này là chìa khóa cho EM. Trọng lượng của màu trên điểm dữ liệu càng lớn, điểm dữ liệu càng ảnh hưởng đến các ước tính tiếp theo cho các tham số của màu đó. Cuối cùng, điều này có tác dụng kéo từng tham số theo đúng hướng.
Các dự đoán mới được tính toán với các chức năng này:
# new estimates for standard deviation
blue_std_guess = estimate_std(both_colours, blue_weight, blue_mean_guess)
red_std_guess = estimate_std(both_colours, red_weight, red_mean_guess)
# new estimates for mean
red_mean_guess = estimate_mean(both_colours, red_weight)
blue_mean_guess = estimate_mean(both_colours, blue_weight)
Quá trình EM sau đó được lặp lại với những dự đoán mới từ bước 2 trở đi. Chúng ta có thể lặp lại các bước cho một số lần lặp nhất định (giả sử 20) hoặc cho đến khi chúng ta thấy các tham số hội tụ.
Sau năm lần lặp lại, chúng ta thấy những dự đoán xấu ban đầu của mình bắt đầu tốt hơn:
Sau 20 lần lặp, quy trình EM đã hội tụ ít nhiều:
Để so sánh, đây là kết quả của quá trình EM so với các giá trị được tính toán trong đó thông tin màu không bị ẩn:
| EM guess | Actual
----------+----------+--------
Red mean | 2.910 | 2.802
Red std | 0.854 | 0.871
Blue mean | 6.838 | 6.932
Blue std | 2.227 | 2.195
Lưu ý: câu trả lời này đã được điều chỉnh từ câu trả lời của tôi trên Stack Overflow tại đây .