Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

【Hackathon 6th No.37】GraphCastNet 代码迁移至 PaddleScience #897

Merged
merged 32 commits into from
Jun 26, 2024

Conversation

MayYouBeProsperous
Copy link
Contributor

@MayYouBeProsperous MayYouBeProsperous commented May 14, 2024

PR types

Others

PR changes

APIs, Docs

Describe

复现运行结果
image

Copy link

paddle-bot bot commented May 14, 2024

Thanks for your contribution!

@MayYouBeProsperous
Copy link
Contributor Author

麻烦制作数据集下载链接:https://aistudio.baidu.com/datasetdetail/252766

@MayYouBeProsperous MayYouBeProsperous changed the title [WIP]【Hackathon 6th No.37】GraphCastNet 代码迁移至 PaddleScience 【Hackathon 6th No.37】GraphCastNet 代码迁移至 PaddleScience Jun 6, 2024
@lijialin03
Copy link
Contributor

lijialin03 commented Jun 7, 2024

麻烦制作数据集下载链接:https://aistudio.baidu.com/datasetdetail/252766

数据集下载链接分别为:
dataset.zip
dataset-step12.zip
params.zip
template_graph.zip
stats.zip
graphcast-jax2paddle.csv
jax_graphcast_small_output.npy

@lijialin03
Copy link
Contributor

我看标题中的WIP去掉了,代码现在是可以review的状态嘛

@MayYouBeProsperous
Copy link
Contributor Author

是的,可以 review 了。
CI 报错好像是注释的括号匹配问题,应该怎么解决呢
image
image

@lijialin03
Copy link
Contributor

是的,可以 review 了。 CI 报错好像是注释的括号匹配问题,应该怎么解决呢 image image

好嘞,这个先不管吧,等review完再说

data: xarray.Dataset,
center: Optional[float] = None,
robust: bool = False,
) -> tuple[xarray.Dataset, matplotlib.colors.Normalize, str]:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

改成Tuple

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.



def plot_data(
data: dict[str, xarray.Dataset],
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

改成Dict

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.


import datetime
import math
from typing import Optional
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

增加from typing import Tuplefrom typing import Dict

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

robust: bool = False,
cols: int = 4,
file: str = "result.png",
) -> tuple[xarray.Dataset, matplotlib.colors.Normalize, str]:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

同上改成Tuple

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

grid_longitude: np.ndarray,
mesh: TriangularMesh,
radius: float,
) -> tuple[np.ndarray, np.ndarray]:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

改成Tuple

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个文件似乎是为了定义一种新的数据类型“GraphGridMesh”,因此把这个文件整个挪到atmospheric_dataset.py中吧

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

好的,移到了atmospheric_dataset.py中

import paddle
import paddle.nn as nn

import ppsci.data.dataset.atmospheric_utils as atmospheric_utils
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这边改成:

from typing import TYPE_CHECKING
if TYPE_CHECKING:
    import ppsci.data.dataset.atmospheric_utils as atmospheric_utils

然后将这个文件中所有atmospheric_utils.GraphGridMesh加上双引号,改为“atmospheric_utils.GraphGridMesh”

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

atmospheric_utils.py 内容移到了 atmospheri_dataset.py

should_init = any(var is None for var in all_input_vars)

if should_init:
# 初始化构建
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

删除中文注释

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

self._mesh2grid_graph_structure = self._init_mesh2grid_graph()
else:
# 直接构建图数据
# 图结构信息
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

同上删除中文注释

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

self.grid2mesh_num_edges = grid2mesh_num_edges
self.mesh2grid_num_edges = mesh2grid_num_edges

# 图特征信息
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

同上

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

Copy link
Contributor

@lijialin03 lijialin03 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

调用jointContribution/graphcast/run中的compare
输入运行结果后得到对齐结果为:

