Làm thế nào để nói Keras ngừng đào tạo dựa trên giá trị tổn thất?


82

Hiện tại tôi sử dụng mã sau:

callbacks = [
    EarlyStopping(monitor='val_loss', patience=2, verbose=0),
    ModelCheckpoint(kfold_weights_path, monitor='val_loss', save_best_only=True, verbose=0),
]
model.fit(X_train.astype('float32'), Y_train, batch_size=batch_size, nb_epoch=nb_epoch,
      shuffle=True, verbose=1, validation_data=(X_valid, Y_valid),
      callbacks=callbacks)

Nó yêu cầu Keras ngừng luyện tập khi tình trạng thua lỗ không được cải thiện trong 2 kỷ. Nhưng tôi muốn ngừng đào tạo sau khi tổn thất nhỏ hơn một số "THR" không đổi:

if val_loss < THR:
    break

Tôi đã thấy trong tài liệu có khả năng thực hiện gọi lại của riêng bạn: http://keras.io/callbacks/ Nhưng không tìm thấy cách nào để dừng quá trình đào tạo. Tôi cần một lời khuyên.

Câu trả lời:


85

Tôi đã tìm thấy câu trả lời. Tôi đã xem xét các nguồn của Keras và tìm ra mã cho EarlyStopping. Tôi đã thực hiện cuộc gọi lại của riêng mình, dựa trên nó:

class EarlyStoppingByLossVal(Callback):
    def __init__(self, monitor='val_loss', value=0.00001, verbose=0):
        super(Callback, self).__init__()
        self.monitor = monitor
        self.value = value
        self.verbose = verbose

    def on_epoch_end(self, epoch, logs={}):
        current = logs.get(self.monitor)
        if current is None:
            warnings.warn("Early stopping requires %s available!" % self.monitor, RuntimeWarning)

        if current < self.value:
            if self.verbose > 0:
                print("Epoch %05d: early stopping THR" % epoch)
            self.model.stop_training = True

Và cách sử dụng:

callbacks = [
    EarlyStoppingByLossVal(monitor='val_loss', value=0.00001, verbose=1),
    # EarlyStopping(monitor='val_loss', patience=2, verbose=0),
    ModelCheckpoint(kfold_weights_path, monitor='val_loss', save_best_only=True, verbose=0),
]
model.fit(X_train.astype('float32'), Y_train, batch_size=batch_size, nb_epoch=nb_epoch,
      shuffle=True, verbose=1, validation_data=(X_valid, Y_valid),
      callbacks=callbacks)

1
Chỉ cần nó hữu ích cho ai đó - trong trường hợp của tôi, tôi đã sử dụng monitor = 'loss', nó hoạt động tốt.
QtRoS

15
Có vẻ như Keras đã được cập nhật. Các EarlyStopping hàm callback có min_delta xây dựng vào nó bây giờ. Không cần phải hack mã nguồn nữa, yay! stackoverflow.com/a/41459368/3345375
jkdev

3
Khi đọc lại câu hỏi và câu trả lời, tôi cần sửa lại bản thân mình: min_delta có nghĩa là "Dừng lại sớm nếu không có đủ cải thiện trên mỗi kỷ nguyên (hoặc trên nhiều kỷ nguyên)." Tuy nhiên, OP đã hỏi làm thế nào để "Dừng lại sớm khi khoản lỗ xuống dưới một mức nhất định."
jkdev

NameError: name 'Callback' không được xác định ... Tôi sẽ sửa nó như thế nào?
alyssaeliyah

2
Eliyah hãy thử cái này: from keras.callbacks import Callback
ZFTurbo

26

Lệnh gọi lại keras.callbacks.EarlyStopping không có đối số min_delta. Từ tài liệu của Keras:

min_delta: thay đổi tối thiểu trong số lượng được giám sát để đủ điều kiện là cải tiến, tức là thay đổi tuyệt đối nhỏ hơn min_delta, sẽ được coi là không cải thiện.


3
Để tham khảo, đây là tài liệu cho phiên bản cũ hơn của Keras (1.1.0), trong đó đối số min_delta chưa được bao gồm: faroit.github.io/keras-docs/1.1.0/callbacks/#earlystopping
jkdev

