Skip to content

Commit

Permalink
Enable strict queue tags
Browse files Browse the repository at this point in the history
  • Loading branch information
bennybp committed May 16, 2024
1 parent e51b369 commit 6567cb8
Show file tree
Hide file tree
Showing 3 changed files with 59 additions and 2 deletions.
6 changes: 4 additions & 2 deletions qcfractal/qcfractal/components/tasks/socket.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ def __init__(self, root_socket: SQLAlchemySocket):
self._logger = logging.getLogger(__name__)

self._tasks_claim_limit = root_socket.qcf_config.api_limits.manager_tasks_claim
self._strict_queue_tags = root_socket.qcf_config.strict_queue_tags

def update_finished(
self, manager_name: str, results_compressed: Dict[int, bytes], *, session: Optional[Session] = None
Expand Down Expand Up @@ -313,8 +314,9 @@ def claim_tasks(
# The sort_date usually comes from the created_on of the record, or the created_on of the record's parent service
stmt = stmt.order_by(TaskQueueORM.priority.desc(), TaskQueueORM.sort_date.asc(), TaskQueueORM.id.asc())

# If tag is "*", then the manager will pull anything
if tag != "*":
# If tag is "*" (and strict_queue_tags is False), then the manager can pull anything
# If tag is "*" and strict_queue_tags is enabled, only pull tasks with tag == '*'
if tag != "*" or self._strict_queue_tags:
stmt = stmt.filter(TaskQueueORM.tag == tag)

# Skip locked rows - They may be in the process of being claimed by someone else
Expand Down
49 changes: 49 additions & 0 deletions qcfractal/qcfractal/components/tasks/test_socket_claim.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

from __future__ import annotations

from qcarchivetesting.testing_classes import QCATestingSnowflake
from typing import TYPE_CHECKING

import pytest
Expand Down Expand Up @@ -300,6 +301,54 @@ def test_task_socket_claim_tag_wildcard(storage_socket: SQLAlchemySocket, sessio
assert tasks[2]["id"] == recs[4].task.id


def test_task_socket_claim_tag_wildcard_strict(postgres_server, pytestconfig):

pg_harness = postgres_server.get_new_harness("claim_tag_wildcard_strict")
encoding = pytestconfig.getoption("--client-encoding")
with QCATestingSnowflake(pg_harness, encoding=encoding, extra_config={"strict_queue_tags": True}) as snowflake:
storage_socket = snowflake.get_storage_socket()

mname1 = ManagerName(cluster="test_cluster", hostname="a_host1", uuid="1234-5678-1234-5678")
mprog1 = {"qcengine": ["unknown"], "psi4": ["unknown"], "geometric": ["v3.0"]}
storage_socket.managers.activate(
name_data=mname1,
manager_version="v2.0",
username="bill",
programs=mprog1,
tags=["TAG3", "*"],
)

meta, id_1 = storage_socket.records.singlepoint.add(
[molecule_1], input_spec_1, "tag1", PriorityEnum.normal, None, None, True
)
meta, id_2 = storage_socket.records.singlepoint.add(
[molecule_2], input_spec_2, "tag2", PriorityEnum.normal, None, None, True
)
meta, id_3 = storage_socket.records.singlepoint.add(
[molecule_3], input_spec_3, "*", PriorityEnum.normal, None, None, True
)
meta, id_4 = storage_socket.records.optimization.add(
[molecule_4], input_spec_4, "taG3", PriorityEnum.normal, None, None, True
)
meta, id_5 = storage_socket.records.singlepoint.add(
[molecule_5], input_spec_5, "tag1", PriorityEnum.normal, None, None, True
)

all_id = id_1 + id_2 + id_3 + id_4 + id_5

client = snowflake.client()
recs = client.get_records(all_id, include=["task"])

# tag3 should be first, then only the * task (in order)
tasks = storage_socket.tasks.claim_tasks(mname1.fullname, mprog1, ["tag3", "*"], 2)
assert len(tasks) == 2
assert tasks[0]["id"] == recs[3].task.id
assert tasks[1]["id"] == recs[2].task.id

tasks = storage_socket.tasks.claim_tasks(mname1.fullname, mprog1, ["tag3", "*"], 3)
assert len(tasks) == 0


def test_task_socket_claim_program(storage_socket: SQLAlchemySocket, session: Session):
mname1 = ManagerName(cluster="test_cluster", hostname="a_host1", uuid="1234-5678-1234-5678")

Expand Down
6 changes: 6 additions & 0 deletions qcfractal/qcfractal/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -338,6 +338,12 @@ class FractalConfig(ConfigBase):
True,
description="Allows unauthenticated read access to this instance. This does not extend to sensitive tables (such as user information)",
)
strict_queue_tags: bool = Field(
False,
description="If True, disables wildcard behavior for queue tags. This disables managers from claiming all "
"tags if they specify a wildcard ('*') tag. Managers will still be able to claim tasks with an "
"explicit '*' tag if they specifiy the '*' queue tag in their config",
)

# Logging and profiling
logfile: Optional[str] = Field(
Expand Down

0 comments on commit 6567cb8

Please sign in to comment.