Lưu ý: mã đằng sau câu trả lời này có thể được tìm thấy ở đây .
Giả sử chúng ta có một số dữ liệu được lấy mẫu từ hai nhóm khác nhau, màu đỏ và xanh lam:
Tại đây, chúng ta có thể xem điểm dữ liệu nào thuộc nhóm màu đỏ hoặc xanh lam. Điều này giúp bạn dễ dàng tìm thấy các thông số đặc trưng cho từng nhóm. Ví dụ: giá trị trung bình của nhóm màu đỏ là khoảng 3, giá trị trung bình của nhóm màu xanh là khoảng 7 (và chúng tôi có thể tìm thấy phương tiện chính xác nếu chúng tôi muốn).
Nói chung, đây được gọi là ước tính khả năng xảy ra tối đa . Với một số dữ liệu, chúng tôi tính toán giá trị của một tham số (hoặc các tham số) giải thích tốt nhất cho dữ liệu đó.
Bây giờ hãy tưởng tượng rằng chúng ta không thể thấy giá trị nào được lấy mẫu từ nhóm nào. Mọi thứ đều có màu tím đối với chúng tôi:
Ở đây, chúng tôi biết rằng có hai nhóm giá trị, nhưng chúng tôi không biết nhóm nào có giá trị cụ thể thuộc về nhóm nào.
Chúng ta vẫn có thể ước tính phương tiện cho nhóm màu đỏ và nhóm màu xanh lam phù hợp nhất với dữ liệu này chứ?
Có, thường thì chúng ta có thể! Tối đa hóa kỳ vọng cung cấp cho chúng tôi một cách để thực hiện điều đó. Ý tưởng chung đằng sau thuật toán là:
- Bắt đầu với ước tính ban đầu về giá trị của từng thông số.
- Tính khả năng mỗi tham số tạo ra điểm dữ liệu.
- Tính toán trọng số cho mỗi điểm dữ liệu cho biết nó có nhiều màu đỏ hơn hay nhiều màu xanh hơn dựa trên khả năng nó được tạo ra bởi một tham số. Kết hợp trọng số với dữ liệu ( kỳ vọng ).
- Tính toán ước tính tốt hơn cho các thông số bằng cách sử dụng dữ liệu đã điều chỉnh trọng lượng ( tối đa hóa ).
- Lặp lại các bước từ 2 đến 4 cho đến khi ước tính tham số hội tụ (quá trình ngừng tạo ra ước tính khác).
Các bước này cần giải thích thêm, vì vậy tôi sẽ xem xét vấn đề được mô tả ở trên.
Ví dụ: ước tính giá trị trung bình và độ lệch chuẩn
Tôi sẽ sử dụng Python trong ví dụ này, nhưng mã phải khá dễ hiểu nếu bạn không quen với ngôn ngữ này.
Giả sử chúng ta có hai nhóm, màu đỏ và màu xanh lam, với các giá trị được phân phối như trong hình trên. Cụ thể, mỗi nhóm chứa một giá trị được rút ra từ phân phối chuẩn với các tham số sau:
import numpy as np
from scipy import stats
np.random.seed(110) # for reproducible results
# set parameters
red_mean = 3
red_std = 0.8
blue_mean = 7
blue_std = 2
# draw 20 samples from normal distributions with red/blue parameters
red = np.random.normal(red_mean, red_std, size=20)
blue = np.random.normal(blue_mean, blue_std, size=20)
both_colours = np.sort(np.concatenate((red, blue))) # for later use...
Đây là hình ảnh của các nhóm màu đỏ và xanh lam này một lần nữa (để giúp bạn không phải cuộn lên):
Khi chúng ta có thể nhìn thấy màu sắc của từng điểm (tức là nó thuộc nhóm nào), rất dễ dàng để ước tính giá trị trung bình và độ lệch chuẩn cho mỗi nhóm. Chúng tôi chỉ chuyển các giá trị màu đỏ và xanh lam cho các hàm nội trang trong NumPy. Ví dụ:
>>> np.mean(red)
2.802
>>> np.std(red)
0.871
>>> np.mean(blue)
6.932
>>> np.std(blue)
2.195
Nhưng nếu chúng ta không thể nhìn thấy màu sắc của các điểm thì sao? Đó là, thay vì màu đỏ hoặc xanh lam, mọi điểm đã được tô màu tím.
Để thử và khôi phục các tham số trung bình và độ lệch chuẩn cho các nhóm màu đỏ và xanh lam, chúng ta có thể sử dụng Tối đa hóa kỳ vọng.
Bước đầu tiên của chúng tôi ( bước 1 ở trên) là đoán các giá trị tham số cho độ lệch chuẩn và trung bình của mỗi nhóm. Chúng ta không cần phải đoán một cách thông minh; chúng tôi có thể chọn bất kỳ số nào chúng tôi thích:
# estimates for the mean
red_mean_guess = 1.1
blue_mean_guess = 9
# estimates for the standard deviation
red_std_guess = 2
blue_std_guess = 1.7
Các ước tính tham số này tạo ra các đường cong hình chuông trông như sau:
Đây là những ước tính không tốt. Ví dụ, cả hai đều có nghĩa là (các đường chấm dọc) nhìn xa bất kỳ loại "giữa" nào đối với các nhóm điểm hợp lý. Chúng tôi muốn cải thiện những ước tính này.
Bước tiếp theo ( bước 2 ) là tính toán khả năng mỗi điểm dữ liệu xuất hiện trong các dự đoán tham số hiện tại:
likelihood_of_red = stats.norm(red_mean_guess, red_std_guess).pdf(both_colours)
likelihood_of_blue = stats.norm(blue_mean_guess, blue_std_guess).pdf(both_colours)
Ở đây, chúng tôi chỉ cần đặt từng điểm dữ liệu vào hàm mật độ xác suất cho phân phối chuẩn bằng cách sử dụng các phỏng đoán hiện tại của chúng tôi ở mức trung bình và độ lệch chuẩn cho màu đỏ và xanh lam. Ví dụ, điều này cho chúng ta biết rằng với những dự đoán hiện tại của chúng ta, điểm dữ liệu tại 1.761 có nhiều khả năng có màu đỏ (0,189) hơn là màu xanh lam (0,00003).
Đối với mỗi điểm dữ liệu, chúng tôi có thể chuyển hai giá trị khả năng này thành trọng số ( bước 3 ) để chúng tổng hợp thành 1 như sau:
likelihood_total = likelihood_of_red + likelihood_of_blue
red_weight = likelihood_of_red / likelihood_total
blue_weight = likelihood_of_blue / likelihood_total
Với ước tính hiện tại và trọng số mới được tính toán của chúng tôi, giờ đây chúng tôi có thể tính toán các ước tính mới cho giá trị trung bình và độ lệch chuẩn của nhóm màu đỏ và xanh lam ( bước 4 ).
Chúng tôi tính toán hai lần giá trị trung bình và độ lệch chuẩn bằng cách sử dụng tất cả các điểm dữ liệu, nhưng với các trọng số khác nhau: một lần cho trọng số màu đỏ và một lần cho trọng số màu xanh.
Điểm mấu chốt của trực giác là trọng lượng của một màu trên một điểm dữ liệu càng lớn, thì điểm dữ liệu đó càng ảnh hưởng đến các ước tính tiếp theo cho các tham số của màu đó. Điều này có tác dụng “kéo” các thông số đi đúng hướng.
def estimate_mean(data, weight):
"""
For each data point, multiply the point by the probability it
was drawn from the colour's distribution (its "weight").
Divide by the total weight: essentially, we're finding where
the weight is centred among our data points.
"""
return np.sum(data * weight) / np.sum(weight)
def estimate_std(data, weight, mean):
"""
For each data point, multiply the point's squared difference
from a mean value by the probability it was drawn from
that distribution (its "weight").
Divide by the total weight: essentially, we're finding where
the weight is centred among the values for the difference of
each data point from the mean.
This is the estimate of the variance, take the positive square
root to find the standard deviation.
"""
variance = np.sum(weight * (data - mean)**2) / np.sum(weight)
return np.sqrt(variance)
# new estimates for standard deviation
blue_std_guess = estimate_std(both_colours, blue_weight, blue_mean_guess)
red_std_guess = estimate_std(both_colours, red_weight, red_mean_guess)
# new estimates for mean
red_mean_guess = estimate_mean(both_colours, red_weight)
blue_mean_guess = estimate_mean(both_colours, blue_weight)
Chúng tôi có ước tính mới cho các thông số. Để cải thiện chúng một lần nữa, chúng ta có thể quay lại bước 2 và lặp lại quy trình. Chúng tôi thực hiện điều này cho đến khi các ước tính hội tụ hoặc sau khi một số lần lặp đã được thực hiện ( bước 5 ).
Đối với dữ liệu của chúng tôi, năm lần lặp đầu tiên của quá trình này trông như thế này (các lần lặp gần đây có hình thức mạnh mẽ hơn):
Chúng tôi thấy rằng các phương tiện đã hội tụ trên một số giá trị và hình dạng của các đường cong (bị chi phối bởi độ lệch chuẩn) cũng đang trở nên ổn định hơn.
Nếu chúng ta tiếp tục trong 20 lần lặp, chúng ta kết thúc với những điều sau:
Quá trình EM đã hội tụ các giá trị sau, hóa ra rất gần với giá trị thực (nơi chúng ta có thể nhìn thấy màu sắc - không có biến ẩn):
| EM guess | Actual | Delta
----------+----------+--------+-------
Red mean | 2.910 | 2.802 | 0.108
Red std | 0.854 | 0.871 | -0.017
Blue mean | 6.838 | 6.932 | -0.094
Blue std | 2.227 | 2.195 | 0.032
Trong đoạn mã trên, bạn có thể nhận thấy rằng ước tính mới cho độ lệch chuẩn được tính bằng cách sử dụng ước tính trung bình của lần lặp trước. Cuối cùng sẽ không thành vấn đề nếu chúng ta tính giá trị mới cho giá trị trung bình trước vì chúng ta chỉ đang tìm phương sai (có trọng số) của các giá trị xung quanh một số điểm trung tâm. Chúng tôi vẫn sẽ thấy các ước tính cho các tham số hội tụ.