hookeai.miscellaneous.pruning.pruning_dataset.perform_model_prediction¶
- perform_model_prediction(predict_directory, dataset_file_path, model_directory, is_remove_sample_prediction=False, device_type='cpu', is_verbose=False)[source]¶
Perform prediction with RNN-based model.
- Parameters:
predict_directory (str) – Directory where model predictions results are stored.
dataset_file_path (str) – Testing data set file path.
model_directory (str) – Directory where model is stored.
is_remove_sample_prediction (bool, default=False) – If True, then remove sample prediction files after plots are generated.
device_type ({'cpu', 'cuda'}, default='cpu') – Type of device on which torch.Tensor is allocated.
is_verbose (bool, default=False) – If True, enable verbose output.
- Returns:
predict_subdir (str) – Subdirectory where samples predictions results files are stored.
avg_predict_loss (float) – Average prediction loss per sample. Defaults to None if ground-truth is not available for all data set samples.