Skip to content

Commit

Permalink
add patch method for submission resource
Browse files Browse the repository at this point in the history
  • Loading branch information
yang-ruoxi committed Jun 22, 2023
1 parent 29e5ec8 commit ee488c7
Show file tree
Hide file tree
Showing 2 changed files with 88 additions and 2 deletions.
88 changes: 87 additions & 1 deletion src/maggma/api/resource/submission.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@

class SubmissionResource(Resource):
"""
Implements a REST Compatible Resource as POST and/or GET URL endpoints
Implements a REST Compatible Resource as POST and/or GET and/or PATCH URL endpoints
for submitted data.
"""

Expand All @@ -30,6 +30,7 @@ def __init__(
model: Type[BaseModel],
post_query_operators: List[QueryOperator],
get_query_operators: List[QueryOperator],
patch_query_operators: Optional[List[QueryOperator]] = None,
tags: Optional[List[str]] = None,
timeout: Optional[int] = None,
include_in_schema: Optional[bool] = True,
Expand All @@ -40,6 +41,7 @@ def __init__(
calculate_submission_id: Optional[bool] = False,
get_sub_path: Optional[str] = "/",
post_sub_path: Optional[str] = "/",
patch_sub_path: Optional[str] = "/"
):
"""
Args:
Expand All @@ -50,6 +52,7 @@ def __init__(
before raising a timeout error
post_query_operators: Operators for the query language for post data
get_query_operators: Operators for the query language for get data
patch_query_operators: Operators for the query language for patch data
include_in_schema: Whether to include the submission resource in the documented schema
duplicate_fields_check: Fields in model used to check for duplicates for POST data
enable_default_search: Enable default endpoint search behavior.
Expand All @@ -59,6 +62,7 @@ def __init__(
If False, the store key is used instead.
get_sub_path: GET sub-URL path for the resource.
post_sub_path: POST sub-URL path for the resource.
patch_sub_path: PATCH sub-URL path for the resource.
"""

if isinstance(state_enum, Enum) and default_state not in [entry.value for entry in state_enum]: # type: ignore
Expand All @@ -75,12 +79,14 @@ def __init__(
if state_enum is not None
else get_query_operators
)
self.patch_query_operators = patch_query_operators
self.include_in_schema = include_in_schema
self.duplicate_fields_check = duplicate_fields_check
self.enable_default_search = enable_default_search
self.calculate_submission_id = calculate_submission_id
self.get_sub_path = get_sub_path
self.post_sub_path = post_sub_path
self.patch_sub_path = patch_sub_path

new_fields = {} # type: dict
if self.calculate_submission_id:
Expand Down Expand Up @@ -120,6 +126,9 @@ def prepare_endpoint(self):

self.build_post_data()

if self.patch_query_operators:
self.build_patch_data()

def build_get_by_key(self):
model_name = self.model.__name__

Expand Down Expand Up @@ -323,3 +332,80 @@ def post_data(**queries: STORE_PARAMS):
response_model_exclude_unset=True,
include_in_schema=self.include_in_schema,
)(attach_query_ops(post_data, self.post_query_operators))


def build_patch_data(self):
model_name = self.model.__name__

def patch_data(**queries: STORE_PARAMS):

request: Request = queries.pop("request") # type: ignore
queries.pop("temp_response") # type: ignore

query: STORE_PARAMS = merge_queries(list(queries.values()))

query_params = [
entry
for _, i in enumerate(self.patch_query_operators) # type: ignore
for entry in signature(i.query).parameters
]

overlap = [key for key in request.query_params.keys() if key not in query_params]
if any(overlap):
raise HTTPException(
status_code=404,
detail="Request contains query parameters which cannot be used: {}".format(", ".join(overlap)),
)

self.store.connect(force_reset=True)

# Check for duplicate entry
if self.duplicate_fields_check:
duplicate = self.store.query_one(
criteria={field: query["criteria"][field] for field in self.duplicate_fields_check}
)

if duplicate:
raise HTTPException(
status_code=400,
detail="Submission already exists. Duplicate data found for fields: {}".format(
", ".join(self.duplicate_fields_check)
),
)

if self.calculate_submission_id:
query["criteria"]["submission_id"] = str(uuid4())

if self.state_enum is not None:
query["criteria"]["state"] = [self.default_state]
query["criteria"]["updated"] = [datetime.utcnow()]

if query.get("update"):
try:
self.store._collection.update_one(
filter=query["criteria"],
update={'$set': query["update"]},
upsert=False
)
except Exception:
raise HTTPException(
status_code=400,
detail="Problem when trying to patch data.",
)

response = {
"updated_data": query["update"],
"meta": "Submission successful",
}

return response

self.router.patch(
self.patch_sub_path,
tags=self.tags,
summary=f"Patch {model_name} data",
response_model=None,
response_description=f"Patch {model_name} data",
response_model_exclude_unset=True,
include_in_schema=self.include_in_schema,
)(attach_query_ops(patch_data, self.patch_query_operators))
2 changes: 1 addition & 1 deletion src/maggma/api/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
QUERY_PARAMS = ["criteria", "properties", "skip", "limit"]
STORE_PARAMS = Dict[
Literal[
"criteria", "properties", "sort", "skip", "limit", "request", "pipeline", "hint"
"criteria", "properties", "sort", "skip", "limit", "request", "pipeline", "hint", "update"
],
Any,
]
Expand Down

0 comments on commit ee488c7

Please sign in to comment.