"""Prediction of Graph Neural Network model.
Functions
---------
predict
Make predictions with Graph Neural Network model for given dataset.
make_predictions_subdir
Create model predictions subdirectory.
save_sample_predictions
Save model prediction results for given sample.
load_sample_predictions
Load model prediction results for given sample.
compute_sample_prediction_loss
Compute loss of sample output features prediction.
seed_worker
Set workers seed in PyTorch data loaders to preserve reproducibility.
write_prediction_summary_file
Write summary data file for model prediction process.
"""
#
# Modules
# =============================================================================
# Standard
import os
import pickle
import random
import re
import time
import datetime
# Third-party
import torch
import torch_geometric
import numpy as np
# Local
from gnn_base_model.model.gnn_model import GNNEPDBaseModel
from gnn_base_model.train.torch_loss import get_pytorch_loss
from ioput.iostandard import make_directory, write_summary_file
#
# Authorship & Credits
# =============================================================================
__author__ = 'Bernardo Ferreira (bernardo_ferreira@brown.edu)'
__credits__ = ['Bernardo Ferreira', ]
__status__ = 'Planning'
# =============================================================================
#
# =============================================================================
[docs]
def predict(dataset, model_directory, model=None, predict_directory=None,
file_name_pattern=None, load_model_state=None,
loss_nature='node_features_out', loss_type='mse', loss_kwargs={},
is_normalized_loss=False, batch_size=1, dataset_file_path=None,
device_type='cpu', seed=None, is_verbose=False,
tqdm_flavor='default'):
"""Make predictions with Graph Neural Network model for given dataset.
Parameters
----------
dataset : torch.utils.data.Dataset
Graph Neural Network graph data set. Each sample corresponds to a
torch_geometric.data.Data object describing a homogeneous graph.
model_directory : str
Directory where Graph Neural Network model is stored.
model : GNNEPDBaseModel, default=None
Graph Neural Network model. If None, then model is initialized
from the initialization file and the state is loaded from the state
file. In both cases the model is set to evaluation mode.
predict_directory : str, default=None
Directory where model predictions results are stored. If None, then
all output files are supressed.
file_name_pattern : str, default=None
A f-string pattern for the file name used to save prediction results.
load_model_state : {'best', 'last', int, None}, default=None
Load available Graph Neural Network model state from the model
directory. Options:
'best' : Model state corresponding to best performance available
'last' : Model state corresponding to highest training epoch
int : Model state corresponding to given training epoch
None : Model default state file
loss_nature : {'node_features_out', 'global_features_out'}, \
default='node_features_out'
Loss nature:
'node_features_out' : Based on node output features
'global_features_out' : Based on global output features
loss_type : {'mse',}, default='mse'
Loss function type:
'mse' : MSE (torch.nn.MSELoss)
loss_kwargs : dict, default={}
Arguments of torch.nn._Loss initializer.
is_normalized_loss : bool, default=False
If True, then samples prediction loss are computed from normalized
output data, False otherwise. Normalization of output data requires
that model data scalers are available.
batch_size : int, default=1
Number of samples loaded per batch.
dataset_file_path : str, default=None
Graph Neural Network graph data set file path if such file exists. Only
used for output purposes.
device_type : {'cpu', 'cuda'}, default='cpu'
Type of device on which torch.Tensor is allocated.
seed : int, default=None
Seed used to initialize the random number generators of Python and
other libraries (e.g., NumPy, PyTorch) for all devices to preserve
reproducibility. Does also set workers seed in PyTorch data loaders.
is_verbose : bool, default=False
If True, enable verbose output.
tqdm_flavor : {'default', 'notebook'}, default='default'
Type of tqdm progress bar to use when is_verbose=True.
Returns
-------
predict_subdir : str
Subdirectory where samples predictions results files are stored.
avg_predict_loss : float
Average prediction loss per sample. Defaults to None if ground-truth is
not available for all data set samples.
"""
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# Import tqdm
if tqdm_flavor == 'notebook':
from tqdm.notebook import tqdm
else:
from tqdm import tqdm
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# Set random number generators initialization for reproducibility
if isinstance(seed, int):
random.seed(seed)
np.random.seed(seed)
generator = torch.Generator().manual_seed(seed)
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# Set device
device = torch.device(device_type)
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
start_time_sec = time.time()
if is_verbose:
print('\nGraph Neural Network model prediction'
'\n-------------------------------------')
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# Check model directory
if not os.path.exists(model_directory):
raise RuntimeError('The Graph Neural Network model directory has not '
'been found:\n\n' + model_directory)
# Check prediction directory
if predict_directory is not None and not os.path.exists(predict_directory):
raise RuntimeError('The Graph Neural Network model prediction '
'directory has not been found:\n\n'
+ predict_directory)
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# Initialize model and load model state if not provided
if model is None:
if is_verbose:
print('\n> Loading Graph Neural Network model...')
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# Initialize Graph Neural Network model
model = GNNEPDBaseModel.init_model_from_file(model_directory)
# Set model device
model.set_device(device_type)
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# Load Graph Neural Network model state
_ = model.load_model_state(load_model_state=load_model_state,
is_remove_posterior=False)
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# Get model input and output features normalization
is_model_in_normalized = model.is_model_in_normalized
is_model_out_normalized = model.is_model_out_normalized
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# Move model to device
model.to(device=device)
# Set model in evaluation mode
model.eval()
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# Create model predictions subdirectory for current prediction process
predict_subdir = None
if predict_directory is not None:
predict_subdir = make_predictions_subdir(predict_directory)
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# Set data loader
if isinstance(seed, int):
data_loader = torch_geometric.loader.dataloader.DataLoader(
dataset=dataset, batch_size=batch_size,
worker_init_fn=seed_worker, generator=generator, shuffle=False)
else:
data_loader = torch_geometric.loader.dataloader.DataLoader(
dataset=dataset, batch_size=batch_size, shuffle=False)
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# Initialize loss function
loss_function = get_pytorch_loss(loss_type, **loss_kwargs)
# Initialize samples prediction loss
loss_samples = []
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
if is_verbose:
print('\n\n> Starting prediction process...\n')
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# Set context manager to avoid creation of computation graphs during the
# model evaluation (forward propagation)
with torch.no_grad():
# Loop over graph samples
for i, pyg_graph in enumerate(tqdm(data_loader, mininterval=1,
maxinterval=60, miniters=0,
desc='> Predictions: ',
disable=not is_verbose,
unit=' sample')):
# Move sample to device
pyg_graph.to(device)
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# Get batch node assignment vector
if batch_size > 1:
batch_vector = pyg_graph.batch
else:
batch_vector = None
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# Get input features from input graph
node_features_in, edge_features_in, global_features_in, \
edges_indexes = model.get_input_features_from_graph(
pyg_graph, is_normalized=is_model_in_normalized)
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# Get node output features ground-truth
node_targets, edge_targets, global_targets = \
model.get_output_features_from_graph(
pyg_graph, is_normalized=False)
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# Initialize sample results
results = {}
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# Get metadata
metadata = model.get_metadata_from_graph(pyg_graph)
# Store metadata
results['metadata'] = {}
if isinstance(metadata, dict):
# Iterate over metadata items
for key, value in metadata.items():
# Process tensor metadata
if isinstance(value, torch.Tensor):
# If there is only one element, store it as a scalar
if value.numel() == 1:
results['metadata'][key] = (
value.detach().cpu().item())
# Otherwise, store it as a numpy array
else:
results['metadata'][key] = (
value.detach().cpu().numpy())
# Process non-tensor metadata
else:
results['metadata'][key] = value
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# Compute output features predictions (forward propagation)
if loss_nature == 'node_features_out':
# Compute node output features
features_out, _, _ = model(
node_features_in=node_features_in,
edge_features_in=edge_features_in,
global_features_in=global_features_in,
edges_indexes=edges_indexes, batch_vector=batch_vector)
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# Denormalize node output features
if is_model_out_normalized:
# Get model data scaler
features_out = model.data_scaler_transform(
tensor=features_out, features_type='node_features_out',
mode='denormalize')
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# Get sample node output features ground-truth
# (None if not available)
targets = node_targets
# Store sample results
results['node_features_out'] = features_out.detach().cpu()
results['node_targets'] = None
if targets is not None:
results['node_targets'] = targets.detach().cpu()
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
elif loss_nature == 'global_features_out':
# Compute global output features
_, _, features_out = model(
node_features_in=node_features_in,
edge_features_in=edge_features_in,
global_features_in=global_features_in,
edges_indexes=edges_indexes, batch_vector=batch_vector)
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# Denormalize global output features
if is_model_out_normalized:
# Get model data scaler
features_out = model.data_scaler_transform(
tensor=features_out,
features_type='global_features_out',
mode='denormalize')
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# Get sample global output features ground-truth
# (None if not available)
targets = global_targets
# Store sample results
results['global_features_out'] = features_out.detach().cpu()
results['global_targets'] = None
if targets is not None:
results['global_targets'] = targets.detach().cpu()
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
else:
raise RuntimeError('Unknown loss nature.')
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# Compute sample output features prediction loss
loss = compute_sample_prediction_loss(
model, loss_nature, loss_function, features_out, targets,
is_normalized_loss=is_normalized_loss)
# Store prediction loss data
results['prediction_loss_data'] = \
(loss_nature, loss_type, loss, is_normalized_loss)
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# Assemble sample prediction loss if ground-truth is available
if loss is not None:
loss_samples.append(loss.detach().cpu())
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# Save sample predictions results
if predict_directory is not None:
save_sample_predictions(predictions_dir=predict_subdir,
prediction_id=i,
sample_results=results,
file_name_pattern=file_name_pattern)
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
if is_verbose:
print('\n> Finished prediction process!\n')
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# Compute average prediction loss per sample
avg_predict_loss = None
if isinstance(loss_samples, list):
avg_predict_loss = np.mean(loss_samples)
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
if is_verbose:
# Set average prediction loss output format
if avg_predict_loss:
loss_str = (f'{avg_predict_loss:.8e} | {loss_type}')
if is_normalized_loss:
loss_str += ', normalized'
else:
loss_str = 'Ground-truth not available'
# Display average loss
print('\n> Avg. prediction loss per sample: '
+ loss_str)
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# Compute total prediction time and average prediction time per sample
total_time_sec = time.time() - start_time_sec
if len(dataset) > 0:
avg_time_sample = total_time_sec/len(dataset)
else:
avg_time_sample = float('nan')
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
if is_verbose:
print(f'\n> Prediction results directory: {predict_subdir}')
print(f'\n> Total prediction time: '
f'{str(datetime.timedelta(seconds=int(total_time_sec)))} | '
f'Avg. prediction time per sample: '
f'{str(datetime.timedelta(seconds=int(avg_time_sample)))}\n')
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# Write summary data file for model prediction process
if predict_directory is not None:
write_prediction_summary_file(
predict_subdir, device_type, seed, model_directory,
load_model_state, loss_type, loss_kwargs, is_normalized_loss,
dataset_file_path, dataset, avg_predict_loss, total_time_sec,
avg_time_sample)
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
return predict_subdir, avg_predict_loss
# =============================================================================
def make_predictions_subdir(predict_directory):
"""Create model predictions subdirectory.
Parameters
----------
predict_directory : str
Directory where model predictions results are stored.
Returns
-------
predict_subdir : str
Subdirectory where samples predictions results files are stored.
"""
# Check prediction directory
if not os.path.exists(predict_directory):
raise RuntimeError('The model prediction directory has not been '
'found:\n\n' + predict_directory)
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# Set predictions subdirectory path
predict_subdir = os.path.join(predict_directory, 'prediction_set_0')
while os.path.exists(predict_subdir):
predict_subdir = os.path.join(
predict_directory,
'prediction_set_' + str(int(predict_subdir.split('_')[-1]) + 1))
# Create model predictions subdirectory
predict_subdir = make_directory(predict_subdir, is_overwrite=False)
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
return predict_subdir
# =============================================================================
def save_sample_predictions(predictions_dir, prediction_id, sample_results,
file_name_pattern = None):
"""Save model prediction results for given sample.
Parameters
----------
predictions_dir : str
Directory where sample prediction results are stored.
prediction_id : int
Prediction ID appended to the prediction sample results file name.
sample_results : dict
Sample prediction results.
file_name_pattern: str, default=None
A f-string pattern for the file name. The pattern will be evaluated
when saving the predictions and has access to `prediction_id` and
all the `sample_results['metadata']` content. If None, the pattern
``'prediction_sample_{prediction_id}'`` is used.
"""
# Check prediction results directory
if not os.path.exists(predictions_dir):
raise RuntimeError('The prediction results directory has not been '
'found:\n\n' + predictions_dir)
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# Generate file name
if file_name_pattern is None:
file_name_pattern = 'prediction_sample_{prediction_id}'
file_name = file_name_pattern.format(**sample_results['metadata'],
prediction_id=prediction_id)
# Set sample prediction results file path
sample_path = os.path.join(predictions_dir, file_name + '.pkl')
# Save sample prediction results
with open(sample_path, 'wb') as sample_file:
pickle.dump(sample_results, sample_file)
# =============================================================================
def load_sample_predictions(sample_prediction_path):
"""Load model prediction results for given sample.
Parameters
----------
sample_prediction_path : str
Sample prediction results file path.
Returns
-------
sample_results : dict
Sample prediction results.
"""
# Check sample prediction results file
if not os.path.isfile(sample_prediction_path):
raise RuntimeError('Sample prediction results file has not been '
'found:\n\n' + sample_prediction_path)
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# Load sample prediction results
with open(sample_prediction_path, 'rb') as sample_prediction_file:
sample_results = pickle.load(sample_prediction_file)
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
return sample_results
# =============================================================================
def compute_sample_prediction_loss(model, loss_nature, loss_function,
features_out, targets,
is_normalized_loss=False):
"""Compute loss of sample output features prediction.
Assumes that provided output features and targets are denormalized.
Parameters
----------
model : GNNEPDBaseModel
Graph Neural Network model.
loss_nature : {'node_features_out', 'global_features_out'}
Loss nature.
loss_function : torch.nn._Loss
PyTorch loss function.
features_out : torch.Tensor
Predicted output features stored as a torch.Tensor(2d).
targets : {torch.Tensor, None}
Output features ground-truth stored as a torch.Tensor(2d).
is_normalized_loss : bool, default=False
If True, then samples prediction loss are computed from normalized
output data, False otherwise. Normalization of output data requires
that model data scalers are available.
Returns
-------
loss : {float, None}
Loss of sample output features prediction. Set to None if output
features ground-truth is not available.
"""
# Check if output features ground-truth is available
is_ground_truth_available = targets is not None
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# Compute sample loss
if is_ground_truth_available:
# Normalize output features
if is_normalized_loss:
# Get model data scaler
if loss_nature == 'node_features_out':
scaler = model.get_fitted_data_scaler('node_features_out')
elif loss_nature == 'global_features_out':
scaler = model.get_fitted_data_scaler('global_features_out')
else:
raise RuntimeError('Unknown loss nature.')
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# Get normalized output features predictions
features_out = scaler.transform(features_out)
# Get normalized output features ground-truth
targets = scaler.transform(targets)
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# Compute sample loss
loss = loss_function(features_out, targets)
else:
# Set sample loss to None if ground-truth is not available
loss = None
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
return loss
# =============================================================================
def seed_worker(worker_id):
"""Set workers seed in PyTorch data loaders to preserve reproducibility.
Taken from: https://pytorch.org/docs/stable/notes/randomness.html
Parameters
----------
worker_id : int
Worker ID.
"""
worker_seed = torch.initial_seed() % 2**32
np.random.seed(worker_seed)
random.seed(worker_seed)
# =============================================================================
def write_prediction_summary_file(
predict_subdir, device_type, seed, model_directory, load_model_state,
loss_type, loss_kwargs, is_normalized_loss, dataset_file_path, dataset,
avg_predict_loss, total_time_sec, avg_time_sample):
"""Write summary data file for model prediction process.
Parameters
----------
predict_subdir : str
Subdirectory where samples predictions results files are stored.
device_type : {'cpu', 'cuda'}
Type of device on which torch.Tensor is allocated.
seed : int
Seed used to initialize the random number generators of Python and
other libraries (e.g., NumPy, PyTorch) for all devices to preserve
reproducibility. Does also set workers seed in PyTorch data loaders.
model_directory : str
Directory where model is stored.
load_model_state : {'best', 'last', int, None}
Load availabl model state from the model directory. Data scalers are
also loaded from model initialization file.
loss_type : {'mse',}
Loss function type.
loss_kwargs : dict
Arguments of torch.nn._Loss initializer.
is_normalized_loss : bool, default=False
If True, then samples prediction loss are computed from the normalized
data, False otherwise. Normalization requires that model features data
scalers are fitted.
dataset_file_path : str
Data set file path if such file exists. Only used for output purposes.
dataset : torch.utils.data.Dataset
Data set.
avg_predict_loss : float
Average prediction loss per sample.
total_time_sec : int
Total prediction time in seconds.
avg_time_sample : float
Average prediction time per sample.
"""
# Set summary data
summary_data = {}
summary_data['device_type'] = device_type
summary_data['seed'] = seed
summary_data['model_directory'] = model_directory
summary_data['load_model_state'] = load_model_state
summary_data['loss_type'] = loss_type
summary_data['loss_kwargs'] = loss_kwargs if loss_kwargs else None
summary_data['is_normalized_loss'] = is_normalized_loss
summary_data['Prediction data set file'] = \
dataset_file_path if dataset_file_path else None
summary_data['Prediction data set size'] = len(dataset)
summary_data['Avg. prediction loss per sample'] = \
f'{avg_predict_loss:.8e}' if avg_predict_loss else None
summary_data['Total prediction time'] = \
str(datetime.timedelta(seconds=int(total_time_sec)))
summary_data['Avg. prediction time per sample'] = \
str(datetime.timedelta(seconds=int(avg_time_sample)))
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# Write summary file
write_summary_file(
summary_directory=predict_subdir,
summary_title='Summary: Model prediction',
**summary_data)