Source code for model_architectures.procedures.model_state_files

"""Procedures to save and load model state files.

Functions
---------
save_model_state
    Save model state to file.
load_model_state
    Load model state from file.
check_state_file
    Check if file is model training epoch state file.
check_best_state_file
    Check if file is model best state file.
remove_posterior_state_files
    Delete model training epoch state files posterior to given epoch.
remove_best_state_files
    Delete existent model best state files.
"""
#
#                                                                       Modules
# =============================================================================
# Standard
import os
import re
# Third-party
import torch
#
#                                                          Authorship & Credits
# =============================================================================
__author__ = 'Bernardo Ferreira (bernardo_ferreira@brown.edu)'
__credits__ = ['Bernardo Ferreira', ]
__status__ = 'Stable'
# =============================================================================
#
# =============================================================================
[docs]def save_model_state(model, state_type='default', epoch=None, is_remove_posterior=True): """Save model state to file. Model default state file is stored in model_directory under the name < model_name >.pt. Model initial state file is stored in model_directory under the name < model_name >-init.pt. Model state file corresponding to given training epoch is stored in model_directory under the name < model_name >.pt or < model_name >-< epoch >.pt if epoch is known. Model state file corresponding to the best performance is stored in model_directory under the name < model_name >-best.pt or < model_name >-< epoch >-best.pt if epoch is known. Parameters ---------- model : torch.nn.Module Model. state_type : {'default', 'init', 'epoch', 'best'}, default='default' Saved model state file type. Options: 'default' : Model default state 'init' : Model initial state 'epoch' : Model state of given training epoch 'best' : Model state of best performance epoch : int, default=None Training epoch corresponding to current model state. is_remove_posterior : bool, default=True Remove model state files corresponding to training epochs posterior to the saved state file. Effective only if saved training epoch is known. """ # Check model if not isinstance(model, torch.nn.Module): raise RuntimeError('Model is not a torch.nn.Module.') # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # Get model directory if not hasattr(model, 'model_directory'): raise RuntimeError('The model directory is not available.') elif not os.path.isdir(model.model_directory): raise RuntimeError('The model directory has not been found:\n\n' + model.model_directory) # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # Initialize model state filename model_state_file = model.model_name # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # Set model state type if state_type == 'init': # Set model state filename model_state_file += '-init' else: # Append epoch if isinstance(epoch, int): model_state_file += '-' + str(epoch) # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # Set particular model states if state_type == 'best': # Set model state corresponding to the best performance model_state_file += '-' + 'best' # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # Remove any existent best model state file remove_best_state_files(model) # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # Set model state file path model_path = os.path.join(model.model_directory, model_state_file + '.pt') # Save model state torch.save(model.state_dict(), model_path) # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # Delete model epoch state files posterior to saved epoch if isinstance(epoch, int) and is_remove_posterior: remove_posterior_state_files(model, epoch)
# =============================================================================
[docs]def load_model_state(model, model_load_state=None, is_remove_posterior=True): """Load model state from file. Model state file is stored in model_directory under the name < model_name >.pt or < model_name >-< epoch >.pt if epoch is known. Model state file corresponding to the best performance is stored in model_directory under the name < model_name >-best.pt or < model_name >-< epoch >-best.pt if epoch if known. Model initial state file is stored in model directory under the name < model_name >-init.pt Parameters ---------- model : torch.nn.Module Model. model_load_state : {'default', 'init', int, 'best', 'last'}, default='default' Available model state to be loaded from the model directory. Options: 'default' : Model default state file 'init' : Model initial state int : Model state of given training epoch 'best' : Model state of best performance 'last' : Model state of latest training epoch is_remove_posterior : bool, default=True Remove model state files corresponding to training epochs posterior to the loaded state file. Effective only if loaded training epoch is known. Returns ------- epoch : int Loaded model state training epoch. Defaults to None if training epoch is unknown. """ # Check model if not isinstance(model, torch.nn.Module): raise RuntimeError('Model is not a torch.nn.Module.') # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # Get model directory if not hasattr(model, 'model_directory'): raise RuntimeError('The model directory is not available.') elif not os.path.isdir(model.model_directory): raise RuntimeError('The model directory has not been found:\n\n' + model.model_directory) # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # Initialize model state filename model_state_file = model.model_name # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # Set model state file if model_load_state == 'init': # Set model initial state file model_state_file += '-init' # Set epoch epoch = 0 # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # Delete model epoch state files posterior to loaded epoch if is_remove_posterior: remove_posterior_state_files(model, epoch) # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ elif isinstance(model_load_state, int): # Get epoch epoch = model_load_state # Set model state filename with epoch model_state_file += '-' + str(int(epoch)) # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # Delete model epoch state files posterior to loaded epoch if is_remove_posterior: remove_posterior_state_files(model, epoch) # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ elif model_load_state == 'best': # Get state files in model directory directory_list = os.listdir(model.model_directory) # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # Initialize model best state files epochs best_state_epochs = [] # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # Loop over files in model directory for filename in directory_list: # Check if file is model epoch best state file is_best_state_file, best_state_epoch = \ check_best_state_file(model, filename) # Store model best state file training epoch if is_best_state_file: best_state_epochs.append(best_state_epoch) # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # Set model best state file if not best_state_epochs: raise RuntimeError('Model best state file has not been found ' 'in directory:\n\n' + model.model_directory) elif len(best_state_epochs) > 1: raise RuntimeError('Two or more model best state files have ' 'been found in directory:' '\n\n' + model.model_directory) else: # Set best state epoch epoch = best_state_epochs[0] # Set model best state file if isinstance(epoch, int): model_state_file += '-' + str(epoch) model_state_file += '-' + 'best' # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # Delete model epoch state files posterior to loaded epoch if isinstance(epoch, int) and is_remove_posterior: remove_posterior_state_files(model, epoch) # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ elif model_load_state == 'last': # Get state files in model directory directory_list = os.listdir(model.model_directory) # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # Initialize model state files training epochs epochs = [] # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # Loop over files in model directory for filename in directory_list: # Check if file is model epoch state file is_state_file, epoch = check_state_file(model, filename) # Store model state file training epoch if is_state_file: epochs.append(epoch) # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # Set highest epoch model state file if epochs: # Set highest epoch epoch = max(epochs) # Set highest epoch model state file model_state_file += '-' + str(epoch) else: raise RuntimeError('Model state files corresponding to epochs ' 'have not been found in directory:\n\n' + model.model_directory) # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ else: # Set epoch as unknown epoch = None # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # Set model state file path model_path = os.path.join(model.model_directory, model_state_file + '.pt') # Check model state file if not os.path.isfile(model_path): raise RuntimeError('Model state file has not been found:\n\n' + model_path) # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # Load model state model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu'))) # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ return epoch
# ============================================================================= def check_state_file(model, filename): """Check if file is model training epoch state file. Model training epoch state file is stored in model_directory under the name < model_name >-< epoch >.pt. Parameters ---------- model : torch.nn.Module Model. filename : str File name. Returns ------- is_state_file : bool True if model training epoch state file, False otherwise. epoch : {None, int} Training epoch corresponding to model state file if is_state_file=True, None otherwise. """ # Check if file is model epoch state file is_state_file = bool(re.search(r'^' + model.model_name + r'-[0-9]+' + r'\.pt', filename)) # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ epoch = None if is_state_file: # Get model state epoch epoch = int(os.path.splitext(filename)[0].split('-')[-1]) # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ return is_state_file, epoch # ============================================================================= def check_best_state_file(model, filename): """Check if file is model best state file. Model state file corresponding to the best performance is stored in model_directory under the name < model_name >-best.pt. or < model_name >-< epoch >-best.pt if the training epoch is known. Parameters ---------- model : torch.nn.Module Model. filename : str File name. Returns ------- is_best_state_file : bool True if model training epoch state file, False otherwise. epoch : {None, int} Training epoch corresponding to model state file if is_best_state_file=True and training epoch is known, None otherwise. """ # Check if file is model epoch best state file is_best_state_file = bool(re.search(r'^' + model.model_name + r'-?[0-9]*' + r'-best' + r'\.pt', filename)) # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ epoch = None if is_best_state_file: # Get model state epoch epoch = int(os.path.splitext(filename)[0].split('-')[-2]) # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ return is_best_state_file, epoch # ============================================================================= def remove_posterior_state_files(model, epoch): """Delete model training epoch state files posterior to given epoch. Parameters ---------- model : torch.nn.Module Model. epoch : int Training epoch. """ # Get files in model directory directory_list = os.listdir(model.model_directory) # Loop over files in model directory for filename in directory_list: # Check if file is model epoch state file is_state_file, file_epoch = check_state_file(model, filename) # Delete model epoch state file posterior to given epoch if is_state_file and file_epoch > epoch: os.remove(os.path.join(model.model_directory, filename)) # ============================================================================= def remove_best_state_files(model): """Delete existent model best state files. Parameters ---------- model : torch.nn.Module Model. """ # Get files in model directory directory_list = os.listdir(model.model_directory) # Loop over files in model directory for filename in directory_list: # Check if file is model best state file is_best_state_file, _ = check_best_state_file(model, filename) # Delete state file if is_best_state_file: os.remove(os.path.join(model.model_directory, filename))