A GNN surrogate to estimate the flow rate through a porous medium

Author(s): Guillaume Broggi

This example implements a graph neural network surrogate to predict the flow rate through a porous medium.

Tip

Run this notebook in Google Colab. Execute the following cell to install requirements.

Install required packages for Google Colab (run this cell in Colab environment):

try:
    # Check if running in Google Colab
    import google.colab

    # Install PyG as in PyG documentation
    %pip uninstall --yes torch torchaudio torchvision torchtext torchdata
    %pip install torch==2.4.1 torchvision==0.19.1 torchaudio==2.4.1 --index-url https://download.pytorch.org/whl/cu124
    %pip install torch-scatter -f https://data.pyg.org/whl/torch-2.4.1+cu124.html
    %pip install torch-sparse -f https://data.pyg.org/whl/torch-2.4.1+cu124.html
    %pip install torch_geometric==2.6.1 -f https://data.pyg.org/whl/torch-2.4.1+cu124.html
    
    # Install Graphorge
    !git clone https://github.com/bessagroup/graphorge.git
    %pip install -e ./graphorge
except:
    pass

# Plotly is required for this example
try:
    import plotly.graph_objects as go
except ImportError:
    %pip install plotly

Imports required to run the notebook:

# Standard
import json
import logging
from pathlib import Path
import pickle
import shutil
import sys

# Third-party
import numpy as np
from plotly import graph_objects as go
import plotly.io as pio
from sklearn.metrics import r2_score
import torch
from tqdm.auto import tqdm, trange

# Locals
from graphorge.gnn_base_model.data.graph_data import GraphData
from graphorge.gnn_base_model.data.graph_dataset import GNNGraphDataset
from graphorge.gnn_base_model.train.training import train_model
from graphorge.gnn_base_model.predict.prediction import predict

# Make utilities available in the path
try:
    # Check if running in Google Colab
    import google.colab
    utils_path = Path("graphorge/benchmarks/utils")
except:
    utils_path = (
        Path().resolve().parents[1] / "utils"
    ) 
sys.path.insert(0, str(utils_path))

# Import required utilities
from utils import download_file, plotly_layout, plot_interactive_graph, plot_loss_history

# Set logging configuration
logger = logging.getLogger()
logger.setLevel(logging.INFO)

# Prepare directories
directories = dict(
    raw_data=Path("0_data"),
    processed_data=Path("1_dataset"),
    training=Path("2_training_dataset"),
    validation=Path("3_validation_dataset"),
    model=Path("4_model"),
    in_distribution=Path("5_testing_id_dataset"),
    prediction=Path("7_prediction"),
)

# Create directories if they do not exist
for directory in directories.values():
    directory.mkdir(parents=True, exist_ok=True)

# Add `notebook_connected` to the default renderer to ensure the figure are rendered in the the documentation
plotly_renderer = f"notebook_connected+{pio.renderers.default}"

# Set device
device_type="cuda" if torch.cuda.is_available() else "cpu"

1. Raw data

The dataset describes pore networks extracted from porous media. See the dataset repository for more information [1].

Problem definition

The dataset is downloaded from the base URL below. It consists of a JSON file, where each entry is a porous network containing:

  • id: the ID of the network.

  • pore_coordinates (x, y): the coordinates of the pores defining the network.

  • channel_idx: channels connecting the pores together.

  • channel_width (w): the width of each channel.

  • flow_rate: the flow rate through the network when applying a fixed pressure gradient.

Let’s download the dataset and inspect it. Note that the data are, by nature, well described by a graph.

# Set the base URL for downloading the dataset
base_url = "placeholder"

# Download the dataset file
download_file(
    file="porous_networks.tar.xz",
    base_url=base_url,
    dest_path=directories["raw_data"],
)

# Extract the dataset
shutil.unpack_archive(
    directories["raw_data"] / "porous_networks.tar.xz",
    extract_dir=directories["raw_data"],
    format="xztar",
)

# Load the data
with open(directories["raw_data"] / "porous_networks.json", "r") as f:
    data = json.load(f)

# Explore the data structure
print(f"Number of samples: {len(data)}")
print("Data structure for the first sample:")
for key, value in data[0].items():
    print(
        f"{key}: {len(value)} values - First value: {value[0]}"
        if isinstance(value, list)
        else f"{key}: {value}"
    )
