Skip to content

Commit

Permalink
Merge pull request #52 from Project-AgML/dev
Browse files Browse the repository at this point in the history
v0.6.0 - Auto-training Module
  • Loading branch information
amogh7joshi authored Jan 4, 2024
2 parents 73e646f + 1d8892a commit 9a9726e
Show file tree
Hide file tree
Showing 23 changed files with 1,822 additions and 100 deletions.
31 changes: 28 additions & 3 deletions agml/backend/tftorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,9 +169,12 @@ def _convert_image_to_torch(image):
if image.shape[0] == 1 and image.shape[-1] <= 3 and image.ndim == 4:
return torch.from_numpy(image).permute(0, 3, 1, 2).float()
return image
if image.shape[0] > image.shape[-1]:
return torch.from_numpy(image).permute(2, 0, 1).float()
return torch.from_numpy(image)
if image.ndim == 3:
if image.shape[0] > image.shape[-1]:
return torch.from_numpy(image).permute(2, 0, 1).float()
elif image.ndim == 2:
return torch.from_numpy(image)
return torch.from_numpy(image).float()


def _postprocess_torch_annotation(image):
Expand Down Expand Up @@ -293,3 +296,25 @@ def _add_dataset_to_mro(inst, mode):
if torch_data.Dataset not in inst.__class__.__bases__:
inst.__class__.__bases__ += (torch_data.Dataset,)


def collate_fn_basic(batch):
images = torch.stack([i[0] for i in batch], dim = 0)
coco = tuple(zip(*[i[1] for i in batch]))
return images, coco


def collate_fn_efficientdet(batch):
"""Collates items together into a batch."""
images, targets = tuple(zip(*batch))
images = torch.stack(images)
images = images.float()

boxes = [target["bboxes"].float() for target in targets]
labels = [target["labels"].float() for target in targets]
img_size = torch.stack([target["img_size"] for target in targets]).float()
img_scale = torch.stack([target["img_scale"] for target in targets]).float()

annotations = {
"bbox": boxes, "cls": labels,
"img_size": img_size, "img_scale": img_scale}
return images, annotations, targets
16 changes: 10 additions & 6 deletions agml/data/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -658,6 +658,7 @@ def val_data(self):
return self._val_data
self._val_data = self._generate_split_loader(
self._val_content, split = 'val')
self._val_data.eval()
return self._val_data

@property
Expand All @@ -667,6 +668,7 @@ def test_data(self):
return self._test_data
self._test_data = self._generate_split_loader(
self._test_content, split = 'test')
self._test_data.eval()
return self._test_data

def eval(self) -> "AgMLDataLoader":
Expand Down Expand Up @@ -1714,13 +1716,15 @@ def export_torch(self, **loader_kwargs):
# The `collate_fn` for object detection is different because
# the COCO JSON dictionaries each have different formats. So,
# we need to replace it with a custom function.
collate_fn = loader_kwargs.pop('collate_fn')
collate_fn = loader_kwargs.pop('collate_fn', None)
if obj.task == 'object_detection' and collate_fn is None:
def collate_fn(batch):
images = torch.stack(
[i[0] for i in batch], dim = 0)
coco = tuple(zip(*[i[1] for i in batch]))
return images, coco
if any('efficientdet' in i.__class__.__name__.lower() for i in
self._manager._transform_manager.get_transform_states()['dual_transform']):
from agml.backend.tftorch import collate_fn_efficientdet
collate_fn = collate_fn_efficientdet
else:
from agml.backend.tftorch import collate_fn_basic
collate_fn = collate_fn_basic

# Return the DataLoader with a copy of this AgMLDataLoader, so
# that changes to this will not affect the returned loader.
Expand Down
7 changes: 7 additions & 0 deletions agml/data/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,6 +225,13 @@ def assign_resize(self, image_size, method):
image_size = 'default'
self._resize_manager.assign(image_size, method)

