Re-implementation of MeshGraphNet for turbulent flow around a cylinder

Author(s): Guillaume Broggi

This example re-implements Google DeepMind’s MeshGraphNet [1] to predict the turbulent flow around a cylinder.

Note

This example runs better on a GPU. The reported result where obtained on a H100 GPU with an average training time of 35 min per epoch.

Tip

Run this notebook in Google Colab. Execute the following cell to install requirements.

Install required packages for Google Colab (run this cell in Colab environment):

try:
    # Check if running in Google Colab
    import google.colab

    # Install PyG as in PyG documentation
    %pip uninstall --yes torch torchaudio torchvision torchtext torchdata
    %pip install torch==2.4.1 torchvision==0.19.1 torchaudio==2.4.1 --index-url https://download.pytorch.org/whl/cu124
    %pip install torch-scatter -f https://data.pyg.org/whl/torch-2.4.1+cu124.html
    %pip install torch-sparse -f https://data.pyg.org/whl/torch-2.4.1+cu124.html
    %pip install torch_geometric==2.6.1 -f https://data.pyg.org/whl/torch-2.4.1+cu124.html
    
    # Install Graphorge
    !git clone https://github.com/bessagroup/graphorge.git
    %pip install -e ./graphorge
except:
    pass

# tfrecord, Plotly, and PyVista are required for this example
try:
    from tfrecord.torch.dataset import TFRecordDataset
except ImportError:
    %pip install 'tfrecord[torch]'

try:
    import plotly.graph_objects as go
except ImportError:
    %pip install plotly

try:
    import pyvista as pv
except ImportError:
    %pip install pyvista trame trame-vtk trame-vuetify trame-components 'imageio[ffmpeg]'

try:
    import cmcrameri.cm as cm
except ImportError:
    %pip install cmcrameri

Imports required to run the notebook:

# Standard
import json
import logging
from pathlib import Path
import pickle
import shutil
import sys

# Third-party
import cmcrameri.cm as cm
import numpy as np
from plotly import graph_objects as go
import plotly.io as pio
import pyvista as pv
from sklearn.metrics import r2_score
from sklearn.preprocessing import OneHotEncoder
from tfrecord.torch.dataset import TFRecordDataset
import torch
from tqdm.auto import tqdm, trange

# Locals
from graphorge.gnn_base_model.data.graph_data import GraphData
from graphorge.gnn_base_model.data.graph_dataset import GNNGraphDataset
from graphorge.gnn_base_model.model.gnn_model import GNNEPDBaseModel
from graphorge.gnn_base_model.train.training import train_model
from graphorge.gnn_base_model.predict.prediction import predict

# Make utilities available in the path
try:
    # Check if running in Google Colab
    import google.colab
    utils_path = Path("graphorge/benchmarks/utils")
except:
    utils_path = (
        Path().resolve().parents[1] / "utils"
    ) 
sys.path.insert(0, str(utils_path))

# Import required utilities
from utils import download_file, plotly_layout, plot_interactive_graph, plot_loss_history, graph_to_pyvista_mesh

# Set logging configuration
logger = logging.getLogger()
logger.setLevel(logging.INFO)

# Prepare directories
directories = dict(
    raw_data=Path("0_data"),
    processed_data=Path("1_dataset"),
    training=Path("2_training_dataset"),
    validation=Path("3_validation_dataset"),
    model=Path("4_model"),
    in_distribution=Path("5_testing_id_dataset"),
    prediction=Path("7_prediction"),
)

# Create directories if they do not exist
for directory in directories.values():
    directory.mkdir(parents=True, exist_ok=True)

# Add `notebook_connected` to the default renderer to ensure the figure are rendered in the the documentation
plotly_renderer = f"notebook_connected+{pio.renderers.default}"

# Set device
device_type="cuda" if torch.cuda.is_available() else "cpu"

1. Raw data

The dataset describes pore networks extracted from porous media. See the dataset repository for more information.

Problem definition

The dataset is downloaded from the base URL below. It consists of a JSON file, where each entry is a porous network containing:

  • id: the ID of the network.

  • pore_coordinates (x, y): the coordinates of the pores defining the network.

  • channel_idx: channels connecting the pores together.

  • channel_width (w): the width of each channel.

  • flow_rate: the flow rate through the network when applying a fixed pressure gradient.

Let’s download the dataset and inspect it. Note that the data are, by nature, well described by a graph.

# Set the base URL for downloading the dataset
from curses import meta

base_url = "https://storage.googleapis.com/dm-meshgraphnets/cylinder_flow"

# Download the dataset files
for file in ["meta.json", "train.tfrecord", "valid.tfrecord", "test.tfrecord"]:
    download_file(
        file=file,
        base_url=base_url,
        dest_path=directories["raw_data"],
    )

with open(directories["raw_data"] / "meta.json", "r") as file:
    metadata = json.load(file)
description = {field: "byte" for field in metadata["field_names"]}


def decode_features(features, metadata):
    out = {}
    for feature, bytes in features.items():
        dtype = metadata["features"][feature]["dtype"]
        shape = metadata["features"][feature]["shape"]
        data_array = np.frombuffer(bytes, dtype=getattr(np, dtype)).copy()
        data_array = data_array.reshape(shape)

        out[feature] = data_array
    return out


dataset_name_to_file = dict(
    training="train.tfrecord",
    validation="valid.tfrecord",
    in_distribution="test.tfrecord",
)

2. Graph dataset

First, we define a function to build the graph. This operation is at the core of the Graphorge framework.

def build_graphorge_graph(
    node_coordinates,
    edge_indexes,
    edge_features,
    node_features,
    node_targets,
):
    """
    Builds a GraphData object from the provided node coordinates, edge indexes,
    node features, and global targets.

    Parameters
    ----------
    node_coordinates : np.ndarray
        An array of shape (n_nodes, n_dim) containing the coordinates of the
        nodes.
    edge_indexes : np.ndarray
        An array of shape (n_edges, 2) containing the indices of the edges.
    edge_features : np.ndarray
        An array of shape (n_edges, n_features) containing the features of the
        edges.
    node_features : np.ndarray
        An array of shape (n_nodes, n_features) containing the features of the
        nodes.
    node_targets : np.ndarray
        An array of shape (n_nodes, n_targets) containing the targets for the
        nodes.

    Returns
    -------
    GraphData
        An instance of GraphData containing the graph structure and features.
    """
    # Instantiate the graph data (graphorge)
    # Note that the n_dim is set to 2 as the geometry is 2D
    graph_data = GraphData(n_dim=2, nodes_coords=node_coordinates)

    # Set graph edges indexes. The data does not contain duplicated edges,
    # so we can set is_unique to False to avoid unnecessary processing
    graph_data.set_graph_edges_indexes(edges_indexes_mesh=edge_indexes, is_unique=False)

    # Set graph edge feature matrix
    graph_data.set_edge_features_matrix(edge_features_matrix=edge_features)

    # Set graph node features matrix
    graph_data.set_node_features_matrix(node_features_matrix=node_features)

    # Set graph node targets matrix
    graph_data.set_node_targets_matrix(node_targets_matrix=node_targets)

    return graph_data

Note

The dataset defines 4 node types, encoded by integers:

  • Interior nodes (the actual computational domain): 0

  • Wall nodes with no-slip condition: 4

  • Inlet nodes: 5

  • Outlet nodes: 6

See PhysicsNeMo’s implementation for more details.

# Defines node types
domain_nodes = 0
wall_nodes = 4
inlet_nodes = 5
outlet_nodes = 6

Let’s define a function to process the data:

