Source code for gnn_base_model.model.gnn_architectures

"""Graph Neural Networks architectures.

Classes
-------
GraphIndependentNetwork(torch.nn.Module)
    Graph Independent Network.
GraphInteractionNetwork(torch_geometric.nn.MessagePassing)
    Graph Interaction Network.
 
Functions
---------
build_fnn
    Build multilayer feed-forward neural network.
"""
#
#                                                                       Modules
# =============================================================================
# Third-party
import torch
import torch_geometric.nn
#
#                                                          Authorship & Credits
# =============================================================================
__author__ = 'Bernardo Ferreira (bernardo_ferreira@brown.edu)'
__credits__ = ['Bernardo Ferreira', 'Rui Barreira']
__status__ = 'Planning'
# =============================================================================
#
# =============================================================================
[docs] def build_fnn(input_size, output_size, output_activation=torch.nn.Identity(), hidden_layer_sizes=[], hidden_activation=torch.nn.Identity()): """Build multilayer feed-forward neural network. Parameters ---------- input_size : int Number of neurons of input layer. output_size : int Number of neurons of output layer. output_activation : torch.nn.Module, default=torch.nn.Identity Output unit activation function. Defaults to identity (linear) unit activation function. hidden_layer_sizes : list[int], default=[] Number of neurons of hidden layers. hidden_activation : torch.nn.Module, default=torch.nn.Identity Hidden unit activation function. Defaults to identity (linear) unit activation function. Returns ------- fnn : torch.nn.Sequential Multilayer feed-forward neural network. """ # Check input and output size if int(input_size) < 1 or int(output_size) < 1: raise RuntimeError(f'Number of input ({int(input_size)}) and output ' f'({output_size}) features must be at least 1.') # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # Set number of neurons of each layer layer_sizes = [] layer_sizes.append(input_size) layer_sizes += hidden_layer_sizes layer_sizes.append(output_size) # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # Set number of layers of adaptive weights n_layer = len(layer_sizes) - 1 # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # Set hidden and output layers unit activation functions if not callable(hidden_activation): raise RuntimeError('Hidden unit activation function must be derived ' 'from torch.nn.Module class.') activation_functions = [hidden_activation for i in range(n_layer - 1)] if not callable(output_activation): raise RuntimeError('Output unit activation function must be derived ' 'from torch.nn.Module class.') activation_functions.append(output_activation) # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # Create multilayer feed-forward neural network: # Initialize neural network fnn = torch.nn.Sequential() # Loop over neural network layers for i in range(n_layer): # Set layer linear transformation fnn.add_module("Layer-" + str(i), torch.nn.Linear(layer_sizes[i], layer_sizes[i + 1], bias=True)) # Set layer unit activation function fnn.add_module("Activation-" + str(i), activation_functions[i]) # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ return fnn
# =============================================================================
[docs] def build_rnn(input_size, hidden_layer_sizes, output_size, num_layers=None, rnn_cell='GRU', bias=True): """Build multilayer recurrent neural network. Parameters ---------- input_size : int Number of neurons of input layer. hidden_layer_sizes : list[int] Number of neurons of hidden layers. output_size : int Number of neurons of output layer. num_layers : list[int], default=None Number of layers in each RNN module. Defaults to None. rnn_cell : str, default='GRU' RNN architecture cell. bias : bool, default=True Whether to use bias within the RNN cell. Returns ------- rnn : torch.nn.Sequential Multilayer recurrent neural network with linear output layer. """ # Check input and output size if int(input_size) < 1 or int(output_size) < 1: raise RuntimeError(f'Number of input ({int(input_size)}) and output ' f'({output_size}) features must be at least 1.') # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # Check cells if rnn_cell != 'GRU': raise RuntimeError(f'({rnn_cell}) is not a recognized RNN cell.') # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ if hidden_layer_sizes is None: # This ensures at least one pass through an RNN, before the linear # output layer. hidden_layer_sizes = [output_size, ] if num_layers is None: num_layers = [] for _ in range(len(hidden_layer_sizes)): num_layers.append(1) elif len(num_layers) != len(hidden_layer_sizes): raise RuntimeError(f'Expected same ' f'number of ''num_layers'' as' f' ''hidden_layer_sizes''. Instead, got ' f'{len(num_layers)} and ' f'{len(hidden_layer_sizes)}.') # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # Set number of neurons of each layer layer_sizes = [] layer_sizes.append(input_size) layer_sizes += hidden_layer_sizes layer_sizes.append(output_size) n_layer = len(layer_sizes) - 1 # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # Create multilayer recurrent neural network: rnn = torch.nn.Sequential() # Loop over neural network layers for idx in range(n_layer - 1): # Set layer linear transformation rnn.add_module(rnn_cell + '-' + str(idx), torch.nn.GRU(input_size=layer_sizes[idx], hidden_size=layer_sizes[idx+1], bias=bias, num_layers=num_layers[idx])) # Extracts 'output' from '(output, hidden_states)'. 'hidden_state' is # not passed between RNN cells, and it are internal variable to each # RNN-cell. rnn.add_module('RNN-output' + str(idx),TorchRNNWrapper()) # Linear output layer (last hidden_size -> output_size) rnn.add_module('Output-layer',torch.nn.Linear(in_features=layer_sizes[-2], out_features=layer_sizes[-1], bias=bias)) # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ return rnn
# ============================================================================= class TorchRNNWrapper(torch.nn.Module): """TorchRNNWrapper. Extracts 'output' features from the tuple '(output, hidden_state)' returned by the recurrent neural network architectures implemented in Pytorch. 'hidden_state' is discarded. Methods ------- forward(self, x) Forward propagation. """ def forward(self, x): """Forward propagation. Parameters ---------- x : Tuple Tuple with (output, hidden_state) as returned by the recurrent neural network architectures implemented in Pytorch. Returns ------- output : {torch.Tensor, None} First element of the tuple, corresponding to the output of the recurrent neural network. """ output, _ = x return output # =============================================================================
[docs] class GraphIndependentNetwork(torch.nn.Module): """Graph Independent Network. A Graph Network block with (1) distinct update functions for node, edge and global features implemented as multilayer feed-forward neural networks with layer normalization and (2) no aggregation functions, i.e., independent node, edges and global blocks. Attributes ---------- _node_fn : torch.nn.Sequential Node update function. _n_node_in : int Number of node input features. _n_node_out : int Number of node output features. _edge_fn : torch.nn.Sequential Edge update function. _n_edge_in : int Number of edge input features. _n_edge_out : int Number of edge output features. _global_fn : torch.nn.Sequential Global update function. _n_global_in : int Number of global input features. _n_global_out : int Number of global output features. _n_time_node : int Number of discrete time steps of nodal features. If greater than 0, then nodal input features include a time dimension, and message passing layers are RNNs. _n_time_edge : int Number of discrete time steps of edge features. If greater than 0, then edge input features include a time dimension, and message passing layers are RNNs. _n_time_global : int Number of discrete time steps of global features. If greater than 0, then global input features include a time dimension, and message passing layers are RNNs. _is_norm_layer : bool, default=False If True, then add normalization layer to node, edge and global update functions. _is_skip_unset_update : bool If True, then return features input matrix when the corresponding update function has not been setup, otherwise return None. Methods ------- forward(self, node_features_in=None, edge_features_in=None, \ global_features_in) Forward propagation. """
[docs] def __init__(self, n_hidden_layers, hidden_layer_size, n_node_in=0, n_node_out=0, n_edge_in=0, n_edge_out=0, n_global_in=0, n_global_out=0, n_time_node=0, n_time_edge=0, n_time_global=0, node_hidden_activation=torch.nn.Identity(), node_output_activation=torch.nn.Identity(), edge_hidden_activation=torch.nn.Identity(), edge_output_activation=torch.nn.Identity(), global_hidden_activation=torch.nn.Identity(), global_output_activation=torch.nn.Identity(), is_norm_layer=False, is_skip_unset_update=False): """Constructor. Parameters ---------- n_hidden_layers : int Number of hidden layers of multilayer feed-forward neural network update functions. hidden_layer_size : int Number of neurons of hidden layers of multilayer feed-forward neural network update functions. n_node_in : int, default=0 Number of node input features. Must be greater than zero to setup node update function. n_node_out : int, default=0 Number of node output features. Must be greater than zero to setup node update function. n_edge_in : int, default=0 Number of edge input features. Must be greater than zero to setup edge update function. n_edge_out : int, default=0 Number of edge output features. Must be greater than zero to setup edge update function. n_global_in : int, default=0 Number of global input features. Must be greater than zero to setup global update function. n_global_out : int, default=0 Number of global output features. Must be greater than zero to setup global update function. node_hidden_activation : torch.nn.Module, default=torch.nn.Identity Hidden unit activation function of node update function (multilayer feed-forward neural network). Defaults to identity (linear) unit activation function. node_output_activation : torch.nn.Module, default=torch.nn.Identity Output unit activation function of node update function (multilayer feed-forward neural network). Defaults to identity (linear) unit activation function. edge_hidden_activation : torch.nn.Module, default=torch.nn.Identity Hidden unit activation function of edge update function (multilayer feed-forward neural network). Defaults to identity (linear) unit activation function. edge_output_activation : torch.nn.Module, default=torch.nn.Identity Output unit activation function of edge update function (multilayer feed-forward neural network). Defaults to identity (linear) unit activation function. global_hidden_activation : torch.nn.Module, default=torch.nn.Identity Hidden unit activation function of global update function (multilayer feed-forward neural network). Defaults to identity (linear) unit activation function. global_output_activation : torch.nn.Module, default=torch.nn.Identity Output unit activation function of global update function (multilayer feed-forward neural network). Defaults to identity (linear) unit activation function. n_time_node : int, default=0 Number of discrete time steps of nodal features. If greater than 0, then nodal input features include a time dimension and message passing layers are RNNs. n_time_edge : int, default=0 Number of discrete time steps of edge features. If greater than 0, then edge input features include a time dimension and message passing layers are RNNs. n_time_global : int, default=0 Number of discrete time steps of global features. If greater than 0, then global input features include a time dimension and message passing layers are RNNs. is_norm_layer : bool, default=False If True, then add normalization layer to node, edge and global update functions. is_skip_unset_update : bool, default=False If True, then return features input matrix when the corresponding update function has not been setup, otherwise return None. Ignored if update function is setup. """ # Initialize Graph Network block from base class super(GraphIndependentNetwork, self).__init__() # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # Set number of features self._n_node_in = int(n_node_in) self._n_node_out = int(n_node_out) self._n_edge_in = int(n_edge_in) self._n_edge_out = int(n_edge_out) self._n_global_in = int(n_global_in) self._n_global_out = int(n_global_out) # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # Set input with time dimension if n_time_node != 0 and n_time_edge !=0 and n_time_node != n_time_edge: raise RuntimeError(f'Number of time steps must match across ' f'nodal and edge features. Instead, got ' f'n_time_node=({n_time_node}) and ' f'n_time_edge=({n_time_edge})') if n_time_node != 0 and n_time_global !=0 and \ n_time_node != n_time_global: raise RuntimeError(f'Number of time steps must match across ' f'nodal and global features. Instead, got ' f'n_time_global=({n_time_global}) and ' f'n_time_node=({n_time_node})') if n_time_edge != 0 and n_time_global !=0 and \ n_time_edge != n_time_global: raise RuntimeError(f'Number of time steps must match across ' f'edge and global features. Instead, got ' f'n_time_global=({n_time_global}) and ' f'n_time_edge=({n_time_edge})') self._n_time_node = int(n_time_node) self._n_time_edge = int(n_time_edge) self._n_time_global = int(n_time_global) # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # Set normalization layer self._is_norm_layer = is_norm_layer # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # Set node update function as multilayer feed-forward neural network # or recurrent neural network # with layer normalization if self._n_node_in > 0 and self._n_node_out > 0: # Set node update function self._node_fn = torch.nn.Sequential() if self._n_time_node > 0: rnn = build_rnn(input_size=self._n_node_in, output_size=self._n_node_out, hidden_layer_sizes=n_hidden_layers*[ hidden_layer_size, ], rnn_cell='GRU', bias=True) self._node_fn.add_module('RNN', rnn) else: # Build multilayer feed-forward neural network fnn = build_fnn( input_size=self._n_node_in, output_size=self._n_node_out, output_activation=node_output_activation, hidden_layer_sizes=n_hidden_layers*[hidden_layer_size,], hidden_activation=node_hidden_activation) self._node_fn.add_module('FNN', fnn) # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # Add normalization layer (per-feature) to node update function if is_norm_layer: norm_layer = torch.nn.BatchNorm1d( num_features=self._n_node_out, affine=True) self._node_fn.add_module('Norm-Layer', norm_layer) else: self._node_fn = None # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # Set edge update function as multilayer feed-forward neural network # with layer normalization: if self._n_edge_in > 0 and self._n_edge_out > 0: # Set edge update function self._edge_fn = torch.nn.Sequential() if self._n_time_edge > 0: rnn = build_rnn(input_size=self._n_edge_in, output_size=self._n_edge_out, hidden_layer_sizes=n_hidden_layers*[ hidden_layer_size, ], rnn_cell='GRU', bias=True) self._edge_fn.add_module('RNN', rnn) else: # Build multilayer feed-forward neural network fnn = build_fnn( input_size=self._n_edge_in, output_size=self._n_edge_out, output_activation=edge_output_activation, hidden_layer_sizes=n_hidden_layers*[hidden_layer_size,], hidden_activation=edge_hidden_activation) self._edge_fn.add_module('FNN', fnn) # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # Add normalization layer (per-feature) to edge update function if is_norm_layer: norm_layer = torch.nn.BatchNorm1d( num_features=self._n_edge_out, affine=True) self._edge_fn.add_module('Norm-Layer', norm_layer) else: self._edge_fn = None # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # Set global update function as multilayer feed-forward neural network # with layer normalization: if self._n_global_in > 0 and self._n_global_out > 0: # Set global update function self._global_fn = torch.nn.Sequential() if self._n_time_global > 0: rnn = build_rnn(input_size=self._n_global_in, output_size=self._n_global_out, hidden_layer_sizes=n_hidden_layers*[ hidden_layer_size, ], bias=True) self._global_fn.add_module('RNN', rnn) else: # Build multilayer feed-forward neural network fnn = build_fnn( input_size=self._n_global_in, output_size=self._n_global_out, output_activation=global_output_activation, hidden_layer_sizes=n_hidden_layers*[hidden_layer_size,], hidden_activation=global_hidden_activation) self._global_fn.add_module('FNN', fnn) # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # Add normalization layer (per-element) to global update function if is_norm_layer: if self._n_global_in < 2: raise RuntimeError(f'Number of global features ' f'({self._n_global_in}) must be ' f'greater than 1 to compute standard ' f'deviation in the corresponding ' f'update function normalization ' f'layer.') else: norm_layer = torch.nn.LayerNorm( normalized_shape=self._n_global_out, elementwise_affine=True) self._global_fn.add_module('Norm-Layer', norm_layer) else: self._global_fn = None # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # Check update functions if (self._node_fn is None and self._edge_fn is None and self._global_fn is None): raise RuntimeError('Graph Independent Network was initialized ' 'without setting up any node, edge or global ' 'update function. Set positive number of ' 'features for at least the node, edge or ' 'global update function.') # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # Set flag to handle unset update function output self._is_skip_unset_update = is_skip_unset_update
# -------------------------------------------------------------------------
[docs] def forward(self, node_features_in=None, edge_features_in=None, global_features_in=None, batch_vector=None): """Forward propagation. Parameters ---------- node_features_in : torch.Tensor, default=None Nodes features input matrix stored as a torch.Tensor(2d) of shape (n_nodes, n_features). Ignored if node update function is not setup. edge_features_in : torch.Tensor, default=None Edges features input matrix stored as a torch.Tensor(2d) of shape (n_edges, n_features). Ignored if edge update function is not setup. global_features_in : torch.Tensor, default=None Global features input matrix stored as a torch.Tensor(2d) of shape (1, n_features). Ignored if global update function is not setup. batch_vector : torch.Tensor, default=None Batch vector stored as torch.Tensor(1d) of shape (n_nodes,), assigning each node to a specific batch subgraph. Required to process a graph holding multiple isolated subgraphs when batch size is greater than 1. Returns ------- node_features_out : {torch.Tensor, None} Nodes features output matrix stored as a torch.Tensor(2d) of shape (n_nodes, n_features). edge_features_out : {torch.Tensor, None} Edges features output matrix stored as a torch.Tensor(2d) of shape (n_edges, n_features). global_features_out : {torch.Tensor, None} Global features output matrix stored as a torch.Tensor(2d) of shape (1, n_features). """ # Check number of nodes and nodes features if self._node_fn is not None: if not isinstance(node_features_in, torch.Tensor): raise RuntimeError('Nodes features input matrix is not a ' 'torch.Tensor.') elif self._is_norm_layer and node_features_in.shape[0] < 2: raise RuntimeError(f'Number of nodes ' f'({node_features_in.shape[0]}) must be ' f'greater than 1 to compute standard ' f'deviation in the corresponding update ' f'function normalization layer.') elif self._n_time_node > 0 and \ node_features_in.shape[1] != \ self._n_node_in*self._n_time_node: raise RuntimeError(f'Mismatch of number of node features of ' f'model ' f'({self._n_node_in*self._n_time_node}) ' f'and nodes input features matrix ' f'({node_features_in.shape[1]}).') elif self._n_time_node == 0 and \ node_features_in.shape[1] != self._n_node_in: raise RuntimeError(f'Mismatch of number of node features of ' f'model ({self._n_node_in}) and nodes ' f'input features matrix ' f'({node_features_in.shape[1]}).') # Check number of edges and edges features if self._edge_fn is not None: if not isinstance(edge_features_in, torch.Tensor): raise RuntimeError('Edges features input matrix is not a ' 'torch.Tensor.') elif self._is_norm_layer and edge_features_in.shape[0] < 2: raise RuntimeError(f'Number of edges ' f'({edge_features_in.shape[0]}) must be ' f'greater than 1 to compute standard ' f'deviation in the corresponding update ' f'function normalization layer.') elif self._n_time_edge > 0 and \ edge_features_in.shape[1] != self._n_edge_in*self._n_time_edge: raise RuntimeError(f'Mismatch of number of edge features of ' f'model ' f'({self._n_edge_in*self._n_time_edge}) ' f'and edges input features matrix ' f'({edge_features_in.shape[1]}).') elif self._n_time_edge == 0 and \ edge_features_in.shape[1] != self._n_edge_in: raise RuntimeError(f'Mismatch of number of edge features of ' f'model ({self._n_edge_in}) and edges ' f'input features matrix ' f'({edge_features_in.shape[1]}).') # Check number global features if self._global_fn is not None: if not isinstance(global_features_in, torch.Tensor): raise RuntimeError('Global features input matrix is not a ' 'torch.Tensor.') elif self._n_time_global > 0 and \ global_features_in.shape[1] != \ self._n_global_in*self._n_time_global: raise RuntimeError(f'Mismatch of number of global features of ' f'model (' f'{self._n_global_in*self._n_time_global}) ' f'and global input features matrix ' f'({global_features_in.shape[1]}).') elif self._n_time_global == 0 and \ global_features_in.shape[1] != self._n_global_in: raise RuntimeError(f'Mismatch of number of global features of ' f'model ({self._n_global_in}) and global ' f'input features matrix ' f'({global_features_in.shape[1]}).') # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # Forward propagation: Node update function node_features_out = None if self._node_fn is not None: if self._n_time_node > 0: batch_size_node = node_features_in.shape[0] # If we have time series data, reshape to a 3d tensor: # (n_time_node, batch_size, n_node_in) node_features_in = \ node_features_in.view( \ self._n_time_node, batch_size_node, -1) # Compute node update node_features_out = self._node_fn(node_features_in) # If we have time series data, reshape back to 2d tensor: # (batch_size, n_node_in * n_time_node) node_features_out = node_features_out.view(batch_size_node, -1) else: node_features_out = self._node_fn(node_features_in) else: if self._is_skip_unset_update: node_features_out = node_features_in # Forward propagation: Edge update function edge_features_out = None if self._edge_fn is not None: # If we have time series data, reshape to a 3d tensor: # (n_time_edge, batch_size, n_edge_in) if self._n_time_edge > 0: batch_size_edge = edge_features_in.shape[0] edge_features_in = \ edge_features_in.view( self._n_time_edge, batch_size_edge, -1) # Compute edge update edge_features_out = self._edge_fn(edge_features_in) # If we have time series data, reshape back to 2d tensor: # (batch_size, n_edge_in * n_time_edge) edge_features_out = edge_features_out.view(batch_size_edge,-1) else: # Compute edge update edge_features_out = self._edge_fn(edge_features_in) else: if self._is_skip_unset_update: edge_features_out = edge_features_in # Forward propagation: Global update function global_features_out = None if self._global_fn is not None: # If we have time series data, reshape to a 3d tensor: # (n_time_global, batch_size, n_global_in) if self._n_time_global > 0: batch_size_global = global_features_in.shape[0] global_features_in = \ global_features_in.view( self._n_time_global, batch_size_global, -1) # Compute global update global_features_out = self._global_fn(global_features_in) # If we have time series data, reshape back to 2d tensor: # (batch_size, n_global_in * n_time_global) global_features_out = \ global_features_out.view(batch_size_global, -1) else: # Compute global update global_features_out = self._global_fn(global_features_in) else: if self._is_skip_unset_update: global_features_out = global_features_in # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ return node_features_out, edge_features_out, global_features_out
# =============================================================================
[docs] class GraphInteractionNetwork(torch_geometric.nn.MessagePassing): """Graph Interaction Network. A Graph Network block with (1) distinct update functions for node, edge and global features implemented as multilayer feed-forward or recurrent neural networks with layer normalization and (2) edge-to-node and node-to-global aggregation functions. Attributes ---------- _node_fn : torch.nn.Sequential Node update function. _n_node_in : int Number of node input features. _n_node_out : int Number of node output features. _edge_fn : torch.nn.Sequential Edge update function. _n_edge_in : int Number of edge input features. _n_edge_out : int Number of edge input features. _global_fn : torch.nn.Sequential Global update function. _n_global_in : int Number of global input features. _n_global_out : int Number of global output features. _n_time_node : int Number of discrete time steps of nodal features. If greater than 0, then nodal input features include a time dimension, and message passing layers are RNNs. _n_time_edge : int Number of discrete time steps of edge features. If greater than 0, then edge input features include a time dimension, and message passing layers are RNNs. _n_time_global : int Number of discrete time steps of global features. If greater than 0, then global input features include a time dimension, and message passing layers are RNNs. _is_norm_layer : bool, default=False If True, then add normalization layer to node, edge and global update functions. Methods ------- forward(self, edges_indexes, node_features_in=None, edge_features_in=None) Forward propagation. message(self, node_features_in_i, node_features_in_j, \ edge_features_in=None) Builds messages to node i from each edge (j, i) (edge update). update(self, node_features_in_aggr, node_features_in=None) Update node features. """
[docs] def __init__(self, n_node_out, n_edge_out, n_hidden_layers, hidden_layer_size, n_node_in=0, n_edge_in=0, n_global_in=0, n_global_out=0, n_time_node=0, n_time_edge=0, n_time_global=0, edge_to_node_aggr='add', node_to_global_aggr='add', node_hidden_activation=torch.nn.Identity(), node_output_activation=torch.nn.Identity(), edge_hidden_activation=torch.nn.Identity(), edge_output_activation=torch.nn.Identity(), global_hidden_activation=torch.nn.Identity(), global_output_activation=torch.nn.Identity(), is_norm_layer=False): """Constructor. Parameters ---------- n_node_out : int Number of node output features. n_edge_out : int Number of edge output features. n_hidden_layers : int Number of hidden layers of multilayer feed-forward neural network update functions. hidden_layer_size : int Number of neurons of hidden layers of multilayer feed-forward neural network update functions. n_node_in : int, default=0 Number of node input features. n_edge_in : int, default=0 Number of edge input features. n_global_in : int, default=0 Number of global input features. n_global_out : int, default=0 Number of global output features. edge_to_node_aggr : {'add',}, default='add' Edge-to-node aggregation scheme. node_to_global_aggr : {'add', 'mean'}, default='add' Node-to-global aggregation scheme. node_hidden_activation : torch.nn.Module, default=torch.nn.Identity Hidden unit activation function of node update function (multilayer feed-forward neural network). Defaults to identity (linear) unit activation function. node_output_activation : torch.nn.Module, default=torch.nn.Identity Output unit activation function of node update function (multilayer feed-forward neural network). Defaults to identity (linear) unit activation function. edge_hidden_activation : torch.nn.Module, default=torch.nn.Identity Hidden unit activation function of edge update function (multilayer feed-forward neural network). Defaults to identity (linear) unit activation function. edge_output_activation : torch.nn.Module, default=torch.nn.Identity Output unit activation function of edge update function (multilayer feed-forward neural network). Defaults to identity (linear) unit activation function. global_hidden_activation : torch.nn.Module, default=torch.nn.Identity Hidden unit activation function of global update function (multilayer feed-forward neural network). Defaults to identity (linear) unit activation function. global_output_activation : torch.nn.Module, default=torch.nn.Identity Output unit activation function of global update function (multilayer feed-forward neural network). Defaults to identity (linear) unit activation function. n_time_node : int, default=0 Number of discrete time steps of nodal features. If greater than 0, then nodal input features include a time dimension and message passing layers are RNNs. n_time_edge : int, default=0 Number of discrete time steps of edge features. If greater than 0, then edge input features include a time dimension and message passing layers are RNNs. n_time_global : int, default=0 Number of discrete time steps of global features. If greater than 0, then global input features include a time dimension and message passing layers are RNNs. is_norm_layer : bool, default=False If True, then add normalization layer to node, edge and global update functions. """ # Set aggregation scheme if edge_to_node_aggr == 'add': aggregation = torch_geometric.nn.aggr.SumAggregation() else: raise RuntimeError('Unknown aggregation scheme.') # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # Set node-to-global aggregation scheme if node_to_global_aggr in ('add', 'mean'): self.node_to_global_aggr = node_to_global_aggr else: raise RuntimeError('Unknown node-to-global aggregation scheme.') # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # Set flow direction of message passing flow = 'source_to_target' # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # Initialize Graph Network block from base class super(GraphInteractionNetwork, self).__init__(aggr=aggregation, flow=flow) # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # Set number of features self._n_node_in = int(n_node_in) self._n_node_out = int(n_node_out) self._n_edge_in = int(n_edge_in) self._n_edge_out = int(n_edge_out) self._n_global_in = int(n_global_in) self._n_global_out = int(n_global_out) # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # Set input with time dimension if n_time_node != 0 and n_time_edge !=0 and n_time_node != n_time_edge: raise RuntimeError(f'Number of time steps must match across ' f'nodal and edge features. Instead, got ' f'n_time_node=({n_time_node}) and ' f'n_time_edge=({n_time_edge})') if n_time_node != 0 and n_time_global !=0 and \ n_time_node != n_time_global: raise RuntimeError(f'Number of time steps must match across ' f'nodal and global features. Instead, got ' f'n_time_global=({n_time_global}) and ' f'n_time_node=({n_time_node})') if n_time_edge != 0 and n_time_global !=0 and \ n_time_edge != n_time_global: raise RuntimeError(f'Number of time steps must match across ' f'edge and global features. Instead, got ' f'n_time_global=({n_time_global}) and ' f'n_time_edge=({n_time_edge})') self._n_time_node = int(n_time_node) self._n_time_edge = int(n_time_edge) self._n_time_global = int(n_time_global) # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # Set normalization layer self._is_norm_layer = is_norm_layer # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # Check number of input features if (self._n_node_in < 1 and self._n_edge_in < 1 and self._n_global_in < 1): raise RuntimeError(f'Impossible to setup model without node ' f'({self._n_node_in}), edge ' f'({self._n_edge_in}), or global ' f'({self._n_global_in}) input features.') # Check number of output features if (self._n_node_out < 1 or self._n_edge_out < 1): raise RuntimeError(f'Number of node ({self._n_node_out}) and ' f'edge ({self._n_edge_out}) output features ' f'must be greater than 0.') # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # Set node update function as multilayer feed-forward/recurrent neural # network with layer normalization: self._node_fn = torch.nn.Sequential() if self._n_time_node > 0: # Build recurrent neural network rnn = build_rnn(input_size=self._n_node_in + self._n_edge_out, output_size=self._n_node_out, hidden_layer_sizes=n_hidden_layers*[ hidden_layer_size, ], rnn_cell='GRU', bias=True) self._node_fn.add_module('RNN', rnn) else: if self._n_time_edge > 0: # Build recurrent neural network rnn = build_rnn(input_size=self._n_node_in + self._n_edge_out, output_size=self._n_node_out, hidden_layer_sizes=n_hidden_layers*[ hidden_layer_size, ], rnn_cell='GRU', bias=True) self._node_fn.add_module('RNN', rnn) else: # Build multilayer feed-forward neural network fnn = build_fnn( input_size=self._n_node_in + self._n_edge_out, output_size=self._n_node_out, output_activation=node_output_activation, hidden_layer_sizes=n_hidden_layers*[ hidden_layer_size,], hidden_activation=node_hidden_activation) # Set node update function self._node_fn.add_module('FNN', fnn) # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # Add normalization layer (per-feature) to node update function if is_norm_layer: norm_layer = torch.nn.BatchNorm1d( num_features=self._n_node_out, affine=True) self._node_fn.add_module('Norm-Layer', norm_layer) # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # Set edge update function as multilayer feed-forward or recurrent # neural network with layer normalization: # Set edge update function self._edge_fn = torch.nn.Sequential() if self._n_time_edge > 0: # Build recurrent neural network # regardless if sef._n_time_node > 0 or not rnn = build_rnn(input_size=self._n_edge_in + 2*self._n_node_in, output_size=self._n_edge_out, hidden_layer_sizes=n_hidden_layers*[ hidden_layer_size,], rnn_cell='GRU', bias=True) self._edge_fn.add_module('RNN', rnn) else: if self._n_time_node > 0 or self._n_time_edge > 0: # Build recurrent neural network rnn = build_rnn(input_size=self._n_edge_in + 2*self._n_node_in, output_size=self._n_edge_out, hidden_layer_sizes=n_hidden_layers*[ hidden_layer_size,], rnn_cell='GRU', bias=True) self._edge_fn.add_module('RNN', rnn) else: # Build multilayer feed-forward neural network fnn = build_fnn( input_size=self._n_edge_in + 2*self._n_node_in, output_size=self._n_edge_out, output_activation=edge_output_activation, hidden_layer_sizes=n_hidden_layers*[ hidden_layer_size,], hidden_activation=edge_hidden_activation) # Set edge update function self._edge_fn.add_module('FNN', fnn) # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # Add normalization layer (per-feature) to edge update function if is_norm_layer: norm_layer = torch.nn.BatchNorm1d( num_features=self._n_edge_out, affine=True) self._edge_fn.add_module('Norm-Layer', norm_layer) # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # Set global update function as multilayer feed-forward neural network # with layer normalization: if self._n_global_out > 0: # Set global update function self._global_fn = torch.nn.Sequential() if self._n_time_global > 0: # Build stacked recurrent neural network # regardless if sef._n_time_node > 0 or self._n_time edge > 0 # or not rnn = build_rnn(input_size=self._n_global_in + self._n_node_out, output_size=self._n_global_out, hidden_layer_sizes=n_hidden_layers*[ hidden_layer_size,], rnn_cell='GRU', bias=True) self._global_fn.add_module('RNN', rnn) else: if self._n_time_node > 0 or self._n_time_edge > 0: # Build stacked recurrent neural network rnn = build_rnn(input_size=self._n_global_in + self._n_node_out, output_size=self._n_global_out, hidden_layer_sizes=n_hidden_layers*[ hidden_layer_size,], rnn_cell='GRU', bias=True) self._global_fn.add_module('RNN', rnn) else: # Build multilayer feed-forward neural network fnn = build_fnn( input_size=self._n_global_in+self._n_node_out, output_size=self._n_global_out, output_activation=global_output_activation, hidden_layer_sizes=n_hidden_layers*[ hidden_layer_size,], hidden_activation=global_hidden_activation) self._global_fn.add_module('FNN', fnn) # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # Add normalization layer (per-element) to global update function if is_norm_layer: if self._n_global_in < 2: raise RuntimeError(f'Number of global features ' f'({self._n_global_in}) must be ' f'greater than 1 to compute standard ' f'deviation in the corresponding ' f'update function normalization layer.') else: norm_layer = torch.nn.LayerNorm( normalized_shape=self._n_global_out, elementwise_affine=True) self._global_fn.add_module('Norm-Layer', norm_layer) else: self._global_fn = None
# -------------------------------------------------------------------------
[docs] def forward(self, edges_indexes, node_features_in=None, edge_features_in=None, global_features_in=None, batch_vector=None): """Forward propagation. Parameters ---------- edges_indexes : torch.Tensor Edges indexes matrix stored as torch.Tensor(2d) with shape (2, n_edges), where the i-th edge is stored in edges_indexes[:, i] as (start_node_index, end_node_index). node_features_in : torch.Tensor, default=None Nodes features input matrix stored as a torch.Tensor(2d) of shape (n_nodes, n_features). If None, the edge-to-node aggregation is only built up to the highest receiver node index according with edges_indexes. To preserve total number of nodes in edge-to-node aggregation, pass torch.empty(n_nodes, 0) instead of None. edge_features_in : torch.Tensor, default=None Edges features input matrix stored as a torch.Tensor(2d) of shape (n_edges, n_features). global_features_in : torch.Tensor, default=None Global features input matrix stored as a torch.Tensor(2d) of shape (1, n_features). Ignored if global update function is not setup. batch_vector : torch.Tensor, default=None Batch vector stored as torch.Tensor(1d) of shape (n_nodes,), assigning each node to a specific batch subgraph. Required to process a graph holding multiple isolated subgraphs when batch size is greater than 1. Returns ------- node_features_out : torch.Tensor Nodes features output matrix stored as a torch.Tensor(2d) of shape (n_nodes, n_features). edge_features_out : torch.Tensor Edges features output matrix stored as a torch.Tensor(2d) of shape (n_edges, n_features). global_features_out : {torch.Tensor, None} Global features output matrix stored as a torch.Tensor(2d) of shape (1, n_features). """ # Check input features matrices if node_features_in is None and edge_features_in is None: raise RuntimeError('Impossible to compute forward propagation of ' 'model without node (None) and edge (None) ' 'input features matrices.') # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # Check edges indexes if not isinstance(edges_indexes, torch.Tensor): raise RuntimeError('Edges indexes matrix is not a torch.Tensor.') elif len(edges_indexes.shape) != 2 or edges_indexes.shape[0] != 2: raise RuntimeError('Edges indexes matrix is not a torch.Tensor ' 'of shape (2, n_edges).') # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # Check number of nodes and nodes features if node_features_in is not None: if not isinstance(node_features_in, torch.Tensor): raise RuntimeError('Nodes features input matrix is not a ' 'torch.Tensor.') elif self._is_norm_layer and node_features_in.shape[0] < 2: raise RuntimeError(f'Number of nodes ' f'({node_features_in.shape[0]}) must be ' f'greater than 1 to compute standard ' f'deviation in the corresponding update ' f'functions normalization layer.') elif self._n_time_node > 0 and \ node_features_in.shape[1] != self._n_node_in*self._n_time_node: raise RuntimeError(f'Mismatch of number of node features of ' f'model ' f'({self._n_node_in*self._n_time_node}) ' f'and nodes input features matrix ' f'({node_features_in.shape[1]}).') elif self._n_time_node == 0 and \ node_features_in.shape[1] != self._n_node_in: raise RuntimeError(f'Mismatch of number of node features of ' f'model ({self._n_node_in}) and nodes ' f'input features matrix ' f'({node_features_in.shape[1]}).') # Check number of edges and edges features if edge_features_in is not None: if not isinstance(edge_features_in, torch.Tensor): raise RuntimeError('Edges features input matrix is not a ' 'torch.Tensor.') elif self._is_norm_layer and edge_features_in.shape[0] < 2: raise RuntimeError(f'Number of edges ' f'({edge_features_in.shape[0]}) must be ' f'greater than 1 to compute standard ' f'deviation in the corresponding update ' f'function normalization layer.') elif edge_features_in.shape[0] != edges_indexes.shape[1]: raise RuntimeError(f'Mismatch of number of edges of graph ' f'edges indexes ({edges_indexes.shape[1]}) ' f'and edges input features matrix ' f'({edge_features_in.shape[1]}).') elif self._n_time_edge > 0 and \ edge_features_in.shape[1] != \ self._n_edge_in * self._n_time_edge: raise RuntimeError(f'Mismatch of number of edge features of ' f'model ' f'({self._n_edge_in*self._n_time_edge }) ' f'and edges input features matrix ' f'({edge_features_in.shape[1]}).') elif self._n_time_edge == 0 and \ edge_features_in.shape[1] != self._n_edge_in: raise RuntimeError(f'Mismatch of number of edge features of ' f'model ({self._n_edge_in}) and edges ' f'input features matrix ' f'({edge_features_in.shape[1]}).') # Check global features if global_features_in is not None: if not isinstance(global_features_in, torch.Tensor): raise RuntimeError('Global features input matrix is not a ' 'torch.Tensor.') elif self._n_time_global > 0 and \ global_features_in.shape[1] != \ self._n_global_in*self._n_time_global: raise RuntimeError(f'Mismatch of number of global features of ' f'model ' f'({self._n_global_in*self._n_time_global}) ' f'and global input features matrix ' f'({global_features_in.shape[1]}).') elif self._n_time_global == 0 and \ global_features_in.shape[1] != self._n_global_in: raise RuntimeError(f'Mismatch of number of global features of ' f'model ({self._n_global_in}) and global ' f'input features matrix ' f'({global_features_in.shape[1]}).') # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # Perform graph neural network message-passing step (message, # aggregation, update) and get updated node features. # Time series data are reshaped within the message and update methods. node_features_out = self.propagate( edge_index=edges_indexes, node_features_in=node_features_in, edge_features_in=edge_features_in) # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # Get updated edge features edge_features_out = self._edge_features_out self._edge_features_out = None # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # Initialize updated global features global_features_out = None # Get update global features # Time series data are reshaped within the update_global method. if self._global_fn is not None: global_features_out = self.update_global( global_features_in=global_features_in, node_features_out=node_features_out, batch_vector=batch_vector) # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ return node_features_out, edge_features_out, global_features_out
# -------------------------------------------------------------------------
[docs] def message(self, node_features_in_i, node_features_in_j, edge_features_in=None): """Builds messages to node i from each edge (j, i) (edge update). Assumes that j is the source node and i is the receiver node (flow direction set as 'source_to_target'). For each edge (j, i), the update function input features result from concatenation of the edge features and the corresponding nodes features. The source and receiver node input features mappings based on the edges indexes matrix are built in the _collect() method of class torch_geometric.nn.MessagePassing. The edges features output matrix is passed as the input tensor to the aggregation operator (class torch.nn.aggr.Aggregation) set in the initialization of the torch_geometric.nn.MessagePassing class. It is called by the "propagate" method within the PyG backend. Parameters ---------- node_features_in_i : torch.Tensor Source node input features for each edge stored as a torch.Tensor(2d) of shape (n_edges, n_features). Mapping is performed based on the edges indexes matrix. node_features_in_j : torch.Tensor Receiver node input features for each edge stored as a torch.Tensor(2d) of shape (n_edges, n_features). Mapping is performed based on the edges indexes matrix. edge_features_in : torch.Tensor, default=None Edges features input matrix stored as a torch.Tensor(2d) of shape (n_edges, n_features). Returns ------- edge_features_out : torch.Tensor Edges features output matrix stored as a torch.Tensor(2d) of shape (n_edges, n_features). """ # Check input features is_node_features_in = (node_features_in_i is not None and node_features_in_j is not None) is_edge_features_in = edge_features_in is not None # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # Concatenate available input features for each edge if is_node_features_in and is_edge_features_in: if self._n_time_node > 0 and self._n_time_edge == 0: edge_features_in = \ edge_features_in.repeat(1, self._n_time_node) elif self._n_time_edge > 0 and self._n_time_node == 0: node_features_in_i = \ node_features_in_i.repeat(1, self._n_time_edge) node_features_in_j = \ node_features_in_j.repeat(1, self._n_time_edge) # Concatenate nodes and edges input features edge_features_in_cat = \ torch.cat([node_features_in_i, node_features_in_j, edge_features_in], dim=-1) elif is_node_features_in and not is_edge_features_in: if self._n_time_edge > 0: node_features_in_i = \ node_features_in_i.repeat(1, self._n_time_edge) node_features_in_j = \ node_features_in_j.repeat(1, self._n_time_edge) # Concatenate nodes input features edge_features_in_cat = \ torch.cat([node_features_in_i, node_features_in_j], dim=-1) elif is_edge_features_in: if self._n_time_node > 0: edge_features_in = \ edge_features_in.repeat(1,self._n_time_node) # Concatenate edges input features edge_features_in_cat = edge_features_in else: raise RuntimeError('Impossible to build edge update function ' 'input features matrix without node (None) and ' 'edge (None) input features matrices') # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # Update edge features # If we have time series data, reshape to a 3d tensor: # (n_time_edge, batch_size, n_edge_in) # If edges feature time series data, then it must be of the same size # as the time series data at the nodes if self._n_time_edge > 0 or self._n_time_node > 0: n_time = max(self._n_time_node, self._n_time_edge) batch_size_edge = edge_features_in_cat.shape[0] edge_features_in_cat = \ edge_features_in_cat.view(n_time, batch_size_edge, -1) # Compute global update edge_features_out = self._edge_fn(edge_features_in_cat) # If we have time series data, reshape back to 2d tensor: # (batch_size, n_edge_in * n_time_edge) edge_features_out = edge_features_out.view(batch_size_edge, -1) else: # Compute global update edge_features_out = self._edge_fn(edge_features_in_cat) # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # Store updated edges features self._edge_features_out = edge_features_out # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ return edge_features_out
# -------------------------------------------------------------------------
[docs] def update(self, node_features_in_aggr, node_features_in=None): """Update node features. The nodes features input matrix resulting from message passing and aggregation is built in the aggregation operator (class torch.nn.aggr.Aggregation) set in the initialization of the torch_geometric.nn.MessagePassing class. It is called by the propagate method within the PyG backend. Parameters ---------- node_features_in_aggr : torch.Tensor Nodes features input matrix resulting from message passing and edge-to-node aggregation, stored as a torch.Tensor(2d) of shape (n_nodes, n_features). node_features_in : torch.Tensor, default=None Nodes features input matrix stored as a torch.Tensor(2d) of shape (n_nodes, n_features). Returns ------- node_features_out : torch.Tensor Nodes features output matrix stored as a torch.Tensor(2d) of shape (n_nodes, n_features). """ # Concatenate features for each node: # Set node features stemming from edge-to-node aggregation # It is not necessary to extend node_features_in_aggr along the last # dimension, as this was already done in the message() function, when # aggregating edge-to-nodal features. # Only node_features_in needs to be extended, for the case # self._n_time_edge > 0 and self._n_time_node == 0, as # node_features_in_aggr will already be extended to include a time # dimension node_features_in_cat = node_features_in_aggr # Concatenate available node input features if node_features_in is not None: # Check number of nodes stemming from edge-to-node aggregation if node_features_in_aggr.shape[0] != node_features_in.shape[0]: raise RuntimeError(f'Mismatch between number of nodes ' f'stemming from edge-to-node aggregation ' f'({node_features_in_aggr.shape[0]}) ' f'and nodes features input matrix ' f'({node_features_in.shape[0]}).') # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ if self._n_time_edge > 0 and self._n_time_node == 0: node_features_in = \ node_features_in.repeat(1,self._n_time_edge) node_features_in_cat = \ torch.cat([node_features_in_cat, node_features_in], dim=-1) # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # Update node features if self._n_time_node > 0 or self._n_time_edge > 0: n_time = max(self._n_time_node, self._n_time_edge) # If we have time series data, reshape to a 3d tensor: # (n_time_node, batch_size, n_node_in) batch_size_node = node_features_in_cat.shape[0] node_features_in_cat = \ node_features_in_cat.view(n_time, batch_size_node, -1) # Compute global update node_features_out = self._node_fn(node_features_in_cat) # If we have time series data, reshape back to 2d tensor: # (batch_size, n_node_in * n_time_node) node_features_out = node_features_out.view(batch_size_node, -1) else: # Compute global update node_features_out = self._node_fn(node_features_in_cat) # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ return node_features_out
# -------------------------------------------------------------------------
[docs] def update_global(self, node_features_out, global_features_in=None, batch_vector=None): """Update global features. Parameters ---------- node_features_out : torch.Tensor Nodes features output matrix stored as a torch.Tensor(2d) of shape (n_nodes, n_features). If self._n_time_node > 0 or self._n_time_edge > 0, then node_features_out has shape (n_nodes, n_features*self._n_time_node/edge). global_features_in : torch.Tensor, default=None Global features input matrix stored as a torch.Tensor(2d) of shape (1, n_features). batch_vector : torch.Tensor, default=None Batch vector stored as torch.Tensor(1d) of shape (n_nodes,), assigning each node to a specific batch subgraph. Required to process a graph holding multiple isolated subgraphs when batch size is greater than 1. Returns ------- global_features_out : torch.Tensor Global features output matrix stored as a torch.Tensor(2d) of shape (1, n_features). """ # Perform node-to-global aggregation if self.node_to_global_aggr == 'add': node_features_in_aggr = torch_geometric.nn.global_add_pool( node_features_out, batch_vector) elif self.node_to_global_aggr == 'mean': node_features_in_aggr = torch_geometric.nn.global_mean_pool( node_features_out, batch_vector) else: raise RuntimeError('Unknown node-to-global aggregation scheme.') # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # Concatenate global features: # Set global features stemming from node-to-global aggregation if self._n_time_global > 0 and \ (self._n_time_node == 0 and self._n_time_edge == 0): global_features_in_cat = \ node_features_in_aggr.repeat(1, self._n_time_global) else: global_features_in_cat = node_features_in_aggr # Concatenate available global input features if global_features_in is not None: # If (self._n_time_node > 0 or self._n_time_edge > 0) # and self._n_time_global == 0: # then global features must be extended with a time dimension if (self._n_time_node > 0 or self._n_time_edge > 0) \ and self._n_time_global == 0: n_time_in = max([self._n_time_node, self._n_time_edge]) global_features_in = global_features_in.repeat(1, n_time_in) global_features_in_cat = \ torch.cat( [global_features_in_cat, global_features_in],dim=-1) # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # Update global features: # If we have time series data, reshape to a 3d tensor: # (n_time_global, batch_size, n_global_in) if self._n_time_global > 0 or self._n_time_node > 0 or \ self._n_time_edge > 0: n_time = max([self._n_time_node, self._n_time_edge, self._n_time_global]) batch_size_global = global_features_in_cat.shape[0] global_features_in_cat = \ global_features_in_cat.view(n_time, batch_size_global, -1) # Compute global update global_features_out = self._global_fn(global_features_in_cat) # If we have time series data, reshape back to 2d tensor: # (batch_size, n_global_in * n_time_node) global_features_out = \ global_features_out.view(batch_size_global, -1) else: global_features_out = self._global_fn(global_features_in_cat) # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ return global_features_out