"""Graph Neural Network based Encoder-Process-Decoder model.
Classes
-------
EncodeProcessDecode(torch.nn.Module)
GNN-based Encoder-Process-Decoder model.
Encoder(GraphIndependentNetwork)
GNN-based encoder.
Processor(torch_geometric.nn.MessagePassing)
GNN-based processor.
Decoder(torch.nn.Module)
FNN-based decoder.
"""
#
# Modules
# =============================================================================
# Third-party
import torch
import torch_geometric.nn
# Local
from gnn_base_model.model.gnn_architectures import build_fnn, \
build_rnn, GraphIndependentNetwork, GraphInteractionNetwork
#
# Authorship & Credits
# =============================================================================
__author__ = 'Bernardo Ferreira (bernardo_ferreira@brown.edu)'
__credits__ = ['Bernardo Ferreira', 'Rui Barreira']
__status__ = 'Planning'
# =============================================================================
#
# =============================================================================
[docs]
class EncodeProcessDecode(torch.nn.Module):
"""GNN-based Encoder-Process-Decoder model.
Attributes
----------
_n_message_steps : int
Number of message-passing steps.
_encoder : Encoder
GNN-based encoder.
_processor : Processor
GNN-based processor.
_decoder : Decoder
GNN-based decoder.
_n_node_out : int
Number of node output features.
_n_edge_out : int
Number of edge output features.
_n_global_out : int
Number of node output features.
_n_time_node : int
Number of discrete time steps of nodal features.
If greater than 0, then nodal input features include a time
dimension and message passing layers are RNNs.
_n_time_edge : int
Number of discrete time steps of edge features.
If greater than 0, then edge input features include a time
dimension and message passing layers are RNNs.
_n_time_global : int
Number of discrete time steps of global features.
If greater than 0, then global input features include a time
dimension and message passing layers are RNNs.
Methods
-------
forward(self, node_features_in, edge_features_in, edges_indexes, \
global_features_in)
Forward propagation.
"""
[docs]
def __init__(self, n_message_steps, n_node_out, n_edge_out,
n_global_out, enc_n_hidden_layers, pro_n_hidden_layers,
dec_n_hidden_layers, hidden_layer_size,
n_node_in=0, n_edge_in=0, n_global_in=0,
n_time_node=0, n_time_edge=0, n_time_global=0,
pro_edge_to_node_aggr='add', pro_node_to_global_aggr='add',
enc_node_hidden_activation=torch.nn.Identity(),
enc_node_output_activation=torch.nn.Identity(),
enc_edge_hidden_activation=torch.nn.Identity(),
enc_edge_output_activation=torch.nn.Identity(),
enc_global_hidden_activation=torch.nn.Identity(),
enc_global_output_activation=torch.nn.Identity(),
pro_node_hidden_activation=torch.nn.Identity(),
pro_node_output_activation=torch.nn.Identity(),
pro_edge_hidden_activation=torch.nn.Identity(),
pro_edge_output_activation=torch.nn.Identity(),
pro_global_hidden_activation=torch.nn.Identity(),
pro_global_output_activation=torch.nn.Identity(),
dec_node_hidden_activation=torch.nn.Identity(),
dec_node_output_activation=torch.nn.Identity(),
dec_edge_hidden_activation=torch.nn.Identity(),
dec_edge_output_activation=torch.nn.Identity(),
dec_global_hidden_activation=torch.nn.Identity(),
dec_global_output_activation=torch.nn.Identity(),
is_node_res_connect=False,
is_edge_res_connect=False,
is_global_res_connect=False):
"""Constructor.
Parameters
----------
n_message_steps : int
Number of message-passing steps. Setting number of message-passing
steps to 0 results in Encoder-Decoder model (Processor is not
initialized).
n_node_out : int
Number of node output features.
n_edge_out : int
Number of edge output features.
n_global_out : int
Number of node output features.
n_time_node : int, default=0
Number of discrete time steps of nodal features.
If greater than 0, then nodal input features include a time
dimension and message passing layers are RNNs.
n_time_edge : int, default=0
Number of discrete time steps of edge features.
If greater than 0, then edge input features include a time
dimension and message passing layers are RNNs.
n_time_global : int, default=0
Number of discrete time steps of global features.
If greater than 0, then global input features include a time
dimension and message passing layers are RNNs.
enc_n_hidden_layers : int
Encoder: Number of hidden layers of multilayer feed-forward/
recurrent neural network update functions.
pro_n_hidden_layers : int
Processor: Number of hidden layers of multilayer feed-forward/
recurrent neural network update functions.
dec_n_hidden_layers : int
Decoder: Number of hidden layers of multilayer feed-forward/
recurrent neural network update functions.
hidden_layer_size : int
Number of neurons of hidden layers of multilayer feed-forward/
recurrent neural network update functions.
n_node_in : int, default=0
Number of node input features.
n_edge_in : int, default=0
Number of edge input features.
n_global_in : int, default=0
Number of global input features.
pro_edge_to_node_aggr : {'add',}, default='add'
Processor: Edge-to-node aggregation scheme.
pro_node_to_global_aggr : {'add', 'mean'}, default='add'
Processor: Node-to-global aggregation scheme.
enc_node_hidden_activation : torch.nn.Module, default=torch.nn.Identity
Encoder: Hidden unit activation function of node update function
(multilayer feed-forward neural network). Defaults to identity
(linear) unit activation function.
enc_node_output_activation : torch.nn.Module, default=torch.nn.Identity
Encoder: Output unit activation function of node update function
(multilayer feed-forward neural network). Defaults to identity
(linear) unit activation function.
enc_edge_hidden_activation : torch.nn.Module, default=torch.nn.Identity
Encoder: Hidden unit activation function of edge update function
(multilayer feed-forward neural network). Defaults to identity
(linear) unit activation function.
enc_edge_output_activation : torch.nn.Module, default=torch.nn.Identity
Encoder: Output unit activation function of edge update function
(multilayer feed-forward neural network). Defaults to identity
(linear) unit activation function.
enc_global_hidden_activation : torch.nn.Module, \
default=torch.nn.Identity
Encoder: Hidden unit activation function of global update function
(multilayer feed-forward neural network). Defaults to identity
(linear) unit activation function.
enc_global_output_activation : torch.nn.Module, \
default=torch.nn.Identity
Encoder: Output unit activation function of global update function
(multilayer feed-forward neural network). Defaults to identity
(linear) unit activation function.
pro_node_hidden_activation : torch.nn.Module, default=torch.nn.Identity
Processor: Hidden unit activation function of node update function
(multilayer feed-forward neural network). Defaults to identity
(linear) unit activation function.
pro_node_output_activation : torch.nn.Module, default=torch.nn.Identity
Processor: Output unit activation function of node update function
(multilayer feed-forward neural network). Defaults to identity
(linear) unit activation function.
pro_edge_hidden_activation : torch.nn.Module, default=torch.nn.Identity
Processor: Hidden unit activation function of edge update function
(multilayer feed-forward neural network). Defaults to identity
(linear) unit activation function.
pro_edge_output_activation : torch.nn.Module, default=torch.nn.Identity
Processor: Output unit activation function of edge update function
(multilayer feed-forward neural network). Defaults to identity
(linear) unit activation function.
pro_global_hidden_activation : torch.nn.Module, \
default=torch.nn.Identity
Processor: Hidden unit activation function of global update
function (multilayer feed-forward neural network). Defaults to
identity (linear) unit activation function.
pro_global_output_activation : torch.nn.Module, \
default=torch.nn.Identity
Processor: Output unit activation function of global update
function (multilayer feed-forward neural network). Defaults to
identity (linear) unit activation function.
dec_node_hidden_activation : torch.nn.Module, default=torch.nn.Identity
Decoder: Hidden unit activation function of node update function
(multilayer feed-forward neural network). Defaults to identity
(linear) unit activation function.
dec_node_output_activation : torch.nn.Module, default=torch.nn.Identity
Decoder: Output unit activation function of node update function
(multilayer feed-forward neural network). Defaults to identity
(linear) unit activation function.
dec_edge_hidden_activation : torch.nn.Module, default=torch.nn.Identity
Decoder: Hidden unit activation function of edge update function
(multilayer feed-forward neural network). Defaults to identity
(linear) unit activation function.
dec_edge_output_activation : torch.nn.Module, default=torch.nn.Identity
Decoder: Output unit activation function of edge update function
(multilayer feed-forward neural network). Defaults to identity
(linear) unit activation function.
dec_global_hidden_activation : torch.nn.Module, \
default=torch.nn.Identity
Decoder: Hidden unit activation function of global update function
(multilayer feed-forward neural network). Defaults to identity
(linear) unit activation function.
dec_global_output_activation : torch.nn.Module, \
default=torch.nn.Identity
Decoder: Output unit activation function of global update function
(multilayer feed-forward neural network). Defaults to identity
(linear) unit activation function.
is_node_res_connect : bool, default=False
Processor: Add residual connections between nodes input and
output features if True, False otherwise. Number of input and
output features must match to process residual connections.
Automatically set to False if number of node input features is
zero.
is_edge_res_connect : bool, default=False
Processor: Add residual connections in between edges input and
output features if True, False otherwise. Number of input and
output features must match to process residual connections.
Automatically set to False if number of edge input features is
zero.
is_global_res_connect : bool, default=False
Processor: Add residual connections in between global input and
output features if True, False otherwise. Number of input and
output features must match to process residual connections.
Automatically set to False if number of global input features is
zero.
"""
# Initialize from base class
super(EncodeProcessDecode, self).__init__()
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# Check number of input features
if int(n_node_in) < 1 and int(n_edge_in) < 1 and int(n_global_in) < 1:
raise RuntimeError(f'Impossible to setup model without node '
f'({int(n_node_in)}), edge ({int(n_edge_in)}) '
f'and global ({int(n_global_in)}) input '
f'features.')
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# Store number of message-passing steps
self._n_message_steps = int(n_message_steps)
# Store number of output features
self._n_node_out = n_node_out
self._n_edge_out = n_edge_out
self._n_global_out = n_global_out
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# Set node update function hidden layer input size
if int(n_node_in) < 1:
# Overwrite hidden layer input size when number of node input
# features is zero
n_node_hidden_in = 0
# Turn off node residual connections
is_node_res_connect = False
else:
n_node_hidden_in = hidden_layer_size
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# Set edge update function hidden layer input size
if int(n_edge_in) < 1:
# Overwrite hidden layer input size when number of edge input
# features is zero
n_edge_hidden_in = 0
# Turn off edge residual connections
is_edge_res_connect = False
else:
n_edge_hidden_in = hidden_layer_size
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# Set global update function hidden layer input size
if int(n_global_in) < 1:
# Overwrite hidden layer input size when number of global input
# features is zero
n_global_hidden_in = 0
# Turn off global residual connections
is_global_res_connect = False
else:
n_global_hidden_in = hidden_layer_size
# Set global update function hidden layer output size
if int(n_global_out) < 1:
# Overwrite hidden layer output size when number of global output
# features is zero
n_global_hidden_out = 0
else:
n_global_hidden_out = hidden_layer_size
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# Set encoder normalization layer
is_enc_norm_layer = False
# Set processor normalization layer
is_pro_norm_layer = False
# Set decoder normalization layer
is_dec_norm_layer = False
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# Set model encoder
self._encoder = \
Encoder(n_hidden_layers=enc_n_hidden_layers,
hidden_layer_size=hidden_layer_size,
n_node_in=n_node_in, n_node_out=hidden_layer_size,
n_time_node=n_time_node,
n_edge_in=n_edge_in, n_edge_out=hidden_layer_size,
n_time_edge=n_time_edge,
n_global_in=n_global_in, n_global_out=n_global_hidden_out,
n_time_global=n_time_global,
node_hidden_activation=enc_node_hidden_activation,
node_output_activation=enc_node_output_activation,
edge_hidden_activation=enc_edge_hidden_activation,
edge_output_activation=enc_edge_output_activation,
global_hidden_activation=enc_global_hidden_activation,
global_output_activation=enc_global_output_activation,
is_norm_layer=is_enc_norm_layer, is_skip_unset_update=True)
# Set model processor if positive number of message-passing steps
if self._n_message_steps > 0:
self._processor = \
Processor(n_message_steps=n_message_steps,
n_node_out=hidden_layer_size,
n_edge_out=hidden_layer_size,
n_global_out=n_global_hidden_out,
n_hidden_layers=pro_n_hidden_layers,
hidden_layer_size=hidden_layer_size,
n_node_in=n_node_hidden_in,
n_edge_in=n_edge_hidden_in,
n_global_in=n_global_hidden_in,
edge_to_node_aggr=pro_edge_to_node_aggr,
node_to_global_aggr=pro_node_to_global_aggr,
node_hidden_activation=pro_node_hidden_activation,
node_output_activation=pro_node_output_activation,
edge_hidden_activation=pro_edge_hidden_activation,
edge_output_activation=pro_edge_output_activation,
global_hidden_activation=\
pro_global_hidden_activation,
global_output_activation=\
pro_global_output_activation,
n_time_node=n_time_node,
n_time_edge=n_time_edge,
n_time_global=n_time_global,
is_norm_layer=is_pro_norm_layer,
is_node_res_connect=is_node_res_connect,
is_edge_res_connect=is_edge_res_connect,
is_global_res_connect=is_global_res_connect)
else:
self._processor = None
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# If nodes/edges encode time series data, and the remaining features
# do not, then, after passing through the process, i.e., a graph
# interaction network, the remaining features also become time series
if n_time_edge > 0:
n_time_node = n_time_edge
n_time_global = n_time_edge
elif n_time_node > 0:
n_time_edge = n_time_node
n_time_global = n_time_node
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# Set model decoder
self._decoder = \
Decoder(n_hidden_layers=dec_n_hidden_layers,
hidden_layer_size=hidden_layer_size,
n_node_in=hidden_layer_size,
n_edge_in=hidden_layer_size,
n_global_in=hidden_layer_size,
n_node_out=n_node_out,
n_edge_out=n_edge_out,
n_global_out=n_global_out,
node_hidden_activation=dec_node_hidden_activation,
node_output_activation=dec_node_output_activation,
edge_hidden_activation=dec_edge_hidden_activation,
edge_output_activation=dec_edge_output_activation,
global_hidden_activation=dec_global_hidden_activation,
global_output_activation=dec_global_output_activation,
n_time_node=n_time_node, n_time_edge=n_time_edge,
n_time_global=n_time_global,
is_norm_layer=is_dec_norm_layer,
is_skip_unset_update=True)
# -------------------------------------------------------------------------
[docs]
def forward(self, edges_indexes, node_features_in=None,
edge_features_in=None, global_features_in=None,
batch_vector=None):
"""Forward propagation.
Processor is skipped if number of message-passing steps is set to zero.
Parameters
----------
edges_indexes : torch.Tensor
Edges indexes matrix stored as torch.Tensor(2d) with shape
(2, n_edges), where the i-th edge is stored in edges_indexes[:, i]
as (start_node_index, end_node_index).
node_features_in : torch.Tensor, default=None
Nodes features input matrix stored as a torch.Tensor(2d) of shape
(n_nodes, n_features).
edge_features_in : torch.Tensor, default=None
Edges features input matrix stored as a torch.Tensor(2d) of shape
(n_edges, n_features).
global_features_in : torch.Tensor, default=None
Global features input matrix stored as a torch.Tensor(2d) of shape
(1, n_features). Ignored if global update function is not setup.
batch_vector : torch.Tensor, default=None
Batch vector stored as torch.Tensor(1d) of shape (n_nodes,),
assigning each node to a specific batch subgraph. Required to
process a graph holding multiple isolated subgraphs when batch
size is greater than 1.
Returns
-------
node_features_out : {torch.Tensor, None}
Nodes features output matrix stored as a torch.Tensor(2d) of shape
(n_nodes, n_features).
edge_features_out : {torch.Tensor, None}
Edges features output matrix stored as a torch.Tensor(2d) of shape
(n_edges, n_features).
global_features_out : {torch.Tensor, None}
Global features output matrix stored as a torch.Tensor(2d) of shape
(1, n_features).
"""
# Check input features matrices
if (node_features_in is None and edge_features_in is None
and global_features_in is None):
raise RuntimeError('Impossible to compute forward propagation of '
'model without node (None), edge (None) and '
'global (None) input features matrices.')
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# Perform encoding
node_features, edge_features, global_features = \
self._encoder(node_features_in=node_features_in,
edge_features_in=edge_features_in,
global_features_in=global_features_in,
batch_vector=batch_vector)
# Perform processing (message-passing steps)
if self._n_message_steps > 0:
# Compute message-passing step
node_features, edge_features, global_features = \
self._processor(edges_indexes=edges_indexes,
node_features_in=node_features,
edge_features_in=edge_features,
global_features_in=global_features,
batch_vector=batch_vector)
node_features_out, edge_features_out, global_features_out = \
self._decoder(node_features_in=node_features,
edge_features_in=edge_features,
global_features_in=global_features,
batch_vector=batch_vector)
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# Discard unsolicited output features
if self._n_node_out < 1:
node_features_out = None
if self._n_edge_out < 1:
edge_features_out = None
if self._n_global_out < 1:
global_features_out = None
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
return node_features_out, edge_features_out, global_features_out
# =============================================================================
class Encoder(GraphIndependentNetwork):
"""GNN-based encoder.
Encodes input state graph data into latent graph by means of a Graph
Independent Network. Node, edge and global features update functions are
implemented as multilayer feed-forward or recurrent neural networks and are
independent (no aggregation).
"""
pass
# =============================================================================
class Processor(torch_geometric.nn.MessagePassing):
"""GNN-based processor.
Performs a given number of graph message-passing steps to generate a
sequence of updated latent graphs as the information is propagated through
the graph neural network. All message-passing steps are performed by means
of an identical Graph Interaction Network (unshared parameters), where
node, edge and global features update functions are implemented as
multilayer feed-forward or recurrent neural networks.
Residual connections are adopted between the input and output latent
features of both nodes and edges at each message-passing step.
Attributes
----------
_processor : torch.nn.ModuleList
Sequence of graph neural networks.
_n_node_in : int
Number of node input features.
_n_node_out : int
Number of node output features.
_n_edge_in : int
Number of edge input features.
_n_edge_out : int
Number of edge output features.
_n_global_in : int
Number of global input features.
_n_global_out : int
Number of global output features.
_n_time_node : int
Number of discrete time steps of nodal features.
If greater than 0, then nodal input features include a time
dimension, and message passing layers are RNNs.
_n_time_edge : int
Number of discrete time steps of edge features.
If greater than 0, then edge input features include a time
dimension, and message passing layers are RNNs.
_n_time_global : int
Number of discrete time steps of global features.
If greater than 0, then global input features include a time
dimension, and message passing layers are RNNs.
Methods
-------
forward(self, edges_indexes, node_features_in=None, edge_features_in=None)
Forward propagation.
"""
def __init__(self, n_message_steps, n_node_out, n_edge_out,
n_hidden_layers, hidden_layer_size,
n_node_in=0, n_edge_in=0, n_global_in=0, n_global_out=0,
n_time_node=0, n_time_edge=0, n_time_global=0,
edge_to_node_aggr='add', node_to_global_aggr='add',
node_hidden_activation=torch.nn.Identity(),
node_output_activation=torch.nn.Identity(),
edge_hidden_activation=torch.nn.Identity(),
edge_output_activation=torch.nn.Identity(),
global_hidden_activation=torch.nn.Identity(),
global_output_activation=torch.nn.Identity(),
is_norm_layer = False, is_node_res_connect=False,
is_edge_res_connect=False, is_global_res_connect=False):
"""Constructor.
Parameters
----------
n_message_steps : int
Number of message-passing steps.
n_node_out : int
Number of node output features.
n_edge_out : int
Number of edge output features.
n_hidden_layers : int
Number of hidden layers of multilayer feed-forward neural network
update functions.
hidden_layer_size : int
Number of neurons of hidden layers of multilayer feed-forward
neural network update functions.
n_node_in : int, default=0
Number of node input features.
n_edge_in : int, default=0
Number of edge input features.
n_global_in : int, default=0
Number of global input features.
n_global_out : int, default=0
Number of global output features.
edge_to_node_aggr : {'add',}, default='add'
Edge-to-node aggregation scheme.
node_to_global_aggr : {'add', 'mean'}, default='add'
Node-to-global aggregation scheme.
node_hidden_activation : torch.nn.Module, default=torch.nn.Identity
Hidden unit activation function of node update function (multilayer
feed-forward neural network). Defaults to identity (linear) unit
activation function.
node_output_activation : torch.nn.Module, default=torch.nn.Identity
Output unit activation function of node update function (multilayer
feed-forward neural network). Defaults to identity (linear) unit
activation function.
edge_hidden_activation : torch.nn.Module, default=torch.nn.Identity
Hidden unit activation function of edge update function (multilayer
feed-forward neural network). Defaults to identity (linear) unit
activation function.
edge_output_activation : torch.nn.Module, default=torch.nn.Identity
Output unit activation function of edge update function (multilayer
feed-forward neural network). Defaults to identity (linear) unit
activation function.
global_hidden_activation : torch.nn.Module, default=torch.nn.Identity
Hidden unit activation function of global update function
(multilayer feed-forward neural network). Defaults to identity
(linear) unit activation function.
global_output_activation : torch.nn.Module, default=torch.nn.Identity
Output unit activation function of global update function
(multilayer feed-forward neural network). Defaults to identity
(linear) unit activation function.
n_time_node : int, default=0
Number of discrete time steps of nodal features.
If greater than 0, then nodal input features include a time
dimension and message passing layers are RNNs.
n_time_edge : int, default=0
Number of discrete time steps of edge features.
If greater than 0, then edge input features include a time
dimension and message passing layers are RNNs.
n_time_global : int, default=0
Number of discrete time steps of global features.
If greater than 0, then global input features include a time
dimension and message passing layers are RNNs.
is_norm_layer : bool, default=False
If True, then add normalization layer to node, edge and global
update functions.
is_node_res_connect : bool, default=False
Add residual connections between nodes input and output features
if True, False otherwise. Number of input and output features must
match to process residual connections.
is_edge_res_connect : bool, default=False
Add residual connections between edges input and output features
if True, False otherwise. Number of input and output features must
match to process residual connections.
is_global_res_connect : bool, default=False
Add residual connections between global input and output features
if True, False otherwise. Number of input and output features must
match to process residual connections.
"""
# Initialize from base class
super(Processor, self).__init__()
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# Set number of features
self._n_node_in = int(n_node_in)
self._n_node_out = int(n_node_out)
self._n_edge_in = int(n_edge_in)
self._n_edge_out = int(n_edge_out)
self._n_global_in = int(n_global_in)
self._n_global_out = int(n_global_out)
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# Set input with time dimension
self._n_time_node = n_time_node
self._n_time_edge = n_time_edge
self._n_time_global = n_time_global
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# Check number of input features
if (self._n_node_in < 1 and self._n_edge_in < 1
and self._n_global_in < 1):
raise RuntimeError(f'Impossible to setup model without node '
f'({self._n_node_in}), edge '
f'({self._n_edge_in}), or global '
f'({self._n_global_in}) input features.')
# Check number of output features
if self._n_node_out < 1 or self._n_edge_out < 1:
raise RuntimeError(f'Number of node ({self._n_node_out}) and '
f'edge ({self._n_edge_out}) output features '
f'must be greater than 0.')
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# Check number of message-passing steps
if n_message_steps < 1:
raise RuntimeError('Number of message-passing steps must be at '
'least 1.')
elif (n_message_steps > 1
and ((n_node_in > 0 and n_node_in != n_node_out)
or (n_edge_in > 0 and n_edge_in != n_edge_out)
or (n_global_in > 0 and n_global_in != n_global_out))):
raise RuntimeError('Number of node/edge/global input and output '
'features must match to process multiple '
'message-passing steps in sequence.')
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# Set node residual connections
if is_node_res_connect and n_node_in != n_node_out:
raise RuntimeError('Number of node input and output features '
'must match to process residual '
'connections.')
else:
self._is_node_res_connect = is_node_res_connect
# Set edge residual connections
if is_edge_res_connect and n_edge_in != n_edge_out:
raise RuntimeError('Number of edge input and output features '
'must match to process residual '
'connections.')
else:
self._is_edge_res_connect = is_edge_res_connect
# Set global residual connections
if is_global_res_connect and n_global_in != n_global_out:
raise RuntimeError('Number of global input and output features '
'must match to process residual connections.')
else:
self._is_global_res_connect = is_global_res_connect
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# Set sequence of identical Graph Interaction Networks
self._processor = torch.nn.ModuleList(
[GraphInteractionNetwork(
n_node_out=n_node_out, n_edge_out=n_edge_out,
n_hidden_layers=n_hidden_layers,
hidden_layer_size=hidden_layer_size,
n_node_in=n_node_in, n_edge_in=n_edge_in,
n_global_in=n_global_in, n_global_out=n_global_out,
edge_to_node_aggr=edge_to_node_aggr,
node_to_global_aggr=node_to_global_aggr,
node_hidden_activation=node_hidden_activation,
node_output_activation=node_output_activation,
edge_hidden_activation=edge_hidden_activation,
edge_output_activation=edge_output_activation,
n_time_node = n_time_node,
n_time_edge = n_time_edge,
n_time_global = n_time_global,
is_norm_layer=is_norm_layer,
global_hidden_activation=global_hidden_activation,
global_output_activation=global_output_activation)
for _ in range(n_message_steps)])
# -------------------------------------------------------------------------
def forward(self, edges_indexes, node_features_in=None,
edge_features_in=None, global_features_in=None,
batch_vector=None):
"""Forward propagation.
Parameters
----------
edges_indexes : torch.Tensor
Edges indexes matrix stored as torch.Tensor(2d) with shape
(2, n_edges), where the i-th edge is stored in edges_indexes[:, i]
as (start_node_index, end_node_index).
node_features_in : torch.Tensor, default=None
Nodes features input matrix stored as a torch.Tensor(2d) of shape
(n_nodes, n_features). If None, the edge-to-node aggregation is
only built up to the highest receiver node index according with
edges_indexes. To preserve total number of nodes in edge-to-node
aggregation, pass torch.empty(n_nodes, 0) instead of None.
edge_features_in : torch.Tensor, default=None
Edges features input matrix stored as a torch.Tensor(2d) of shape
(n_edges, n_features).
global_features_in : torch.Tensor, default=None
Global features input matrix stored as a torch.Tensor(2d) of shape
(1, n_features).
batch_vector : torch.Tensor, default=None
Batch vector stored as torch.Tensor(1d) of shape (n_nodes,),
assigning each node to a specific batch subgraph. Required to
process a graph holding multiple isolated subgraphs when batch
size is greater than 1.
Returns
-------
node_features_out : torch.Tensor
Nodes features output matrix stored as a torch.Tensor(2d) of shape
(n_nodes, n_features).
edge_features_out : torch.Tensor
Edges features output matrix stored as a torch.Tensor(2d) of shape
(n_edges, n_features).
global_features_out : {torch.Tensor, None}
Global features output matrix stored as a torch.Tensor(2d) of shape
(1, n_features).
"""
# Check number of nodal, edge and global features
if self._n_time_node > 0:
if node_features_in is not None and node_features_in.numel() > 0 \
and node_features_in.shape[-1] != \
self._n_node_in*self._n_time_node:
raise RuntimeError(f'Mismatch of number of node '
f'features of model '
f'({self._n_node_in*self._n_time_node}) '
f'and nodes input features '
f'matrix ({node_features_in.shape[1]}).')
else:
if node_features_in is not None and node_features_in.numel() > 0 \
and node_features_in.shape[-1] != self._n_node_in:
raise RuntimeError(f'Mismatch of number of node '
f'features of model ({self._n_node_in}) '
f'and nodes input features '
f'matrix ({node_features_in.shape[1]}).')
if self._n_time_edge > 0 :
if edge_features_in is not None \
and edge_features_in.shape[-1] != \
self._n_edge_in * self._n_time_edge:
raise RuntimeError(f'Mismatch of number of edge '
f'features of model '
f'({self._n_edge_in*self._n_time_edge}) '
f'and edges input features '
f'matrix ({edge_features_in.shape[1]}).')
else:
if edge_features_in is not None \
and edge_features_in.shape[-1] != self._n_edge_in:
raise RuntimeError(f'Mismatch of number of edge features of '
f'model ({self._n_edge_in}) and edges '
f'input features matrix '
f'({edge_features_in.shape[1]}).')
if self._n_time_global > 0 :
if global_features_in is not None \
and global_features_in.shape[-1] != \
self._n_global_in*self._n_time_global:
raise RuntimeError(f'Mismatch of number of global features of '
f'model ('
f'{self._n_global_in*self._n_time_global}) '
f' and global input features matrix '
f'({global_features_in.shape[1]}).')
else:
if global_features_in is not None \
and global_features_in.shape[-1] != self._n_global_in:
raise RuntimeError(f'Mismatch of number of global features of '
f'model ({self._n_global_in}) and global '
f'input features matrix '
f'({global_features_in.shape[1]}).')
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# Collect number of nodes to preserve total number of nodes in
# edge-to-node aggregation when number of node input features is zero
n_nodes = None
if self._n_node_in < 1 and node_features_in is not None:
n_nodes = node_features_in.shape[-2]
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# Initialize residual update node features
node_features_out = None
if node_features_in is not None:
node_features_out = node_features_in.clone()
# Initialize residual update edge features
edge_features_out = None
if edge_features_in is not None:
edge_features_out = edge_features_in.clone()
# Initialize residual update global features
global_features_out = None
if global_features_in is not None:
global_features_out = global_features_in.clone()
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# Loop over message-passing steps
for i, gnn_model in enumerate(self._processor):
# Save features matrix (residual connection)
if self._is_node_res_connect:
node_features_res = node_features_out.clone()
if self._is_edge_res_connect and self._n_edge_in > 0:
edge_features_res = edge_features_out.clone()
if self._is_global_res_connect:
global_features_res = global_features_out.clone()
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# Perform graph neural network message-passing step
node_features_out, edge_features_out, global_features_out = \
gnn_model(edges_indexes=edges_indexes,
node_features_in=node_features_out,
edge_features_in=edge_features_out,
global_features_in=global_features_out,
batch_vector=batch_vector)
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# Check if last message-passing step
is_last_step = i == len(self._processor) - 1
# Discard node features output matrix except in the last
# message-passing step
if self._n_node_in < 1 and not is_last_step:
if isinstance(n_nodes, int):
node_features_out = torch.empty(n_nodes, 0)
else:
node_features_out = None
# Discard edge features output matrix except in the last
# message-passing step
if self._n_edge_in < 1 and not is_last_step:
edge_features_out = None
# Discard global features output matrix except in the last
# message-passing step
if self._n_global_in < 1 and not is_last_step:
global_features_out = None
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# Add residual connections to features output
if self._is_node_res_connect and self._n_node_in > 0:
node_features_out += node_features_res
if self._is_edge_res_connect and self._n_edge_in > 0:
edge_features_out += edge_features_res
if self._is_global_res_connect and self._n_global_in > 0:
global_features_out += global_features_res
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
return node_features_out, edge_features_out, global_features_out
# =============================================================================
class Decoder(GraphIndependentNetwork):
"""GNN-based decoder.
Decodes latent graph into output graph by means of a Graph Independent
Network. Node, edge and global features update functions are implemented as
multilayer feed-forward or recurrent neural networks and are independent
(no aggregation).
"""
pass