diff --git a/docs/zh/examples/phycrnet.md b/docs/zh/examples/phycrnet.md new file mode 100644 index 000000000..e3936f3da --- /dev/null +++ b/docs/zh/examples/phycrnet.md @@ -0,0 +1,124 @@ +# PhyCRNet + +## 1. 背景简介 + +复杂时空系统通常可以通过偏微分方程(PDE)来建模,它们在许多领域都十分常见,如应用数学、物理学、生物学、化学和工程学。求解PDE系统一直是科学计算领域的一个关键组成部分。 +本文的具体目标是为了提出一种新颖的、考虑物理信息的卷积-递归学习架构(PhyCRNet)及其轻量级变体(PhyCRNet-s),用于解决没有任何标签数据的多元时间空间PDEs。我们不试图将我们提出的方法与经典的数值求解器进行比较,而是为复杂PDEs的代理建模提供一种时空深度学习视角。 + +## 2. 问题定义 + +在此,我们考虑一组多维(n)、非线性、耦合的参数设置下的偏微分方程(PDE)系统的通用形式: + +$$ +\mathbf{u}_t+\mathcal{F}\left[\mathbf{u}, \mathbf{u}^2, \cdots, \nabla_{\mathbf{x}} \mathbf{u}, \nabla_{\mathbf{x}}^2 \mathbf{u}, \nabla_{\mathbf{x}} \mathbf{u} \cdot \mathbf{u}, \cdots ; \boldsymbol{\lambda}\right]=\mathbf{0} +$$ + +我们的目标是开发基于深度神经网络(DNN)的方法,用于解决给定式中的时空PDE系统的正向分析问题。 + +## 3. 问题求解 + +接下来开始讲解如何将问题一步一步地转化为 PaddleScience 代码,用深度学习的方法求解该问题。 +为了快速理解 PaddleScience,接下来仅对模型构建、方程构建、计算域构建等关键步骤进行阐述,而其余细节请参考 [API文档](../api/arch.md)。 + +### 3.1 模型构建 + +在 PhyCRNet 问题中,建立网络,用 PaddleScience 代码表示如下 + +``` py linenums="105" +--8<-- +examples/phycrnet/main.py:163:174 +--8<-- +``` + +PhyCRNet 参数 input_channels 是输入通道数,hidden_channels 是隐藏层通道数,input_kernel_size 是内核层大小。 + +### 3.2 数据构建 + +运行本问题代码前请按照下方命令生成数据集 + +``` shell +python burgers_data.py +``` + +本案例涉及读取数据构建,如下所示 + +``` py linenums="182" +--8<-- +examples/phycrnet/main.py:182:191 +--8<-- +``` + +### 3.3 约束构建 + +设置训练数据集和损失计算函数,返回字段,代码如下所示: + +``` py linenums="200" +--8<-- +examples/phycrnet/main.py:200:213 +--8<-- +``` + +### 3.4 评估器构建 + +设置评估数据集和损失计算函数,返回字段,代码如下所示: + +``` py linenums="216" +--8<-- +examples/phycrnet/main.py:216:230 +--8<-- +``` + +### 3.5 超参数设定 + +接下来我们需要指定训练轮数,此处我们按实验经验,使用 200 轮训练轮数。 + +``` py linenums="143" +--8<-- +examples/phycrnet/main.py:143:143 +--8<-- +``` + +### 3.6 优化器构建 + +训练过程会调用优化器来更新模型参数,此处选择 `Adam` 优化器并设定 `learning_rate` 为 1e-4。 + +``` py linenums="242" +--8<-- +examples/phycrnet/main.py:242:242 +--8<-- +``` + +### 3.7 模型训练与评估 + +完成上述设置之后,只需要将上述实例化的对象按顺序传递给 `ppsci.solver.Solver`。 + +``` py linenums="243" +--8<-- +examples/phycrnet/main.py:243:254 +--8<-- +``` + +最后启动训练、评估即可: + +``` py linenums="256" +--8<-- +examples/phycrnet/main.py:256:260 +--8<-- +``` + +## 4. 完整代码 + +``` py linenums="1" title="phycrnet" +--8<-- +examples/phycrnet/main.py +--8<-- +``` + +## 5. 结果展示 + +PhyCRNet 案例针对 epoch=200 和 learning\_rate=1e-4 的参数配置进行了实验,结果返回Loss为 17.86。 + +## 6. 参考资料 + +- [PhyCRNet: Physics-informed Convolutional-Recurrent Network for Solving Spatiotemporal PDEs](https://arxiv.org/abs/2106.14103) +- diff --git a/examples/phycrnet/README.md b/examples/phycrnet/README.md new file mode 100644 index 000000000..113dd7c24 --- /dev/null +++ b/examples/phycrnet/README.md @@ -0,0 +1,43 @@ +# PhyCRNet + +Physics-informed convolutional-recurrent neural networks for solving spatiotemporal PDEs + +Paper link: [[Journal Paper](https://www.sciencedirect.com/science/article/pii/S0045782521006514)], [[ArXiv](https://arxiv.org/pdf/2106.14103.pdf)] + +By: [Pu Ren](https://scholar.google.com/citations?user=7FxlSHEAAAAJ&hl=en), [Chengping Rao](https://github.com/Raocp), [Yang Liu](https://coe.northeastern.edu/people/liu-yang/), [Jian-Xun Wang](http://sites.nd.edu/jianxun-wang/) and [Hao Sun](https://web.mit.edu/haosun/www/#/home) + +## Highlights + +- Present a Physics-informed discrete learning framework for solving spatiotemporal PDEs without any labeled data +- Proposed an encoder-decoder convolutional-recurrent scheme for low-dimensional feature extraction +- Employ hard-encoding of initial and boundary conditions +- Incorporate autoregressive and residual connections to explicitly simulate the time marching + +## 参考 + +- + +## 原仓库环境 + +- Python 3.6.13,使用Pytorch 1.6.0 +- [Pytorch](https://pytorch.org/) 1.6.0,random_fields.py使用的torch.ifft在更高版本不支持。如果不生成数据集,可以使用其他版本。 +- matplotlib, numpy, scipy +- post_process 中 `x = x[:-1]` 一行需要注释 + +## 数据集 + +生成测试数据集 + +``` shell +python burgers_data.py +``` + +## 运行 + +``` shell +python main.py +``` + +## 注意 + +训练网络 steps 可以从较小步骤开始,比如100,然后修改为200 diff --git a/examples/phycrnet/burgers_data.py b/examples/phycrnet/burgers_data.py new file mode 100644 index 000000000..c64faf2ce --- /dev/null +++ b/examples/phycrnet/burgers_data.py @@ -0,0 +1,273 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +high-order finite difference solver for 2d Burgers equation +spatial diff: 4th order laplacian +temporal diff: O(dt^5) due to RK4 +""" + +import os + +import functions +import matplotlib.pyplot as plt +import numpy as np +import scipy.io + +import ppsci + +ppsci.utils.misc.set_random_seed(5) + + +def apply_laplacian(mat, dx=1.0): + # dx is inversely proportional to N + """This function applies a discretized Laplacian + in periodic boundary conditions to a matrix + + For more information see + https://en.wikipedia.org/wiki/Discrete_Laplace_operator#Implementation_via_operator_discretization + """ + + # the cell appears 4 times in the formula to compute + # the total difference + neigh_mat = -5 * mat.copy() + + # Each direct neighbor on the lattice is counted in + # the discrete difference formula + neighbors = [ + (4 / 3, (-1, 0)), + (4 / 3, (0, -1)), + (4 / 3, (0, 1)), + (4 / 3, (1, 0)), + (-1 / 12, (-2, 0)), + (-1 / 12, (0, -2)), + (-1 / 12, (0, 2)), + (-1 / 12, (2, 0)), + ] + + # shift matrix according to demanded neighbors + # and add to this cell with corresponding weight + for weight, neigh in neighbors: + neigh_mat += weight * np.roll(mat, neigh, (0, 1)) + + return neigh_mat / dx**2 + + +def apply_dx(mat, dx=1.0): + """central diff for dx""" + + # np.roll, axis=0 -> row + # the total difference + neigh_mat = -0 * mat.copy() + + # Each direct neighbor on the lattice is counted in + # the discrete difference formula + neighbors = [ + (1.0 / 12, (2, 0)), + (-8.0 / 12, (1, 0)), + (8.0 / 12, (-1, 0)), + (-1.0 / 12, (-2, 0)), + ] + + # shift matrix according to demanded neighbors + # and add to this cell with corresponding weight + for weight, neigh in neighbors: + neigh_mat += weight * np.roll(mat, neigh, (0, 1)) + + return neigh_mat / dx + + +def apply_dy(mat, dy=1.0): + """central diff for dy""" + + # the total difference + neigh_mat = -0 * mat.copy() + + # Each direct neighbor on the lattice is counted in + # the discrete difference formula + neighbors = [ + (1.0 / 12, (0, 2)), + (-8.0 / 12, (0, 1)), + (8.0 / 12, (0, -1)), + (-1.0 / 12, (0, -2)), + ] + + # shift matrix according to demanded neighbors + # and add to this cell with corresponding weight + for weight, neigh in neighbors: + neigh_mat += weight * np.roll(mat, neigh, (0, 1)) + + return neigh_mat / dy + + +def get_temporal_diff(U, V, R, dx): + # u and v in (h, w) + laplace_u = apply_laplacian(U, dx) + laplace_v = apply_laplacian(V, dx) + + u_x = apply_dx(U, dx) + v_x = apply_dx(V, dx) + + u_y = apply_dy(U, dx) + v_y = apply_dy(V, dx) + + # governing equation + u_t = (1.0 / R) * laplace_u - U * u_x - V * u_y + v_t = (1.0 / R) * laplace_v - U * v_x - V * v_y + + return u_t, v_t + + +def update_rk4(U0, V0, R=100, dt=0.05, dx=1.0): + """Update with Runge-kutta-4 method + See https://en.wikipedia.org/wiki/Runge%E2%80%93Kutta_methods + """ + ############# Stage 1 ############## + # compute the diffusion part of the update + + u_t, v_t = get_temporal_diff(U0, V0, R, dx) + + K1_u = u_t + K1_v = v_t + + ############# Stage 2 ############## + U1 = U0 + K1_u * dt / 2.0 + V1 = V0 + K1_v * dt / 2.0 + + u_t, v_t = get_temporal_diff(U1, V1, R, dx) + + K2_u = u_t + K2_v = v_t + + ############# Stage 3 ############## + U2 = U0 + K2_u * dt / 2.0 + V2 = V0 + K2_v * dt / 2.0 + + u_t, v_t = get_temporal_diff(U2, V2, R, dx) + + K3_u = u_t + K3_v = v_t + + ############# Stage 4 ############## + U3 = U0 + K3_u * dt + V3 = V0 + K3_v * dt + + u_t, v_t = get_temporal_diff(U3, V3, R, dx) + + K4_u = u_t + K4_v = v_t + + # Final solution + U = U0 + dt * (K1_u + 2 * K2_u + 2 * K3_u + K4_u) / 6.0 + V = V0 + dt * (K1_v + 2 * K2_v + 2 * K3_v + K4_v) / 6.0 + + return U, V + + +def postProcess(output, reso, xmin, xmax, ymin, ymax, num, fig_save_dir): + """num: Number of time step""" + x = np.linspace(0, reso, reso + 1) + y = np.linspace(0, reso, reso + 1) + x_star, y_star = np.meshgrid(x, y) + x_star, y_star = x_star[:-1, :-1], y_star[:-1, :-1] + + u_pred = output[num, 0, :, :] + v_pred = output[num, 1, :, :] + + fig, ax = plt.subplots(nrows=1, ncols=2, figsize=(6, 3)) + fig.subplots_adjust(hspace=0.3, wspace=0.3) + + cf = ax[0].scatter( + x_star, + y_star, + c=u_pred, + alpha=0.95, + edgecolors="none", + cmap="RdYlBu", + marker="s", + s=3, + ) + ax[0].axis("square") + ax[0].set_xlim([xmin, xmax]) + ax[0].set_ylim([ymin, ymax]) + cf.cmap.set_under("black") + cf.cmap.set_over("whitesmoke") + ax[0].set_title("u-FDM") + fig.colorbar(cf, ax=ax[0], fraction=0.046, pad=0.04) + + cf = ax[1].scatter( + x_star, + y_star, + c=v_pred, + alpha=0.95, + edgecolors="none", + cmap="RdYlBu", + marker="s", + s=3, + ) + ax[1].axis("square") + ax[1].set_xlim([xmin, xmax]) + ax[1].set_ylim([ymin, ymax]) + cf.cmap.set_under("black") + cf.cmap.set_over("whitesmoke") + ax[1].set_title("v-FDM") + fig.colorbar(cf, ax=ax[1], fraction=0.046, pad=0.04) + + plt.savefig(fig_save_dir + "/uv_[i=%d].png" % (num)) + plt.close("all") + + +if __name__ == "__main__": + # grid size + M, N = 128, 128 + n_simu_steps = 30000 + dt = 0.0001 # maximum 0.003 + dx = 1.0 / M + R = 200.0 + + # get initial condition from random field + GRF = functions.GaussianRF(2, M, alpha=2, tau=5) + U, V = GRF.sample(2) # U and V have shape of [128, 128] + U = U.cpu().numpy() + V = V.cpu().numpy() + + U_record = U.copy()[None, ...] + V_record = V.copy()[None, ...] + + for step in range(n_simu_steps): + + U, V = update_rk4(U, V, R, dt, dx) # [h, w] + + if (step + 1) % 20 == 0: + print(step) + U_record = np.concatenate((U_record, U[None, ...]), axis=0) # [t,h,w] + V_record = np.concatenate((V_record, V[None, ...]), axis=0) + + UV = np.concatenate((U_record[None, ...], V_record[None, ...]), axis=0) # (c,t,h,w) + UV = np.transpose(UV, [1, 0, 2, 3]) # (t,c,h,w) + + fig_save_dir = "./output/figures/2dBurgers/" + os.makedirs(fig_save_dir, exist_ok=True) + for i in range(0, 30): + postProcess(UV, M, 0, M, 0, M, 50 * i, fig_save_dir) + + # save data + data_save_dir = "./output/" + os.makedirs(data_save_dir, exist_ok=True) + scipy.io.savemat( + os.path.join(data_save_dir, "burgers_1501x2x128x128.mat"), {"uv": UV} + ) + +# [umin, umax] = [-0.7, 0.7] +# [vmin, vmax] = [-1.0, 1.0] diff --git a/examples/phycrnet/functions.py b/examples/phycrnet/functions.py new file mode 100644 index 000000000..bf54594d0 --- /dev/null +++ b/examples/phycrnet/functions.py @@ -0,0 +1,210 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +from typing import Dict + +import numpy as np +import paddle +import paddle.nn as nn + + +def metric_expr(output_dict, *args) -> Dict[str, paddle.Tensor]: + return {"dummy_loss": paddle.to_tensor(0.0)} + + +class GaussianRF(object): + def __init__(self, dim, size, alpha=2, tau=3, sigma=None, boundary="periodic"): + self.dim = dim + + if sigma is None: + sigma = tau ** (0.5 * (2 * alpha - self.dim)) + + k_max = size // 2 + + if dim == 1: + k = paddle.concat( + ( + paddle.arange(start=0, end=k_max, step=1), + paddle.arange(start=-k_max, end=0, step=1), + ), + 0, + ) + + self.sqrt_eig = ( + size + * math.sqrt(2.0) + * sigma + * ((4 * (math.pi**2) * (k**2) + tau**2) ** (-alpha / 2.0)) + ) + self.sqrt_eig[0] = 0.0 + + elif dim == 2: + wavenumers = paddle.concat( + ( + paddle.arange(start=0, end=k_max, step=1), + paddle.arange(start=-k_max, end=0, step=1), + ), + 0, + ).tile((size, 1)) + + perm = list(range(wavenumers.ndim)) + perm[1] = 0 + perm[0] = 1 + k_x = wavenumers.transpose(perm=perm) + k_y = wavenumers + + self.sqrt_eig = ( + (size**2) + * math.sqrt(2.0) + * sigma + * ( + (4 * (math.pi**2) * (k_x**2 + k_y**2) + tau**2) + ** (-alpha / 2.0) + ) + ) + self.sqrt_eig[0, 0] = 0.0 + + elif dim == 3: + wavenumers = paddle.concat( + ( + paddle.arange(start=0, end=k_max, step=1), + paddle.arange(start=-k_max, end=0, step=1), + ), + 0, + ).tile((size, size, 1)) + + perm = list(range(wavenumers.ndim)) + perm[1] = 2 + perm[2] = 1 + k_x = wavenumers.transpose(perm=perm) + k_y = wavenumers + + perm = list(range(wavenumers.ndim)) + perm[0] = 2 + perm[2] = 0 + k_z = wavenumers.transpose(perm=perm) + + self.sqrt_eig = ( + (size**3) + * math.sqrt(2.0) + * sigma + * ( + (4 * (math.pi**2) * (k_x**2 + k_y**2 + k_z**2) + tau**2) + ** (-alpha / 2.0) + ) + ) + self.sqrt_eig[0, 0, 0] = 0.0 + + self.size = [] + for j in range(self.dim): + self.size.append(size) + + self.size = tuple(self.size) + + def sample(self, N): + + coeff = paddle.randn((N, *self.size, 2)) + + coeff[..., 0] = self.sqrt_eig * coeff[..., 0] + coeff[..., 1] = self.sqrt_eig * coeff[..., 1] + + if self.dim == 2: + u = paddle.as_real(paddle.fft.ifft2(paddle.as_complex(coeff))) + else: + raise f"self.dim not in (2): {self.dim}" + + u = u[..., 0] + + return u + + +def compute_loss(output, loss_func): + """calculate the physics loss""" + + # Padding x axis due to periodic boundary condition + # shape: [t, c, h, w] + output = paddle.concat((output[:, :, :, -2:], output, output[:, :, :, 0:3]), axis=3) + + # Padding y axis due to periodic boundary condition + # shape: [t, c, h, w] + output = paddle.concat((output[:, :, -2:, :], output, output[:, :, 0:3, :]), axis=2) + + # get physics loss + mse_loss = nn.MSELoss() + f_u, f_v = loss_func.get_phy_Loss(output) + loss = mse_loss(f_u, paddle.zeros_like(f_u).cuda()) + mse_loss( + f_v, paddle.zeros_like(f_v).cuda() + ) + + return loss + + +def count_parameters(model): + return sum(p.numel() for p in model.parameters() if not p.stop_gradient) + + +def post_process(output, true, num): + """ + num: Number of time step + """ + u_star = true[num, 0, 1:-1, 1:-1] + u_pred = output[num, 0, 1:-1, 1:-1].detach().cpu().numpy() + + v_star = true[num, 1, 1:-1, 1:-1] + v_pred = output[num, 1, 1:-1, 1:-1].detach().cpu().numpy() + + return u_star, u_pred, v_star, v_pred + + +def frobenius_norm(tensor): + return np.sqrt(np.sum(tensor**2)) + + +class Dataset: + def __init__(self, initial_state, input): + self.initial_state = initial_state + self.input = input + + def get(self, epochs=1): + input_dict_train = { + "initial_state": [], + "initial_state_shape": [], + "input": [], + } + label_dict_train = {"dummy_loss": []} + input_dict_val = { + "initial_state": [], + "initial_state_shape": [], + "input": [], + } + label_dict_val = {"dummy_loss": []} + for i in range(epochs): + # paddle not support rand >=7, so reshape, and then recover in input_transform + shape = self.initial_state.shape + input_dict_train["initial_state"].append(self.initial_state.reshape((-1,))) + input_dict_train["initial_state_shape"].append(paddle.to_tensor(shape)) + input_dict_train["input"].append(self.input) + label_dict_train["dummy_loss"].append(paddle.to_tensor(0.0)) + + if i == epochs - 1: + shape = self.initial_state.shape + input_dict_val["initial_state"].append( + self.initial_state.reshape((-1,)) + ) + input_dict_val["initial_state_shape"].append(paddle.to_tensor(shape)) + input_dict_val["input"].append(self.input) + label_dict_val["dummy_loss"].append(paddle.to_tensor(0.0)) + + return input_dict_train, label_dict_train, input_dict_val, label_dict_val diff --git a/examples/phycrnet/main.py b/examples/phycrnet/main.py new file mode 100644 index 000000000..613623d61 --- /dev/null +++ b/examples/phycrnet/main.py @@ -0,0 +1,341 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +PhyCRNet for solving spatiotemporal PDEs +Reference: https://github.com/isds-neu/PhyCRNet/ +""" +import functions +import matplotlib.pyplot as plt +import numpy as np +import paddle +import scipy.io as scio + +import ppsci +from ppsci.arch import phycrnet +from ppsci.utils import config +from ppsci.utils import logger + + +# transform +def transform_in(input): + shape = input["initial_state_shape"][0] + input_transformed = { + "initial_state": input["initial_state"][0].reshape(shape.tolist()), + "input": input["input"][0], + } + return input_transformed + + +def transform_out(input, out, model): + # Stop the transform. + model.enable_transform = False + global dt, dx + global num_time_batch + + loss_func = phycrnet.loss_generator(dt, dx) + batch_loss = 0 + state_detached = [] + prev_output = [] + for time_batch_id in range(num_time_batch): + # update the first input for each time batch + if time_batch_id == 0: + hidden_state = input["initial_state"] + u0 = input["input"] + else: + hidden_state = state_detached + u0 = prev_output[-2:-1].detach() # second last output + out = model({"initial_state": hidden_state, "input": u0}) + + # output is a list + output = out["outputs"] + second_last_state = out["second_last_state"] + + # [t, c, height (Y), width (X)] + output = paddle.concat(tuple(output), axis=0) + + # concatenate the initial state to the output for central diff + output = paddle.concat((u0.cuda(), output), axis=0) + + # get loss + loss = functions.compute_loss(output, loss_func) + # loss.backward(retain_graph=True) + batch_loss += loss + + # update the state and output for next batch + prev_output = output + state_detached = [] + for i in range(len(second_last_state)): + (h, c) = second_last_state[i] + state_detached.append((h.detach(), c.detach())) # hidden state + + model.enable_transform = True + return {"loss": batch_loss} + + +def tranform_output_val(input, out): + global uv + output = out["outputs"] + input = input["input"] + + # shape: [t, c, h, w] + output = paddle.concat(tuple(output), axis=0) + output = paddle.concat((input.cuda(), output), axis=0) + + # Padding x and y axis due to periodic boundary condition + output = paddle.concat((output[:, :, :, -1:], output, output[:, :, :, 0:2]), axis=3) + output = paddle.concat((output[:, :, -1:, :], output, output[:, :, 0:2, :]), axis=2) + + # [t, c, h, w] + truth = uv[0:1001, :, :, :] + + # [101, 2, 131, 131] + truth = np.concatenate((truth[:, :, :, -1:], truth, truth[:, :, :, 0:2]), axis=3) + truth = np.concatenate((truth[:, :, -1:, :], truth, truth[:, :, 0:2, :]), axis=2) + + # post-process + ten_true = [] + ten_pred = [] + for i in range(0, 50): + u_star, u_pred, v_star, v_pred = functions.post_process( + output, + truth, + num=20 * i, + ) + + ten_true.append([u_star, v_star]) + ten_pred.append([u_pred, v_pred]) + + # compute the error + error = functions.frobenius_norm( + np.array(ten_pred) - np.array(ten_true) + ) / functions.frobenius_norm(np.array(ten_true)) + return {"loss": paddle.to_tensor([error])} + + +def train_loss_func(result_dict, *args) -> paddle.Tensor: + return result_dict["loss"] + + +def val_loss_func(result_dict, *args) -> paddle.Tensor: + return result_dict["loss"] + + +def output_graph(model, input_dataset, fig_save_path): + output_dataset = model(input_dataset) + output = output_dataset["outputs"] + input = input_dataset["input"][0] + + # shape: [t, c, h, w] + output = paddle.concat(tuple(output), axis=0) + output = paddle.concat((input.cuda(), output), axis=0) + + # Padding x and y axis due to periodic boundary condition + output = paddle.concat((output[:, :, :, -1:], output, output[:, :, :, 0:2]), axis=3) + output = paddle.concat((output[:, :, -1:, :], output, output[:, :, 0:2, :]), axis=2) + + # [t, c, h, w] + truth = uv[0:1001, :, :, :] + + # [101, 2, 131, 131] + truth = np.concatenate((truth[:, :, :, -1:], truth, truth[:, :, :, 0:2]), axis=3) + truth = np.concatenate((truth[:, :, -1:, :], truth, truth[:, :, 0:2, :]), axis=2) + + # post-process + ten_true = [] + ten_pred = [] + for i in range(0, 50): + u_star, u_pred, v_star, v_pred = functions.post_process( + output, truth, num=20 * i + ) + + ten_true.append([u_star, v_star]) + ten_pred.append([u_pred, v_pred]) + + # compute the error + error = functions.frobenius_norm( + np.array(ten_pred) - np.array(ten_true) + ) / functions.frobenius_norm(np.array(ten_true)) + + print("The predicted error is: ", error) + + u_pred = output[:-1, 0, :, :].detach().cpu().numpy() + u_pred = np.swapaxes(u_pred, 1, 2) # [h,w] = [y,x] + u_true = truth[:, 0, :, :] + + t_true = np.linspace(0, 2, 1001) + t_pred = np.linspace(0, 2, time_steps) + + plt.plot(t_pred, u_pred[:, 32, 32], label="x=32, y=32, CRL") + plt.plot(t_true, u_true[:, 32, 32], "--", label="x=32, y=32, Ref.") + plt.xlabel("t") + plt.ylabel("u") + plt.xlim(0, 2) + plt.legend() + plt.savefig(fig_save_path + "x=32,y=32.png") + plt.close("all") + + # # plot train loss + # plt.figure() + # plt.plot(train_loss, label="train loss") + # plt.yscale("log") + # plt.legend() + # plt.savefig(fig_save_path + "train loss.png", dpi=300) + + +if __name__ == "__main__": + args = config.parse_args() + # set random seed for reproducibility + ppsci.utils.misc.set_random_seed(5) + # set output directory + OUTPUT_DIR = "./output_PhyCRNet" if not args.output_dir else args.output_dir + # initialize logger + logger.init_logger("ppsci", f"{OUTPUT_DIR}/train.log", "info") + # set training hyper-parameters + EPOCHS = 200 if not args.epochs else args.epochs + + # set initial states for convlstm + num_convlstm = 1 + (h0, c0) = (paddle.randn((1, 128, 16, 16)), paddle.randn((1, 128, 16, 16))) + initial_state = [] + for i in range(num_convlstm): + initial_state.append((h0, c0)) + + global num_time_batch + global uv, dt, dx + # grid parameters + time_steps = 1001 + dt = 0.002 + dx = 1.0 / 128 + + time_batch_size = 1000 + steps = time_batch_size + 1 + effective_step = list(range(0, steps)) + num_time_batch = int(time_steps / time_batch_size) + model = ppsci.arch.PhyCRNet( + input_channels=2, + hidden_channels=[8, 32, 128, 128], + input_kernel_size=[4, 4, 4, 3], + input_stride=[2, 2, 2, 1], + input_padding=[1, 1, 1, 1], + dt=dt, + num_layers=[3, 1], + upscale_factor=8, + step=steps, + effective_step=effective_step, + ) + + def transform_out_wrap(_in, _out): + return transform_out(_in, _out, model) + + model.register_input_transform(transform_in) + model.register_output_transform(transform_out_wrap) + + # use burgers_data.py to generate data + data_file = "./output/burgers_1501x2x128x128.mat" + data = scio.loadmat(data_file) + uv = data["uv"] # [t,c,h,w] + + # initial condition + uv0 = uv[0:1, ...] + input = paddle.to_tensor(uv0, dtype=paddle.get_default_dtype()) + + initial_state = paddle.to_tensor(initial_state) + dataset_obj = functions.Dataset(initial_state, input) + ( + input_dict_train, + label_dict_train, + input_dict_val, + label_dict_val, + ) = dataset_obj.get(200) + + sup_constraint_pde = ppsci.constraint.SupervisedConstraint( + { + "dataset": { + "name": "NamedArrayDataset", + "input": input_dict_train, + "label": label_dict_train, + }, + }, + ppsci.loss.FunctionalLoss(train_loss_func), + { + "loss": lambda out: out["loss"], + }, + name="sup_train", + ) + constraint_pde = {sup_constraint_pde.name: sup_constraint_pde} + + sup_validator_pde = ppsci.validate.SupervisedValidator( + { + "dataset": { + "name": "NamedArrayDataset", + "input": input_dict_val, + "label": label_dict_val, + }, + }, + ppsci.loss.FunctionalLoss(val_loss_func), + { + "loss": lambda out: out["loss"], + }, + metric={"metric": ppsci.metric.FunctionalMetric(functions.metric_expr)}, + name="sup_valid", + ) + validator_pde = {sup_validator_pde.name: sup_validator_pde} + + # initialize solver + ITERS_PER_EPOCH = 1 + scheduler = ppsci.optimizer.lr_scheduler.Step( + epochs=EPOCHS, + iters_per_epoch=ITERS_PER_EPOCH, + step_size=100, + gamma=0.97, + learning_rate=1e-4, + )() + optimizer = ppsci.optimizer.Adam(scheduler)(model) + solver = ppsci.solver.Solver( + model, + constraint_pde, + OUTPUT_DIR, + optimizer, + scheduler, + EPOCHS, + ITERS_PER_EPOCH, + save_freq=50, + validator=validator_pde, + eval_with_no_grad=True, + ) + + # Used to set whether the graph is generated + graph = True + + if not graph: + # train model + solver.train() + # evaluate after finished training + model.register_output_transform(tranform_output_val) + solver.eval() + + # save the model + layer_state_dict = model.state_dict() + paddle.save(layer_state_dict, "output/phycrnet.pdparams") + else: + import os + + fig_save_path = "output/figures/" + if not os.path.exists(fig_save_path): + os.makedirs(fig_save_path, True) + layer_state_dict = paddle.load("output/phycrnet.pdparams") + model.set_state_dict(layer_state_dict) + model.register_output_transform(None) + output_graph(model, input_dict_val, fig_save_path) diff --git a/mkdocs.yml b/mkdocs.yml index 4cf50e9d4..c4ef97aff 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -58,6 +58,7 @@ nav: - DeepHPMs: zh/examples/deephpms.md - Lorenz_transform_physx: zh/examples/lorenz.md - Rossler_transform_physx: zh/examples/rossler.md + - PhyCRNet: zh/examples/phycrnet.md - 算子学习: - DeepONet: zh/examples/deeponet.md - 气象预测: diff --git a/ppsci/arch/__init__.py b/ppsci/arch/__init__.py index 8f1e417d7..ef4c426bc 100644 --- a/ppsci/arch/__init__.py +++ b/ppsci/arch/__init__.py @@ -22,6 +22,7 @@ from ppsci.arch.embedding_koopman import CylinderEmbedding # isort:skip from ppsci.arch.gan import Generator # isort:skip from ppsci.arch.gan import Discriminator # isort:skip +from ppsci.arch.phycrnet import PhyCRNet # isort:skip from ppsci.arch.phylstm import DeepPhyLSTM # isort:skip from ppsci.arch.physx_transformer import PhysformerGPT2 # isort:skip from ppsci.arch.model_list import ModelList # isort:skip @@ -46,6 +47,7 @@ "AFNONet", "PrecipNet", "UNetEx", + "PhyCRNet", "build_model", ] diff --git a/ppsci/arch/phycrnet.py b/ppsci/arch/phycrnet.py new file mode 100644 index 000000000..70eeac823 --- /dev/null +++ b/ppsci/arch/phycrnet.py @@ -0,0 +1,534 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np +import paddle +import paddle.nn as nn +from paddle.nn import utils + +from ppsci.arch import base + +# define the high-order finite difference kernels +lapl_op = [ + [ + [ + [0, 0, -1 / 12, 0, 0], + [0, 0, 4 / 3, 0, 0], + [-1 / 12, 4 / 3, -5, 4 / 3, -1 / 12], + [0, 0, 4 / 3, 0, 0], + [0, 0, -1 / 12, 0, 0], + ] + ] +] + +partial_y = [ + [ + [ + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [1 / 12, -8 / 12, 0, 8 / 12, -1 / 12], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + ] + ] +] + +partial_x = [ + [ + [ + [0, 0, 1 / 12, 0, 0], + [0, 0, -8 / 12, 0, 0], + [0, 0, 0, 0, 0], + [0, 0, 8 / 12, 0, 0], + [0, 0, -1 / 12, 0, 0], + ] + ] +] + +# specific parameters for burgers equation +def initialize_weights(module): + if isinstance(module, nn.Conv2D): + c = 1.0 # 0.5 + initializer = nn.initializer.Uniform( + -c * np.sqrt(1 / (3 * 3 * 320)), c * np.sqrt(1 / (3 * 3 * 320)) + ) + initializer(module.weight) + elif isinstance(module, nn.Linear): + initializer = nn.initializer.Constant(0.0) + initializer(module.bias) + + +class PhyCRNet(base.Arch): + """Physics-informed convolutional-recurrent neural networks + + Args: + input_channels (int): The input channels. + hidden_channels (List[int]): The hidden channels. + input_kernel_size (List[int]): The input kernel size. + input_stride (List[int]): The input stride. + input_padding (List[int]): The input padding. + dt (float): The dt parameter. + num_layers (List[int]): The number of layers. + upscale_factor (int): The upscale factor. + step (int, optional): The step. Defaults to 1. + effective_step (list, optional): The effective step. Defaults to [1]. + + Examples: + >>> import ppsci + >>> model = ppsci.arch.PhyCRNet( + input_channels=2, + hidden_channels=[8, 32, 128, 128], + input_kernel_size=[4, 4, 4, 3], + input_stride=[2, 2, 2, 1], + input_padding=[1, 1, 1, 1], + dt=0.002, + num_layers=[3, 1], + upscale_factor=8 + ) + """ + + def __init__( + self, + input_channels, + hidden_channels, + input_kernel_size, + input_stride, + input_padding, + dt, + num_layers, + upscale_factor, + step=1, + effective_step=[1], + ): + super(PhyCRNet, self).__init__() + + # input channels of layer includes input_channels and hidden_channels of cells + self.input_channels = [input_channels] + hidden_channels + self.hidden_channels = hidden_channels + self.input_kernel_size = input_kernel_size + self.input_stride = input_stride + self.input_padding = input_padding + self.step = step + self.effective_step = effective_step + self._all_layers = [] + self.dt = dt + self.upscale_factor = upscale_factor + + # number of layers + self.num_encoder = num_layers[0] + self.num_convlstm = num_layers[1] + + # encoder - downsampling + for i in range(self.num_encoder): + name = "encoder{}".format(i) + cell = encoder_block( + input_channels=self.input_channels[i], + hidden_channels=self.hidden_channels[i], + input_kernel_size=self.input_kernel_size[i], + input_stride=self.input_stride[i], + input_padding=self.input_padding[i], + ) + + setattr(self, name, cell) + self._all_layers.append(cell) + + # ConvLSTM + for i in range(self.num_encoder, self.num_encoder + self.num_convlstm): + name = "convlstm{}".format(i) + cell = ConvLSTMCell( + input_channels=self.input_channels[i], + hidden_channels=self.hidden_channels[i], + input_kernel_size=self.input_kernel_size[i], + input_stride=self.input_stride[i], + input_padding=self.input_padding[i], + ) + + setattr(self, name, cell) + self._all_layers.append(cell) + + # output layer + self.output_layer = nn.Conv2D( + 2, 2, kernel_size=5, stride=1, padding=2, padding_mode="circular" + ) + + # pixelshuffle - upscale + self.pixelshuffle = nn.PixelShuffle(self.upscale_factor) + + # initialize weights + self.apply(initialize_weights) + initializer_0 = paddle.nn.initializer.Constant(0.0) + initializer_0(self.output_layer.bias) + self.enable_transform = True + + def forward(self, x): + if self.enable_transform: + if self._input_transform is not None: + x = self._input_transform(x) + output_x = x + + self.initial_state = x["initial_state"] + x = x["input"] + internal_state = [] + outputs = [] + second_last_state = [] + + for step in range(self.step): + xt = x + + # encoder + for i in range(self.num_encoder): + name = "encoder{}".format(i) + x = getattr(self, name)(x) + + # convlstm + for i in range(self.num_encoder, self.num_encoder + self.num_convlstm): + name = "convlstm{}".format(i) + if step == 0: + (h, c) = getattr(self, name).init_hidden_tensor( + prev_state=self.initial_state[i - self.num_encoder] + ) + internal_state.append((h, c)) + + # one-step forward + (h, c) = internal_state[i - self.num_encoder] + x, new_c = getattr(self, name)(x, h, c) + internal_state[i - self.num_encoder] = (x, new_c) + + # output + x = self.pixelshuffle(x) + x = self.output_layer(x) + + # residual connection + x = xt + self.dt * x + + if step == (self.step - 2): + second_last_state = internal_state.copy() + + if step in self.effective_step: + outputs.append(x) + + result_dict = {"outputs": outputs, "second_last_state": second_last_state} + if self.enable_transform: + if self._output_transform is not None: + result_dict = self._output_transform(output_x, result_dict) + return result_dict + + +class ConvLSTMCell(nn.Layer): + """Convolutional LSTM""" + + def __init__( + self, + input_channels, + hidden_channels, + input_kernel_size, + input_stride, + input_padding, + ): + super(ConvLSTMCell, self).__init__() + + self.input_channels = input_channels + self.hidden_channels = hidden_channels + self.hidden_kernel_size = 3 + self.input_kernel_size = input_kernel_size + self.input_stride = input_stride + self.input_padding = input_padding + self.num_features = 4 + + # padding for hidden state + self.padding = int((self.hidden_kernel_size - 1) / 2) + + self.Wxi = nn.Conv2D( + self.input_channels, + self.hidden_channels, + self.input_kernel_size, + self.input_stride, + self.input_padding, + bias_attr=None, + padding_mode="circular", + ) + + self.Whi = nn.Conv2D( + self.hidden_channels, + self.hidden_channels, + self.hidden_kernel_size, + 1, + padding=1, + bias_attr=False, + padding_mode="circular", + ) + + self.Wxf = nn.Conv2D( + self.input_channels, + self.hidden_channels, + self.input_kernel_size, + self.input_stride, + self.input_padding, + bias_attr=None, + padding_mode="circular", + ) + + self.Whf = nn.Conv2D( + self.hidden_channels, + self.hidden_channels, + self.hidden_kernel_size, + 1, + padding=1, + bias_attr=False, + padding_mode="circular", + ) + + self.Wxc = nn.Conv2D( + self.input_channels, + self.hidden_channels, + self.input_kernel_size, + self.input_stride, + self.input_padding, + bias_attr=None, + padding_mode="circular", + ) + + self.Whc = nn.Conv2D( + self.hidden_channels, + self.hidden_channels, + self.hidden_kernel_size, + 1, + padding=1, + bias_attr=False, + padding_mode="circular", + ) + + self.Wxo = nn.Conv2D( + self.input_channels, + self.hidden_channels, + self.input_kernel_size, + self.input_stride, + self.input_padding, + bias_attr=None, + padding_mode="circular", + ) + + self.Who = nn.Conv2D( + self.hidden_channels, + self.hidden_channels, + self.hidden_kernel_size, + 1, + padding=1, + bias_attr=False, + padding_mode="circular", + ) + + initializer_0 = paddle.nn.initializer.Constant(0.0) + initializer_1 = paddle.nn.initializer.Constant(1.0) + + initializer_0(self.Wxi.bias) + initializer_0(self.Wxf.bias) + initializer_0(self.Wxc.bias) + initializer_1(self.Wxo.bias) + + def forward(self, x, h, c): + ci = paddle.nn.functional.sigmoid(self.Wxi(x) + self.Whi(h)) + cf = paddle.nn.functional.sigmoid(self.Wxf(x) + self.Whf(h)) + cc = cf * c + ci * paddle.tanh(self.Wxc(x) + self.Whc(h)) + co = paddle.nn.functional.sigmoid(self.Wxo(x) + self.Who(h)) + ch = co * paddle.tanh(cc) + return ch, cc + + def init_hidden_tensor(self, prev_state): + return ((prev_state[0]).cuda(), (prev_state[1]).cuda()) + + +class encoder_block(nn.Layer): + """encoder with CNN""" + + def __init__( + self, + input_channels, + hidden_channels, + input_kernel_size, + input_stride, + input_padding, + ): + super(encoder_block, self).__init__() + + self.input_channels = input_channels + self.hidden_channels = hidden_channels + self.input_kernel_size = input_kernel_size + self.input_stride = input_stride + self.input_padding = input_padding + + self.conv = utils.weight_norm( + nn.Conv2D( + self.input_channels, + self.hidden_channels, + self.input_kernel_size, + self.input_stride, + self.input_padding, + bias_attr=None, + padding_mode="circular", + ) + ) + + self.act = nn.ReLU() + + initializer_0 = paddle.nn.initializer.Constant(0.0) + initializer_0(self.conv.bias) + + def forward(self, x): + return self.act(self.conv(x)) + + +class Conv2DDerivative(nn.Layer): + def __init__(self, der_filter, resol, kernel_size=3, name=""): + super(Conv2DDerivative, self).__init__() + + self.resol = resol # constant in the finite difference + self.name = name + self.input_channels = 1 + self.output_channels = 1 + self.kernel_size = kernel_size + + self.padding = int((kernel_size - 1) / 2) + self.filter = nn.Conv2D( + self.input_channels, + self.output_channels, + self.kernel_size, + 1, + padding=0, + bias_attr=False, + ) + + # Fixed gradient operator + self.filter.weight = self.create_parameter( + shape=self.filter.weight.shape, + dtype=self.filter.weight.dtype, + default_initializer=paddle.nn.initializer.Assign( + paddle.to_tensor( + der_filter, dtype=paddle.get_default_dtype(), stop_gradient=True + ) + ), + ) + self.filter.weight.stop_gradient = True + + def forward(self, input): + derivative = self.filter(input) + return derivative / self.resol + + +class Conv1DDerivative(nn.Layer): + def __init__(self, der_filter, resol, kernel_size=3, name=""): + super(Conv1DDerivative, self).__init__() + + self.resol = resol # $\delta$*constant in the finite difference + self.name = name + self.input_channels = 1 + self.output_channels = 1 + self.kernel_size = kernel_size + + self.padding = int((kernel_size - 1) / 2) + self.filter = nn.Conv1D( + self.input_channels, + self.output_channels, + self.kernel_size, + 1, + padding=0, + bias_attr=False, + ) + + # Fixed gradient operator + self.filter.weight = self.create_parameter( + shape=self.filter.weight.shape, + dtype=self.filter.weight.dtype, + default_initializer=paddle.nn.initializer.Assign( + paddle.to_tensor( + der_filter, dtype=paddle.get_default_dtype(), stop_gradient=True + ) + ), + ) + self.filter.weight.stop_gradient = True + + def forward(self, input): + derivative = self.filter(input) + return derivative / self.resol + + +class loss_generator(nn.Layer): + """Loss generator for physics loss""" + + def __init__(self, dt=(10.0 / 200), dx=(20.0 / 128)): + """Construct the derivatives, X = Width, Y = Height""" + super(loss_generator, self).__init__() + + # spatial derivative operator + self.laplace = Conv2DDerivative( + der_filter=lapl_op, resol=(dx**2), kernel_size=5, name="laplace_operator" + ) + + self.dx = Conv2DDerivative( + der_filter=partial_x, resol=(dx * 1), kernel_size=5, name="dx_operator" + ) + + self.dy = Conv2DDerivative( + der_filter=partial_y, resol=(dx * 1), kernel_size=5, name="dy_operator" + ) + + # temporal derivative operator + self.dt = Conv1DDerivative( + der_filter=[[[-1, 0, 1]]], resol=(dt * 2), kernel_size=3, name="partial_t" + ) + + def get_phy_Loss(self, output): + # spatial derivatives + laplace_u = self.laplace(output[1:-1, 0:1, :, :]) # [t,c,h,w] + laplace_v = self.laplace(output[1:-1, 1:2, :, :]) + + u_x = self.dx(output[1:-1, 0:1, :, :]) + u_y = self.dy(output[1:-1, 0:1, :, :]) + v_x = self.dx(output[1:-1, 1:2, :, :]) + v_y = self.dy(output[1:-1, 1:2, :, :]) + + # temporal derivative - u + u = output[:, 0:1, 2:-2, 2:-2] + lent = u.shape[0] + lenx = u.shape[3] + leny = u.shape[2] + u_conv1d = u.transpose((2, 3, 1, 0)) # [height(Y), width(X), c, step] + u_conv1d = u_conv1d.reshape((lenx * leny, 1, lent)) + u_t = self.dt(u_conv1d) # lent-2 due to no-padding + u_t = u_t.reshape((leny, lenx, 1, lent - 2)) + u_t = u_t.transpose((3, 2, 0, 1)) # [step-2, c, height(Y), width(X)] + + # temporal derivative - v + v = output[:, 1:2, 2:-2, 2:-2] + v_conv1d = v.transpose((2, 3, 1, 0)) # [height(Y), width(X), c, step] + v_conv1d = v_conv1d.reshape((lenx * leny, 1, lent)) + v_t = self.dt(v_conv1d) # lent-2 due to no-padding + v_t = v_t.reshape((leny, lenx, 1, lent - 2)) + v_t = v_t.transpose((3, 2, 0, 1)) # [step-2, c, height(Y), width(X)] + + u = output[1:-1, 0:1, 2:-2, 2:-2] # [t, c, height(Y), width(X)] + v = output[1:-1, 1:2, 2:-2, 2:-2] # [t, c, height(Y), width(X)] + + assert laplace_u.shape == u_t.shape + assert u_t.shape == v_t.shape + assert laplace_u.shape == u.shape + assert laplace_v.shape == v.shape + + R = 200.0 + + # 2D burgers eqn + f_u = u_t + u * u_x + v * u_y - (1 / R) * laplace_u + f_v = v_t + u * v_x + v * v_y - (1 / R) * laplace_v + + return f_u, f_v