Source code for model_architectures.hybrid_base_model.model.hybridization

"""Hybridization model.

Classes
-------
HybridizationModel(torch.nn.Module)
    Hybridization model.
HMIdentity(torch.nn.Module)
    Hybridization model: Identity.
HMAdditive(torch.nn.Module)
    Hybridization model: Additive.
"""
#
#                                                                       Modules
# =============================================================================
# Third-party
import torch
#
#                                                          Authorship & Credits
# =============================================================================
__author__ = 'Bernardo Ferreira (bernardo_ferreira@brown.edu)'
__credits__ = ['Bernardo Ferreira', ]
__status__ = 'Stable'
# =============================================================================
#
# =============================================================================
[docs]class HybridizationModel(torch.nn.Module): """Hybridization model. Attributes ---------- hybridization_type : str Hybridization model type. hybridization_model : torch.nn.Module Hybridization model. Methods ------- forward(self, list_features_in) Forward propagation. """
[docs] def __init__(self, hybridization_type='identity'): """Constructor. Parameters ---------- hybridization_type : str, default='identity' Hybridization model type. """ # Initialize from base class super(HybridizationModel, self).__init__() # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # Set hybridization model type self._hybridization_type = hybridization_type # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # Initialize hybridization model if hybridization_type == 'identity': self.hybridization_model = HMIdentity() elif hybridization_type == 'additive': self.hybridization_model = HMAdditive() else: raise RuntimeError('Unknown hybridization type.')
# -------------------------------------------------------------------------
[docs] def forward(self, list_features_in): """Forward propagation. Parameters ---------- list_features_in : list[torch.Tensor] List of similar shaped tensors of input features stored as torch.Tensor(2d) of shape (sequence_length, n_features_in) for unbatched input or torch.Tensor(3d) of shape (sequence_length, batch_size, n_features_in) for batched input. Returns ------- features_out : torch.Tensor Tensor of output features stored as torch.Tensor(2d) of shape (sequence_length, n_features_in) for unbatched input or torch.Tensor(3d) of shape (sequence_length, batch_size, n_features_in) for batched input. """ # Forward propagation features_out = self.hybridization_model(list_features_in) # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ return features_out
# ============================================================================= class HMIdentity(torch.nn.Module): """Hybridization model: Identity. Let first input tensor pass unchanged, discard all remainder input tensors. Methods ------- forward(self, list_features_in) Forward propagation. """ def __init__(self): """Constructor.""" # Initialize from base class super(HMIdentity, self).__init__() # ------------------------------------------------------------------------- def forward(self, list_features_in): """Forward propagation. Parameters ---------- list_features_in : list[torch.Tensor] List of similar shaped tensors of input features stored as torch.Tensor(2d) of shape (sequence_length, n_features_in) for unbatched input or torch.Tensor(3d) of shape (sequence_length, batch_size, n_features_in) for batched input. Returns ------- features_out : torch.Tensor Tensor of output features stored as torch.Tensor(2d) of shape (sequence_length, n_features_in) for unbatched input or torch.Tensor(3d) of shape (sequence_length, batch_size, n_features_in) for batched input. """ # Get first input tensor features_out = list_features_in[0] # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ return features_out # ============================================================================= class HMAdditive(torch.nn.Module): """Hybridization model: Additive. Add input tensors. Methods ------- forward(self, list_features_in) Forward propagation. """ def __init__(self): """Constructor.""" # Initialize from base class super(HMAdditive, self).__init__() # ------------------------------------------------------------------------- def forward(self, list_features_in): """Forward propagation. Parameters ---------- list_features_in : list[torch.Tensor] List of similar shaped tensors of input features stored as torch.Tensor(2d) of shape (sequence_length, n_features_in) for unbatched input or torch.Tensor(3d) of shape (sequence_length, batch_size, n_features_in) for batched input. Returns ------- features_out : torch.Tensor Tensor of output features stored as torch.Tensor(2d) of shape (sequence_length, n_features_in) for unbatched input or torch.Tensor(3d) of shape (sequence_length, batch_size, n_features_in) for batched input. """ # Add input tensors features_out = sum(list_features_in) # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ return features_out