Source code for model_architectures.procedures.model_summary

"""Get PyTorch model summary data.
 
Functions
---------
get_n_model_parameters
    Get number of parameters of PyTorch model.
get_model_summary
    Get summary of PyTorch model.
"""
#
#                                                                       Modules
# =============================================================================
# Third-party
import torch
import torchinfo
#
#                                                          Authorship & Credits
# =============================================================================
__author__ = 'Bernardo Ferreira (bernardo_ferreira@brown.edu)'
__credits__ = ['Bernardo Ferreira', ]
__status__ = 'Stable'
# =============================================================================
#
# =============================================================================
def get_n_model_parameters(model):
    """Get number of parameters of PyTorch model.
    
    Parameters
    ----------
    model : torch.nn.Module
        PyTorch model.
        
    Returns
    -------
    n_total_params : int
        Total number of model parameters.
    n_train_params : int
        Number of trainable model parameters.
    n_nontrain_params : int
        Number of non-trainable model parameters.
    """
    # Check model
    if not isinstance(model, torch.nn.Module):
        raise RuntimeError('Model is not a torch.nn.Module.')
    # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
    # Get model total number of parameters
    n_total_params = sum(p.numel() for p in model.parameters())
    # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
    # Get model number of trainable parameters
    n_train_params = sum(p.numel() for p in model.parameters()
                         if p.requires_grad)
    # Get model number of non-trainable parameters
    n_nontrain_params = n_total_params - n_train_params
    # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
    return n_total_params, n_train_params, n_nontrain_params
# =============================================================================
[docs]def get_model_summary(model, input_data=None, device_type='cpu', is_verbose=False, **kwargs): """Get summary of PyTorch model. Wrapper: torchinfo (https://pypi.org/project/torchinfo/) Parameters ---------- model : torch.nn.Module PyTorch model. input_data : list[torch.Tensor], default=None Input data of PyTorch model forward propagation. If provided, then further summary data is computed and displayed (e.g., input/output shapes, number of operations, memory requirements). Must be list of one or more torch.Tensor to avoid unexpected behavior. device_type : {'cpu', 'cuda'}, default='cpu' Type of device on which torch.Tensor is allocated. is_verbose : bool, default=False If True, enable verbose output. kwargs : dict Other arguments of PyTorch model forward propagation. Returns ------- model_statistics : torchinfo.model_statistics.ModelStatistics PyTorch model summary object. """ # Check PyTorch model input data if input_data is not None: if not isinstance(input_data, list): raise RuntimeError('If provided, input data of PyTorch model must ' 'be list of torch.Tensor.') elif not all([isinstance(x, torch.Tensor) for x in input_data]): raise RuntimeError('If provided, input data of PyTorch model must ' 'be list of torch.Tensor.') # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # Check device if device_type in ('cpu', 'cuda'): if device_type == 'cuda' and not torch.cuda.is_available(): raise RuntimeError('PyTorch with CUDA is not available.') device = torch.device(device_type) else: raise RuntimeError('Invalid device type.') # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # Set model forward propagation mode mode = 'eval' # Set maximum nesting model depth to display depth = 5 # Set model verbose if is_verbose: verbose = 1 else: verbose = 0 # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ if is_verbose: print('\ntorchinfo - Model summary in PyTorch' '\n------------------------------------') print('\n> Model summary:\n') # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # Generate Pytorch model summary model_statistics = torchinfo.summary(model, input_data=input_data, depth=depth, device=device, mode=mode, verbose=verbose, **kwargs) # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ if is_verbose: print() # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ return model_statistics