Skip to content

Commit

Permalink
Move aggregator to keyword argument in operation decorator (#681)
Browse files Browse the repository at this point in the history
* refactor: Standardize future(deprecation) warnings

* feat: Add aggregator argument to FlowProject.operation

Also deprecate the use of aggregator as a decorator.

* test: Update tests with new FlowProject.operation

* doc: Update aggregator documentation

* refactor: _deprecated_warning to keyword only arguments

* doc: Fix typo in aggregator spelling

Co-authored-by: Bradley Dice <bdice@bradleydice.com>

* fix: Correctly issue warnings from caller

`_deprecated_warning` now issues warning from the caller's stack level.

Co-authored-by: Bradley Dice <bdice@bradleydice.com>

* fix: Remove unnecssarying typing quotations

Co-authored-by: Bradley Dice <bdice@bradleydice.com>
  • Loading branch information
b-butler and bdice authored Nov 2, 2022
1 parent 0054e49 commit 6ed91d5
Show file tree
Hide file tree
Showing 6 changed files with 118 additions and 44 deletions.
23 changes: 16 additions & 7 deletions flow/aggregates.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from hashlib import md5

from .errors import FlowProjectDefinitionError
from .util.misc import _deprecated_warning


def _get_unique_function_id(func):
Expand Down Expand Up @@ -45,7 +46,7 @@ def _get_unique_function_id(func):


class aggregator:
"""Decorator for operation functions that operate on aggregates.
"""Class for generating aggregates for use in operations.
By default, if the ``aggregator_function`` is ``None``, an aggregate of all
jobs will be created.
Expand All @@ -57,8 +58,7 @@ class aggregator:
.. code-block:: python
@aggregator()
@FlowProject.operation
@FlowProject.operation(aggregator=aggregator())
def foo(*jobs):
print(len(jobs))
Expand Down Expand Up @@ -133,8 +133,7 @@ def groupsof(cls, num=1, sort_by=None, sort_ascending=True, select=None):
.. code-block:: python
@aggregator.groupsof(num=2)
@FlowProject.operation
@FlowProject.operation(aggregator=aggregator.groupsof(num=2))
def foo(*jobs):
print(len(jobs))
Expand Down Expand Up @@ -198,8 +197,7 @@ def groupby(cls, key, default=None, sort_by=None, sort_ascending=True, select=No
.. code-block:: python
@aggregator.groupby(key="key", default=-1)
@FlowProject.operation
@FlowProject.operation(aggregator=aggregator.groupby(key="key", default=-1))
def foo(*jobs):
print(len(jobs))
Expand Down Expand Up @@ -348,6 +346,12 @@ def __call__(self, func=None):
The function to decorate.
"""
_deprecated_warning(
deprecation="@aggregator(...)",
alternative="Use FlowProject.operation(aggregator=aggregator(...)) instead.",
deprecated_in="0.23.0",
removed_in="0.24.0",
)
if not callable(func):
raise FlowProjectDefinitionError(
"Invalid argument passed while calling the aggregate "
Expand All @@ -357,6 +361,11 @@ def __call__(self, func=None):
raise FlowProjectDefinitionError(
"The with_job option cannot be used with aggregation."
)
current_agg = getattr(func, "_flow_aggregate", None)
if current_agg is not None and current_agg != aggregator.groupsof(1):
raise FlowProjectDefinitionError(
"Cannot specify aggregates in function and decorator."
)
setattr(func, "_flow_aggregate", self)
return func

Expand Down
67 changes: 50 additions & 17 deletions flow/project.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@
_add_cwd_to_environment_pythonpath,
_bidict,
_cached_partial,
_deprecated_warning,
_get_parallel_executor,
_positive_int,
_roundrobin,
Expand Down Expand Up @@ -615,10 +616,11 @@ def __call__(self, *jobs):
format_arguments[match.group(1)] = jobs
formatted_cmd = cmd.format(**format_arguments)
if formatted_cmd != cmd:
warnings.warn(
"Returning format strings in a cmd operation is deprecated as of version 0.22.0 "
"and will be removed in 0.23.0. Users should format the command string.",
FutureWarning,
_deprecated_warning(
deprecation="Returning format strings in a cmd operation",
alternative="Users should format the command string.",
deprecated_in="0.22.0",
removed_in="0.23.0",
)
return formatted_cmd

Expand Down Expand Up @@ -1485,6 +1487,9 @@ def mpi_hello(job):
``False``.
directives : dict, optional, keyword-only
Directives to use for resource requests and execution.
aggregator : flow.aggregator, optional, keyword-only
The aggregator to use for the operation. Default value uses aggregator of size one
(i.e. individual jobs).
Returns
-------
Expand All @@ -1502,20 +1507,38 @@ def __call__(
cmd=False,
with_job=False,
directives=None,
aggregator=None,
):
if isinstance(func, str):
return lambda op: self._internal_call(
op, name=func, cmd=cmd, with_job=with_job, directives=directives
op,
name=func,
cmd=cmd,
with_job=with_job,
directives=directives,
op_aggregator=aggregator,
)
if func is None:
return lambda op: self._internal_call(
op, name=name, cmd=cmd, with_job=with_job, directives=directives
op,
name=name,
cmd=cmd,
with_job=with_job,
directives=directives,
op_aggregator=aggregator,
)
return self._internal_call(
func, name=name, cmd=cmd, with_job=with_job, directives=directives
func,
name=name,
cmd=cmd,
with_job=with_job,
directives=directives,
op_aggregator=aggregator,
)

def _internal_call(self, func, name, *, cmd, with_job, directives):
def _internal_call(
self, func, name, *, cmd, with_job, directives, op_aggregator
):
if func in chain(
*self._parent_class._OPERATION_PRECONDITIONS.values(),
*self._parent_class._OPERATION_POSTCONDITIONS.values(),
Expand Down Expand Up @@ -1560,7 +1583,15 @@ def _internal_call(self, func, name, *, cmd, with_job, directives):
)

if not getattr(func, "_flow_aggregate", False):
func._flow_aggregate = aggregator.groupsof(1)
default_aggregator = aggregator.groupsof(1)
if op_aggregator is None:
op_aggregator = default_aggregator
elif op_aggregator != default_aggregator:
if getattr(func, "_flow_with_job", False):
raise FlowProjectDefinitionError(
"The with_job option cannot be used with aggregation."
)
func._flow_aggregate = op_aggregator

# Append the name and function to the class registry
self._parent_class._OPERATION_FUNCTIONS.append((name, func))
Expand Down Expand Up @@ -1602,11 +1633,11 @@ def with_directives(self, directives, name=None):
name and directives as an operation of the
:class:`~.FlowProject` subclass.
"""
warnings.warn(
"@FlowProject.operation.with_directives has been deprecated as of 0.22.0 and "
"will be removed in 0.23.0. Use @FlowProject.operation(directives={...}) "
"instead.",
FutureWarning,
_deprecated_warning(
deprecation="@FlowProject.operation.with_directives",
alternative="Use @FlowProject.operation(directives={...}) instead.",
deprecated_in="0.22.0",
removed_in="0.23.0",
)

def add_operation_with_directives(function):
Expand Down Expand Up @@ -5153,9 +5184,11 @@ class MyProject(FlowProject):
sys.exit(2)

if args.show_traceback:
warnings.warn(
"--show-traceback is deprecated and to be removed in signac-flow version 0.23.",
FutureWarning,
_deprecated_warning(
deprecation="--show-traceback",
alternative="",
deprecated_in="0.22.0",
removed_in="0.23.0",
)

# Manually 'merge' the various global options defined for both the main parser
Expand Down
25 changes: 25 additions & 0 deletions flow/util/misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import argparse
import logging
import os
import warnings
from collections.abc import MutableMapping
from contextlib import contextmanager
from functools import lru_cache, partial
Expand Down Expand Up @@ -393,6 +394,30 @@ def parallel_executor(func, iterable, **kwargs):
return parallel_executor


def _deprecated_warning(
*,
deprecation: str,
alternative: str,
deprecated_in: str,
removed_in: str,
category: Warning = FutureWarning,
):
warnings.warn(
" ".join(
(
deprecation,
"has been deprecated as of",
deprecated_in,
"and will be removed in",
removed_in + ".",
alternative,
)
),
category,
stacklevel=2,
)


__all__ = [
"redirect_log",
]
27 changes: 13 additions & 14 deletions tests/define_aggregate_test_project.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,50 +28,49 @@ def op1(job):
pass


@_AggregateTestProject.operation(cmd=True)
@aggregator.groupby("even")
@_AggregateTestProject.operation(cmd=True, aggregator=aggregator.groupby("even"))
def agg_op_parallel(*jobs):
# This is used to test parallel execution of aggregation operations
return f"echo '{len(jobs)}'"


@_AggregateTestProject.operation
@aggregator.groupby("even")
@_AggregateTestProject.operation(aggregator=aggregator.groupby("even"))
def agg_op1(*jobs):
total = sum(job.sp.i for job in jobs)
set_all_job_docs(jobs, "sum", total)


@_AggregateTestProject.operation
@aggregator.groupby(lambda job: job.sp.i % 2)
@_AggregateTestProject.operation(
aggregator=aggregator.groupby(lambda job: job.sp.i % 2)
)
def agg_op1_different(*jobs):
sum_other = sum(job.sp.i for job in jobs)
set_all_job_docs(jobs, "sum_other", sum_other)


@_AggregateTestProject.operation
@aggregator(statepoint_i_even_odd_aggregator)
@_AggregateTestProject.operation(
aggregator=aggregator(statepoint_i_even_odd_aggregator)
)
def agg_op1_custom(*jobs):
sum_custom = sum(job.sp.i for job in jobs)
set_all_job_docs(jobs, "sum_custom", sum_custom)


@group1
@_AggregateTestProject.operation
@aggregator.groupsof(30)
@_AggregateTestProject.operation(aggregator=aggregator.groupsof(30))
def agg_op2(*jobs):
set_all_job_docs(jobs, "op2", True)


@group1
@_AggregateTestProject.operation
@aggregator()
@_AggregateTestProject.operation(aggregator=aggregator())
def agg_op3(*jobs):
set_all_job_docs(jobs, "op3", True)


@_AggregateTestProject.operation(cmd=True)
@aggregator(sort_by="i", select=lambda job: job.sp.i <= 2)
@_AggregateTestProject.operation(
cmd=True, aggregator=aggregator(sort_by="i", select=lambda job: job.sp.i <= 2)
)
def agg_op4(*jobs):
return "echo '{jobs[0].sp.i} and {jobs[1].sp.i}'"

Expand Down
5 changes: 5 additions & 0 deletions tests/test_aggregates.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
from flow.aggregates import aggregator, get_aggregate_id
from flow.errors import FlowProjectDefinitionError

ignore_call_warning = pytest.mark.filterwarnings("ignore:@aggregator():FutureWarning")


class AggregateProjectSetup(TestProjectBase):
project_name = "AggregateTestProject"
Expand Down Expand Up @@ -311,17 +313,20 @@ def test_invalid_select(self, select):
with pytest.raises(TypeError):
aggregator(select=select)

@ignore_call_warning
@pytest.mark.parametrize("param", ["str", 1, None])
def test_invalid_call(self, param):
aggregator_instance = aggregator()
with pytest.raises(FlowProjectDefinitionError):
aggregator_instance(param)

@ignore_call_warning
def test_call_without_argument(self):
aggregate_instance = aggregator()
with pytest.raises(FlowProjectDefinitionError):
aggregate_instance()

@ignore_call_warning
def test_call_with_decorator(self):
@aggregator()
def test_function(x):
Expand Down
15 changes: 9 additions & 6 deletions tests/test_project.py
Original file line number Diff line number Diff line change
Expand Up @@ -2245,6 +2245,7 @@ def test_reregister_aggregates(self):
# The operation agg_op2 adds another aggregate in the project.
assert len(agg_cursor) == NUM_BEFORE_REREGISTRATION + 2

@pytest.mark.filterwarnings("ignore:@aggregator():FutureWarning")
def test_aggregator_with_job(self):
class A(FlowProject):
pass
Expand All @@ -2256,15 +2257,17 @@ class A(FlowProject):
def test_invalid_decorators(job):
pass

def test_with_job_aggregator(self):
class A(FlowProject):
pass

with pytest.raises(FlowProjectDefinitionError):

@aggregator()
@A.operation(with_job=True)
def test_invalid_decorators(job):
@aggregator()
def test_invalid_decorators_2(job):
pass

with pytest.raises(FlowProjectDefinitionError):

@A.operation(with_job=True, aggregator=aggregator())
def test_invalid_decorator_combination(job):
pass


Expand Down

0 comments on commit 6ed91d5

Please sign in to comment.