Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

Commit

Permalink
Add CLIP backbones for text / image classification (#1458)
Browse files Browse the repository at this point in the history
  • Loading branch information
ethanwharris authored Oct 1, 2022
1 parent 63eeec1 commit eb11c35
Show file tree
Hide file tree
Showing 20 changed files with 405 additions and 54 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Added

- Added support for CLIP backbones to the `TextClassifier` and `ImageClassifier` tasks ([#1458](https://github.com/Lightning-AI/lightning-flash/pull/1458))


### Changed

Expand Down
2 changes: 1 addition & 1 deletion docs/source/quickstart.rst
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ Here's an example of inference:
from flash.text import TextClassifier, TextClassificationData

# 1. Init the finetuned task from URL
model = TextClassifier.load_from_checkpoint("https://flash-weights.s3.amazonaws.com/0.7.0/text_classification_model.pt")
model = TextClassifier.load_from_checkpoint("https://flash-weights.s3.amazonaws.com/0.9.0/text_classification_model.pt")

# 2. Perform inference from list of sequences
trainer = Trainer()
Expand Down
14 changes: 12 additions & 2 deletions flash/core/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,11 +205,13 @@ def __init__(
name: str,
providers: Optional[Union[Provider, List[Provider]]] = None,
verbose: bool = False,
**metadata,
):
super().__init__(name, verbose=verbose)

self.getter = getter
self.providers = providers if providers is None or isinstance(providers, list) else [providers]
self.metadata = metadata

def __contains__(self, item):
"""Contains is always ``True`` for an ``ExternalRegistry`` as we can't know whether the getter will fail
Expand All @@ -228,7 +230,10 @@ def get(
fn = functools.partial(self.getter, key)
if self.providers is not None:
fn = print_provider_info(key, self.providers, fn)
return fn

if not with_metadata:
return fn
return {"fn": fn, "metadata": self.metadata}

def available_keys(self) -> List[str]:
"""Since we don't know the available keys, just give a generic message."""
Expand All @@ -242,7 +247,12 @@ class ConcatRegistry(FlashRegistry):

def __init__(self, *registries: FlashRegistry):
super().__init__(
",".join({registry.name for registry in registries}),
",".join(
{
registry.name
for registry in sorted(registries, key=lambda r: 1 if isinstance(r, ExternalRegistry) else 0)
}
),
verbose=any(registry._verbose for registry in registries),
)

Expand Down
32 changes: 31 additions & 1 deletion flash/core/utilities/embedder.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,14 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any
from typing import Any, Optional

from pytorch_lightning import LightningModule
from torch.utils.data import DataLoader, Sampler

import flash
from flash.core.data.io.input import InputBase
from flash.core.data.io.input_transform import InputTransform
from flash.core.model import Task


Expand All @@ -33,6 +37,32 @@ def __init__(self, model: LightningModule, layer: str):
self._handle = None
self._out = None

def process_predict_dataset(
self,
dataset: InputBase,
batch_size: int,
num_workers: int = 0,
pin_memory: bool = False,
shuffle: bool = False,
drop_last: bool = False,
sampler: Optional[Sampler] = None,
persistent_workers: bool = False,
input_transform: Optional[InputTransform] = None,
trainer: Optional["flash.Trainer"] = None,
) -> DataLoader:
return self.model.process_predict_dataset(
dataset,
batch_size,
num_workers,
pin_memory,
shuffle,
drop_last,
sampler,
persistent_workers,
input_transform,
trainer,
)

def _make_hook(self):
def hook(_, __, output):
self._out = output
Expand Down
1 change: 1 addition & 0 deletions flash/core/utilities/providers.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ def __str__(self):

_TIMM = Provider("rwightman/pytorch-image-models", "https://github.com/rwightman/pytorch-image-models")
_DINO = Provider("Facebook Research/dino", "https://github.com/facebookresearch/dino")
_CLIP = Provider("OpenAI/CLIP", "https://github.com/openai/CLIP")
_ICEVISION = Provider("airctic/IceVision", "https://github.com/airctic/icevision")
_TORCHVISION = Provider("PyTorch/torchvision", "https://github.com/pytorch/vision")
_ULTRALYTICS = Provider("Ultralytics/YOLOV5", "https://github.com/ultralytics/yolov5")
Expand Down
2 changes: 1 addition & 1 deletion flash/image/classification/adapters.py
Original file line number Diff line number Diff line change
Expand Up @@ -530,7 +530,7 @@ def test_step(self, batch: Any, batch_idx: int) -> Any:

def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> Any:
batch[DataKeys.PREDS] = Task.predict_step(
self._task, (batch[DataKeys.INPUT]), batch_idx, dataloader_idx=dataloader_idx
self._task, batch[DataKeys.INPUT], batch_idx, dataloader_idx=dataloader_idx
)
return batch

Expand Down
4 changes: 3 additions & 1 deletion flash/image/classification/backbones/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from flash.core.registry import FlashRegistry # noqa: F401
from flash.core.registry import FlashRegistry
from flash.image.classification.backbones.clip import register_clip_backbones # noqa: F401
from flash.image.classification.backbones.resnet import register_resnet_backbones # noqa: F401
from flash.image.classification.backbones.timm import register_timm_backbones # noqa: F401
from flash.image.classification.backbones.torchvision import ( # noqa: F401
Expand All @@ -12,6 +13,7 @@

register_resnet_backbones(IMAGE_CLASSIFIER_BACKBONES)
register_dino_backbones(IMAGE_CLASSIFIER_BACKBONES)
register_clip_backbones(IMAGE_CLASSIFIER_BACKBONES)

register_mobilenet_vgg_backbones(IMAGE_CLASSIFIER_BACKBONES)
register_resnext_model(IMAGE_CLASSIFIER_BACKBONES)
Expand Down
58 changes: 58 additions & 0 deletions flash/image/classification/backbones/clip.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from functools import partial

import torch
from torch import nn

from flash.core.registry import FlashRegistry
from flash.core.utilities.providers import _CLIP
from flash.core.utilities.url_error import catch_url_error

# Paper: Learning Transferable Visual Models From Natural Language Supervision
# https://arxiv.org/abs/2103.00020 from Alec Radford et. al. (26 Feb 2021)
# weights from https://github.com/openai/CLIP


_CLIP_MODELS = {
"RN50": "resnet50",
"RN101": "resnet101",
"RN50x4": "resrnet50x4",
"RN50x16": "resrnet50x16",
"RN50x64": "resrnet50x64",
"ViT_B_32": "vitb32",
"ViT_B_16": "vitb16",
"ViT_L_14": "vitl14",
"ViT_L_14_336px": "vitl14@336px",
}


class _CLIPWrapper(nn.Module):
def __init__(self, clip_model: nn.Module):
super().__init__()

self.clip_model = clip_model

def forward(self, x):
return self.clip_model.encode_image(x)


def _load_clip(model_name: str, **kwargs):
backbone, _ = torch.hub.load("openai/CLIP:main", model_name)
return _CLIPWrapper(backbone), backbone.visual.output_dim


def register_clip_backbones(register: FlashRegistry):
for clip_model_name, flash_model_name in _CLIP_MODELS.items():
register(catch_url_error(partial(_load_clip, clip_model_name)), f"clip_{flash_model_name}", providers=_CLIP)
2 changes: 1 addition & 1 deletion flash/image/segmentation/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,4 +210,4 @@ def serve(
@staticmethod
def _ci_benchmark_fn(history: List[Dict[str, Any]]):
"""This function is used only for debugging usage with CI."""
assert history[-1]["val_jaccardindex"] > 0.2
assert history[-1]["val_jaccardindex"] > 0.1
165 changes: 165 additions & 0 deletions flash/text/classification/adapters.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,165 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import warnings
from dataclasses import dataclass
from types import FunctionType
from typing import Any, Callable, Dict

import torch
from torch import Tensor

from flash.core.adapter import Adapter, AdapterTask
from flash.core.data.io.input import DataKeys
from flash.core.model import Task
from flash.core.registry import FlashRegistry
from flash.core.utilities.imports import _TRANSFORMERS_AVAILABLE
from flash.image.classification.heads import IMAGE_CLASSIFIER_HEADS
from flash.text.classification.collate import TextClassificationCollate

if _TRANSFORMERS_AVAILABLE:
from transformers.modeling_outputs import Seq2SeqSequenceClassifierOutput, SequenceClassifierOutput


class HuggingFaceAdapter(Adapter):
def __init__(self, backbone, num_classes: int, max_length: int = 128):
super().__init__()

os.environ["TOKENIZERS_PARALLELISM"] = "TRUE"
# disable HF thousand warnings
warnings.simplefilter("ignore")
# set os environ variable for multiprocesses
os.environ["PYTHONWARNINGS"] = "ignore"

self.model, tokenizer = backbone(num_classes)
self.collate_fn = TextClassificationCollate(tokenizer, max_length=max_length)

@classmethod
def from_task(
cls,
task: AdapterTask,
backbone: str,
num_classes: int,
**kwargs,
) -> Adapter:
adapter = cls(backbone, num_classes, **kwargs)
adapter.__dict__["_task"] = task
return adapter

@property
def backbone(self):
return self.model.base_model

def forward(self, batch: Dict[str, Tensor]):
result = self.model(input_ids=batch.get("input_ids", None), attention_mask=batch.get("attention_mask", None))
if isinstance(result, (SequenceClassifierOutput, Seq2SeqSequenceClassifierOutput)):
result = result.logits
return result

def training_step(self, batch: Any, batch_idx: int) -> Any:
target = batch.pop(DataKeys.TARGET)
batch = (batch, target)
return Task.training_step(self._task, batch, batch_idx)

def validation_step(self, batch: Any, batch_idx: int) -> None:
target = batch.pop(DataKeys.TARGET)
batch = (batch, target)
return Task.validation_step(self._task, batch, batch_idx)

def test_step(self, batch: Any, batch_idx: int) -> None:
target = batch.pop(DataKeys.TARGET)
batch = (batch, target)
return Task.test_step(self._task, batch, batch_idx)

def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> Any:
return self(batch)


@dataclass
class GenericCollate:

tokenizer: Callable[[str], Any]

@staticmethod
def to_tensor(sample: Dict[str, Any]) -> Dict[str, Any]:
tensor_sample = {}
for key in sample:
if key is DataKeys.METADATA:
tensor_sample[key] = sample[key]
else:
tensor_sample[key] = torch.tensor(sample[key])
return tensor_sample

def tokenize(self, sample):
sample[DataKeys.INPUT] = self.tokenizer(sample[DataKeys.INPUT])
return sample

def __call__(self, samples):
return self.to_tensor(self.tokenize({key: [sample[key] for sample in samples] for key in samples[0].keys()}))


class GenericAdapter(Adapter):

# TODO: Move IMAGE_CLASSIFIIER_HEADS out for general classification tasks
heads: FlashRegistry = IMAGE_CLASSIFIER_HEADS

def __init__(self, backbone, num_classes: int, max_length: int = 128, head="linear"):
super().__init__()

self.backbone, tokenizer, num_features = backbone()

self.collate_fn = GenericCollate(tokenizer)

if isinstance(head, str):
head = self.heads.get(head)(num_features=num_features, num_classes=num_classes)
else:
head = head(num_features, num_classes) if isinstance(head, FunctionType) else head

self.head = head

@classmethod
def from_task(
cls,
task: AdapterTask,
backbone: str,
num_classes: int,
**kwargs,
) -> Adapter:
adapter = cls(backbone, num_classes, **kwargs)
adapter.__dict__["_task"] = task
return adapter

def training_step(self, batch: Any, batch_idx: int) -> Any:
batch = (batch[DataKeys.INPUT], batch[DataKeys.TARGET])
return Task.training_step(self._task, batch, batch_idx)

def validation_step(self, batch: Any, batch_idx: int) -> Any:
batch = (batch[DataKeys.INPUT], batch[DataKeys.TARGET])
return Task.validation_step(self._task, batch, batch_idx)

def test_step(self, batch: Any, batch_idx: int) -> Any:
batch = (batch[DataKeys.INPUT], batch[DataKeys.TARGET])
return Task.test_step(self._task, batch, batch_idx)

def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> Any:
batch[DataKeys.PREDS] = Task.predict_step(
self._task, batch[DataKeys.INPUT], batch_idx, dataloader_idx=dataloader_idx
)
return batch

def forward(self, x) -> Tensor:
x = self.backbone(x)
if x.dim() == 4:
x = x.mean(-1).mean(-1)
return self.head(x)
5 changes: 5 additions & 0 deletions flash/text/classification/backbones/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
from flash.core.registry import FlashRegistry
from flash.text.classification.backbones.clip import CLIP_BACKBONES
from flash.text.classification.backbones.huggingface import HUGGINGFACE_BACKBONES

TEXT_CLASSIFIER_BACKBONES = FlashRegistry("backbones") + CLIP_BACKBONES + HUGGINGFACE_BACKBONES
Loading

0 comments on commit eb11c35

Please sign in to comment.