From 6567cb8536051293a830c73c691f5a2e9f2c901f Mon Sep 17 00:00:00 2001 From: Benjamin Pritchard Date: Thu, 16 May 2024 10:19:17 -0400 Subject: [PATCH] Enable strict queue tags --- .../qcfractal/components/tasks/socket.py | 6 ++- .../components/tasks/test_socket_claim.py | 49 +++++++++++++++++++ qcfractal/qcfractal/config.py | 6 +++ 3 files changed, 59 insertions(+), 2 deletions(-) diff --git a/qcfractal/qcfractal/components/tasks/socket.py b/qcfractal/qcfractal/components/tasks/socket.py index 9b63481c8..9e9b8f799 100644 --- a/qcfractal/qcfractal/components/tasks/socket.py +++ b/qcfractal/qcfractal/components/tasks/socket.py @@ -42,6 +42,7 @@ def __init__(self, root_socket: SQLAlchemySocket): self._logger = logging.getLogger(__name__) self._tasks_claim_limit = root_socket.qcf_config.api_limits.manager_tasks_claim + self._strict_queue_tags = root_socket.qcf_config.strict_queue_tags def update_finished( self, manager_name: str, results_compressed: Dict[int, bytes], *, session: Optional[Session] = None @@ -313,8 +314,9 @@ def claim_tasks( # The sort_date usually comes from the created_on of the record, or the created_on of the record's parent service stmt = stmt.order_by(TaskQueueORM.priority.desc(), TaskQueueORM.sort_date.asc(), TaskQueueORM.id.asc()) - # If tag is "*", then the manager will pull anything - if tag != "*": + # If tag is "*" (and strict_queue_tags is False), then the manager can pull anything + # If tag is "*" and strict_queue_tags is enabled, only pull tasks with tag == '*' + if tag != "*" or self._strict_queue_tags: stmt = stmt.filter(TaskQueueORM.tag == tag) # Skip locked rows - They may be in the process of being claimed by someone else diff --git a/qcfractal/qcfractal/components/tasks/test_socket_claim.py b/qcfractal/qcfractal/components/tasks/test_socket_claim.py index b68400a12..86d6bdbc3 100644 --- a/qcfractal/qcfractal/components/tasks/test_socket_claim.py +++ b/qcfractal/qcfractal/components/tasks/test_socket_claim.py @@ -4,6 +4,7 @@ from __future__ import annotations +from qcarchivetesting.testing_classes import QCATestingSnowflake from typing import TYPE_CHECKING import pytest @@ -300,6 +301,54 @@ def test_task_socket_claim_tag_wildcard(storage_socket: SQLAlchemySocket, sessio assert tasks[2]["id"] == recs[4].task.id +def test_task_socket_claim_tag_wildcard_strict(postgres_server, pytestconfig): + + pg_harness = postgres_server.get_new_harness("claim_tag_wildcard_strict") + encoding = pytestconfig.getoption("--client-encoding") + with QCATestingSnowflake(pg_harness, encoding=encoding, extra_config={"strict_queue_tags": True}) as snowflake: + storage_socket = snowflake.get_storage_socket() + + mname1 = ManagerName(cluster="test_cluster", hostname="a_host1", uuid="1234-5678-1234-5678") + mprog1 = {"qcengine": ["unknown"], "psi4": ["unknown"], "geometric": ["v3.0"]} + storage_socket.managers.activate( + name_data=mname1, + manager_version="v2.0", + username="bill", + programs=mprog1, + tags=["TAG3", "*"], + ) + + meta, id_1 = storage_socket.records.singlepoint.add( + [molecule_1], input_spec_1, "tag1", PriorityEnum.normal, None, None, True + ) + meta, id_2 = storage_socket.records.singlepoint.add( + [molecule_2], input_spec_2, "tag2", PriorityEnum.normal, None, None, True + ) + meta, id_3 = storage_socket.records.singlepoint.add( + [molecule_3], input_spec_3, "*", PriorityEnum.normal, None, None, True + ) + meta, id_4 = storage_socket.records.optimization.add( + [molecule_4], input_spec_4, "taG3", PriorityEnum.normal, None, None, True + ) + meta, id_5 = storage_socket.records.singlepoint.add( + [molecule_5], input_spec_5, "tag1", PriorityEnum.normal, None, None, True + ) + + all_id = id_1 + id_2 + id_3 + id_4 + id_5 + + client = snowflake.client() + recs = client.get_records(all_id, include=["task"]) + + # tag3 should be first, then only the * task (in order) + tasks = storage_socket.tasks.claim_tasks(mname1.fullname, mprog1, ["tag3", "*"], 2) + assert len(tasks) == 2 + assert tasks[0]["id"] == recs[3].task.id + assert tasks[1]["id"] == recs[2].task.id + + tasks = storage_socket.tasks.claim_tasks(mname1.fullname, mprog1, ["tag3", "*"], 3) + assert len(tasks) == 0 + + def test_task_socket_claim_program(storage_socket: SQLAlchemySocket, session: Session): mname1 = ManagerName(cluster="test_cluster", hostname="a_host1", uuid="1234-5678-1234-5678") diff --git a/qcfractal/qcfractal/config.py b/qcfractal/qcfractal/config.py index 6a6785bb9..3299549fa 100644 --- a/qcfractal/qcfractal/config.py +++ b/qcfractal/qcfractal/config.py @@ -338,6 +338,12 @@ class FractalConfig(ConfigBase): True, description="Allows unauthenticated read access to this instance. This does not extend to sensitive tables (such as user information)", ) + strict_queue_tags: bool = Field( + False, + description="If True, disables wildcard behavior for queue tags. This disables managers from claiming all " + "tags if they specify a wildcard ('*') tag. Managers will still be able to claim tasks with an " + "explicit '*' tag if they specifiy the '*' queue tag in their config", + ) # Logging and profiling logfile: Optional[str] = Field(