ROC trung bình để xác nhận chéo 10 lần lặp lại với ước tính xác suất


15

Tôi đang lên kế hoạch sử dụng lặp lại (10 lần) phân tầng xác thực chéo 10 lần cho khoảng 10.000 trường hợp sử dụng thuật toán học máy. Mỗi lần lặp lại sẽ được thực hiện với hạt giống ngẫu nhiên khác nhau.

Trong quá trình này, tôi tạo ra 10 trường hợp ước tính xác suất cho mỗi trường hợp. 1 trường hợp ước tính xác suất cho mỗi trong 10 lần lặp lại của xác thực chéo 10 lần

Tôi có thể tính trung bình 10 xác suất cho mỗi trường hợp và sau đó tạo đường cong ROC trung bình mới (đại diện cho kết quả của CV 10 lần lặp lại), có thể so sánh với các đường cong ROC khác bằng cách so sánh ghép đôi không?

Câu trả lời:


13

Từ mô tả của bạn, nó dường như có ý nghĩa hoàn hảo: không chỉ bạn có thể tính toán đường cong ROC trung bình, mà cả phương sai xung quanh nó để xây dựng khoảng tin cậy. Nó sẽ cho bạn ý tưởng về mức độ ổn định của mô hình của bạn.

Ví dụ, như thế này:

nhập mô tả hình ảnh ở đây

Ở đây tôi đặt các đường cong ROC riêng lẻ cũng như đường cong trung bình và khoảng tin cậy. Có những khu vực mà các đường cong đồng ý, vì vậy chúng ta có ít phương sai hơn và có những khu vực chúng không đồng ý.

Đối với CV lặp đi lặp lại, bạn có thể chỉ cần lặp lại nhiều lần và nhận tổng số trung bình trên tất cả các nếp gấp riêng lẻ:

nhập mô tả hình ảnh ở đây

Nó khá giống với hình ảnh trước, nhưng đưa ra các ước tính ổn định hơn (nghĩa là đáng tin cậy) về giá trị trung bình và phương sai.

Đây là mã để có được cốt truyện:

import matplotlib.pyplot as plt
import numpy as np
from scipy import interp

from sklearn.datasets import make_classification
from sklearn.cross_validation import KFold
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import roc_curve

X, y = make_classification(n_samples=500, random_state=100, flip_y=0.3)

kf = KFold(n=len(y), n_folds=10)

tprs = []
base_fpr = np.linspace(0, 1, 101)

plt.figure(figsize=(5, 5))

for i, (train, test) in enumerate(kf):
    model = LogisticRegression().fit(X[train], y[train])
    y_score = model.predict_proba(X[test])
    fpr, tpr, _ = roc_curve(y[test], y_score[:, 1])

    plt.plot(fpr, tpr, 'b', alpha=0.15)
    tpr = interp(base_fpr, fpr, tpr)
    tpr[0] = 0.0
    tprs.append(tpr)

tprs = np.array(tprs)
mean_tprs = tprs.mean(axis=0)
std = tprs.std(axis=0)

tprs_upper = np.minimum(mean_tprs + std, 1)
tprs_lower = mean_tprs - std


plt.plot(base_fpr, mean_tprs, 'b')
plt.fill_between(base_fpr, tprs_lower, tprs_upper, color='grey', alpha=0.3)

plt.plot([0, 1], [0, 1],'r--')
plt.xlim([-0.01, 1.01])
plt.ylim([-0.01, 1.01])
plt.ylabel('True Positive Rate')
plt.xlabel('False Positive Rate')
plt.axes().set_aspect('equal', 'datalim')
plt.show()

Đối với CV lặp đi lặp lại:

idx = np.arange(0, len(y))

for j in np.random.randint(0, high=10000, size=10):
    np.random.shuffle(idx)
    kf = KFold(n=len(y), n_folds=10, random_state=j)

    for i, (train, test) in enumerate(kf):
        model = LogisticRegression().fit(X[idx][train], y[idx][train])
        y_score = model.predict_proba(X[idx][test])
        fpr, tpr, _ = roc_curve(y[idx][test], y_score[:, 1])

        plt.plot(fpr, tpr, 'b', alpha=0.05)
        tpr = interp(base_fpr, fpr, tpr)
        tpr[0] = 0.0
        tprs.append(tpr)

Nguồn cảm hứng: http://scikit-learn.org/urdy/auto_examples/model_selection/plot_roc_crossval.html


3

Điều này không đúng với xác suất trung bình vì điều đó sẽ không thể hiện dự đoán mà bạn đang cố xác thực và liên quan đến ô nhiễm trên các mẫu xác nhận.

Lưu ý rằng 100 lần lặp lại xác nhận chéo 10 lần có thể được yêu cầu để đạt được độ chính xác đầy đủ. Hoặc sử dụng bootstrap lạc quan Efron-Gong yêu cầu số lần lặp ít hơn cho cùng độ chính xác (xem ví dụ: các hàm rmsgói R validate).

c


Bạn có thể vui lòng giải thích thêm về lý do tại sao tính trung bình là không chính xác?
DataD'oh

Đã nêu. Bạn cần xác thực biện pháp bạn sẽ sử dụng trong trường.
Frank Harrell
Khi sử dụng trang web của chúng tôi, bạn xác nhận rằng bạn đã đọc và hiểu Chính sách cookieChính sách bảo mật của chúng tôi.
Licensed under cc by-sa 3.0 with attribution required.