Tham số class_weight trong scikit-learning hoạt động như thế nào?


116

Tôi đang gặp rất nhiều khó khăn khi hiểu cách class_weighttham số trong hồi quy logistic của scikit-learning hoạt động.

Tình huống

Tôi muốn sử dụng hồi quy logistic để thực hiện phân loại nhị phân trên một tập dữ liệu rất không cân bằng. Các lớp được dán nhãn 0 (tiêu cực) và 1 (dương tính) và dữ liệu quan sát được theo tỷ lệ khoảng 19: 1 với phần lớn các mẫu có kết quả âm tính.

Lần thử đầu tiên: Chuẩn bị thủ công dữ liệu đào tạo

Tôi chia dữ liệu tôi có thành các bộ rời rạc để đào tạo và thử nghiệm (khoảng 80/20). Sau đó, tôi lấy mẫu ngẫu nhiên dữ liệu đào tạo bằng tay để lấy dữ liệu đào tạo theo các tỷ lệ khác nhau so với 19: 1; từ 2: 1 -> 16: 1.

Sau đó, tôi đã đào tạo hồi quy logistic trên các tập con dữ liệu đào tạo khác nhau này và vẽ biểu đồ thu hồi (= TP / (TP + FN)) như một hàm của các tỷ lệ đào tạo khác nhau. Tất nhiên, việc thu hồi được tính toán trên các mẫu TEST rời rạc có tỷ lệ quan sát được là 19: 1. Lưu ý, mặc dù tôi đã huấn luyện các mô hình khác nhau trên các dữ liệu huấn luyện khác nhau, nhưng tôi đã tính toán thu hồi cho tất cả chúng trên cùng một dữ liệu thử nghiệm (rời rạc).

Kết quả đúng như mong đợi: tỷ lệ thu hồi là khoảng 60% ở tỷ lệ đào tạo 2: 1 và giảm khá nhanh vào thời điểm nó chuyển sang tỷ lệ 16: 1. Có một số tỷ lệ 2: 1 -> 6: 1 trong đó tỷ lệ thu hồi trên 5%.

Lần thử thứ hai: Tìm kiếm theo lưới

Tiếp theo, tôi muốn kiểm tra các thông số chính quy khác nhau và vì vậy tôi đã sử dụng GridSearchCV và tạo một lưới gồm một số giá trị của Ctham số cũng như class_weighttham số. Để dịch tỷ lệ n: m của các mẫu đào tạo âm: dương của class_weighttôi sang ngôn ngữ từ điển của tôi, tôi nghĩ rằng tôi chỉ cần chỉ định một số từ điển như sau:

{ 0:0.67, 1:0.33 } #expected 2:1
{ 0:0.75, 1:0.25 } #expected 3:1
{ 0:0.8, 1:0.2 }   #expected 4:1

và tôi cũng bao gồm Noneauto.

Lần này kết quả hoàn toàn bất ngờ. Tất cả các lần thu hồi của tôi đều rất nhỏ (<0,05) cho mọi giá trị class_weightngoại trừ auto. Vì vậy, tôi chỉ có thể cho rằng hiểu biết của tôi về cách đặt class_weighttừ điển là sai. Thật thú vị, class_weightgiá trị của 'tự động' trong tìm kiếm lưới là khoảng 59% cho tất cả các giá trị của Cvà tôi đoán nó cân bằng thành 1: 1?

Những câu hỏi của tôi

  1. Làm thế nào để bạn sử dụng đúng cách class_weightđể đạt được sự cân bằng khác nhau trong dữ liệu đào tạo từ những gì bạn thực sự cung cấp? Cụ thể, tôi chuyển qua từ điển nào class_weightđể sử dụng tỷ lệ n: m của các mẫu đào tạo âm: dương?

  2. Nếu bạn chuyển nhiều class_weighttừ điển khác nhau cho GridSearchCV, trong quá trình xác thực chéo, nó có cân bằng lại dữ liệu trong màn hình đào tạo theo từ điển nhưng sử dụng tỷ lệ mẫu thực cho trước để tính toán chức năng tính điểm của tôi trong màn hình thử nghiệm không? Điều này rất quan trọng vì bất kỳ số liệu nào chỉ hữu ích với tôi nếu nó đến từ dữ liệu theo tỷ lệ quan sát được.

  3. autogiá trị của class_weightlàm như xa như tỷ lệ? Tôi đọc tài liệu và tôi giả định rằng "cân bằng dữ liệu tỷ lệ nghịch với tần suất của chúng" chỉ có nghĩa là nó làm cho nó là 1: 1. Điều này có chính xác? Nếu không, ai đó có thể làm rõ?


Khi một người sử dụng class_weight, hàm mất mát được sửa đổi. Ví dụ, thay vì entropy chéo, nó trở thành entropy chéo có trọng lượng. hướngdatascience.com/
prashanth

Câu trả lời:


123

Trước hết, sẽ không tốt nếu chỉ nhớ lại một mình. Bạn có thể chỉ cần nhớ lại 100% bằng cách phân loại mọi thứ là lớp tích cực. Tôi thường đề xuất sử dụng AUC để chọn các tham số, sau đó tìm ngưỡng cho điểm hoạt động (giả sử mức độ chính xác nhất định) mà bạn quan tâm.

Đối với bao class_weightcông trình: Nó phạt sai lầm trong mẫu class[i]với class_weight[i]thay vì 1. Vì vậy, phương tiện đẳng cấp cao hơn trọng lượng bạn muốn đặt trọng tâm nhiều hơn vào một lớp. Theo những gì bạn nói, có vẻ như lớp 0 thường xuyên hơn 19 lần so với lớp 1. Vì vậy, bạn nên tăng class_weightlớp 1 so với lớp 0, giả sử {0: .1, 1: .9}. Nếu class_weighttổng không thành 1, về cơ bản nó sẽ thay đổi tham số chính quy hóa.

Để biết cách class_weight="auto"hoạt động, bạn có thể xem cuộc thảo luận này . Trong phiên bản dành cho nhà phát triển, bạn có thể sử dụng class_weight="balanced", điều này dễ hiểu hơn: về cơ bản nó có nghĩa là sao chép lớp nhỏ hơn cho đến khi bạn có nhiều mẫu như trong lớp lớn hơn, nhưng theo cách ẩn.


1
Cảm ơn! Câu hỏi nhanh: Tôi đã đề cập đến việc thu hồi cho rõ ràng và thực tế là tôi đang cố gắng quyết định sử dụng AUC nào làm thước đo của mình. Sự hiểu biết của tôi là tôi nên tối đa hóa diện tích dưới đường cong ROC hoặc khu vực dưới đường cong nhớ lại so với đường cong chính xác để tìm các tham số. Sau khi chọn các tham số theo cách này, tôi tin rằng tôi chọn ngưỡng phân loại bằng cách trượt dọc theo đường cong. Đây có phải là những gì bạn muốn nói? Nếu vậy, đường cong nào trong hai đường cong hợp lý nhất để xem nếu mục tiêu của tôi là chiếm được càng nhiều TP càng tốt? Ngoài ra, cảm ơn bạn đã làm việc và đóng góp cho scikit-learning !!!
kilgoretrout

1
Tôi nghĩ sử dụng ROC sẽ là cách chuẩn hơn để đi, nhưng tôi không nghĩ sẽ có sự khác biệt lớn. Tuy nhiên, bạn cần một số tiêu chí để chọn điểm trên đường cong.
Andreas Mueller

3
@MiNdFrEaK Tôi nghĩ ý của Andrew là công cụ ước lượng sao chép các mẫu trong lớp thiểu số, để mẫu của các lớp khác nhau được cân bằng. Nó chỉ lấy mẫu quá mức theo một cách ngầm hiểu.
Shawn TIAN

8
@MiNdFrEaK và Shawn Tian: Bộ phân loại dựa trên SV không tạo ra nhiều mẫu hơn của các lớp nhỏ hơn khi bạn sử dụng 'cân bằng'. Nó thực sự trừng phạt những sai lầm mắc phải ở các lớp nhỏ hơn. Nói cách khác là một sai lầm và gây hiểu lầm, đặc biệt là trong các tập dữ liệu lớn khi bạn không đủ khả năng tạo thêm mẫu. Câu trả lời này phải được chỉnh sửa.
Pablo Rivas

4
scikit-learn.org/dev/glossary.html#term-class-weight Trọng số lớp sẽ được sử dụng khác nhau tùy thuộc vào thuật toán: đối với mô hình tuyến tính (chẳng hạn như SVM tuyến tính hoặc hồi quy logistic), trọng số lớp sẽ thay đổi hàm mất mát bằng cách tính trọng lượng hao hụt của từng mẫu theo khối lượng loại của nó. Đối với các thuật toán dựa trên cây, các trọng số của lớp sẽ được sử dụng để định lại trọng số cho tiêu chí phân tách. Tuy nhiên, lưu ý rằng việc tái cân bằng này không tính đến khối lượng của các mẫu trong từng loại.
prashanth

2

Câu trả lời đầu tiên là tốt cho việc hiểu cách nó hoạt động. Nhưng tôi muốn hiểu cách tôi nên sử dụng nó trong thực tế.

TÓM LƯỢC

  • cho dữ liệu không cân bằng vừa phải KHÔNG bị nhiễu, không có nhiều sự khác biệt trong việc áp dụng trọng số lớp
  • đối với dữ liệu mất cân bằng vừa phải CÓ nhiễu và mất cân bằng mạnh, tốt hơn nên áp dụng trọng số lớp
  • param class_weight="balanced"hoạt động tốt trong trường hợp bạn không muốn tối ưu hóa theo cách thủ công
  • với việc class_weight="balanced"bạn nắm bắt được nhiều sự kiện đúng hơn (truy lại TRUE cao hơn) nhưng bạn cũng có nhiều khả năng nhận được cảnh báo sai hơn (độ chính xác TRUE thấp hơn)
    • do đó, tổng% TRUE có thể cao hơn thực tế vì tất cả các kết quả dương tính giả
    • AUC có thể dẫn đường cho bạn ở đây nếu cảnh báo sai là một vấn đề
  • không cần thay đổi ngưỡng quyết định thành% mất cân bằng, ngay cả đối với mức mất cân bằng mạnh, có thể giữ nguyên 0,5 (hoặc ở đâu đó xung quanh mức đó tùy thuộc vào những gì bạn cần)

