Làm cách nào để sử dụng đầu ra của GridSearch?


23

Tôi hiện đang làm việc với Python và Scikit tìm hiểu cho mục đích phân loại và đọc một số xung quanh GridSearch Tôi nghĩ rằng đây là một cách tuyệt vời để tối ưu hóa các tham số ước tính của tôi để có kết quả tốt nhất.

Phương pháp của tôi là thế này:

  1. Chia dữ liệu của tôi thành đào tạo / kiểm tra.
  2. Sử dụng GridSearch với xác thực 5Fold Cross để đào tạo và kiểm tra các công cụ ước tính của tôi (Random Forest, Gradient Boost, SVC trong số những người khác) để có được các công cụ ước tính tốt nhất với sự kết hợp tối ưu của các tham số siêu tốc.
  3. Sau đó, tôi tính toán các số liệu trên từng công cụ ước tính của mình như Chính xác, Nhớ lại, FMeasure và Matthews Correlation Coffic, sử dụng bộ kiểm tra của tôi để dự đoán các phân loại và so sánh chúng với các nhãn lớp thực tế.

Ở giai đoạn này, tôi thấy hành vi lạ và tôi không biết phải tiến hành như thế nào. Tôi có lấy .best_estimator_ từ GridSearch và sử dụng điều này làm đầu ra 'tối ưu' từ tìm kiếm lưới và thực hiện dự đoán bằng công cụ ước tính này không? Nếu tôi làm điều này, tôi thấy rằng các số liệu của giai đoạn 3 thường thấp hơn nhiều so với việc tôi chỉ đơn giản là đào tạo trên tất cả các dữ liệu đào tạo và kiểm tra trên bộ kiểm tra. Hoặc, tôi chỉ đơn giản lấy đối tượng GridSearchCV đầu ra làm công cụ ước tính mới ? Nếu tôi làm điều này, tôi sẽ nhận được điểm số tốt hơn cho số liệu giai đoạn 3 của mình, nhưng có vẻ kỳ quặc khi sử dụng đối tượng GridSearchCV thay vì trình phân loại dự định (Ví dụ: Khu rừng ngẫu nhiên) ...

EDIT: Vậy câu hỏi của tôi là sự khác biệt giữa đối tượng GridSearchCV được trả về và thuộc tính .best_estimator_ là gì? Tôi nên sử dụng cái nào trong số này để tính các số liệu tiếp theo? Tôi có thể sử dụng đầu ra này như một trình phân loại thông thường (ví dụ: sử dụng dự đoán) hoặc nếu không thì tôi nên sử dụng nó như thế nào?

Câu trả lời:


27

Quyết định đi xa và tìm câu trả lời thỏa mãn câu hỏi của tôi, và viết chúng lên đây cho bất cứ ai khác thắc mắc.

Thuộc tính .best_estimator_ là một thể hiện của kiểu mô hình đã chỉ định, có tổ hợp 'tốt nhất' của các tham số đã cho từ param_grid. Trường hợp này có hữu ích hay không phụ thuộc vào việc tham số refit có được đặt thành True hay không (theo mặc định). Ví dụ:

clf = GridSearchCV(estimator=RandomForestClassifier(), 
                    param_grid=parameter_candidates,
                    cv=5,
                    refit=True,
                    error_score=0,
                    n_jobs=-1)

clf.fit(training_set, training_classifications)
optimised_random_forest = clf.best_estimator_
return optimised_random_forest

Sẽ trả về RandomForestClassifier. Đây là tất cả khá rõ ràng từ các tài liệu . Điều không rõ ràng từ tài liệu này là tại sao hầu hết các ví dụ không sử dụng cụ thể .best_estimator_ và thay vào đó làm điều này:

clf = GridSearchCV(estimator=RandomForestClassifier(), 
                    param_grid=parameter_candidates,
                    cv=5,
                    refit=True,
                    error_score=0,
                    n_jobs=-1)

clf.fit(training_set, training_classifications)
return clf

Cách tiếp cận thứ hai này trả về một cá thể GridSearchCV, với tất cả các chuông và còi của GridSearchCV, chẳng hạn như .best_estimator_, .best_params, v.v., bản thân nó có thể được sử dụng như một trình phân loại được đào tạo bởi vì:

Optimised Random Forest Accuracy:  0.916970802919708
[[139  47]
 [ 44 866]]
GridSearchCV Accuracy:  0.916970802919708
[[139  47]
 [ 44 866]]

Nó chỉ sử dụng cùng một ví dụ ước tính tốt nhất khi đưa ra dự đoán. Vì vậy, trong thực tế không có sự khác biệt giữa hai điều này trừ khi bạn đặc biệt chỉ muốn bản thân trình ước tính. Là một lưu ý phụ, sự khác biệt của tôi về số liệu không liên quan và xuống chức năng trọng số của lớp lỗi.


Cảm ơn bài viết của bạn @Dan, nó rất hữu ích. Tôi muốn yêu cầu bạn làm rõ. Tại trường hợp sau, nếu tôi đã refit=Falsesau đó clf.fitsẽ không được thực hiện với sự phân loại tốt nhất?
Poete Maudit

@PoeteMaudit Tham số refit cho biết hàm GridSearchCV lấy các tham số tốt nhất được tìm thấy và đào tạo lại mô hình bằng các tham số đó trên toàn bộ tập dữ liệu. Nếu refit = Sai, thì best_estimator không có sẵn, theo tài liệu: scikit-learn.org/ sóng / modules / generic / trộm
Dan Carter

