From 36377b7c7ac561fd14afca96e0462c875646895f Mon Sep 17 00:00:00 2001 From: dudeperf3ct Date: Fri, 4 Mar 2022 02:29:40 +0530 Subject: [PATCH] Fix normalizing video classification input (#1213) Co-authored-by: Ethan Harris --- CHANGELOG.md | 2 ++ flash/video/classification/input_transform.py | 7 ++++++- 2 files changed, 8 insertions(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index f63aa56e74..97a91fc881 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -20,6 +20,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Fixed +- Fixed normalizing inputs to video classification ([#1213](https://github.com/PyTorchLightning/lightning-flash/pull/1213)) + - Fixed examples (question answering), where NLTK's `punkt` module needs to be downloaded first. ([#1215](https://github.com/PyTorchLightning/lightning-flash/pull/1215/files)) - Fixed a bug where DDP would not work with Flash tasks ([#1182](https://github.com/PyTorchLightning/lightning-flash/pull/1182)) diff --git a/flash/video/classification/input_transform.py b/flash/video/classification/input_transform.py index a441a2abb9..5626be5e8e 100644 --- a/flash/video/classification/input_transform.py +++ b/flash/video/classification/input_transform.py @@ -30,6 +30,10 @@ ClipSampler, LabeledVideoDataset, EncodedVideo, ApplyTransformToKey = None, None, None, None +def normalize(x: torch.Tensor) -> torch.Tensor: + return x / 255.0 + + @requires("video") @dataclass class VideoClassificationInputTransform(InputTransform): @@ -48,7 +52,8 @@ def per_sample_transform(self) -> Callable: per_sample_transform = [CenterCrop(self.image_size)] return ApplyToKeys( - "video", Compose([UniformTemporalSubsample(self.temporal_sub_sample)] + per_sample_transform) + "video", + Compose([UniformTemporalSubsample(self.temporal_sub_sample), normalize] + per_sample_transform), ) def per_batch_transform_on_device(self) -> Callable: