graphorge.gnn_base_model.model.model_summary.get_model_summary¶
- get_model_summary(model, input_data=None, device_type='cpu', is_verbose=False, **kwargs)[source]¶
Get summary of PyTorch model.
Wrapper: torchinfo (https://pypi.org/project/torchinfo/)
- Parameters:
model (torch.nn.Module) – PyTorch model.
input_data (list[torch.Tensor], default=None) – Input data of PyTorch model forward propagation. If provided, then further summary data is computed and displayed (e.g., input/output shapes, number of operations, memory requirements). Must be list of one or more torch.Tensor to avoid unexpected behavior.
device_type ({'cpu', 'cuda'}, default='cpu') – Type of device on which torch.Tensor is allocated.
is_verbose (bool, default=False) – If True, enable verbose output.
kwargs (dict) – Other arguments of PyTorch model forward propagation.
- Returns:
model_statistics – PyTorch model summary object.
- Return type:
torchinfo.model_statistics.ModelStatistics