From c9b89a9ea134744e56e06d47b81420763336b992 Mon Sep 17 00:00:00 2001 From: Luca Actis Grosso Date: Mon, 22 Nov 2021 16:14:29 +0100 Subject: [PATCH 1/7] add predict_kwargs in ObjectDetectionModel in order to filter the prediction using custom treshold and manage other parameters --- flash/core/integrations/icevision/adapter.py | 9 ++++++--- flash/image/detection/model.py | 8 ++++++++ tests/image/detection/test_model.py | 17 +++++++++++++++++ 3 files changed, 31 insertions(+), 3 deletions(-) diff --git a/flash/core/integrations/icevision/adapter.py b/flash/core/integrations/icevision/adapter.py index e723bc2cd5..37ee729da3 100644 --- a/flash/core/integrations/icevision/adapter.py +++ b/flash/core/integrations/icevision/adapter.py @@ -46,13 +46,14 @@ class IceVisionAdapter(Adapter): required_extras: str = "image" - def __init__(self, model_type, model, icevision_adapter, backbone): + def __init__(self, model_type, model, icevision_adapter, backbone, predict_kwargs): super().__init__() self.model_type = model_type self.model = model self.icevision_adapter = icevision_adapter self.backbone = backbone + self.predict_kwargs = predict_kwargs @classmethod @catch_url_error @@ -62,6 +63,7 @@ def from_task( num_classes: int, backbone: str, head: str, + predict_kwargs: Dict, pretrained: bool = True, metrics: Optional["IceVisionMetric"] = None, image_size: Optional = None, @@ -77,7 +79,7 @@ def from_task( **kwargs, ) icevision_adapter = icevision_adapter(model=model, metrics=metrics) - return cls(model_type, model, icevision_adapter, backbone) + return cls(model_type, model, icevision_adapter, backbone, predict_kwargs) @staticmethod def _collate_fn(collate_fn, samples, metadata: Optional[List[Dict[str, Any]]] = None): @@ -198,7 +200,8 @@ def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> A return batch def forward(self, batch: Any) -> Any: - return from_icevision_predictions(self.model_type.predict_from_dl(self.model, [batch], show_pbar=False)) + return from_icevision_predictions(self.model_type.predict_from_dl(self.model, [batch], show_pbar=False, + **self.predict_kwargs)) def training_epoch_end(self, outputs) -> None: return self.icevision_adapter.training_epoch_end(outputs) diff --git a/flash/image/detection/model.py b/flash/image/detection/model.py index 0a080af611..2d2ee4d306 100644 --- a/flash/image/detection/model.py +++ b/flash/image/detection/model.py @@ -33,6 +33,7 @@ class ObjectDetector(AdapterTask): lr_scheduler: The LR scheduler to use during training. learning_rate: The learning rate to use for training. output: The :class:`~flash.core.data.io.output.Output` to use when formatting prediction outputs. + predict_kwargs: dictionary containing parameters that will be used during the prediction phase. kwargs: additional kwargs nessesary for initializing the backbone task """ @@ -50,10 +51,12 @@ def __init__( lr_scheduler: LR_SCHEDULER_TYPE = None, learning_rate: float = 5e-3, output: OUTPUT_TYPE = None, + predict_kwargs: Dict = None, **kwargs: Any, ): self.save_hyperparameters() + predict_kwargs = predict_kwargs if predict_kwargs else {} metadata = self.heads.get(head, with_metadata=True) adapter = metadata["metadata"]["adapter"].from_task( self, @@ -61,6 +64,7 @@ def __init__( backbone=backbone, head=head, pretrained=pretrained, + predict_kwargs=predict_kwargs, **kwargs, ) @@ -75,3 +79,7 @@ def __init__( def _ci_benchmark_fn(self, history: List[Dict[str, Any]]) -> None: """This function is used only for debugging usage with CI.""" # todo + + def set_predict_kwargs(self, value): + """This function is used to update the kwargs used for the prediction step""" + self.adapter.predict_kwargs = value diff --git a/tests/image/detection/test_model.py b/tests/image/detection/test_model.py index 5c4997b151..a73cdc059c 100644 --- a/tests/image/detection/test_model.py +++ b/tests/image/detection/test_model.py @@ -134,3 +134,20 @@ def test_cli(): main() except SystemExit: pass + + +@pytest.mark.parametrize("head", ["retinanet"]) +@pytest.mark.skipif(not _IMAGE_AVAILABLE, reason="image libraries aren't installed.") +@pytest.mark.skipif(not _ICEVISION_AVAILABLE, reason="IceVision is not installed for testing") +def test_predict(tmpdir, head): + model = ObjectDetector(num_classes=2, head=head, pretrained=False) + ds = DummyDetectionDataset((128, 128, 3), 1, 2, 10) + trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True) + dl = model.process_train_dataset(ds, trainer, 2, 0, False, None) + trainer.fit(model, dl) + dl = model.process_predict_dataset(ds, batch_size=2) + predictions = trainer.predict(model, dl) + assert len(predictions[0][0]["bboxes"]) > 0 + model.set_predict_kwargs({"detection_threshold": 2}) + predictions = trainer.predict(model, dl) + assert len(predictions[0][0]["bboxes"]) == 0 From c07ee055970c9036c897db4e292a6d01347ec50a Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 22 Nov 2021 15:21:43 +0000 Subject: [PATCH 2/7] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- flash/core/integrations/icevision/adapter.py | 5 +++-- flash/image/detection/model.py | 2 +- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/flash/core/integrations/icevision/adapter.py b/flash/core/integrations/icevision/adapter.py index 37ee729da3..7d155dd5b9 100644 --- a/flash/core/integrations/icevision/adapter.py +++ b/flash/core/integrations/icevision/adapter.py @@ -200,8 +200,9 @@ def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> A return batch def forward(self, batch: Any) -> Any: - return from_icevision_predictions(self.model_type.predict_from_dl(self.model, [batch], show_pbar=False, - **self.predict_kwargs)) + return from_icevision_predictions( + self.model_type.predict_from_dl(self.model, [batch], show_pbar=False, **self.predict_kwargs) + ) def training_epoch_end(self, outputs) -> None: return self.icevision_adapter.training_epoch_end(outputs) diff --git a/flash/image/detection/model.py b/flash/image/detection/model.py index 2d2ee4d306..277fadb169 100644 --- a/flash/image/detection/model.py +++ b/flash/image/detection/model.py @@ -81,5 +81,5 @@ def _ci_benchmark_fn(self, history: List[Dict[str, Any]]) -> None: # todo def set_predict_kwargs(self, value): - """This function is used to update the kwargs used for the prediction step""" + """This function is used to update the kwargs used for the prediction step.""" self.adapter.predict_kwargs = value From e02c77320b1a4faa710a7b14f574220ed6b217e6 Mon Sep 17 00:00:00 2001 From: Luca Actis Grosso Date: Mon, 22 Nov 2021 18:32:24 +0100 Subject: [PATCH 3/7] add predict_kwargs in InstanceSegmentation and in KeypointDetector --- flash/image/instance_segmentation/model.py | 8 ++++++++ flash/image/keypoint_detection/model.py | 8 ++++++++ 2 files changed, 16 insertions(+) diff --git a/flash/image/instance_segmentation/model.py b/flash/image/instance_segmentation/model.py index 50c1936b9e..ba10e8cd8f 100644 --- a/flash/image/instance_segmentation/model.py +++ b/flash/image/instance_segmentation/model.py @@ -40,6 +40,7 @@ class InstanceSegmentation(AdapterTask): lr_scheduler: The LR scheduler to use during training. learning_rate: The learning rate to use for training. output: The :class:`~flash.core.data.io.output.Output` to use when formatting prediction outputs. + predict_kwargs: dictionary containing parameters that will be used during the prediction phase. **kwargs: additional kwargs used for initializing the task """ @@ -57,10 +58,12 @@ def __init__( lr_scheduler: LR_SCHEDULER_TYPE = None, learning_rate: float = 5e-4, output: OUTPUT_TYPE = None, + predict_kwargs: Dict = None, **kwargs: Any, ): self.save_hyperparameters() + predict_kwargs = predict_kwargs if predict_kwargs else {} metadata = self.heads.get(head, with_metadata=True) adapter = metadata["metadata"]["adapter"].from_task( self, @@ -68,6 +71,7 @@ def __init__( backbone=backbone, head=head, pretrained=pretrained, + predict_kwargs=predict_kwargs, **kwargs, ) @@ -96,3 +100,7 @@ def on_load_checkpoint(self, checkpoint: Dict[str, Any]) -> None: input_transform=InstanceSegmentationInputTransform(), output_transform=InstanceSegmentationOutputTransform(), ) + + def set_predict_kwargs(self, value): + """This function is used to update the kwargs used for the prediction step""" + self.adapter.predict_kwargs = value diff --git a/flash/image/keypoint_detection/model.py b/flash/image/keypoint_detection/model.py index 3b404d8235..55f02af209 100644 --- a/flash/image/keypoint_detection/model.py +++ b/flash/image/keypoint_detection/model.py @@ -34,6 +34,7 @@ class KeypointDetector(AdapterTask): lr_scheduler: The LR scheduler to use during training. learning_rate: The learning rate to use for training. output: The :class:`~flash.core.data.io.output.Output` to use when formatting prediction outputs. + predict_kwargs: dictionary containing parameters that will be used during the prediction phase. **kwargs: additional kwargs used for initializing the task """ @@ -52,10 +53,12 @@ def __init__( lr_scheduler: LR_SCHEDULER_TYPE = None, learning_rate: float = 5e-4, output: OUTPUT_TYPE = None, + predict_kwargs: Dict = None, **kwargs: Any, ): self.save_hyperparameters() + predict_kwargs = predict_kwargs if predict_kwargs else {} metadata = self.heads.get(head, with_metadata=True) adapter = metadata["metadata"]["adapter"].from_task( self, @@ -64,6 +67,7 @@ def __init__( backbone=backbone, head=head, pretrained=pretrained, + predict_kwargs=predict_kwargs, **kwargs, ) @@ -78,3 +82,7 @@ def __init__( def _ci_benchmark_fn(self, history: List[Dict[str, Any]]) -> None: """This function is used only for debugging usage with CI.""" # todo + + def set_predict_kwargs(self, value): + """This function is used to update the kwargs used for the prediction step""" + self.adapter.predict_kwargs = value From 91fa8cb8a3e777a93e59d865cab729ee6e4ac8f6 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 22 Nov 2021 17:33:29 +0000 Subject: [PATCH 4/7] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- flash/image/instance_segmentation/model.py | 2 +- flash/image/keypoint_detection/model.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/flash/image/instance_segmentation/model.py b/flash/image/instance_segmentation/model.py index ba10e8cd8f..2331e2f8fc 100644 --- a/flash/image/instance_segmentation/model.py +++ b/flash/image/instance_segmentation/model.py @@ -102,5 +102,5 @@ def on_load_checkpoint(self, checkpoint: Dict[str, Any]) -> None: ) def set_predict_kwargs(self, value): - """This function is used to update the kwargs used for the prediction step""" + """This function is used to update the kwargs used for the prediction step.""" self.adapter.predict_kwargs = value diff --git a/flash/image/keypoint_detection/model.py b/flash/image/keypoint_detection/model.py index 55f02af209..921c170a02 100644 --- a/flash/image/keypoint_detection/model.py +++ b/flash/image/keypoint_detection/model.py @@ -84,5 +84,5 @@ def _ci_benchmark_fn(self, history: List[Dict[str, Any]]) -> None: # todo def set_predict_kwargs(self, value): - """This function is used to update the kwargs used for the prediction step""" + """This function is used to update the kwargs used for the prediction step.""" self.adapter.predict_kwargs = value From 00ebdd2bb075885a0660518855f2cc3254afad30 Mon Sep 17 00:00:00 2001 From: Luca Actis Grosso Date: Tue, 23 Nov 2021 17:17:24 +0100 Subject: [PATCH 5/7] update predict_kwargs using python property for IceVision Integration and update CHANGELOG.md --- CHANGELOG.md | 2 ++ flash/image/detection/model.py | 11 ++++++++--- flash/image/instance_segmentation/model.py | 11 ++++++++--- flash/image/keypoint_detection/model.py | 11 ++++++++--- tests/image/detection/test_model.py | 2 +- 5 files changed, 27 insertions(+), 10 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index d26cdcd97c..4934a0a2ce 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -8,6 +8,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Added +- Added predict_kwargs in ObjectDetector, InstanceSegmentation, KeypointDetector ([#990](https://github.com/PyTorchLightning/lightning-flash/pull/990)) + - Added backbones for `GraphClassifier` ([#592](https://github.com/PyTorchLightning/lightning-flash/pull/592)) - Added `GraphEmbedder` task ([#592](https://github.com/PyTorchLightning/lightning-flash/pull/592)) diff --git a/flash/image/detection/model.py b/flash/image/detection/model.py index 277fadb169..94905f81e5 100644 --- a/flash/image/detection/model.py +++ b/flash/image/detection/model.py @@ -80,6 +80,11 @@ def _ci_benchmark_fn(self, history: List[Dict[str, Any]]) -> None: """This function is used only for debugging usage with CI.""" # todo - def set_predict_kwargs(self, value): - """This function is used to update the kwargs used for the prediction step.""" - self.adapter.predict_kwargs = value + @property + def predict_kwargs(self) -> Dict[str, Any]: + """The kwargs used for the prediction step.""" + return self.adapter.predict_kwargs + + @predict_kwargs.setter + def predict_kwargs(self, predict_kwargs: Dict[str, Any]): + self.adapter.predict_kwargs = predict_kwargs diff --git a/flash/image/instance_segmentation/model.py b/flash/image/instance_segmentation/model.py index 2331e2f8fc..eb0e257653 100644 --- a/flash/image/instance_segmentation/model.py +++ b/flash/image/instance_segmentation/model.py @@ -101,6 +101,11 @@ def on_load_checkpoint(self, checkpoint: Dict[str, Any]) -> None: output_transform=InstanceSegmentationOutputTransform(), ) - def set_predict_kwargs(self, value): - """This function is used to update the kwargs used for the prediction step.""" - self.adapter.predict_kwargs = value + @property + def predict_kwargs(self) -> Dict[str, Any]: + """The kwargs used for the prediction step.""" + return self.adapter.predict_kwargs + + @predict_kwargs.setter + def predict_kwargs(self, predict_kwargs: Dict[str, Any]): + self.adapter.predict_kwargs = predict_kwargs diff --git a/flash/image/keypoint_detection/model.py b/flash/image/keypoint_detection/model.py index 921c170a02..1993ee1ac9 100644 --- a/flash/image/keypoint_detection/model.py +++ b/flash/image/keypoint_detection/model.py @@ -83,6 +83,11 @@ def _ci_benchmark_fn(self, history: List[Dict[str, Any]]) -> None: """This function is used only for debugging usage with CI.""" # todo - def set_predict_kwargs(self, value): - """This function is used to update the kwargs used for the prediction step.""" - self.adapter.predict_kwargs = value + @property + def predict_kwargs(self) -> Dict[str, Any]: + """The kwargs used for the prediction step.""" + return self.adapter.predict_kwargs + + @predict_kwargs.setter + def predict_kwargs(self, predict_kwargs: Dict[str, Any]): + self.adapter.predict_kwargs = predict_kwargs diff --git a/tests/image/detection/test_model.py b/tests/image/detection/test_model.py index a73cdc059c..903948a6de 100644 --- a/tests/image/detection/test_model.py +++ b/tests/image/detection/test_model.py @@ -148,6 +148,6 @@ def test_predict(tmpdir, head): dl = model.process_predict_dataset(ds, batch_size=2) predictions = trainer.predict(model, dl) assert len(predictions[0][0]["bboxes"]) > 0 - model.set_predict_kwargs({"detection_threshold": 2}) + model.predict_kwargs = {"detection_threshold": 2} predictions = trainer.predict(model, dl) assert len(predictions[0][0]["bboxes"]) == 0 From 29824046adb1068968017eff1b1883febb33df7e Mon Sep 17 00:00:00 2001 From: Luca Actis Grosso Date: Tue, 23 Nov 2021 17:18:03 +0100 Subject: [PATCH 6/7] update CHANGELOG.md --- CHANGELOG.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 4934a0a2ce..b750483eff 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -8,7 +8,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Added -- Added predict_kwargs in ObjectDetector, InstanceSegmentation, KeypointDetector ([#990](https://github.com/PyTorchLightning/lightning-flash/pull/990)) +- Added predict_kwargs in `ObjectDetector`, `InstanceSegmentation`, `KeypointDetector` ([#990](https://github.com/PyTorchLightning/lightning-flash/pull/990)) - Added backbones for `GraphClassifier` ([#592](https://github.com/PyTorchLightning/lightning-flash/pull/592)) From 98563c41eac65d436f03b5a176d91d14f5915321 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 23 Nov 2021 16:19:02 +0000 Subject: [PATCH 7/7] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- CHANGELOG.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index b750483eff..5f7860d7fa 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -9,7 +9,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Added - Added predict_kwargs in `ObjectDetector`, `InstanceSegmentation`, `KeypointDetector` ([#990](https://github.com/PyTorchLightning/lightning-flash/pull/990)) - + - Added backbones for `GraphClassifier` ([#592](https://github.com/PyTorchLightning/lightning-flash/pull/592)) - Added `GraphEmbedder` task ([#592](https://github.com/PyTorchLightning/lightning-flash/pull/592))