Skip to content

Commit

Permalink
Fix for albumentations transform
Browse files Browse the repository at this point in the history
  • Loading branch information
JBWilkie committed Sep 13, 2024
1 parent 943684c commit 0cba8ac
Show file tree
Hide file tree
Showing 6 changed files with 23 additions and 16 deletions.
12 changes: 8 additions & 4 deletions darwin/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -392,13 +392,17 @@ def __init__(
super().__init__("Complex polygons not yet supported for dataloop import")


class ExportException(DarwinException): ...
class ExportException(DarwinException):
...


class ExportException_CouldNotAssembleOutputPath(ExportException): ...
class ExportException_CouldNotAssembleOutputPath(ExportException):
...


class ExportException_CouldNotBuildOutput(ExportException): ...
class ExportException_CouldNotBuildOutput(ExportException):
...


class ExportException_CouldNotWriteFile(ExportException): ...
class ExportException_CouldNotWriteFile(ExportException):
...
3 changes: 2 additions & 1 deletion darwin/future/core/types/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,8 @@


class Implements_str(Protocol):
def __str__(self) -> str: ...
def __str__(self) -> str:
...


Stringable = Union[str, Implements_str]
Expand Down
2 changes: 0 additions & 2 deletions darwin/future/data_objects/properties.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,8 +84,6 @@ class FullProperty(DefaultDarwin):
options: Optional[List[PropertyValue]] = None
granularity: PropertyGranularity = PropertyGranularity("section")

# model_config = ConfigDict(use_enum_values=True)

def to_create_endpoint(
self,
) -> dict:
Expand Down
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
class InvalidValueForTest: ...
class InvalidValueForTest:
...
12 changes: 6 additions & 6 deletions darwin/importer/importer.py
Original file line number Diff line number Diff line change
Expand Up @@ -283,9 +283,9 @@ def _get_team_properties_annotation_lookup(client, team_slug):
team_properties = client.get_team_properties(team_slug)

# (property-name, annotation_class_id): FullProperty object
team_properties_annotation_lookup: Dict[Tuple[str, Optional[int]], FullProperty] = (
{}
)
team_properties_annotation_lookup: Dict[
Tuple[str, Optional[int]], FullProperty
] = {}
for prop in team_properties:
team_properties_annotation_lookup[(prop.name, prop.annotation_class_id)] = prop

Expand Down Expand Up @@ -1407,9 +1407,9 @@ def _import_annotations(
# Insert the default slot name if not available in the import source
annotation = _handle_slot_names(annotation, dataset.version, default_slot_name)

annotation_class_ids_map[(annotation_class.name, annotation_type)] = (
annotation_class_id
)
annotation_class_ids_map[
(annotation_class.name, annotation_type)
] = annotation_class_id
serial_obj = {
"annotation_class_id": annotation_class_id,
"data": data,
Expand Down
7 changes: 5 additions & 2 deletions darwin/torch/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -368,8 +368,11 @@ def _pre_process(self, image: np.ndarray, annotation: dict) -> dict:
if (
masks is not None and masks.numel() > 0
): # using numel() to check if tensor is non-empty
print("WE GOT MASKS")
albumentation_dict["masks"] = masks.numpy()
if isinstance(masks, torch.Tensor):
masks = masks.numpy()
if masks.ndim == 3: # Ensure masks is a list of numpy arrays
masks = [masks[i] for i in range(masks.shape[0])]
albumentation_dict["masks"] = masks

return albumentation_dict

Expand Down

0 comments on commit 0cba8ac

Please sign in to comment.