Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

convenient utils for mpi/hierarchies #915

Merged
merged 1 commit into from
Oct 24, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 29 additions & 0 deletions pyphare/pyphare/pharesee/hierarchy/hierarchy_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,35 @@
}


def nbr_ranks(hier):
"""
returns the number of mpi ranks used in the given hierarchy
"""
max_rank = 0
t0 = hier.times()[0]
for _, lvl in hier.levels(t0).items():
for patch in lvl.patches:
rank = patch.attrs["mpi_rank"]
if rank > max_rank:
max_rank = rank
return max_rank

Comment on lines +43 to +55
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Add input validation and type hints for better reliability.

The function should include input validation and type hints to improve reliability and maintainability.

Consider applying these improvements:

-def nbr_ranks(hier):
+from typing import TYPE_CHECKING
+if TYPE_CHECKING:
+    from .hierarchy import PatchHierarchy
+
+def nbr_ranks(hier: 'PatchHierarchy') -> int:
     """
     returns the number of mpi ranks used in the given hierarchy
+
+    Args:
+        hier: The hierarchy to analyze
+
+    Returns:
+        int: The maximum MPI rank found in the hierarchy
+
+    Raises:
+        KeyError: If a patch is missing the 'mpi_rank' attribute
     """
+    if not hier or not hier.times():
+        return 0
+
     max_rank = 0
     t0 = hier.times()[0]
     for _, lvl in hier.levels(t0).items():
         for patch in lvl.patches:
+            if "mpi_rank" not in patch.attrs:
+                raise KeyError(f"Patch {patch.id} is missing 'mpi_rank' attribute")
             rank = patch.attrs["mpi_rank"]
             if rank > max_rank:
                 max_rank = rank
     return max_rank
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
def nbr_ranks(hier):
"""
returns the number of mpi ranks used in the given hierarchy
"""
max_rank = 0
t0 = hier.times()[0]
for _, lvl in hier.levels(t0).items():
for patch in lvl.patches:
rank = patch.attrs["mpi_rank"]
if rank > max_rank:
max_rank = rank
return max_rank
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from .hierarchy import PatchHierarchy
def nbr_ranks(hier: 'PatchHierarchy') -> int:
"""
returns the number of mpi ranks used in the given hierarchy
Args:
hier: The hierarchy to analyze
Returns:
int: The maximum MPI rank found in the hierarchy
Raises:
KeyError: If a patch is missing the 'mpi_rank' attribute
"""
if not hier or not hier.times():
return 0
max_rank = 0
t0 = hier.times()[0]
for _, lvl in hier.levels(t0).items():
for patch in lvl.patches:
if "mpi_rank" not in patch.attrs:
raise KeyError(f"Patch {patch.id} is missing 'mpi_rank' attribute")
rank = patch.attrs["mpi_rank"]
if rank > max_rank:
max_rank = rank
return max_rank


def patch_per_rank(hier):
"""
returns the number of patch per mpi rank for each time step
"""
nbranks = nbr_ranks(hier)
ppr = {}
for t in hier.times():
ppr[t] = {ir: 0 for ir in np.arange(nbranks + 1)}
for _, lvl in hier.levels(t).items():
for patch in lvl.patches:
ppr[t][patch.attrs["mpi_rank"]] += 1

return ppr

Comment on lines +57 to +70
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Enhance function with type hints, validation, and better documentation.

The function needs improvements in type hints, validation, and documentation. Also, the rank count initialization can be more efficient.

Consider applying these improvements:

-def patch_per_rank(hier):
+from typing import Dict, TYPE_CHECKING
+if TYPE_CHECKING:
+    from .hierarchy import PatchHierarchy
+
+def patch_per_rank(hier: 'PatchHierarchy') -> Dict[float, Dict[int, int]]:
     """
     returns the number of patch per mpi rank for each time step
+
+    Args:
+        hier: The hierarchy to analyze
+
+    Returns:
+        Dict[float, Dict[int, int]]: A dictionary mapping timestamps to rank counts
+            where each rank count is a dictionary mapping rank IDs to patch counts
+
+    Example:
+        >>> ppr = patch_per_rank(hierarchy)
+        >>> print(ppr[0.0])  # Prints patch counts per rank at t=0.0
+        {0: 2, 1: 3, 2: 2}  # Example output: rank 0 has 2 patches, rank 1 has 3, etc.
+
+    Raises:
+        KeyError: If a patch is missing the 'mpi_rank' attribute
     """
+    if not hier or not hier.times():
+        return {}
+
     nbranks = nbr_ranks(hier)
     ppr = {}
     for t in hier.times():
-        ppr[t] = {ir: 0 for ir in np.arange(nbranks + 1)}
+        ppr[t] = dict.fromkeys(range(nbranks + 1), 0)
         for _, lvl in hier.levels(t).items():
             for patch in lvl.patches:
+                if "mpi_rank" not in patch.attrs:
+                    raise KeyError(f"Patch {patch.id} is missing 'mpi_rank' attribute")
                 ppr[t][patch.attrs["mpi_rank"]] += 1

     return ppr
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
def patch_per_rank(hier):
"""
returns the number of patch per mpi rank for each time step
"""
nbranks = nbr_ranks(hier)
ppr = {}
for t in hier.times():
ppr[t] = {ir: 0 for ir in np.arange(nbranks + 1)}
for _, lvl in hier.levels(t).items():
for patch in lvl.patches:
ppr[t][patch.attrs["mpi_rank"]] += 1
return ppr
from typing import Dict, TYPE_CHECKING
if TYPE_CHECKING:
from .hierarchy import PatchHierarchy
def patch_per_rank(hier: 'PatchHierarchy') -> Dict[float, Dict[int, int]]:
"""
returns the number of patch per mpi rank for each time step
Args:
hier: The hierarchy to analyze
Returns:
Dict[float, Dict[int, int]]: A dictionary mapping timestamps to rank counts
where each rank count is a dictionary mapping rank IDs to patch counts
Example:
>>> ppr = patch_per_rank(hierarchy)
>>> print(ppr[0.0]) # Prints patch counts per rank at t=0.0
{0: 2, 1: 3, 2: 2} # Example output: rank 0 has 2 patches, rank 1 has 3, etc.
Raises:
KeyError: If a patch is missing the 'mpi_rank' attribute
"""
if not hier or not hier.times():
return {}
nbranks = nbr_ranks(hier)
ppr = {}
for t in hier.times():
ppr[t] = dict.fromkeys(range(nbranks + 1), 0)
for _, lvl in hier.levels(t).items():
for patch in lvl.patches:
if "mpi_rank" not in patch.attrs:
raise KeyError(f"Patch {patch.id} is missing 'mpi_rank' attribute")
ppr[t][patch.attrs["mpi_rank"]] += 1
return ppr


def are_compatible_hierarchies(hierarchies):
ref = hierarchies[0]
same_box = True
Expand Down
Loading