Source code for hookeai.utilities.type_conversion

"""Torch data types enforcement and conversion.

Functions
---------
convert_dict_to_tensor
    Convert all int, float and bool in dictionary to torch.Tensor.
convert_tensor_to_float32
    Convert floating point torch tensor to torch.float32.
convert_dict_to_float32
    Convert all floating point torch tensors in dictionary to torch.float32.
convert_tensor_to_float64
    Convert floating point torch tensor to torch.float64.
convert_dict_to_float64
    Convert all floating point torch tensors in dictionary to torch.float64.
"""
#
#                                                                       Modules
# =============================================================================
# Third-party
import torch
#
#                                                          Authorship & Credits
# =============================================================================
__author__ = 'Bernardo Ferreira (bernardo_ferreira@brown.edu)'
__credits__ = ['Bernardo Ferreira', ]
__status__ = 'Stable'
# =============================================================================
#
# =============================================================================
[docs]def convert_dict_to_tensor(data_dict, is_inplace=True): """Convert all int, float and bool in dictionary to torch.Tensor. Torch default types are assumed for each variable input type. Torch tensors and non-listed types are kept unchanged. Nested dictionaries are processed recursively. Parameters ---------- data_dict : dict Dictionary. is_inplace : bool, default=True If True, then input dictionary is updated in-place. Returns ------- data_dict : dict Dictionary. """ # Perform dictionary conversion if is_inplace: # Loop over dictionary items for key, value in data_dict.items(): # Perform type conversion if isinstance(value, dict): # Process nested dictionary recursively data_dict[key] = convert_dict_to_tensor(value, in_place=True) elif (isinstance(value, (int, float, bool)) and not isinstance(value, torch.Tensor)): # Convert to torch tensor data_dict[key] = torch.tensor(value) else: # Initialize converted dictionary local_data_dict = {} # Loop over dictionary items for key, value in data_dict.items(): # Perform type conversion if isinstance(value, dict): # Process nested dictionary recursively local_data_dict[key] = \ convert_dict_to_tensor(value, in_place=False) elif (isinstance(value, (int, float, bool)) and not isinstance(value, torch.Tensor)): # Convert to torch tensor local_data_dict[key] = torch.tensor(value) # Assign pointer data_dict = local_data_dict # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ return data_dict
# =============================================================================
[docs]def convert_tensor_to_float32(tensor): """Convert floating point torch tensor to torch.float32. Torch tensor with type torch.float32 or other non-float types is kept unchanged. Parameters ---------- tensor : torch.Tensor Tensor. Returns ------- tensor : torch.Tensor Tensor. """ # Perform type conversion if isinstance(tensor, torch.Tensor) and torch.is_floating_point(tensor): # Convert torch tensor to torch.float32 tensor = tensor.to(torch.float32) # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ return tensor
# =============================================================================
[docs]def convert_dict_to_float32(tensor_dict, is_inplace=True): """Convert all floating point torch tensors in dictionary to torch.float32. Torch tensors with type torch.float32 or other non-float types are kept unchanged. Nested dictionaries are processed recursively. Parameters ---------- tensor_dict : dict Dictionary. is_inplace : bool, default=True If True, then input dictionary is updated in-place. Returns ------- tensor_dict : dict Dictionary. """ # Perform dictionary conversion if is_inplace: # Loop over dictionary items for key, value in tensor_dict.items(): # Perform type conversion if isinstance(value, dict): # Process nested dictionary recursively tensor_dict[key] = convert_dict_to_float32(value) elif (isinstance(value, torch.Tensor) and torch.is_floating_point(value)): # Convert torch tensor to torch.float32 tensor_dict[key] = value.to(torch.float32) else: # Initialize converted dictionary local_tensor_dict = {} # Loop over dictionary items for key, value in tensor_dict.items(): # Perform type conversion if isinstance(value, dict): # Process nested dictionary recursively local_tensor_dict[key] = convert_dict_to_float32(value) elif (isinstance(value, torch.Tensor) and torch.is_floating_point(value)): # Convert torch tensor to torch.float32 local_tensor_dict[key] = value.to(torch.float32) else: # Keep value unchanged local_tensor_dict[key] = value # Assign pointer tensor_dict = local_tensor_dict # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ return tensor_dict
# =============================================================================
[docs]def convert_tensor_to_float64(tensor): """Convert floating point torch tensor to torch.float64. Torch tensor with type torch.float64 or other non-float types is kept unchanged. Parameters ---------- tensor : torch.Tensor Tensor. Returns ------- tensor : torch.Tensor Tensor. """ # Perform type conversion if isinstance(tensor, torch.Tensor) and torch.is_floating_point(tensor): # Convert torch tensor to torch.float64 tensor = tensor.to(torch.float64) # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ return tensor
# =============================================================================
[docs]def convert_dict_to_float64(tensor_dict, is_inplace=True): """Convert all floating point torch tensors in dictionary to torch.float64. Torch tensors with type torch.float64 or other non-float types are kept unchanged. Nested dictionaries are processed recursively. Parameters ---------- tensor_dict : dict Dictionary. is_inplace : bool, default=True If True, then input dictionary is updated in-place. Returns ------- tensor_dict : dict Dictionary. """ # Perform dictionary conversion if is_inplace: # Loop over dictionary items for key, value in tensor_dict.items(): # Perform type conversion if isinstance(value, dict): # Process nested dictionary recursively tensor_dict[key] = convert_dict_to_float64(value) elif (isinstance(value, torch.Tensor) and torch.is_floating_point(value)): # Convert torch tensor to torch.float64 tensor_dict[key] = value.to(torch.float64) else: # Initialize converted dictionary local_tensor_dict = {} # Loop over dictionary items for key, value in tensor_dict.items(): # Perform type conversion if isinstance(value, dict): # Process nested dictionary recursively local_tensor_dict[key] = convert_dict_to_float64(value) elif (isinstance(value, torch.Tensor) and torch.is_floating_point(value)): # Convert torch tensor to torch.float64 local_tensor_dict[key] = value.to(torch.float64) else: # Keep value unchanged local_tensor_dict[key] = value # Assign pointer tensor_dict = local_tensor_dict # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ return tensor_dict