Skip to content

Commit

Permalink
Add Heat PINN
Browse files Browse the repository at this point in the history
  • Loading branch information
Gxinhu committed Oct 6, 2023
1 parent 2acfd1d commit 722ab8b
Show file tree
Hide file tree
Showing 3 changed files with 287 additions and 0 deletions.
22 changes: 22 additions & 0 deletions jointContribution/Heat_PINN/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
# A Physics-informed Neural Network to solve 2D steady-state heat equation

## 参考

Physics Informed Deep Learning (Part I): Data-driven Solutions of Nonlinear Partial Differential Equations

<https://arxiv.org/abs/1711.10561>
<https://github.com/314arhaam/heat-pinn>

## 包含的文件

.
|-- README.md 本文件,说明文件
|-- fdm.py 使用有限差分法计算 Heat Function
`-- main.py 主文件

## 步骤

1. 选择当前目录为工作目录
2. python main.py

### 注意
44 changes: 44 additions & 0 deletions jointContribution/Heat_PINN/fdm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
import itertools
import numpy as np


def solve(n: int, l: float) -> np.ndarray:
"""
Solves the heat equation using the finite difference method.
Reference: https://github.com/314arhaam/heat-pinn/blob/main/codes/heatman.ipynb
Args:
n (int): The number of grid points in each direction.
l (float): The length of the square domain.
Returns:
np.ndarray: A 2D array containing the temperature values at each grid point.
"""
bc = {"x=-l": 75.0, "x=+l": 0.0, "y=-l": 50.0, "y=+l": 0.0}
B = np.zeros([n, n])
T = np.zeros([n**2, n**2])
for k, (i, j) in enumerate(itertools.product(range(n), range(n))):
M = np.zeros([n, n])
M[i, j] = -4
if i != 0:
M[i - 1, j] = 1
else:
B[i, j] += -bc["y=-l"]
if i != n - 1:
M[i + 1, j] = 1
else:
B[i, j] += -bc["y=+l"]
if j != 0:
M[i, j - 1] = 1
else:
B[i, j] += -bc["x=-l"]
if j != n - 1:
M[i, j + 1] = 1
else:
B[i, j] += -bc["x=+l"]
m = np.reshape(M, (1, n**2))
T[k, :] = m
b = np.reshape(B, (n**2, 1))
T = np.matmul(np.linalg.inv(T), b)
T = T.reshape([n, n])
return T
221 changes: 221 additions & 0 deletions jointContribution/Heat_PINN/main.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,221 @@
import argparse
import os

import fdm
import matplotlib.pyplot as plt
import numpy as np

import ppsci
from ppsci.utils import logger


def main():
parser = argparse.ArgumentParser(description="Solve a 2D heat equation by PINN")
parser.add_argument("--epoch", default=1000, type=int, help="max epochs")
parser.add_argument("--lr", default=5e-4, type=float, help="learning rate")
parser.add_argument(
"--output_dir",
default="./output_heat2d",
type=str,
help="output folder",
)

args = parser.parse_args()
# set random seed for reproducibility
ppsci.utils.misc.set_random_seed(2)
# set training hyper-parameters
EPOCHS = args.epoch
ITERS_PER_EPOCH = 1

# set output directory
OUTPUT_DIR = args.output_dir
logger.init_logger("ppsci", os.path.join(OUTPUT_DIR, "train.log"), "info")

# set model
model = ppsci.arch.MLP(("x", "y"), ("u",), 9, 20, "tanh")

# set equation
equation = {"heat": ppsci.equation.Laplace(dim=2)}

# set geometry
geom = {"rect": ppsci.geometry.Rectangle((-1.0, -1.0), (1.0, 1.0))}

# set train dataloader config
train_dataloader_cfg = {
"dataset": "IterableNamedArrayDataset",
"iters_per_epoch": ITERS_PER_EPOCH,
}

NPOINT_PDE = 99**2
NPOINT_TOP = 25
NPOINT_BOTTOM = 25
NPOINT_LEFT = 25
NPOINT_RIGHT = 25

# set constraint
pde_constraint = ppsci.constraint.InteriorConstraint(
equation["heat"].equations,
{"laplace": 0},
geom["rect"],
{**train_dataloader_cfg, "batch_size": NPOINT_PDE},
ppsci.loss.MSELoss("mean"),
weight_dict={
"laplace": 1,
},
evenly=True,
name="EQ",
)
bc_top = ppsci.constraint.BoundaryConstraint(
{"u": lambda out: out["u"]},
{"u": 0},
geom["rect"],
{**train_dataloader_cfg, "batch_size": NPOINT_TOP},
ppsci.loss.MSELoss("mean"),
weight_dict={
"u": 0.25,
},
criteria=lambda x, y: np.isclose(y, 1),
name="BC_top",
)

bc_bottom = ppsci.constraint.BoundaryConstraint(
{"u": lambda out: out["u"]},
{"u": 50 / 75},
geom["rect"],
{**train_dataloader_cfg, "batch_size": NPOINT_BOTTOM},
ppsci.loss.MSELoss("mean"),
weight_dict={
"u": 0.25,
},
criteria=lambda x, y: np.isclose(y, -1),
name="BC_bottom",
)

bc_left = ppsci.constraint.BoundaryConstraint(
{"u": lambda out: out["u"]},
{"u": 1},
geom["rect"],
{**train_dataloader_cfg, "batch_size": NPOINT_LEFT},
ppsci.loss.MSELoss("mean"),
weight_dict={
"u": 0.25,
},
criteria=lambda x, y: np.isclose(x, -1),
name="BC_left",
)

bc_right = ppsci.constraint.BoundaryConstraint(
{"u": lambda out: out["u"]},
{"u": 0},
geom["rect"],
{**train_dataloader_cfg, "batch_size": NPOINT_RIGHT},
ppsci.loss.MSELoss("mean"),
weight_dict={
"u": 0.25,
},
criteria=lambda x, y: np.isclose(x, 1),
name="BC_right",
)
# wrap constraints together
constraint = {
pde_constraint.name: pde_constraint,
bc_top.name: bc_top,
bc_bottom.name: bc_bottom,
bc_left.name: bc_left,
bc_right.name: bc_right,
}

# set optimizer
optimizer = ppsci.optimizer.Adam(learning_rate=args.lr)(model)

# initialize solver
solver = ppsci.solver.Solver(
model,
constraint,
OUTPUT_DIR,
optimizer,
epochs=EPOCHS,
iters_per_epoch=ITERS_PER_EPOCH,
equation=equation,
geom=geom,
)
# train model
solver.train()

# begin eval
n = 100
input_data = geom["rect"].sample_interior(n**2, evenly=True)
pinn_output = solver.predict(input_data, return_numpy=True)["u"].reshape(n, n)
fdm_output = fdm.solve(n, 1).T
mes_loss = np.mean(np.square(pinn_output - (fdm_output / 75.0)))
logger.info(f"The norm MSE loss between the FDM and PINN is {mes_loss}")

x = input_data["x"].reshape(n, n)
y = input_data["y"].reshape(n, n)

plt.subplot(2, 1, 1)
plt.pcolormesh(x, y, pinn_output, cmap="magma")
plt.colorbar()
plt.title("PINN")
plt.xlabel("x")
plt.ylabel("y")
plt.tight_layout()
plt.axis("square")

plt.subplot(2, 1, 2)
plt.pcolormesh(x, y, fdm_output, cmap="magma")
plt.colorbar()
plt.xlabel("x")
plt.ylabel("y")
plt.title("FDM")
plt.tight_layout()
plt.axis("square")
plt.savefig(os.path.join(OUTPUT_DIR, "fdm.png"))
plt.close()

frames_val = np.array([-0.75, -0.5, -0.25, 0.0, +0.25, +0.5, +0.75])
frames = [*map(int, (frames_val + 1) / 2 * (n - 1))]
height = 3
plt.figure("", figsize=(len(frames) * height, 2 * height))

for i, var_index in enumerate(frames):
plt.subplot(2, len(frames), i + 1)
plt.title(f"y = {frames_val[i]:.2f}")
plt.plot(
x[:, var_index],
pinn_output[:, var_index] * 75.0,
"r--",
lw=4.0,
label="pinn",
)
plt.plot(x[:, var_index], fdm_output[:, var_index], "b", lw=2.0, label="FDM")
plt.ylim(0.0, 100.0)
plt.xlim(-1.0, +1.0)
plt.xlabel("x")
plt.ylabel("T")
plt.tight_layout()
plt.legend()

for i, var_index in enumerate(frames):
plt.subplot(2, len(frames), len(frames) + i + 1)
plt.title(f"x = {frames_val[i]:.2f}")
plt.plot(
y[var_index, :],
pinn_output[var_index, :] * 75.0,
"r--",
lw=4.0,
label="pinn",
)
plt.plot(y[var_index, :], fdm_output[var_index, :], "b", lw=2.0, label="FDM")
plt.ylim(0.0, 100.0)
plt.xlim(-1.0, +1.0)
plt.xlabel("y")
plt.ylabel("T")
plt.tight_layout()
plt.legend()

plt.savefig(os.path.join(OUTPUT_DIR, "profiles.png"))


if __name__ == "__main__":
main()

0 comments on commit 722ab8b

Please sign in to comment.