INFO:root:Using 'porous_networks.tar.xz' cached in in 0_data
Number of samples: 5250
Data structure for the first sample:
id: 0
pore_coordinates: 73 values - First value: [92.639109, 9.877235]
channel_idx: 94 values - First value: [0, 1]
channel_width: 94 values - First value: 2.466653662021832
flow_rate: 2.01291895376651e-15

2. Graph dataset

First, we define a function to build the graph. This operation is at the core of the Graphorge framework.

def build_graphorge_graph(
    node_coordinates,
    edge_indexes,
    edge_features,
    node_features,
    global_targets,
):
    """
    Builds a GraphData object from the provided node coordinates, edge indexes,
    node features, and global targets.

    Parameters
    ----------
    node_coordinates : np.ndarray
        An array of shape (n_nodes, n_dim) containing the coordinates of the
        nodes.
    edge_indexes : np.ndarray
        An array of shape (n_edges, 2) containing the indices of the edges.
    edge_features : np.ndarray
        An array of shape (n_edges, n_features) containing the features of the
        edges.
    node_features : np.ndarray
        An array of shape (n_nodes, n_features) containing the features of the
        nodes.
    global_targets : np.ndarray
        An array of shape (n_samples, n_targets) containing the global targets.

    Returns
    -------
    GraphData
        An instance of GraphData containing the graph structure and features.
    """
    # Instantiate the graph data (graphorge)
    # Note that the n_dim is set to 2 as the porous medium is a 2D network
    graph_data = GraphData(n_dim=2, nodes_coords=node_coordinates)

    # Make the edges undirected, i.e., create target_to_source index pairs
    # from source_to_target index pairs
    edge_indexes = np.concatenate((edge_indexes, edge_indexes[:, ::-1]), axis=0)

    # Set graph edges indexes. The data does not contain duplicated edges,
    # so we can set is_unique to False to avoid unnecessary processing
    graph_data.set_graph_edges_indexes(edges_indexes_mesh=edge_indexes, is_unique=False)

    # Duplicate the edge features to match the undirected edges
    edge_features = np.concatenate((edge_features, edge_features), axis=0)

    # Set graph edge feature matrix
    graph_data.set_edge_features_matrix(edge_features_matrix=edge_features)

    # Set graph node features matrix
    graph_data.set_node_features_matrix(node_features_matrix=node_features)

    # Set global targets matrix
    global_targets_matrix = global_targets

    # Set graph global targets
    graph_data.set_global_targets_matrix(global_targets_matrix)

    return graph_data

Let’s define a function to process the data:

def process_porous_network(porous_network):
    """
    Processes a porous_network dictionary to create a GraphData object.

    Parameters
    ----------
    porous_network : dict
        A dictionary containing the porous network data with keys:
        - "pore_coordinates": Coordinates of the pores.
        - "channel_idx": Indexes of pores connected by channels.
        - "channel_width": Width of the channels.
        - "permeability": Permeability of the porous network.

    Returns
    -------
    GraphData
        An instance of GraphData containing the processed porous network data.
    """

    # Extract the node coordinates, edge indexes, and global targets
    node_coordinates = np.array(porous_network["pore_coordinates"])
    edge_indexes = np.array(porous_network["channel_idx"])
    edge_features = np.array(porous_network["channel_width"]).reshape(-1, 1)
    global_targets = np.array(porous_network["flow_rate"]).reshape(-1, 1)

    # Build the graph data
    graph_data = build_graphorge_graph(
        node_coordinates=node_coordinates,
        edge_indexes=edge_indexes,
        edge_features=edge_features,
        # Using coordinates as node features
        node_features=node_coordinates,
        global_targets=global_targets,
    )

    # Set metadata
    graph_data.set_metadata(
        dict(edge_feature_names=["channel_width"], global_target_names=["flow_rate"])
    )

    return graph_data

We can visualize the obtained graphs.

Note

We use an interactive visualization for pedagogical purpose. Interactive visualization may require significant resources when working with large graphs. Note that the plot_graph() method is available in Graphorge for fast, static plots.

graph_data = process_porous_network(data[400])
figure = plot_interactive_graph(graph_data)
figure.update_layout(
    legend=dict(
        yanchor="top",
        y=0.98,
        xanchor="right",
        x=0.98,
        title_text=None,
        traceorder="normal",
    ),
)
figure.show(renderer=plotly_renderer)

Tip

Hover the nodes and edges to inspect features.

Finally, we generate the dataset.

# Parameters
samples = "all"  # Number of samples to process, or "all" for all samples

# Determine the number of samples (porous networks) to process
n_samples = len(data)
if samples != "all":
    n_samples = min(samples, n_samples)

