Re-implementation of MeshGraphNet for turbulent flow around a cylinder¶
Author(s): Guillaume Broggi
This example re-implements Google DeepMind’s MeshGraphNet [1] to predict the turbulent flow around a cylinder.
Refer to DeepMind’s repository for the data.
See NVIDIA’s PhysicsNeMo for another re-implementation.
Note
This example runs better on a GPU. The reported result where obtained on a H100 GPU with an average training time of 35 min per epoch.
Tip
Run this notebook in Google Colab. Execute the following cell to install requirements.
Install required packages for Google Colab (run this cell in Colab environment):
try:
# Check if running in Google Colab
import google.colab
# Install PyG as in PyG documentation
%pip uninstall --yes torch torchaudio torchvision torchtext torchdata
%pip install torch==2.4.1 torchvision==0.19.1 torchaudio==2.4.1 --index-url https://download.pytorch.org/whl/cu124
%pip install torch-scatter -f https://data.pyg.org/whl/torch-2.4.1+cu124.html
%pip install torch-sparse -f https://data.pyg.org/whl/torch-2.4.1+cu124.html
%pip install torch_geometric==2.6.1 -f https://data.pyg.org/whl/torch-2.4.1+cu124.html
# Install Graphorge
!git clone https://github.com/bessagroup/graphorge.git
%pip install -e ./graphorge
except:
pass
# tfrecord, Plotly, and PyVista are required for this example
try:
from tfrecord.torch.dataset import TFRecordDataset
except ImportError:
%pip install 'tfrecord[torch]'
try:
import plotly.graph_objects as go
except ImportError:
%pip install plotly
try:
import pyvista as pv
except ImportError:
%pip install pyvista trame trame-vtk trame-vuetify trame-components 'imageio[ffmpeg]'
try:
import cmcrameri.cm as cm
except ImportError:
%pip install cmcrameri
Imports required to run the notebook:
# Standard
import json
import logging
from pathlib import Path
import pickle
import shutil
import sys
# Third-party
import cmcrameri.cm as cm
import numpy as np
from plotly import graph_objects as go
import plotly.io as pio
import pyvista as pv
from sklearn.metrics import r2_score
from sklearn.preprocessing import OneHotEncoder
from tfrecord.torch.dataset import TFRecordDataset
import torch
from tqdm.auto import tqdm, trange
# Locals
from graphorge.gnn_base_model.data.graph_data import GraphData
from graphorge.gnn_base_model.data.graph_dataset import GNNGraphDataset
from graphorge.gnn_base_model.model.gnn_model import GNNEPDBaseModel
from graphorge.gnn_base_model.train.training import train_model
from graphorge.gnn_base_model.predict.prediction import predict
# Make utilities available in the path
try:
# Check if running in Google Colab
import google.colab
utils_path = Path("graphorge/benchmarks/utils")
except:
utils_path = (
Path().resolve().parents[1] / "utils"
)
sys.path.insert(0, str(utils_path))
# Import required utilities
from utils import download_file, plotly_layout, plot_interactive_graph, plot_loss_history, graph_to_pyvista_mesh
# Set logging configuration
logger = logging.getLogger()
logger.setLevel(logging.INFO)
# Prepare directories
directories = dict(
raw_data=Path("0_data"),
processed_data=Path("1_dataset"),
training=Path("2_training_dataset"),
validation=Path("3_validation_dataset"),
model=Path("4_model"),
in_distribution=Path("5_testing_id_dataset"),
prediction=Path("7_prediction"),
)
# Create directories if they do not exist
for directory in directories.values():
directory.mkdir(parents=True, exist_ok=True)
# Add `notebook_connected` to the default renderer to ensure the figure are rendered in the the documentation
plotly_renderer = f"notebook_connected+{pio.renderers.default}"
# Set device
device_type="cuda" if torch.cuda.is_available() else "cpu"
1. Raw data¶
The dataset describes pore networks extracted from porous media. See the dataset repository for more information.
The dataset is downloaded from the base URL below. It consists of a JSON file, where each entry is a porous network containing:
id: the ID of the network.
pore_coordinates (x, y): the coordinates of the pores defining the network.
channel_idx: channels connecting the pores together.
channel_width (w): the width of each channel.
flow_rate: the flow rate through the network when applying a fixed pressure gradient.
Let’s download the dataset and inspect it. Note that the data are, by nature, well described by a graph.
# Set the base URL for downloading the dataset
from curses import meta
base_url = "https://storage.googleapis.com/dm-meshgraphnets/cylinder_flow"
# Download the dataset files
for file in ["meta.json", "train.tfrecord", "valid.tfrecord", "test.tfrecord"]:
download_file(
file=file,
base_url=base_url,
dest_path=directories["raw_data"],
)
with open(directories["raw_data"] / "meta.json", "r") as file:
metadata = json.load(file)
description = {field: "byte" for field in metadata["field_names"]}
def decode_features(features, metadata):
out = {}
for feature, bytes in features.items():
dtype = metadata["features"][feature]["dtype"]
shape = metadata["features"][feature]["shape"]
data_array = np.frombuffer(bytes, dtype=getattr(np, dtype)).copy()
data_array = data_array.reshape(shape)
out[feature] = data_array
return out
dataset_name_to_file = dict(
training="train.tfrecord",
validation="valid.tfrecord",
in_distribution="test.tfrecord",
)
2. Graph dataset¶
First, we define a function to build the graph. This operation is at the core of the Graphorge framework.
def build_graphorge_graph(
node_coordinates,
edge_indexes,
edge_features,
node_features,
node_targets,
):
"""
Builds a GraphData object from the provided node coordinates, edge indexes,
node features, and global targets.
Parameters
----------
node_coordinates : np.ndarray
An array of shape (n_nodes, n_dim) containing the coordinates of the
nodes.
edge_indexes : np.ndarray
An array of shape (n_edges, 2) containing the indices of the edges.
edge_features : np.ndarray
An array of shape (n_edges, n_features) containing the features of the
edges.
node_features : np.ndarray
An array of shape (n_nodes, n_features) containing the features of the
nodes.
node_targets : np.ndarray
An array of shape (n_nodes, n_targets) containing the targets for the
nodes.
Returns
-------
GraphData
An instance of GraphData containing the graph structure and features.
"""
# Instantiate the graph data (graphorge)
# Note that the n_dim is set to 2 as the geometry is 2D
graph_data = GraphData(n_dim=2, nodes_coords=node_coordinates)
# Set graph edges indexes. The data does not contain duplicated edges,
# so we can set is_unique to False to avoid unnecessary processing
graph_data.set_graph_edges_indexes(edges_indexes_mesh=edge_indexes, is_unique=False)
# Set graph edge feature matrix
graph_data.set_edge_features_matrix(edge_features_matrix=edge_features)
# Set graph node features matrix
graph_data.set_node_features_matrix(node_features_matrix=node_features)
# Set graph node targets matrix
graph_data.set_node_targets_matrix(node_targets_matrix=node_targets)
return graph_data
Note
The dataset defines 4 node types, encoded by integers:
Interior nodes (the actual computational domain):
0Wall nodes with no-slip condition:
4Inlet nodes:
5Outlet nodes:
6
See PhysicsNeMo’s implementation for more details.
# Defines node types
domain_nodes = 0
wall_nodes = 4
inlet_nodes = 5
outlet_nodes = 6
Let’s define a function to process the data:
def process_cylinder_flow_sample(sample, time_steps=1, velocity_noise=0, rng=None):
"""
Processes a cylinder flow geometry to create a GraphData object.
Parameters
----------
sample : dict
A dictionary containing the cylinder flow data with keys:
- "mesh_pos": Coordinates of the mesh nodes.
- "node_type": Type of each node (domain, wall, inlet, outlet).
- "velocity": Velocity field at each time step.
- "pressure": Pressure field at each time step.
- "cells": Connectivity of the mesh cells.
time_steps : int, optional
The number of time steps to process. Default is 1.
velocity_noise : float, optional
Standard deviation of Gaussian noise to add to the velocity field at
each time step. Default is 0 (no noise).
rng : np.random.Generator, optional
A NumPy random number generator for reproducibility. If None, the
default NumPy random generator is used.
Returns
-------
GraphData
An instance of GraphData containing the processed sample mesh data.
"""
# Sample data is stored as [time, data] (see DeepMind MeshGraphNet repository)
# Extract the node coordinates, edge indexes, and global targets
# The mesh is invariant over time, so we only have one time step
node_coordinates = sample["mesh_pos"][0]
# Cell connectivity is defined by the "cells" entry
# Each cell is a triangle defined by 3 node indexes
# Node id start at 0 in the data, but Graphorge expect them to start at 1
connected_nodes = [
(cell[i] + 1, cell[j] + 1)
for cell in sample["cells"][0]
for i, j in [(0, 1), (1, 2), (2, 0)]
]
# Get the edges indexes from the connected nodes
# Note the method returns undirected edges
edge_indexes = GraphData.get_edges_indexes_mesh(connected_nodes)
# Prepare edge features as the distance vector between connected nodes
# and the euclidean distance (norm of the distance vector)
distance_vector = (
node_coordinates[edge_indexes[:, 1]] - node_coordinates[edge_indexes[:, 0]]
)
distance_norm = np.linalg.norm(distance_vector, axis=1, keepdims=True)
edge_features = np.concatenate((distance_vector, distance_norm), axis=1)
# Prepare the node type one-hot which is static over time
encoder = OneHotEncoder(sparse_output=False, dtype=int)
encoder.fit([[domain_nodes], [wall_nodes], [inlet_nodes], [outlet_nodes]])
node_type = encoder.transform(sample["node_type"].reshape(-1, 1))
if velocity_noise:
# Create a random number generator if not provided
rng = np.random.default_rng(rng)
# Create a mask for the domain nodes where noise will be added
noise_mask = sample["node_type"][0].ravel() == 0
# Iterate over the time steps to create a graph for each time step
graphs = []
for time_step in range(time_steps):
# Extract the velocity field at the current time step
velocity_field = sample["velocity"][time_step].copy()
if velocity_noise:
# Add Gaussian noise to the velocity field at the domain nodes
noise = rng.normal(
loc=0.0,
scale=velocity_noise,
size=velocity_field[noise_mask].shape,
)
velocity_field[noise_mask] += noise
# Prepare the node features by concatenating the velocity field
# and the one-hot encoded node type
node_features = np.concatenate((velocity_field, node_type), axis=1)
# Prepare the node target, note the velocity update is predicted
# The velocity update accounts for the noise if any, see MeshGraphNet
# paper, sec. A.2.2
node_targets = np.concatenate(
(
sample["velocity"][time_step + 1] - velocity_field,
sample["pressure"][time_step + 1],
),
axis=1,
)
# Build the graph data
graph_data = build_graphorge_graph(
node_coordinates=node_coordinates,
edge_indexes=edge_indexes,
edge_features=edge_features,
node_features=node_features,
node_targets=node_targets,
)
# Set metadata
graph_data.set_metadata(
dict(
node_feature_names=[
"u",
"v",
"one_hot",
"one_hot",
"one_hot",
"one_hot",
],
edge_feature_names=["dx", "dy", "dist"],
node_target_names=["du", "dv", "p"],
)
)
graphs.append(graph_data)
return graphs
We can visualize the obtained graphs.
Note
We use an interactive visualization for pedagogical purpose. Interactive visualization may require significant resources when working with large graphs. Note that the plot_graph() method is available in Graphorge for fast, static plots.
# Get the first sample from the training dataset
tfrecord_path = directories["raw_data"] / dataset_name_to_file["training"]
dataset = TFRecordDataset(
data_path=tfrecord_path,
index_path=None,
description=description,
transform=lambda rec: decode_features(rec, metadata),
)
sample = next(iter(dataset))
# Process the first time step
graph_data = process_cylinder_flow_sample(
sample, time_steps=1, velocity_noise=2e-2, rng=np.random.default_rng(42)
)[0]
# Visualize the graph
figure = plot_interactive_graph(graph_data, node_size=5)
figure.update_layout(
legend=dict(
yanchor="top",
y=0.98,
xanchor="right",
x=0.98,
title_text=None,
traceorder="normal",
),
width=800,
)
figure.show(renderer=plotly_renderer)
Tip
Hover the nodes and edges to inspect features.
Finally, we can generate the dataset.
# Parameters
samples = 0 # Number of samples to process
time_steps = 0 # Number of time steps to process for each sample
rng = np.random.default_rng(42)
# Noise standard deviation, see MeshGraphNet paper, sec. A.2.2
velocity_noise_std = dict(training=2e-2, validation=0, in_distribution=0)
for dataset_name in ["training", "validation", "in_distribution"]:
tfrecord_path = directories["raw_data"] / dataset_name_to_file[dataset_name]
dataset = TFRecordDataset(
data_path=tfrecord_path,
index_path=None,
description=description,
transform=lambda rec: decode_features(rec, metadata),
)
dataset_sample_files = []
dataset_iterator = iter(dataset)
# Iterate over each samples in the dataset
for sample_id in trange(samples, desc=f"Processing {dataset_name}"):
try:
sample = next(dataset_iterator)
except StopIteration:
print("Reached the end of the dataset.")
break
# Process the porous networks data to get the graph data
pyg_graphs = process_cylinder_flow_sample(
sample=sample,
time_steps=time_steps,
velocity_noise=velocity_noise_std[dataset_name],
rng=rng,
)
for time_step, pyg_graph in enumerate(pyg_graphs):
sample_file_name = (
f"cylinder_flow_graph-{sample_id:04d}_time-{time_step:03d}.pt"
)
# Set sample file path
sample_file_path = directories[dataset_name] / sample_file_name
# Save graph sample file
torch.save(pyg_graph, sample_file_path)
# Save graph sample file path
dataset_sample_files.append(sample_file_path)
# Create a dataset from the processed samples
dataset = GNNGraphDataset(
directories[dataset_name],
dataset_sample_files=dataset_sample_files,
dataset_basename="cylinder_flow_dataset_",
is_store_dataset=False,
)
# Save the dataset
_ = dataset.save_dataset(is_append_n_sample=True)
Note
Graphs require significant amount of memory. The graphs generated from this dataset would not fit in the memory of a H100 GPU.
Graphorge reduces the memory footprint required to train the network by saving and loading individual graphs to and from disk.
This approach imply an overhead from the I/O disk operations. Training this MGN implementation on OSCAR HPC requires about 50 min per epochs.
Tip
The raw data (about 10 GB) fit in memory. The mesh is static and does not require to be stored for each time step. In such a situation, it may be advantageous to define a custom dataset that generates the graphs on the fly, following Google DeepMind and PhysicsNeMo approach.
# Custom dataset class for the cylinder flow samples
class CylinderFlowDataset(torch.utils.data.Dataset):
def __init__(
self,
tfrecord_path,
samples=600, # Number of samples to process
time_steps=400, # Number of time steps to process for each sample
train=False,
velocity_noise_std=0,
save_cells=False,
rng=None,
):
# Set attributes
self.rng = np.random.default_rng(rng)
self.train = train
self.velocity_noise = velocity_noise_std
# Prepare lists to store the data
self.pyg_graphs = []
self.noise_masks = []
self.velocity_fields = []
self.node_targets = []
self.sample_idx_to_graph_idx = []
# Load the TF dataset
raw_dataset = TFRecordDataset(
data_path=tfrecord_path,
index_path=None,
description=description,
transform=lambda rec: decode_features(rec, metadata),
)
# Create a dataset iterator
dataset_iterator = iter(raw_dataset)
# Iterate over each samples in the dataset
for sample_id in trange(samples, desc=f"Processing {dataset_name}"):
try:
sample = next(dataset_iterator)
except StopIteration:
print("Reached the end of the dataset.")
break
# Sample data is stored as [time, data] (see DeepMind MeshGraphNet
# repository)
# Extract the node coordinates
# The mesh is invariant over time, so we only have one time step
node_coordinates = sample["mesh_pos"][0]
# Cell connectivity is defined by the "cells" entry
# Each cell is a triangle defined by 3 node indexes
# Node id start at 0 in the data, but Graphorge expect them to start at 1
connected_nodes = [
(cell[i] + 1, cell[j] + 1)
for cell in sample["cells"][0]
for i, j in [(0, 1), (1, 2), (2, 0)]
]
# Get the edges indexes from the connected nodes
# Note the method returns undirected edges
edge_indexes = GraphData.get_edges_indexes_mesh(connected_nodes)
# Prepare edge features as the distance vector between connected
# nodes and the euclidean distance (norm of the distance vector)
distance_vector = (
node_coordinates[edge_indexes[:, 1]]
- node_coordinates[edge_indexes[:, 0]]
)
distance_norm = np.linalg.norm(distance_vector, axis=1, keepdims=True)
edge_features = np.concatenate((distance_vector, distance_norm), axis=1)
# Prepare the node type one-hot which is static over time
encoder = OneHotEncoder(sparse_output=False, dtype=int)
encoder.fit([[domain_nodes], [wall_nodes], [inlet_nodes], [outlet_nodes]])
node_type = encoder.transform(sample["node_type"].reshape(-1, 1))
if self.train and self.velocity_noise:
# Create a mask for the domain nodes where noise will be added
noise_mask = sample["node_type"][0].ravel() == 0
for time_step in range(time_steps):
# Extract the velocity field at the current time step
velocity_field = sample["velocity"][time_step].copy()
if self.train and self.velocity_noise:
noise = np.zeros_like(velocity_field)
# Add Gaussian noise to the velocity field at the domain nodes
noise = self.rng.normal(
loc=0.0,
scale=self.velocity_noise,
size=velocity_field.shape,
)
noise[~noise_mask] = 0.0
velocity_field += noise
self.velocity_fields.append(
torch.tensor(velocity_field, dtype=torch.float32)
)
node_target = np.concatenate(
(
sample["velocity"][time_step + 1] - velocity_field,
sample["pressure"][time_step + 1],
),
axis=1,
)
self.node_targets.append(torch.tensor(node_target, dtype=torch.float32))
self.sample_idx_to_graph_idx.append(sample_id)
node_features = np.concatenate(
(np.empty_like(self.velocity_fields[-1]), node_type), axis=1
)
node_targets = np.empty_like(node_target)
graph_data = build_graphorge_graph(
node_coordinates=node_coordinates,
edge_indexes=edge_indexes,
edge_features=edge_features,
node_features=node_features,
node_targets=node_targets,
)
# Prepare metadata
graph_metadata = dict(
node_feature_names=[
"u",
"v",
"one_hot",
"one_hot",
"one_hot",
"one_hot",
],
edge_feature_names=["dx", "dy", "dist"],
node_target_names=["du", "dv", "p"],
)
if save_cells:
graph_metadata["cells"] = sample["cells"][0]
# Set metadata
graph_data.set_metadata(graph_metadata)
self.pyg_graphs.append(graph_data.get_torch_data_object())
def __len__(self):
return len(self.velocity_fields)
def __getitem__(self, idx):
graph_idx = self.sample_idx_to_graph_idx[idx]
graph_data = self.pyg_graphs[graph_idx]
# Update the node features
graph_data.x[:, :2] = self.velocity_fields[idx]
# Update the node targets
graph_data.y = self.node_targets[idx]
return graph_data
for dataset_name in ["training", "validation", "in_distribution"]:
tfrecord_path = directories["raw_data"] / dataset_name_to_file[dataset_name]
dataset = CylinderFlowDataset(
tfrecord_path=tfrecord_path,
samples=400, # Number of samples to process
time_steps=300, # Number of time steps to process for each sample
train=True if dataset_name == "training" else False,
velocity_noise_std=2e-2,
save_cells=True if dataset_name == "in_distribution" else False,
rng=42,
)
# Save the dataset
dataset_file_path = directories[dataset_name] / "cylinder_flow_dataset.pt"
torch.save(dataset, dataset_file_path)
Reached the end of the dataset.
figure = plot_interactive_graph(dataset[0], node_size=5)
figure.update_layout(
legend=dict(
yanchor="top",
y=0.98,
xanchor="right",
x=0.98,
title_text=None,
traceorder="normal",
),
width=800,
)
figure.show(renderer=plotly_renderer)
3. GNN model architecture¶
# Set the GNN model parameters
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
gnn_architecture_parameters = dict(
# Set number of node input and output features
n_node_in=6,
n_node_out=3,
n_time_node=0,
# Set number of edge input and output features
n_edge_in=3,
n_edge_out=0,
n_time_edge=0,
# Set number of global input and output features
n_global_in=0,
n_global_out=0,
n_time_global=0,
# Set number of message-passing steps (number of processor layers)
n_message_steps=15,
# Set number of FNN/RNN hidden layers
enc_n_hidden_layers=2,
pro_n_hidden_layers=2,
dec_n_hidden_layers=2,
# Set hidden layer size
hidden_layer_size=128,
# Set model directory
model_directory=directories["model"],
model_name="CylinderFlowMGN",
# Set model input and output features normalization
is_model_in_normalized=False,
is_model_out_normalized=False,
# Set aggregation schemes
pro_edge_to_node_aggr="add",
pro_node_to_global_aggr="add",
# Set activation functions
enc_node_hidden_activ_type="relu",
enc_node_output_activ_type="identity",
enc_edge_hidden_activ_type="relu",
enc_edge_output_activ_type="identity",
pro_node_hidden_activ_type="relu",
pro_node_output_activ_type="identity",
pro_edge_hidden_activ_type="relu",
pro_edge_output_activ_type="identity",
dec_node_hidden_activ_type="relu",
dec_node_output_activ_type="identity",
# Set device
device_type=device_type,
)
4. GNN model training¶
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# Load training datasets
train_dataset = torch.load(directories["training"] / "cylinder_flow_dataset.pt")
validation_dataset = torch.load(directories["validation"] / "cylinder_flow_dataset.pt")
training_parameters = dict(
# Set number of epochs
n_max_epochs=25,
# Set batch size
batch_size=1,
# Set optimizer
opt_algorithm="adam",
# Set learning rate
lr_init=1.0e-04,
# Set learning rate scheduler
lr_scheduler_type=None,
lr_scheduler_kwargs=None,
# Set loss function
loss_nature="node_features_out",
loss_type="mse",
loss_kwargs=dict(),
# Set data shuffling
is_sampler_shuffle=True,
# Set early stopping
is_early_stopping=True,
# Set early stopping parameters
early_stopping_kwargs=dict(
validation_dataset=validation_dataset,
validation_frequency=1,
trigger_tolerance=20,
improvement_tolerance=1e-2,
),
# Set seed
seed=42,
# Set verbosity
is_verbose=True,
tqdm_flavor="notebook",
# Save loss history
save_loss_every=1,
# Set device
device_type=device_type,
)
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# Compute exponential decay (learning rate scheduler)
lr_end = 1.0e-5
gamma = (lr_end / training_parameters["lr_init"]) ** (
1 / training_parameters["n_max_epochs"]
)
# Set learning rate scheduler
training_parameters["lr_scheduler_type"] = "explr"
training_parameters["lr_scheduler_kwargs"] = dict(gamma=gamma)
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# Set model state loading
load_model_state = None
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# Training of GNN-based model
model, _, _ = train_model(
dataset=train_dataset,
model_init_args=gnn_architecture_parameters,
load_model_state=load_model_state,
**training_parameters
)
5. GNN model prediction¶
First, lets analyze the loss curves.
# Plot the loss history
figure = plot_loss_history(loss_file=directories["model"] / "loss_history_record.pkl")
figure.show(renderer=plotly_renderer)
We will define a custom rollout function to animate the predictions.
def predict_rollout_as_pyvista_movie(
dataset, model, output_path, sample_idx, time_steps=300
):
"""
Predicts a rollout using the provided dataset and model, and saves the
results as a PyVista movie.
Parameters
----------
dataset : torch.utils.data.Dataset
The dataset to use for predictions.
model : GNNEPDBaseModel
The GNN model to use for predictions.
sample_idx : int
The index of the sample in the dataset to start the rollout from.
output_path : str
The path to save the output movie.
time_steps : int, optional
The number of time steps to predict. Default is 300.
Returns
-------
None
The function saves the movie to the specified output path.
"""
if (
dataset.sample_idx_to_graph_idx[sample_idx]
!= dataset.sample_idx_to_graph_idx[sample_idx + time_steps - 1]
):
raise ValueError(
f"The provided index {sample_idx} and the number of time steps {time_steps} exceed the sample size."
)
pyg_graph = dataset[0].detach().cpu()
mesh = graph_to_pyvista_mesh(graph=pyg_graph)
updated_velocity_magnitudes = []
predicted_velocity_magnitudes = []
velocity_magnitude_errors = []
for time_step in trange(time_steps, desc="Inference"):
pyg_graph = dataset[time_step]
# Get input features from input graph
node_features_in, edge_features_in, global_features_in, edges_indexes = (
model.get_input_features_from_graph(
pyg_graph, is_normalized=model.is_model_in_normalized
)
)
features_out, _, _ = model(
node_features_in=node_features_in,
edge_features_in=edge_features_in,
edges_indexes=edges_indexes,
)
features_out = (
model.data_scaler_transform(
tensor=features_out,
features_type="node_features_out",
mode="denormalize",
)
.detach()
.cpu()
)
pyg_graph = pyg_graph.detach().cpu()
# Update velocity magnitude
updated_velocity_magnitudes.append(
np.linalg.norm(
pyg_graph.x[:, :2].numpy() + pyg_graph.y[:, :2].numpy(), axis=1
)
)
predicted_velocity_magnitudes.append(
np.linalg.norm(
pyg_graph.x[:, :2].numpy() + features_out[:, :2].numpy(), axis=1
)
)
velocity_magnitude_errors.append(
(predicted_velocity_magnitudes[-1] - updated_velocity_magnitudes[-1])
/ (updated_velocity_magnitudes[-1] + 1e-32)
)
mesh.point_data["velocity_magnitude"] = updated_velocity_magnitudes[0]
mesh.point_data["predicted_velocity_magnitude"] = predicted_velocity_magnitudes[0]
mesh.point_data["velocity_magnitude_error"] = velocity_magnitude_errors[0]
for time_step in trange(time_steps, desc="Generating frames"):
if time_step == 0:
plotter = pv.Plotter(
shape=(3, 1),
notebook=False,
off_screen=True,
border=False,
window_size=(1920, 1200),
)
# Add ground truth
plotter.subplot(0, 0)
plotter.add_text("Ground truth", font_size=15)
plotter.add_mesh(
mesh,
scalars="velocity_magnitude",
lighting=False,
show_edges=False,
cmap=cm.oslo,
clim=[
np.min(updated_velocity_magnitudes),
np.max(updated_velocity_magnitudes),
],
scalar_bar_args=dict(
title="Velocity magnitude",
vertical=True,
height=0.8,
width=0.04,
position_x=0.93,
position_y=0.1,
),
)
plotter.camera.tight(adjust_render_window=False, padding=0.1, view="xy")
# Add predictions
# A shallow copy of cube is made when plotting each scalar array since
# a mesh can have only one active scalar, see PyVista documentation
prediction_mesh = mesh.copy(deep=False)
plotter.subplot(1, 0)
plotter.add_text("Predictions", font_size=15)
plotter.add_mesh(
prediction_mesh,
scalars="predicted_velocity_magnitude",
lighting=False,
show_edges=False,
cmap=cm.oslo,
clim=[
np.min(updated_velocity_magnitudes),
np.max(updated_velocity_magnitudes),
],
scalar_bar_args=dict(
title="Velocity magnitude",
vertical=True,
height=0.8,
width=0.04,
position_x=0.93,
position_y=0.1,
),
)
plotter.camera.tight(adjust_render_window=False, padding=0.1, view="xy")
# Add errors
error_mesh = mesh.copy(deep=False)
plotter.subplot(2, 0)
plotter.add_text("Error", font_size=15)
plotter.add_mesh(
error_mesh,
scalars="velocity_magnitude_error",
lighting=False,
show_edges=False,
cmap=cm.vik,
clim=[-0.05, 0.05],
scalar_bar_args=dict(
title="Relative error",
vertical=True,
height=0.8,
width=0.04,
position_x=0.93,
position_y=0.1,
),
)
plotter.camera.tight(adjust_render_window=False, padding=0.1, view="xy")
plotter.open_movie(output_path)
plotter.write_frame()
else:
mesh.point_data["velocity_magnitude"] = updated_velocity_magnitudes[
time_step
]
prediction_mesh.point_data["predicted_velocity_magnitude"] = (
predicted_velocity_magnitudes[time_step]
)
error_mesh.point_data["velocity_magnitude_error"] = (
velocity_magnitude_errors[time_step]
)
# Write a frame and trigger a render
plotter.write_frame()
# Close and finalize the movie
plotter.close()
We perform the rollout:
# Set logging configuration to WARNING to avoid excessive output from PyVista
logger = logging.getLogger()
logger.setLevel(logging.WARNING)
# Load the in-distribution dataset
test_dataset = torch.load(directories["in_distribution"] / "cylinder_flow_dataset.pt")
# Load the trained model from the best model state
model = GNNEPDBaseModel.init_model_from_file(directories["model"])
_ = model.load_model_state(load_model_state="best", is_remove_posterior=False)
# Set the model to evaluation mode
model.eval()
# Predict a rollout and save it as a PyVista movie
predict_rollout_as_pyvista_movie(
dataset=test_dataset,
model=model,
output_path="_assets/cylinder_flow.mp4",
sample_idx=0,
time_steps=300,
)
Let’s inspect the result as saved on Graphorge’s main branch: