Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

【Hackathon 6th No.39】XPINN 迁移至 PaddleScience #849

Merged
merged 20 commits into from
May 30, 2024

Conversation

MayYouBeProsperous
Copy link
Contributor

@MayYouBeProsperous MayYouBeProsperous commented Apr 15, 2024

PR types

Others

PR changes

Others

Describe

原案例精度:0.04685
复现精度:0.04177

image

Copy link

paddle-bot bot commented Apr 15, 2024

Thanks for your contribution!


import ppsci

# For the use of the second derivative: paddle.cos, paddle.exp
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

可以删除注释里的paddle.exp,默认模式已经支持无限阶微分

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

好的,已修改

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

好像没有上传配置文件?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已上传

Comment on lines 34 to 126
def xpinn_2d(
x1: paddle.Tensor,
y1: paddle.Tensor,
u1: paddle.Tensor,
x2: paddle.Tensor,
y2: paddle.Tensor,
u2: paddle.Tensor,
x3: paddle.Tensor,
y3: paddle.Tensor,
u3: paddle.Tensor,
xi1: paddle.Tensor,
yi1: paddle.Tensor,
u1i1: paddle.Tensor,
u2i1: paddle.Tensor,
xi2: paddle.Tensor,
yi2: paddle.Tensor,
u1i2: paddle.Tensor,
u3i2: paddle.Tensor,
ub: paddle.Tensor,
ub_pred: paddle.Tensor,
):
u1_x = get_grad(u1, x1)
u1_y = get_grad(u1, y1)
u1_xx = get_grad(u1_x, x1)
u1_yy = get_grad(u1_y, y1)

u2_x = get_grad(u2, x2)
u2_y = get_grad(u2, y2)
u2_xx = get_grad(u2_x, x2)
u2_yy = get_grad(u2_y, y2)

u3_x = get_grad(u3, x3)
u3_y = get_grad(u3, y3)
u3_xx = get_grad(u3_x, x3)
u3_yy = get_grad(u3_y, y3)

u1i1_x = get_grad(u1i1, xi1)
u1i1_y = get_grad(u1i1, yi1)
u1i1_xx = get_grad(u1i1_x, xi1)
u1i1_yy = get_grad(u1i1_y, yi1)

u2i1_x = get_grad(u2i1, xi1)
u2i1_y = get_grad(u2i1, yi1)
u2i1_xx = get_grad(u2i1_x, xi1)
u2i1_yy = get_grad(u2i1_y, yi1)

u1i2_x = get_grad(u1i2, xi2)
u1i2_y = get_grad(u1i2, yi2)
u1i2_xx = get_grad(u1i2_x, xi2)
u1i2_yy = get_grad(u1i2_y, yi2)

u3i2_x = get_grad(u3i2, xi2)
u3i2_y = get_grad(u3i2, yi2)
u3i2_xx = get_grad(u3i2_x, xi2)
u3i2_yy = get_grad(u3i2_y, yi2)

uavgi1 = (u1i1 + u2i1) / 2
uavgi2 = (u1i2 + u3i2) / 2

# Residuals
f1 = u1_xx + u1_yy - (paddle.exp(x1) + paddle.exp(y1))
f2 = u2_xx + u2_yy - (paddle.exp(x2) + paddle.exp(y2))
f3 = u3_xx + u3_yy - (paddle.exp(x3) + paddle.exp(y3))

# Residual continuity conditions on the interfaces
fi1 = (u1i1_xx + u1i1_yy - (paddle.exp(xi1) + paddle.exp(yi1))) - (
u2i1_xx + u2i1_yy - (paddle.exp(xi1) + paddle.exp(yi1))
)
fi2 = (u1i2_xx + u1i2_yy - (paddle.exp(xi2) + paddle.exp(yi2))) - (
u3i2_xx + u3i2_yy - (paddle.exp(xi2) + paddle.exp(yi2))
)

loss1 = (
20 * paddle.mean(paddle.square(ub - ub_pred))
+ paddle.mean(paddle.square(f1))
+ 1 * paddle.mean(paddle.square(fi1))
+ 1 * paddle.mean(paddle.square(fi2))
+ 20 * paddle.mean(paddle.square(u1i1 - uavgi1))
+ 20 * paddle.mean(paddle.square(u1i2 - uavgi2))
)

loss2 = (
paddle.mean(paddle.square(f2))
+ 1 * paddle.mean(paddle.square(fi1))
+ 20 * paddle.mean(paddle.square(u2i1 - uavgi1))
)

loss3 = (
paddle.mean(paddle.square(f3))
+ 1 * paddle.mean(paddle.square(fi2))
+ 20 * paddle.mean(paddle.square(u3i2 - uavgi2))
)
return loss1, loss2, loss3
Copy link
Collaborator

@HydrogenSulfate HydrogenSulfate Apr 15, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  1. 函数内重复代码过多,建议化简下
  2. 这样的XPINN写法感觉不太具备可扩展性,换一个案例这个函数就不能用了

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

重新实现了XPINN的loss函数

