Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(ingest/backend): add and use 'x-total-record' header for /get-original-metadata endpoint #2857

Merged
merged 6 commits into from
Sep 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -349,6 +349,13 @@ open class SubmissionController(
@ApiResponse(
responseCode = "200",
description = GET_ORIGINAL_METADATA_RESPONSE_DESCRIPTION,
headers = [
Header(
name = "x-total-records",
description = "The total number of records sent in responseBody",
schema = Schema(type = "integer"),
),
],
)
@ApiResponse(
responseCode = "423",
Expand All @@ -369,16 +376,29 @@ open class SubmissionController(
@HiddenParam authenticatedUser: AuthenticatedUser,
@RequestParam compression: CompressionFormat?,
): ResponseEntity<StreamingResponseBody> {
val stillProcessing = submitModel.checkIfStillProcessingSubmittedData()
if (stillProcessing) {
return ResponseEntity.status(HttpStatus.LOCKED).build()
}

val headers = HttpHeaders()
headers.contentType = MediaType.parseMediaType(MediaType.APPLICATION_NDJSON_VALUE)
if (compression != null) {
headers.add(HttpHeaders.CONTENT_ENCODING, compression.compressionName)
}

val stillProcessing = submitModel.checkIfStillProcessingSubmittedData()
if (stillProcessing) {
return ResponseEntity.status(HttpStatus.LOCKED).build()
}
val totalRecords = submissionDatabaseService.countOriginalMetadata(
authenticatedUser,
organism,
groupIdsFilter?.takeIf { it.isNotEmpty() },
statusesFilter?.takeIf { it.isNotEmpty() },
)
headers.add("x-total-records", totalRecords.toString())
// TODO(https://github.com/loculus-project/loculus/issues/2778)
// There's a possibility that the totalRecords change between the count and the actual query
// this is not too bad, if the client ends up with a few more records than expected
// We just need to make sure the etag used is from before the count
// Alternatively, we could read once to file while counting and then stream the file

val streamBody = streamTransactioned(compression) {
submissionDatabaseService.streamOriginalMetadata(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -950,13 +950,12 @@ open class SubmissionDatabaseService(
)
}

fun streamOriginalMetadata(
private fun originalMetadataFilter(
authenticatedUser: AuthenticatedUser,
organism: Organism,
groupIdsFilter: List<Int>?,
statusesFilter: List<Status>?,
fields: List<String>?,
): Sequence<AccessionVersionOriginalMetadata> {
): Op<Boolean> {
val organismCondition = SequenceEntriesView.organismIs(organism)
val groupCondition = getGroupCondition(groupIdsFilter, authenticatedUser)
val statusCondition = if (statusesFilter != null) {
Expand All @@ -966,6 +965,33 @@ open class SubmissionDatabaseService(
}
val conditions = organismCondition and groupCondition and statusCondition

return conditions
}

fun countOriginalMetadata(
authenticatedUser: AuthenticatedUser,
organism: Organism,
groupIdsFilter: List<Int>?,
statusesFilter: List<Status>?,
): Long = SequenceEntriesView
.selectAll()
.where(
originalMetadataFilter(
authenticatedUser,
organism,
groupIdsFilter,
statusesFilter,
),
)
.count()

fun streamOriginalMetadata(
authenticatedUser: AuthenticatedUser,
organism: Organism,
groupIdsFilter: List<Int>?,
statusesFilter: List<Status>?,
fields: List<String>?,
): Sequence<AccessionVersionOriginalMetadata> {
val originalMetadata = SequenceEntriesView.originalDataColumn
.extract<Map<String, String>>("metadata")
.alias("original_metadata")
Expand All @@ -976,7 +1002,14 @@ open class SubmissionDatabaseService(
SequenceEntriesView.accessionColumn,
SequenceEntriesView.versionColumn,
)
.where(conditions)
.where(
originalMetadataFilter(
authenticatedUser,
organism,
groupIdsFilter,
statusesFilter,
),
)
.fetchSize(streamBatchSize)
.asSequence()
.map {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,10 @@ class GetOriginalMetadataEndpointTest(
@Test
fun `GIVEN no sequence entries in database THEN returns empty response`() {
val response = submissionControllerClient.getOriginalMetadata()

val responseBody = response.expectNdjsonAndGetContent<MetadataMap>()

response.andExpect(status().isOk)
.andExpect(header().string("x-total-records", `is`("0")))
assertThat(responseBody, `is`(emptyList()))
}

Expand All @@ -63,6 +65,9 @@ class GetOriginalMetadataEndpointTest(
val response = submissionControllerClient.getOriginalMetadata()

val responseBody = response.expectNdjsonAndGetContent<AccessionVersionOriginalMetadata>()

response.andExpect(status().isOk)
.andExpect(header().string("x-total-records", `is`(DefaultFiles.NUMBER_OF_SEQUENCES.toString())))
assertThat(responseBody.size, `is`(DefaultFiles.NUMBER_OF_SEQUENCES))
}

Expand Down Expand Up @@ -150,6 +155,8 @@ class GetOriginalMetadataEndpointTest(
groupIdsFilter = listOf(g0),
statusesFilter = listOf(Status.APPROVED_FOR_RELEASE),
)
response.andExpect(status().isOk)
.andExpect(header().string("x-total-records", `is`(expectedAccessionVersions.count().toString())))
val responseBody = response.expectNdjsonAndGetContent<AccessionVersionOriginalMetadata>()

assertThat(responseBody, hasSize(expected.size))
Expand Down
33 changes: 21 additions & 12 deletions ingest/scripts/call_loculus.py
Original file line number Diff line number Diff line change
Expand Up @@ -312,20 +312,29 @@ def get_submitted(config: Config):
"statusesFilter": [],
}

logger.info("Getting previously submitted sequences")
while True:
logger.info("Getting previously submitted sequences")

response = make_request(HTTPMethod.GET, url, config, params=params)
response = make_request(HTTPMethod.GET, url, config, params=params)
expected_record_count = int(response.headers["x-total-records"])

entries: list[dict[str, Any]] = []
try:
entries = list(jsonlines.Reader(response.iter_lines()).iter())
except jsonlines.Error as err:
response_summary = response.text
max_error_length = 100
if len(response_summary) > max_error_length:
response_summary = response_summary[:50] + "\n[..]\n" + response_summary[-50:]
logger.error(f"Error decoding JSON from /get-original-metadata: {response_summary}")
raise ValueError from err
entries: list[dict[str, Any]] = []
try:
entries = list(jsonlines.Reader(response.iter_lines()).iter())
except jsonlines.Error as err:
response_summary = response.text
max_error_length = 100
if len(response_summary) > max_error_length:
response_summary = response_summary[:50] + "\n[..]\n" + response_summary[-50:]
logger.error(f"Error decoding JSON from /get-original-metadata: {response_summary}")
raise ValueError from err

if len(entries) == expected_record_count:
f"Got {len(entries)} records as expected"
break
logger.error(f"Got incomplete original metadata stream: expected {len(entries)}"
f"records but got {expected_record_count}. Retrying after 60 seconds.")
sleep(60)

# Initialize the dictionary to store results
submitted_dict: dict[str, dict[str, str | list]] = {}
Expand Down
Loading