Skip to content

Commit

Permalink
feat: make worker skip run when a task is completed
Browse files Browse the repository at this point in the history
  • Loading branch information
hiro-o918 committed Nov 8, 2024
1 parent 47b0b0b commit fd52fc6
Show file tree
Hide file tree
Showing 2 changed files with 76 additions and 6 deletions.
33 changes: 28 additions & 5 deletions gokart/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,9 @@
from luigi.task_register import TaskClassException, load_task
from luigi.task_status import RUNNING

logger = logging.getLogger('luigi-interface')
from gokart.parameter import ExplicitBoolParameter

logger = logging.getLogger(__name__)

# Prevent fork() from being called during a C-level getaddrinfo() which uses a process-global mutex,
# that may not be unlocked in child process, resulting in the process being locked indefinitely.
Expand Down Expand Up @@ -124,6 +126,7 @@ def __init__(
check_unfulfilled_deps: bool = True,
check_complete_on_run: bool = False,
task_completion_cache: Optional[Dict[str, Any]] = None,
skip_if_completed_pre_run: bool = True,
) -> None:
super(TaskProcess, self).__init__()
self.task = task
Expand All @@ -136,12 +139,19 @@ def __init__(
self.check_unfulfilled_deps = check_unfulfilled_deps
self.check_complete_on_run = check_complete_on_run
self.task_completion_cache = task_completion_cache
self.skip_if_completed_pre_run = skip_if_completed_pre_run

# completeness check using the cache
self.check_complete = functools.partial(luigi.worker.check_complete_cached, completion_cache=task_completion_cache)

def _run_task(self) -> Optional[collections.abc.Generator]:
if self.skip_if_completed_pre_run and self.check_complete(self.task):
logger.warning(f'{self.task} is skipped because the task is already completed.')
return None
return self.task.run()

def _run_get_new_deps(self) -> Optional[List[Tuple[str, str, Dict[str, str]]]]:
task_gen = self.task.run()
task_gen = self._run_task()

if not isinstance(task_gen, collections.abc.Generator):
return None
Expand Down Expand Up @@ -358,6 +368,11 @@ class gokart_worker(luigi.Config):
'dynamic dependencies but assumes that the completion status does not change '
'after it was true the first time.',
)
skip_if_completed_pre_run: bool = ExplicitBoolParameter(
default=True, description='If true, skip running tasks that are already completed just before the Task is run.'
)


class Worker:
"""
Worker object communicates with a scheduler.
Expand All @@ -369,15 +384,22 @@ class Worker:
"""

def __init__(
self, scheduler: Optional[Scheduler] = None, worker_id: Optional[str] = None, worker_processes: int = 1, assistant: bool = False, **kwargs: Any
self,
scheduler: Optional[Scheduler] = None,
worker_id: Optional[str] = None,
worker_processes: int = 1,
assistant: bool = False,
config: Optional[gokart_worker] = None,
) -> None:
if scheduler is None:
scheduler = Scheduler()

self.worker_processes = int(worker_processes)
self._worker_info = self._generate_worker_info()

self._config = luigi.worker.worker(**kwargs)
if config is None:
self._config = gokart_worker()
else:
self._config = config

worker_id = worker_id or self._config.id or self._generate_worker_id(self._worker_info)

Expand Down Expand Up @@ -886,6 +908,7 @@ def _create_task_process(self, task):
check_unfulfilled_deps=self._config.check_unfulfilled_deps,
check_complete_on_run=self._config.check_complete_on_run,
task_completion_cache=self._task_completion_cache,
skip_if_completed_pre_run=self._config.skip_if_completed_pre_run,
)

def _purge_children(self) -> None:
Expand Down
49 changes: 48 additions & 1 deletion test/test_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,12 @@
from unittest.mock import Mock

import luigi
import luigi.worker
import pytest
from luigi import scheduler

import gokart
from gokart.worker import Worker
from gokart.worker import Worker, gokart_worker


class _DummyTask(gokart.TaskOnKart):
Expand All @@ -33,3 +34,49 @@ def test_run(self, monkeypatch: pytest.MonkeyPatch):
assert worker.add(task)
assert worker.run()
mock_run.assert_called_once()


class _DummyTaskToCheckSkip(gokart.TaskOnKart[None]):
task_namespace = __name__

def _run(self): ...

def run(self):
self._run()
self.dump(None)

def complete(self) -> bool:
return False


class TestWorkerSkipIfCompletedPreRun:
@pytest.mark.parametrize(
'skip_if_completed_pre_run,is_completed,expect_skipped',
[
pytest.param(True, True, True, id='skipped when completed and skip_if_completed_pre_run is True'),
pytest.param(True, False, False, id='not skipped when not completed and skip_if_completed_pre_run is True'),
pytest.param(False, True, False, id='not skipped when completed and skip_if_completed_pre_run is False'),
pytest.param(False, False, False, id='not skipped when not completed and skip_if_completed_pre_run is False'),
],
)
def test_skip_task(self, monkeypatch: pytest.MonkeyPatch, skip_if_completed_pre_run: bool, is_completed: bool, expect_skipped: bool):
sch = scheduler.Scheduler()
worker = Worker(scheduler=sch, config=gokart_worker(skip_if_completed_pre_run=skip_if_completed_pre_run))

mock_complete = Mock(return_value=is_completed)
# NOTE: set `complete_check_at_run=False` to avoid using deprecated skip logic.
task = _DummyTaskToCheckSkip(complete_check_at_run=False)
mock_run = Mock()
monkeypatch.setattr(task, '_run', mock_run)

with worker:
assert worker.add(task)
# NOTE: mock `complete` after `add` because `add` calls `complete`
# to check if the task is already completed.
monkeypatch.setattr(task, 'complete', mock_complete)
assert worker.run()

if expect_skipped:
mock_run.assert_not_called()
else:
mock_run.assert_called_once()

0 comments on commit fd52fc6

Please sign in to comment.