Comment on lines 129 to 156
def loss_fun(
output_dict: Dict[str, paddle.Tensor],
label_dict: Dict[str, paddle.Tensor],
*args,
):
loss1, loss2, loss3 = xpinn_2d(
output_dict["x_f1"],
output_dict["y_f1"],
output_dict["u1"],
output_dict["x_f2"],
output_dict["y_f2"],
output_dict["u2"],
output_dict["x_f3"],
output_dict["y_f3"],
output_dict["u3"],
output_dict["xi1"],
output_dict["yi1"],
output_dict["u1i1"],
output_dict["u2i1"],
output_dict["xi2"],
output_dict["yi2"],
output_dict["u1i2"],
output_dict["u3i2"],
label_dict["ub"],
output_dict["ub_pred"],
)

return loss1 + loss2 + loss3
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

同上,代码化简

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

)

error_u_total = paddle.linalg.norm(
paddle.squeeze(u_exact) - u_pred.flatten(), 2
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

squeeze改为flatten,两边语义一致

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

Comment on lines 225 to 238
"input_keys": (
"x_f1",
"y_f1",
"x_f2",
"y_f2",
"x_f3",
"y_f3",
"xi1",
"yi1",
"xi2",
"yi2",
"xb",
"yb",
),
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

分解后的input_keys可以由基本的input_keys+XPINN生成,不建议使用hard code

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

input_keys 放在了配置文件中

"xb",
"yb",
),
"label_keys": ("ub", "u_exact", "u_exact2", "u_exact3"),
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

u_exact==>u_exact1?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

数据集中的名字是 u_exact

Comment on lines 352 to 363
x1=input_["x_f1"],
y1=input_["y_f1"],
x2=input_["x_f2"],
y2=input_["y_f2"],
x3=input_["x_f3"],
y3=input_["y_f3"],
xi1=input_["xi1"],
yi1=input_["yi1"],
xi2=input_["xi2"],
yi2=input_["yi2"],
xb=input_["xb"],
yb=input_["yb"],
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

不建议使用x_f1、xb、yi1这类命名方式,看不太出来每个字段的含义

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

好的,修改了命名

Comment on lines 283 to 296
"input_keys": (
"x_f1",
"y_f1",
"x_f2",
"y_f2",
"x_f3",
"y_f3",
"xi1",
"yi1",
"xi2",
"yi2",
"xb",
"yb",
),
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

不建议使用x_f1、xb、yi1这类命名方式,看不太出来每个字段的含义

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

好的,修改了命名

):
for key in in_:
in_[key] = paddle.cast(in_[key], paddle.float64)
return in_, _label, _weight
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

同上,下划线前后保持一致

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

if not np.issubdtype(data_dict[key].dtype, np.integer):
if (
not np.issubdtype(data_dict[key].dtype, np.integer)
and data_dict[key].dtype != np.float64
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里这么改会导致float64类型的数据不能被转换为float32,模型训练时数据与权重参数类型不匹配而报错。

Copy link
Contributor Author

@MayYouBeProsperous MayYouBeProsperous May 14, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

之前误以为有些算子只支持float64,已经修改。文档还在编写中。

Comment on lines 33 to 36
def get_grad(outputs: paddle.Tensor, inputs: paddle.Tensor) -> paddle.Tensor:
grad = paddle.grad(outputs, inputs, retain_graph=True, create_graph=True)
return grad[0]

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个函数看起来只被单独调用,建议放到get_second_derivatives上方,并且函数名前加上单个下划线

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

residual_func (Callable, optional): residual calculation function. Defaults to lambda x,y : x - y.
"""

def get_second_derivatives(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

函数前面加上单下划线

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

return grad[0]


def xpinn_loss(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

函数前面加上单下划线

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

]
)

# the shape of label_dict["residual_u_exact"] is [22387, 1], and be cut into [18211, 1] `_eval_by_dataset`(ppsci/solver/eval.py).
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个注释是什么意思呢?是输入的样本点数不等于输出的样本点数,导致在_eval_by_dataset被丢掉?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

是的,输入的样本点数小于标签点数,数据集提供的label_dict["residual_u_exact"] 本来包含了label_dict["residual2_u_exact"] 和 label_dict["residual3_u_exact"]。

error_u_total = paddle.linalg.norm(
u_exact.flatten() - u_pred.flatten(), 2
) / paddle.linalg.norm(u_exact.flatten(), 2)
return {"total": error_u_total}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

total==>l2_error

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

@MayYouBeProsperous
Copy link
Contributor Author

@HydrogenSulfate
Copy link
Collaborator

HydrogenSulfate

This comment was marked as outdated.

Copy link
Collaborator

@HydrogenSulfate HydrogenSulfate left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  1. cfg.TRAIN_DATA_FILE好像没有出现在cfg里面,应该是cfg.DATA_FILE?
  2. 精度本地测试没有问题,辛苦修改一下第一点

Copy link
Collaborator

@HydrogenSulfate HydrogenSulfate left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@HydrogenSulfate HydrogenSulfate merged commit 92c654d into PaddlePaddle:develop May 30, 2024
3 of 4 checks passed
huohuohuohuohuo123 pushed a commit to huohuohuohuohuo123/PaddleScience that referenced this pull request Aug 12, 2024
* add XPINNs example

* comment

* add conf file

* refine code

* fix comment

* fix data type

* fix data type

* Update examples/xpinn/plotting.py

* Update examples/xpinn/conf/xpinn.yaml

* refine code

* refine doc

* refine doc

* fix doc

* fix bugs

---------

Co-authored-by: HydrogenSulfate <490868991@qq.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants