-
Notifications
You must be signed in to change notification settings - Fork 171
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
【Hackathon 6th No.37】GraphCastNet 代码迁移至 PaddleScience #897
Conversation
Thanks for your contribution! |
麻烦制作数据集下载链接:https://aistudio.baidu.com/datasetdetail/252766 |
我看标题中的WIP去掉了,代码现在是可以review的状态嘛 |
examples/graphcast/plot.py
Outdated
data: xarray.Dataset, | ||
center: Optional[float] = None, | ||
robust: bool = False, | ||
) -> tuple[xarray.Dataset, matplotlib.colors.Normalize, str]: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
改成Tuple
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
examples/graphcast/plot.py
Outdated
|
||
|
||
def plot_data( | ||
data: dict[str, xarray.Dataset], |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
改成Dict
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
|
||
import datetime | ||
import math | ||
from typing import Optional |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
增加from typing import Tuple
和from typing import Dict
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
examples/graphcast/plot.py
Outdated
robust: bool = False, | ||
cols: int = 4, | ||
file: str = "result.png", | ||
) -> tuple[xarray.Dataset, matplotlib.colors.Normalize, str]: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
同上改成Tuple
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
grid_longitude: np.ndarray, | ||
mesh: TriangularMesh, | ||
radius: float, | ||
) -> tuple[np.ndarray, np.ndarray]: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
改成Tuple
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个文件似乎是为了定义一种新的数据类型“GraphGridMesh”,因此把这个文件整个挪到atmospheric_dataset.py中吧
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
好的,移到了atmospheric_dataset.py中
ppsci/arch/graphcast.py
Outdated
import paddle | ||
import paddle.nn as nn | ||
|
||
import ppsci.data.dataset.atmospheric_utils as atmospheric_utils |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这边改成:
from typing import TYPE_CHECKING
if TYPE_CHECKING:
import ppsci.data.dataset.atmospheric_utils as atmospheric_utils
然后将这个文件中所有atmospheric_utils.GraphGridMesh加上双引号,改为“atmospheric_utils.GraphGridMesh”
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
atmospheric_utils.py 内容移到了 atmospheri_dataset.py
should_init = any(var is None for var in all_input_vars) | ||
|
||
if should_init: | ||
# 初始化构建 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
删除中文注释
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
self._mesh2grid_graph_structure = self._init_mesh2grid_graph() | ||
else: | ||
# 直接构建图数据 | ||
# 图结构信息 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
同上删除中文注释
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
self.grid2mesh_num_edges = grid2mesh_num_edges | ||
self.mesh2grid_num_edges = mesh2grid_num_edges | ||
|
||
# 图特征信息 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
同上
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
调用jointContribution/graphcast/run中的compare
输入运行结果后得到对齐结果为:
10m_u_component_of_wind diff is -2.356016193110792e-05
10m_v_component_of_wind diff is 7.523250922509292e-06
2m_temperature diff is 2.4777781459127155e-05
geopotential diff is -0.000539646573504715
mean_sea_level_pressure diff is -0.0005250472340893714
specific_humidity diff is 6.035347363899133e-10
temperature diff is 8.158555083697042e-07
total_precipitation_6hr diff is 1.8224237754237203e-09
u_component_of_wind diff is 1.7749217200032017e-06
v_component_of_wind diff is 7.323719328271271e-06
vertical_velocity diff is -2.375314757111974e-07
All diff is -8.92277301779838e-05
参考#699,认为可以对齐
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
辛苦提交PR,代码内容修改完毕之后,可以同步更新一下文档的代码块定位位置
docs/zh/examples/graphcast.md
Outdated
=== "模型评估命令" | ||
``` sh | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
=== "模型评估命令" | |
``` sh | |
=== "模型评估命令" | |
``` sh |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
@@ -0,0 +1,63 @@ | |||
hydra: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
hydra: | |
defaults: | |
- ppsci_default | |
- TRAIN: train_default | |
- TRAIN/ema: ema_default | |
- TRAIN/swa: swa_default | |
- EVAL: eval_default | |
- INFER: infer_default | |
- hydra/job/config/override_dirname/exclude_keys: exclude_keys_default | |
- _self_ | |
hydra: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
config: | ||
override_dirname: | ||
exclude_keys: | ||
- TRAIN.checkpoint_path | ||
- TRAIN.pretrained_model_path | ||
- EVAL.pretrained_model_path | ||
- mode | ||
- output_dir | ||
- log_freq |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
config: | |
override_dirname: | |
exclude_keys: | |
- TRAIN.checkpoint_path | |
- TRAIN.pretrained_model_path | |
- EVAL.pretrained_model_path | |
- mode | |
- output_dir | |
- log_freq |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
|
||
EVAL: | ||
batch_size: 1 | ||
pretrained_model_path: "data/params/GraphCast_small---ERA5-1979-2015---resolution-1.0---pressure-levels-13---mesh-2to5---precipitation-input-and-output.pdparams" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
- pretrained_model_path默认值改为null,模型评估要求手动指定预训练模型
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
docs/zh/examples/graphcast.md
Outdated
unzip -q stats.zip -d data/ | ||
unzip -q template_graph.zip -d data/ | ||
|
||
python graphcast.py |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
模型评估要求手动指定预训练模型,命令应该类似:
python graphcast.py mode=eval EVAL.pretrained_model_path=xxxx
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
def stacked_to_dataset( | ||
stacked_array: "xarray.Variable", | ||
template_dataset: "xarray.Dataset", | ||
preserved_dims: typing.Tuple[str, ...] = ("batch", "lat", "lon"), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
typing.XXX建议改为XXX
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
好的,已修改
faces: np.ndarray | ||
|
||
|
||
def merge_meshes(mesh_list: typing.Sequence[TriangularMesh]) -> TriangularMesh: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
同上
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
|
||
def get_hierarchy_of_triangular_meshes_for_sphere( | ||
splits: int, | ||
) -> typing.List[TriangularMesh]: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
同上
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
) | ||
|
||
|
||
class _ChildVerticesBuilder(object): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
python3不需要显式继承object
class _ChildVerticesBuilder(object): | |
class _ChildVerticesBuilder: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这类位置不需要加换行,其余文件同理,不需要添加换行
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里是删除多余换行的
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
* add GraphGridMeshDataSet * add graphcast model and example * refine code and visualization * add license * fix bugs * fix ci errors * fix ci errors * fix ci errors * fix ci errors * fix ci errors * add docs * resolve conflicts * resolve conflicts * resolve conflicts * resolve conflicts * refine docs * fix comments * fix comments * fix comments * fix * delete atmospheric_utils.py * fix ci errors * fix * fix comments * fix * add transform in graphcast * fix model bugs * fix * fix docs
PR types
Others
PR changes
APIs, Docs
Describe
复现运行结果