0

GridSearchCV cho phép bạn kết hợp một công cụ ước tính với phần mở đầu tìm kiếm lưới để điều chỉnh các tham số siêu. Phương thức chọn tham số tối ưu từ tìm kiếm lưới và sử dụng nó với công cụ ước tính do người dùng chọn. GridSearchCV kế thừa các phương thức từ trình phân loại, vì vậy, bạn có thể sử dụng các phương thức .score, .predict, v.v. trực tiếp thông qua giao diện GridSearchCV. Nếu bạn muốn trích xuất các tham số siêu tốt nhất được xác định bởi tìm kiếm dạng lưới, bạn có thể sử dụng .best_params_ và điều này sẽ trả về siêu tham số tốt nhất. Sau đó, bạn có thể chuyển siêu tham số này cho công cụ ước tính của mình một cách riêng biệt.

Sử dụng .predict trực tiếp sẽ mang lại kết quả tương tự như nhận được siêu tham số tốt nhất thông qua .best_param_ và sau đó sử dụng nó trong mô hình của bạn. Bằng cách hiểu các hoạt động gạch chân của tìm kiếm lưới, chúng ta có thể thấy tại sao đây là trường hợp.


Tìm kiếm lưới

Kỹ thuật này được sử dụng để tìm các tham số tối ưu để sử dụng với thuật toán. Đây KHÔNG phải là trọng số hoặc mô hình, những người được học bằng cách sử dụng dữ liệu. Điều này rõ ràng là khá khó hiểu vì vậy tôi sẽ phân biệt giữa các tham số này, bằng cách gọi một siêu tham số.

Các tham số siêu giống như k trong Hàng xóm k-Gần nhất (k-NN). k-NN yêu cầu người dùng chọn hàng xóm cần xem xét khi tính khoảng cách. Thuật toán sau đó điều chỉnh một tham số, một ngưỡng, để xem nếu một ví dụ mới nằm trong phân phối đã học, điều này được thực hiện với dữ liệu.

Làm thế nào để chúng ta chọn k?

Một số người chỉ cần đi với các khuyến nghị dựa trên các nghiên cứu trước đây về loại dữ liệu. Những người khác sử dụng tìm kiếm lưới. Phương pháp này sẽ có thể xác định tốt nhất k nào là tối ưu để sử dụng cho dữ liệu của bạn.

Làm thế nào nó hoạt động?

[1,2,3,...,10]

Điều này đi ngược lại các nguyên tắc không sử dụng dữ liệu thử nghiệm !!

nnn1n

Giá trị siêu tham số đã chọn là giá trị đạt được hiệu suất trung bình cao nhất trong các lần n. Một khi bạn hài lòng với thuật toán của mình, thì bạn có thể kiểm tra nó trên bộ thử nghiệm. Nếu bạn đi thẳng vào bộ thử nghiệm thì bạn có nguy cơ bị thừa.


Xin chào Jah, đây là một câu trả lời hay nhưng tôi vẫn không khôn ngoan hơn khi trả lời câu hỏi của mình. Tôi đã cập nhật tiêu đề câu hỏi và chính câu hỏi để thử và làm cho mọi thứ rõ ràng hơn.
Dan Carter

Viết tìm kiếm lưới của riêng bạn. Đó là nghĩa đen tạo một mảng, sau đó thêm một vòng lặp for xung quanh mô hình của bạn. Sau đó, ở cuối vòng lặp for của bạn ghi lại hiệu suất kết quả thành một mảng. Sau khi bạn đã trải qua tất cả các giá trị có thể có trong lưới của mình, hãy xem các mảng biểu diễn và chọn ra giá trị tốt nhất. Đó là giá trị tối ưu cho siêu tham số của bạn. Dựa vào các chức năng tích hợp cho các vấn đề cơ bản rất không được khuyến khích cho khoa học dữ liệu. Dữ liệu thay đổi rất lớn và tốt nhất để bạn có quyền kiểm soát!
JahKnows

Đó sẽ là một gợi ý tốt nếu tôi chỉ có một siêu tham số để tối ưu hóa, nhưng nếu tôi có 4? 5? Một vòng 4/5 lần lồng cho vòng lặp là xấu xí và tôi thấy không cần phải phát minh lại bánh xe ở đây, điều đó sẽ lãng phí thời gian, và đó là lý do các gói như thế này tồn tại.
Dan Carter

GridSearchCV cho phép bạn kết hợp một công cụ ước tính với cài đặt GridSearchCV. Vì vậy, nó làm chính xác những gì chúng ta vừa thảo luận. Sau đó, nó chọn tham số tối ưu và sử dụng nó với công cụ ước tính bạn đã chọn. GridSearchCV kế thừa các phương thức từ trình phân loại, vì vậy, bạn có thể sử dụng các phương thức .score, .predict, v.v. trực tiếp thông qua giao diện GridSearchCV. Tôi không khuyên bạn nên làm điều này tuy nhiên, các công cụ dễ dàng hơn có nghĩa là kiểm soát ít hơn. Đối với một cái gì đó đơn giản như một tìm kiếm lưới chỉ cần tự viết mã.
JahKnows

1
Câu trả lời này không giải quyết câu hỏi liên quan đến việc sử dụng GridSearchCV.
Hobbes
Khi sử dụng trang web của chúng tôi, bạn xác nhận rằng bạn đã đọc và hiểu Chính sách cookieChính sách bảo mật của chúng tôi.
Licensed under cc by-sa 3.0 with attribution required.