Skip to content

Commit

Permalink
Task iterators: backend performance improvements (#218)
Browse files Browse the repository at this point in the history
  • Loading branch information
psrok1 authored Jun 23, 2023
1 parent 4fb07c1 commit a677b49
Show file tree
Hide file tree
Showing 6 changed files with 209 additions and 57 deletions.
2 changes: 1 addition & 1 deletion Dockerfile
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
FROM python:3.7
FROM python:3.11

WORKDIR /app/service
COPY ./requirements.txt ./requirements.txt
Expand Down
124 changes: 114 additions & 10 deletions karton/core/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

from .exceptions import InvalidIdentityError
from .task import Task, TaskPriority, TaskState
from .utils import chunks
from .utils import chunks, chunks_iter

KARTON_TASKS_QUEUE = "karton.tasks"
KARTON_OPERATIONS_QUEUE = "karton.operations"
Expand Down Expand Up @@ -377,12 +377,20 @@ def get_task(self, task_uid: str) -> Optional[Task]:
return None
return Task.unserialize(task_data, backend=self)

def get_tasks(self, task_uid_list: List[str], chunk_size: int = 1000) -> List[Task]:
def get_tasks(
self,
task_uid_list: List[str],
chunk_size: int = 1000,
parse_resources: bool = True,
) -> List[Task]:
"""
Get multiple tasks for given identifier list
:param task_uid_list: List of task identifiers
:param chunk_size: Size of chunks passed to the Redis MGET command
:param parse_resources: If set to False, resources are not parsed.
It speeds up deserialization. Read :py:meth:`Task.unserialize`
documentation to learn more.
:return: List of task objects
"""
keys = chunks(
Expand All @@ -391,25 +399,89 @@ def get_tasks(self, task_uid_list: List[str], chunk_size: int = 1000) -> List[Ta
)
return [
Task.unserialize(task_data, backend=self)
if parse_resources
else Task.unserialize(task_data, parse_resources=False)
for chunk in keys
for task_data in self.redis.mget(chunk)
if task_data is not None
]

def get_all_tasks(self, chunk_size: int = 1000) -> List[Task]:
def _iter_tasks(
self,
task_keys: Iterator[str],
chunk_size: int = 1000,
parse_resources: bool = True,
) -> Iterator[Task]:
for chunk in chunks_iter(task_keys, chunk_size):
yield from (
Task.unserialize(task_data, backend=self)
if parse_resources
else Task.unserialize(task_data, parse_resources=False)
for task_data in self.redis.mget(chunk)
if task_data is not None
)

def iter_tasks(
self,
task_uid_list: Iterable[str],
chunk_size: int = 1000,
parse_resources: bool = True,
) -> Iterator[Task]:
"""
Get multiple tasks for given identifier list as an iterator
:param task_uid_list: List of task fully-qualified identifiers
:param chunk_size: Size of chunks passed to the Redis MGET command
:param parse_resources: If set to False, resources are not parsed.
It speeds up deserialization. Read :py:meth:`Task.unserialize` documentation
to learn more.
:return: Iterator with task objects
"""
return self._iter_tasks(
map(
lambda task_uid: f"{KARTON_TASK_NAMESPACE}:{task_uid}",
task_uid_list,
),
chunk_size=chunk_size,
parse_resources=parse_resources,
)

def iter_all_tasks(
self, chunk_size: int = 1000, parse_resources: bool = True
) -> Iterator[Task]:
"""
Iterates all tasks registered in Redis
:param chunk_size: Size of chunks passed to the Redis SCAN and MGET command
:param parse_resources: If set to False, resources are not parsed.
It speeds up deserialization. Read :py:meth:`Task.unserialize` documentation
to learn more.
:return: Iterator with Task objects
"""
task_keys = self.redis.scan_iter(
match=f"{KARTON_TASK_NAMESPACE}:*", count=chunk_size
)
return self._iter_tasks(
task_keys, chunk_size=chunk_size, parse_resources=parse_resources
)

def get_all_tasks(
self, chunk_size: int = 1000, parse_resources: bool = True
) -> List[Task]:
"""
Get all tasks registered in Redis
.. warning::
This method loads all tasks into memory.
It's recommended to use :py:meth:`iter_all_tasks` instead.
:param chunk_size: Size of chunks passed to the Redis MGET command
:param parse_resources: If set to False, resources are not parsed.
It speeds up deserialization. Read :py:meth:`Task.unserialize` documentation
to learn more.
:return: List with Task objects
"""
tasks = self.redis.keys(f"{KARTON_TASK_NAMESPACE}:*")
return [
Task.unserialize(task_data)
for chunk in chunks(tasks, chunk_size)
for task_data in self.redis.mget(chunk)
if task_data is not None
]
return list(
self.iter_all_tasks(chunk_size=chunk_size, parse_resources=parse_resources)
)

def register_task(self, task: Task, pipe: Optional[Pipeline] = None) -> None:
"""
Expand Down Expand Up @@ -451,6 +523,11 @@ def delete_task(self, task: Task) -> None:
"""
Remove task from Redis
.. warning::
Used internally by karton.system.
If you want to cancel task: mark it as finished and let it be deleted
by karton.system.
:param task: Task object
"""
self.redis.delete(f"{KARTON_TASK_NAMESPACE}:{task.uid}")
Expand All @@ -459,6 +536,11 @@ def delete_tasks(self, tasks: Iterable[Task], chunk_size: int = 1000) -> None:
"""
Remove multiple tasks from Redis
.. warning::
Used internally by karton.system.
If you want to cancel task: mark it as finished and let it be deleted
by karton.system.
:param tasks: List of Task objects
:param chunk_size: Size of chunks passed to the Redis DELETE command
"""
Expand All @@ -485,6 +567,14 @@ def get_task_ids_from_queue(self, queue: str) -> List[str]:
"""
return self.redis.lrange(queue, 0, -1)

def delete_consumer_queues(self, identity: str) -> None:
"""
Deletes consumer queues for given identity
:param identity: Consumer identity
"""
self.redis.delete(*self.get_queue_names(identity))

def remove_task_queue(self, queue: str) -> List[Task]:
"""
Remove task queue with all contained tasks
Expand Down Expand Up @@ -535,6 +625,20 @@ def consume_queues(
"""
return self.redis.blpop(queues, timeout=timeout)

def increment_multiple_metrics(
self, metric: KartonMetrics, increments: Dict[str, int]
) -> None:
"""
Increments metrics for multiple identities by given value via single pipeline
:param metric: Operation metric type
:param increments: Dictionary of Karton service identities and value
to add to the metric
"""
p = self.redis.pipeline()
for identity, increment in increments.items():
p.hincrby(metric.value, identity, increment)
p.execute()

def consume_queues_batch(self, queue: str, max_count: int) -> List[str]:
"""
Get a batch of items from the queue
Expand Down
79 changes: 58 additions & 21 deletions karton/core/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@
if TYPE_CHECKING:
from .backend import KartonBackend # noqa

import orjson


class TaskState(enum.Enum):
DECLARED = "Declared" # Task declared in TASKS_QUEUE
Expand Down Expand Up @@ -62,6 +64,21 @@ class Task(object):
:param error: Traceback of a exception that happened while performing this task
"""

__slots__ = (
"uid",
"root_uid",
"orig_uid",
"parent_uid",
"error",
"headers",
"status",
"last_update",
"priority",
"payload",
"payload_persistent",
"_headers_persistent_keys",
)

def __init__(
self,
headers: Dict[str, Any],
Expand All @@ -74,6 +91,8 @@ def __init__(
orig_uid: Optional[str] = None,
uid: Optional[str] = None,
error: Optional[List[str]] = None,
_status: Optional[TaskState] = None,
_last_update: Optional[float] = None,
) -> None:
payload = payload or {}
payload_persistent = payload_persistent or {}
Expand Down Expand Up @@ -102,9 +121,9 @@ def __init__(
self.error = error
self.headers = {**headers, **headers_persistent}
self._headers_persistent_keys = set(headers_persistent.keys())
self.status = TaskState.DECLARED
self.status = _status or TaskState.DECLARED

self.last_update: float = time.time()
self.last_update: float = _last_update or time.time()
self.priority = priority or TaskPriority.NORMAL

self.payload = dict(payload)
Expand All @@ -114,6 +133,10 @@ def __init__(
def headers_persistent(self) -> Dict[str, Any]:
return {k: v for k, v in self.headers.items() if self.is_header_persistent(k)}

@property
def receiver(self) -> Optional[str]:
return self.headers.get("receiver")

def fork_task(self) -> "Task":
"""
Fork task to transfer single task to many queues (but use different UID).
Expand Down Expand Up @@ -362,13 +385,24 @@ def iterate_resources(self) -> Iterator[ResourceBase]:

@staticmethod
def unserialize(
data: Union[str, bytes], backend: Optional["KartonBackend"] = None
data: Union[str, bytes],
backend: Optional["KartonBackend"] = None,
parse_resources: bool = True,
) -> "Task":
"""
Unserialize Task instance from JSON string
:param data: JSON-serialized task
:param backend: Backend instance to be bound to RemoteResource objects
:param parse_resources: |
If set to False (default is True), method doesn't
deserialize '__karton_resource__' entries, which speeds up deserialization
process. This flag is used mainly for multiple task processing e.g.
filtering based on status.
When resource deserialization is turned off, Task.unserialize will try
to use faster 3rd-party JSON parser (orjson) if it's installed. It's not
added as a required dependency but can speed up things if you need to check
status of multiple tasks at once.
:return: Unserialized Task object
:meta private:
Expand All @@ -386,7 +420,10 @@ def unserialize_resources(value: Any) -> Any:
if not isinstance(data, str):
data = data.decode("utf8")

task_data = json.loads(data, object_hook=unserialize_resources)
if parse_resources:
task_data = json.loads(data, object_hook=unserialize_resources)
else:
task_data = orjson.loads(data)

# Compatibility with Karton <5.2.0
headers_persistent_fallback = task_data["payload_persistent"].get(
Expand All @@ -399,24 +436,24 @@ def unserialize_resources(value: Any) -> Any:
task = Task(
task_data["headers"],
headers_persistent=headers_persistent,
uid=task_data["uid"],
root_uid=task_data["root_uid"],
parent_uid=task_data["parent_uid"],
# Compatibility with <= 3.x.x (get)
orig_uid=task_data.get("orig_uid", None),
payload=task_data["payload"],
payload_persistent=task_data["payload_persistent"],
# Compatibility with <= 3.x.x (get)
error=task_data.get("error"),
# Compatibility with <= 2.x.x (get)
priority=(
TaskPriority(task_data.get("priority"))
if "priority" in task_data
else TaskPriority.NORMAL
),
_status=TaskState(task_data["status"]),
_last_update=task_data.get("last_update", None),
)
task.uid = task_data["uid"]
task.root_uid = task_data["root_uid"]
task.parent_uid = task_data["parent_uid"]
# Compatibility with <= 3.x.x (get)
task.orig_uid = task_data.get("orig_uid", None)
task.status = TaskState(task_data["status"])
# Compatibility with <= 3.x.x (get)
task.error = task_data.get("error")
# Compatibility with <= 2.x.x (get)
task.priority = (
TaskPriority(task_data.get("priority"))
if "priority" in task_data
else TaskPriority.NORMAL
)
task.last_update = task_data.get("last_update", None)
task.payload = task_data["payload"]
task.payload_persistent = task_data["payload_persistent"]
return task

def __repr__(self) -> str:
Expand Down
11 changes: 11 additions & 0 deletions karton/core/utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import functools
import itertools
import signal
from contextlib import contextmanager
from typing import Any, Callable, Iterator, Sequence, Tuple, TypeVar
Expand All @@ -12,6 +13,16 @@ def chunks(seq: Sequence[T], size: int) -> Iterator[Sequence[T]]:
return (seq[pos : pos + size] for pos in range(0, len(seq), size))


def chunks_iter(seq: Iterator[T], size: int) -> Iterator[Sequence[T]]:
# We need to ensure that seq is iterator, so this method works correctly
it = iter(seq)
while True:
elements = list(itertools.islice(it, size))
if len(elements) == 0:
return
yield elements


def recursive_iter(obj: Any) -> Iterator[Any]:
"""
Yields all values recursively from nested list/dict structures
Expand Down
Loading

0 comments on commit a677b49

Please sign in to comment.