Skip to content

Commit

Permalink
Add Heat PINN
Browse files Browse the repository at this point in the history
  • Loading branch information
Gxinhu committed Oct 6, 2023
1 parent 2acfd1d commit 7aff7f8
Show file tree
Hide file tree
Showing 21 changed files with 296 additions and 6 deletions.
10 changes: 5 additions & 5 deletions examples/ldc/ldc2d_steady_Re10.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,11 +43,11 @@
"iters_per_epoch": ITERS_PER_EPOCH,
}

NPOINT_PDE = 99**2
NPOINT_TOP = 101
NPOINT_BOTTOM = 101
NPOINT_LEFT = 99
NPOINT_RIGHT = 99
NPOINT_PDE = 10000
NPOINT_TOP = 25
NPOINT_BOTTOM =25
NPOINT_LEFT = 25
NPOINT_RIGHT = 25

# set constraint
pde_constraint = ppsci.constraint.InteriorConstraint(
Expand Down
Binary file added fdm.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
22 changes: 22 additions & 0 deletions jointContribution/Heat_PINN/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
# A Physics-informed Neural Network to solve 2D steady-state heat equation

## 参考

Physics Informed Deep Learning (Part I): Data-driven Solutions of Nonlinear Partial Differential Equations

<https://arxiv.org/abs/1711.10561>
<https://github.com/314arhaam/heat-pinn>

## 包含的文件

.
|-- README.md 本文件,说明文件
|-- fdm.py 使用有限差分法计算 Heat Function
`-- main.py 主文件

## 步骤

1. 选择当前目录为工作目录
2. python main.py

### 注意
44 changes: 44 additions & 0 deletions jointContribution/Heat_PINN/fdm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
import itertools
import numpy as np


def solve(n: int, l: float) -> np.ndarray:
"""
Solves the heat equation using the finite difference method.
Reference: https://github.com/314arhaam/heat-pinn/blob/main/codes/heatman.ipynb
Args:
n (int): The number of grid points in each direction.
l (float): The length of the square domain.
Returns:
np.ndarray: A 2D array containing the temperature values at each grid point.
"""
bc = {"x=-l": 75.0, "x=+l": 0.0, "y=-l": 50.0, "y=+l": 0.0}
B = np.zeros([n, n])
T = np.zeros([n**2, n**2])
for k, (i, j) in enumerate(itertools.product(range(n), range(n))):
M = np.zeros([n, n])
M[i, j] = -4
if i != 0:
M[i - 1, j] = 1
else:
B[i, j] += -bc["y=-l"]
if i != n - 1:
M[i + 1, j] = 1
else:
B[i, j] += -bc["y=+l"]
if j != 0:
M[i, j - 1] = 1
else:
B[i, j] += -bc["x=-l"]
if j != n - 1:
M[i, j + 1] = 1
else:
B[i, j] += -bc["x=+l"]
m = np.reshape(M, (1, n**2))
T[k, :] = m
b = np.reshape(B, (n**2, 1))
T = np.matmul(np.linalg.inv(T), b)
T = T.reshape([n, n])
return T
225 changes: 225 additions & 0 deletions jointContribution/Heat_PINN/main.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,225 @@
import argparse
import os

import fdm
import matplotlib.pyplot as plt
import numpy as np

import ppsci
from ppsci.utils import logger


def main():
parser = argparse.ArgumentParser(description="Solve a 2D heat equation by PINN")
parser.add_argument("--epoch", default=1000, type=int, help="max epochs")
parser.add_argument("--lr", default=5e-4, type=float, help="learning rate")
parser.add_argument(
"--output_dir",
default="./output_heat2d",
type=str,
help="output folder",
)

args = parser.parse_args()
# set random seed for reproducibility
ppsci.utils.misc.set_random_seed(2)
# set training hyper-parameters
EPOCHS = args.epoch
ITERS_PER_EPOCH = 1

# set output directory
OUTPUT_DIR = args.output_dir
logger.init_logger("ppsci", os.path.join(OUTPUT_DIR, "train.log"), "info")

# set model
model = ppsci.arch.MLP(("x", "y"), ("u",), 9, 20, "tanh")

# set equation
equation = {"heat": ppsci.equation.Laplace(dim=2)}

# set geometry
geom = {"rect": ppsci.geometry.Rectangle((-1.0, -1.0), (1.0, 1.0))}

# set train dataloader config
train_dataloader_cfg = {
"dataset": "IterableNamedArrayDataset",
"iters_per_epoch": ITERS_PER_EPOCH,
}

NPOINT_PDE = 99**2
NPOINT_TOP = 25
NPOINT_BOTTOM = 25
NPOINT_LEFT = 25
NPOINT_RIGHT = 25

# set constraint
pde_constraint = ppsci.constraint.InteriorConstraint(
equation["heat"].equations,
{"laplace": 0},
geom["rect"],
{**train_dataloader_cfg, "batch_size": NPOINT_PDE},
ppsci.loss.MSELoss("mean"),
weight_dict={
"laplace": 1,
},
evenly=True,
name="EQ",
)
bc_top = ppsci.constraint.BoundaryConstraint(
{"u": lambda out: out["u"]},
{"u": 0},
geom["rect"],
{**train_dataloader_cfg, "batch_size": NPOINT_TOP},
ppsci.loss.MSELoss("mean"),
weight_dict={
"u": 0.25,
},
criteria=lambda x, y: np.isclose(y, 1),
name="BC_top",
)

bc_bottom = ppsci.constraint.BoundaryConstraint(
{"u": lambda out: out["u"]},
{"u": 50 / 75},
geom["rect"],
{**train_dataloader_cfg, "batch_size": NPOINT_BOTTOM},
ppsci.loss.MSELoss("mean"),
weight_dict={
"u": 0.25,
},
criteria=lambda x, y: np.isclose(y, -1),
name="BC_bottom",
)

bc_left = ppsci.constraint.BoundaryConstraint(
{"u": lambda out: out["u"]},
{"u": 1},
geom["rect"],
{**train_dataloader_cfg, "batch_size": NPOINT_LEFT},
ppsci.loss.MSELoss("mean"),
weight_dict={
"u": 0.25,
},
criteria=lambda x, y: np.isclose(x, -1),
name="BC_left",
)

bc_right = ppsci.constraint.BoundaryConstraint(
{"u": lambda out: out["u"]},
{"u": 0},
geom["rect"],
{**train_dataloader_cfg, "batch_size": NPOINT_RIGHT},
ppsci.loss.MSELoss("mean"),
weight_dict={
"u": 0.25,
},
criteria=lambda x, y: np.isclose(x, 1),
name="BC_right",
)
# wrap constraints together
constraint = {
pde_constraint.name: pde_constraint,
bc_top.name: bc_top,
bc_bottom.name: bc_bottom,
bc_left.name: bc_left,
bc_right.name: bc_right,
}

# set optimizer
optimizer = ppsci.optimizer.Adam(learning_rate=args.lr)(model)

# initialize solver
solver = ppsci.solver.Solver(
model,
constraint,
OUTPUT_DIR,
optimizer,
epochs=EPOCHS,
iters_per_epoch=ITERS_PER_EPOCH,
equation=equation,
geom=geom,
)
# train model
solver.train()

# begin eval
n = 100
input_data = geom["rect"].sample_interior(n**2, evenly=True)
pinn_output = solver.predict(input_data, return_numpy=True)["u"].reshape(n, n)
fdm_output = fdm.solve(n, 1).T
mes_loss = np.mean(np.square(pinn_output - (fdm_output / 75.0)))
logger.info(f"The norm MSE loss between the FDM and PINN is {mes_loss}")

x = input_data["x"].reshape(n, n)
y = input_data["y"].reshape(n, n)

plt.subplot(2, 1, 1)
plt.pcolormesh(x, y, pinn_output, cmap="magma")
plt.colorbar()
plt.title("PINN")
plt.xlabel("x")
plt.ylabel("y")
plt.tight_layout()
plt.axis("square")

plt.subplot(2, 1, 2)
plt.pcolormesh(x, y, fdm_output, cmap="magma")
plt.colorbar()
plt.xlabel("x")
plt.ylabel("y")
plt.title("FDM")
plt.tight_layout()
plt.axis("square")
plt.savefig(os.path.join(OUTPUT_DIR, "fdm.png"))
plt.close()

frames_val = np.array([-0.75, -0.5, -0.25, 0.0, +0.25, +0.5, +0.75])
frames = [*map(int, (frames_val + 1) / 2 * (n - 1))]
height = 3
plt.figure("", figsize=(len(frames) * height, 2 * height))

for i, var_index in enumerate(frames):
plt.subplot(2, len(frames), i + 1)
plt.title(f"y = {frames_val[i]:.2f}")
plt.plot(
x[:, var_index],
pinn_output[:, var_index] * 75.0,
"r--",
lw=4.0,
label="pinn",
)
plt.plot(
x[:, var_index], fdm_output[:, var_index] * 75.0, "b", lw=2.0, label="FDM"
)
plt.ylim(0.0, 100.0)
plt.xlim(-1.0, +1.0)
plt.xlabel("x")
plt.ylabel("T")
plt.tight_layout()
plt.legend()

for i, var_index in enumerate(frames):
plt.subplot(2, len(frames), len(frames) + i + 1)
plt.title(f"x = {frames_val[i]:.2f}")
plt.plot(
y[var_index, :],
pinn_output[var_index, :] * 75.0,
"r--",
lw=4.0,
label="pinn",
)
plt.plot(
y[var_index, :], fdm_output[var_index, :] * 75.0, "b", lw=2.0, label="FDM"
)
plt.ylim(0.0, 100.0)
plt.xlim(-1.0, +1.0)
plt.xlabel("y")
plt.ylabel("T")
plt.tight_layout()
plt.legend()

plt.savefig(os.path.join(OUTPUT_DIR, "profiles.png"))


if __name__ == "__main__":
main()
Binary file added output_darcy2d/checkpoints/latest.pdeqn
Binary file not shown.
Binary file added output_darcy2d/checkpoints/latest.pdopt
Binary file not shown.
Binary file added output_darcy2d/checkpoints/latest.pdparams
Binary file not shown.
Binary file added output_darcy2d/checkpoints/latest.pdstates
Binary file not shown.
Binary file added output_heat2d/checkpoints/latest.pdeqn
Binary file not shown.
Binary file added output_heat2d/checkpoints/latest.pdopt
Binary file not shown.
Binary file added output_heat2d/checkpoints/latest.pdparams
Binary file not shown.
Binary file added output_heat2d/checkpoints/latest.pdstates
Binary file not shown.
Binary file added output_heat2d/fdm.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added output_heat2d/profiles.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added output_laplace2d/checkpoints/latest.pdeqn
Binary file not shown.
Binary file added output_laplace2d/checkpoints/latest.pdopt
Binary file not shown.
Binary file added output_laplace2d/checkpoints/latest.pdparams
Binary file not shown.
Binary file added output_laplace2d/checkpoints/latest.pdstates
Binary file not shown.
Binary file added profiles.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
1 change: 0 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
numpy<=1.23.1
sympy
matplotlib
vtk
Expand Down

0 comments on commit 7aff7f8

Please sign in to comment.