def process_cylinder_flow_sample(sample, time_steps=1, velocity_noise=0, rng=None):
    """
    Processes a cylinder flow geometry to create a GraphData object.

    Parameters
    ----------
    sample : dict
        A dictionary containing the cylinder flow data with keys:
        - "mesh_pos": Coordinates of the mesh nodes.
        - "node_type": Type of each node (domain, wall, inlet, outlet).
        - "velocity": Velocity field at each time step.
        - "pressure": Pressure field at each time step.
        - "cells": Connectivity of the mesh cells.
    time_steps : int, optional
        The number of time steps to process. Default is 1.
    velocity_noise : float, optional
        Standard deviation of Gaussian noise to add to the velocity field at
        each time step. Default is 0 (no noise).
    rng : np.random.Generator, optional
        A NumPy random number generator for reproducibility. If None, the
        default NumPy random generator is used.
    Returns
    -------
    GraphData
        An instance of GraphData containing the processed sample mesh data.
    """

    # Sample data is stored as [time, data] (see DeepMind MeshGraphNet repository)
    # Extract the node coordinates, edge indexes, and global targets
    # The mesh is invariant over time, so we only have one time step
    node_coordinates = sample["mesh_pos"][0]
    # Cell connectivity is defined by the "cells" entry
    # Each cell is a triangle defined by 3 node indexes
    # Node id start at 0 in the data, but Graphorge expect them to start at 1
    connected_nodes = [
        (cell[i] + 1, cell[j] + 1)
        for cell in sample["cells"][0]
        for i, j in [(0, 1), (1, 2), (2, 0)]
    ]
    # Get the edges indexes from the connected nodes
    # Note the method returns undirected edges
    edge_indexes = GraphData.get_edges_indexes_mesh(connected_nodes)
    # Prepare edge features as the distance vector between connected nodes
    # and the euclidean distance (norm of the distance vector)
    distance_vector = (
        node_coordinates[edge_indexes[:, 1]] - node_coordinates[edge_indexes[:, 0]]
    )
    distance_norm = np.linalg.norm(distance_vector, axis=1, keepdims=True)
    edge_features = np.concatenate((distance_vector, distance_norm), axis=1)

    # Prepare the node type one-hot which is static over time
    encoder = OneHotEncoder(sparse_output=False, dtype=int)
    encoder.fit([[domain_nodes], [wall_nodes], [inlet_nodes], [outlet_nodes]])
    node_type = encoder.transform(sample["node_type"].reshape(-1, 1))

    if velocity_noise:
        # Create a random number generator if not provided
        rng = np.random.default_rng(rng)
        # Create a mask for the domain nodes where noise will be added
        noise_mask = sample["node_type"][0].ravel() == 0

    # Iterate over the time steps to create a graph for each time step
    graphs = []
    for time_step in range(time_steps):
        # Extract the velocity field at the current time step
        velocity_field = sample["velocity"][time_step].copy()

        if velocity_noise:
            # Add Gaussian noise to the velocity field at the domain nodes
            noise = rng.normal(
                loc=0.0,
                scale=velocity_noise,
                size=velocity_field[noise_mask].shape,
            )
            velocity_field[noise_mask] += noise

        # Prepare the node features by concatenating the velocity field
        # and the one-hot encoded node type
        node_features = np.concatenate((velocity_field, node_type), axis=1)

        # Prepare the node target, note the velocity update is predicted
        # The velocity update accounts for the noise if any, see MeshGraphNet
        # paper, sec. A.2.2
        node_targets = np.concatenate(
            (
                sample["velocity"][time_step + 1] - velocity_field,
                sample["pressure"][time_step + 1],
            ),
            axis=1,
        )

        # Build the graph data
        graph_data = build_graphorge_graph(
            node_coordinates=node_coordinates,
            edge_indexes=edge_indexes,
            edge_features=edge_features,
            node_features=node_features,
            node_targets=node_targets,
        )

        # Set metadata
        graph_data.set_metadata(
            dict(
                node_feature_names=[
                    "u",
                    "v",
                    "one_hot",
                    "one_hot",
                    "one_hot",
                    "one_hot",
                ],
                edge_feature_names=["dx", "dy", "dist"],
                node_target_names=["du", "dv", "p"],
            )
        )

        graphs.append(graph_data)

    return graphs

We can visualize the obtained graphs.

Note

