Source code for gnn_base_model.train.training

"""Training of Graph Neural Network model.

Classes
-------
EarlyStopper
    Early stopping procedure (implicit regularizaton).

Functions
---------
train_model
    Training of Graph Neural Network model.
get_pytorch_optimizer
    Get PyTorch optimizer.
get_learning_rate_scheduler
    Get PyTorch optimizer learning rate scheduler.
save_training_state
    Save model and optimizer states at given training epoch.
load_training_state
    Load model and optimizer states from available training data.
remove_posterior_optim_state_files
    Delete optimizer training epoch state files posterior to given epoch.
save_loss_history
    Save training process loss history record.
load_loss_history
    Load training process loss history record.
load_lr_history
    Load training process learning rate history record.
seed_worker
    Set workers seed in PyTorch data loaders to preserve reproducibility.
read_loss_history_from_file
    Read training loss history from loss history record file.
read_lr_history_from_file(loss_record_path)
    Read training learning rate history from loss history record file.
write_training_summary_file    
    Write summary data file for model training process.
"""
#
#                                                                       Modules
# =============================================================================
# Standard
import os
import pickle
import random
import re
import time
import datetime
import copy
# Third-party
import torch
import torch_geometric.loader
import numpy as np
# Local
from gnn_base_model.model.gnn_model import GNNEPDBaseModel
from gnn_base_model.train.torch_loss import get_pytorch_loss
from gnn_base_model.predict.prediction import predict
from ioput.iostandard import write_summary_file
#
#                                                          Authorship & Credits
# =============================================================================
__author__ = 'Bernardo Ferreira (bernardo_ferreira@brown.edu)'
__credits__ = ['Bernardo Ferreira', 'Rui Barreira', ]
__status__ = 'Planning'
# =============================================================================
#
# =============================================================================
[docs] def train_model(n_max_epochs, dataset, model_init_args, lr_init, opt_algorithm='adam', lr_scheduler_type=None, lr_scheduler_kwargs={}, loss_nature='node_features_out', loss_type='mse', loss_kwargs={}, batch_size=1, is_sampler_shuffle=False, data_loader_kwargs={}, is_early_stopping=False, early_stopping_kwargs={}, load_model_state=None, save_every=None, save_loss_every=None, dataset_file_path=None, device_type='cpu', seed=None, is_verbose=False, tqdm_flavor='default'): """Training of Graph Neural Network model. Parameters ---------- n_max_epochs : int Maximum number of training epochs. dataset : torch.utils.data.Dataset Graph Neural Network graph data set. Each sample corresponds to a torch_geometric.data.Data object describing a homogeneous graph. model_init_args : dict Graph Neural Network model class initialization parameters (check class GNNEPDBaseModel). lr_init : float Initial value optimizer learning rate. Constant learning rate value if no learning rate scheduler is specified (lr_scheduler_type=None). opt_algorithm : {'adam',}, default='adam' Optimization algorithm: 'adam' : Adam (torch.optim.Adam) lr_scheduler_type : {'steplr', 'explr', 'linlr'}, default=None Type of learning rate scheduler: 'steplr' : Step-based decay (torch.optim.lr_scheduler.SetpLR) 'explr' : Exponential decay (torch.optim.lr_scheduler.ExponentialLR) 'linlr' : Linear decay (torch.optim.lr_scheduler.LinearLR) lr_scheduler_kwargs : dict, default={} Arguments of torch.optim.lr_scheduler.LRScheduler initializer. loss_nature : {'node_features_out', \ 'edge_features_out', \ 'global_features_out'}, \ default='node_features_out' Loss nature: 'node_features_out' : Based on node output features 'edge_features_out' : Based on edge output features 'global_features_out' : Based on global output features loss_type : {'mse',}, default='mse' Loss function type: 'mse' : MSE (torch.nn.MSELoss) loss_kwargs : dict, default={} Arguments of torch.nn._Loss initializer. batch_size : int, default=1 Number of samples loaded per batch. is_sampler_shuffle : bool, default=False If True, shuffles data set samples at every epoch. data_loader_kwargs : dict, default={} Additional arguments for torch_geometric.loader.dataloader.DataLoader. is_early_stopping : bool, default=False If True, then training process is halted when early stopping criterion is triggered. By default, 20% of the training data set is allocated for the underlying validation procedures. early_stopping_kwargs : dict, default={} Early stopping criterion parameters (key, str, item, value). load_model_state : {'best', 'last', 'init', int, None}, default=None Load available GNN-based model state from the model directory. Data scalers are also loaded from model initialization file. Options: 'best' : Model state corresponding to best performance available 'last' : Model state corresponding to highest training epoch int : Model state corresponding to given training epoch 'init' : Model state corresponding to initial state None : Model default state file save_every : int, default=None Save Graph Neural Network model every save_every epochs. If None, then saves only last epoch and best performance states. save_loss_every : int, default=None Save loss history model every save_loss_every epochs. If None, then saves loss history only after the last epoch. dataset_file_path : str, default=None Graph Neural Network graph data set file path if such file exists. Only used for output purposes. device_type : {'cpu', 'cuda'}, default='cpu' Type of device on which torch.Tensor is allocated. seed : int, default=None Seed used to initialize the random number generators of Python and other libraries (e.g., NumPy, PyTorch) for all devices to preserve reproducibility. Does also set workers seed in PyTorch data loaders. is_verbose : bool, default=False If True, enable verbose output. tqdm_flavor : {'default', 'notebook'}, default='default' Type of tqdm progress bar to use when is_verbose=True. Returns ------- model : torch.nn.Module Graph Neural Network model. best_loss : float Best loss during training process. best_training_epoch : int Training epoch corresponding to best loss during training process. """ # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # Import tqdm if tqdm_flavor == 'notebook': from tqdm.notebook import tqdm else: from tqdm import tqdm # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # Set random number generators initialization for reproducibility if isinstance(seed, int): torch.manual_seed(seed) random.seed(seed) np.random.seed(seed) generator = torch.Generator().manual_seed(seed) # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # Set device device = torch.device(device_type) # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ if is_verbose: if device_type == 'cuda': device_name = torch.cuda.get_device_name(device) else: device_name = 'cpu' print(f'\n> Setting device: {device_name}') # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ start_time_sec = time.time() if is_verbose: print('\nGraph Neural Network model training' '\n-----------------------------------') # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # Initialize Graph Neural Network model state if load_model_state is not None: if is_verbose: print('\n> Initializing model...') # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # Initialize Graph Neural Network model # (includes loading of data scalers) model = GNNEPDBaseModel.init_model_from_file( model_init_args['model_directory']) # Set model device model.set_device(device_type) # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # Get model input and output features normalization is_model_in_normalized = model.is_model_in_normalized is_model_out_normalized = model.is_model_out_normalized # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ if is_verbose: print('\n> Loading model state...') # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # Load Graph Neural Network model state _ = model.load_model_state(load_model_state=load_model_state, is_remove_posterior=True) else: if is_verbose: print('\n> Initializing model...') # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # Initialize Graph Neural Network model model = GNNEPDBaseModel(**model_init_args) # Set model device model.set_device(device_type) # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # Get model input and output features normalization is_model_in_normalized = model.is_model_in_normalized is_model_out_normalized = model.is_model_out_normalized # Fit model data scalers if is_model_in_normalized or is_model_out_normalized: model.fit_data_scalers(dataset, is_verbose=is_verbose, tqdm_flavor=tqdm_flavor) # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # Save model initial state model.save_model_init_state() # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # Get model parameters model_parameters = model.parameters(recurse=True) # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # Move model to device model.to(device=device) # Set model in training mode model.train() # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # Initialize learning rate learning_rate = lr_init # Set optimizer if opt_algorithm == 'adam': # Initialize optimizer, specifying the model (and submodels) parameters # that should be optimized. By default, model parameters gradient flag # is set to True, meaning that gradients with respect to the parameters # are required (operations on the parameters are recorded for automatic # differentiation) optimizer = torch.optim.Adam(params=model_parameters, lr=learning_rate) else: raise RuntimeError('Unknown optimization algorithm') # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # Initialize learning rate scheduler is_lr_scheduler = False if lr_scheduler_type is not None: is_lr_scheduler = True lr_scheduler = get_learning_rate_scheduler( optimizer=optimizer, scheduler_type=lr_scheduler_type, **lr_scheduler_kwargs) # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # Initialize loss function loss_function = get_pytorch_loss(loss_type, **loss_kwargs) # Initialize loss and learning rate histories (per epoch) loss_history_epochs = [] lr_history_epochs = [] # Initialize loss and learning rate histories (per training step) loss_history_steps = [] lr_history_steps = [] # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # Initialize training flag is_keep_training = True # Initialize number of training epochs epoch = 0 # Initialize number of training steps step = 0 # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # Initialize validation loss history validation_loss_history = None # Initialize early stopping criterion if is_early_stopping: if is_verbose: print('\n> Initializing early stopping criterion...') # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # Initialize early stopping criterion early_stopper = EarlyStopper(**early_stopping_kwargs) # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # Initialize early stopping flag is_stop_training = False # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ if is_verbose: print(f'\n> Training data set size: {len(dataset)}') # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # Set data loader if isinstance(seed, int): data_loader = torch_geometric.loader.dataloader.DataLoader( dataset=dataset, batch_size=batch_size, worker_init_fn=seed_worker, generator=generator, **data_loader_kwargs) else: data_loader = torch_geometric.loader.dataloader.DataLoader( dataset=dataset, batch_size=batch_size, shuffle=is_sampler_shuffle, **data_loader_kwargs) # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ if is_verbose: input_normalization_str = 'Yes' if is_model_in_normalized else 'No' print(f'\n> Input data normalization: {input_normalization_str}') output_normalization_str = 'Yes' if is_model_out_normalized else 'No' print(f'\n> Output data normalization: {output_normalization_str}') print('\n\n> Starting training process...\n') # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ if tqdm_flavor == 'notebook': from tqdm.notebook import tqdm pbar = tqdm(total=n_max_epochs, mininterval=1, maxinterval=300, miniters=0, desc='> Epochs: ', unit=' epoch') # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # Loop over training iterations while is_keep_training: # Store epoch initial training step epoch_init_step = step # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # Loop over graph batches. A graph batch is a data object describing a # batch of graphs as one large (disconnected) graph. for pyg_graph in tqdm(data_loader, leave=False, mininterval=1, maxinterval=60, miniters=0, desc='> Steps: ', disable=not is_verbose, unit=' step'): # Move graph sample to device pyg_graph.to(device) # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # Get batch node assignment vector if batch_size > 1: batch_vector = pyg_graph.batch else: batch_vector = None # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # Get input features from input graph node_features_in, edge_features_in, global_features_in, \ edges_indexes = model.get_input_features_from_graph( pyg_graph, is_normalized=is_model_in_normalized) # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # Get node output features ground-truth node_targets, edge_targets, global_targets = \ model.get_output_features_from_graph( pyg_graph, is_normalized=is_model_out_normalized) # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # Compute output features predictions (forward propagation). # During the foward pass, PyTorch creates a computation graph for # the tensors that require gradients (gradient flag set to True) to # keep track of the operations on these tensors, i.e., the model # parameters. In addition, PyTorch additionally stores the # corresponding 'gradient functions' (mathematical operator) of the # executed operations to the output tensor, stored in the .grad_fn # attribute of the corresponding tensors. Tensor.grad_fn is set to # None for tensors corresponding to leaf-nodes of the computation # graph or for tensors with the gradient flag set to False. if loss_nature == 'node_features_out': # Get node output features node_features_out, _, _ = model( node_features_in=node_features_in, edge_features_in=edge_features_in, global_features_in=global_features_in, edges_indexes=edges_indexes, batch_vector=batch_vector) # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # Compute loss loss = loss_function(node_features_out, node_targets) # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ elif loss_nature == 'edge_features_out': # Get edge output features _, edge_features_out, _ = model( node_features_in=node_features_in, edge_features_in=edge_features_in, global_features_in=global_features_in, edges_indexes=edges_indexes, batch_vector=batch_vector) # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # Compute loss loss = loss_function(edge_features_out, edge_targets) # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ elif loss_nature == 'global_features_out': # Get global output features _, _, global_features_out = model( node_features_in=node_features_in, edge_features_in=edge_features_in, global_features_in=global_features_in, edges_indexes=edges_indexes, batch_vector=batch_vector) # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # Compute loss loss = loss_function(global_features_out, global_targets) # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ else: raise RuntimeError('Unknown loss nature.') # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # Initialize gradients (set to zero) optimizer.zero_grad() # Compute gradients with respect to model parameters (backward # propagation). PyTorch backpropagates recursively through the # computation graph of loss and computes the gradients with respect # to the model parameters. On each Tensor, PyTorch computes the # local gradients using the previously stored .grad_fn mathematical # operators and combines them with the incoming gradients to # compute the complete gradient (i.e., building the # differentiation chain rule). The backward propagation recursive # path stops when a leaf-node is reached (e.g., a model parameter), # where .grad_fn is set to None. Gradients are cumulatively stored # in the .grad attribute of the corresponding tensors loss.backward() # Perform optimization step. Gradients are stored in the .grad # attribute of model parameters optimizer.step() # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # Save training step loss and learning rate loss_history_steps.append(loss.clone().detach().cpu()) if is_lr_scheduler: lr_history_steps.append(lr_scheduler.get_last_lr()) else: lr_history_steps.append(lr_init) # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # Increment training step counter step += 1 # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # Update optimizer learning rate if is_lr_scheduler: lr_scheduler.step() # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # Save training epoch loss (epoch average loss value) epoch_avg_loss = np.mean(loss_history_steps[epoch_init_step:]) loss_history_epochs.append(epoch_avg_loss) # Save training epoch learning rate (epoch last value) if is_lr_scheduler: lr_history_epochs.append(lr_scheduler.get_last_lr()) else: lr_history_epochs.append(lr_init) # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # Save model and optimizer current states if save_every is not None and epoch % save_every == 0: save_training_state(model=model, optimizer=optimizer, epoch=epoch) # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ if save_loss_every is not None and epoch % save_loss_every == 0: # Get validation loss history if is_early_stopping: validation_loss_history = \ early_stopper.get_validation_loss_history() # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # Save loss and learning rate histories save_loss_history(model, n_max_epochs, loss_nature, loss_type, loss_history_epochs, lr_scheduler_type=lr_scheduler_type, lr_history_epochs=lr_history_epochs, validation_loss_history=validation_loss_history) # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # Save model and optimizer best performance state corresponding to # minimum training loss if epoch_avg_loss <= min(loss_history_epochs): save_training_state(model=model, optimizer=optimizer, epoch=epoch, is_best_state=True) # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # Check early stopping criterion if is_early_stopping: # Evaluate early stopping criterion if early_stopper.is_evaluate_criterion(epoch): is_stop_training = early_stopper.evaluate_criterion( model, optimizer, epoch, loss_nature=loss_nature, loss_type=loss_type, loss_kwargs=loss_kwargs, batch_size=batch_size, device_type=device_type) # If early stopping is triggered, save model and optimizer best # performance corresponding to early stopping criterion if is_stop_training: # Load best performance model and optimizer states best_epoch = early_stopper.load_best_performance_state( model, optimizer) # Save model and optimizer best performance states save_training_state(model, optimizer, epoch=best_epoch, is_best_state=True) # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # Increment epoch counter epoch += 1 # Update progress bar if is_verbose: pbar.update(1) # Check training process flow if epoch == n_max_epochs: # Completed maximum number of epochs is_keep_training = False break elif is_early_stopping and is_stop_training: # Early stopping criterion triggered is_keep_training = False break # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ if is_verbose: if is_early_stopping and is_stop_training: print('\n\n> Early stopping has been triggered!', '\n\n> Finished training process!') else: print('\n\n> Finished training process!') # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # Get validation loss history if is_early_stopping: validation_loss_history = \ early_stopper.get_validation_loss_history() # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # Save model and optimizer final states save_training_state(model=model, optimizer=optimizer, epoch=epoch) # Save loss and learning rate histories save_loss_history(model, n_max_epochs, loss_nature, loss_type, loss_history_epochs, lr_scheduler_type=lr_scheduler_type, lr_history_epochs=lr_history_epochs, validation_loss_history=validation_loss_history) # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # Get best loss and corresponding training epoch best_loss = float(min(loss_history_epochs)) best_training_epoch = loss_history_epochs.index(best_loss) # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ if is_verbose: if is_model_out_normalized: min_loss_str = 'Minimum training loss (normalized)' else: min_loss_str = 'Minimum training loss' print(f'\n\n> {min_loss_str}: {best_loss:.8e} | ' f'Epoch: {best_training_epoch:d}') # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # Compute total training time and average training time per epoch total_time_sec = time.time() - start_time_sec avg_time_epoch = total_time_sec/epoch # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ if is_verbose: print(f'\n> Model directory: {model.model_directory}') total_time_sec = time.time() - start_time_sec print(f'\n> Total training time: ' f'{str(datetime.timedelta(seconds=int(total_time_sec)))} | ' f'Avg. training time per epoch: ' f'{str(datetime.timedelta(seconds=int(avg_time_epoch)))}\n') # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # Write summary data file for model training process write_training_summary_file( device_type, seed, model.model_directory, load_model_state, n_max_epochs, is_model_in_normalized, is_model_out_normalized, batch_size, is_sampler_shuffle, loss_nature, loss_type, loss_kwargs, opt_algorithm, lr_init, lr_scheduler_type, lr_scheduler_kwargs, epoch, dataset_file_path, dataset, best_loss, best_training_epoch, total_time_sec, avg_time_epoch) # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ return model, best_loss, best_training_epoch
# ============================================================================= def get_pytorch_optimizer(algorithm, params, **kwargs): """Get PyTorch optimizer. Parameters ---------- algorithm : {'adam',} Optimization algorithm: 'adam' : Adam (torch.optim.Adam) params : list List of parameters (torch.Tensors) to optimize or list of dicts defining parameter groups. **kwargs Arguments of torch.optim.Optimizer initializer. Returns ------- optimizer : torch.optim.Optimizer PyTorch optimizer. """ if algorithm == 'adam': optimizer = torch.optim.Adam(params, **kwargs) else: raise RuntimeError('Unknown or unavailable PyTorch optimizer.') # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ return optimizer # ============================================================================= def get_learning_rate_scheduler(optimizer, scheduler_type, **kwargs): """Get PyTorch optimizer learning rate scheduler. Parameters ---------- optimizer : torch.optim.Optimizer PyTorch optimizer. lr_scheduler_type : {'steplr', 'explr', 'linlr'} Type of learning rate scheduler: 'steplr' : Step-based decay (torch.optim.lr_scheduler.SetpLR) 'explr' : Exponential decay (torch.optim.lr_scheduler.ExponentialLR) 'linlr' : Linear decay (torch.optim.lr_scheduler.LinearLR) **kwargs Arguments of torch.optim.lr_scheduler.LRScheduler initializer. Returns ------- scheduler : torch.optim.lr_scheduler.LRScheduler PyTorch optimizer learning rate scheduler. """ if scheduler_type == 'steplr': # Check scheduler mandatory parameters if 'step_size' not in kwargs.keys(): raise RuntimeError('The parameter \'step_size\' needs to be ' 'provided to initialize step-based decay ' 'learning rate scheduler.') # Initialize scheduler scheduler = torch.optim.lr_scheduler.StepLR(optimizer, **kwargs) # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ elif scheduler_type == 'explr': # Check scheduler mandatory parameters if 'gamma' not in kwargs.keys(): raise RuntimeError('The parameter \'gamma\' needs to be ' 'provided to initialize exponential decay ' 'learning rate scheduler.') # Initialize scheduler scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, **kwargs) # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ elif scheduler_type == 'linlr': # Initialize scheduler scheduler = torch.optim.lr_scheduler.LinearLR(optimizer, **kwargs) # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ else: raise RuntimeError('Unknown or unavailable PyTorch optimizer ' 'learning rate scheduler.') # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ return scheduler # ============================================================================= def save_training_state(model, optimizer, epoch=None, is_best_state=False, is_remove_posterior=True): """Save model and optimizer states at given training epoch. Material patch model state file is stored in model_directory under the name < model_name >.pt or < model_name >-< epoch >.pt if epoch is known. Material patch model state file corresponding to the best performance is stored in model_directory under the name < model_name >-best.pt or < model_name >-< epoch >-best.pt if epoch is known. Optimizer state file is stored in model_directory under the name < model_name >_optim-< epoch >.pt. Optimizer state file corresponding to the best performance is stored in model_directory under the name < model_name >_optim-best.pt or < model_name >_optim-< epoch >-best.pt if epoch is known. Parameters ---------- model : torch.nn.Module Model. optimizer : torch.optim.Optimizer PyTorch optimizer. epoch : int, default=None Training epoch. is_best_state : bool, default=False If True, save material patch model state file corresponding to the best performance instead of regular state file. is_remove_posterior : bool, default=True Remove material patch model and optimizer state files corresponding to training epochs posterior to the saved state file. Effective only if saved epoch is known. """ # Save model model.save_model_state(epoch=epoch, is_best_state=is_best_state) # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # Set optimizer state file optimizer_state_file = model.model_name + '_optim' # Append epoch if isinstance(epoch, int): optimizer_state_file += '-' + str(epoch) # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # Set optimizer state file corresponding to best performance if is_best_state: # Append best performance optimizer_state_file += '-' + 'best' # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # Get optimizer state files in material patch model directory directory_list = os.listdir(model.model_directory) # Loop over files in material patch model directory for filename in directory_list: # Check if file is optimizer epoch best state file is_best_state_file = \ bool(re.search(r'^' + model.model_name + r'_optim' + r'-?[0-9]*' + r'-best' + r'\.pt', filename)) # Delete state file if is_best_state_file: os.remove(os.path.join(model.model_directory, filename)) # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # Set optimizer state file path optimizer_path = os.path.join(model.model_directory, optimizer_state_file + '.pt') # Save optimizer state optimizer_state = dict(state=optimizer.state_dict(), epoch=epoch) torch.save(optimizer_state, optimizer_path) # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # Delete model and optimizer epoch state files posterior to saved epoch if isinstance(epoch, int) and is_remove_posterior: remove_posterior_optim_state_files(model, epoch) # ============================================================================= def load_training_state(model, opt_algorithm, optimizer, load_model_state=None, is_remove_posterior=True): """Load model and optimizer states from available training data. Material patch model state file is stored in model_directory under the name < model_name >.pt, < model_name >-< epoch >.pt, < model_name >-best.pt or < model_name >-< epoch >-best.pt. Optimizer state file is stored in model_directory under the name < model_name >_optim.pt or < model_name >_optim-< epoch >.pt. Both model and optimizer are updated 'in-place' with loaded state data. Parameters ---------- model : torch.nn.Module Model. opt_algorithm : {'adam',}, default='adam' Optimization algorithm: 'adam' : Adam (torch.optim.Adam) optimizer : torch.optim.Optimizer PyTorch optimizer. load_model_state : {'best', 'last', int, None}, default=None Load available Graph Neural Network model state from the model directory. Options: 'best' : Model state corresponding to best performance available 'last' : Model state corresponding to highest training epoch int : Model state corresponding to given training epoch None : Model default state file is_remove_posterior : bool, default=True Remove material patch model state files corresponding to training epochs posterior to the loaded state file. Effective only if loaded training epoch is known. Returns ------- loaded_epoch : int Training epoch corresponding to loaded state data. Defaults to 0 if training epoch is unknown. """ # Load model state loaded_epoch = \ model.load_model_state(load_model_state=load_model_state, is_remove_posterior=is_remove_posterior) # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # Set optimizer state file optimizer_state_file = model.model_name + '_optim' # Append epoch if isinstance(loaded_epoch, int): optimizer_state_file += '-' + str(loaded_epoch) # Append best performance if load_model_state == 'best': optimizer_state_file += '-' + 'best' # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # Set optimizer state file path optimizer_path = os.path.join(model.model_directory, optimizer_state_file + '.pt') # Load optimizer state if not os.path.isfile(optimizer_path): raise RuntimeError('Optimizer state file has not been found:\n\n' + optimizer_path) else: # Initialize optimizer if opt_algorithm == 'adam': optimizer = torch.optim.Adam(params=model.parameters(recurse=True)) else: raise RuntimeError('Unknown optimization algorithm') # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # Load optimizer state optimizer_state = torch.load(optimizer_path, weights_only=True) # Set loaded optimizer state optimizer.load_state_dict(optimizer_state['state']) # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # Delete optimizer epoch state files posterior to loaded epoch if isinstance(loaded_epoch, int) and is_remove_posterior: remove_posterior_optim_state_files(model, loaded_epoch) # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # Set loaded epoch to 0 if unknown from state file if loaded_epoch is None: loaded_epoch = 0 # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ return loaded_epoch # ============================================================================= def remove_posterior_optim_state_files(model, epoch): """Delete optimizer training epoch state files posterior to given epoch. Parameters ---------- model : torch.nn.Module Model. epoch : int Training epoch. """ # Get files in material patch model directory directory_list = os.listdir(model.model_directory) # Loop over files in material patch model directory for filename in directory_list: # Check if file is optimizer epoch state file is_state_file = bool(re.search(r'^' + model.model_name + r'_optim' + r'-[0-9]+' + r'\.pt', filename)) # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # Delete optimizer epoch state file posterior to given epoch if is_state_file: # Get optimizer state epoch file_epoch = int(os.path.splitext(filename)[0].split('-')[-1]) # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # Delete optimizer epoch state file if file_epoch > epoch: os.remove(os.path.join(model.model_directory, filename)) # ============================================================================= def save_loss_history(model, n_max_epochs, loss_nature, loss_type, training_loss_history, lr_scheduler_type=None, lr_history_epochs=None, validation_loss_history=None): """Save training process loss history record. Loss history record file is stored in model_directory under the name loss_history_record.pkl. Overwrites existing loss history record file. Parameters ---------- model : torch.nn.Module Model. n_max_epochs : int Maximum number of epochs of training process. loss_nature : str Loss nature. loss_type : str Loss function type. training_loss_history : list[float] Training process training loss history (per epoch). lr_scheduler_type : {'steplr', 'explr', 'linlr'}, default=None Type of learning rate scheduler. lr_history_epochs : list[float], default=None Training process learning rate history (per epoch). validation_loss_history : list[float], default=None Training process validation loss history (e.g., early stopping criterion). """ # Set loss history record file path loss_record_path = os.path.join(model.model_directory, 'loss_history_record' + '.pkl') # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # Build training loss history record loss_history_record = {} loss_history_record['n_max_epochs'] = int(n_max_epochs) loss_history_record['loss_nature'] = str(loss_nature) loss_history_record['loss_type'] = str(loss_type) loss_history_record['training_loss_history'] = list(training_loss_history) # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # Store learning rate history record if lr_scheduler_type is not None: loss_history_record['lr_scheduler_type'] = str(lr_scheduler_type) else: loss_history_record['lr_scheduler_type'] = None if lr_history_epochs is not None: loss_history_record['lr_history_epochs'] = list(lr_history_epochs) else: loss_history_record['lr_history_epochs'] = None # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # Store validation loss history if validation_loss_history is not None: loss_history_record['validation_loss_history'] = \ list(validation_loss_history) else: loss_history_record['validation_loss_history'] = None # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # Save loss history record with open(loss_record_path, 'wb') as loss_record_file: pickle.dump(loss_history_record, loss_record_file) # ============================================================================= def load_loss_history(model, loss_nature, loss_type, epoch=None): """Load training process training loss history record. Loss history record file is stored in model_directory under the name loss_history_record.pkl. Parameters ---------- model : torch.nn.Module Model. loss_nature : str Loss nature. loss_type : str Loss function type. epoch : int, default=None Epoch to which loss history is loaded (included), with the first epoch being 0. If None, then loads the full loss history. Returns ------- training_loss_history : list[float] Training process training loss history (per epoch). """ # Set loss history record file path loss_record_path = os.path.join(model.model_directory, 'loss_history_record' + '.pkl') # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # Load training process training loss history if os.path.isfile(loss_record_path): # Load loss history record with open(loss_record_path, 'rb') as loss_record_file: loss_history_record = pickle.load(loss_record_file) # Check consistency between loss history nature and current training # process loss nature history_loss_nature = loss_history_record['loss_nature'] if history_loss_nature != loss_nature: raise RuntimeError('Loss history nature (' + str(history_loss_nature) + ') is not consistent with current training ' 'process loss nature (' + str(loss_nature) + ').') # Check consistency between loss history type and current training # process loss type history_loss_type = loss_history_record['loss_type'] if history_loss_type != loss_type: raise RuntimeError('Loss history type (' + str(history_loss_type) + ') is not consistent with current training ' 'process loss type (' + str(loss_type) + ').') # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # Check training loss history loss_record = loss_history_record['training_loss_history'] if not isinstance(loss_record, list): raise RuntimeError('Loaded loss history is not a list[float].') # Load training loss history if epoch is None or epoch + 1 == len(loss_record): training_loss_history = loss_record else: if epoch + 1 > len(loss_record): raise RuntimeError('Target epoch is beyond available loss ' 'history.') else: training_loss_history = loss_record[:epoch + 1] # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ else: # Build training loss history with None entries if loss history record # file cannot be found if epoch is None: raise RuntimeError('Training process loss history file has not ' 'been found and loaded epoch is unknown.') else: training_loss_history = (epoch + 1)*[None,] # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ return training_loss_history # ============================================================================= def load_lr_history(model, epoch=None): """Load training process learning rate history record. Loss history record file is stored in model_directory under the name loss_history_record.pkl. Parameters ---------- model : torch.nn.Module Model. epoch : int, default=None Training epoch to which loss history is loaded (included), with the first training epoch being 0. If None, then loads the full loss history. Returns ------- lr_history_epochs : list[float] Training process learning rate history (per epoch). """ # Set loss history record file path loss_record_path = os.path.join(model.model_directory, 'loss_history_record' + '.pkl') # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # Load training process learning history if os.path.isfile(loss_record_path): # Load loss history record with open(loss_record_path, 'rb') as loss_record_file: loss_history_record = pickle.load(loss_record_file) # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # Check learning rate history lr_record = loss_history_record['lr_history_epochs'] if not isinstance(lr_record, list) and lr_record is not None: raise RuntimeError('Loaded learning rate history is not a ' 'list[float] or None.') # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # Load learning rate history if lr_record is None and epoch is None: # Build learning rate history with None entries if learning rate # history is not available lr_history_epochs = len( loss_history_record['training_loss_history'])*[None,] elif lr_record is None and epoch is not None: # Build learning rate history with None entries if learning rate # history is not available lr_history_epochs = (epoch + 1)*[None,] else: if epoch is None or epoch + 1 == len(lr_record): lr_history_epochs = lr_record elif epoch + 1 > len(lr_record): raise RuntimeError('Target epoch is beyond available ' 'learning rate history.') else: lr_history_epochs = lr_record[:epoch + 1] # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ else: # Build learning rate history with None entries if loss history record # file cannot be found if epoch is None: raise RuntimeError('Training process loss history file has not ' 'been found and loaded epoch is unknown.') else: lr_history_epochs = (epoch + 1)*[None,] # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ return lr_history_epochs # ============================================================================= def seed_worker(worker_id): """Set workers seed in PyTorch data loaders to preserve reproducibility. Taken from: https://pytorch.org/docs/stable/notes/randomness.html Parameters ---------- worker_id : int Worker ID. """ worker_seed = torch.initial_seed() % 2**32 np.random.seed(worker_seed) random.seed(worker_seed) # ============================================================================= def read_loss_history_from_file(loss_record_path): """Read training process loss history from loss history record file. Loss history record file is stored in model_directory under the name loss_history_record.pkl. Detaches loss values from computation graph and moves them to CPU. Parameters ---------- loss_record_path : str Loss history record file path. Returns ------- loss_nature : str Loss nature. loss_type : str Loss function type. training_loss_history : list[float] Training process training loss history (per epoch). validation_loss_history : {None, list[float]} Training process validation loss history. Set to None if not available. """ # Check loss history record file if not os.path.isfile(loss_record_path): raise RuntimeError('Loss history record file has not been found:\n\n' + loss_record_path) # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # Load loss history record with open(loss_record_path, 'rb') as loss_record_file: loss_history_record = pickle.load(loss_record_file) # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # Check loss history if 'loss_nature' not in loss_history_record.keys(): raise RuntimeError('Loss nature is not available in loss history ' 'record.') elif 'loss_type' not in loss_history_record.keys(): raise RuntimeError('Loss type is not available in loss history ' 'record.') elif 'training_loss_history' not in loss_history_record.keys(): raise RuntimeError('Loss history is not available in loss history ' 'record.') elif not isinstance(loss_history_record['training_loss_history'], list): raise RuntimeError('Loss history is not a list[float].') # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # Set loss nature loss_nature = str(loss_history_record['loss_nature']) # Set loss type loss_type = str(loss_history_record['loss_type']) # Set training loss history training_loss_history = [] for x in loss_history_record['training_loss_history']: if isinstance(x, torch.Tensor): training_loss_history.append(x.detach().cpu()) else: training_loss_history.append(x) # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # Set validation loss history if isinstance(loss_history_record['validation_loss_history'], list): validation_loss_history = [] for x in loss_history_record['validation_loss_history']: if isinstance(x, torch.Tensor): validation_loss_history.append(x.detach().cpu()) else: validation_loss_history.append(x) else: validation_loss_history = None # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ return (loss_nature, loss_type, training_loss_history, validation_loss_history) # ============================================================================= def read_lr_history_from_file(loss_record_path): """Read training learning rate history from loss history record file. Loss history record file is stored in model_directory under the name loss_history_record.pkl. Parameters ---------- loss_record_path : str Loss history record file path. Returns ------- lr_scheduler_type : {'steplr', 'explr', 'linlr'} Type of learning rate scheduler: 'steplr' : Step-based decay (torch.optim.lr_scheduler.SetpLR) 'explr' : Exponential decay (torch.optim.lr_scheduler.ExponentialLR) 'linlr' : Linear decay (torch.optim.lr_scheduler.LinearLR) lr_history_epochs : list[float] Training process learning rate history (per epoch). """ # Check loss history record file if not os.path.isfile(loss_record_path): raise RuntimeError('Loss history record file has not been found:\n\n' + loss_record_path) # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # Load loss history record with open(loss_record_path, 'rb') as loss_record_file: loss_history_record = pickle.load(loss_record_file) # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # Check learning rate history if 'lr_scheduler_type' not in loss_history_record.keys(): raise RuntimeError('Learning rate scheduler type is not available in ' 'loss history record.') elif 'lr_history_epochs' not in loss_history_record.keys(): raise RuntimeError('Learning rate history is not available in loss ' 'history record.') elif not isinstance(loss_history_record['lr_history_epochs'], list): raise RuntimeError('Learning rate history is not a list[float].') # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # Set learning rate scheduler type lr_scheduler_type = loss_history_record['lr_scheduler_type'] # Set learning rate history lr_history_epochs = loss_history_record['lr_history_epochs'] # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ return lr_scheduler_type, lr_history_epochs # ============================================================================= def write_training_summary_file( device_type, seed, model_directory, load_model_state, n_max_epochs, is_model_in_normalized, is_model_out_normalized, batch_size, is_sampler_shuffle, loss_nature, loss_type, loss_kwargs, opt_algorithm, lr_init, lr_scheduler_type, lr_scheduler_kwargs, n_epochs, dataset_file_path, dataset, best_loss, best_training_epoch, total_time_sec, avg_time_epoch, best_model_parameters=None, torchinfo_summary=None): """Write summary data file for model training process. Parameters ---------- device_type : {'cpu', 'cuda'} Type of device on which torch.Tensor is allocated. seed : int Seed used to initialize the random number generators of Python and other libraries (e.g., NumPy, PyTorch) for all devices to preserve reproducibility. Does also set workers seed in PyTorch data loaders. model_directory : str Directory where material patch model is stored. load_model_state : {'best', 'last', int, None} Load available Graph Neural Network model state from the model directory. Data scalers are also loaded from model initialization file. n_max_epochs : int Maximum number of training epochs. is_model_in_normalized : bool, default=False If True, then model input features are assumed to be normalized (normalized input data has been seen during model training). is_model_out_normalized : bool, default=False If True, then model output features are assumed to be normalized (normalized output data has been seen during model training). batch_size : int Number of samples loaded per batch. is_sampler_shuffle : bool If True, shuffles data set samples at every epoch. loss_nature : str Loss nature. loss_type : str Loss function type. loss_kwargs : dict Arguments of torch.nn._Loss initializer. opt_algorithm : str Optimization algorithm. lr_init : float Initial value optimizer learning rate. Constant learning rate value if no learning rate scheduler is specified (lr_scheduler_type=None). lr_scheduler_type : str Type of learning rate scheduler. lr_scheduler_kwargs : dict Arguments of torch.optim.lr_scheduler.LRScheduler initializer. n_epochs : int Number of completed epochs in the training process. dataset_file_path : str Graph Neural Network graph data set file path if such file exists. Only used for output purposes. dataset : torch.utils.data.Dataset Graph Neural Network graph data set. Each sample corresponds to a torch_geometric.data.Data object describing a homogeneous graph. best_loss : float Best loss during training process. best_training_epoch : int Training epoch corresponding to best loss during training process. total_time_sec : int Total training time in seconds. avg_time_epoch : float Average training time per epoch. best_model_parameters : dict Model parameters corresponding to best model state. torchinfo_summary : str, default=None Torchinfo model architecture summary. """ # Set summary data summary_data = {} summary_data['device_type'] = device_type summary_data['seed'] = seed summary_data['model_directory'] = model_directory summary_data['load_model_state'] = \ load_model_state if load_model_state else None summary_data['n_max_epochs'] = n_max_epochs summary_data['is_model_in_normalized'] = is_model_in_normalized summary_data['is_model_out_normalized'] = is_model_out_normalized summary_data['batch_size'] = batch_size summary_data['is_sampler_shuffle'] = is_sampler_shuffle summary_data['loss_nature'] = loss_nature summary_data['loss_type'] = loss_type summary_data['loss_kwargs'] = loss_kwargs if loss_kwargs else None summary_data['opt_algorithm'] = opt_algorithm summary_data['lr_init'] = lr_init summary_data['lr_scheduler_type'] = \ lr_scheduler_type if lr_scheduler_type else None summary_data['lr_scheduler_kwargs'] = \ lr_scheduler_kwargs if lr_scheduler_kwargs else None summary_data['Number of completed epochs'] = n_epochs summary_data['Training data set file'] = \ dataset_file_path if dataset_file_path else None summary_data['Training data set size'] = len(dataset) summary_data['Best loss: '] = \ f'{best_loss:.8e} (training epoch {best_training_epoch})' summary_data['Total training time'] = \ str(datetime.timedelta(seconds=int(total_time_sec))) summary_data['Avg. training time per epoch'] = \ str(datetime.timedelta(seconds=int(avg_time_epoch))) # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # Set summary optional data if best_model_parameters is not None: summary_data['Model parameters (best state)'] = best_model_parameters if torchinfo_summary is not None: summary_data['torchinfo summary'] = torchinfo_summary # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # Write summary file write_summary_file( summary_directory=model_directory, summary_title='Summary: Model training', **summary_data) # ============================================================================= class EarlyStopper: """Early stopping procedure (implicit regularizaton). Attributes ---------- _validation_size : float Size of the validation data set for early stopping evaluation, where size is a fraction of the whole data set contained between 0 and 1. _validation_frequency : int Frequency of validation procedures, i.e., frequency with respect to training epochs at which model is validated to evaluate early stopping criterion. _trigger_tolerance : int Number of consecutive model validation procedures without performance improvement to trigger early stopping. _improvement_tolerance : float Minimum relative improvement required to count as a performance improvement. _validation_steps_history : list Validation steps history. _validation_loss_history : list Validation loss history. _min_validation_loss : float Minimum validation loss. _n_not_improve : int Number of consecutive model validations without improvement. _best_model_state : dict Model state corresponding to the best performance. _best_optimizer_state : dict Optimizer state corresponding to the best performance. _best_training_epoch : int Training epoch corresponding to the best performance. Methods ------- get_validation_loss_history(self) Get validation loss history. is_evaluate_criterion(self, epoch) Check whether to evaluate early stopping criterion. evaluate_criterion(self, model, optimizer, epoch, \ loss_nature='node_features_out', loss_type='mse', \ loss_kwargs={}, device_type='cpu') Evaluate early stopping criterion. _validate_model(self, model, optimizer, epoch, loss_nature='node_features_out', loss_type='mse', loss_kwargs={}, device_type='cpu') Perform model validation. load_best_performance_state(self, model, optimizer) Load minimum validation loss model and optimizer states. """ def __init__(self, validation_dataset, validation_frequency=1, trigger_tolerance=1, improvement_tolerance=1e-2): """Constructor. Parameters ---------- validation_dataset : torch.utils.data.Dataset Graph Neural Network graph data set. Each sample corresponds to a torch_geometric.data.Data object describing a homogeneous graph. validation_frequency : int, default=1 Frequency of validation procedures, i.e., frequency with respect to training epochs at which model is validated to evaluate early stopping criterion. trigger_tolerance : int, default=1 Number of consecutive model validation procedures without performance improvement to trigger early stopping. improvement_tolerance : float, default=1e-2 Minimum relative improvement required to count as a performance improvement. """ # Set validation data set self._validation_dataset = validation_dataset # Set validation frequency self._validation_frequency = validation_frequency # Set early stopping trigger tolerance self._trigger_tolerance = trigger_tolerance # Set minimum relative improvement tolerance self._improvement_tolerance = improvement_tolerance # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # Initialize validation training steps history self._validation_steps_history = [] # Initialize validation loss history self._validation_loss_history = [] # Initialize minimum validation loss self._min_validation_loss = np.inf # Initialize number of consecutive model validations without # improvement self._n_not_improve = 0 # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # Initialize minimum validation loss state (best performance) self._best_model_state = None self._best_optimizer_state = None self._best_training_epoch = None # ------------------------------------------------------------------------- def get_validation_loss_history(self): """Get validation loss history. Returns ------- validation_loss_history : list[float] Validation loss history. """ return copy.deepcopy(self._validation_loss_history) # ------------------------------------------------------------------------- def is_evaluate_criterion(self, epoch): """Check whether to evaluate early stopping criterion. Parameters ---------- epoch : int Training epoch. Returns ------- is_evaluate_criterion : bool If True, then early stopping criterion should be evaluated, False otherwise. """ return epoch % self._validation_frequency == 0 # ------------------------------------------------------------------------- def evaluate_criterion(self, model, optimizer, epoch, loss_nature='node_features_out', loss_type='mse', loss_kwargs={}, batch_size=1, device_type='cpu'): """Evaluate early stopping criterion. Parameters ---------- model : torch.nn.Module Graph Neural Network model. optimizer : torch.optim.Optimizer PyTorch optimizer. epoch : int Training epoch. loss_nature : {'node_features_out', 'global_features_out'}, \ default='node_features_out' Loss nature: 'node_features_out' : Based on node output features 'global_features_out' : Based on global output features loss_type : {'mse',}, default='mse' Loss function type: 'mse' : MSE (torch.nn.MSELoss) loss_kwargs : dict, default={} Arguments of torch.nn._Loss initializer. batch_size : int, default=1 Number of samples loaded per batch. device_type : {'cpu', 'cuda'}, default='cpu' Type of device on which torch.Tensor is allocated. Returns ------- is_stop_training : bool True if early stopping criterion has been triggered, False otherwise. """ # Set early stopping flag is_stop_training = False # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # Perform model validation avg_valid_loss_sample = self._validate_model( model, optimizer, epoch, loss_nature=loss_nature, loss_type=loss_type, loss_kwargs=loss_kwargs, batch_size=batch_size, device_type=device_type) # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # Update minimum validation loss and performance counter if avg_valid_loss_sample < self._min_validation_loss: # Check relative performance improvement with respect to minimum # validation loss if len(self._validation_steps_history) > 1: # Compute relative performance improvement relative_improvement = \ (self._min_validation_loss - avg_valid_loss_sample)/ \ np.abs(self._min_validation_loss) # Update performance counter if relative_improvement > self._improvement_tolerance: # Reset performance counter (significant improvement) self._n_not_improve = 0 else: # Reset performance counter (not significant improvement) self._n_not_improve += 1 # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # Update minimum validation loss self._min_validation_loss = avg_valid_loss_sample # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # Save best performance state (minimum validation loss) self._best_model_state = copy.deepcopy(model.state_dict()) self._best_optimizer_state = \ dict(state=copy.deepcopy(optimizer.state_dict()), epoch=epoch) self._best_training_epoch = epoch else: # Increment performance counter self._n_not_improve += 1 # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # Evaluate early stopping criterion if self._n_not_improve >= self._trigger_tolerance: # Trigger early stopping is_stop_training = True # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ return is_stop_training # ------------------------------------------------------------------------- def _validate_model(self, model, optimizer, epoch, loss_nature='node_features_out', loss_type='mse', loss_kwargs={}, batch_size=1, device_type='cpu'): """Perform model validation. Parameters ---------- model : torch.nn.Module Graph Neural Network model. optimizer : torch.optim.Optimizer PyTorch optimizer. epoch : int Training epoch. loss_nature : {'node_features_out', 'global_features_out'}, \ default='node_features_out' Loss nature: 'node_features_out' : Based on node output features 'global_features_out' : Based on global output features loss_type : {'mse',}, default='mse' Loss function type: 'mse' : MSE (torch.nn.MSELoss) loss_kwargs : dict, default={} Arguments of torch.nn._Loss initializer. batch_size : int, default=1 Number of samples loaded per batch. device_type : {'cpu', 'cuda'}, default='cpu' Type of device on which torch.Tensor is allocated. Returns ------- avg_predict_loss : float Average prediction loss per sample. """ # Set material patch model state file name and path model_state_file = model.model_name + '-' + str(int(epoch)) # Set material patch model state file path model_state_path = \ os.path.join(model.model_directory, model_state_file + '.pt') # Set optimizer state file name and path optimizer_state_file = \ model.model_name + '_optim' + '-' + str(int(epoch)) optimizer_state_path = \ os.path.join(model.model_directory, optimizer_state_file + '.pt') # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # Initialize temporary state files flag is_state_file_temp = False # Save model and optimizer state files (required for validation) if not os.path.isfile(model_state_path): # Update temporary state files flag is_state_file_temp = True # Save state files save_training_state(model=model, optimizer=optimizer, epoch=epoch) # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # Prediction with Graph Neural Network model _, avg_valid_loss_sample = predict( self._validation_dataset, model.model_directory, model=model, predict_directory=None, load_model_state=epoch, loss_nature=loss_nature, loss_type=loss_type, loss_kwargs=loss_kwargs, is_normalized_loss=model.is_model_out_normalized, batch_size=batch_size, device_type=device_type, seed=None, is_verbose=False) # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # Set model in training mode model.train() # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # Update validation epochs history self._validation_steps_history.append(epoch) # Propagate last validation loss until current epoch history_length = len(self._validation_loss_history) history_gap = epoch - history_length if history_length > 0: self._validation_loss_history += \ history_gap*[self._validation_loss_history[-1],] # Append validation loss self._validation_loss_history.append(avg_valid_loss_sample) # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # Remove model and optimizer state files (required for validation) if is_state_file_temp: os.remove(model_state_path) os.remove(optimizer_state_path) # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ return avg_valid_loss_sample # ------------------------------------------------------------------------- def load_best_performance_state(self, model, optimizer): """Load minimum validation loss model and optimizer states. Both model and optimizer are updated 'in-place' with stored state data. Parameters ---------- model : torch.nn.Module Graph Neural Network model. optimizer : torch.optim.Optimizer PyTorch optimizer. Returns ------- best_training_epoch : int Training epoch corresponding to the best performance. """ # Check best performance states if self._best_model_state is None: raise RuntimeError('The best performance model state has not been ' 'stored.') if self._best_optimizer_state is None: raise RuntimeError('The best performance optimization state has ' 'not been stored.') # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # Load material patch model state model.load_state_dict(self._best_model_state) # Set loaded optimizer state optimizer.load_state_dict(self._best_optimizer_state['state']) # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ return self._best_training_epoch