graphorge.gnn_base_model.train.training.train_model

train_model(n_max_epochs, dataset, model_init_args, lr_init, opt_algorithm='adam', lr_scheduler_type=None, lr_scheduler_kwargs={}, loss_nature='node_features_out', loss_type='mse', loss_kwargs={}, batch_size=1, is_sampler_shuffle=False, data_loader_kwargs={}, is_early_stopping=False, early_stopping_kwargs={}, load_model_state=None, save_every=None, save_loss_every=None, dataset_file_path=None, device_type='cpu', seed=None, is_verbose=False, tqdm_flavor='default')[source]

Training of Graph Neural Network model.

Parameters:
  • n_max_epochs (int) – Maximum number of training epochs.

  • dataset (torch.utils.data.Dataset) – Graph Neural Network graph data set. Each sample corresponds to a torch_geometric.data.Data object describing a homogeneous graph.

  • model_init_args (dict) – Graph Neural Network model class initialization parameters (check class GNNEPDBaseModel).

  • lr_init (float) – Initial value optimizer learning rate. Constant learning rate value if no learning rate scheduler is specified (lr_scheduler_type=None).

  • opt_algorithm ({'adam',}, default='adam') –

    Optimization algorithm:

    ’adam’ : Adam (torch.optim.Adam)

  • lr_scheduler_type ({'steplr', 'explr', 'linlr'}, default=None) –

    Type of learning rate scheduler:

    ’steplr’ : Step-based decay (torch.optim.lr_scheduler.SetpLR)

    ’explr’ : Exponential decay (torch.optim.lr_scheduler.ExponentialLR)

    ’linlr’ : Linear decay (torch.optim.lr_scheduler.LinearLR)

  • lr_scheduler_kwargs (dict, default={}) – Arguments of torch.optim.lr_scheduler.LRScheduler initializer.

  • loss_nature ({'node_features_out', 'edge_features_out', 'global_features_out'}, default='node_features_out') –

    Loss nature:

    ’node_features_out’ : Based on node output features

    ’edge_features_out’ : Based on edge output features

    ’global_features_out’ : Based on global output features

  • loss_type ({'mse',}, default='mse') –

    Loss function type:

    ’mse’ : MSE (torch.nn.MSELoss)

  • loss_kwargs (dict, default={}) – Arguments of torch.nn._Loss initializer.

  • batch_size (int, default=1) – Number of samples loaded per batch.

  • is_sampler_shuffle (bool, default=False) – If True, shuffles data set samples at every epoch.

  • data_loader_kwargs (dict, default={}) – Additional arguments for torch_geometric.loader.dataloader.DataLoader.

  • is_early_stopping (bool, default=False) – If True, then training process is halted when early stopping criterion is triggered. By default, 20% of the training data set is allocated for the underlying validation procedures.

  • early_stopping_kwargs (dict, default={}) – Early stopping criterion parameters (key, str, item, value).

  • load_model_state ({'best', 'last', 'init', int, None}, default=None) –

    Load available GNN-based model state from the model directory. Data scalers are also loaded from model initialization file. Options:

    ’best’ : Model state corresponding to best performance available

    ’last’ : Model state corresponding to highest training epoch

    int : Model state corresponding to given training epoch

    ’init’ : Model state corresponding to initial state

    None : Model default state file

  • save_every (int, default=None) – Save Graph Neural Network model every save_every epochs. If None, then saves only last epoch and best performance states.

  • save_loss_every (int, default=None) – Save loss history model every save_loss_every epochs. If None, then saves loss history only after the last epoch.

  • dataset_file_path (str, default=None) – Graph Neural Network graph data set file path if such file exists. Only used for output purposes.

  • device_type ({'cpu', 'cuda'}, default='cpu') – Type of device on which torch.Tensor is allocated.

  • seed (int, default=None) – Seed used to initialize the random number generators of Python and other libraries (e.g., NumPy, PyTorch) for all devices to preserve reproducibility. Does also set workers seed in PyTorch data loaders.

  • is_verbose (bool, default=False) – If True, enable verbose output.

  • tqdm_flavor ({'default', 'notebook'}, default='default') – Type of tqdm progress bar to use when is_verbose=True.

Returns:

  • model (torch.nn.Module) – Graph Neural Network model.

  • best_loss (float) – Best loss during training process.

  • best_training_epoch (int) – Training epoch corresponding to best loss during training process.