hookeai.model_architectures.rc_base_model.train.training.get_pytorch_loss

get_pytorch_loss(loss_type, **kwargs)[source]

Get PyTorch-based loss function.

Includes both native and custom PyTorch-based loss functions.

Parameters:
  • loss_type ({'mse', 'mean_relative_error'}) –

    Loss function type:

    ’mse’ : MSE (torch.nn.MSELoss)

    ’mre’ : MRE (Mean Relative Error, custom)

  • **kwargs – Arguments of Pytorch-based loss function.

Returns:

loss_function – PyTorch-based loss function.

Return type:

torch.nn.Module