graphorge.gnn_base_model.predict.prediction.write_prediction_summary_file¶
- write_prediction_summary_file(predict_subdir, device_type, seed, model_directory, load_model_state, loss_type, loss_kwargs, is_normalized_loss, dataset_file_path, dataset, avg_predict_loss, total_time_sec, avg_time_sample)[source]¶
Write summary data file for model prediction process.
- Parameters:
predict_subdir (str) – Subdirectory where samples predictions results files are stored.
device_type ({'cpu', 'cuda'}) – Type of device on which torch.Tensor is allocated.
seed (int) – Seed used to initialize the random number generators of Python and other libraries (e.g., NumPy, PyTorch) for all devices to preserve reproducibility. Does also set workers seed in PyTorch data loaders.
model_directory (str) – Directory where model is stored.
load_model_state ({'best', 'last', int, None}) – Load availabl model state from the model directory. Data scalers are also loaded from model initialization file.
loss_type ({'mse',}) – Loss function type.
loss_kwargs (dict) – Arguments of torch.nn._Loss initializer.
is_normalized_loss (bool, default=False) – If True, then samples prediction loss are computed from the normalized data, False otherwise. Normalization requires that model features data scalers are fitted.
dataset_file_path (str) – Data set file path if such file exists. Only used for output purposes.
dataset (torch.utils.data.Dataset) – Data set.
avg_predict_loss (float) – Average prediction loss per sample.
total_time_sec (int) – Total prediction time in seconds.
avg_time_sample (float) – Average prediction time per sample.