Skip to content

Commit

Permalink
hier comparae (#910)
Browse files Browse the repository at this point in the history
  • Loading branch information
PhilipDeegan authored Oct 18, 2024
1 parent af3f2be commit c262a75
Show file tree
Hide file tree
Showing 9 changed files with 156 additions and 158 deletions.
12 changes: 10 additions & 2 deletions pyphare/pyphare/core/phare_utilities.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,10 +128,18 @@ def is_fp32(item):
return isinstance(item, float)


def assert_fp_any_all_close(a, b, atol=1e-16, rtol=0, atol_fp32=None):
def any_fp_tol(a, b, atol=1e-16, rtol=0, atol_fp32=None):
if any([is_fp32(el) for el in [a, b]]):
atol = atol_fp32 if atol_fp32 else atol * 1e8
np.testing.assert_allclose(a, b, atol=atol, rtol=rtol)
return dict(atol=atol, rtol=rtol)


def fp_any_all_close(a, b, atol=1e-16, rtol=0, atol_fp32=None):
return np.allclose(a, b, **any_fp_tol(a, b, atol, rtol, atol_fp32))


def assert_fp_any_all_close(a, b, atol=1e-16, rtol=0, atol_fp32=None):
np.testing.assert_allclose(a, b, **any_fp_tol(a, b, atol, rtol, atol_fp32))


def decode_bytes(input, errors="ignore"):
Expand Down
110 changes: 0 additions & 110 deletions pyphare/pyphare/pharein/examples/job.py

This file was deleted.

2 changes: 2 additions & 0 deletions pyphare/pyphare/pharein/init.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,4 +6,6 @@ def get_user_inputs(jobname):
_init_.PHARE_EXE = True
print(jobname)
jobmodule = importlib.import_module(jobname) # lgtm [py/unused-local-variable]
if jobmodule is None:
raise RuntimeError("failed to import job")
populateDict()
6 changes: 3 additions & 3 deletions pyphare/pyphare/pharein/load_balancer.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ class LoadBalancer:

def __post_init__(self):
if self.auto and self.every:
raise RuntimeError(f"LoadBalancer cannot work with both 'every' and 'auto'")
raise RuntimeError("LoadBalancer cannot work with both 'every' and 'auto'")

if self.every is None:
self.auto = True
Expand All @@ -50,8 +50,8 @@ def __post_init__(self):
if self._register:
if not gv.sim:
raise RuntimeError(
f"LoadBalancer cannot be registered as no simulation exists"
"LoadBalancer cannot be registered as no simulation exists"
)
if gv.sim.load_balancer:
raise RuntimeError(f"LoadBalancer is already registered to simulation")
raise RuntimeError("LoadBalancer is already registered to simulation")
gv.sim.load_balancer = self
3 changes: 1 addition & 2 deletions pyphare/pyphare/pharesee/hierarchy/fromh5.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,8 @@
particle_files_patterns = ("domain", "patchGhost", "levelGhost")


def get_all_available_quantities_from_h5(filepath, time=0, exclude=["tags"]):
def get_all_available_quantities_from_h5(filepath, time=0, exclude=["tags"], hier=None):
time = format_timestamp(time)
hier = None
path = Path(filepath)
for h5 in path.glob("*.h5"):
if h5.parent == path and h5.stem not in exclude:
Expand Down
13 changes: 9 additions & 4 deletions pyphare/pyphare/pharesee/hierarchy/hierarchy.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
import numpy as np
import matplotlib.pyplot as plt

from .patch import Patch
from .patchlevel import PatchLevel
from ...core.box import Box
from ...core import box as boxm
from ...core.phare_utilities import refinement_ratio
from ...core.phare_utilities import listify

import numpy as np
import matplotlib.pyplot as plt
from ...core.phare_utilities import deep_copy
from ...core.phare_utilities import refinement_ratio


def format_timestamp(timestamp):
Expand Down Expand Up @@ -68,6 +69,10 @@ def __init__(

self.update()

def __deepcopy__(self, memo):
no_copy_keys = ["data_files"] # do not copy these things
return deep_copy(self, memo, no_copy_keys)

def __getitem__(self, qty):
return self.__dict__[qty]

Expand Down
130 changes: 100 additions & 30 deletions pyphare/pyphare/pharesee/hierarchy/hierarchy_utils.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,19 @@
from .hierarchy import PatchHierarchy
from .patchdata import FieldData
from dataclasses import dataclass, field
from copy import deepcopy
import numpy as np

from typing import Any, List, Tuple

from .hierarchy import PatchHierarchy, format_timestamp
from .patchdata import FieldData, ParticleData
from .patchlevel import PatchLevel
from .patch import Patch
from ...core.box import Box
from ...core.gridlayout import GridLayout
from ...core.phare_utilities import listify
from ...core.phare_utilities import refinement_ratio
from pyphare.core import phare_utilities as phut

import numpy as np

field_qties = {
"EM_B_x": "Bx",
Expand Down Expand Up @@ -298,7 +306,7 @@ def overlap_mask_2d(x, y, dl, level, qty):
return is_overlaped


def flat_finest_field(hierarchy, qty, time=None):
def flat_finest_field(hierarchy, qty, time=None, neghosts=1):
"""
returns 2 flattened arrays containing the data (with shape [Npoints])
and the coordinates (with shape [Npoints, Ndim]) for the given
Expand All @@ -311,7 +319,7 @@ def flat_finest_field(hierarchy, qty, time=None):
dim = hierarchy.ndim

if dim == 1:
return flat_finest_field_1d(hierarchy, qty, time)
return flat_finest_field_1d(hierarchy, qty, time, neghosts)
elif dim == 2:
return flat_finest_field_2d(hierarchy, qty, time)
elif dim == 3:
Expand All @@ -321,7 +329,7 @@ def flat_finest_field(hierarchy, qty, time=None):
raise ValueError("the dim of a hierarchy should be 1, 2 or 3")


def flat_finest_field_1d(hierarchy, qty, time=None):
def flat_finest_field_1d(hierarchy, qty, time=None, neghosts=1):
lvl = hierarchy.levels(time)

for ilvl in range(hierarchy.finest_level(time) + 1)[::-1]:
Expand All @@ -333,7 +341,7 @@ def flat_finest_field_1d(hierarchy, qty, time=None):
# all but 1 ghost nodes are removed in order to limit
# the overlapping, but to keep enough point to avoid
# any extrapolation for the interpolator
needed_points = pdata.ghosts_nbr - 1
needed_points = pdata.ghosts_nbr - neghosts

# data = pdata.dataset[patch.box] # TODO : once PR 551 will be merged...
data = pdata.dataset[needed_points[0] : -needed_points[0]]
Expand Down Expand Up @@ -552,34 +560,55 @@ def _compute_scalardiv(patch_datas, **kwargs):
return tuple(pd_attrs)


from dataclasses import dataclass


@dataclass
class EqualityReport:
ok: bool
reason: str
failed: List[Tuple[str, Any, Any]] = field(default_factory=lambda: [])

def __bool__(self):
return self.ok
return not self.failed

def __repr__(self):
for msg, ref, cmp in self:
print(msg)
try:
if type(ref) is FieldData:
phut.assert_fp_any_all_close(ref[:], cmp[:], atol=1e-16)
except AssertionError as e:
print(e)
return self.failed[0][0]

def __call__(self, reason, ref=None, cmp=None):
self.failed.append((reason, ref, cmp))
return self

def __getitem__(self, idx):
return (self.failed[idx][1], self.failed[idx][2])

def __iter__(self):
return self.failed.__iter__()

def __reversed__(self):
return reversed(self.failed)


def hierarchy_compare(this, that):
def hierarchy_compare(this, that, atol=1e-16):
eqr = EqualityReport()

if not isinstance(this, PatchHierarchy) or not isinstance(that, PatchHierarchy):
return EqualityReport(False, "class type mismatch")
return eqr("class type mismatch")

if this.ndim != that.ndim or this.domain_box != that.domain_box:
return EqualityReport(False, "dimensional mismatch")
return eqr("dimensional mismatch")

if this.time_hier.keys() != that.time_hier.keys():
return EqualityReport(False, "timesteps mismatch")
return eqr("timesteps mismatch")

for tidx in this.times():
patch_levels_ref = this.time_hier[tidx]
patch_levels_cmp = that.time_hier[tidx]

if patch_levels_ref.keys() != patch_levels_cmp.keys():
return EqualityReport(False, "levels mismatch")
return eqr("levels mismatch")

for level_idx in patch_levels_cmp.keys():
patch_level_ref = patch_levels_ref[level_idx]
Expand All @@ -590,21 +619,62 @@ def hierarchy_compare(this, that):
patch_cmp = patch_level_cmp.patches[patch_idx]

if patch_ref.patch_datas.keys() != patch_cmp.patch_datas.keys():
print(list(patch_ref.patch_datas.keys()))
print(list(patch_cmp.patch_datas.keys()))
return EqualityReport(False, "data keys mismatch")
return eqr("data keys mismatch")

for patch_data_key in patch_ref.patch_datas.keys():
patch_data_ref = patch_ref.patch_datas[patch_data_key]
patch_data_cmp = patch_cmp.patch_datas[patch_data_key]

if patch_data_cmp != patch_data_ref:
return EqualityReport(
False,
"data mismatch: "
+ type(patch_data_cmp).__name__
+ " "
+ type(patch_data_ref).__name__,
)
if not patch_data_cmp.compare(patch_data_ref, atol=atol):
msg = f"data mismatch: {type(patch_data_ref).__name__} {patch_data_key}"
eqr(msg, patch_data_cmp, patch_data_ref)

if not eqr:
return eqr

return eqr


def single_patch_for_LO(hier, qties=None, skip=None):
def _skip(qty):
return (qties is not None and qty not in qties) or (
skip is not None and qty in skip
)

return EqualityReport(True, "OK")
cier = deepcopy(hier)
sim = hier.sim
layout = GridLayout(
Box(sim.origin, sim.cells), sim.origin, sim.dl, interp_order=sim.interp_order
)
p0 = Patch(patch_datas={}, patch_id="", layout=layout)
for t in cier.times():
cier.time_hier[format_timestamp(t)] = {0: cier.level(0, t)}
cier.level(0, t).patches = [deepcopy(p0)]
l0_pds = cier.level(0, t).patches[0].patch_datas
for k, v in hier.level(0, t).patches[0].patch_datas.items():
if _skip(k):
continue
if isinstance(v, FieldData):
l0_pds[k] = FieldData(
layout, v.field_name, None, centering=v.centerings
)
l0_pds[k].dataset = np.zeros(l0_pds[k].size)
patch_box = hier.level(0, t).patches[0].box
l0_pds[k][patch_box] = v[patch_box]

elif isinstance(v, ParticleData):
l0_pds[k] = deepcopy(v)
else:
raise RuntimeError("unexpected state")

for patch in hier.level(0, t).patches[1:]:
for k, v in patch.patch_datas.items():
if _skip(k):
continue
if isinstance(v, FieldData):
l0_pds[k][patch.box] = v[patch.box]
elif isinstance(v, ParticleData):
l0_pds[k].dataset.add(v.dataset)
else:
raise RuntimeError("unexpected state")
return cier
Loading

0 comments on commit c262a75

Please sign in to comment.