def _warn_training_resize(self):
"""Warn the user if they are not resizing during training."""
if self._resize_manager._resize_type == 'default':
log(f"Warning: you have not applied any resizing method to the "
f"dataset `{self._dataset_name}`. This may cause errors during "
f"training if the image aspect ratios are not consistent.")

def push_transforms(self, **transform_dict):
"""Pushes a transformation to the data transform pipeline."""
# Check if any transforms are being reset and assign them as such.
Expand Down
20 changes: 8 additions & 12 deletions agml/data/managers/resize.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,9 +105,9 @@ def _tuple_euclidean(t1, t2):
return np.sqrt((x1 - x2) ** 2 + (y1 - y2) ** 2)

def _method_resize(self, image, size):
return cv2.resize(image, size, interpolation = self._interpolation)
return cv2.resize(image, size, interpolation=self._interpolation)

def assign(self, kind, method = None):
def assign(self, kind, method=None):
"""Assigns the resize parameter (and does necessary calculations)."""
if kind == 'default':
self._resize_type = 'default'
Expand Down Expand Up @@ -157,14 +157,11 @@ def assign(self, kind, method = None):
def apply(self, contents):
"""Applies the resizing operation to the input data."""
if self._task in ['image_classification', 'image_regression']:
return self._resize_image_input(
contents, self._image_size)
return self._resize_image_input(contents, self._image_size)
elif self._task == 'object_detection':
return self._resize_image_and_coco(
contents, self._image_size)
return self._resize_image_and_coco(contents, self._image_size)
elif self._task == 'semantic_segmentation':
return self._resize_image_and_mask(
contents, self._image_size)
return self._resize_image_and_mask(contents, self._image_size)

