From ee488c76f39ef557c1490914b5ea84ab15ada187 Mon Sep 17 00:00:00 2001 From: yang-ruoxi Date: Thu, 22 Jun 2023 10:37:56 -0700 Subject: [PATCH] add patch method for submission resource --- src/maggma/api/resource/submission.py | 88 ++++++++++++++++++++++++++- src/maggma/api/utils.py | 2 +- 2 files changed, 88 insertions(+), 2 deletions(-) diff --git a/src/maggma/api/resource/submission.py b/src/maggma/api/resource/submission.py index 1176de583..5d1936870 100644 --- a/src/maggma/api/resource/submission.py +++ b/src/maggma/api/resource/submission.py @@ -20,7 +20,7 @@ class SubmissionResource(Resource): """ - Implements a REST Compatible Resource as POST and/or GET URL endpoints + Implements a REST Compatible Resource as POST and/or GET and/or PATCH URL endpoints for submitted data. """ @@ -30,6 +30,7 @@ def __init__( model: Type[BaseModel], post_query_operators: List[QueryOperator], get_query_operators: List[QueryOperator], + patch_query_operators: Optional[List[QueryOperator]] = None, tags: Optional[List[str]] = None, timeout: Optional[int] = None, include_in_schema: Optional[bool] = True, @@ -40,6 +41,7 @@ def __init__( calculate_submission_id: Optional[bool] = False, get_sub_path: Optional[str] = "/", post_sub_path: Optional[str] = "/", + patch_sub_path: Optional[str] = "/" ): """ Args: @@ -50,6 +52,7 @@ def __init__( before raising a timeout error post_query_operators: Operators for the query language for post data get_query_operators: Operators for the query language for get data + patch_query_operators: Operators for the query language for patch data include_in_schema: Whether to include the submission resource in the documented schema duplicate_fields_check: Fields in model used to check for duplicates for POST data enable_default_search: Enable default endpoint search behavior. @@ -59,6 +62,7 @@ def __init__( If False, the store key is used instead. get_sub_path: GET sub-URL path for the resource. post_sub_path: POST sub-URL path for the resource. + patch_sub_path: PATCH sub-URL path for the resource. """ if isinstance(state_enum, Enum) and default_state not in [entry.value for entry in state_enum]: # type: ignore @@ -75,12 +79,14 @@ def __init__( if state_enum is not None else get_query_operators ) + self.patch_query_operators = patch_query_operators self.include_in_schema = include_in_schema self.duplicate_fields_check = duplicate_fields_check self.enable_default_search = enable_default_search self.calculate_submission_id = calculate_submission_id self.get_sub_path = get_sub_path self.post_sub_path = post_sub_path + self.patch_sub_path = patch_sub_path new_fields = {} # type: dict if self.calculate_submission_id: @@ -120,6 +126,9 @@ def prepare_endpoint(self): self.build_post_data() + if self.patch_query_operators: + self.build_patch_data() + def build_get_by_key(self): model_name = self.model.__name__ @@ -323,3 +332,80 @@ def post_data(**queries: STORE_PARAMS): response_model_exclude_unset=True, include_in_schema=self.include_in_schema, )(attach_query_ops(post_data, self.post_query_operators)) + + + def build_patch_data(self): + model_name = self.model.__name__ + + def patch_data(**queries: STORE_PARAMS): + + request: Request = queries.pop("request") # type: ignore + queries.pop("temp_response") # type: ignore + + query: STORE_PARAMS = merge_queries(list(queries.values())) + + query_params = [ + entry + for _, i in enumerate(self.patch_query_operators) # type: ignore + for entry in signature(i.query).parameters + ] + + overlap = [key for key in request.query_params.keys() if key not in query_params] + if any(overlap): + raise HTTPException( + status_code=404, + detail="Request contains query parameters which cannot be used: {}".format(", ".join(overlap)), + ) + + self.store.connect(force_reset=True) + + # Check for duplicate entry + if self.duplicate_fields_check: + duplicate = self.store.query_one( + criteria={field: query["criteria"][field] for field in self.duplicate_fields_check} + ) + + if duplicate: + raise HTTPException( + status_code=400, + detail="Submission already exists. Duplicate data found for fields: {}".format( + ", ".join(self.duplicate_fields_check) + ), + ) + + if self.calculate_submission_id: + query["criteria"]["submission_id"] = str(uuid4()) + + if self.state_enum is not None: + query["criteria"]["state"] = [self.default_state] + query["criteria"]["updated"] = [datetime.utcnow()] + + if query.get("update"): + try: + self.store._collection.update_one( + filter=query["criteria"], + update={'$set': query["update"]}, + upsert=False + ) + except Exception: + raise HTTPException( + status_code=400, + detail="Problem when trying to patch data.", + ) + + response = { + "updated_data": query["update"], + "meta": "Submission successful", + } + + return response + + self.router.patch( + self.patch_sub_path, + tags=self.tags, + summary=f"Patch {model_name} data", + response_model=None, + response_description=f"Patch {model_name} data", + response_model_exclude_unset=True, + include_in_schema=self.include_in_schema, + )(attach_query_ops(patch_data, self.patch_query_operators)) diff --git a/src/maggma/api/utils.py b/src/maggma/api/utils.py index 95c3b2bda..0345386fe 100644 --- a/src/maggma/api/utils.py +++ b/src/maggma/api/utils.py @@ -19,7 +19,7 @@ QUERY_PARAMS = ["criteria", "properties", "skip", "limit"] STORE_PARAMS = Dict[ Literal[ - "criteria", "properties", "sort", "skip", "limit", "request", "pipeline", "hint" + "criteria", "properties", "sort", "skip", "limit", "request", "pipeline", "hint", "update" ], Any, ]