Skip to content

Commit

Permalink
Use ProcessPoolExecutor instead of multiprocessing.Pool
Browse files Browse the repository at this point in the history
  • Loading branch information
yuqian90 committed May 28, 2021
1 parent e674711 commit 51fdcff
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 27 deletions.
24 changes: 8 additions & 16 deletions airflow/executors/celery_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,8 @@
import time
import traceback
from collections import OrderedDict
from multiprocessing import Pool, cpu_count
from concurrent.futures import ProcessPoolExecutor
from multiprocessing import cpu_count
from typing import Any, Dict, List, Mapping, MutableMapping, Optional, Set, Tuple, Union

from celery import Celery, Task, states as celery_states
Expand Down Expand Up @@ -318,18 +319,9 @@ def _send_tasks_to_celery(self, task_tuples_to_send: List[TaskInstanceInCelery])
chunksize = self._num_tasks_per_send_process(len(task_tuples_to_send))
num_processes = min(len(task_tuples_to_send), self._sync_parallelism)

def reset_signals():
# Since we are run from inside the SchedulerJob, we don't to
# inherit the signal handlers that we registered there.
import signal

signal.signal(signal.SIGINT, signal.SIG_DFL)
signal.signal(signal.SIGTERM, signal.SIG_DFL)
signal.signal(signal.SIGUSR2, signal.SIG_DFL)

with Pool(processes=num_processes, initializer=reset_signals) as send_pool:
key_and_async_results = send_pool.map(
send_task_to_executor, task_tuples_to_send, chunksize=chunksize
with ProcessPoolExecutor(max_workers=num_processes) as send_pool:
key_and_async_results = list(
send_pool.map(send_task_to_executor, task_tuples_to_send, chunksize=chunksize)
)
return key_and_async_results

Expand Down Expand Up @@ -592,11 +584,11 @@ def _prepare_state_and_info_by_task_dict(
def _get_many_using_multiprocessing(self, async_results) -> Mapping[str, EventBufferValueType]:
num_process = min(len(async_results), self._sync_parallelism)

with Pool(processes=num_process) as sync_pool:
with ProcessPoolExecutor(max_workers=num_process) as sync_pool:
chunksize = max(1, math.floor(math.ceil(1.0 * len(async_results) / self._sync_parallelism)))

task_id_to_states_and_info = sync_pool.map(
fetch_celery_task_state, async_results, chunksize=chunksize
task_id_to_states_and_info = list(
sync_pool.map(fetch_celery_task_state, async_results, chunksize=chunksize)
)

states_and_info_by_task_id: MutableMapping[str, EventBufferValueType] = {}
Expand Down
33 changes: 22 additions & 11 deletions airflow/jobs/scheduler_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -671,6 +671,14 @@ def process_file(
return len(dagbag.dags), len(dagbag.import_errors)


def _is_parent_process():
"""
Returns True if the current process is the parent process. False if the current process is a child
process started by multiprocessing.
"""
return multiprocessing.current_process().name == 'MainProcess'


class SchedulerJob(BaseJob): # pylint: disable=too-many-instance-attributes
"""
This SchedulerJob runs for a specific time interval and schedules the jobs
Expand Down Expand Up @@ -746,21 +754,24 @@ def register_signals(self) -> None:

def _exit_gracefully(self, signum, frame) -> None: # pylint: disable=unused-argument
"""Helper method to clean up processor_agent to avoid leaving orphan processes."""
self.log.info("Exiting gracefully upon receiving signal %s", signum)
if self.processor_agent:
self.processor_agent.end()
sys.exit(os.EX_OK)
if _is_parent_process():
# Only the parent process should perform the cleanup.
self.log.info("Exiting gracefully upon receiving signal %s", signum)
if self.processor_agent:
self.processor_agent.end()
sys.exit(os.EX_OK)

def _debug_dump(self, signum, frame): # pylint: disable=unused-argument
try:
sig_name = signal.Signals(signum).name # pylint: disable=no-member
except Exception: # pylint: disable=broad-except
sig_name = str(signum)
if _is_parent_process():
try:
sig_name = signal.Signals(signum).name # pylint: disable=no-member
except Exception: # pylint: disable=broad-except
sig_name = str(signum)

self.log.info("%s\n%s received, printing debug\n%s", "-" * 80, sig_name, "-" * 80)
self.log.info("%s\n%s received, printing debug\n%s", "-" * 80, sig_name, "-" * 80)

self.executor.debug_dump()
self.log.info("-" * 80)
self.executor.debug_dump()
self.log.info("-" * 80)

def is_alive(self, grace_multiplier: Optional[float] = None) -> bool:
"""
Expand Down

0 comments on commit 51fdcff

Please sign in to comment.