Skip to content

Commit

Permalink
Fix Celery executor getting stuck randomly because of reset_signals i…
Browse files Browse the repository at this point in the history
…n multiprocessing (#15989)

Fixes #15938

multiprocessing.Pool is known to often become stuck. It causes celery_executor to hang randomly. This happens at least on Debian, Ubuntu using Python 3.8.7 and Python 3.8.10. The issue is reproducible by running test_send_tasks_to_celery_hang in this PR several times (with db backend set to something other than sqlite because sqlite disables some parallelization)

The issue goes away once switched to concurrent.futures.ProcessPoolExecutor. In python 3.6 and earlier, ProcessPoolExecutor has no initializer argument. Fortunately, it's not needed because reset_signal is no longer needed because the signal handler now checks if the current process is the parent.
  • Loading branch information
yuqian90 authored May 29, 2021
1 parent 2de0692 commit f75dd7a
Show file tree
Hide file tree
Showing 4 changed files with 78 additions and 16 deletions.
24 changes: 8 additions & 16 deletions airflow/executors/celery_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,8 @@
import time
import traceback
from collections import OrderedDict
from multiprocessing import Pool, cpu_count
from concurrent.futures import ProcessPoolExecutor
from multiprocessing import cpu_count
from typing import Any, Dict, List, Mapping, MutableMapping, Optional, Set, Tuple, Union

from celery import Celery, Task, states as celery_states
Expand Down Expand Up @@ -318,18 +319,9 @@ def _send_tasks_to_celery(self, task_tuples_to_send: List[TaskInstanceInCelery])
chunksize = self._num_tasks_per_send_process(len(task_tuples_to_send))
num_processes = min(len(task_tuples_to_send), self._sync_parallelism)

def reset_signals():
# Since we are run from inside the SchedulerJob, we don't to
# inherit the signal handlers that we registered there.
import signal

signal.signal(signal.SIGINT, signal.SIG_DFL)
signal.signal(signal.SIGTERM, signal.SIG_DFL)
signal.signal(signal.SIGUSR2, signal.SIG_DFL)

with Pool(processes=num_processes, initializer=reset_signals) as send_pool:
key_and_async_results = send_pool.map(
send_task_to_executor, task_tuples_to_send, chunksize=chunksize
with ProcessPoolExecutor(max_workers=num_processes) as send_pool:
key_and_async_results = list(
send_pool.map(send_task_to_executor, task_tuples_to_send, chunksize=chunksize)
)
return key_and_async_results

Expand Down Expand Up @@ -592,11 +584,11 @@ def _prepare_state_and_info_by_task_dict(
def _get_many_using_multiprocessing(self, async_results) -> Mapping[str, EventBufferValueType]:
num_process = min(len(async_results), self._sync_parallelism)

with Pool(processes=num_process) as sync_pool:
with ProcessPoolExecutor(max_workers=num_process) as sync_pool:
chunksize = max(1, math.floor(math.ceil(1.0 * len(async_results) / self._sync_parallelism)))

task_id_to_states_and_info = sync_pool.map(
fetch_celery_task_state, async_results, chunksize=chunksize
task_id_to_states_and_info = list(
sync_pool.map(fetch_celery_task_state, async_results, chunksize=chunksize)
)

states_and_info_by_task_id: MutableMapping[str, EventBufferValueType] = {}
Expand Down
16 changes: 16 additions & 0 deletions airflow/jobs/scheduler_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -671,6 +671,14 @@ def process_file(
return len(dagbag.dags), len(dagbag.import_errors)


def _is_parent_process():
"""
Returns True if the current process is the parent process. False if the current process is a child
process started by multiprocessing.
"""
return multiprocessing.current_process().name == 'MainProcess'


class SchedulerJob(BaseJob): # pylint: disable=too-many-instance-attributes
"""
This SchedulerJob runs for a specific time interval and schedules the jobs
Expand Down Expand Up @@ -746,12 +754,20 @@ def register_signals(self) -> None:

def _exit_gracefully(self, signum, frame) -> None: # pylint: disable=unused-argument
"""Helper method to clean up processor_agent to avoid leaving orphan processes."""
if not _is_parent_process():
# Only the parent process should perform the cleanup.
return

self.log.info("Exiting gracefully upon receiving signal %s", signum)
if self.processor_agent:
self.processor_agent.end()
sys.exit(os.EX_OK)

def _debug_dump(self, signum, frame): # pylint: disable=unused-argument
if not _is_parent_process():
# Only the parent process should perform the debug dump.
return

try:
sig_name = signal.Signals(signum).name # pylint: disable=no-member
except Exception: # pylint: disable=broad-except
Expand Down
2 changes: 2 additions & 0 deletions scripts/ci/docker-compose/base.yml
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@ services:
ports:
- "${WEBSERVER_HOST_PORT}:8080"
- "${FLOWER_HOST_PORT}:5555"
cap_add:
- SYS_PTRACE
volumes:
sqlite-db-volume:
postgres-db-volume:
Expand Down
52 changes: 52 additions & 0 deletions tests/executors/test_celery_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import contextlib
import json
import os
import signal
import sys
import unittest
from datetime import datetime, timedelta
Expand Down Expand Up @@ -484,3 +485,54 @@ def test_should_support_base_backend(self):
assert [
'DEBUG:airflow.executors.celery_executor.BulkStateFetcher:Fetched 2 state(s) for 2 task(s)'
] == cm.output


class MockTask:
"""
A picklable object used to mock tasks sent to Celery. Can't use the mock library
here because it's not picklable.
"""

def apply_async(self, *args, **kwargs):
return 1


def _exit_gracefully(signum, _):
print(f"{os.getpid()} Exiting gracefully upon receiving signal {signum}")
sys.exit(signum)


@pytest.fixture
def register_signals():
"""
Register the same signals as scheduler does to test celery_executor to make sure it does not
hang.
"""
orig_sigint = orig_sigterm = orig_sigusr2 = signal.SIG_DFL

orig_sigint = signal.signal(signal.SIGINT, _exit_gracefully)
orig_sigterm = signal.signal(signal.SIGTERM, _exit_gracefully)
orig_sigusr2 = signal.signal(signal.SIGUSR2, _exit_gracefully)

yield

# Restore original signal handlers after test
signal.signal(signal.SIGINT, orig_sigint)
signal.signal(signal.SIGTERM, orig_sigterm)
signal.signal(signal.SIGUSR2, orig_sigusr2)


def test_send_tasks_to_celery_hang(register_signals): # pylint: disable=unused-argument
"""
Test that celery_executor does not hang after many runs.
"""
executor = celery_executor.CeleryExecutor()

task = MockTask()
task_tuples_to_send = [(None, None, None, None, task) for _ in range(26)]

for _ in range(500):
# This loop can hang on Linux if celery_executor does something wrong with
# multiprocessing.
results = executor._send_tasks_to_celery(task_tuples_to_send)
assert results == [(None, None, 1) for _ in task_tuples_to_send]

0 comments on commit f75dd7a

Please sign in to comment.