Keras'a kayıp değerine göre eğitimi nasıl durdurabilirim?


82

Şu anda aşağıdaki kodu kullanıyorum:

callbacks = [
    EarlyStopping(monitor='val_loss', patience=2, verbose=0),
    ModelCheckpoint(kfold_weights_path, monitor='val_loss', save_best_only=True, verbose=0),
]
model.fit(X_train.astype('float32'), Y_train, batch_size=batch_size, nb_epoch=nb_epoch,
      shuffle=True, verbose=1, validation_data=(X_valid, Y_valid),
      callbacks=callbacks)

Keras'a, kayıp 2 dönem boyunca iyileşmediğinde eğitimi durdurmasını söyler. Ancak, kayıp sabit bir "THR" den daha küçük hale geldikten sonra eğitimi durdurmak istiyorum:

if val_loss < THR:
    break

Dokümantasyonda kendi geri aramanızı yapma olasılığınız olduğunu gördüm: http://keras.io/callbacks/ Ancak hiçbir şey eğitim sürecini nasıl durduracağınızı bulamadı. Bir tavsiyeye ihtiyacım var.

Yanıtlar:


85

Cevabı buldum. Keras kaynaklarına baktım ve EarlyStopping kodunu buldum. Buna dayanarak kendi geri aramamı yaptım:

class EarlyStoppingByLossVal(Callback):
    def __init__(self, monitor='val_loss', value=0.00001, verbose=0):
        super(Callback, self).__init__()
        self.monitor = monitor
        self.value = value
        self.verbose = verbose

    def on_epoch_end(self, epoch, logs={}):
        current = logs.get(self.monitor)
        if current is None:
            warnings.warn("Early stopping requires %s available!" % self.monitor, RuntimeWarning)

        if current < self.value:
            if self.verbose > 0:
                print("Epoch %05d: early stopping THR" % epoch)
            self.model.stop_training = True

Ve kullanım:

callbacks = [
    EarlyStoppingByLossVal(monitor='val_loss', value=0.00001, verbose=1),
    # EarlyStopping(monitor='val_loss', patience=2, verbose=0),
    ModelCheckpoint(kfold_weights_path, monitor='val_loss', save_best_only=True, verbose=0),
]
model.fit(X_train.astype('float32'), Y_train, batch_size=batch_size, nb_epoch=nb_epoch,
      shuffle=True, verbose=1, validation_data=(X_valid, Y_valid),
      callbacks=callbacks)

1
Birisi için yararlı olacaksa - benim durumumda monitör = 'kayıp' kullandım, iyi çalıştı.
QtRoS

15
Görünüşe göre Keras güncellendi. EarlyStopping geri arama işlevi artık yerleşik min_delta vardır. Artık kaynak kodunu kırmaya gerek yok, yaşasın! stackoverflow.com/a/41459368/3345375
jkdev

3
Soruyu ve yanıtları yeniden okuduktan sonra kendimi düzeltmem gerekiyor: min_delta, "Dönem başına (veya birden çok dönem için) yeterli gelişme yoksa erken dur" anlamına gelir. Ancak OP, "Kayıp belirli bir seviyenin altına düştüğünde erken durdurulmasını" sordu.
jkdev

NameError: 'Geri arama' adı tanımlanmadı ... Bunu nasıl düzelteceğim?
alyssaeliyah

2
Eliyah bunu deneyin: from keras.callbacks import Callback
ZFTurbo

26

Keras.callbacks.EarlyStopping geri aramasının bir min_delta bağımsız değişkeni vardır. Keras belgelerinden:

min_delta: bir iyileştirme olarak nitelendirilmek için izlenen miktardaki minimum değişiklik, yani min_delta'dan daha küçük bir mutlak değişiklik, iyileştirme olarak sayılmayacaktır.


3
Referans için, min_delta bağımsız değişkeninin henüz dahil edilmediği Keras'ın önceki bir sürümü (1.1.0) için dokümanlar: faroit.github.io/keras-docs/1.1.0/callbacks/#earlystopping
jkdev

