Skip to content

Commit

Permalink
Merge pull request #809 from materialsproject/add_patch_resource
Browse files Browse the repository at this point in the history
add patch method for submission resource
  • Loading branch information
Jason Munro authored Jun 22, 2023
2 parents 29e5ec8 + 841dcb4 commit a01d007
Show file tree
Hide file tree
Showing 3 changed files with 179 additions and 18 deletions.
135 changes: 122 additions & 13 deletions src/maggma/api/resource/submission.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@

class SubmissionResource(Resource):
"""
Implements a REST Compatible Resource as POST and/or GET URL endpoints
Implements a REST Compatible Resource as POST and/or GET and/or PATCH URL endpoints
for submitted data.
"""

Expand All @@ -30,6 +30,7 @@ def __init__(
model: Type[BaseModel],
post_query_operators: List[QueryOperator],
get_query_operators: List[QueryOperator],
patch_query_operators: Optional[List[QueryOperator]] = None,
tags: Optional[List[str]] = None,
timeout: Optional[int] = None,
include_in_schema: Optional[bool] = True,
Expand All @@ -40,6 +41,7 @@ def __init__(
calculate_submission_id: Optional[bool] = False,
get_sub_path: Optional[str] = "/",
post_sub_path: Optional[str] = "/",
patch_sub_path: Optional[str] = "/",
):
"""
Args:
Expand All @@ -50,6 +52,7 @@ def __init__(
before raising a timeout error
post_query_operators: Operators for the query language for post data
get_query_operators: Operators for the query language for get data
patch_query_operators: Operators for the query language for patch data
include_in_schema: Whether to include the submission resource in the documented schema
duplicate_fields_check: Fields in model used to check for duplicates for POST data
enable_default_search: Enable default endpoint search behavior.
Expand All @@ -59,10 +62,13 @@ def __init__(
If False, the store key is used instead.
get_sub_path: GET sub-URL path for the resource.
post_sub_path: POST sub-URL path for the resource.
patch_sub_path: PATCH sub-URL path for the resource.
"""

if isinstance(state_enum, Enum) and default_state not in [entry.value for entry in state_enum]: # type: ignore
raise RuntimeError("If data is stateful a state enum and valid default value must be provided")
raise RuntimeError(
"If data is stateful a state enum and valid default value must be provided"
)

self.state_enum = state_enum
self.default_state = default_state
Expand All @@ -75,12 +81,14 @@ def __init__(
if state_enum is not None
else get_query_operators
)
self.patch_query_operators = patch_query_operators
self.include_in_schema = include_in_schema
self.duplicate_fields_check = duplicate_fields_check
self.enable_default_search = enable_default_search
self.calculate_submission_id = calculate_submission_id
self.get_sub_path = get_sub_path
self.post_sub_path = post_sub_path
self.patch_sub_path = patch_sub_path

new_fields = {} # type: dict
if self.calculate_submission_id:
Expand Down Expand Up @@ -120,6 +128,9 @@ def prepare_endpoint(self):

self.build_post_data()

if self.patch_query_operators:
self.build_patch_data()

def build_get_by_key(self):
model_name = self.model.__name__

Expand Down Expand Up @@ -180,11 +191,9 @@ def get_by_key(
)(get_by_key)

def build_search_data(self):

model_name = self.model.__name__

def search(**queries: STORE_PARAMS):

request: Request = queries.pop("request") # type: ignore
queries.pop("temp_response") # type: ignore

Expand All @@ -196,29 +205,41 @@ def search(**queries: STORE_PARAMS):
for entry in signature(i.query).parameters
]

overlap = [key for key in request.query_params.keys() if key not in query_params]
overlap = [
key for key in request.query_params.keys() if key not in query_params
]
if any(overlap):
raise HTTPException(
status_code=404,
detail="Request contains query parameters which cannot be used: {}".format(", ".join(overlap)),
detail="Request contains query parameters which cannot be used: {}".format(
", ".join(overlap)
),
)

self.store.connect(force_reset=True)

try:
with query_timeout(self.timeout):
count = self.store.count( # type: ignore
**{field: query[field] for field in query if field in ["criteria", "hint"]}
**{
field: query[field]
for field in query
if field in ["criteria", "hint"]
}
)
if isinstance(self.store, S3Store):
data = list(self.store.query(**query)) # type: ignore
else:

pipeline = generate_query_pipeline(query, self.store)

data = list(
self.store._collection.aggregate(
pipeline, **{field: query[field] for field in query if field in ["hint"]}
pipeline,
**{
field: query[field]
for field in query
if field in ["hint"]
},
)
)
except (NetworkTimeout, PyMongoError) as e:
Expand Down Expand Up @@ -257,7 +278,6 @@ def build_post_data(self):
model_name = self.model.__name__

def post_data(**queries: STORE_PARAMS):

request: Request = queries.pop("request") # type: ignore
queries.pop("temp_response") # type: ignore

Expand All @@ -269,19 +289,26 @@ def post_data(**queries: STORE_PARAMS):
for entry in signature(i.query).parameters
]

overlap = [key for key in request.query_params.keys() if key not in query_params]
overlap = [
key for key in request.query_params.keys() if key not in query_params
]
if any(overlap):
raise HTTPException(
status_code=404,
detail="Request contains query parameters which cannot be used: {}".format(", ".join(overlap)),
detail="Request contains query parameters which cannot be used: {}".format(
", ".join(overlap)
),
)

self.store.connect(force_reset=True)