NB

Kết quả có thể khác khi sử dụng RF hoặc GBM. sklearn không có class_weight="balanced" cho GBM nhưng lightgbmLGBMClassifier(is_unbalance=False)

CODE

# scikit-learn==0.21.3
from sklearn import datasets
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import roc_auc_score, classification_report
import numpy as np
import pandas as pd

# case: moderate imbalance
X, y = datasets.make_classification(n_samples=50*15, n_features=5, n_informative=2, n_redundant=0, random_state=1, weights=[0.8]) #,flip_y=0.1,class_sep=0.5)
np.mean(y) # 0.2

LogisticRegression(C=1e9).fit(X,y).predict(X).mean() # 0.184
(LogisticRegression(C=1e9).fit(X,y).predict_proba(X)[:,1]>0.5).mean() # 0.184 => same as first
LogisticRegression(C=1e9,class_weight={0:0.5,1:0.5}).fit(X,y).predict(X).mean() # 0.184 => same as first
LogisticRegression(C=1e9,class_weight={0:2,1:8}).fit(X,y).predict(X).mean() # 0.296 => seems to make things worse?
LogisticRegression(C=1e9,class_weight="balanced").fit(X,y).predict(X).mean() # 0.292 => seems to make things worse?

roc_auc_score(y,LogisticRegression(C=1e9).fit(X,y).predict(X)) # 0.83
roc_auc_score(y,LogisticRegression(C=1e9,class_weight={0:2,1:8}).fit(X,y).predict(X)) # 0.86 => about the same
roc_auc_score(y,LogisticRegression(C=1e9,class_weight="balanced").fit(X,y).predict(X)) # 0.86 => about the same

# case: strong imbalance
X, y = datasets.make_classification(n_samples=50*15, n_features=5, n_informative=2, n_redundant=0, random_state=1, weights=[0.95])
np.mean(y) # 0.06

LogisticRegression(C=1e9).fit(X,y).predict(X).mean() # 0.02
(LogisticRegression(C=1e9).fit(X,y).predict_proba(X)[:,1]>0.5).mean() # 0.02 => same as first
LogisticRegression(C=1e9,class_weight={0:0.5,1:0.5}).fit(X,y).predict(X).mean() # 0.02 => same as first
LogisticRegression(C=1e9,class_weight={0:1,1:20}).fit(X,y).predict(X).mean() # 0.25 => huh??
LogisticRegression(C=1e9,class_weight="balanced").fit(X,y).predict(X).mean() # 0.22 => huh??
(LogisticRegression(C=1e9,class_weight="balanced").fit(X,y).predict_proba(X)[:,1]>0.5).mean() # same as last

roc_auc_score(y,LogisticRegression(C=1e9).fit(X,y).predict(X)) # 0.64
roc_auc_score(y,LogisticRegression(C=1e9,class_weight={0:1,1:20}).fit(X,y).predict(X)) # 0.84 => much better
roc_auc_score(y,LogisticRegression(C=1e9,class_weight="balanced").fit(X,y).predict(X)) # 0.85 => similar to manual
roc_auc_score(y,(LogisticRegression(C=1e9,class_weight="balanced").fit(X,y).predict_proba(X)[:,1]>0.5).astype(int)) # same as last

print(classification_report(y,LogisticRegression(C=1e9).fit(X,y).predict(X)))
pd.crosstab(y,LogisticRegression(C=1e9).fit(X,y).predict(X),margins=True)
pd.crosstab(y,LogisticRegression(C=1e9).fit(X,y).predict(X),margins=True,normalize='index') # few prediced TRUE with only 28% TRUE recall and 86% TRUE precision so 6%*28%~=2%

print(classification_report(y,LogisticRegression(C=1e9,class_weight="balanced").fit(X,y).predict(X)))
pd.crosstab(y,LogisticRegression(C=1e9,class_weight="balanced").fit(X,y).predict(X),margins=True)
pd.crosstab(y,LogisticRegression(C=1e9,class_weight="balanced").fit(X,y).predict(X),margins=True,normalize='index') # 88% TRUE recall but also lot of false positives with only 23% TRUE precision, making total predicted % TRUE > actual % TRUE
Khi sử dụng trang web của chúng tôi, bạn xác nhận rằng bạn đã đọc và hiểu Chính sách cookieChính sách bảo mật của chúng tôi.
Licensed under cc by-sa 3.0 with attribution required.