def _inference_shape(self, info):
"""Attempts to inference a shape for the `auto` method.
Expand Down Expand Up @@ -231,7 +228,7 @@ def _random_inference_shape(self):
f"Attempting to randomly inference the dataset shape.")

image_path = os.path.join(self._dataset_root, 'images')
images = np.random.choice(os.listdir(image_path), size = 25)
images = np.random.choice(os.listdir(image_path), size=25)

# Get all of the shapes from the random sample of images.
shapes = []
Expand All @@ -242,7 +239,7 @@ def _random_inference_shape(self):

# Inference a valid shape from the shapes. We dispatch to the
# regular inferencing method once we have the shapes and counts.
unique_shapes, counts = np.unique(shapes, return_counts = True, axis = 0)
unique_shapes, counts = np.unique(shapes, return_counts=True, axis=0)
return self._inference_shape((unique_shapes, counts))

def _maybe_load_shape_info(self):
Expand Down Expand Up @@ -275,8 +272,7 @@ def _resize_image_input(self, contents, image_size):
return self._resize_single_image(contents, image_size)
if image_size is not None:
return {
k: self._method_resize(
i.astype(np.uint16), image_size).astype(np.int32)
k: self._method_resize(i.astype(np.uint16), image_size).astype(np.int32)
for k, i in image.items()}, label
return image, label

Expand Down
6 changes: 2 additions & 4 deletions agml/data/managers/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,16 +151,14 @@ def update_state(self, state):
self._state = TrainState.TORCH
if get_backend() == 'tf':
if user_changed_backend():
raise StrictBackendError(
change = 'torch', obj = t_(state))
raise StrictBackendError(change = 'torch', obj = t_(state))
set_backend('torch')
self._resize_manager.assign('train-auto')
elif t_(state) == TrainState.TF:
self._state = TrainState.TF
if get_backend() == 'torch':
if user_changed_backend():
raise StrictBackendError(
change = 'tf', obj = t_(state))
raise StrictBackendError(change = 'tf', obj = t_(state))
set_backend('tf')
self._resize_manager.assign('train-auto')

Expand Down
16 changes: 10 additions & 6 deletions agml/data/managers/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,10 @@ def assign(self, kind, transform):
kind = 'dual_transform'
if self._task == 'semantic_segmentation':
kind = 'dual_transform'
elif 'agml.models' in transform.__module__:
from agml.models.preprocessing import EfficientDetPreprocessor
if isinstance(transform, EfficientDetPreprocessor):
kind = 'dual_transform'
except AttributeError:
# Some type of object that doesn't have `__module__`.
pass
Expand All @@ -154,9 +158,9 @@ def assign(self, kind, transform):
if t_(kind) == TransformKind.Transform:
transform = self._maybe_normalization_or_regular_transform(transform)
elif t_(kind) == TransformKind.TargetTransform:
if isinstance(transform, tuple): # a special convenience case
if isinstance(transform, tuple): # a special convenience case
if transform[0] == 'one_hot':
if transform[2] is not True: # removing the transform
if transform[2] is not True: # removing the transform
self._pop_transform(OneHotLabelTransform, kind)
return
transform = OneHotLabelTransform(transform[1])
Expand All @@ -168,9 +172,9 @@ def assign(self, kind, transform):
if t_(kind) == TransformKind.Transform:
transform = self._maybe_normalization_or_regular_transform(transform)
elif t_(kind) == TransformKind.TargetTransform:
if isinstance(transform, tuple): # a special convenience case
if isinstance(transform, tuple): # a special convenience case
if transform[0] == 'one_hot':
if transform[2] is not True: # removing the transform
if transform[2] is not True: # removing the transform
self._pop_transform(OneHotLabelTransform, kind)
return
transform = OneHotLabelTransform(transform[1])
Expand All @@ -180,9 +184,9 @@ def assign(self, kind, transform):
if t_(kind) == TransformKind.Transform:
transform = self._maybe_normalization_or_regular_transform(transform)
elif t_(kind) == TransformKind.TargetTransform:
if isinstance(transform, tuple): # a special convenience case
if isinstance(transform, tuple): # a special convenience case
if transform[0] == 'channel_basis':
if transform[2] is not True: # removing the transform
if transform[2] is not True: # removing the transform
self._pop_transform(MaskToChannelBasisTransform, kind)
transform = MaskToChannelBasisTransform(transform[1])
else:
Expand Down
2 changes: 2 additions & 0 deletions agml/data/multi_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -631,6 +631,7 @@ def val_data(self):
return self._val_data
self._val_data = self._generate_split_loader(
self._loaders.get_attributes('val_data'), split = 'val')
self._val_data.eval()
return self._val_data

@property
Expand All @@ -640,6 +641,7 @@ def test_data(self):
return self._test_data
self._test_data = self._generate_split_loader(
self._loaders.get_attributes('test_data'), split = 'test')
self._test_data.eval()
return self._test_data

def eval(self):
Expand Down
31 changes: 31 additions & 0 deletions agml/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,8 @@ class AgMLModelBase(AgMLSerializable, LightningModule):
and image input preprocessing, as well as other stubs for common methods.
"""

_ml_task: str

def __init__(self):
self._benchmark = BenchmarkMetadata(None)
super(AgMLModelBase, self).__init__()
Expand Down Expand Up @@ -183,5 +185,34 @@ def evaluate(self, loader, **kwargs):
"""Evaluates the model on the given loader."""
raise NotImplementedError

@abc.abstractmethod
def _prepare_for_training(self, **kwargs):
"""Prepares the model for training (setting parameters, etc.)"""
raise NotImplementedError

def on_train_epoch_end(self):
if self._ml_task != 'object_detection':
for _, metric in self._metrics:
metric.reset()

def on_validation_epoch_end(self):
if self._ml_task != 'object_detection':
for _, metric in self._metrics:
metric.reset()

def get_progress_bar_dict(self):
if not hasattr(super(), 'get_progress_bar_dict'):
return
tqdm_dict = super().get_progress_bar_dict()
tqdm_dict.pop('v_num', None)
return tqdm_dict

def get_metrics(self):
if not hasattr(super(), 'get_metrics'):
return
tqdm_dict = super().get_metrics()
tqdm_dict.pop('v_num', None)
return tqdm_dict



Loading

0 comments on commit 9a9726e

Please sign in to comment.