From 2641aa86326d14aabf0442e84c7e7b71bed41275 Mon Sep 17 00:00:00 2001 From: Andrew Halberstadt Date: Wed, 2 Oct 2024 13:32:05 -0400 Subject: [PATCH] fix(optimize): support kwargs in 'register_strategy' decorator Some optimization strategies can take class level kwargs (such as the 'split_args' kwarg for all composite strategies). Ensure these get forwarded when using the 'register_strategy' decorator. --- src/taskgraph/optimize/base.py | 6 ++++-- test/test_optimize.py | 15 ++++++++++++++- 2 files changed, 18 insertions(+), 3 deletions(-) diff --git a/src/taskgraph/optimize/base.py b/src/taskgraph/optimize/base.py index c40b9a87..c6e84a5b 100644 --- a/src/taskgraph/optimize/base.py +++ b/src/taskgraph/optimize/base.py @@ -28,10 +28,12 @@ registry = {} -def register_strategy(name, args=()): +def register_strategy(name, args=(), kwargs=None): + kwargs = kwargs or {} + def wrap(cls): if name not in registry: - registry[name] = cls(*args) + registry[name] = cls(*args, **kwargs) if not hasattr(registry[name], "description"): registry[name].description = name return cls diff --git a/test/test_optimize.py b/test/test_optimize.py index 86beb13c..0769c544 100644 --- a/test/test_optimize.py +++ b/test/test_optimize.py @@ -9,7 +9,13 @@ from taskgraph.graph import Graph from taskgraph.optimize import base as optimize_mod -from taskgraph.optimize.base import All, Any, Not, OptimizationStrategy +from taskgraph.optimize.base import ( + All, + Any, + Not, + OptimizationStrategy, + register_strategy, +) from taskgraph.task import Task from taskgraph.taskgraph import TaskGraph @@ -467,3 +473,10 @@ def test_get_subgraph_removed_dep(): graph = make_triangle() with pytest.raises(Exception): optimize_mod.get_subgraph(graph, {"t2"}, set(), {}) + + +def test_register_strategy(mocker): + m = mocker.Mock() + func = register_strategy("foo", args=("one", "two"), kwargs={"n": 1}) + func(m) + m.assert_called_with("one", "two", n=1)