Skip to content

Commit

Permalink
FEAT: Add multi-label support for label-studio integration
Browse files Browse the repository at this point in the history
  • Loading branch information
tataganesh committed Aug 25, 2024
1 parent 4f076cc commit e61d137
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 7 deletions.
14 changes: 12 additions & 2 deletions fiftyone/utils/labelstudio.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,7 @@ def supported_media_types(self):
def supported_label_types(self):
return [
"classification",
"classifications",
"detection",
"detections",
"instance",
Expand Down Expand Up @@ -416,7 +417,14 @@ def _import_annotations(self, tasks, task_map):
# add to dict
sample_id = task_map[t["id"]]
# we save and pass both id and the name of the label field
results[sample_id] = {l.id: (ln, l) for (ln, l) in labels}
results[sample_id] = {}
for ln, l in labels:
if isinstance(l, fol.Classifications):
for classification_obj in l.classifications:
label_id = classification_obj.id
results[sample_id][label_id] = (ln, classification_obj)
else:
results[sample_id][l.id] = (ln, l)

return results

Expand Down Expand Up @@ -891,7 +899,9 @@ def _from_choices(result):
return fol.Classification(label=label_values[0])

# multi-label classification
return [fol.Classification(label=l) for l in label_values]
return fol.Classifications(
classifications=[fol.Classification(label=l) for l in label_values]
)


def _from_rectanglelabels(result):
Expand Down
13 changes: 8 additions & 5 deletions tests/intensive/labelstudio_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,10 +265,12 @@ def label_mappings():
"from_name": "choice",
"type": "choices",
},
"fiftyone": [
fo.Classification(label="Airbus"),
fo.Classification(label="Boeing"),
],
"fiftyone": fo.Classifications(
classifications=[
fo.Classification(label="Airbus"),
fo.Classification(label="Boeing"),
]
),
},
{
"labelstudio": {
Expand Down Expand Up @@ -477,7 +479,6 @@ def test_import_labels(label_mappings):
for case in label_mappings:
label = fouls.import_label_studio_annotation(case["labelstudio"])[1]
expected = case["fiftyone"]

if isinstance(expected, (list, tuple)):
for pair in zip(label, expected):
_assert_labels_equal(*pair)
Expand Down Expand Up @@ -673,6 +674,8 @@ def _assert_labels_equal(converted, expected):
_assert_labels_equal(*pair)
elif expected._cls == "Regression":
assert expected.value == converted.value
elif expected._cls == "Classifications":
assert all(cls_obj.label for cls_obj in expected.classifications)
else:
raise NotImplementedError()

Expand Down

0 comments on commit e61d137

Please sign in to comment.