We use an interactive visualization for pedagogical purpose. Interactive visualization may require significant resources when working with large graphs. Note that the plot_graph() method is available in Graphorge for fast, static plots.

# Get the first sample from the training dataset
tfrecord_path = directories["raw_data"] / dataset_name_to_file["training"]
dataset = TFRecordDataset(
    data_path=tfrecord_path,
    index_path=None,
    description=description,
    transform=lambda rec: decode_features(rec, metadata),
)
sample = next(iter(dataset))
# Process the first time step
graph_data = process_cylinder_flow_sample(
    sample, time_steps=1, velocity_noise=2e-2, rng=np.random.default_rng(42)
)[0]
# Visualize the graph
figure = plot_interactive_graph(graph_data, node_size=5)
figure.update_layout(
    legend=dict(
        yanchor="top",
        y=0.98,
        xanchor="right",
        x=0.98,
        title_text=None,
        traceorder="normal",
    ),
    width=800,
)
figure.show(renderer=plotly_renderer)

Tip

Hover the nodes and edges to inspect features.

Finally, we can generate the dataset.

# Parameters
samples = 0  # Number of samples to process
time_steps = 0  # Number of time steps to process for each sample
rng = np.random.default_rng(42)

# Noise standard deviation, see MeshGraphNet paper, sec. A.2.2
velocity_noise_std = dict(training=2e-2, validation=0, in_distribution=0)

for dataset_name in ["training", "validation", "in_distribution"]:

    tfrecord_path = directories["raw_data"] / dataset_name_to_file[dataset_name]

    dataset = TFRecordDataset(
        data_path=tfrecord_path,
        index_path=None,
        description=description,
        transform=lambda rec: decode_features(rec, metadata),
    )

    dataset_sample_files = []
    dataset_iterator = iter(dataset)
    # Iterate over each samples in the dataset
    for sample_id in trange(samples, desc=f"Processing {dataset_name}"):
        try:
            sample = next(dataset_iterator)
        except StopIteration:
            print("Reached the end of the dataset.")
            break
        # Process the porous networks data to get the graph data
        pyg_graphs = process_cylinder_flow_sample(
            sample=sample,
            time_steps=time_steps,
            velocity_noise=velocity_noise_std[dataset_name],
            rng=rng,
        )

        for time_step, pyg_graph in enumerate(pyg_graphs):
            sample_file_name = (
                f"cylinder_flow_graph-{sample_id:04d}_time-{time_step:03d}.pt"
            )

            # Set sample file path
            sample_file_path = directories[dataset_name] / sample_file_name
            # Save graph sample file
            torch.save(pyg_graph, sample_file_path)
            # Save graph sample file path
            dataset_sample_files.append(sample_file_path)

    # Create a dataset from the processed samples
    dataset = GNNGraphDataset(
        directories[dataset_name],
        dataset_sample_files=dataset_sample_files,
        dataset_basename="cylinder_flow_dataset_",
        is_store_dataset=False,
    )

    # Save the dataset
    _ = dataset.save_dataset(is_append_n_sample=True)

Note

Graphs require significant amount of memory. The graphs generated from this dataset would not fit in the memory of a H100 GPU.

Graphorge reduces the memory footprint required to train the network by saving and loading individual graphs to and from disk.

This approach imply an overhead from the I/O disk operations. Training this MGN implementation on OSCAR HPC requires about 50 min per epochs.

Tip

The raw data (about 10 GB) fit in memory. The mesh is static and does not require to be stored for each time step. In such a situation, it may be advantageous to define a custom dataset that generates the graphs on the fly, following Google DeepMind and PhysicsNeMo approach.