làm thế nào tôi có thể làm cho nó không dừng lại cho đến khi min_deltavẫn tồn tại trong nhiều kỷ nguyên?
zyxue

có một thông số khác đối với EarlyStopping được gọi là sự kiên nhẫn: số kỷ nguyên không có sự cải thiện mà sau đó quá trình đào tạo sẽ bị dừng.
devin

13

Một giải pháp là gọi model.fit(nb_epoch=1, ...)bên trong vòng lặp for, sau đó bạn có thể đặt câu lệnh break bên trong vòng lặp for và thực hiện bất kỳ luồng điều khiển tùy chỉnh nào khác mà bạn muốn.


Sẽ thật tuyệt nếu họ thực hiện một lệnh gọi lại có một hàm duy nhất có thể làm được điều đó.
Trung thực

7

Tôi đã giải quyết vấn đề tương tự bằng cách sử dụng gọi lại tùy chỉnh.

Trong mã gọi lại tùy chỉnh sau đây, chỉ định THR với giá trị mà bạn muốn dừng đào tạo và thêm lệnh gọi lại vào mô hình của mình.

from keras.callbacks import Callback

class stopAtLossValue(Callback):

        def on_batch_end(self, batch, logs={}):
            THR = 0.03 #Assign THR with the value at which you want to stop training.
            if logs.get('loss') <= THR:
                 self.model.stop_training = True

2

Trong khi tham gia TensorFlow trong chuyên ngành thực hành , tôi đã học được một kỹ thuật rất thanh lịch. Chỉ cần sửa đổi một chút từ câu trả lời được chấp nhận.

Hãy làm ví dụ với dữ liệu MNIST yêu thích của chúng tôi.

import tensorflow as tf

class new_callback(tf.keras.callbacks.Callback):
    def epoch_end(self, epoch, logs={}): 
        if(logs.get('accuracy')> 0.90): # select the accuracy
            print("\n !!! 90% accuracy, no further training !!!")
            self.model.stop_training = True

mnist = tf.keras.datasets.mnist

(x_train, y_train),(x_test, y_test) = mnist.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0 #normalize

callbacks = new_callback()

# model = tf.keras.models.Sequential([# define your model here])

model.compile(optimizer=tf.optimizers.Adam(),
          loss='sparse_categorical_crossentropy',
          metrics=['accuracy'])
model.fit(x_train, y_train, epochs=10, callbacks=[callbacks])

Vì vậy, ở đây tôi đặt metrics=['accuracy'], và do đó trong lớp gọi lại, điều kiện được đặt thành'accuracy'> 0.90 .

Bạn có thể chọn bất kỳ số liệu nào và theo dõi quá trình đào tạo như ví dụ này. Quan trọng nhất là bạn có thể đặt các điều kiện khác nhau cho các chỉ số khác nhau và sử dụng chúng đồng thời.

Hy vọng rằng điều này sẽ giúp!


tên hàm phải là on_epoch_end
xarion

0

Đối với tôi, mô hình sẽ chỉ dừng đào tạo nếu tôi thêm câu lệnh trả về sau khi đặt tham số stop_training thành True vì tôi đang gọi sau self.model.evaluate. Vì vậy, hãy đảm bảo đặt stop_training = True ở cuối hàm hoặc thêm một câu lệnh trả về.

def on_epoch_end(self, batch, logs):
        self.epoch += 1
        self.stoppingCounter += 1
        print('\nstopping counter \n',self.stoppingCounter)

        #Stop training if there hasn't been any improvement in 'Patience' epochs
        if self.stoppingCounter >= self.patience:
            self.model.stop_training = True
            return

        # Test on additional set if there is one
        if self.testingOnAdditionalSet:
            evaluation = self.model.evaluate(self.val2X, self.val2Y, verbose=0)
            self.validationLoss2.append(evaluation[0])
            self.validationAcc2.append(evaluation[1])enter code here

0

Nếu đang sử dụng vòng lặp đào tạo tùy chỉnh, bạn có thể sử dụng một collections.deque, là danh sách "cuộn" có thể được thêm vào và các mục bên trái sẽ xuất hiện khi danh sách dài hơn maxlen. Đây là dòng:

