Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Improve nutrition extraction #1484

Merged
merged 2 commits into from
Dec 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 13 additions & 3 deletions robotoff/prediction/nutrition_extraction.py
Original file line number Diff line number Diff line change
Expand Up @@ -416,6 +416,9 @@ def postprocess_aggregated_entities(
return postprocessed_entities


SERVING_SIZE_MISSING_G = re.compile(r"([0-9]+[,.]?[0-9]*)\s*9")


def postprocess_aggregated_entities_single(entity: JSONType) -> JSONType:
"""Postprocess a single aggregated entity and return an entity with the extracted
information. This is the first step in the postprocessing of aggregated entities.
Expand Down Expand Up @@ -466,6 +469,11 @@ def postprocess_aggregated_entities_single(entity: JSONType) -> JSONType:

if entity_label == "serving_size":
value = words_str
# Sometimes the unit 'g' in the `serving_size is detected as a '9'
# In such cases, we replace the '9' with 'g'
match = SERVING_SIZE_MISSING_G.match(value)
if match:
value = f"{match.group(1)} g"
elif words_str in ("trace", "traces"):
value = "traces"
else:
Expand Down Expand Up @@ -549,13 +557,15 @@ def match_nutrient_value(
for target in (
"proteins",
"sugars",
"added-sugars",
"carbohydrates",
"fat",
"saturated-fat",
"fiber",
"salt",
"trans-fat",
# we use "_" here as separator as '-' is only used in
# Product Opener, the label names are all separated by '_'
"saturated_fat",
"added_sugars",
"trans_fat",
)
)
and value.endswith("9")
Expand Down
53 changes: 43 additions & 10 deletions robotoff/workers/tasks/import_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,10 @@ def rerun_import_all_images(
where_clauses.append(ImageModel.server_type == server_type.name)
query = (
ImageModel.select(
ImageModel.barcode, ImageModel.image_id, ImageModel.server_type
ImageModel.id,
ImageModel.barcode,
ImageModel.image_id,
ImageModel.server_type,
)
.where(*where_clauses)
.order_by(ImageModel.uploaded_at.desc())
Expand All @@ -104,18 +107,16 @@ def rerun_import_all_images(
if return_count:
return query.count()

for barcode, image_id, server_type_str in query:
for image_model_id, barcode, image_id, server_type_str in query:
if not isinstance(barcode, str) and not barcode.isdigit():
raise ValueError("Invalid barcode: %s" % barcode)

product_id = ProductIdentifier(barcode, ServerType[server_type_str])
image_url = generate_image_url(product_id, image_id)
ocr_url = generate_json_ocr_url(product_id, image_id)
enqueue_job(
run_import_image_job,
get_high_queue(product_id),
job_kwargs={"result_ttl": 0},
run_import_image(
product_id=product_id,
image_model_id=image_model_id,
image_url=image_url,
ocr_url=ocr_url,
flags=flags,
Expand Down Expand Up @@ -144,16 +145,16 @@ def run_import_image_job(
What tasks are performed can be controlled using the `flags` parameter. By
default, all tasks are performed. A new rq job is enqueued for each task.

Before running the tasks, the image is downloaded and stored in the Robotoff
DB.

:param product_id: the product identifier
:param image_url: the URL of the image to import
:param ocr_url: the URL of the OCR JSON file
:param flags: the list of flags to run, defaults to None (all)
"""
logger.info("Running `import_image` for %s, image %s", product_id, image_url)

if flags is None:
flags = [flag for flag in ImportImageFlag]

source_image = get_source_from_url(image_url)
product = get_product_store(product_id.server_type)[product_id]
if product is None and settings.ENABLE_MONGODB_ACCESS:
Expand Down Expand Up @@ -185,13 +186,45 @@ def run_import_image_job(
ImageModel.bulk_update([image_model], fields=["deleted"])
return

run_import_image(
product_id=product_id,
image_model_id=image_model.id,
image_url=image_url,
ocr_url=ocr_url,
flags=flags,
)


def run_import_image(
product_id: ProductIdentifier,
image_model_id: int,
image_url: str,
ocr_url: str,
flags: list[ImportImageFlag] | None = None,
) -> None:
"""Launch all extraction tasks on an image.

We assume that the image exists in the Robotoff DB.

What tasks are performed can be controlled using the `flags` parameter. By
default, all tasks are performed. A new rq job is enqueued for each task.

:param product_id: the product identifier
:param image_model_id: the DB ID of the image
:param image_url: the URL of the image to import
:param ocr_url: the URL of the OCR JSON file
:param flags: the list of flags to run, defaults to None (all)
"""
if flags is None:
flags = [flag for flag in ImportImageFlag]

if ImportImageFlag.add_image_fingerprint in flags:
# Compute image fingerprint, this job is low priority
enqueue_job(
add_image_fingerprint_job,
low_queue,
job_kwargs={"result_ttl": 0},
image_model_id=image_model.id,
image_model_id=image_model_id,
)

if product_id.server_type.is_food():
Expand Down
57 changes: 57 additions & 0 deletions tests/unit/prediction/test_nutrition_extraction.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
aggregate_entities,
match_nutrient_value,
postprocess_aggregated_entities,
postprocess_aggregated_entities_single,
)


Expand Down Expand Up @@ -392,8 +393,64 @@ def test_aggregate_entities_multiple_entities(self):
("25.9", "iron_100g", ("25.9", None, True)),
("O g", "salt_100g", ("0", "g", True)),
("O", "salt_100g", ("0", None, True)),
("0,19", "saturated_fat_100g", ("0.1", "g", True)),
],
)
def test_match_nutrient_value(words_str: str, entity_label: str, expected_output):

assert match_nutrient_value(words_str, entity_label) == expected_output


@pytest.mark.parametrize(
"aggregated_entity,expected_output",
[
(
{
"end": 90,
"score": 0.9985358715057373,
"start": 89,
"words": ["0,19\n"],
"entity": "SATURATED_FAT_100G",
"char_end": 459,
"char_start": 454,
},
{
"char_end": 459,
"char_start": 454,
"end": 90,
"entity": "saturated-fat_100g",
"score": 0.9985358715057373,
"start": 89,
"text": "0,19",
"unit": "g",
"valid": True,
"value": "0.1",
},
),
(
{
"end": 92,
"score": 0.9985358715057373,
"start": 90,
"words": ["42.5 9"],
"entity": "SERVING_SIZE",
"char_end": 460,
"char_start": 454,
},
{
"char_end": 460,
"char_start": 454,
"end": 92,
"entity": "serving_size",
"score": 0.9985358715057373,
"start": 90,
"text": "42.5 9",
"unit": None,
"valid": True,
"value": "42.5 g",
},
),
],
)
def test_postprocess_aggregated_entities_single(aggregated_entity, expected_output):
assert postprocess_aggregated_entities_single(aggregated_entity) == expected_output
Loading