hookeai.miscellaneous.pruning.pruning_dataset.perform_model_prediction

perform_model_prediction(predict_directory, dataset_file_path, model_directory, is_remove_sample_prediction=False, device_type='cpu', is_verbose=False)[source]

Perform prediction with RNN-based model.

Parameters:
  • predict_directory (str) – Directory where model predictions results are stored.

  • dataset_file_path (str) – Testing data set file path.

  • model_directory (str) – Directory where model is stored.

  • is_remove_sample_prediction (bool, default=False) – If True, then remove sample prediction files after plots are generated.

  • device_type ({'cpu', 'cuda'}, default='cpu') – Type of device on which torch.Tensor is allocated.

  • is_verbose (bool, default=False) – If True, enable verbose output.

Returns:

  • predict_subdir (str) – Subdirectory where samples predictions results files are stored.

  • avg_predict_loss (float) – Average prediction loss per sample. Defaults to None if ground-truth is not available for all data set samples.