# Custom dataset class for the cylinder flow samples
class CylinderFlowDataset(torch.utils.data.Dataset):
    def __init__(
        self,
        tfrecord_path,
        samples=600,  # Number of samples to process
        time_steps=400,  # Number of time steps to process for each sample
        train=False,
        velocity_noise_std=0,
        save_cells=False,
        rng=None,
    ):
        # Set attributes
        self.rng = np.random.default_rng(rng)
        self.train = train
        self.velocity_noise = velocity_noise_std

        # Prepare lists to store the data
        self.pyg_graphs = []
        self.noise_masks = []
        self.velocity_fields = []
        self.node_targets = []
        self.sample_idx_to_graph_idx = []

        # Load the TF dataset
        raw_dataset = TFRecordDataset(
            data_path=tfrecord_path,
            index_path=None,
            description=description,
            transform=lambda rec: decode_features(rec, metadata),
        )

        # Create a dataset iterator
        dataset_iterator = iter(raw_dataset)

        # Iterate over each samples in the dataset
        for sample_id in trange(samples, desc=f"Processing {dataset_name}"):
            try:
                sample = next(dataset_iterator)
            except StopIteration:
                print("Reached the end of the dataset.")
                break

            # Sample data is stored as [time, data] (see DeepMind MeshGraphNet
            # repository)

            # Extract the node coordinates
            # The mesh is invariant over time, so we only have one time step
            node_coordinates = sample["mesh_pos"][0]

            # Cell connectivity is defined by the "cells" entry
            # Each cell is a triangle defined by 3 node indexes
            # Node id start at 0 in the data, but Graphorge expect them to start at 1
            connected_nodes = [
                (cell[i] + 1, cell[j] + 1)
                for cell in sample["cells"][0]
                for i, j in [(0, 1), (1, 2), (2, 0)]
            ]

            # Get the edges indexes from the connected nodes
            # Note the method returns undirected edges
            edge_indexes = GraphData.get_edges_indexes_mesh(connected_nodes)

            # Prepare edge features as the distance vector between connected
            # nodes and the euclidean distance (norm of the distance vector)
            distance_vector = (
                node_coordinates[edge_indexes[:, 1]]
                - node_coordinates[edge_indexes[:, 0]]
            )
            distance_norm = np.linalg.norm(distance_vector, axis=1, keepdims=True)

            edge_features = np.concatenate((distance_vector, distance_norm), axis=1)

            # Prepare the node type one-hot which is static over time
            encoder = OneHotEncoder(sparse_output=False, dtype=int)
            encoder.fit([[domain_nodes], [wall_nodes], [inlet_nodes], [outlet_nodes]])

            node_type = encoder.transform(sample["node_type"].reshape(-1, 1))

            if self.train and self.velocity_noise:
                # Create a mask for the domain nodes where noise will be added
                noise_mask = sample["node_type"][0].ravel() == 0

            for time_step in range(time_steps):
                # Extract the velocity field at the current time step
                velocity_field = sample["velocity"][time_step].copy()

                if self.train and self.velocity_noise:
                    noise = np.zeros_like(velocity_field)
                    # Add Gaussian noise to the velocity field at the domain nodes
                    noise = self.rng.normal(
                        loc=0.0,
                        scale=self.velocity_noise,
                        size=velocity_field.shape,
                    )
                    noise[~noise_mask] = 0.0

                    velocity_field += noise

                self.velocity_fields.append(
                    torch.tensor(velocity_field, dtype=torch.float32)
                )

                node_target = np.concatenate(
                    (
                        sample["velocity"][time_step + 1] - velocity_field,
                        sample["pressure"][time_step + 1],
                    ),
                    axis=1,
                )

                self.node_targets.append(torch.tensor(node_target, dtype=torch.float32))

                self.sample_idx_to_graph_idx.append(sample_id)

            node_features = np.concatenate(
                (np.empty_like(self.velocity_fields[-1]), node_type), axis=1
            )

            node_targets = np.empty_like(node_target)

            graph_data = build_graphorge_graph(
                node_coordinates=node_coordinates,
                edge_indexes=edge_indexes,
                edge_features=edge_features,
                node_features=node_features,
                node_targets=node_targets,
            )

            # Prepare metadata
            graph_metadata = dict(
                node_feature_names=[
                    "u",
                    "v",
                    "one_hot",
                    "one_hot",
                    "one_hot",
                    "one_hot",
                ],
                edge_feature_names=["dx", "dy", "dist"],
                node_target_names=["du", "dv", "p"],
            )

            if save_cells:
                graph_metadata["cells"] = sample["cells"][0]

            # Set metadata
            graph_data.set_metadata(graph_metadata)

            self.pyg_graphs.append(graph_data.get_torch_data_object())

    def __len__(self):
        return len(self.velocity_fields)

    def __getitem__(self, idx):
        graph_idx = self.sample_idx_to_graph_idx[idx]

        graph_data = self.pyg_graphs[graph_idx]

        # Update the node features
        graph_data.x[:, :2] = self.velocity_fields[idx]

        # Update the node targets
        graph_data.y = self.node_targets[idx]

        return graph_data
