diff --git a/src/jobflow/core/flow.py b/src/jobflow/core/flow.py index 0cffcf80..c9056746 100644 --- a/src/jobflow/core/flow.py +++ b/src/jobflow/core/flow.py @@ -2,9 +2,9 @@ from __future__ import annotations -import copy import logging import warnings +from copy import deepcopy from typing import TYPE_CHECKING, Sequence from monty.json import MSONable @@ -182,7 +182,7 @@ def __add__(self, other: Job | Flow | Sequence[Flow | Job]) -> Flow: """Add a job or subflow to the flow.""" if not isinstance(other, (Flow, jobflow.Job, tuple, list)): return NotImplemented - new_flow = self.__deepcopy__() + new_flow = deepcopy(self) new_flow.add_jobs(other) return new_flow @@ -190,18 +190,18 @@ def __sub__(self, other: Flow | Job) -> Flow: """Remove a job or subflow from the flow.""" if other not in self.jobs: raise ValueError(f"{other!r} not found in flow") - new_flow = self.__deepcopy__() + new_flow = deepcopy(self) new_flow.jobs = tuple([job for job in new_flow.jobs if job != other]) return new_flow - def __repr__(self, level=0, index=None) -> str: + def __repr__(self, level: int = 0, prefix: str = "") -> str: """Get a string representation of the flow.""" indent = " " * level name, uuid = self.name, self.uuid - flow_index = f"{index}." if index is not None else "" + _prefix = f"{prefix}." if prefix else "" job_reprs = "\n".join( - f"{indent}{flow_index}{i}. " - f"{j.__repr__(level + 1, f'{flow_index}{i}') if isinstance(j, Flow) else j}" + f"{indent}{_prefix}{i}. " + f"{j.__repr__(level + 1, f'{_prefix}{i}') if isinstance(j, Flow) else j}" for i, j in enumerate(self.jobs, 1) ) return f"Flow({name=}, {uuid=})\n{job_reprs}" @@ -216,22 +216,6 @@ def __hash__(self) -> int: """Get the hash of the flow.""" return hash(self.uuid) - def __deepcopy__(self, memo: dict[int, Any] = None) -> Flow: - """Get a deep copy of the flow. - - Shallow copy doesn't make sense; jobs aren't allowed to belong to multiple flows - """ - kwds = self.as_dict() - for key in ("jobs", "@class", "@module", "@version"): - kwds.pop(key) - jobs = copy.deepcopy(self.jobs, memo) - new_flow = Flow(jobs=[], **kwds) - # reassign host - for job in jobs: - job.hosts = [new_flow.uuid] - new_flow.jobs = jobs - return new_flow - @property def jobs(self) -> tuple[Flow | Job, ...]: """ diff --git a/tests/core/test_flow.py b/tests/core/test_flow.py index cad582e1..d34fc2ae 100644 --- a/tests/core/test_flow.py +++ b/tests/core/test_flow.py @@ -905,11 +905,6 @@ def test_flow_magic_methods(): assert flow1 != flow2 assert hash(flow1) != hash(flow2) - # test __deepcopy__ - flow_copy = flow1.__deepcopy__() - assert flow_copy == flow1 - assert id(flow_copy) != id(flow1) - # test __getitem__ with out of range index with pytest.raises(IndexError): _ = flow1[10]