hookeai.model_architectures.rnn_base_model.predict.prediction.get_pytorch_loss¶
- get_pytorch_loss(loss_type, **kwargs)[source]¶
Get PyTorch-based loss function.
Includes both native and custom PyTorch-based loss functions.
- Parameters:
loss_type ({'mse', 'mean_relative_error'}) –
Loss function type:
’mse’ : MSE (torch.nn.MSELoss)
’mre’ : MRE (Mean Relative Error, custom)
**kwargs – Arguments of Pytorch-based loss function.
- Returns:
loss_function – PyTorch-based loss function.
- Return type:
torch.nn.Module