dataset_sample_files = []
# Iterate over each porous_network_id in the dataset
for porous_network_id in trange(n_samples, desc="Processing porous networks"):

    # Process the porous netwoeks data to get the graph data
    graph_data = process_porous_network(data[porous_network_id])

    # Get PyG homogeneous graph data object
    pyg_graph = graph_data.get_torch_data_object()

    # Set sample file name
    sample_file_name = f"porous_network_graph_{porous_network_id:04d}.pt"
    # Set sample file path
    sample_file_path = directories["processed_data"] / sample_file_name
    # Save graph sample file
    torch.save(pyg_graph, sample_file_path)
    # Save graph sample file path
    dataset_sample_files.append(sample_file_path)

Let’s split the data into train/validation/test datasets.

# This operation is random, so we fix the seed for reproducibility
seed = 42

# Initialize the random number generator
rng = np.random.default_rng(seed)

dataset_split_sizes = {
    "training": 0.7,
    "validation": 0.2,
    "in_distribution": 0.1,
}

split_fractions = np.array(list(dataset_split_sizes.values()))
# Normalize the split fractions to ensure they sum to 1
split_fractions /= split_fractions.sum()
# Get the corresponding split sizes
split_sizes = (split_fractions * len(dataset_sample_files)).astype(int)
# Get the split indices as the cumulative sum of the split sizes
split_indices = np.cumsum(split_sizes)

# Shuffle the dataset sample files
rng.shuffle(dataset_sample_files)

# Split the dataset sample files into training, validation, and test sets
splits_sample_files = np.split(dataset_sample_files, split_indices[:-1])

for split_name, split_files in tqdm(
    zip(dataset_split_sizes.keys(), splits_sample_files),
    desc="Splitting datasets",
):
    split_paths = []
    for file_path in split_files:
        target_path = directories[split_name] / file_path.name
        shutil.copy(
            file_path,
            target_path,
        )

        split_paths.append(target_path.as_posix())

    dataset = GNNGraphDataset(
        directories[split_name].as_posix(),
        dataset_sample_files=split_paths,
        dataset_basename="porous_network_dataset",
        is_store_dataset=False,
    )

    _ = dataset.save_dataset(is_append_n_sample=True)

3. GNN model architecture

# Set the GNN model parameters
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
gnn_architecture_parameters = dict(
    # Set number of node input and output features
    n_node_in=2,
    n_node_out=0,
    n_time_node=0,
    # Set number of edge input and output features
    n_edge_in=1,
    n_edge_out=0,
    n_time_edge=0,
    # Set number of global input and output features
    n_global_in=0,
    n_global_out=1,
    n_time_global=0,
    # Set number of message-passing steps (number of processor layers)
    n_message_steps=2,
    # Set number of FNN/RNN hidden layers
    enc_n_hidden_layers=2,
    pro_n_hidden_layers=2,
    dec_n_hidden_layers=2,
    # Set hidden layer size
    hidden_layer_size=128,
    # Set model directory
    model_directory=directories["model"],
    model_name="PorousNetworkGNN",
    # Set model input and output features normalization
    is_model_in_normalized=True,
    is_model_out_normalized=True,
    # Set aggregation schemes
    pro_edge_to_node_aggr="add",
    pro_node_to_global_aggr="add",
    # Set activation functions
    enc_node_hidden_activ_type="leakyrelu",
    enc_node_output_activ_type="identity",
    enc_edge_hidden_activ_type="leakyrelu",
    enc_edge_output_activ_type="identity",
    pro_node_hidden_activ_type="leakyrelu",
    pro_node_output_activ_type="identity",
    pro_edge_hidden_activ_type="leakyrelu",
    pro_edge_output_activ_type="identity",
    dec_node_hidden_activ_type="leakyrelu",
    dec_node_output_activ_type="identity",
    # Set device
    device_type=device_type,
)

4. GNN model training

train_dataset_file_path = list(
    directories["training"].glob("porous_network_dataset_*.pkl")
)[0]

validation_dataset_file_path = list(
    directories["validation"].glob("porous_network_dataset_*.pkl")
)[0]

# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# Load training datasets
train_dataset = GNNGraphDataset.load_dataset(train_dataset_file_path)
validation_dataset = GNNGraphDataset.load_dataset(validation_dataset_file_path)


training_parameters = dict(
    # Set number of epochs
    n_max_epochs=50,
    # Set batch size
    batch_size=16,
    # Set optimizer
    opt_algorithm="adam",
    # Set learning rate
    lr_init=1.0e-03,
    # Set learning rate scheduler
    lr_scheduler_type=None,
    lr_scheduler_kwargs=None,
    # Set loss function
    loss_nature="global_features_out",
    loss_type="mse",
    loss_kwargs=dict(),
    # Set data shuffling
    is_sampler_shuffle=False,
    # Set early stopping
    is_early_stopping=True,
    # Set early stopping parameters
    early_stopping_kwargs=dict(
        validation_dataset=validation_dataset,
        validation_frequency=1,
        trigger_tolerance=20,
        improvement_tolerance=1e-2,
    ),
    # Set seed
    seed=42,
    # Set verbosity
    is_verbose=True,
    tqdm_flavor="notebook",
    # Save loss history
    save_loss_every=1,
    # Set device
    device_type=device_type,
)

# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# Compute exponential decay (learning rate scheduler)
lr_end = 1.0e-5
gamma = (lr_end / training_parameters["lr_init"]) ** (
    1 / training_parameters["n_max_epochs"]
)
# Set learning rate scheduler
training_parameters["lr_scheduler_type"] = "explr"
training_parameters["lr_scheduler_kwargs"] = dict(gamma=gamma)


# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# Set model state loading
load_model_state = None
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# Training of GNN-based model
model, _, _ = train_model(
    dataset=train_dataset,
    model_init_args=gnn_architecture_parameters,
    **training_parameters
)

5. GNN model prediction

First, lets analyze the loss curves.

# Plot the loss history
figure = plot_loss_history(loss_file=directories["model"] / "loss_history_record.pkl")

figure.show(renderer=plotly_renderer)
test_dataset_file_path = list(
    directories["in_distribution"].glob("porous_network_dataset_*.pkl")
)[0]

dataset = GNNGraphDataset.load_dataset(test_dataset_file_path)

prediction_set_directory, _ = predict(
    dataset,
    model_directory=directories["model"],
    predict_directory=directories["prediction"],
    load_model_state="best",
    loss_nature="global_features_out",
    loss_type="mse",
    loss_kwargs={},
    is_normalized_loss=False,
    dataset_file_path=test_dataset_file_path,
    device_type=device_type,
    seed=None,
    is_verbose=True,
    tqdm_flavor="notebook",
)
ground_truth = []
predictions = []
# Load the predictions and ground truth from the prediction files
try:
    prediction_set_directory
except NameError:
    prediction_set_directory = directories["prediction"] / "prediction_set_0"
for result_path in tqdm(
    Path(prediction_set_directory).glob("prediction_sample_*.pkl"),
    desc="Loading predictions",
):
    with open(result_path, "rb") as f:
        result = pickle.load(f)
        predictions.append(result["global_features_out"].detach().cpu().item())
        ground_truth.append(result["global_targets"].detach().cpu().item())

# Scale results
predictions = [pred * 1e15 for pred in predictions]
ground_truth = [gt * 1e15 for gt in ground_truth]

min_value = min(min(predictions), min(ground_truth))
max_value = max(max(predictions), max(ground_truth))

figure = go.Figure()

figure.add_scatter(
    x=[min_value, max_value],
    y=[min_value, max_value],
    mode="lines",
    line_width=2,
    line_color="black",
    line_dash="dot",
    name="Identity line",
)

figure.add_scatter(
    x=ground_truth,
    y=predictions,
    mode="markers",
    name="Flow rate predictions",
    # opacity=0.5,
    marker_size=10,
    marker_color="rgba(69, 119, 170, 0.4)",
    marker_line_width=1,
    marker_line_color="rgba(69, 119, 170, 1)",
    hovertemplate="Ground truth: %{x:.4f}<br>Prediction: %{y:.4f}<extra></extra>",
)

figure.add_annotation(
    x=0.1,
    y=0.8,
    xref="paper",
    yref="paper",
    text=f"R² = {r2_score(ground_truth, predictions):.3f}",
    showarrow=False,
    font=dict(size=14),
)

figure.update_layout(
    **plotly_layout,
    xaxis_title="Ground truth",
    yaxis_title="Predictions",
    xaxis_range=[min_value, max_value],
    yaxis_range=[min_value, max_value],
    legend=dict(
        yanchor="top",
        y=0.98,
        xanchor="left",
        x=0.02,
        title_text=None,
        traceorder="normal",
    ),
    width=500,
    height=500,
)

figure.update_yaxes(
    scaleanchor="x",
    scaleratio=1,
)
figure.show(renderer=plotly_renderer)