Multi-fidelity Residual VeBRNN¤
This notebook guidance on how to train MF-Residual-VeBRNN of the MF-VeBRNN repo.
import torch
import matplotlib.pyplot as plt
from MFVeBRNN.dataset.load_dataset import MultiFidelityDataset
from MFVeBRNN.method.mf_residual_vebrnn_trainer import MFResidualVeBRNNTrainer
from VeBNN.networks import MeanNet, GammaVarNet
from torch import nn
import warnings
warnings.filterwarnings("ignore")
Load dataset¤
dataset = MultiFidelityDataset(lf_train_data_path = "lf_dns_sve_0d1.pickle",
hf_train_data_path= "hf_dns_rve_0d1.pickle",
id_ground_truth=True,
id_hf_test_data_path="hf_dns_rve_0d1_gt.pickle",
id_lf_ground_truth_data_path="lf_dns_sve_0d1_gt.pickle",
ood_ground_truth=True,
ood_hf_test_data_path="hf_dns_rve_0d125_gt.pickle",
ood_lf_ground_truth_data_path="lf_dns_sve_0d125_gt.pickle",)
dataset.get_hf_train_val_split(num_hf_train=100, num_hf_val=10, seed=0)
dataset.get_lf_train_val_split(num_lf_train=100, num_lf_val=0, seed=0)
Define a simple GRU network¤
Since the history dependent constitutive law is has recurrent data structure, we need a recurrent neural network for this task.
class SimpleGRU(nn.Module):
def __init__(self,
input_size: int,
hidden_size: int,
num_layers: int,
output_size: int,
bias: bool = True):
super().__init__()
self.gru = nn.GRU(input_size=input_size,
hidden_size=hidden_size,
num_layers=num_layers,
bias=bias,
batch_first=True,
)
self.h2y = nn.Linear(hidden_size, output_size)
def forward(self, x, hx=None):
out, _ = self.gru(x, hx) # (B, T, hidden_size)
y = self.h2y(out) # (B, T, output_size)
return y
# define the mean and variance networks
mean_network_ = SimpleGRU(input_size=67,
hidden_size=64,
num_layers=2,
output_size=3)
mean_network = MeanNet(net=mean_network_,
prior_mu=0.0,
prior_sigma=1.0)
# define the variance network
var_network_ = SimpleGRU(input_size=67,
hidden_size=4,
num_layers=1,
output_size=6)
var_network = GammaVarNet(net=var_network_,
prior_mu=0.0,
prior_sigma=1.0)
# load the pre-trained low-fidelity model
lf_pre_trained_model = torch.load("single_fidelity_rnn_model.pth", weights_only=False)
mf_model = MFResidualVeBRNNTrainer(
mean_net=mean_network,
var_net=var_network,
pre_trained_lf_model=lf_pre_trained_model,
device=torch.device("cuda" if torch.cuda.is_available() else "cpu"),
job_id=1
)
# ====== configs (kept small so test is fast) ======
init_config = {
"loss_name": "MSE",
"optimizer_name": "Adam",
"lr": 1e-3,
"weight_decay": 1e-6,
"num_epochs": 1000, # warm-up epochs
"batch_size": 200,
"verbose": False,
"print_iter": 50,
"split_ratio": 0.8,
}
var_config = {
"optimizer_name": "Adam",
"lr": 1e-3,
"num_epochs": 1000,
"batch_size": 200,
"verbose": True,
"print_iter": 50,
"early_stopping": False,
"early_stopping_iter": 100,
"early_stopping_tol": 1e-4,
}
sampler_config = {
"sampler": "pSGLD", # must match your VeBNN.samplers names
"lr": 1e-3,
"gamma": 0.9999,
"num_epochs": 2000, # SGMCMC epochs
"mix_epochs": 10, # thinning interval
"burn_in_epochs": 500,
"batch_size": 200,
"verbose": False,
"print_iter": 100,
}
# ====== run cooperative training ======
# For a quick test, iteration=2 is enough to see if everything works.
mf_model.cooperative_train(
x_train=dataset.hx_train,
y_train=dataset.hy_train,
iteration=1,
init_config=init_config,
var_config=var_config,
sampler_config=sampler_config,
delete_model_raw_data=True, # delete temporary folder after training
)
Get MF-Residual-VeBRNN's prediction¤
index = 2
hy_pred, hf_var = mf_model.hf_bayes_predict(dataset.hx_id_gt_scaled)
fig, ax = plt.subplots(2, 3, figsize=(12, 5))
for i in range(3):
ax[0, i].plot(dataset.hx_id_gt_scaled[index, :, i], label='hx test')
ax[0, i].legend()
ax[1, i].plot(dataset.hy_id_gt_scaled[index, :, i], label='hy test')
ax[1, i].fill_between(
range(len(hy_pred[index, :, i].cpu())),
hy_pred[index, :, i].cpu() - 1.96 * torch.sqrt(hf_var[index, :, i].cpu()),
hy_pred[index, :, i].cpu() + 1.96 * torch.sqrt(hf_var[index, :, i].cpu()),
alpha=0.2,
label='95% CI'
)
ax[1, i].plot(hy_pred[index, :, i].cpu(), label='hy pred')
ax[1, i].legend()
# set the x and y labels
for i in range(3):
ax[1, i].set_xlabel('Time step')
ax[0, 0].set_ylabel('strain')
ax[1, 0].set_ylabel('stress')
# set the title
ax[0, 0].set_title('Component 11')
ax[0, 1].set_title('Component 12')
ax[0, 2].set_title('Component 22')
plt.show()