graphorge.gnn_base_model.model.gnn_model.graph_standard_partial_fit¶
- graph_standard_partial_fit(dataset, features_type, n_features, is_verbose=False, tqdm_flavor='default')[source]¶
Perform batch fitting of standardization data scalers.
- Parameters:
dataset (torch.utils.data.Dataset) – GNN-based data set. Each sample corresponds to a torch_geometric.data.Data object describing a homogeneous graph.
features_type (str) –
Features for which data scaler is required:
’node_features_in’ : Node features input matrix
’edge_features_in’ : Edge features input matrix
’global_features_in’ : Global features input matrix
’node_features_out’ : Node features output matrix
’edge_features_out’ : Edge features output matrix
’global_features_out’ : Global features output matrix
n_features (int) – Number of features to standardize.
is_verbose (bool, default=False) – If True, enable verbose output.
tqdm_flavor ({'default', 'notebook'}, default='default') – Type of tqdm progress bar to use when is_verbose=True.
- Returns:
mean (torch.Tensor) – Features standardization mean tensor stored as a torch.Tensor with shape (n_features,).
std (torch.Tensor) – Features standardization standard deviation tensor stored as a torch.Tensor with shape (n_features,).
Notes
A biased estimator is used to compute the standard deviation according with scikit-learn 1.3.2 documentation (sklearn.preprocessing.StandardScaler).