diff --git a/CHANGELOG.md b/CHANGELOG.md index 1af7603426..e21d662d6f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -17,6 +17,7 @@ ENHANCEMENTS: * Block anonymous access to 2 storage accounts ([#2524](https://github.com/microsoft/AzureTRE/pull/2524)) * Gitea shared service support app-service standard SKUs ([#2523](https://github.com/microsoft/AzureTRE/pull/2523)) * Keyvault diagnostic settings in base workspace ([#2521](https://github.com/microsoft/AzureTRE/pull/2521)) +* Airlock requests contain a field with information about the files that were submitted ([#2504](https://github.com/microsoft/AzureTRE/pull/2504)) BUG FIXES: diff --git a/airlock_processor/StatusChangedQueueTrigger/__init__.py b/airlock_processor/StatusChangedQueueTrigger/__init__.py index 83fda858d9..9af33a8792 100644 --- a/airlock_processor/StatusChangedQueueTrigger/__init__.py +++ b/airlock_processor/StatusChangedQueueTrigger/__init__.py @@ -31,17 +31,18 @@ def __init__(self, source_account_name: str, dest_account_name: str): def main(msg: func.ServiceBusMessage, outputEvent: func.Out[func.EventGridOutputEvent]): try: request_properties = extract_properties(msg) - handle_status_changed(request_properties) + request_files = get_request_files(request_properties) if request_properties.status == constants.STAGE_SUBMITTED else None + handle_status_changed(request_properties, outputEvent, request_files) except NoFilesInRequestException: - report_failure(outputEvent, request_properties, failure_reason=constants.NO_FILES_IN_REQUEST_MESSAGE) + set_output_event_to_report_failure(outputEvent, request_properties, failure_reason=constants.NO_FILES_IN_REQUEST_MESSAGE, request_files=request_files) except TooManyFilesInRequestException: - report_failure(outputEvent, request_properties, failure_reason=constants.TOO_MANY_FILES_IN_REQUEST_MESSAGE) + set_output_event_to_report_failure(outputEvent, request_properties, failure_reason=constants.TOO_MANY_FILES_IN_REQUEST_MESSAGE, request_files=request_files) except Exception: - report_failure(outputEvent, request_properties, failure_reason=constants.UNKNOWN_REASON_MESSAGE) + set_output_event_to_report_failure(outputEvent, request_properties, failure_reason=constants.UNKNOWN_REASON_MESSAGE, request_files=request_files) -def handle_status_changed(request_properties: RequestProperties): +def handle_status_changed(request_properties: RequestProperties, outputEvent: func.Out[func.EventGridOutputEvent], request_files): new_status = request_properties.status req_id = request_properties.request_id ws_id = request_properties.workspace_id @@ -65,6 +66,9 @@ def handle_status_changed(request_properties: RequestProperties): blob_operations.create_container(account_name, req_id) return + if new_status == constants.STAGE_SUBMITTED: + set_output_event_to_report_request_files(outputEvent, request_properties, request_files) + if (is_require_data_copy(new_status)): logging.info('Request with id %s. requires data copy between storage accounts', req_id) containers_metadata = get_source_dest_for_copy(new_status, request_type, ws_id) @@ -148,13 +152,30 @@ def get_source_dest_for_copy(new_status: str, request_type: str, short_workspace return ContainersCopyMetadata(source_account_name, dest_account_name) -def report_failure(outputEvent, request_properties, failure_reason): +def set_output_event_to_report_failure(outputEvent, request_properties, failure_reason, request_files): logging.exception(f"Failed processing Airlock request with ID: '{request_properties.request_id}', changing request status to '{constants.STAGE_FAILED}'.") outputEvent.set( func.EventGridOutputEvent( id=str(uuid.uuid4()), - data={"completed_step": request_properties.status, "new_status": constants.STAGE_FAILED, "request_id": request_properties.request_id, "error_message": failure_reason}, + data={"completed_step": request_properties.status, "new_status": constants.STAGE_FAILED, "request_id": request_properties.request_id, "request_files": request_files, "error_message": failure_reason}, + subject=request_properties.request_id, + event_type="Airlock.StepResult", + event_time=datetime.datetime.utcnow(), + data_version=constants.STEP_RESULT_EVENT_DATA_VERSION)) + + +def set_output_event_to_report_request_files(outputEvent, request_properties, request_files): + outputEvent.set( + func.EventGridOutputEvent( + id=str(uuid.uuid4()), + data={"completed_step": request_properties.status, "request_id": request_properties.request_id, "request_files": request_files}, subject=request_properties.request_id, event_type="Airlock.StepResult", event_time=datetime.datetime.utcnow(), data_version=constants.STEP_RESULT_EVENT_DATA_VERSION)) + + +def get_request_files(request_properties): + containers_metadata = get_source_dest_for_copy(request_properties.status, request_properties.type, request_properties.workspace_id) + storage_account_name = containers_metadata.source_account_name + return blob_operations.get_request_files(account_name=storage_account_name, request_id=request_properties.request_id) diff --git a/airlock_processor/_version.py b/airlock_processor/_version.py index 98a433b310..3dd3d2d51b 100644 --- a/airlock_processor/_version.py +++ b/airlock_processor/_version.py @@ -1 +1 @@ -__version__ = "0.4.5" +__version__ = "0.4.6" diff --git a/airlock_processor/shared_code/blob_operations.py b/airlock_processor/shared_code/blob_operations.py index b90461c121..42459ea2d6 100644 --- a/airlock_processor/shared_code/blob_operations.py +++ b/airlock_processor/shared_code/blob_operations.py @@ -34,6 +34,17 @@ def create_container(account_name: str, request_id: str): logging.info(f'Did not create a new container. Container already exists for request id: {request_id}.') +def get_request_files(account_name: str, request_id: str) -> list: + files = [] + blob_service_client = BlobServiceClient(account_url=get_account_url(account_name), credential=get_credential()) + container_client = blob_service_client.get_container_client(container=request_id) + + for blob in container_client.list_blobs(): + files.append({"name": blob.name, "size": blob.size}) + + return files + + def copy_data(source_account_name: str, destination_account_name: str, request_id: str): credential = get_credential() container_name = request_id diff --git a/airlock_processor/tests/test_status_change_queue_trigger.py b/airlock_processor/tests/test_status_change_queue_trigger.py index 59dc14fb63..bb198931d2 100644 --- a/airlock_processor/tests/test_status_change_queue_trigger.py +++ b/airlock_processor/tests/test_status_change_queue_trigger.py @@ -1,9 +1,13 @@ from json import JSONDecodeError +import os import unittest +from unittest import mock +from unittest.mock import MagicMock, patch from pydantic import ValidationError -from StatusChangedQueueTrigger import extract_properties, get_source_dest_for_copy, is_require_data_copy +from StatusChangedQueueTrigger import get_request_files, main, extract_properties, get_source_dest_for_copy, is_require_data_copy from azure.functions.servicebus import ServiceBusMessage +from shared_code import constants class TestPropertiesExtraction(unittest.TestCase): @@ -62,6 +66,49 @@ def test_wrong_type_raises_when_getting_storage_account_properties(self): self.assertRaises(Exception, get_source_dest_for_copy, "accepted", "somethingelse") +class TestFileEnumeration(unittest.TestCase): + @patch("StatusChangedQueueTrigger.set_output_event_to_report_request_files") + @patch("StatusChangedQueueTrigger.get_request_files") + @patch("StatusChangedQueueTrigger.is_require_data_copy", return_value=False) + @mock.patch.dict(os.environ, {"TRE_ID": "tre-id"}, clear=True) + def test_get_request_files_should_be_called_on_submit_stage(self, _, mock_get_request_files, mock_set_output_event_to_report_request_files): + message_body = "{ \"data\": { \"request_id\":\"123\",\"status\":\"submitted\" , \"type\":\"import\", \"workspace_id\":\"ws1\" }}" + message = _mock_service_bus_message(body=message_body) + main(msg=message, outputEvent=MagicMock()) + self.assertTrue(mock_get_request_files.called) + self.assertTrue(mock_set_output_event_to_report_request_files.called) + + @patch("StatusChangedQueueTrigger.set_output_event_to_report_failure") + @patch("StatusChangedQueueTrigger.get_request_files") + @patch("StatusChangedQueueTrigger.handle_status_changed") + def test_get_request_files_should_not_be_called_if_new_status_is_not_submit(self, _, mock_get_request_files, mock_set_output_event_to_report_failure): + message_body = "{ \"data\": { \"request_id\":\"123\",\"status\":\"fake-status\" , \"type\":\"import\", \"workspace_id\":\"ws1\" }}" + message = _mock_service_bus_message(body=message_body) + main(msg=message, outputEvent=MagicMock()) + self.assertFalse(mock_get_request_files.called) + self.assertFalse(mock_set_output_event_to_report_failure.called) + + @patch("StatusChangedQueueTrigger.set_output_event_to_report_failure") + @patch("StatusChangedQueueTrigger.get_request_files") + @patch("StatusChangedQueueTrigger.handle_status_changed", side_effect=Exception) + def test_get_request_files_should_be_called_when_failing_during_submit_stage(self, _, mock_get_request_files, mock_set_output_event_to_report_failure): + message_body = "{ \"data\": { \"request_id\":\"123\",\"status\":\"submitted\" , \"type\":\"import\", \"workspace_id\":\"ws1\" }}" + message = _mock_service_bus_message(body=message_body) + main(msg=message, outputEvent=MagicMock()) + self.assertTrue(mock_get_request_files.called) + self.assertTrue(mock_set_output_event_to_report_failure.called) + + @patch("StatusChangedQueueTrigger.blob_operations.get_request_files") + @mock.patch.dict(os.environ, {"TRE_ID": "tre-id"}, clear=True) + def test_get_request_files_called_with_correct_storage_account(self, mock_get_request_files): + source_storage_account_for_submitted_stage = constants.STORAGE_ACCOUNT_NAME_EXPORT_INTERNAL + 'ws1' + message_body = "{ \"data\": { \"request_id\":\"123\",\"status\":\"submitted\" , \"type\":\"export\", \"workspace_id\":\"ws1\" }}" + message = _mock_service_bus_message(body=message_body) + request_properties = extract_properties(message) + get_request_files(request_properties) + mock_get_request_files.assert_called_with(account_name=source_storage_account_for_submitted_stage, request_id=request_properties.request_id) + + def _mock_service_bus_message(body: str): encoded_body = str.encode(body, "utf-8") message = ServiceBusMessage(body=encoded_body, message_id="123", user_properties={}) diff --git a/api_app/_version.py b/api_app/_version.py index e427a55476..ece529aa19 100644 --- a/api_app/_version.py +++ b/api_app/_version.py @@ -1 +1 @@ -__version__ = "0.4.21" +__version__ = "0.4.22" diff --git a/api_app/api/routes/airlock_resource_helpers.py b/api_app/api/routes/airlock_resource_helpers.py index f3ae0a8387..79d0db8568 100644 --- a/api_app/api/routes/airlock_resource_helpers.py +++ b/api_app/api/routes/airlock_resource_helpers.py @@ -6,7 +6,7 @@ from fastapi import HTTPException from starlette import status from db.repositories.airlock_requests import AirlockRequestRepository -from models.domain.airlock_request import AirlockActions, AirlockRequest, AirlockRequestStatus, AirlockRequestType, AirlockReview +from models.domain.airlock_request import AirlockActions, AirlockFile, AirlockRequest, AirlockRequestStatus, AirlockRequestType, AirlockReview from event_grid.event_sender import send_status_changed_event, send_airlock_notification_event from models.domain.authentication import User from models.domain.workspace import Workspace @@ -42,10 +42,10 @@ async def save_and_publish_event_airlock_request(airlock_request: AirlockRequest raise HTTPException(status_code=status.HTTP_503_SERVICE_UNAVAILABLE, detail=strings.EVENT_GRID_GENERAL_ERROR_MESSAGE) -async def update_and_publish_event_airlock_request(airlock_request: AirlockRequest, airlock_request_repo: AirlockRequestRepository, user: User, new_status: AirlockRequestStatus, workspace: Workspace, error_message: str = None, airlock_review: AirlockReview = None): +async def update_and_publish_event_airlock_request(airlock_request: AirlockRequest, airlock_request_repo: AirlockRequestRepository, user: User, new_status: AirlockRequestStatus, workspace: Workspace, request_files: List[AirlockFile] = None, error_message: str = None, airlock_review: AirlockReview = None): try: logging.debug(f"Updating airlock request item: {airlock_request.id}") - updated_airlock_request = airlock_request_repo.update_airlock_request(airlock_request=airlock_request, new_status=new_status, user=user, error_message=error_message, airlock_review=airlock_review) + updated_airlock_request = airlock_request_repo.update_airlock_request(original_request=airlock_request, user=user, new_status=new_status, request_files=request_files, error_message=error_message, airlock_review=airlock_review) except Exception as e: logging.error(f'Failed updating airlock_request item {airlock_request}: {e}') # If the validation failed, the error was not related to the saving itself diff --git a/api_app/db/repositories/airlock_requests.py b/api_app/db/repositories/airlock_requests.py index cb8e138500..a2cf5bc205 100644 --- a/api_app/db/repositories/airlock_requests.py +++ b/api_app/db/repositories/airlock_requests.py @@ -4,18 +4,19 @@ from datetime import datetime from typing import List from pydantic import UUID4 -from azure.cosmos.exceptions import CosmosResourceNotFoundError +from azure.cosmos.exceptions import CosmosResourceNotFoundError, CosmosAccessConditionFailedError from azure.cosmos import CosmosClient from starlette import status from fastapi import HTTPException from pydantic import parse_obj_as from models.domain.authentication import User from db.errors import EntityDoesNotExist -from models.domain.airlock_request import AirlockRequest, AirlockRequestStatus, AirlockReview, AirlockReviewDecision, AirlockRequestHistoryItem, AirlockRequestType +from models.domain.airlock_request import AirlockFile, AirlockRequest, AirlockRequestStatus, AirlockReview, AirlockReviewDecision, AirlockRequestHistoryItem, AirlockRequestType from models.schemas.airlock_request import AirlockRequestInCreate, AirlockReviewInCreate from core import config from resources import strings from db.repositories.base import BaseRepository +import logging class AirlockRequestRepository(BaseRepository): @@ -43,7 +44,7 @@ def update_airlock_request_item(self, original_request: AirlockRequest, new_requ new_request.user = user new_request.updatedWhen = self.get_timestamp() - self.update_item(new_request) + self.upsert_item_with_etag(new_request, new_request.etag) return new_request @staticmethod @@ -126,21 +127,17 @@ def get_airlock_request_by_id(self, airlock_request_id: UUID4) -> AirlockRequest raise EntityDoesNotExist return parse_obj_as(AirlockRequest, airlock_requests) - def update_airlock_request(self, airlock_request: AirlockRequest, new_status: AirlockRequestStatus, user: User, error_message: str = None, airlock_review: AirlockReview = None) -> AirlockRequest: - current_status = airlock_request.status - if self.validate_status_update(current_status, new_status): - updated_request = copy.deepcopy(airlock_request) - updated_request.status = new_status - if new_status == AirlockRequestStatus.Failed: - updated_request.errorMessage = error_message - if airlock_review is not None: - if updated_request.reviews is None: - updated_request.reviews = [airlock_review] - else: - updated_request.reviews.append(airlock_review) - return self.update_airlock_request_item(airlock_request, updated_request, user, {"previousStatus": current_status}) - else: - raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=strings.AIRLOCK_REQUEST_ILLEGAL_STATUS_CHANGE) + def update_airlock_request(self, original_request: AirlockRequest, user: User, new_status: AirlockRequestStatus = None, request_files: List[AirlockFile] = None, error_message: str = None, airlock_review: AirlockReview = None) -> AirlockRequest: + updated_request = self._build_updated_request(original_request=original_request, new_status=new_status, request_files=request_files, error_message=error_message, airlock_review=airlock_review) + try: + db_response = self.update_airlock_request_item(original_request, updated_request, user, {"previousStatus": original_request.status}) + except CosmosAccessConditionFailedError: + logging.warning(f"ETag mismatch for request ID: '{original_request.id}'. Retrying.") + original_request = self.get_airlock_request_by_id(original_request.id) + updated_request = self._build_updated_request(original_request=original_request, new_status=new_status, request_files=request_files, error_message=error_message, airlock_review=airlock_review) + db_response = self.update_airlock_request_item(original_request, updated_request, user, {"previousStatus": original_request.status}) + + return db_response def get_airlock_request_spec_params(self): return self.get_resource_base_spec_params() @@ -158,3 +155,28 @@ def create_airlock_review_item(self, airlock_review_input: AirlockReviewInCreate ) return airlock_review + + def _build_updated_request(self, original_request: AirlockRequest, new_status: AirlockRequestStatus = None, request_files: List[AirlockFile] = None, error_message: str = None, airlock_review: AirlockReview = None) -> AirlockRequest: + updated_request = copy.deepcopy(original_request) + + if new_status is not None: + self._validate_status_update(current_status=original_request.status, new_status=new_status) + updated_request.status = new_status + + if error_message is not None: + updated_request.errorMessage = error_message + + if request_files is not None: + updated_request.files = request_files + + if airlock_review is not None: + if updated_request.reviews is None: + updated_request.reviews = [airlock_review] + else: + updated_request.reviews.append(airlock_review) + + return updated_request + + def _validate_status_update(self, current_status, new_status): + if not self.validate_status_update(current_status=current_status, new_status=new_status): + raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=strings.AIRLOCK_REQUEST_ILLEGAL_STATUS_CHANGE) diff --git a/api_app/db/repositories/base.py b/api_app/db/repositories/base.py index 211813d3a2..a55527b0f3 100644 --- a/api_app/db/repositories/base.py +++ b/api_app/db/repositories/base.py @@ -44,6 +44,9 @@ def update_item_with_etag(self, item: BaseModel, etag: str) -> BaseModel: self.container.replace_item(item=item.id, body=item.dict(), etag=etag, match_condition=MatchConditions.IfNotModified) return self.read_item_by_id(item.id) + def upsert_item_with_etag(self, item: BaseModel, etag: str) -> BaseModel: + return self.container.upsert_item(body=item.dict(), etag=etag, match_condition=MatchConditions.IfNotModified) + def update_item_dict(self, item_dict: dict): self.container.upsert_item(body=item_dict) diff --git a/api_app/models/domain/airlock_operations.py b/api_app/models/domain/airlock_operations.py index 0c60f1f92c..1f0fc57f65 100644 --- a/api_app/models/domain/airlock_operations.py +++ b/api_app/models/domain/airlock_operations.py @@ -2,12 +2,15 @@ from pydantic.types import UUID4 from pydantic.schema import Optional from models.domain.azuretremodel import AzureTREModel +from typing import List +from models.domain.airlock_request import AirlockFile class EventGridMessageData(AzureTREModel): completed_step: str = Field(title="", description="") - new_status: str = Field(title="", description="") + new_status: Optional[str] = Field(title="", description="") request_id: str = Field(title="", description="") + request_files: Optional[List[AirlockFile]] = Field(title="", description="") error_message: Optional[str] = Field(title="", description="") diff --git a/api_app/models/domain/airlock_request.py b/api_app/models/domain/airlock_request.py index d6e6df5503..9c08fb7414 100644 --- a/api_app/models/domain/airlock_request.py +++ b/api_app/models/domain/airlock_request.py @@ -1,6 +1,6 @@ from typing import List from enum import Enum -from pydantic import Field +from pydantic import Field, validator from pydantic.schema import Optional from resources import strings from models.domain.azuretremodel import AzureTREModel @@ -35,6 +35,11 @@ class AirlockActions(str, Enum): Submit = strings.AIRLOCK_ACTION_SUBMIT +class AirlockFile(AzureTREModel): + name: str = Field(title="name", description="name of the file") + size: float = Field(title="size", description="size of the file in bytes") + + class AirlockReviewDecision(str, Enum): Approved = strings.AIRLOCK_RESOURCE_STATUS_APPROVAL_INPROGRESS Rejected = strings.AIRLOCK_RESOURCE_STATUS_REJECTION_INPROGRESS @@ -72,9 +77,16 @@ class AirlockRequest(AzureTREModel): history: List[AirlockRequestHistoryItem] = [] workspaceId: str = Field("", title="Workspace ID", description="Service target Workspace id") requestType: AirlockRequestType = Field("", title="Airlock request type") - files: List[str] = Field([], title="Files of the request") + files: List[AirlockFile] = Field([], title="Files of the request") businessJustification: str = Field("Business Justifications", title="Explanation that will be provided to the request reviewer") status = AirlockRequestStatus.Draft creationTime: float = Field(None, title="Creation time of the request") errorMessage: Optional[str] = Field(title="Present only if the request have failed, provides the reason of the failure.") reviews: Optional[List[AirlockReview]] + etag: Optional[str] = Field(title="_etag", alias="_etag") + + # SQL API CosmosDB saves ETag as an escaped string: https://github.com/microsoft/AzureTRE/issues/1931 + @validator("etag", pre=True) + def parse_etag_to_remove_escaped_quotes(cls, value): + if value: + return value.replace('\"', '') diff --git a/api_app/service_bus/airlock_request_status_update.py b/api_app/service_bus/airlock_request_status_update.py index f9acc9d54d..095d252c66 100644 --- a/api_app/service_bus/airlock_request_status_update.py +++ b/api_app/service_bus/airlock_request_status_update.py @@ -56,15 +56,16 @@ async def update_status_in_database(airlock_request_repo: AirlockRequestReposito step_result_data = step_result_message.data airlock_request_id = step_result_data.request_id current_status = step_result_data.completed_step - new_status = AirlockRequestStatus(step_result_data.new_status) + new_status = AirlockRequestStatus(step_result_data.new_status) if step_result_data.new_status else None error_message = step_result_data.error_message + request_files = step_result_data.request_files # Find the airlock request by id airlock_request = await get_airlock_request_by_id_from_path(airlock_request_id=airlock_request_id, airlock_request_repo=airlock_request_repo) # Validate that the airlock request status is the same as current status if airlock_request.status == current_status: workspace = workspace_repo.get_workspace_by_id(airlock_request.workspaceId) # update to new status and send to event grid - await update_and_publish_event_airlock_request(airlock_request=airlock_request, airlock_request_repo=airlock_request_repo, user=airlock_request.user, new_status=new_status, workspace=workspace, error_message=error_message) + await update_and_publish_event_airlock_request(airlock_request=airlock_request, airlock_request_repo=airlock_request_repo, user=airlock_request.user, new_status=new_status, workspace=workspace, request_files=request_files, error_message=error_message) result = True else: error_string = strings.STEP_RESULT_MESSAGE_STATUS_DOES_NOT_MATCH.format(airlock_request_id, current_status, airlock_request.status) diff --git a/api_app/tests_ma/test_db/test_repositories/test_airlock_request_repository.py b/api_app/tests_ma/test_db/test_repositories/test_airlock_request_repository.py index 8c3fc7728d..754f99eac8 100644 --- a/api_app/tests_ma/test_db/test_repositories/test_airlock_request_repository.py +++ b/api_app/tests_ma/test_db/test_repositories/test_airlock_request_repository.py @@ -7,7 +7,7 @@ from db.repositories.airlock_requests import AirlockRequestRepository from db.errors import EntityDoesNotExist -from azure.cosmos.exceptions import CosmosResourceNotFoundError +from azure.cosmos.exceptions import CosmosResourceNotFoundError, CosmosAccessConditionFailedError WORKSPACE_ID = "abc000d3-82da-4bfc-b6e9-9a7853ef753e" @@ -113,7 +113,7 @@ def test_create_airlock_request_item_creates_an_airlock_request_with_the_right_v def test_update_airlock_request_with_allowed_new_status_should_update_request_status(airlock_request_repo, current_status, new_status, verify_dictionary_contains_all_enum_values): user = create_test_user() mock_existing_request = airlock_request_mock(status=current_status) - airlock_request = airlock_request_repo.update_airlock_request(mock_existing_request, new_status, user) + airlock_request = airlock_request_repo.update_airlock_request(mock_existing_request, user, new_status) assert airlock_request.status == new_status @@ -122,7 +122,17 @@ def test_update_airlock_request_with_forbidden_status_should_fail_on_validation( user = create_test_user() mock_existing_request = airlock_request_mock(status=current_status) with pytest.raises(HTTPException): - airlock_request_repo.update_airlock_request(mock_existing_request, new_status, user) + airlock_request_repo.update_airlock_request(mock_existing_request, user, new_status) + + +@patch("db.repositories.airlock_requests.AirlockRequestRepository.update_airlock_request_item", side_effect=[CosmosAccessConditionFailedError, None]) +@patch("db.repositories.airlock_requests.AirlockRequestRepository.get_airlock_request_by_id", return_value=airlock_request_mock(status=DRAFT)) +def test_update_airlock_request_should_retry_update_when_etag_is_not_up_to_date(_, update_airlock_request_item_mock, airlock_request_repo): + expected_update_attempts = 2 + user = create_test_user() + mock_existing_request = airlock_request_mock(status=DRAFT) + airlock_request_repo.update_airlock_request(original_request=mock_existing_request, user=user, new_status=SUBMITTED) + assert update_airlock_request_item_mock.call_count == expected_update_attempts def test_get_airlock_requests_queries_db(airlock_request_repo): diff --git a/api_app/tests_ma/test_service_bus/test_airlock_request_status_update.py b/api_app/tests_ma/test_service_bus/test_airlock_request_status_update.py index d0c97e1719..e4cdf1762b 100644 --- a/api_app/tests_ma/test_service_bus/test_airlock_request_status_update.py +++ b/api_app/tests_ma/test_service_bus/test_airlock_request_status_update.py @@ -99,7 +99,7 @@ async def test_receiving_good_message(_, app, sb_client, logging_mock, workspace await receive_step_result_message_and_update_status(app) airlock_request_repo().get_airlock_request_by_id.assert_called_once_with(test_sb_step_result_message["data"]["request_id"]) - airlock_request_repo().update_airlock_request.assert_called_once_with(airlock_request=expected_airlock_request, new_status=test_sb_step_result_message["data"]["new_status"], user=expected_airlock_request.user, error_message=None, airlock_review=None) + airlock_request_repo().update_airlock_request.assert_called_once_with(original_request=expected_airlock_request, user=expected_airlock_request.user, new_status=test_sb_step_result_message["data"]["new_status"], request_files=None, error_message=None, airlock_review=None) assert eg_client().send.call_count == 2 logging_mock.assert_not_called() sb_client().get_queue_receiver().complete_message.assert_called_once_with(service_bus_received_message_mock)