for dataset_name in ["training", "validation", "in_distribution"]:

    tfrecord_path = directories["raw_data"] / dataset_name_to_file[dataset_name]

    dataset = CylinderFlowDataset(
        tfrecord_path=tfrecord_path,
        samples=400,  # Number of samples to process
        time_steps=300,  # Number of time steps to process for each sample
        train=True if dataset_name == "training" else False,
        velocity_noise_std=2e-2,
        save_cells=True if dataset_name == "in_distribution" else False,
        rng=42,
    )

    # Save the dataset
    dataset_file_path = directories[dataset_name] / "cylinder_flow_dataset.pt"
    torch.save(dataset, dataset_file_path)
Reached the end of the dataset.
figure = plot_interactive_graph(dataset[0], node_size=5)
figure.update_layout(
    legend=dict(
        yanchor="top",
        y=0.98,
        xanchor="right",
        x=0.98,
        title_text=None,
        traceorder="normal",
    ),
    width=800,
)
figure.show(renderer=plotly_renderer)

3. GNN model architecture

# Set the GNN model parameters
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
gnn_architecture_parameters = dict(
    # Set number of node input and output features
    n_node_in=6,
    n_node_out=3,
    n_time_node=0,
    # Set number of edge input and output features
    n_edge_in=3,
    n_edge_out=0,
    n_time_edge=0,
    # Set number of global input and output features
    n_global_in=0,
    n_global_out=0,
    n_time_global=0,
    # Set number of message-passing steps (number of processor layers)
    n_message_steps=15,
    # Set number of FNN/RNN hidden layers
    enc_n_hidden_layers=2,
    pro_n_hidden_layers=2,
    dec_n_hidden_layers=2,
    # Set hidden layer size
    hidden_layer_size=128,
    # Set model directory
    model_directory=directories["model"],
    model_name="CylinderFlowMGN",
    # Set model input and output features normalization
    is_model_in_normalized=False,
    is_model_out_normalized=False,
    # Set aggregation schemes
    pro_edge_to_node_aggr="add",
    pro_node_to_global_aggr="add",
    # Set activation functions
    enc_node_hidden_activ_type="relu",
    enc_node_output_activ_type="identity",
    enc_edge_hidden_activ_type="relu",
    enc_edge_output_activ_type="identity",
    pro_node_hidden_activ_type="relu",
    pro_node_output_activ_type="identity",
    pro_edge_hidden_activ_type="relu",
    pro_edge_output_activ_type="identity",
    dec_node_hidden_activ_type="relu",
    dec_node_output_activ_type="identity",
    # Set device
    device_type=device_type,
)

4. GNN model training

# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# Load training datasets
train_dataset = torch.load(directories["training"] / "cylinder_flow_dataset.pt")
validation_dataset = torch.load(directories["validation"] / "cylinder_flow_dataset.pt")

training_parameters = dict(
    # Set number of epochs
    n_max_epochs=25,
    # Set batch size
    batch_size=1,
    # Set optimizer
    opt_algorithm="adam",
    # Set learning rate
    lr_init=1.0e-04,
    # Set learning rate scheduler
    lr_scheduler_type=None,
    lr_scheduler_kwargs=None,
    # Set loss function
    loss_nature="node_features_out",
    loss_type="mse",
    loss_kwargs=dict(),
    # Set data shuffling
    is_sampler_shuffle=True,
    # Set early stopping
    is_early_stopping=True,
    # Set early stopping parameters
    early_stopping_kwargs=dict(
        validation_dataset=validation_dataset,
        validation_frequency=1,
        trigger_tolerance=20,
        improvement_tolerance=1e-2,
    ),
    # Set seed
    seed=42,
    # Set verbosity
    is_verbose=True,
    tqdm_flavor="notebook",
    # Save loss history
    save_loss_every=1,
    # Set device
    device_type=device_type,
)

# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# Compute exponential decay (learning rate scheduler)
lr_end = 1.0e-5
gamma = (lr_end / training_parameters["lr_init"]) ** (
    1 / training_parameters["n_max_epochs"]
)
# Set learning rate scheduler
training_parameters["lr_scheduler_type"] = "explr"
training_parameters["lr_scheduler_kwargs"] = dict(gamma=gamma)


# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# Set model state loading
load_model_state = None
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# Training of GNN-based model
model, _, _ = train_model(
    dataset=train_dataset,
    model_init_args=gnn_architecture_parameters,
    load_model_state=load_model_state,
    **training_parameters
)

5. GNN model prediction

First, lets analyze the loss curves.

# Plot the loss history
figure = plot_loss_history(loss_file=directories["model"] / "loss_history_record.pkl")

figure.show(renderer=plotly_renderer)

We will define a custom rollout function to animate the predictions.

def predict_rollout_as_pyvista_movie(
    dataset, model, output_path, sample_idx, time_steps=300
):
    """
    Predicts a rollout using the provided dataset and model, and saves the
    results as a PyVista movie.

    Parameters
    ----------
    dataset : torch.utils.data.Dataset
        The dataset to use for predictions.

    model : GNNEPDBaseModel
        The GNN model to use for predictions.

    sample_idx : int
        The index of the sample in the dataset to start the rollout from.

    output_path : str
        The path to save the output movie.

    time_steps : int, optional
        The number of time steps to predict. Default is 300.

    Returns
    -------
    None
        The function saves the movie to the specified output path.
    """

    if (
        dataset.sample_idx_to_graph_idx[sample_idx]
        != dataset.sample_idx_to_graph_idx[sample_idx + time_steps - 1]
    ):
        raise ValueError(
            f"The provided index {sample_idx} and the number of time steps {time_steps} exceed the sample size."
        )

    pyg_graph = dataset[0].detach().cpu()
    mesh = graph_to_pyvista_mesh(graph=pyg_graph)

    updated_velocity_magnitudes = []
    predicted_velocity_magnitudes = []
    velocity_magnitude_errors = []

    for time_step in trange(time_steps, desc="Inference"):
        pyg_graph = dataset[time_step]

        # Get input features from input graph
        node_features_in, edge_features_in, global_features_in, edges_indexes = (
            model.get_input_features_from_graph(
                pyg_graph, is_normalized=model.is_model_in_normalized
            )
        )

        features_out, _, _ = model(
            node_features_in=node_features_in,
            edge_features_in=edge_features_in,
            edges_indexes=edges_indexes,
        )

        features_out = (
            model.data_scaler_transform(
                tensor=features_out,
                features_type="node_features_out",
                mode="denormalize",
            )
            .detach()
            .cpu()
        )

        pyg_graph = pyg_graph.detach().cpu()
        # Update velocity magnitude
        updated_velocity_magnitudes.append(
            np.linalg.norm(
                pyg_graph.x[:, :2].numpy() + pyg_graph.y[:, :2].numpy(), axis=1
            )
        )

        predicted_velocity_magnitudes.append(
            np.linalg.norm(
                pyg_graph.x[:, :2].numpy() + features_out[:, :2].numpy(), axis=1
            )
        )

        velocity_magnitude_errors.append(
            (predicted_velocity_magnitudes[-1] - updated_velocity_magnitudes[-1])
            / (updated_velocity_magnitudes[-1] + 1e-32)
        )

    mesh.point_data["velocity_magnitude"] = updated_velocity_magnitudes[0]
    mesh.point_data["predicted_velocity_magnitude"] = predicted_velocity_magnitudes[0]
    mesh.point_data["velocity_magnitude_error"] = velocity_magnitude_errors[0]

    for time_step in trange(time_steps, desc="Generating frames"):

        if time_step == 0:
            plotter = pv.Plotter(
                shape=(3, 1),
                notebook=False,
                off_screen=True,
                border=False,
                window_size=(1920, 1200),
            )
            # Add ground truth
            plotter.subplot(0, 0)
            plotter.add_text("Ground truth", font_size=15)
            plotter.add_mesh(
                mesh,
                scalars="velocity_magnitude",
                lighting=False,
                show_edges=False,
                cmap=cm.oslo,
                clim=[
                    np.min(updated_velocity_magnitudes),
                    np.max(updated_velocity_magnitudes),
                ],
                scalar_bar_args=dict(
                    title="Velocity magnitude",
                    vertical=True,
                    height=0.8,
                    width=0.04,
                    position_x=0.93,
                    position_y=0.1,
                ),
            )
            plotter.camera.tight(adjust_render_window=False, padding=0.1, view="xy")

            # Add predictions
            # A shallow copy of cube is made when plotting each scalar array since
            # a mesh can have only one active scalar, see PyVista documentation
            prediction_mesh = mesh.copy(deep=False)

            plotter.subplot(1, 0)
            plotter.add_text("Predictions", font_size=15)
            plotter.add_mesh(
                prediction_mesh,
                scalars="predicted_velocity_magnitude",
                lighting=False,
                show_edges=False,
                cmap=cm.oslo,
                clim=[
                    np.min(updated_velocity_magnitudes),
                    np.max(updated_velocity_magnitudes),
                ],
                scalar_bar_args=dict(
                    title="Velocity magnitude",
                    vertical=True,
                    height=0.8,
                    width=0.04,
                    position_x=0.93,
                    position_y=0.1,
                ),
            )
            plotter.camera.tight(adjust_render_window=False, padding=0.1, view="xy")

            # Add errors
            error_mesh = mesh.copy(deep=False)

            plotter.subplot(2, 0)
            plotter.add_text("Error", font_size=15)
            plotter.add_mesh(
                error_mesh,
                scalars="velocity_magnitude_error",
                lighting=False,
                show_edges=False,
                cmap=cm.vik,
                clim=[-0.05, 0.05],
                scalar_bar_args=dict(
                    title="Relative error",
                    vertical=True,
                    height=0.8,
                    width=0.04,
                    position_x=0.93,
                    position_y=0.1,
                ),
            )
            plotter.camera.tight(adjust_render_window=False, padding=0.1, view="xy")

            plotter.open_movie(output_path)
            plotter.write_frame()

        else:

            mesh.point_data["velocity_magnitude"] = updated_velocity_magnitudes[
                time_step
            ]

            prediction_mesh.point_data["predicted_velocity_magnitude"] = (
                predicted_velocity_magnitudes[time_step]
            )

            error_mesh.point_data["velocity_magnitude_error"] = (
                velocity_magnitude_errors[time_step]
            )
            # Write a frame and trigger a render
            plotter.write_frame()

    # Close and finalize the movie
    plotter.close()

We perform the rollout:

# Set logging configuration to WARNING to avoid excessive output from PyVista
logger = logging.getLogger()
logger.setLevel(logging.WARNING)

# Load the in-distribution dataset
test_dataset = torch.load(directories["in_distribution"] / "cylinder_flow_dataset.pt")

# Load the trained model from the best model state
model = GNNEPDBaseModel.init_model_from_file(directories["model"])
_ = model.load_model_state(load_model_state="best", is_remove_posterior=False)
# Set the model to evaluation mode
model.eval()

# Predict a rollout and save it as a PyVista movie
predict_rollout_as_pyvista_movie(
    dataset=test_dataset,
    model=model,
    output_path="_assets/cylinder_flow.mp4",
    sample_idx=0,
    time_steps=300,
)

Let’s inspect the result as saved on Graphorge’s main branch: