graphorge.gnn_base_model.train.training.get_pytorch_loss

get_pytorch_loss(loss_type, **kwargs)[source]

Get PyTorch loss function.

Parameters:
  • loss_type ({'mse',}) –

    Loss function type:

    ’mse’ : MSE (torch.nn.MSELoss)

  • **kwargs – Arguments of torch.nn._Loss initializer.

Returns:

loss_function – PyTorch loss function.

Return type:

torch.nn._Loss