graphorge.gnn_base_model.train.cross_validation.kfold_cross_validation¶
- kfold_cross_validation(cross_validation_dir, n_fold, n_max_epochs, dataset, model_init_args, lr_init, opt_algorithm='adam', lr_scheduler_type='steplr', lr_scheduler_kwargs={}, loss_nature='node_features_out', loss_type='mse', loss_kwargs={}, batch_size=1, is_sampler_shuffle=False, is_early_stopping=False, early_stopping_kwargs={}, dataset_file_path=None, device_type='cpu', is_verbose=False)[source]¶
k-fold cross validation of Graph Neural Network model.
Data set is split into k consecutive folds. The first n_samples % n_splits folds have size n_samples // n_splits + 1, other folds have size n_samples // n_splits, where n_samples is the number of samples. Each fold is then used once as a validation set while the k - 1 remaining folds form the training set.
- Parameters:
cross_validation_dir (dir) – Directory where cross-validation process data is stored.
n_fold (int) – Number of folds into which the data set is split to perform cross-validation.
n_max_epochs (int) – Maximum number of training epochs.
dataset (torch.utils.data.Dataset) – Graph Neural Network graph data set. Each sample corresponds to a torch_geometric.data.Data object describing a homogeneous graph.
model_init_args (dict) – Graph Neural Network model class initialization parameters (check class GNNEPDBaseModel).
lr_init (float) – Initial value optimizer learning rate. Constant learning rate value if no learning rate scheduler is specified (lr_scheduler_type=None).
opt_algorithm ({'adam',}, default='adam') –
Optimization algorithm:
’adam’ : Adam (torch.optim.Adam)
lr_scheduler_type ({'steplr', 'explr', 'linlr'}, default=None) –
Type of learning rate scheduler:
’steplr’ : Step-based decay (torch.optim.lr_scheduler.SetpLR)
’explr’ : Exponential decay (torch.optim.lr_scheduler.ExponentialLR)
’linlr’ : Linear decay (torch.optim.lr_scheduler.LinearLR)
lr_scheduler_kwargs (dict, default={}) – Arguments of torch.optim.lr_scheduler.LRScheduler initializer.
loss_nature ({'node_features_out', 'global_features_out'}, default='node_features_out') –
Loss nature:
’node_features_out’ : Based on node output features
’global_features_out’ : Based on global output features
loss_type ({'mse',}, default='mse') –
Loss function type:
’mse’ : MSE (torch.nn.MSELoss)
loss_kwargs (dict, default={}) – Arguments of torch.nn._Loss initializer.
batch_size (int, default=1) – Number of samples loaded per batch.
is_sampler_shuffle (bool, default=False) – If True, shuffles data set samples at every epoch.
is_early_stopping (bool, default=False) – If True, then training process is halted when early stopping criterion is triggered. By default, 20% of the training data set is allocated for the underlying validation procedures.
early_stopping_kwargs (dict, default={}) – Early stopping criterion parameters (key, str, item, value).
dataset_file_path (str, default=None) – Graph Neural Network graph data set file path if such file exists. Only used for output purposes.
device_type ({'cpu', 'cuda'}, default='cpu') – Type of device on which torch.Tensor is allocated.
is_verbose (bool, default=False) – If True, enable verbose output.
- Returns:
k_fold_loss_array – k-fold cross-validation loss array. For the i-th fold, data_array[i, 0] stores the best training loss and data_array[i, 1] stores the average prediction loss per sample.
- Return type:
numpy.ndarray(2d)