Skip to content

Commit

Permalink
Allowing clients to provide TaskRun ids (#13683)
Browse files Browse the repository at this point in the history
Co-authored-by: zangell44 <zachary.james.angell@gmail.com>
Co-authored-by: Zach Angell <42625717+zangell44@users.noreply.github.com>
  • Loading branch information
3 people authored May 31, 2024
1 parent 9f9083b commit 255ef98
Show file tree
Hide file tree
Showing 4 changed files with 62 additions and 2 deletions.
1 change: 0 additions & 1 deletion .github/workflows/codeql-analysis.yml
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,6 @@ jobs:
languages: ${{ matrix.language }}
config-file: ./.github/codeql-config.yml
queries: security-extended
setup-python-dependencies: false

- name: Perform CodeQL Analysis
uses: github/codeql-action/analyze@v3
6 changes: 5 additions & 1 deletion src/prefect/server/api/task_runs.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,8 +56,12 @@ async def create_task_run(
If no state is provided, the task run will be created in a PENDING state.
"""

# hydrate the input model into a full task run / state model
task_run = schemas.core.TaskRun(**task_run.dict())
task_run_dict = task_run.dict()
if not task_run_dict.get("id"):
task_run_dict.pop("id", None)
task_run = schemas.core.TaskRun(**task_run_dict)

if not task_run.state:
task_run.state = schemas.states.Pending()
Expand Down
5 changes: 5 additions & 0 deletions src/prefect/server/schemas/actions.py
Original file line number Diff line number Diff line change
Expand Up @@ -421,6 +421,11 @@ def default_scheduled_start_time(cls, values):
class TaskRunCreate(ActionBaseModel):
"""Data used by the Prefect REST API to create a task run"""

id: Optional[UUID] = Field(
default=None,
description="The ID to use for the task run. If not provided, a random UUID will be generated.",
)

# TaskRunCreate states must be provided as StateCreate objects
state: Optional[StateCreate] = Field(
default=None, description="The state of the task run to create"
Expand Down
52 changes: 52 additions & 0 deletions tests/server/api/test_task_runs.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import uuid
from uuid import uuid4

import pendulum
Expand Down Expand Up @@ -145,6 +146,57 @@ async def test_raises_on_jitter_factor_validation(self, flow_run, client, sessio
== "`retry_jitter_factor` must be >= 0."
)

async def test_create_task_run_with_client_provided_id(self, flow_run, client):
client_provided_id = uuid.uuid4()
task_run_data = {
"flow_run_id": str(flow_run.id),
"task_key": "my-task-key",
"name": "my-cool-task-run-name",
"dynamic_key": "0",
"id": str(client_provided_id),
}
response = await client.post(
"/task_runs/",
json=task_run_data,
)
assert response.status_code == 201
assert response.json()["id"] == str(client_provided_id)

async def test_create_task_run_with_same_client_provided_id(
self,
flow_run,
client,
):
client_provided_id = uuid.uuid4()
task_run_data = {
"flow_run_id": str(flow_run.id),
"task_key": "my-task-key",
"name": "my-cool-task-run-name",
"dynamic_key": "0",
"id": str(client_provided_id),
}
response = await client.post(
"/task_runs/",
json=task_run_data,
)
assert response.status_code == 201
assert response.json()["id"] == str(client_provided_id)

task_run_data = {
"flow_run_id": str(flow_run.id),
"task_key": "my-task-key",
"name": "my-cool-task-run-name",
"dynamic_key": "1",
"id": str(client_provided_id),
}

response = await client.post(
"/task_runs/",
json=task_run_data,
)

assert response.status_code == 409


class TestReadTaskRun:
async def test_read_task_run(self, flow_run, task_run, client):
Expand Down

0 comments on commit 255ef98

Please sign in to comment.