loss_history = deque(maxlen=early_stopping + 1)

for epoch in range(epochs):
    fit(epoch)
    loss_history.append(test_loss.result().numpy())
    if len(loss_history) > early_stopping and loss_history.popleft() < min(loss_history)
            break

Đây là một ví dụ đầy đủ:

import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
import tensorflow_datasets as tfds
import tensorflow as tf
from tensorflow.keras.layers import Dense
from collections import deque

data, info = tfds.load('iris', split='train', as_supervised=True, with_info=True)

data = data.map(lambda x, y: (tf.cast(x, tf.int32), y))

train_dataset = data.take(120).batch(4)
test_dataset = data.skip(120).take(30).batch(4)

model = tf.keras.models.Sequential([
    Dense(8, activation='relu'),
    Dense(16, activation='relu'),
    Dense(info.features['label'].num_classes)])

loss_object = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)

train_loss = tf.keras.metrics.Mean()
test_loss = tf.keras.metrics.Mean()

train_acc = tf.keras.metrics.SparseCategoricalAccuracy()
test_acc = tf.keras.metrics.SparseCategoricalAccuracy()

opt = tf.keras.optimizers.Adam(learning_rate=1e-3)


@tf.function
def train_step(inputs, labels):
    with tf.GradientTape() as tape:
        logits = model(inputs, training=True)
        loss = loss_object(labels, logits)

    gradients = tape.gradient(loss, model.trainable_variables)
    opt.apply_gradients(zip(gradients, model.trainable_variables))
    train_loss(loss)
    train_acc(labels, logits)


@tf.function
def test_step(inputs, labels):
    logits = model(inputs, training=False)
    loss = loss_object(labels, logits)
    test_loss(loss)
    test_acc(labels, logits)


def fit(epoch):
    template = 'Epoch {:>2} Train Loss {:.3f} Test Loss {:.3f} ' \
               'Train Acc {:.2f} Test Acc {:.2f}'

    train_loss.reset_states()
    test_loss.reset_states()
    train_acc.reset_states()
    test_acc.reset_states()

    for X_train, y_train in train_dataset:
        train_step(X_train, y_train)

    for X_test, y_test in test_dataset:
        test_step(X_test, y_test)

    print(template.format(
        epoch + 1,
        train_loss.result(),
        test_loss.result(),
        train_acc.result(),
        test_acc.result()
    ))


def main(epochs=50, early_stopping=10):
    loss_history = deque(maxlen=early_stopping + 1)

    for epoch in range(epochs):
        fit(epoch)
        loss_history.append(test_loss.result().numpy())
        if len(loss_history) > early_stopping and loss_history.popleft() < min(loss_history):
            print(f'\nEarly stopping. No validation loss '
                  f'improvement in {early_stopping} epochs.')
            break

if __name__ == '__main__':
    main(epochs=250, early_stopping=10)
Epoch  1 Train Loss 1.730 Test Loss 1.449 Train Acc 0.33 Test Acc 0.33
Epoch  2 Train Loss 1.405 Test Loss 1.220 Train Acc 0.33 Test Acc 0.33
Epoch  3 Train Loss 1.173 Test Loss 1.054 Train Acc 0.33 Test Acc 0.33
Epoch  4 Train Loss 1.006 Test Loss 0.935 Train Acc 0.33 Test Acc 0.33
Epoch  5 Train Loss 0.885 Test Loss 0.846 Train Acc 0.33 Test Acc 0.33
...
Epoch 89 Train Loss 0.196 Test Loss 0.240 Train Acc 0.89 Test Acc 0.87
Epoch 90 Train Loss 0.195 Test Loss 0.239 Train Acc 0.89 Test Acc 0.87
Epoch 91 Train Loss 0.195 Test Loss 0.239 Train Acc 0.89 Test Acc 0.87
Epoch 92 Train Loss 0.194 Test Loss 0.239 Train Acc 0.90 Test Acc 0.87

Early stopping. No validation loss improvement in 10 epochs.
Khi sử dụng trang web của chúng tôi, bạn xác nhận rằng bạn đã đọc và hiểu Chính sách cookieChính sách bảo mật của chúng tôi.
Licensed under cc by-sa 3.0 with attribution required.