min_deltabirden çok dönem boyunca devam edene kadar nasıl durdurabilirim ?
zyxue

EarlyStopping için sabır denen başka bir parametre daha var: iyileştirme olmayan dönemlerin sayısı ve sonrasında eğitim durdurulacak.
devin

13

Çözümlerden biri, model.fit(nb_epoch=1, ...)for döngüsü içinde arama yapmaktır, ardından for döngüsünün içine bir break ifadesi koyabilir ve istediğiniz diğer özel kontrol akışını yapabilirsiniz.


Bunu yapabilen tek bir işlevi alan bir geri arama yapsalar iyi olurdu.
Dürüstlük

8

Özel geri aramayı kullanarak aynı sorunu çözdüm.

Aşağıdaki özel geri arama kodunda, eğitimi durdurmak istediğiniz değeri THR'ye atayın ve geri aramayı modelinize ekleyin.

from keras.callbacks import Callback

class stopAtLossValue(Callback):

        def on_batch_end(self, batch, logs={}):
            THR = 0.03 #Assign THR with the value at which you want to stop training.
            if logs.get('loss') <= THR:
                 self.model.stop_training = True

2

TensorFlow'u uygulama uzmanlığı alırken çok zarif bir teknik öğrendim. Kabul edilen cevaptan biraz değiştirildi.

En sevdiğimiz MNIST verileriyle örnek oluşturalım.

import tensorflow as tf

class new_callback(tf.keras.callbacks.Callback):
    def epoch_end(self, epoch, logs={}): 
        if(logs.get('accuracy')> 0.90): # select the accuracy
            print("\n !!! 90% accuracy, no further training !!!")
            self.model.stop_training = True

mnist = tf.keras.datasets.mnist

(x_train, y_train),(x_test, y_test) = mnist.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0 #normalize

callbacks = new_callback()

# model = tf.keras.models.Sequential([# define your model here])

model.compile(optimizer=tf.optimizers.Adam(),
          loss='sparse_categorical_crossentropy',
          metrics=['accuracy'])
model.fit(x_train, y_train, epochs=10, callbacks=[callbacks])

Bu yüzden burada metrics=['accuracy'], ve böylece geri arama sınıfında koşul ayarlandı 'accuracy'> 0.90.

Herhangi bir ölçüyü seçebilir ve bu örnek gibi eğitimi izleyebilirsiniz. En önemlisi, farklı metrikler için farklı koşullar belirleyebilir ve bunları aynı anda kullanabilirsiniz.

Umarım bu yardımcı olur!


1
işlev adı on_epoch_end olmalıdır
xarion

0

Benim için model, self.model.evaluate'den sonra çağırdığım için stop_training parametresini True olarak ayarladıktan sonra bir dönüş ifadesi eklersem eğitimi durdurur. Bu nedenle, işlevin sonuna stop_training = True koyduğunuzdan emin olun veya bir dönüş ifadesi ekleyin.

def on_epoch_end(self, batch, logs):
        self.epoch += 1
        self.stoppingCounter += 1
        print('\nstopping counter \n',self.stoppingCounter)

        #Stop training if there hasn't been any improvement in 'Patience' epochs
        if self.stoppingCounter >= self.patience:
            self.model.stop_training = True
            return

        # Test on additional set if there is one
        if self.testingOnAdditionalSet:
            evaluation = self.model.evaluate(self.val2X, self.val2Y, verbose=0)
            self.validationLoss2.append(evaluation[0])
            self.validationAcc2.append(evaluation[1])enter code here

0

Özel bir eğitim döngüsü kullanıyorsanız, eklenebilen collections.deque"yuvarlanan" bir liste olan a kullanabilirsiniz ve sol taraftaki öğeler, liste daha uzun olduğunda açılır.maxlen . İşte satır:

loss_history = deque(maxlen=early_stopping + 1)

for epoch in range(epochs):
    fit(epoch)
    loss_history.append(test_loss.result().numpy())
    if len(loss_history) > early_stopping and loss_history.popleft() < min(loss_history)
            break