10m_u_component_of_wind diff is -2.356016193110792e-05
10m_v_component_of_wind diff is 7.523250922509292e-06
2m_temperature diff is 2.4777781459127155e-05
geopotential diff is -0.000539646573504715
mean_sea_level_pressure diff is -0.0005250472340893714
specific_humidity diff is 6.035347363899133e-10
temperature diff is 8.158555083697042e-07
total_precipitation_6hr diff is 1.8224237754237203e-09
u_component_of_wind diff is 1.7749217200032017e-06
v_component_of_wind diff is 7.323719328271271e-06
vertical_velocity diff is -2.375314757111974e-07
All diff is -8.92277301779838e-05

参考#699,认为可以对齐

Copy link
Collaborator

@HydrogenSulfate HydrogenSulfate left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

辛苦提交PR,代码内容修改完毕之后,可以同步更新一下文档的代码块定位位置

Comment on lines 3 to 5
=== "模型评估命令"
``` sh

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
=== "模型评估命令"
``` sh
=== "模型评估命令"
``` sh

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

@@ -0,0 +1,63 @@
hydra:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
hydra:
defaults:
- ppsci_default
- TRAIN: train_default
- TRAIN/ema: ema_default
- TRAIN/swa: swa_default
- EVAL: eval_default
- INFER: infer_default
- hydra/job/config/override_dirname/exclude_keys: exclude_keys_default
- _self_
hydra:

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

Comment on lines 8 to 16
config:
override_dirname:
exclude_keys:
- TRAIN.checkpoint_path
- TRAIN.pretrained_model_path
- EVAL.pretrained_model_path
- mode
- output_dir
- log_freq
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
config:
override_dirname:
exclude_keys:
- TRAIN.checkpoint_path
- TRAIN.pretrained_model_path
- EVAL.pretrained_model_path
- mode
- output_dir
- log_freq

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.


EVAL:
batch_size: 1
pretrained_model_path: "data/params/GraphCast_small---ERA5-1979-2015---resolution-1.0---pressure-levels-13---mesh-2to5---precipitation-input-and-output.pdparams"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  1. pretrained_model_path默认值改为null,模型评估要求手动指定预训练模型

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

unzip -q stats.zip -d data/
unzip -q template_graph.zip -d data/

python graphcast.py
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

模型评估要求手动指定预训练模型,命令应该类似:
python graphcast.py mode=eval EVAL.pretrained_model_path=xxxx

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

def stacked_to_dataset(
stacked_array: "xarray.Variable",
template_dataset: "xarray.Dataset",
preserved_dims: typing.Tuple[str, ...] = ("batch", "lat", "lon"),
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

typing.XXX建议改为XXX

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

好的,已修改

faces: np.ndarray


def merge_meshes(mesh_list: typing.Sequence[TriangularMesh]) -> TriangularMesh:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

同上

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.


def get_hierarchy_of_triangular_meshes_for_sphere(
splits: int,
) -> typing.List[TriangularMesh]:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

同上

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

)


class _ChildVerticesBuilder(object):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

python3不需要显式继承object

Suggested change
class _ChildVerticesBuilder(object):
class _ChildVerticesBuilder:

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这类位置不需要加换行,其余文件同理,不需要添加换行

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里是删除多余换行的

Copy link
Collaborator

@HydrogenSulfate HydrogenSulfate left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

Copy link
Collaborator

@HydrogenSulfate HydrogenSulfate left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@HydrogenSulfate HydrogenSulfate merged commit 2eb2639 into PaddlePaddle:develop Jun 26, 2024
2 of 3 checks passed
huohuohuohuohuo123 pushed a commit to huohuohuohuohuo123/PaddleScience that referenced this pull request Aug 12, 2024
* add GraphGridMeshDataSet

* add graphcast model and example

* refine code and visualization

* add license

* fix bugs

* fix ci errors

* fix ci errors

* fix ci errors

* fix ci errors

* fix ci errors

* add docs

* resolve conflicts

* resolve conflicts

* resolve conflicts

* resolve conflicts

* refine docs

* fix comments

* fix comments

* fix comments

* fix

* delete atmospheric_utils.py

* fix ci errors

* fix

* fix comments

* fix

* add transform in graphcast

* fix model bugs

* fix

* fix docs
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants