From 6ed91d5b055702c808df127e12d7b20ddaeaaa46 Mon Sep 17 00:00:00 2001 From: Brandon Butler Date: Wed, 2 Nov 2022 11:34:33 -0400 Subject: [PATCH] Move aggregator to keyword argument in operation decorator (#681) * refactor: Standardize future(deprecation) warnings * feat: Add aggregator argument to FlowProject.operation Also deprecate the use of aggregator as a decorator. * test: Update tests with new FlowProject.operation * doc: Update aggregator documentation * refactor: _deprecated_warning to keyword only arguments * doc: Fix typo in aggregator spelling Co-authored-by: Bradley Dice * fix: Correctly issue warnings from caller `_deprecated_warning` now issues warning from the caller's stack level. Co-authored-by: Bradley Dice * fix: Remove unnecssarying typing quotations Co-authored-by: Bradley Dice --- flow/aggregates.py | 23 ++++++--- flow/project.py | 67 +++++++++++++++++++------- flow/util/misc.py | 25 ++++++++++ tests/define_aggregate_test_project.py | 27 +++++------ tests/test_aggregates.py | 5 ++ tests/test_project.py | 15 +++--- 6 files changed, 118 insertions(+), 44 deletions(-) diff --git a/flow/aggregates.py b/flow/aggregates.py index 45a5e17e8..4f47bd922 100644 --- a/flow/aggregates.py +++ b/flow/aggregates.py @@ -13,6 +13,7 @@ from hashlib import md5 from .errors import FlowProjectDefinitionError +from .util.misc import _deprecated_warning def _get_unique_function_id(func): @@ -45,7 +46,7 @@ def _get_unique_function_id(func): class aggregator: - """Decorator for operation functions that operate on aggregates. + """Class for generating aggregates for use in operations. By default, if the ``aggregator_function`` is ``None``, an aggregate of all jobs will be created. @@ -57,8 +58,7 @@ class aggregator: .. code-block:: python - @aggregator() - @FlowProject.operation + @FlowProject.operation(aggregator=aggregator()) def foo(*jobs): print(len(jobs)) @@ -133,8 +133,7 @@ def groupsof(cls, num=1, sort_by=None, sort_ascending=True, select=None): .. code-block:: python - @aggregator.groupsof(num=2) - @FlowProject.operation + @FlowProject.operation(aggregator=aggregator.groupsof(num=2)) def foo(*jobs): print(len(jobs)) @@ -198,8 +197,7 @@ def groupby(cls, key, default=None, sort_by=None, sort_ascending=True, select=No .. code-block:: python - @aggregator.groupby(key="key", default=-1) - @FlowProject.operation + @FlowProject.operation(aggregator=aggregator.groupby(key="key", default=-1)) def foo(*jobs): print(len(jobs)) @@ -348,6 +346,12 @@ def __call__(self, func=None): The function to decorate. """ + _deprecated_warning( + deprecation="@aggregator(...)", + alternative="Use FlowProject.operation(aggregator=aggregator(...)) instead.", + deprecated_in="0.23.0", + removed_in="0.24.0", + ) if not callable(func): raise FlowProjectDefinitionError( "Invalid argument passed while calling the aggregate " @@ -357,6 +361,11 @@ def __call__(self, func=None): raise FlowProjectDefinitionError( "The with_job option cannot be used with aggregation." ) + current_agg = getattr(func, "_flow_aggregate", None) + if current_agg is not None and current_agg != aggregator.groupsof(1): + raise FlowProjectDefinitionError( + "Cannot specify aggregates in function and decorator." + ) setattr(func, "_flow_aggregate", self) return func diff --git a/flow/project.py b/flow/project.py index d63e83582..0155def1c 100644 --- a/flow/project.py +++ b/flow/project.py @@ -68,6 +68,7 @@ _add_cwd_to_environment_pythonpath, _bidict, _cached_partial, + _deprecated_warning, _get_parallel_executor, _positive_int, _roundrobin, @@ -615,10 +616,11 @@ def __call__(self, *jobs): format_arguments[match.group(1)] = jobs formatted_cmd = cmd.format(**format_arguments) if formatted_cmd != cmd: - warnings.warn( - "Returning format strings in a cmd operation is deprecated as of version 0.22.0 " - "and will be removed in 0.23.0. Users should format the command string.", - FutureWarning, + _deprecated_warning( + deprecation="Returning format strings in a cmd operation", + alternative="Users should format the command string.", + deprecated_in="0.22.0", + removed_in="0.23.0", ) return formatted_cmd @@ -1485,6 +1487,9 @@ def mpi_hello(job): ``False``. directives : dict, optional, keyword-only Directives to use for resource requests and execution. + aggregator : flow.aggregator, optional, keyword-only + The aggregator to use for the operation. Default value uses aggregator of size one + (i.e. individual jobs). Returns ------- @@ -1502,20 +1507,38 @@ def __call__( cmd=False, with_job=False, directives=None, + aggregator=None, ): if isinstance(func, str): return lambda op: self._internal_call( - op, name=func, cmd=cmd, with_job=with_job, directives=directives + op, + name=func, + cmd=cmd, + with_job=with_job, + directives=directives, + op_aggregator=aggregator, ) if func is None: return lambda op: self._internal_call( - op, name=name, cmd=cmd, with_job=with_job, directives=directives + op, + name=name, + cmd=cmd, + with_job=with_job, + directives=directives, + op_aggregator=aggregator, ) return self._internal_call( - func, name=name, cmd=cmd, with_job=with_job, directives=directives + func, + name=name, + cmd=cmd, + with_job=with_job, + directives=directives, + op_aggregator=aggregator, ) - def _internal_call(self, func, name, *, cmd, with_job, directives): + def _internal_call( + self, func, name, *, cmd, with_job, directives, op_aggregator + ): if func in chain( *self._parent_class._OPERATION_PRECONDITIONS.values(), *self._parent_class._OPERATION_POSTCONDITIONS.values(), @@ -1560,7 +1583,15 @@ def _internal_call(self, func, name, *, cmd, with_job, directives): ) if not getattr(func, "_flow_aggregate", False): - func._flow_aggregate = aggregator.groupsof(1) + default_aggregator = aggregator.groupsof(1) + if op_aggregator is None: + op_aggregator = default_aggregator + elif op_aggregator != default_aggregator: + if getattr(func, "_flow_with_job", False): + raise FlowProjectDefinitionError( + "The with_job option cannot be used with aggregation." + ) + func._flow_aggregate = op_aggregator # Append the name and function to the class registry self._parent_class._OPERATION_FUNCTIONS.append((name, func)) @@ -1602,11 +1633,11 @@ def with_directives(self, directives, name=None): name and directives as an operation of the :class:`~.FlowProject` subclass. """ - warnings.warn( - "@FlowProject.operation.with_directives has been deprecated as of 0.22.0 and " - "will be removed in 0.23.0. Use @FlowProject.operation(directives={...}) " - "instead.", - FutureWarning, + _deprecated_warning( + deprecation="@FlowProject.operation.with_directives", + alternative="Use @FlowProject.operation(directives={...}) instead.", + deprecated_in="0.22.0", + removed_in="0.23.0", ) def add_operation_with_directives(function): @@ -5153,9 +5184,11 @@ class MyProject(FlowProject): sys.exit(2) if args.show_traceback: - warnings.warn( - "--show-traceback is deprecated and to be removed in signac-flow version 0.23.", - FutureWarning, + _deprecated_warning( + deprecation="--show-traceback", + alternative="", + deprecated_in="0.22.0", + removed_in="0.23.0", ) # Manually 'merge' the various global options defined for both the main parser diff --git a/flow/util/misc.py b/flow/util/misc.py index 252f3f718..5a92ec8b4 100644 --- a/flow/util/misc.py +++ b/flow/util/misc.py @@ -5,6 +5,7 @@ import argparse import logging import os +import warnings from collections.abc import MutableMapping from contextlib import contextmanager from functools import lru_cache, partial @@ -393,6 +394,30 @@ def parallel_executor(func, iterable, **kwargs): return parallel_executor +def _deprecated_warning( + *, + deprecation: str, + alternative: str, + deprecated_in: str, + removed_in: str, + category: Warning = FutureWarning, +): + warnings.warn( + " ".join( + ( + deprecation, + "has been deprecated as of", + deprecated_in, + "and will be removed in", + removed_in + ".", + alternative, + ) + ), + category, + stacklevel=2, + ) + + __all__ = [ "redirect_log", ] diff --git a/tests/define_aggregate_test_project.py b/tests/define_aggregate_test_project.py index dc02f2966..563632c89 100644 --- a/tests/define_aggregate_test_project.py +++ b/tests/define_aggregate_test_project.py @@ -28,50 +28,49 @@ def op1(job): pass -@_AggregateTestProject.operation(cmd=True) -@aggregator.groupby("even") +@_AggregateTestProject.operation(cmd=True, aggregator=aggregator.groupby("even")) def agg_op_parallel(*jobs): # This is used to test parallel execution of aggregation operations return f"echo '{len(jobs)}'" -@_AggregateTestProject.operation -@aggregator.groupby("even") +@_AggregateTestProject.operation(aggregator=aggregator.groupby("even")) def agg_op1(*jobs): total = sum(job.sp.i for job in jobs) set_all_job_docs(jobs, "sum", total) -@_AggregateTestProject.operation -@aggregator.groupby(lambda job: job.sp.i % 2) +@_AggregateTestProject.operation( + aggregator=aggregator.groupby(lambda job: job.sp.i % 2) +) def agg_op1_different(*jobs): sum_other = sum(job.sp.i for job in jobs) set_all_job_docs(jobs, "sum_other", sum_other) -@_AggregateTestProject.operation -@aggregator(statepoint_i_even_odd_aggregator) +@_AggregateTestProject.operation( + aggregator=aggregator(statepoint_i_even_odd_aggregator) +) def agg_op1_custom(*jobs): sum_custom = sum(job.sp.i for job in jobs) set_all_job_docs(jobs, "sum_custom", sum_custom) @group1 -@_AggregateTestProject.operation -@aggregator.groupsof(30) +@_AggregateTestProject.operation(aggregator=aggregator.groupsof(30)) def agg_op2(*jobs): set_all_job_docs(jobs, "op2", True) @group1 -@_AggregateTestProject.operation -@aggregator() +@_AggregateTestProject.operation(aggregator=aggregator()) def agg_op3(*jobs): set_all_job_docs(jobs, "op3", True) -@_AggregateTestProject.operation(cmd=True) -@aggregator(sort_by="i", select=lambda job: job.sp.i <= 2) +@_AggregateTestProject.operation( + cmd=True, aggregator=aggregator(sort_by="i", select=lambda job: job.sp.i <= 2) +) def agg_op4(*jobs): return "echo '{jobs[0].sp.i} and {jobs[1].sp.i}'" diff --git a/tests/test_aggregates.py b/tests/test_aggregates.py index bda96071f..60911cb67 100644 --- a/tests/test_aggregates.py +++ b/tests/test_aggregates.py @@ -7,6 +7,8 @@ from flow.aggregates import aggregator, get_aggregate_id from flow.errors import FlowProjectDefinitionError +ignore_call_warning = pytest.mark.filterwarnings("ignore:@aggregator():FutureWarning") + class AggregateProjectSetup(TestProjectBase): project_name = "AggregateTestProject" @@ -311,17 +313,20 @@ def test_invalid_select(self, select): with pytest.raises(TypeError): aggregator(select=select) + @ignore_call_warning @pytest.mark.parametrize("param", ["str", 1, None]) def test_invalid_call(self, param): aggregator_instance = aggregator() with pytest.raises(FlowProjectDefinitionError): aggregator_instance(param) + @ignore_call_warning def test_call_without_argument(self): aggregate_instance = aggregator() with pytest.raises(FlowProjectDefinitionError): aggregate_instance() + @ignore_call_warning def test_call_with_decorator(self): @aggregator() def test_function(x): diff --git a/tests/test_project.py b/tests/test_project.py index ae6d392a4..657626b30 100644 --- a/tests/test_project.py +++ b/tests/test_project.py @@ -2245,6 +2245,7 @@ def test_reregister_aggregates(self): # The operation agg_op2 adds another aggregate in the project. assert len(agg_cursor) == NUM_BEFORE_REREGISTRATION + 2 + @pytest.mark.filterwarnings("ignore:@aggregator():FutureWarning") def test_aggregator_with_job(self): class A(FlowProject): pass @@ -2256,15 +2257,17 @@ class A(FlowProject): def test_invalid_decorators(job): pass - def test_with_job_aggregator(self): - class A(FlowProject): - pass - with pytest.raises(FlowProjectDefinitionError): - @aggregator() @A.operation(with_job=True) - def test_invalid_decorators(job): + @aggregator() + def test_invalid_decorators_2(job): + pass + + with pytest.raises(FlowProjectDefinitionError): + + @A.operation(with_job=True, aggregator=aggregator()) + def test_invalid_decorator_combination(job): pass