# Check for duplicate entry
if self.duplicate_fields_check:
duplicate = self.store.query_one(
criteria={field: query["criteria"][field] for field in self.duplicate_fields_check}
criteria={
field: query["criteria"][field]
for field in self.duplicate_fields_check
}
)

if duplicate:
Expand Down Expand Up @@ -323,3 +350,85 @@ def post_data(**queries: STORE_PARAMS):
response_model_exclude_unset=True,
include_in_schema=self.include_in_schema,
)(attach_query_ops(post_data, self.post_query_operators))

def build_patch_data(self):
model_name = self.model.__name__

def patch_data(**queries: STORE_PARAMS):
request: Request = queries.pop("request") # type: ignore
queries.pop("temp_response") # type: ignore

query: STORE_PARAMS = merge_queries(list(queries.values()))

query_params = [
entry
for _, i in enumerate(self.patch_query_operators) # type: ignore
for entry in signature(i.query).parameters
]

overlap = [
key for key in request.query_params.keys() if key not in query_params
]
if any(overlap):
raise HTTPException(
status_code=404,
detail="Request contains query parameters which cannot be used: {}".format(
", ".join(overlap)
),
)

self.store.connect(force_reset=True)

# Check for duplicate entry
if self.duplicate_fields_check:
duplicate = self.store.query_one(
criteria={
field: query["criteria"][field]
for field in self.duplicate_fields_check
}
)

if duplicate:
raise HTTPException(
status_code=400,
detail="Submission already exists. Duplicate data found for fields: {}".format(
", ".join(self.duplicate_fields_check)
),
)

if self.calculate_submission_id:
query["criteria"]["submission_id"] = str(uuid4())

if self.state_enum is not None:
query["criteria"]["state"] = [self.default_state]
query["criteria"]["updated"] = [datetime.utcnow()]

if query.get("update"):
try:
self.store._collection.update_one(
filter=query["criteria"],
update={"$set": query["update"]},
upsert=False,
)
except Exception:
raise HTTPException(
status_code=400,
detail="Problem when trying to patch data.",
)

response = {
"updated_data": query["update"],
"meta": "Submission successful",
}

return response

self.router.patch(
self.patch_sub_path,
tags=self.tags,
summary=f"Patch {model_name} data",
response_model=None,
response_description=f"Patch {model_name} data",
response_model_exclude_unset=True,
include_in_schema=self.include_in_schema,
)(attach_query_ops(patch_data, self.patch_query_operators))
11 changes: 9 additions & 2 deletions src/maggma/api/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,14 +19,21 @@
QUERY_PARAMS = ["criteria", "properties", "skip", "limit"]
STORE_PARAMS = Dict[
Literal[
"criteria", "properties", "sort", "skip", "limit", "request", "pipeline", "hint"
"criteria",
"properties",
"sort",
"skip",
"limit",
"request",
"pipeline",
"hint",
"update",
],
Any,
]


def merge_queries(queries: List[STORE_PARAMS]) -> STORE_PARAMS:

criteria: STORE_PARAMS = {}
properties: List[str] = []

Expand Down
51 changes: 48 additions & 3 deletions tests/api/test_submission_resource.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from random import randint
from urllib.parse import urlencode
from pydantic.utils import get_model

import json
import pytest
from fastapi import FastAPI, Body
from pydantic import BaseModel, Field
Expand Down Expand Up @@ -49,15 +49,24 @@ def query(self, name):

return PostQuery()

@pytest.fixture
def patch_query_op():
class PatchQuery(QueryOperator):
def query(self, name, update):
return {"criteria": {"name": name},
"update": update}

return PatchQuery()

def test_init(owner_store, post_query_op):
def test_init(owner_store, post_query_op, patch_query_op):
resource = SubmissionResource(
store=owner_store,
get_query_operators=[PaginationQuery()],
post_query_operators=[post_query_op],
patch_query_operators=[patch_query_op],
model=Owner,
)
assert len(resource.router.routes) == 4
assert len(resource.router.routes) == 5


def test_msonable(owner_store, post_query_op):
Expand Down Expand Up @@ -93,6 +102,25 @@ def test_submission_search(owner_store, post_query_op):
assert client.post("/?name=test_name").status_code == 200


def test_submission_patch(owner_store, post_query_op, patch_query_op):
endpoint = SubmissionResource(
store=owner_store,
get_query_operators=[PaginationQuery()],
post_query_operators=[post_query_op],
patch_query_operators=[patch_query_op],
calculate_submission_id=True,
model=Owner,
)
app = FastAPI()
app.include_router(endpoint.router)

client = TestClient(app)
update = json.dumps({"last_updated": "2023-06-22T17:32:11.645713"})

assert client.get("/").status_code == 200
assert client.patch(f"/?name=PersonAge9&update={update}").status_code == 200


def test_key_fields(owner_store, post_query_op):
endpoint = SubmissionResource(
store=owner_store,
Expand All @@ -108,3 +136,20 @@ def test_key_fields(owner_store, post_query_op):

assert client.get("/Person1/").status_code == 200
assert client.get("/Person1/").json()["data"][0]["name"] == "Person1"

def test_patch_submission(owner_store, post_query_op):
endpoint = SubmissionResource(
store=owner_store,
get_query_operators=[PaginationQuery()],
post_query_operators=[post_query_op],
calculate_submission_id=False,
model=Owner,
)
app = FastAPI()
app.include_router(endpoint.router)

client = TestClient(app)

assert client.get("/Person1/").status_code == 200
assert client.get("/Person1/").json()["data"][0]["name"] == "Person1"

0 comments on commit a01d007

Please sign in to comment.