graphorge.gnn_base_model.train.cross_validation.predict

predict(dataset, model_directory, model=None, predict_directory=None, file_name_pattern=None, load_model_state=None, loss_nature='node_features_out', loss_type='mse', loss_kwargs={}, is_normalized_loss=False, batch_size=1, dataset_file_path=None, device_type='cpu', seed=None, is_verbose=False, tqdm_flavor='default')[source]

Make predictions with Graph Neural Network model for given dataset.

Parameters:
  • dataset (torch.utils.data.Dataset) – Graph Neural Network graph data set. Each sample corresponds to a torch_geometric.data.Data object describing a homogeneous graph.

  • model_directory (str) – Directory where Graph Neural Network model is stored.

  • model (GNNEPDBaseModel, default=None) – Graph Neural Network model. If None, then model is initialized from the initialization file and the state is loaded from the state file. In both cases the model is set to evaluation mode.

  • predict_directory (str, default=None) – Directory where model predictions results are stored. If None, then all output files are supressed.

  • file_name_pattern (str, default=None) – A f-string pattern for the file name used to save prediction results.

  • load_model_state ({'best', 'last', int, None}, default=None) –

    Load available Graph Neural Network model state from the model directory. Options:

    ’best’ : Model state corresponding to best performance available

    ’last’ : Model state corresponding to highest training epoch

    int : Model state corresponding to given training epoch

    None : Model default state file

  • loss_nature ({'node_features_out', 'global_features_out'}, default='node_features_out') –

    Loss nature:

    ’node_features_out’ : Based on node output features

    ’global_features_out’ : Based on global output features

  • loss_type ({'mse',}, default='mse') –

    Loss function type:

    ’mse’ : MSE (torch.nn.MSELoss)

  • loss_kwargs (dict, default={}) – Arguments of torch.nn._Loss initializer.

  • is_normalized_loss (bool, default=False) – If True, then samples prediction loss are computed from normalized output data, False otherwise. Normalization of output data requires that model data scalers are available.

  • batch_size (int, default=1) – Number of samples loaded per batch.

  • dataset_file_path (str, default=None) – Graph Neural Network graph data set file path if such file exists. Only used for output purposes.

  • device_type ({'cpu', 'cuda'}, default='cpu') – Type of device on which torch.Tensor is allocated.

  • seed (int, default=None) – Seed used to initialize the random number generators of Python and other libraries (e.g., NumPy, PyTorch) for all devices to preserve reproducibility. Does also set workers seed in PyTorch data loaders.

  • is_verbose (bool, default=False) – If True, enable verbose output.

  • tqdm_flavor ({'default', 'notebook'}, default='default') – Type of tqdm progress bar to use when is_verbose=True.

Returns:

  • predict_subdir (str) – Subdirectory where samples predictions results files are stored.

  • avg_predict_loss (float) – Average prediction loss per sample. Defaults to None if ground-truth is not available for all data set samples.