"""Graph Neural Network based model.
Classes
-------
GNNEPDBaseModel(torch.nn.Module)
GNN Encoder-Processor-Decoder base model.
Functions
---------
graph_standard_partial_fit
Perform batch fitting of standardization data scalers.
"""
#
# Modules
# =============================================================================
# Standard
import copy
import os
import re
import pickle
# Third-party
import torch
import torch_geometric.nn
import torch_geometric.data
import torch_geometric.loader
import sklearn.preprocessing
# Local
from gnn_base_model.model.gnn_epd_model import EncodeProcessDecode
from utilities.data_scalers import TorchStandardScaler
#
# Authorship & Credits
# =============================================================================
__author__ = 'Bernardo Ferreira (bernardo_ferreira@brown.edu)'
__credits__ = ['Bernardo Ferreira', 'Rui Barreira']
__status__ = 'Planning'
# =============================================================================
#
# =============================================================================
[docs]
class GNNEPDBaseModel(torch.nn.Module):
"""GNN Encoder-Processor-Decoder base model.
Attributes
----------
model_directory : str
Directory where model is stored.
model_name : str, default='gnn_epd_model'
Name of model.
_n_node_in : int
Number of node input features.
_n_node_out : int
Number of node output features.
_n_edge_in : int
Number of edge input features.
_n_edge_out : int
Number of edge output features.
_n_global_in : int
Number of global input features.
_n_global_out : int
Number of global output features.
_n_time_node : int
Number of discrete time steps of nodal features.
If greater than 0, then nodal input features include a time
dimension and message passing layers are RNNs.
_n_time_edge : int
Number of discrete time steps of edge features.
If greater than 0, then edge input features include a time
dimension and message passing layers are RNNs.
_n_time_global : int
Number of discrete time steps of global features.
If greater than 0, then global input features include a time
dimension and message passing layers are RNNs.
_n_message_steps : int
Number of message-passing steps.
_enc_n_hidden_layers : int
Encoder: Number of hidden layers of multilayer feed-forward neural
network update functions.
_pro_n_hidden_layers : int
Processor: Number of hidden layers of multilayer feed-forward
neural network update functions.
_dec_n_hidden_layers : int
Decoder: Number of hidden layers of multilayer feed-forward neural
network update functions.
_hidden_layer_size : int
Number of neurons of hidden layers of multilayer feed-forward
neural network update functions.
_pro_edge_to_node_aggr : {'add',}, default='add'
Processor: Edge-to-node aggregation scheme.
_pro_node_to_global_aggr : {'add',}, default='add'
Processor: Node-to-global aggregation scheme.
_enc_node_hidden_activ_type : str, default='identity'
Encoder: Hidden unit activation function type of node update function
(multilayer feed-forward neural network). Defaults to identity
(linear) unit activation function.
_enc_node_output_activ_type : str, default='identity'
Encoder: Output unit activation function type of node update function
(multilayer feed-forward neural network). Defaults to identity
(linear) unit activation function.
_enc_edge_hidden_activ_type : str, default='identity'
Encoder: Hidden unit activation function type of edge update function
(multilayer feed-forward neural network). Defaults to identity
(linear) unit activation function.
_enc_edge_output_activ_type : str, default='identity'
Encoder: Output unit activation function type of edge update function
(multilayer feed-forward neural network). Defaults to identity
(linear) unit activation function.
_enc_global_hidden_activ_type : str, default='identity'
Encoder: Hidden unit activation function type of global update function
(multilayer feed-forward neural network). Defaults to identity
(linear) unit activation function.
_enc_global_output_activ_type : str, default='identity'
Encoder: Output unit activation function type of global update function
(multilayer feed-forward neural network). Defaults to identity
(linear) unit activation function.
_pro_node_hidden_activ_type : str, default='identity'
Processor: Hidden unit activation function type of node update function
(multilayer feed-forward neural network). Defaults to identity
(linear) unit activation function.
_pro_node_output_activ_type : str, default='identity'
Processor: Output unit activation function type of node update function
(multilayer feed-forward neural network). Defaults to identity
(linear) unit activation function.
_pro_edge_hidden_activ_type : str, default='identity'
Processor: Hidden unit activation function type of edge update function
(multilayer feed-forward neural network). Defaults to identity
(linear) unit activation function.
_pro_edge_output_activ_type : str, default='identity'
Processor: Output unit activation function type of edge update function
(multilayer feed-forward neural network). Defaults to identity
(linear) unit activation function.
_pro_global_hidden_activ_type : str, default='identity'
Processor: Hidden unit activation function type of global update
function (multilayer feed-forward neural network). Defaults to identity
(linear) unit activation function.
_pro_global_output_activ_type : str, default='identity'
Processor: Output unit activation function type of global update
function (multilayer feed-forward neural network). Defaults to identity
(linear) unit activation function.
_dec_node_hidden_activ_type : str, default='identity'
Decoder: Hidden unit activation function type of node update function
(multilayer feed-forward neural network). Defaults to identity
(linear) unit activation function.
_dec_node_output_activ_type : str, default='identity'
Decoder: Output unit activation function type of node update function
(multilayer feed-forward neural network). Defaults to identity
(linear) unit activation function.
_dec_edge_hidden_activ_type : str, default='identity'
Decoder: Hidden unit activation function type of edge update function
(multilayer feed-forward neural network). Defaults to identity
(linear) unit activation function.
_dec_edge_output_activ_type : str, default='identity'
Decoder: Output unit activation function type of edge update function
(multilayer feed-forward neural network). Defaults to identity
(linear) unit activation function.
_dec_global_hidden_activ_type : str, default='identity'
Decoder: Hidden unit activation function type of global update function
(multilayer feed-forward neural network). Defaults to identity
(linear) unit activation function.
_dec_global_output_activ_type : str, default='identity'
Decoder: Output unit activation function type of global update function
(multilayer feed-forward neural network). Defaults to identity
(linear) unit activation function.
_gnn_epd_model : EncodeProcessDecode
GNN-based Encoder-Process-Decoder model.
_device_type : {'cpu', 'cuda'}, default='cpu'
Type of device on which torch.Tensor is allocated.
_device : torch.device
Device on which torch.Tensor is allocated.
is_model_in_normalized : bool, default=False
If True, then model input features are assumed to be normalized
(normalized input data has been seen during model training).
is_model_out_normalized : bool, default=False
If True, then model output features are assumed to be normalized
(normalized output data has been seen during model training).
_data_scalers : dict
Data scaler (item, sklearn.preprocessing.StandardScaler) for each
feature data (key, str).
_available_activ_fn : dict
For each available activation function type (key, str), store the
corresponding PyTorch unit activation function (item, torch.nn.Module).
Methods
-------
init_model_from_file(model_directory)
Initialize GNN-based model from initialization file.
set_device(self, device_type)
Set device on which torch.Tensor is allocated.
get_device(self)
Get device on which torch.Tensor is allocated.
forward(self, node_features_in=None, edge_features_in=None, \
global_features_in=None, edges_indexes=None, batch_vector=None)
Forward propagation.
save_model_init_file(self)
Save model class initialization attributes.
get_input_features_from_graph(self, graph, is_normalized=False)
Get input features from graph.
get_output_features_from_graph(self, graph, is_normalized=False)
Get output features from graph.
get_metadata_from_graph(self, graph)
Get metadata from graph.
predict_output_features(self, input_graph, is_normalized=False)
Predict output features.
save_model_init_state(self)
Save model initial state to file.
save_model_state(self)
Save model state to file.
load_model_state(self)
Load model state from file.
_check_state_file(self, filename)
Check if file is model training epoch state file.
_check_best_state_file(self, filename)
Check if file is model training epoch best state file.
_remove_posterior_state_files(self, epoch)
Delete model training epoch state files posterior to given epoch.
_remove_best_state_files(self)
Delete existent model best state files.
_init_data_scalers(self)
Initialize model data scalers.
set_data_scalers(self, scaler_node_in, scaler_edge_in, scaler_global_in,
scaler_node_out, scaler_edge_out, scaler_global_out)
Set fitted model data scalers.
fit_data_scalers(self, dataset, is_verbose=False)
Fit model data scalers.
get_fitted_data_scaler(self, features_type)
Get fitted model data scalers.
get_fitted_data_scaler(self, features_type)
Get fitted model data scalers.
load_model_data_scalers_from_file(self)
Load data scalers from model initialization file.
check_normalized_return(self)
Check if model data normalization is available.
"""
# Set available unit activation function types
_available_activ_fn = {
str(name).lower(): getattr(torch.nn.modules.activation, name)
for name in torch.nn.modules.activation.__all__}
# Add identity which is not available in torch.nn.modules.activation
_available_activ_fn['identity'] = torch.nn.Identity
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
[docs]
def __init__(self, n_node_in, n_node_out, n_edge_in, n_edge_out,
n_global_in, n_global_out, n_message_steps,
enc_n_hidden_layers, pro_n_hidden_layers, dec_n_hidden_layers,
hidden_layer_size, model_directory,
model_name='gnn_epd_model',
n_time_node=0, n_time_edge=0, n_time_global=0,
is_model_in_normalized=False, is_model_out_normalized=False,
pro_edge_to_node_aggr='add', pro_node_to_global_aggr='add',
enc_node_hidden_activ_type='identity',
enc_node_output_activ_type='identity',
enc_edge_hidden_activ_type='identity',
enc_edge_output_activ_type='identity',
enc_global_hidden_activ_type='identity',
enc_global_output_activ_type='identity',
pro_node_hidden_activ_type='identity',
pro_node_output_activ_type='identity',
pro_edge_hidden_activ_type='identity',
pro_edge_output_activ_type='identity',
pro_global_hidden_activ_type='identity',
pro_global_output_activ_type='identity',
dec_node_hidden_activ_type='identity',
dec_node_output_activ_type='identity',
dec_edge_hidden_activ_type='identity',
dec_edge_output_activ_type='identity',
dec_global_hidden_activ_type='identity',
dec_global_output_activ_type='identity',
is_save_model_init_file=True,
device_type='cpu'):
"""Constructor.
Parameters
----------
n_node_in : int
Number of node input features.
n_node_out : int
Number of node output features.
n_edge_in : int
Number of edge input features.
n_edge_out : int
Number of edge output features.
n_global_in : int
Number of global input features.
n_global_out : int
Number of global output features.
n_message_steps : int
Number of message-passing steps.
enc_n_hidden_layers : int
Encoder: Number of hidden layers of multilayer feed-forward neural
network update functions.
pro_n_hidden_layers : int
Processor: Number of hidden layers of multilayer feed-forward
neural network update functions.
dec_n_hidden_layers : int
Decoder: Number of hidden layers of multilayer feed-forward neural
network update functions.
hidden_layer_size : int
Number of neurons of hidden layers of multilayer feed-forward
neural network update functions.
model_directory : str
Directory where model is stored.
model_name : str, default='gnn_epd_model'
Name of model.
n_time_node : int, default=0
Number of discrete time steps of nodal features.
If greater than 0, then nodal input features include a time
dimension and message passing layers are RNNs.
n_time_edge : int, default=0
Number of discrete time steps of edge features.
If greater than 0, then edge input features include a time
dimension and message passing layers are RNNs.
n_time_global : int, default=0
Number of discrete time steps of global features.
If greater than 0, then global input features include a time
dimension and message passing layers are RNNs.
is_model_in_normalized : bool, default=False
If True, then model input features are assumed to be normalized
(normalized input data has been seen during model training).
is_model_out_normalized : bool, default=False
If True, then model output features are assumed to be normalized
(normalized output data has been seen during model training).
pro_edge_to_node_aggr : {'add',}, default='add'
Processor: Edge-to-node aggregation scheme.
pro_node_to_global_aggr : {'add',}, default='add'
Processor: Node-to-global aggregation scheme.
enc_node_hidden_activ_type : str, default='identity'
Encoder: Hidden unit activation function type of node update
function (multilayer feed-forward neural network). Defaults to
identity (linear) unit activation function.
enc_node_output_activ_type : str, default='identity'
Encoder: Output unit activation function type of node update
function (multilayer feed-forward neural network). Defaults to
identity (linear) unit activation function.
enc_edge_hidden_activ_type : str, default='identity'
Encoder: Hidden unit activation function type of edge update
function (multilayer feed-forward neural network). Defaults to
identity (linear) unit activation function.
enc_edge_output_activ_type : str, default='identity'
Encoder: Output unit activation function type of edge update
function (multilayer feed-forward neural network). Defaults to
identity (linear) unit activation function.
enc_global_hidden_activ_type : str, default='identity'
Encoder: Hidden unit activation function type of global update
function (multilayer feed-forward neural network). Defaults to
identity (linear) unit activation function.
enc_global_output_activ_type : str, default='identity'
Encoder: Output unit activation function type of global update
function (multilayer feed-forward neural network). Defaults to
identity (linear) unit activation function.
pro_node_hidden_activ_type : str, default='identity'
Processor: Hidden unit activation function type of node update
function (multilayer feed-forward neural network). Defaults to
identity (linear) unit activation function.
pro_node_output_activ_type : str, default='identity'
Processor: Output unit activation function type of node update
function (multilayer feed-forward neural network). Defaults to
identity (linear) unit activation function.
pro_edge_hidden_activ_type : str, default='identity'
Processor: Hidden unit activation function type of edge update
function (multilayer feed-forward neural network). Defaults to
identity (linear) unit activation function.
pro_edge_output_activ_type : str, default='identity'
Processor: Output unit activation function type of edge update
function (multilayer feed-forward neural network). Defaults to
identity (linear) unit activation function.
pro_global_hidden_activ_type : str, default='identity'
Processor: Hidden unit activation function type of global update
function (multilayer feed-forward neural network). Defaults to
identity (linear) unit activation function.
pro_global_output_activ_type : str, default='identity'
Processor: Output unit activation function type of global update
function (multilayer feed-forward neural network). Defaults to
identity (linear) unit activation function.
dec_node_hidden_activ_type : str, default='identity'
Decoder: Hidden unit activation function type of node update
function (multilayer feed-forward neural network). Defaults to
identity (linear) unit activation function.
dec_node_output_activ_type : str, default='identity'
Decoder: Output unit activation function type of node update
function (multilayer feed-forward neural network). Defaults to
identity (linear) unit activation function.
dec_edge_hidden_activ_type : str, default='identity'
Decoder: Hidden unit activation function type of edge update
function (multilayer feed-forward neural network). Defaults to
identity (linear) unit activation function.
dec_edge_output_activ_type : str, default='identity'
Decoder: Output unit activation function type of edge update
function (multilayer feed-forward neural network). Defaults to
identity (linear) unit activation function.
dec_global_hidden_activ_type : str, default='identity'
Decoder: Hidden unit activation function type of global update
function (multilayer feed-forward neural network). Defaults to
identity (linear) unit activation function.
dec_global_output_activ_type : str, default='identity'
Decoder: Output unit activation function type of global update
function (multilayer feed-forward neural network). Defaults to
identity (linear) unit activation function.
is_save_model_init_file: bool, default=True
If True, saves model initialization file when model is initialized
(overwritting existent initialization file), False otherwise. When
initializing model from initialization file this option should be
set to False to avoid updating the initialization file and preserve
fitted data scalers.
device_type : {'cpu', 'cuda'}, default='cpu'
Type of device on which torch.Tensor is allocated.
"""
# Initialize from base class
super(GNNEPDBaseModel, self).__init__()
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# Set time attributes
self._n_time_node = n_time_node
self._n_time_edge = n_time_edge
self._n_time_global = n_time_global
# Set number of features
self._n_node_in = n_node_in
self._n_node_out = n_node_out
self._n_edge_in = n_edge_in
self._n_edge_out= n_edge_out
self._n_global_in = n_global_in
self._n_global_out = n_global_out
# Set architecture parameters
self._n_message_steps = n_message_steps
self._enc_n_hidden_layers = enc_n_hidden_layers
self._pro_n_hidden_layers = pro_n_hidden_layers
self._dec_n_hidden_layers = dec_n_hidden_layers
self._hidden_layer_size = hidden_layer_size
self._pro_edge_to_node_aggr = pro_edge_to_node_aggr
self._pro_node_to_global_aggr = pro_node_to_global_aggr
self._enc_node_hidden_activ_type = enc_node_hidden_activ_type
self._enc_node_output_activ_type = enc_node_output_activ_type
self._enc_edge_hidden_activ_type = enc_edge_hidden_activ_type
self._enc_edge_output_activ_type = enc_edge_output_activ_type
self._enc_global_hidden_activ_type = enc_global_hidden_activ_type
self._enc_global_output_activ_type = enc_global_output_activ_type
self._pro_node_hidden_activ_type = pro_node_hidden_activ_type
self._pro_node_output_activ_type = pro_node_output_activ_type
self._pro_edge_hidden_activ_type = pro_edge_hidden_activ_type
self._pro_edge_output_activ_type = pro_edge_output_activ_type
self._pro_global_hidden_activ_type = pro_global_hidden_activ_type
self._pro_global_output_activ_type = pro_global_output_activ_type
self._dec_node_hidden_activ_type = dec_node_hidden_activ_type
self._dec_node_output_activ_type = dec_node_output_activ_type
self._dec_edge_hidden_activ_type = dec_edge_hidden_activ_type
self._dec_edge_output_activ_type = dec_edge_output_activ_type
self._dec_global_hidden_activ_type = dec_global_hidden_activ_type
self._dec_global_output_activ_type = dec_global_output_activ_type
# Set model directory and name
if os.path.isdir(model_directory):
self.model_directory = str(model_directory)
else:
raise RuntimeError('The model directory has not been found.')
if not isinstance(model_name, str):
raise RuntimeError('The model name must be a string.')
else:
self.model_name = model_name
# Set model input and output features normalization
self.is_model_in_normalized = is_model_in_normalized
self.is_model_out_normalized = is_model_out_normalized
# Set device
self.set_device(device_type)
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# Initialize model
self._gnn_epd_model = EncodeProcessDecode(
n_message_steps=n_message_steps,
n_node_out=n_node_out, n_edge_out=n_edge_out,
n_global_out=n_global_out,
n_time_node=self._n_time_node,
n_time_edge=self._n_time_edge,
n_time_global=self._n_time_global,
enc_n_hidden_layers=enc_n_hidden_layers,
pro_n_hidden_layers=pro_n_hidden_layers,
dec_n_hidden_layers=dec_n_hidden_layers,
hidden_layer_size=hidden_layer_size,
n_node_in=n_node_in, n_edge_in=n_edge_in, n_global_in=n_global_in,
pro_edge_to_node_aggr=pro_edge_to_node_aggr,
pro_node_to_global_aggr=pro_node_to_global_aggr,
enc_node_hidden_activation=type(self).get_pytorch_activation(
self._enc_node_hidden_activ_type),
enc_node_output_activation=type(self).get_pytorch_activation(
self._enc_node_output_activ_type),
enc_edge_hidden_activation=type(self).get_pytorch_activation(
self._enc_edge_hidden_activ_type),
enc_edge_output_activation=type(self).get_pytorch_activation(
self._enc_edge_output_activ_type),
enc_global_hidden_activation=type(self).get_pytorch_activation(
self._enc_global_hidden_activ_type),
enc_global_output_activation=type(self).get_pytorch_activation(
self._enc_global_output_activ_type),
pro_node_hidden_activation=type(self).get_pytorch_activation(
self._pro_node_hidden_activ_type),
pro_node_output_activation=type(self).get_pytorch_activation(
self._pro_node_output_activ_type),
pro_edge_hidden_activation=type(self).get_pytorch_activation(
self._pro_edge_hidden_activ_type),
pro_edge_output_activation=type(self).get_pytorch_activation(
self._pro_edge_output_activ_type),
pro_global_hidden_activation=type(self).get_pytorch_activation(
self._pro_global_hidden_activ_type),
pro_global_output_activation=type(self).get_pytorch_activation(
self._pro_global_output_activ_type),
dec_node_hidden_activation=type(self).get_pytorch_activation(
self._dec_node_hidden_activ_type),
dec_node_output_activation=type(self).get_pytorch_activation(
self._dec_node_output_activ_type),
dec_edge_hidden_activation=type(self).get_pytorch_activation(
self._dec_edge_hidden_activ_type),
dec_edge_output_activation=type(self).get_pytorch_activation(
self._dec_edge_output_activ_type),
dec_global_hidden_activation=type(self).get_pytorch_activation(
self._dec_global_hidden_activ_type),
dec_global_output_activation=type(self).get_pytorch_activation(
self._dec_global_output_activ_type),
is_node_res_connect=False, is_edge_res_connect=False)
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# Initialize data scalers
self._init_data_scalers()
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# Save model initialization file
if is_save_model_init_file:
self.save_model_init_file()
# -------------------------------------------------------------------------
[docs]
@staticmethod
def init_model_from_file(model_directory):
"""Initialize model from initialization file.
Initialization file is assumed to be stored in the model directory
under the name model_init_file.pkl.
Parameters
----------
model_directory : str
Directory where model is stored.
"""
# Check model directory
if not os.path.isdir(model_directory):
raise RuntimeError('The model directory has not been found:\n\n'
+ model_directory)
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# Get model initialization file path from model directory
model_init_file_path = os.path.join(model_directory,
'model_init_file' + '.pkl')
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# Load model initialization attributes from file
if not os.path.isfile(model_init_file_path):
raise RuntimeError('The model initialization file has not been '
'found:\n\n' + model_init_file_path)
else:
with open(model_init_file_path, 'rb') as model_init_file:
model_init_attributes = pickle.load(model_init_file)
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# Get model initialization attributes
model_init_args = model_init_attributes['model_init_args']
# Update model directory
model_init_args['model_directory'] = model_directory
# Initialize model
model = GNNEPDBaseModel(**model_init_args,
is_save_model_init_file=False)
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# Set model data scalers
model_data_scalers = model_init_attributes['model_data_scalers']
model._data_scalers = model_data_scalers
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
return model
# -------------------------------------------------------------------------
[docs]
def set_device(self, device_type):
"""Set device on which torch.Tensor is allocated.
Parameters
----------
device_type : {'cpu', 'cuda'}
Type of device on which torch.Tensor is allocated.
device : torch.device
Device on which torch.Tensor is allocated.
"""
if device_type in ('cpu', 'cuda'):
if device_type == 'cuda' and not torch.cuda.is_available():
raise RuntimeError('PyTorch with CUDA is not available. '
'Please set the model device type as CPU '
'as:\n\n' + 'model.set_device(\'cpu\').')
self._device_type = device_type
self._device = torch.device(device_type)
else:
raise RuntimeError('Invalid device type.')
# -------------------------------------------------------------------------
[docs]
def get_device(self):
"""Get device on which torch.Tensor is allocated.
Parameters
----------
device_type : {'cpu', 'cuda'}
Type of device on which torch.Tensor is allocated.
device : torch.device
Device on which torch.Tensor is allocated.
"""
return self.device_type, self.device
# -------------------------------------------------------------------------
[docs]
def forward(self, node_features_in=None, edge_features_in=None,
global_features_in=None, edges_indexes=None,
batch_vector=None):
"""Forward propagation.
Parameters
----------
node_features_in : {torch.Tensor, None}, default=None
Nodes features input matrix stored as a torch.Tensor(2d) of shape
(n_nodes, n_features).
edge_features_in : {torch.Tensor, None}, default=None
Edges features input matrix stored as a torch.Tensor(2d) of shape
(n_edges, n_features).
global_features_in : {torch.Tensor, None}, default=None
Global features input matrix stored as a torch.Tensor(2d) of shape
(1, n_features).
edges_indexes : {torch.Tensor, None}, default=None
Edges indexes matrix stored as torch.Tensor(2d) with shape
(2, n_edges), where the i-th global is stored in
edges_indexes[:, i] as (start_node_index, end_node_index).
batch_vector : torch.Tensor, default=None
Batch vector stored as torch.Tensor(1d) of shape (n_nodes,),
assigning each node to a specific batch subgraph. Required to
process a graph holding multiple isolated subgraphs when batch
size is greater than 1.
Returns
-------
node_features_out : {torch.Tensor, None}
Nodes features output matrix stored as a torch.Tensor(2d) of shape
(n_nodes, n_features).
edge_features_out : {torch.Tensor, None}
Edges features output matrix stored as a torch.Tensor(2d) of shape
(n_edges, n_features).
global_features_out : {torch.Tensor, None}
Global features output matrix stored as a torch.Tensor(2d) of shape
(1, n_features).
"""
# Check input node features
if node_features_in is not None:
if not isinstance(node_features_in, torch.Tensor):
raise RuntimeError('Node input features were not provided '
'as torch.Tensor.')
# Check input edge features
if edge_features_in is not None:
if not isinstance(edge_features_in, torch.Tensor):
raise RuntimeError('Edge input features were not provided '
'as torch.Tensor.')
# Check input global features
if global_features_in is not None:
if not isinstance(global_features_in, torch.Tensor):
raise RuntimeError('Global input features were not provided '
'torch.Tensor.')
# Check edges indexes
if edges_indexes is not None:
if not isinstance(edges_indexes, torch.Tensor):
raise RuntimeError('Edges indexes were not provided provided '
'as torch.Tensor.')
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# Predict output features
node_features_out, edge_features_out, global_features_out = \
self.predict_output_features(
node_features_in=node_features_in,
edge_features_in=edge_features_in,
global_features_in=global_features_in,
edges_indexes=edges_indexes, batch_vector=batch_vector)
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
return node_features_out, edge_features_out, global_features_out
# -------------------------------------------------------------------------
[docs]
def save_model_init_file(self):
"""Save model initialization file.
Initialization file is stored in the model directory under the name
model_init_file.pkl.
Initialization file contains a dictionary model_init_attributes that
includes:
'model_init_args' - Model initialization parameters
'model_data_scalers' - Model fitted data scalers
"""
# Check model directory
if not os.path.isdir(self.model_directory):
raise RuntimeError('The model directory has not been found:\n\n'
+ self.model_directory)
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# Initialize model initialization attributes
model_init_attributes = {}
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# Build initialization parameters
model_init_args = {}
model_init_args['n_node_in'] = self._n_node_in
model_init_args['n_node_out'] = self._n_node_out
model_init_args['n_edge_in'] = self._n_edge_in
model_init_args['n_edge_out'] = self._n_edge_out
model_init_args['n_global_in'] = self._n_global_in
model_init_args['n_global_out'] = self._n_global_out
model_init_args['n_time_node'] = self._n_time_node
model_init_args['n_time_edge'] = self._n_time_edge
model_init_args['n_time_global'] = self._n_time_global
model_init_args['n_message_steps'] = self._n_message_steps
model_init_args['dec_n_hidden_layers'] = self._enc_n_hidden_layers
model_init_args['pro_n_hidden_layers'] = self._pro_n_hidden_layers
model_init_args['enc_n_hidden_layers'] = self._dec_n_hidden_layers
model_init_args['pro_edge_to_node_aggr'] = \
self._pro_edge_to_node_aggr
model_init_args['pro_node_to_global_aggr'] = \
self._pro_node_to_global_aggr
model_init_args['hidden_layer_size'] = self._hidden_layer_size
model_init_args['enc_node_hidden_activ_type'] = \
self._enc_node_hidden_activ_type
model_init_args['enc_node_output_activ_type'] = \
self._enc_node_output_activ_type
model_init_args['enc_edge_hidden_activ_type'] = \
self._enc_edge_hidden_activ_type
model_init_args['enc_edge_output_activ_type'] = \
self._enc_edge_output_activ_type
model_init_args['enc_global_hidden_activ_type'] = \
self._enc_global_hidden_activ_type
model_init_args['enc_global_output_activ_type'] = \
self._enc_global_output_activ_type
model_init_args['pro_node_hidden_activ_type'] = \
self._pro_node_hidden_activ_type
model_init_args['pro_node_output_activ_type'] = \
self._pro_node_output_activ_type
model_init_args['pro_edge_hidden_activ_type'] = \
self._pro_edge_hidden_activ_type
model_init_args['pro_edge_output_activ_type'] = \
self._enc_edge_output_activ_type
model_init_args['pro_global_hidden_activ_type'] = \
self._pro_global_hidden_activ_type
model_init_args['pro_global_output_activ_type'] = \
self._enc_global_output_activ_type
model_init_args['dec_node_hidden_activ_type'] = \
self._dec_node_hidden_activ_type
model_init_args['dec_node_output_activ_type'] = \
self._dec_node_output_activ_type
model_init_args['dec_edge_hidden_activ_type'] = \
self._dec_edge_hidden_activ_type
model_init_args['dec_edge_output_activ_type'] = \
self._dec_edge_output_activ_type
model_init_args['dec_global_hidden_activ_type'] = \
self._dec_global_hidden_activ_type
model_init_args['dec_global_output_activ_type'] = \
self._dec_global_output_activ_type
model_init_args['model_directory'] = self.model_directory
model_init_args['model_name'] = self.model_name
model_init_args['is_model_in_normalized'] = self.is_model_in_normalized
model_init_args['is_model_out_normalized'] = \
self.is_model_out_normalized
model_init_args['device_type'] = self._device_type
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# Assemble initialization parameters
model_init_attributes['model_init_args'] = model_init_args
# Assemble model data scalers
model_init_attributes['model_data_scalers'] = self._data_scalers
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# Set model initialization file path
model_init_file_path = os.path.join(self.model_directory,
'model_init_file' + '.pkl')
# Save model initialization file
with open(model_init_file_path, 'wb') as init_file:
pickle.dump(model_init_attributes, init_file)
# -------------------------------------------------------------------------
# -------------------------------------------------------------------------
[docs]
def get_output_features_from_graph(self, graph, is_normalized=False):
"""Get output features from graph.
Parameters
----------
graph : torch_geometric.data.Data
Homogeneous graph.
is_normalized : bool, default=False
If True, get normalized output features from graph, False
otherwise.
Returns
-------
node_features_out : {torch.Tensor, None}
Nodes features output matrix stored as a torch.Tensor(2d) of shape
(n_nodes, n_features).
edge_features_out : {torch.Tensor, None}
Edges features output matrix stored as a torch.Tensor(2d) of shape
(n_edges, n_features).
global_features_out : {torch.Tensor, None}
Global features output matrix stored as a torch.Tensor(2d) of shape
(1, n_features).
"""
# Check input graph
if not isinstance(graph, torch_geometric.data.Data):
raise RuntimeError('Input graph is not torch_geometric.data.Data.')
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# Get features from graph
if 'y' in graph.keys() and isinstance(graph.y, torch.Tensor):
node_features_out = graph.y.clone()
else:
node_features_out = None
if ('edge_targets_matrix' in graph.keys()
and isinstance(graph.edge_targets_matrix, torch.Tensor)):
edge_features_out = graph.edge_targets_matrix.clone()
else:
edge_features_out = None
if ('global_targets_matrix' in graph.keys()
and isinstance(graph.global_targets_matrix, torch.Tensor)):
global_features_out = graph.global_targets_matrix.clone()
else:
global_features_out = None
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# Check consistency with simulator
if self._n_time_node > 0 :
if (node_features_out is not None
and node_features_out.shape[-1] != \
self._n_node_out*self._n_time_node):
raise RuntimeError(f'Input graph '
f'({node_features_out.shape[-1]}) '
f'and simulator '
f'({self._n_node_out*self._n_time_node}) '
f'number of output node features are not '
f'consistent.')
else:
if (node_features_out is not None
and node_features_out.shape[-1] != self._n_node_out):
raise RuntimeError(f'Input graph '
f'({node_features_out.shape[-1]}) '
f'and simulator ({self._n_node_out}) '
f'number of output node features are not '
f'consistent.')
if self._n_time_edge > 0:
if (edge_features_out is not None
and edge_features_out.shape[-1] != \
self._n_edge_out*self._n_time_edge):
raise RuntimeError(f'Input graph '
f'({edge_features_out.shape[-1]}) '
f'and simulator '
f'({self._n_edge_out*self._n_time_edge}) '
f'number of output edge features are not '
f'consistent.')
else:
if (edge_features_out is not None
and edge_features_out.shape[-1] != self._n_edge_out):
raise RuntimeError(f'Input graph '
f'({edge_features_out.shape[-1]}) and '
f'simulator ({self._n_edge_out}) number '
f'of output edge features are not '
f'consistent.')
if self._n_time_global > 0:
if (global_features_out is not None
and global_features_out.shape[-1] != \
self._n_global_out*self._n_time_global):
raise RuntimeError(f'Input graph ('
f'({global_features_out.shape[-1]}) '
f'and simulator '
f'''({self._n_global_out*
self._n_time_global}) '''
f'number of output global features are not '
f'consistent.')
else:
if (global_features_out is not None
and global_features_out.shape[-1] != self._n_global_out):
raise RuntimeError(f'Input graph '
f'({global_features_out.shape[-1]}) '
f'and simulator ({self._n_global_out}) '
f'number of output global features are not '
f'consistent.')
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# Normalize output features data
if is_normalized:
if node_features_out is not None:
node_features_out = self.data_scaler_transform(
tensor=node_features_out,
features_type='node_features_out',
mode='normalize')
if edge_features_out is not None:
edge_features_out = self.data_scaler_transform(
tensor=edge_features_out,
features_type='edge_features_out',
mode='normalize')
if global_features_out is not None:
global_features_out = self.data_scaler_transform(
tensor=global_features_out,
features_type='global_features_out',
mode='normalize')
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
return node_features_out, edge_features_out, global_features_out
# -------------------------------------------------------------------------
# -------------------------------------------------------------------------
[docs]
def predict_output_features(self, node_features_in=None,
edge_features_in=None, global_features_in=None,
edges_indexes=None, batch_vector=None):
"""Predict output features.
Parameters
----------
node_features_in : {torch.Tensor, None}, default=None
Nodes features input matrix stored as a torch.Tensor(2d) of shape
(n_nodes, n_features).
edge_features_in : {torch.Tensor, None}, default=None
Edges features input matrix stored as a torch.Tensor(2d) of shape
(n_edges, n_features).
global_features_in : {torch.Tensor, None}, default=None
Global features input matrix stored as a torch.Tensor(2d) of shape
(1, n_features).
edges_indexes : {torch.Tensor, None}, default=None
Edges indexes matrix stored as torch.Tensor(2d) with shape
(2, n_edges), where the i-th global is stored in
edges_indexes[:, i] as (start_node_index, end_node_index).
batch_vector : torch.Tensor, default=None
Batch vector stored as torch.Tensor(1d) of shape (n_nodes,),
assigning each node to a specific batch subgraph. Required to
process a graph holding multiple isolated subgraphs when batch
size is greater than 1.
Returns
-------
node_features_out : {torch.Tensor, None}
Nodes features output matrix stored as a torch.Tensor(2d) of shape
(n_nodes, n_features).
edge_features_out : {torch.Tensor, None}
Edges features output matrix stored as a torch.Tensor(2d) of shape
(n_edges, n_features).
global_features_out : {torch.Tensor, None}
Global features output matrix stored as a torch.Tensor(2d) of shape
(1, n_features).
"""
# Check input node features
if node_features_in is not None:
if not isinstance(node_features_in, torch.Tensor):
raise RuntimeError('Node input features were not provided '
'as torch.Tensor.')
# Check input edge features
if edge_features_in is not None:
if not isinstance(edge_features_in, torch.Tensor):
raise RuntimeError('Edge input features were not provided '
'as torch.Tensor.')
# Check input global features
if global_features_in is not None:
if not isinstance(global_features_in, torch.Tensor):
raise RuntimeError('Global input features were not provided '
'torch.Tensor.')
# Check edges indexes
if edges_indexes is not None:
if not isinstance(edges_indexes, torch.Tensor):
raise RuntimeError('Edges indexes were not provided provided '
'as torch.Tensor.')
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# Predict output features
node_features_out, edge_features_out, global_features_out = \
self._gnn_epd_model(edges_indexes=edges_indexes,
node_features_in=node_features_in,
edge_features_in=edge_features_in,
global_features_in=global_features_in,
batch_vector=batch_vector)
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
return node_features_out, edge_features_out, global_features_out
# -------------------------------------------------------------------------
[docs]
@classmethod
def get_pytorch_activation(cls, activation_type, **kwargs):
"""Get PyTorch unit activation function.
Parameters
----------
activation_type : {'identity', 'relu', 'tanh'}
Unit activation function type:
'identity' : Linear (torch.nn.Identity)
'relu' : Rectified linear unit (torch.nn.Identity)
'tanh' : Hyperbolic Tangent (torch.nn.Tanh)
**kwargs
Arguments of torch.nn._Module initializer.
Returns
-------
activation_function : torch.nn._Module
PyTorch unit activation function.
"""
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# Get activation function class
try:
activation_function_cls = cls._available_activ_fn[activation_type]
except ValueError:
raise RuntimeError(f'Unknown or unavailable PyTorch unit '
f'activation function: \'{activation_type}\'.'
'\n\nAvailable: '
f'{cls._available_activ_fn.keys()}.')
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# Initialize activation function
activation_function = activation_function_cls(**kwargs)
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# Return activation function
return activation_function
# -------------------------------------------------------------------------
[docs]
def save_model_init_state(self):
"""Save model initial state to file.
Model state file is stored in model_directory under the name
< model_name >-init.pt.
"""
# Check model directory
if not os.path.isdir(self.model_directory):
raise RuntimeError('The model directory has not been found:\n\n'
+ self.model_directory)
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# Set model state filename
model_state_file = self.model_name + '-init'
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# Set model state file path
model_path = os.path.join(self.model_directory,
model_state_file + '.pt')
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# Save model state
torch.save(self.state_dict(), model_path)
# -------------------------------------------------------------------------
[docs]
def save_model_state(self, epoch=None, is_best_state=False,
is_remove_posterior=True):
"""Save model state to file.
Model state file is stored in model_directory under the name
< model_name >.pt or < model_name >-< epoch >.pt if epoch is known.
Model state file corresponding to the best performance is stored in
model_directory under the name < model_name >-best.pt or
< model_name >-< epoch >-best.pt if epoch is known.
Parameters
----------
epoch : int, default=None
Training epoch corresponding to current model state.
is_best_state : bool, default=False
If True, save model state file corresponding to the best
performance instead of regular state file.
is_remove_posterior : bool, default=True
Remove model and optimizer state files corresponding to training
epochs posterior to the saved state file. Effective only if saved
training epoch is known.
"""
# Check model directory
if not os.path.isdir(self.model_directory):
raise RuntimeError('The model directory has not been found:\n\n'
+ self.model_directory)
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# Set model state filename
model_state_file = self.model_name
# Append epoch
if isinstance(epoch, int):
model_state_file += '-' + str(epoch)
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# Set model state corresponding to the best performance
if is_best_state:
# Append best performance
model_state_file += '-' + 'best'
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# Remove any existent best model state file
self._remove_best_state_files()
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# Set model state file path
model_path = os.path.join(self.model_directory,
model_state_file + '.pt')
# Save model state
torch.save(self.state_dict(), model_path)
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# Delete model epoch state files posterior to saved epoch
if isinstance(epoch, int) and is_remove_posterior:
self._remove_posterior_state_files(epoch)
# -------------------------------------------------------------------------
[docs]
def load_model_state(self, load_model_state=None,
is_remove_posterior=True):
"""Load model state from file.
Model state file is stored in model_directory under the name
< model_name >.pt or < model_name >-< epoch >.pt if epoch is known.
Model state file corresponding to the best performance is stored in
model_directory under the name < model_name >-best.pt or
< model_name >-< epoch >-best.pt if epoch if known.
Model initial state file is stored in model directory under the name
< model_name >-init.pt
Parameters
----------
load_model_state : {'best', 'last', int, None}, default=None
Load available GNN-based model state from the model directory.
Options:
'best' : Model state corresponding to best performance
'last' : Model state corresponding to highest training epoch
int : Model state corresponding to given training epoch
'init' : Model initial state
None : Model default state file
is_remove_posterior : bool, default=True
Remove model state files corresponding to training epochs posterior
to the loaded state file. Effective only if loaded training epoch
is known.
Returns
-------
epoch : int
Loaded model state training epoch. Defaults to None if training
epoch is unknown.
"""
# Check model directory
if not os.path.isdir(self.model_directory):
raise RuntimeError('The model directory has not been found:\n\n'
+ self.model_directory)
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
if load_model_state == 'best':
# Get state files in model directory
directory_list = os.listdir(self.model_directory)
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# Initialize model best state files epochs
best_state_epochs = []
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# Loop over files in model directory
for filename in directory_list:
# Check if file is model epoch best state file
is_best_state_file, best_state_epoch = \
self._check_best_state_file(filename)
# Store model best state file training epoch
if is_best_state_file:
best_state_epochs.append(best_state_epoch)
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# Set model best state file
if not best_state_epochs:
raise RuntimeError('Model best state file has not been found '
'in directory:\n\n' + self.model_directory)
elif len(best_state_epochs) > 1:
raise RuntimeError('Two or more model best state files have '
'been found in directory:'
'\n\n' + self.model_directory)
else:
# Set best state epoch
epoch = best_state_epochs[0]
# Set model best state file
model_state_file = self.model_name
if isinstance(epoch, int):
model_state_file += '-' + str(epoch)
model_state_file += '-' + 'best'
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# Delete model epoch state files posterior to loaded epoch
if isinstance(epoch, int) and is_remove_posterior:
self._remove_posterior_state_files(epoch)
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
elif load_model_state == 'last':
# Get state files in model directory
directory_list = os.listdir(self.model_directory)
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# Initialize model state files training epochs
epochs = []
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# Loop over files in model directory
for filename in directory_list:
# Check if file is model epoch state file
is_state_file, epoch = self._check_state_file(filename)
# Store model state file training epoch
if is_state_file:
epochs.append(epoch)
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# Set highest epoch model state file
if epochs:
# Set highest epoch
epoch = max(epochs)
# Set highest epoch model state file
model_state_file = self.model_name + '-' + str(epoch)
else:
raise RuntimeError('Model state files corresponding to epochs '
'have not been found in directory:\n\n'
+ self.model_directory)
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# Set model state filename
elif isinstance(load_model_state, int):
# Get epoch
epoch = load_model_state
# Set model state filename with epoch
model_state_file = self.model_name + '-' + str(int(epoch))
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# Delete model epoch state files posterior to loaded epoch
if is_remove_posterior:
self._remove_posterior_state_files(epoch)
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
elif load_model_state == 'init':
# Set model initial state file
model_state_file = self.model_name + '-init'
# Set epoch as unknown
epoch = 0
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# Delete model epoch state files posterior to loaded epoch
if is_remove_posterior:
self._remove_posterior_state_files(epoch)
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
else:
# Set model state filename
model_state_file = self.model_name
# Set epoch as unknown
epoch = None
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# Set model state file path
model_path = os.path.join(self.model_directory,
model_state_file + '.pt')
# Check model state file
if not os.path.isfile(model_path):
raise RuntimeError('Model state file has not been found:\n\n'
+ model_path)
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# Load model state
self.load_state_dict(torch.load(model_path,
map_location=torch.device('cpu'),
weights_only=True))
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
return epoch
# -------------------------------------------------------------------------
[docs]
def _check_state_file(self, filename):
"""Check if file is model training epoch state file.
Model training epoch state file is stored in model_directory under the
name < model_name >-< epoch >.pt.
Parameters
----------
filename : str
File name.
Returns
-------
is_state_file : bool
True if model training epoch state file, False otherwise.
epoch : {None, int}
Training epoch corresponding to model state file if
is_state_file=True, None otherwise.
"""
# Check if file is model epoch state file
is_state_file = bool(re.search(r'^' + self.model_name + r'-[0-9]+'
+ r'\.pt', filename))
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
epoch = None
if is_state_file:
# Get model state epoch
epoch = int(os.path.splitext(filename)[0].split('-')[-1])
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
return is_state_file, epoch
# -------------------------------------------------------------------------
[docs]
def _check_best_state_file(self, filename):
"""Check if file is model best state file.
Model state file corresponding to the best performance is stored in
model_directory under the name < model_name >-best.pt. or
< model_name >-< epoch >-best.pt if the training epoch is known.
Parameters
----------
filename : str
File name.
Returns
-------
is_best_state_file : bool
True if model training epoch state file, False otherwise.
epoch : {None, int}
Training epoch corresponding to model state file if
is_best_state_file=True and training epoch is known, None
otherwise.
"""
# Check if file is model epoch best state file
is_best_state_file = bool(re.search(r'^' + self.model_name
+ r'-?[0-9]*' + r'-best' + r'\.pt',
filename))
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
epoch = None
if is_best_state_file:
# Get model state epoch
epoch = int(os.path.splitext(filename)[0].split('-')[-2])
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
return is_best_state_file, epoch
# -------------------------------------------------------------------------
[docs]
def _remove_posterior_state_files(self, epoch):
"""Delete model training epoch state files posterior to given epoch.
Parameters
----------
epoch : int
Training epoch.
"""
# Get files in model directory
directory_list = os.listdir(self.model_directory)
# Loop over files in model directory
for filename in directory_list:
# Check if file is model epoch state file
is_state_file, file_epoch = self._check_state_file(filename)
# Delete model epoch state file posterior to given epoch
if is_state_file and file_epoch > epoch:
os.remove(os.path.join(self.model_directory, filename))
# -------------------------------------------------------------------------
[docs]
def _remove_best_state_files(self):
"""Delete existent model best state files."""
# Get files in model directory
directory_list = os.listdir(self.model_directory)
# Loop over files in model directory
for filename in directory_list:
# Check if file is model best state file
is_best_state_file, _ = self._check_best_state_file(filename)
# Delete state file
if is_best_state_file:
os.remove(os.path.join(self.model_directory, filename))
# -------------------------------------------------------------------------
[docs]
def _init_data_scalers(self):
"""Initialize model data scalers."""
self._data_scalers = {}
self._data_scalers['node_features_in'] = None
self._data_scalers['edge_features_in'] = None
self._data_scalers['global_features_in'] = None
self._data_scalers['node_features_out'] = None
self._data_scalers['edge_features_out'] = None
self._data_scalers['global_features_out'] = None
# -------------------------------------------------------------------------
[docs]
def set_data_scalers(self, scaler_node_in, scaler_edge_in,
scaler_global_in, scaler_node_out, scaler_edge_out,
scaler_global_out):
"""Set fitted model data scalers.
Parameters
----------
scaler_node_in : {TorchMinMaxScaler, TorchMinMaxScaler}
Data scaler for input node features.
scaler_edge_in : {TorchMinMaxScaler, TorchMinMaxScaler}
Data scaler for input edge features.
scaler_global_in : {TorchMinMaxScaler, TorchMinMaxScaler}
Data scaler for input global features.
scaler_node_out : {TorchMinMaxScaler, TorchMinMaxScaler}
Data scaler for output node features.
scaler_edge_out : {TorchMinMaxScaler, TorchMinMaxScaler}
Data scaler for output edge features.
scaler_global_out : {TorchMinMaxScaler, TorchMinMaxScaler}
Data scaler for output global features.
"""
# Initialize data scalers
self._init_data_scalers()
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# Set fitted data scalers
self._data_scalers['node_features_in'] = scaler_node_in
self._data_scalers['edge_features_in'] = scaler_edge_in
self._data_scalers['global_features_in'] = scaler_global_in
self._data_scalers['node_features_out'] = scaler_node_out
self._data_scalers['edge_features_out'] = scaler_edge_out
self._data_scalers['global_features_out'] = scaler_global_out
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# Update model initialization file with fitted data scalers
if self._is_save_model_init_file:
self.save_model_init_file()
# -------------------------------------------------------------------------
[docs]
def fit_data_scalers(self, dataset, is_verbose=False,
tqdm_flavor='default'):
"""Fit model data scalers.
Data scalars are set a standard scalers where features are normalized
by removing the mean and scaling to unit variance.
Calling this method turns on model data normalization.
Parameters
----------
dataset : torch.utils.data.Dataset
GNN-based data set. Each sample corresponds to a
torch_geometric.data.Data object describing a homogeneous graph.
is_verbose : bool, default=False
If True, enable verbose output.
tqdm_flavor : {'default', 'notebook'}, default='default'
Type of tqdm progress bar to use when is_verbose=True.
"""
if is_verbose:
print('\nFitting GNN-based model data scalers'
'\n------------------------------------\n')
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# Initialize data scalers
self._init_data_scalers()
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# Instantiate data scalers
scaler_node_in = None
scaler_node_out = None
if self._n_time_node > 0:
if self._n_node_in > 0:
scaler_node_in = TorchStandardScaler(
n_features=self._n_node_in*self._n_time_node,
device_type=self._device_type)
if self._n_node_out > 0:
scaler_node_out = TorchStandardScaler(
n_features=self._n_node_out*self._n_time_node,
device_type=self._device_type)
else:
if self._n_node_in > 0:
scaler_node_in = TorchStandardScaler(
n_features=self._n_node_in, device_type=self._device_type)
if self._n_node_out > 0:
scaler_node_out = TorchStandardScaler(
n_features=self._n_node_out, device_type=self._device_type)
scaler_edge_in = None
scaler_edge_out = None
if self._n_time_edge > 0:
if self._n_edge_in > 0:
scaler_edge_in = TorchStandardScaler(
n_features=self._n_edge_in*self._n_time_edge,
device_type=self._device_type)
if self._n_edge_out > 0:
scaler_edge_out = TorchStandardScaler(
n_features=self._n_edge_out*self._n_time_edge,
device_type=self._device_type)
else:
if self._n_edge_in > 0:
scaler_edge_in = TorchStandardScaler(
n_features=self._n_edge_in, device_type=self._device_type)
if self._n_edge_out > 0:
scaler_edge_out = TorchStandardScaler(
n_features=self._n_edge_out, device_type=self._device_type)
scaler_global_in = None
scaler_global_out = None
if self._n_time_global > 0:
if self._n_global_in > 0:
scaler_global_in = TorchStandardScaler(
n_features=self._n_global_in*self._n_time_global,
device_type=self._device_type)
if self._n_global_out > 0:
scaler_global_out = TorchStandardScaler(
n_features=self._n_global_out*self._n_time_global,
device_type=self._device_type)
else:
if self._n_global_in > 0:
scaler_global_in = TorchStandardScaler(
n_features=self._n_global_in,
device_type=self._device_type)
if self._n_global_out > 0:
scaler_global_out = TorchStandardScaler(
n_features=self._n_global_out,
device_type=self._device_type)
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
if self._n_time_node > 0:
# Time series data
# Get scaling parameters and fit data scalers: node input features
if self._n_node_in > 0:
mean, std = graph_standard_partial_fit(
dataset, features_type='node_features_in',
n_features=self._n_node_in*self._n_time_node,
is_verbose=is_verbose, tqdm_flavor=tqdm_flavor)
scaler_node_in.set_mean_and_std(mean, std)
# Get scaling parameters and fit data scalers: node output features
if self._n_node_out > 0:
mean, std = graph_standard_partial_fit(
dataset, features_type='node_features_out',
n_features=self._n_node_out*self._n_time_node,
is_verbose=is_verbose, tqdm_flavor=tqdm_flavor)
scaler_node_out.set_mean_and_std(mean, std)
else:
# No time series data
# Get scaling parameters and fit data scalers: node input features
if self._n_node_in > 0:
mean, std = graph_standard_partial_fit(
dataset, features_type='node_features_in',
n_features=self._n_node_in, is_verbose=is_verbose)
scaler_node_in.set_mean_and_std(mean, std)
# Get scaling parameters and fit data scalers: node output features
if self._n_node_out > 0:
mean, std = graph_standard_partial_fit(
dataset, features_type='node_features_out',
n_features=self._n_node_out, is_verbose=is_verbose)
scaler_node_out.set_mean_and_std(mean, std)
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
if self._n_time_edge > 0:
# Get scaling parameters and fit data scalers: edge input features
if self._n_edge_in > 0:
mean, std = graph_standard_partial_fit(
dataset, features_type='edge_features_in',
n_features=self._n_edge_in*self._n_time_edge,
is_verbose=is_verbose)
scaler_edge_in.set_mean_and_std(mean, std)
# Get scaling parameters and fit data scalers: edge output features
if self._n_edge_out > 0:
mean, std = graph_standard_partial_fit(
dataset, features_type='edge_features_out',
n_features=self._n_edge_out*self._n_time_edge,
is_verbose=is_verbose)
scaler_edge_out.set_mean_and_std(mean, std)
else:
# Get scaling parameters and fit data scalers: edge input features
if self._n_edge_in > 0:
mean, std = graph_standard_partial_fit(
dataset, features_type='edge_features_in',
n_features=self._n_edge_in, is_verbose=is_verbose)
scaler_edge_in.set_mean_and_std(mean, std)
# Get scaling parameters and fit data scalers: edge output features
if self._n_edge_out > 0:
mean, std = graph_standard_partial_fit(
dataset, features_type='edge_features_out',
n_features=self._n_edge_out,is_verbose=is_verbose)
scaler_edge_out.set_mean_and_std(mean, std)
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
if self._n_time_global > 0:
# Get scaling parameters and fit data scalers:
# global input features
if self._n_global_in > 0:
mean, std = graph_standard_partial_fit(
dataset, features_type='global_features_in',
n_features=self._n_global_in*self._n_time_global,
is_verbose=is_verbose)
scaler_global_in.set_mean_and_std(mean, std)
# Get scaling parameters and fit data scalers:
# global output features
if self._n_global_out > 0:
mean, std = graph_standard_partial_fit(
dataset, features_type='global_features_out',
n_features=self._n_global_out*self._n_time_global,
is_verbose=is_verbose)
scaler_global_out.set_mean_and_std(mean, std)
else:
# Get scaling parameters and fit data scalers:
# global input features
if self._n_global_in > 0:
mean, std = graph_standard_partial_fit(
dataset, features_type='global_features_in',
n_features=self._n_global_in, is_verbose=is_verbose)
scaler_global_in.set_mean_and_std(mean, std)
# Get scaling parameters and fit data scalers:
# global output features
if self._n_global_out > 0:
mean, std = graph_standard_partial_fit(
dataset, features_type='global_features_out',
n_features=self._n_global_out, is_verbose=is_verbose)
scaler_global_out.set_mean_and_std(mean, std)
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
if is_verbose:
print('\n> Setting fitted standard scalers...\n')
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# Set fitted data scalers
self._data_scalers['node_features_in'] = scaler_node_in
self._data_scalers['node_features_out'] = scaler_node_out
self._data_scalers['edge_features_in'] = scaler_edge_in
self._data_scalers['global_features_in'] = scaler_global_in
self._data_scalers['edge_features_out'] = scaler_edge_out
self._data_scalers['global_features_out'] = scaler_global_out
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# Update model initialization file with fitted data scalers
self.save_model_init_file()
# -------------------------------------------------------------------------
[docs]
def get_fitted_data_scaler(self, features_type):
"""Get fitted model data scalers.
Parameters
----------
features_type : str
Features for which data scaler is required:
'node_features_in' : Node features input matrix
'edge_features_in' : Edge features input matrix
'global_features_in' : Global features input matrix
'node_features_out' : Node features output matrix
'edge_features_out' : Edge features output matrix
'global_features_out' : Global features output matrix
Returns
-------
data_scaler : sklearn.preprocessing.StandardScaler
Fitted data scaler.
"""
# Get fitted data scaler
if features_type not in self._data_scalers.keys():
raise RuntimeError(f'Unknown data scaler for {features_type}.')
elif self._data_scalers[features_type] is None:
raise RuntimeError(f'Data scaler for {features_type} has not '
'been fitted. Fit data scalers by calling '
'method fit_data_scalers() before training '
'or predicting with the model.')
else:
data_scaler = self._data_scalers[features_type]
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
return data_scaler
# -------------------------------------------------------------------------
# -------------------------------------------------------------------------
[docs]
def load_model_data_scalers_from_file(self):
"""Load data scalers from model initialization file."""
# Check model directory
if not os.path.isdir(self.model_directory):
raise RuntimeError('The model directory has not '
'been found:\n\n' + self.model_directory)
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# Get model initialization file path from model directory
model_init_file_path = os.path.join(self.model_directory,
'model_init_file' + '.pkl')
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# Load model initialization attributes from file
if not os.path.isfile(model_init_file_path):
raise RuntimeError('The model initialization file '
'has not been found:\n\n'
+ model_init_file_path)
else:
with open(model_init_file_path, 'rb') as model_init_file:
model_init_attributes = pickle.load(model_init_file)
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# Load model data scalers
model_data_scalers = model_init_attributes['model_data_scalers']
self._data_scalers = model_data_scalers
# -------------------------------------------------------------------------
[docs]
def check_normalized_return(self):
"""Check if model data normalization is available."""
if self._data_scalers is None:
raise RuntimeError('Data scalers for model features have not '
'been set or fitted. Call set_data_scalers() '
'or fit_data_scalers() to make model '
'normalization procedures available.')
if all([x is None for x in self._data_scalers.values()]):
raise RuntimeError('Data scalers for model features have not '
'been set or fitted. Call set_data_scalers() '
'or fit_data_scalers() to make model '
'normalization procedures available.')
# =============================================================================
def graph_standard_partial_fit(dataset, features_type, n_features,
is_verbose=False, tqdm_flavor='default'):
"""Perform batch fitting of standardization data scalers.
Parameters
----------
dataset : torch.utils.data.Dataset
GNN-based data set. Each sample corresponds to a
torch_geometric.data.Data object describing a homogeneous graph.
features_type : str
Features for which data scaler is required:
'node_features_in' : Node features input matrix
'edge_features_in' : Edge features input matrix
'global_features_in' : Global features input matrix
'node_features_out' : Node features output matrix
'edge_features_out' : Edge features output matrix
'global_features_out' : Global features output matrix
n_features : int
Number of features to standardize.
is_verbose : bool, default=False
If True, enable verbose output.
tqdm_flavor : {'default', 'notebook'}, default='default'
Type of tqdm progress bar to use when is_verbose=True.
Returns
-------
mean : torch.Tensor
Features standardization mean tensor stored as a torch.Tensor with
shape (n_features,).
std : torch.Tensor
Features standardization standard deviation tensor stored as a
torch.Tensor with shape (n_features,).
Notes
-----
A biased estimator is used to compute the standard deviation according with
scikit-learn 1.3.2 documentation (sklearn.preprocessing.StandardScaler).
"""
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# Import tqdm
if tqdm_flavor == 'notebook':
from tqdm.notebook import tqdm
else:
from tqdm import tqdm
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# Instantiate data scaler
data_scaler = sklearn.preprocessing.StandardScaler()
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# Set data loader
data_loader = \
torch_geometric.loader.dataloader.DataLoader(dataset=dataset)
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# Loop over graph samples
for pyg_graph in tqdm(data_loader, mininterval=1, maxinterval=60,
miniters=0, desc='> Processing data samples: ',
disable=not is_verbose, unit=' sample'):
# Check sample graph type
if not isinstance(pyg_graph, torch_geometric.data.Data):
raise RuntimeError('Graph sample must be instance of '
'torch_geometric.data.Data.')
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# Set features mapping
features_map = {'node_features_in': 'x',
'edge_features_in': 'edge_attr',
'global_features_in': 'global_features_matrix',
'node_features_out': 'y',
'edge_features_out': 'edge_targets_matrix',
'global_features_out': 'global_targets_matrix'}
# Check sample graph feature
if features_map[features_type] not in pyg_graph.keys():
raise RuntimeError(f'Unknown or unexistent attribute '
f'{features_map[features_type]} from graph '
f'sample.')
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# Get features tensor
if features_type == 'node_features_in':
features_tensor = pyg_graph.x
elif features_type == 'edge_features_in':
features_tensor = pyg_graph.edge_attr
elif features_type == 'global_features_in':
features_tensor = pyg_graph.global_features_matrix
elif features_type == 'node_features_out':
features_tensor = pyg_graph.y
elif features_type == 'edge_features_out':
features_tensor = pyg_graph.edge_targets_matrix
elif features_type == 'global_features_out':
features_tensor = pyg_graph.global_targets_matrix
else:
raise RuntimeError('Unknown features type.')
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# Process sample to fit data scaler
if isinstance(features_tensor, torch.Tensor):
if features_tensor.shape[-1] != n_features:
raise RuntimeError(f'Mismatch between input graph '
f'({features_tensor.shape[-1]}) and '
f'model ({n_features}) number of '
f'features for features type: '
f'{features_type}')
# Process sample
data_scaler.partial_fit(features_tensor.clone())
else:
raise RuntimeError('Sample features tensor is not torch.Tensor.')
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# Get fitted mean and standard deviation tensors
mean = torch.tensor(data_scaler.mean_)
std = torch.sqrt(torch.tensor(data_scaler.var_))
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# Check features standardization mean tensor
if not isinstance(mean, torch.Tensor):
raise RuntimeError('Features standardization mean tensor is not a '
'torch.Tensor.')
elif len(mean) != features_tensor.shape[-1]:
raise RuntimeError('Features standardization mean tensor is not a '
'torch.Tensor(1d) with shape (n_features,).')
# Check features standardization standard deviation tensor
if not isinstance(std, torch.Tensor):
raise RuntimeError('Features standardization standard deviation '
'tensor is not a torch.Tensor.')
elif len(std) != features_tensor.shape[-1]:
raise RuntimeError('Features standardization standard deviation '
'tensor is not a torch.Tensor(1d) with shape '
'(n_features,).')
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
return mean, std