Source code for utilities.optimizers

"""PyTorch-based optimizers.

Functions
---------
get_pytorch_optimizer
    Get PyTorch optimizer.
get_learning_rate_scheduler
    Get PyTorch optimizer learning rate scheduler.
"""
#
#                                                                       Modules
# =============================================================================
# Third-party
import torch
#
#                                                          Authorship & Credits
# =============================================================================
__author__ = 'Bernardo Ferreira (bernardo_ferreira@brown.edu)'
__credits__ = ['Bernardo Ferreira', ]
__status__ = 'Stable'
# =============================================================================
#
# =============================================================================
[docs]def get_pytorch_optimizer(algorithm, params, **kwargs): """Get PyTorch optimizer. Parameters ---------- algorithm : {'adam',} Optimization algorithm: 'adam' : Adam (torch.optim.Adam) params : list List of parameters (torch.Tensors) to optimize or list of dicts defining parameter groups. **kwargs Arguments of torch.optim.Optimizer initializer. Returns ------- optimizer : torch.optim.Optimizer PyTorch optimizer. """ if algorithm == 'adam': optimizer = torch.optim.Adam(params, **kwargs) else: raise RuntimeError('Unknown or unavailable PyTorch optimizer.') # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ return optimizer
# =============================================================================
[docs]def get_learning_rate_scheduler(optimizer, scheduler_type, **kwargs): """Get PyTorch optimizer learning rate scheduler. Parameters ---------- optimizer : torch.optim.Optimizer PyTorch optimizer. lr_scheduler_type : {'steplr', 'explr', 'linlr'} Type of learning rate scheduler: 'steplr' : Step-based decay (torch.optim.lr_scheduler.SetpLR) 'explr' : Exponential decay (torch.optim.lr_scheduler.ExponentialLR) 'linlr' : Linear decay (torch.optim.lr_scheduler.LinearLR) **kwargs Arguments of torch.optim.lr_scheduler.LRScheduler initializer. Returns ------- scheduler : torch.optim.lr_scheduler.LRScheduler PyTorch optimizer learning rate scheduler. """ if scheduler_type == 'steplr': # Check scheduler mandatory parameters if 'step_size' not in kwargs.keys(): raise RuntimeError('The parameter \'step_size\' needs to be ' 'provided to initialize step-based decay ' 'learning rate scheduler.') # Initialize scheduler scheduler = torch.optim.lr_scheduler.StepLR(optimizer, **kwargs) # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ elif scheduler_type == 'explr': # Check scheduler mandatory parameters if 'gamma' not in kwargs.keys(): raise RuntimeError('The parameter \'gamma\' needs to be ' 'provided to initialize exponential decay ' 'learning rate scheduler.') # Initialize scheduler scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, **kwargs) # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ elif scheduler_type == 'linlr': # Initialize scheduler scheduler = torch.optim.lr_scheduler.LinearLR(optimizer, **kwargs) # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ else: raise RuntimeError('Unknown or unavailable PyTorch optimizer ' 'learning rate scheduler.') # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ return scheduler