Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add multi-label support for label-studio integration #4725

Merged
merged 1 commit into from
Sep 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 12 additions & 2 deletions fiftyone/utils/labelstudio.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,7 @@ def supported_media_types(self):
def supported_label_types(self):
return [
"classification",
"classifications",
"detection",
"detections",
"instance",
Expand Down Expand Up @@ -416,7 +417,14 @@ def _import_annotations(self, tasks, task_map):
# add to dict
sample_id = task_map[t["id"]]
# we save and pass both id and the name of the label field
results[sample_id] = {l.id: (ln, l) for (ln, l) in labels}
results[sample_id] = {}
for ln, l in labels:
if isinstance(l, fol.Classifications):
for classification_obj in l.classifications:
label_id = classification_obj.id
results[sample_id][label_id] = (ln, classification_obj)
else:
results[sample_id][l.id] = (ln, l)

return results

Expand Down Expand Up @@ -891,7 +899,9 @@ def _from_choices(result):
return fol.Classification(label=label_values[0])

# multi-label classification
return [fol.Classification(label=l) for l in label_values]
return fol.Classifications(
classifications=[fol.Classification(label=l) for l in label_values]
)


def _from_rectanglelabels(result):
Expand Down
13 changes: 8 additions & 5 deletions tests/intensive/labelstudio_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,10 +265,12 @@ def label_mappings():
"from_name": "choice",
"type": "choices",
},
"fiftyone": [
fo.Classification(label="Airbus"),
fo.Classification(label="Boeing"),
],
"fiftyone": fo.Classifications(
classifications=[
fo.Classification(label="Airbus"),
fo.Classification(label="Boeing"),
]
),
},
{
"labelstudio": {
Expand Down Expand Up @@ -477,7 +479,6 @@ def test_import_labels(label_mappings):
for case in label_mappings:
label = fouls.import_label_studio_annotation(case["labelstudio"])[1]
expected = case["fiftyone"]

if isinstance(expected, (list, tuple)):
for pair in zip(label, expected):
_assert_labels_equal(*pair)
Expand Down Expand Up @@ -673,6 +674,8 @@ def _assert_labels_equal(converted, expected):
_assert_labels_equal(*pair)
elif expected._cls == "Regression":
assert expected.value == converted.value
elif expected._cls == "Classifications":
assert all(cls_obj.label for cls_obj in expected.classifications)
else:
raise NotImplementedError()

Expand Down
Loading