import torch
[docs]class EarlyStopping:
"""
Early Stopping Mechanism
Returns True if training should stop and False if it should continue.
Saves the best model.
Only works for decreasing loss values.
Args:
patience (int) : Stop after p continous incrementations.
Attributes:
lastLoss (float) : Loss-value of previous call of the 'stop'-function
patience (int) : Stop after how many continous incrementations
suceedingsHigherValues (int) :Number of continous incrementations of loss
"""
def __init__(self, patience):
self.lastLoss = 0
self.patience = patience
self.succeedingHigherValues = 0
self.best_model = None
[docs] def stop(self, newLoss, model):
"""
Decides whether training should be stopped
Decides whether training should be stopped. Every time the erorr decreases the model is saved.
Args:
newLoss (float): Training or validation Loss.
Returns:
bool : Return true if number of values higher than the previous one equals patience.
"""
if self.best_model is None:
self.best_model = model
if(newLoss > self.lastLoss):
self.succeedingHigherValues += 1
else:
self.succeedingHigherValues = 0
self.__save_model(model)
self.lastLoss = newLoss
if(self.patience <= self.succeedingHigherValues):
if self.patience > 1:
self.best_model = torch.load("checkpoint.pth.tar")
return True
else:
return False
def __save_model(self, model):
torch.save({'state_dict': model.state_dict()}, 'checkpoint.pth.tar')
[docs]def printMessage(Message, verbosity):
"""Prints messages if verbosity is set
Does not do much currently but can be expanded later.
"""
if(verbosity == 1):
print(Message)