Source code for model_architectures.rnn_base_model.custom.gru_vmap

"""Multi-layer gated recurrent unit (GRU) recurrent neural network model.

Custom implementation of GRUCell and GRU compatible with VMAPs.

Classes
-------
GRUCell(torch.nn.Module)
    Gated Recurrent Unit (GRU) cell.
GRU(torch.nn.Module)
    Multi-layer gated recurrent unit (GRU) recurrent neural network model.
"""
#
#                                                                       Modules
# =============================================================================
# Standard
import math
import time
# Third-party
import torch
#
#                                                          Authorship & Credits
# =============================================================================
__author__ = 'Bernardo Ferreira (bernardo_ferreira@brown.edu)'
__credits__ = ['Bernardo Ferreira', ]
__status__ = 'Stable'
# =============================================================================
#
# =============================================================================
class GRUCell(torch.nn.Module):
    """Gated Recurrent Unit (GRU) cell.
    
    Attributes
    ----------
    input_size : int
        Number of input features.
    hidden_size : int
        Number of output features.
    bias : bool
        If True, then consider bias weights as learnable parameters.
    num_chunks : int
        Number of gates.
    w_ih : torch.Tensor(2d)
        Input-to-hidden weights stored as torch.Tensor(2d) of shape
        (3*hidden_size, input_size).
    w_hh : torch.Tensor(2d)
        Hidden-to-hidden weights stored as torch.Tensor(2d) of shape
        (3*hidden_size, input_size).
    b_ih : torch.Tensor(2d)
        Hidden-to-hidden bias weights stored as torch.Tensor(2d) of shape
        (3*hidden_size).
    b_hh : torch.Tensor(2d)
        Hidden-to-hidden bias weights stored as torch.Tensor(2d) of shape
        (3*hidden_size).
    device : torch.device
        Device on which torch.Tensor is allocated.

    Methods
    -------
    forward(self, input, hx=None)
        Forward propagation.
    """
    def __init__(self, input_size, hidden_size, bias=True, device=None):
        """Constructor.
        
        Parameters
        ----------
        input_size : int
            Number of input features.
        hidden_size : int
            Number of output features.
        bias : bool, default=True
            If True, then consider bias weights.
        device : torch.device
            Device on which torch.Tensor is allocated.
        """
        # Initialize from base class
        super(GRUCell, self).__init__()
        # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
        # Set number of features
        self.input_size = input_size
        self.hidden_size = hidden_size
        # Set architecture parameters
        self.bias = bias
        # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
        # Set device
        self.device = device
        # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
        # Set number of chunks
        self.num_chunks = 3
        # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
        # Set weights (input-to-hidden)
        self.w_ih = torch.nn.Parameter(
            torch.empty((self.num_chunks*hidden_size, input_size),
                        device=self.device))
        # Set weights (hidden-to-hidden)
        self.w_hh = torch.nn.Parameter(
            torch.empty((self.num_chunks*hidden_size, hidden_size),
                        device=self.device))
        # Set biases
        if self.bias:
            # Set bias (input-to-hidden)
            self.b_ih = torch.nn.Parameter(
                torch.empty(self.num_chunks*hidden_size, device=self.device))
            # Set bias (hidden-to-hidden)
            self.b_hh = torch.nn.Parameter(
                torch.empty(self.num_chunks*hidden_size, device=self.device))
        else:
            # Set bias (input-to-hidden)
            self.b_ih = \
                torch.zeros(self.num_chunks*hidden_size, device=self.device)
            # Set bias (hidden-to-hidden)
            self.b_hh = \
                torch.zeros(self.num_chunks*hidden_size, device=self.device)
        # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
        # Initialize parameters
        self.reset_parameters()
    # -------------------------------------------------------------------------
    def reset_parameters(self):
        """Initialize learnable parameters."""
        # Initialize uniform distribution bounds
        if self.hidden_size > 0:
            stdv = 1.0 / math.sqrt(self.hidden_size)
        else:
            stdv = 0
        # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
        # Loop over parameters
        for parameter in self.parameters():
            # Initialize from uniform distribution
            torch.nn.init.uniform_(parameter, -stdv, stdv)
    # -------------------------------------------------------------------------
    def forward(self, input, hx=None):
        """Forward propagation.
        
        Parameters
        ----------
        input : torch.Tensor
            Tensor of input features stored as torch.Tensor(1d) of shape
            (input_size) for unbatched input or torch.Tensor(2d) of shape
            (batch_size, input_size) for batched input.
        hx : torch.Tensor, default=None
            Tensor of initial hidden state features stored as torch.Tensor(1d)
            of shape (hidden_size) for unbatched input or torch.Tensor(2d) of
            shape (batch_size, hidden_size) for batched input. If None, then
            defaults to zero.
        
        Returns
        -------
        h_updated : torch.Tensor
            Tensor of updated hidden state features stored as torch.Tensor(1d)
            of shape (hidden_size) for unbatched input or torch.Tensor(2d) of
            shape (batch_size, hidden_size) for batched input.
        """
        # Check if batched input
        is_batched = input.dim() == 2
        # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
        # Set batched dimension
        if not is_batched:
            input = input.unsqueeze(0)
        # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
        # Initialize hidden state features
        if hx is None:
            hx = torch.zeros((input.size(0), self.hidden_size),
                             device=input.device)
        else:
            # Set batched dimension
            if not is_batched:
                hx = hx.unsqueeze(0)
        # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
        # Compute input tensor linear transformations
        gi = torch.mm(input, self.w_ih.t()) + self.b_ih
        # Compute hidden tensor linear transformations
        gh = torch.mm(hx, self.w_hh.t()) + self.b_hh
        # Split input and hidden tensor linear transformations
        i_r, i_i, i_n = torch.chunk(gi, chunks=self.num_chunks, dim=1)
        h_r, h_i, h_n = torch.chunk(gh, chunks=self.num_chunks, dim=1)
        # Compute update gate output
        # (shape: batch_size x hidden_size)
        update_gate = torch.sigmoid(i_i + h_i)
        # Compute reset gate output
        # (shape: batch_size x hidden_size)
        reset_gate = torch.sigmoid(i_r + h_r)
        # Compute new gate output (new candidate hidden tensor)
        # (shape: batch_size x hidden_size)
        new_gate = torch.tanh(i_n + reset_gate*h_n)
        # Compute updated hidden tensor
        # (shape: batch_size x hidden_size)
        h_updated = update_gate*(hx - new_gate) + new_gate        
        # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
        # Remove batched dimension
        if not is_batched:
            h_updated = h_updated.squeeze(0)
        # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
        return h_updated
