From 64af5325390ff436f77548ad149870457a499a45 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jos=C3=A9=20S=C3=A1nchez-Gallego?= Date: Fri, 12 Jul 2024 14:40:38 -0700 Subject: [PATCH] Conditionally import the GatheringTaskGroup class --- src/sdsstools/utils.py | 40 ++++++++++++++++++++++------------------ 1 file changed, 22 insertions(+), 18 deletions(-) diff --git a/src/sdsstools/utils.py b/src/sdsstools/utils.py index 6d5da8c..56d1ec6 100644 --- a/src/sdsstools/utils.py +++ b/src/sdsstools/utils.py @@ -11,6 +11,7 @@ import asyncio import concurrent.futures import pathlib +import sys import tempfile import time from contextlib import suppress @@ -22,7 +23,6 @@ "get_temporary_file_path", "run_in_executor", "cancel_task", - "GatheringTaskGroup", ] @@ -110,29 +110,33 @@ async def cancel_task(task: asyncio.Future | None): await task -class GatheringTaskGroup(asyncio.TaskGroup): - """An extension to ``asyncio.TaskGroup`` that keeps track of the tasks created. +if sys.version_info >= (3, 11): - Adapted from https://stackoverflow.com/questions/75204560/consuming-taskgroup-response + class GatheringTaskGroup(asyncio.TaskGroup): + """An extension to ``asyncio.TaskGroup`` that keeps track of the tasks created. - """ + Adapted from https://stackoverflow.com/questions/75204560/consuming-taskgroup-response + + """ + + def __init__(self): + super().__init__() + self.__tasks = [] - def __init__(self): - super().__init__() - self.__tasks = [] + def create_task(self, coro, *, name=None, context=None): + """Creates a task and appends it to the list of tasks.""" - def create_task(self, coro, *, name=None, context=None): - """Creates a task and appends it to the list of tasks.""" + task = super().create_task(coro, name=name, context=context) + self.__tasks.append(task) - task = super().create_task(coro, name=name, context=context) - self.__tasks.append(task) + return task - return task + def results(self): + """Returns the results of the tasks in the same order they were created.""" - def results(self): - """Returns the results of the tasks in the same order they were created.""" + if len(self._tasks) > 0: + raise RuntimeError("Not all tasks have completed yet.") - if len(self._tasks) > 0: - raise RuntimeError("Not all tasks have completed yet.") + return [task.result() for task in self.__tasks] - return [task.result() for task in self.__tasks] + __all__.append("GatheringTaskGroup")