Tôi đang đào tạo mạng nơ-ron cho dự án của mình bằng Keras. Keras đã cung cấp một chức năng để dừng sớm. Tôi có thể biết những thông số nào cần được quan sát để tránh mạng nơ-ron của tôi bị quá tải bằng cách dừng sớm không?
Tôi đang đào tạo mạng nơ-ron cho dự án của mình bằng Keras. Keras đã cung cấp một chức năng để dừng sớm. Tôi có thể biết những thông số nào cần được quan sát để tránh mạng nơ-ron của tôi bị quá tải bằng cách dừng sớm không?
Câu trả lời:
Dừng sớm về cơ bản là dừng việc đào tạo khi khoản lỗ của bạn bắt đầu tăng lên (hay nói cách khác là độ chính xác của việc xác nhận bắt đầu giảm). Theo các tài liệu nó được sử dụng như sau;
keras.callbacks.EarlyStopping(monitor='val_loss',
min_delta=0,
patience=0,
verbose=0, mode='auto')
Giá trị phụ thuộc vào việc triển khai của bạn (sự cố, kích thước lô, v.v.) nhưng nói chung để tránh trang bị quá mức, tôi sẽ sử dụng;
monitor
đối số thành 'val_loss'
.min_delta
là một ngưỡng để xác định liệu lượng lỗ tại một số kỷ nguyên có được cải thiện hay không. Nếu chênh lệch lỗ dưới đây min_delta
, nó được định lượng là không cải thiện. Tốt hơn nên để nó là 0 vì chúng ta quan tâm đến thời điểm tổn thất trở nên tồi tệ hơn.patience
đối số đại diện cho số kỷ nguyên trước khi dừng lại khi khoản lỗ của bạn bắt đầu tăng lên (ngừng cải thiện). Điều này phụ thuộc vào việc triển khai của bạn, nếu bạn sử dụng các lô rất nhỏ
hoặc tỷ lệ học tập lớn theo đường zig-zag mất mát của bạn (độ chính xác sẽ ồn ào hơn) vì vậy tốt hơn nên đặt một patience
đối số lớn . Nếu bạn sử dụng các lô lớn và tỷ lệ học tập nhỏ, việc thua lỗ của bạn sẽ suôn sẻ hơn, vì vậy bạn có thể sử dụng một patience
đối số nhỏ hơn . Dù bằng cách nào, tôi sẽ để nó là 2 vì vậy tôi sẽ cho người mẫu nhiều cơ hội hơn.verbose
quyết định những gì sẽ in, để nó ở mặc định (0).mode
đối số phụ thuộc vào hướng số lượng được theo dõi của bạn (nó được cho là đang giảm hay tăng), vì chúng tôi theo dõi tổn thất, chúng tôi có thể sử dụng min
. Nhưng hãy để keras xử lý điều đó cho chúng tôi và đặt nó thànhauto
Vì vậy, tôi sẽ sử dụng một cái gì đó như thế này và thử nghiệm bằng cách vẽ biểu đồ tổn thất lỗi có và không có điểm dừng sớm.
keras.callbacks.EarlyStopping(monitor='val_loss',
min_delta=0,
patience=2,
verbose=0, mode='auto')
Đối với sự mơ hồ có thể xảy ra về cách hoạt động của lệnh gọi lại, tôi sẽ cố gắng giải thích thêm. Khi bạn gọi fit(... callbacks=[es])
trên mô hình của mình, Keras sẽ gọi các đối tượng gọi lại đã cho các hàm được xác định trước. Các chức năng này có thể được gọi on_train_begin
, on_train_end
, on_epoch_begin
, on_epoch_end
và on_batch_begin
, on_batch_end
. Gọi lại dừng sớm được gọi trên mọi kết thúc kỷ nguyên, so sánh giá trị được giám sát tốt nhất với giá trị hiện tại và dừng nếu các điều kiện được đáp ứng (bao nhiêu kỷ nguyên đã qua kể từ khi quan sát giá trị được theo dõi tốt nhất và nó không chỉ là đối số kiên nhẫn, sự khác biệt giữa giá trị cuối cùng lớn hơn min_delta, v.v.).
Như được chỉ ra bởi @BrentFaust trong phần nhận xét, quá trình đào tạo của mô hình sẽ tiếp tục cho đến khi đáp ứng các điều kiện Dừng sớm hoặc epochs
tham số (mặc định = 10) trong fit()
được thỏa mãn. Đặt lệnh gọi lại Dừng sớm sẽ không làm cho mô hình đào tạo vượt quá epochs
tham số của nó . Vì vậy, fit()
hàm gọi với epochs
giá trị lớn hơn sẽ được hưởng lợi nhiều hơn từ lệnh gọi lại Dừng sớm.
callbacks=[EarlyStopping(patience=2)]
không có tác dụng, trừ khi các kỷ nguyên được đưa ra model.fit(..., epochs=max_epochs)
.
epoch=1
trong vòng lặp for (cho các trường hợp sử dụng khác nhau) trong đó việc gọi lại này sẽ không thành công. Nếu có sự mơ hồ trong câu trả lời của tôi, tôi sẽ cố gắng giải thích nó theo cách tốt hơn.
restore_best_weights
đối số (chưa có trong tài liệu), đối số này tải mô hình với trọng lượng tốt nhất sau khi tập luyện. Tuy nhiên, cho mục đích của bạn, tôi sẽ sử dụng ModelCheckpoint
callback với save_best_only
đối số. Bạn có thể kiểm tra tài liệu hướng dẫn sử dụng nhưng bạn cần phải tự nạp tạ tốt nhất sau khi tập.
min_delta
là một ngưỡng để đánh giá liệu lượng thay đổi trong giá trị được giám sát như một sự cải thiện hay không. Vì vậy, có, nếu chúng tôi đưa ramonitor = 'val_loss'
thì nó sẽ đề cập đến sự khác biệt giữa mất xác thực hiện tại và mất xác thực trước đó. Trong thực tế, nếu bạn đưa ramin_delta=0.1
mức giảm mất xác thực (hiện tại - trước đó) nhỏ hơn 0,1 sẽ không định lượng được, do đó sẽ ngừng đào tạo (nếu bạn cópatience = 0
).