# =============================================================================
[docs]class GRU(torch.nn.Module): """Multi-layer gated recurrent unit (GRU) recurrent neural network model. It is assumed that batched input and output tensors are provided with shape (sequential_length, batch_size, n_features), corresponding to the default setting batch_first=False. Dropout layers and bidirectionality are not implemented. Attributes ---------- input_size : int Number of input features. hidden_size : int Number of output features. bias : bool If True, then consider bias weights as learnable parameters. num_layers : int, default=1 Number of recurrent layers. A number of recurrent layers greater than 1 results in a stacked GRU (output of GRU in each time t is the input of next GRU). device : torch.device Device on which torch.Tensor is allocated. Methods ------- forward(self, input, hx=None) Forward propagation. """
[docs] def __init__(self, input_size, hidden_size, num_layers=1, bias=True, device=None): """Constructor. Parameters ---------- input_size : int Number of input features. hidden_size : int Number of output features. num_layers : int, default=1 Number of recurrent layers. A number of recurrent layers greater than 1 results in a stacked GRU (output of GRU in each time t is the input of next GRU). bias : bool, default=True If True, then consider bias weights. device : torch.device Device on which torch.Tensor is allocated. """ # Initialize from base class super(GRU, self).__init__() # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # Set number of features self.input_size = input_size self.hidden_size = hidden_size # Set architecture parameters self.bias = bias # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # Set number of layers self.num_layers = num_layers # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # Set device self.device = device # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # Initialize GRU layers self.gru_layers = torch.nn.ModuleList() # Set initial GRU layer (input to hidden) self.gru_layers.append(GRUCell(self.input_size, self.hidden_size, bias=self.bias, device=self.device)) # Set remaining GRU layers (hidden to hidden) for _ in range(1, num_layers): self.gru_layers.append(GRUCell(self.hidden_size, self.hidden_size, bias=self.bias, device=self.device))
# -------------------------------------------------------------------------
[docs] def forward(self, input, hx=None): """Forward propagation. Parameters ---------- input : torch.Tensor Tensor of input features stored as torch.Tensor(2d) of shape (sequence_length, input_size) for unbatched input or torch.Tensor(3d) of shape (sequence_length, batch_size, input_size) for batched input. hx : torch.Tensor, default=None Tensor of initial hidden state features stored as torch.Tensor(2d) of shape (num_layers, hidden_size) for unbatched input or torch.Tensor(3d) of shape (num_layers, batch_size, hidden_size) for batched input. Returns ------- output : torch.Tensor Tensor of output features stored as torch.Tensor(2d) of shape (sequence_length, hidden_size) for unbatched input or torch.Tensor(3d) of shape (sequence_length, batch_size, hidden_size) for batched input. Corresponds to the tensor of hidden state features output from the last GRU layer. h_n : torch.Tensor Tensor of final multi-layer GRU hidden state features stored as torch.Tensor(2d) of shape (num_layers, hidden_size) for unbatched input or torch.Tensor(3d) of shape (num_layers, batch_size, hidden_size) for batched input. """ # Get sequence length n_time = input.size(0) # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # Set batching dimension batch_dim = 1 # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # Check if batched input is_batched = input.dim() == 3 # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # Set batched dimension if not is_batched: input = input.unsqueeze(batch_dim) # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # Get batch size batch_size = input.size(batch_dim) # Initialize hidden state features if hx is None: hx = torch.zeros((self.num_layers, batch_size, self.hidden_size), device=input.device) else: # Set batched dimension if not is_batched: hx = hx.unsqueeze(1) # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # Initialize GRU layers (current) hidden state h_layers = [hx[l, :, :] for l in range(self.num_layers)] # Initialize output features output_times = [] # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # Loop over time steps for t in range(n_time): # Loop over GRU layers for l in range(self.num_layers): # Get GRU layer input for current time step if l == 0: layer_input = input[t, :, :] else: layer_input = h_layers[l - 1] # Compute GRU layer updated hidden state for current time step h_updated = self.gru_layers[l](layer_input, h_layers[l]) # Update GRU layer (current) hidden state h_layers[l] = h_updated # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # Store last GRU layer output for current time step output_times.append(h_updated) # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # Build tensor of output features output = torch.stack(output_times, dim=0) # Build tensor of final multi-layer GRU hidden state features h_n = torch.stack(h_layers, dim=0) # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # Remove batched dimension if not is_batched: output = output.squeeze(1) h_n = h_n.squeeze(1) # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ return output, h_n
# ============================================================================= if __name__ == '__main__': # Set function timer def function_timer(function, args, n_calls=1): # Initialize total execution time total_time = 0 # Loop over number of function calls for i in range(n_calls): # Set initial call time t0 = time.time() # Call function function(*args) # Add to total execution time total_time += time.time() - t0 # Compute average time per function call avg_time_call = total_time/n_calls # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ return avg_time_call # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # Set backward timer def backward_timer(scalar): # Initialize total execution time t0 = time.time() # Backward propagation scalar.backward() # Set total execution time total_time = time.time() - t0 # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ return total_time # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # Set seed for reproducibility def set_seed(seed): torch.manual_seed(seed) torch.cuda.manual_seed(seed) torch.cuda.manual_seed_all(seed) torch.backends.cudnn.deterministic = True torch.backends.cudnn.benchmark = False # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # Set random number generators initialization for reproducibility seed = 0 set_seed(seed) # Set device device = torch.device('cuda') # Set batch testing is_batched = True # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # Set parameter option option = ('simple', 'complex')[0] # Set parameters if option == 'complex': # Set number of features input_size = 6 hidden_size = 500 # Set sequence length sequence_length = 100 # Set batch size (number of time series) batch_size = 1 # Set number of GRU layers num_layers = 3 else: # Set number of features input_size = 6 hidden_size = 5 # Set sequence length sequence_length = 10 # Set batch size (number of time series) batch_size = 4 # Set number of GRU layers num_layers = 3 # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # Set random input time series data input = torch.randn((sequence_length, batch_size, input_size), device=device) # Set random initial hidden state hx = torch.randn((num_layers, batch_size, hidden_size), device=device) # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # Display print(f'\nTesting: Torch GRU Cell vs Custom GRU Cell' + '\n' + '-'*len('\nTesting: Torch GRU Cell vs Custom GRU Cell')) print(f'\n > Device: {device}') print(f'\n > Batched input/output: {is_batched}') # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # Initialize PyTorch GRU Cell set_seed(seed) gru_cell_torch = torch.nn.GRUCell(input_size, hidden_size, bias=True, device=device) # Initialize Custom GRU Cell set_seed(seed) gru_cell_custom = GRUCell(input_size, hidden_size, bias=True, device=device) # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # Set batched/unbatched data if is_batched: input_cell = input[0, :, :].clone() hx_cell = hx[0, :, :].clone() else: input_cell = input[0, 0, :].clone() hx_cell = hx[0, 0, :].clone() # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # Forward propagation (updated hidden state) h_updated_torch = gru_cell_torch(input_cell, hx_cell) # Forward propagation (updated hidden state) h_updated_custom = gru_cell_custom(input_cell, hx_cell) # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # Compare results if not torch.allclose(h_updated_torch, h_updated_custom): print(f'\n > Matching results? FALSE') raise RuntimeError('Torch and Custom GRU Cells results do not match!') else: print(f'\n > Matching results? TRUE') # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # Check backward propagation time backward_time_torch = backward_timer(torch.sum(h_updated_torch)) backward_time_custom = backward_timer(torch.sum(h_updated_custom)) backward_time_ratio = backward_time_custom/backward_time_torch # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # Check average time per call avg_time_call_torch = function_timer(gru_cell_torch, (input_cell, hx_cell), n_calls=1000) avg_time_call_custom = function_timer(gru_cell_custom, (input_cell, hx_cell), n_calls=1000) avg_time_call_ratio = avg_time_call_custom/avg_time_call_torch # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # Compare computational times print(f'\n > Forward propagation times:') print(f'\n - Avg. time per call (Torch): {avg_time_call_torch:.4e}') print(f'\n - Avg. time per call (Custom): {avg_time_call_custom:.4e}') print(f'\n - Custom/Torch = {avg_time_call_ratio:.2f} ') print(f'\n\n > Backward propagation times:') print(f'\n - Time (Torch): {backward_time_torch:.4e}') print(f'\n - Time (Custom): {backward_time_custom:.4e}') print(f'\n - Custom/Torch = {backward_time_ratio:.2f}') # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # Display print(f'\n\nTesting: Torch GRU vs Custom GRU' + '\n' + '-'*len('\nTesting: Torch GRU vs Custom GRU')) print(f'\n > Device: {device}') print(f'\n > Batched input/output: {is_batched}') # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # Initialize PyTorch GRU set_seed(seed) gru_torch = torch.nn.GRU(input_size, hidden_size, num_layers=num_layers, bias=True, batch_first=False, device=device) # Initialize Custom GRU set_seed(seed) gru_custom = GRU(input_size, hidden_size, num_layers=num_layers, bias=True, device=device) # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # Set batched/unbatched data if is_batched: pass else: input = input[:, 0, :].clone() hx = hx[:, 0, :].clone() # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # Forward propagation output_torch, h_n_torch = gru_torch(input, hx) # Forward propagation output_custom, h_n_custom = gru_custom(input, hx) # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # Notes: # The absolute difference in the following comparison seems to increase # as the number of sequential computation increases (increasing number of # layers or sequence length). Suspect that this may result from the # accumulation of error because (1) matching results are obtained for # GRU Cell irrespective of parameters and (ii) matching results are # obtained for GRU with parameters leading to lower number of sequential # operations (layers or time). # # Example of GRU parameters leading to matching results: # input_size = 6 # hidden_size = 5 # sequence_length = 10 # batch_size = 4 # num_layers = 3 # # Set raise error when mismatching results is_raise_error = False # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # Compare results if not torch.allclose(output_torch, output_custom, atol=1e-08): print(f'\n > Output matching results? FALSE') if is_raise_error: raise RuntimeError('Torch and Custom GRUs output results do not ' 'match!') elif not torch.allclose(h_n_torch, h_n_custom, atol=1e-08): print(f'\n > Layers hidden states matching results? FALSE') if is_raise_error: raise RuntimeError('Torch and Custom GRUs layers hidden states ' 'results do not match!') else: print(f'\n > Matching results? TRUE') # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # Check backward propagation time backward_time_torch = backward_timer(torch.sum(output_torch)) backward_time_custom = backward_timer(torch.sum(output_custom)) backward_time_ratio = backward_time_custom/backward_time_torch # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # Check average time per call avg_time_call_torch = function_timer(gru_torch, (input, hx), n_calls=100) avg_time_call_custom = function_timer(gru_custom, (input, hx), n_calls=100) avg_time_call_ratio = avg_time_call_custom/avg_time_call_torch # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # Compare computational times print(f'\n > Forward propagation times:') print(f'\n - Avg. time per call (Torch): {avg_time_call_torch:.4e}') print(f'\n - Avg. time per call (Custom): {avg_time_call_custom:.4e}') print(f'\n - Custom/Torch = {avg_time_call_ratio:.2f}') print(f'\n\n > Backward propagation times:') print(f'\n - Time (Torch): {backward_time_torch:.4e}') print(f'\n - Time (Custom): {backward_time_custom:.4e}') print(f'\n - Custom/Torch = {backward_time_ratio:.2f}') # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # Test vectorizing maps (only for batched input/output) if is_batched: # Display print(f'\n\nTesting: Custom GRU Cell vs Custom GRU Cell VMAP' + '\n' + '-'*len('\nTesting: Custom GRU Cell vs Custom GRU Cell VMAP')) print(f'\n > Device: {device}') print(f'\n > Batched input/output: {is_batched}') # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # Set vectorized GRU cell (batch along time) vmap_gru_cell_custom = \ torch.vmap(gru_cell_custom, in_dims=(0, 0), out_dims=(0,)) # Forward propagation (updated hidden state) h_updated_custom_vmap = vmap_gru_cell_custom(input_cell, hx_cell) # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # Compare results if not torch.allclose(h_updated_custom, h_updated_custom_vmap): print(f'\n > Matching results? FALSE') raise RuntimeError('Custom and Custom VMAP GRU Cells results do ' 'not match!') else: print(f'\n > Matching results? TRUE') # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # Backward propagation torch.sum(h_updated_custom_vmap).backward() # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # Display print(f'\n\nTesting: Custom GRU vs Custom GRU VMAP' + '\n' + '-'*len('\nTesting: Custom GRU vs Custom GRU VMAP')) print(f'\n > Device: {device}') print(f'\n > Batched input/output: {is_batched}') # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # Set vectorized GRU (batch along time series) vmap_gru_custom = \ torch.vmap(gru_custom, in_dims=(1, 1), out_dims=(1, 1)) # Forward propagation (updated hidden state) output_custom_vmap, h_n_custom_vmap = vmap_gru_custom(input, hx) # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # Compare results if not torch.allclose(output_custom, output_custom_vmap): print(f'\n > Output matching results? FALSE') if is_raise_error: raise RuntimeError('Torch and Custom GRUs output results do ' 'not match!') elif not torch.allclose(h_n_custom, h_n_custom_vmap): print(f'\n > Layers hidden states matching results? FALSE') if is_raise_error: raise RuntimeError('Torch and Custom GRUs layers hidden ' 'states results do not match!') else: print(f'\n > Matching results? TRUE') # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ print()