İşte tam bir örnek:

import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
import tensorflow_datasets as tfds
import tensorflow as tf
from tensorflow.keras.layers import Dense
from collections import deque

data, info = tfds.load('iris', split='train', as_supervised=True, with_info=True)

data = data.map(lambda x, y: (tf.cast(x, tf.int32), y))

train_dataset = data.take(120).batch(4)
test_dataset = data.skip(120).take(30).batch(4)

model = tf.keras.models.Sequential([
    Dense(8, activation='relu'),
    Dense(16, activation='relu'),
    Dense(info.features['label'].num_classes)])

loss_object = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)

train_loss = tf.keras.metrics.Mean()
test_loss = tf.keras.metrics.Mean()

train_acc = tf.keras.metrics.SparseCategoricalAccuracy()
test_acc = tf.keras.metrics.SparseCategoricalAccuracy()

opt = tf.keras.optimizers.Adam(learning_rate=1e-3)


@tf.function
def train_step(inputs, labels):
    with tf.GradientTape() as tape:
        logits = model(inputs, training=True)
        loss = loss_object(labels, logits)

    gradients = tape.gradient(loss, model.trainable_variables)
    opt.apply_gradients(zip(gradients, model.trainable_variables))
    train_loss(loss)
    train_acc(labels, logits)


@tf.function
def test_step(inputs, labels):
    logits = model(inputs, training=False)
    loss = loss_object(labels, logits)
    test_loss(loss)
    test_acc(labels, logits)


def fit(epoch):
    template = 'Epoch {:>2} Train Loss {:.3f} Test Loss {:.3f} ' \
               'Train Acc {:.2f} Test Acc {:.2f}'

    train_loss.reset_states()
    test_loss.reset_states()
    train_acc.reset_states()
    test_acc.reset_states()

    for X_train, y_train in train_dataset:
        train_step(X_train, y_train)

    for X_test, y_test in test_dataset:
        test_step(X_test, y_test)

    print(template.format(
        epoch + 1,
        train_loss.result(),
        test_loss.result(),
        train_acc.result(),
        test_acc.result()
    ))


def main(epochs=50, early_stopping=10):
    loss_history = deque(maxlen=early_stopping + 1)

    for epoch in range(epochs):
        fit(epoch)
        loss_history.append(test_loss.result().numpy())
        if len(loss_history) > early_stopping and loss_history.popleft() < min(loss_history):
            print(f'\nEarly stopping. No validation loss '
                  f'improvement in {early_stopping} epochs.')
            break

if __name__ == '__main__':
    main(epochs=250, early_stopping=10)
Epoch  1 Train Loss 1.730 Test Loss 1.449 Train Acc 0.33 Test Acc 0.33
Epoch  2 Train Loss 1.405 Test Loss 1.220 Train Acc 0.33 Test Acc 0.33
Epoch  3 Train Loss 1.173 Test Loss 1.054 Train Acc 0.33 Test Acc 0.33
Epoch  4 Train Loss 1.006 Test Loss 0.935 Train Acc 0.33 Test Acc 0.33
Epoch  5 Train Loss 0.885 Test Loss 0.846 Train Acc 0.33 Test Acc 0.33
...
Epoch 89 Train Loss 0.196 Test Loss 0.240 Train Acc 0.89 Test Acc 0.87
Epoch 90 Train Loss 0.195 Test Loss 0.239 Train Acc 0.89 Test Acc 0.87
Epoch 91 Train Loss 0.195 Test Loss 0.239 Train Acc 0.89 Test Acc 0.87
Epoch 92 Train Loss 0.194 Test Loss 0.239 Train Acc 0.90 Test Acc 0.87

Early stopping. No validation loss improvement in 10 epochs.
Sitemizi kullandığınızda şunları okuyup anladığınızı kabul etmiş olursunuz: Çerez Politikası ve Gizlilik Politikası.
Licensed under cc by-sa 3.0 with attribution required.