"""PyTorch-based optimizers.
Functions
---------
get_pytorch_optimizer
Get PyTorch optimizer.
get_learning_rate_scheduler
Get PyTorch optimizer learning rate scheduler.
"""
#
# Modules
# =============================================================================
# Third-party
import torch
#
# Authorship & Credits
# =============================================================================
__author__ = 'Bernardo Ferreira (bernardo_ferreira@brown.edu)'
__credits__ = ['Bernardo Ferreira', ]
__status__ = 'Stable'
# =============================================================================
#
# =============================================================================
[docs]def get_pytorch_optimizer(algorithm, params, **kwargs):
"""Get PyTorch optimizer.
Parameters
----------
algorithm : {'adam',}
Optimization algorithm:
'adam' : Adam (torch.optim.Adam)
params : list
List of parameters (torch.Tensors) to optimize or list of dicts
defining parameter groups.
**kwargs
Arguments of torch.optim.Optimizer initializer.
Returns
-------
optimizer : torch.optim.Optimizer
PyTorch optimizer.
"""
if algorithm == 'adam':
optimizer = torch.optim.Adam(params, **kwargs)
else:
raise RuntimeError('Unknown or unavailable PyTorch optimizer.')
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
return optimizer
# =============================================================================
[docs]def get_learning_rate_scheduler(optimizer, scheduler_type, **kwargs):
"""Get PyTorch optimizer learning rate scheduler.
Parameters
----------
optimizer : torch.optim.Optimizer
PyTorch optimizer.
lr_scheduler_type : {'steplr', 'explr', 'linlr'}
Type of learning rate scheduler:
'steplr' : Step-based decay (torch.optim.lr_scheduler.SetpLR)
'explr' : Exponential decay (torch.optim.lr_scheduler.ExponentialLR)
'linlr' : Linear decay (torch.optim.lr_scheduler.LinearLR)
**kwargs
Arguments of torch.optim.lr_scheduler.LRScheduler initializer.
Returns
-------
scheduler : torch.optim.lr_scheduler.LRScheduler
PyTorch optimizer learning rate scheduler.
"""
if scheduler_type == 'steplr':
# Check scheduler mandatory parameters
if 'step_size' not in kwargs.keys():
raise RuntimeError('The parameter \'step_size\' needs to be '
'provided to initialize step-based decay '
'learning rate scheduler.')
# Initialize scheduler
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, **kwargs)
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
elif scheduler_type == 'explr':
# Check scheduler mandatory parameters
if 'gamma' not in kwargs.keys():
raise RuntimeError('The parameter \'gamma\' needs to be '
'provided to initialize exponential decay '
'learning rate scheduler.')
# Initialize scheduler
scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, **kwargs)
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
elif scheduler_type == 'linlr':
# Initialize scheduler
scheduler = torch.optim.lr_scheduler.LinearLR(optimizer, **kwargs)
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
else:
raise RuntimeError('Unknown or unavailable PyTorch optimizer '
'learning rate scheduler.')
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
return scheduler