Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

Add text embedder #996

Merged
Merged
Show file tree
Hide file tree
Changes from 33 commits
Commits
Show all changes
34 commits
Select commit Hold shift + click to select a range
98ed529
Sentence Embedder API using sentence transformers
abhijithneilabraham Nov 24, 2021
83fbf1e
remove train, test and pred step
abhijithneilabraham Nov 24, 2021
4ad3abb
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 24, 2021
8cd9395
sentence embedders with forward step and predict step
abhijithneilabraham Dec 5, 2021
47bbce7
Merge branch 'master' into ST_embeddings
abhijithneilabraham Dec 5, 2021
34f39d1
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Dec 5, 2021
219042a
Update __init__.py
abhijithneilabraham Dec 5, 2021
9f9a965
Merge branch 'ST_embeddings' of https://github.com/abhijithneilabraha…
abhijithneilabraham Dec 5, 2021
4b3c772
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Dec 5, 2021
5fecb22
Merge branch 'master' into ST_embeddings
ethanwharris Dec 8, 2021
06e35bd
Updates
ethanwharris Dec 8, 2021
2071c9f
Create test_model.py
abhijithneilabraham Dec 8, 2021
5477415
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Dec 8, 2021
8db110d
__init__ for embedding
abhijithneilabraham Dec 8, 2021
2c09c71
Merge branch 'master' into ST_embeddings
abhijithneilabraham Dec 8, 2021
a6bfc9f
remove download_data()
abhijithneilabraham Dec 8, 2021
602d9e1
Merge branch 'ST_embeddings' of https://github.com/abhijithneilabraha…
abhijithneilabraham Dec 8, 2021
b29b6e0
Merge branch 'master' into ST_embeddings
abhijithneilabraham Dec 9, 2021
e771883
Merge branch 'master' into ST_embeddings
abhijithneilabraham Dec 9, 2021
21305d6
lower size model for text embededer examples and test
abhijithneilabraham Dec 9, 2021
5d1b4c6
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Dec 9, 2021
9570522
text embedder example entry to CI
abhijithneilabraham Dec 9, 2021
8cfd877
Merge branch 'master' into ST_embeddings
abhijithneilabraham Dec 9, 2021
bb98d77
change `SentenceEmbedder` to `TextEmbedder`
abhijithneilabraham Dec 9, 2021
923e6ec
remove `download_data` import
abhijithneilabraham Dec 9, 2021
8c90286
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Dec 9, 2021
20233f2
fix bug - test_model.py
abhijithneilabraham Dec 9, 2021
33f30e6
Merge branch 'ST_embeddings' of https://github.com/abhijithneilabraha…
abhijithneilabraham Dec 9, 2021
57aa577
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Dec 9, 2021
fdbb2de
Update test_model.py
abhijithneilabraham Dec 9, 2021
6f5b9e8
Merge branch 'ST_embeddings' of https://github.com/abhijithneilabraha…
abhijithneilabraham Dec 9, 2021
14a5e27
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Dec 9, 2021
3d14659
Update CHANGELOG.md
abhijithneilabraham Dec 9, 2021
f69f207
Merge branch 'master' into ST_embeddings
tchaton Dec 9, 2021
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
## [Unreleased] - YYYY-DD-MM

### Added
- Added `TextEmbedder` task ([#996](https://github.com/PyTorchLightning/lightning-flash/pull/996))

- Added predict_kwargs in `ObjectDetector`, `InstanceSegmentation`, `KeypointDetector` ([#990](https://github.com/PyTorchLightning/lightning-flash/pull/990))

Expand Down
2 changes: 2 additions & 0 deletions flash/core/utilities/imports.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,7 @@ def _compare_version(package: str, op, version) -> bool:
_ALBUMENTATIONS_AVAILABLE = _module_available("albumentations")
_BAAL_AVAILABLE = _module_available("baal")
_TORCH_OPTIMIZER_AVAILABLE = _module_available("torch_optimizer")
_SENTENCE_TRANSFORMERS_AVAILABLE = _module_available("sentence_transformers")


if _PIL_AVAILABLE:
Expand All @@ -130,6 +131,7 @@ class Image:
_SENTENCEPIECE_AVAILABLE,
_DATASETS_AVAILABLE,
_TM_TEXT_AVAILABLE,
_SENTENCE_TRANSFORMERS_AVAILABLE,
]
)
_TABULAR_AVAILABLE = _TABNET_AVAILABLE and _PANDAS_AVAILABLE and _FORECASTING_AVAILABLE
Expand Down
1 change: 1 addition & 0 deletions flash/core/utilities/providers.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ def __str__(self):
_LEARN2LEARN = Provider("learnables/learn2learn", "https://github.com/learnables/learn2learn")
_PYSTICHE = Provider("pystiche/pystiche", "https://github.com/pystiche/pystiche")
_HUGGINGFACE = Provider("Hugging Face/transformers", "https://github.com/huggingface/transformers")
_SENTENCE_TRANSFORMERS = Provider("UKPLab/sentence-transformers", "https://github.com/UKPLab/sentence-transformers")
_FAIRSEQ = Provider("PyTorch/fairseq", "https://github.com/pytorch/fairseq")
_OPEN3D_ML = Provider("Intelligent Systems Lab Org/Open3D-ML", "https://github.com/isl-org/Open3D-ML")
_PYTORCHVIDEO = Provider("Facebook Research/PyTorchVideo", "https://github.com/facebookresearch/pytorchvideo")
Expand Down
1 change: 1 addition & 0 deletions flash/text/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from flash.text.classification import TextClassificationData, TextClassifier # noqa: F401
from flash.text.embedding import TextEmbedder # noqa: F401
from flash.text.question_answering import QuestionAnsweringData, QuestionAnsweringTask # noqa: F401
from flash.text.seq2seq import ( # noqa: F401
Seq2SeqData,
Expand Down
1 change: 1 addition & 0 deletions flash/text/embedding/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from flash.text.embedding.model import TextEmbedder # noqa: F401
14 changes: 14 additions & 0 deletions flash/text/embedding/backbones.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
from flash.core.registry import ExternalRegistry, FlashRegistry
from flash.core.utilities.imports import _TEXT_AVAILABLE
from flash.core.utilities.providers import _HUGGINGFACE

if _TEXT_AVAILABLE:
from transformers import AutoModel

HUGGINGFACE_BACKBONES = ExternalRegistry(
AutoModel.from_pretrained,
"backbones",
_HUGGINGFACE,
)
else:
HUGGINGFACE_BACKBONES = FlashRegistry("backbones")
106 changes: 106 additions & 0 deletions flash/text/embedding/model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import os
import warnings
from typing import Any, Dict, List, Optional

import torch
from pytorch_lightning import Callback

from flash.core.integrations.transformers.states import TransformersBackboneState
from flash.core.model import Task
from flash.core.registry import FlashRegistry, print_provider_info
from flash.core.utilities.imports import _TEXT_AVAILABLE
from flash.core.utilities.providers import _SENTENCE_TRANSFORMERS
from flash.text.embedding.backbones import HUGGINGFACE_BACKBONES
from flash.text.ort_callback import ORTCallback

if _TEXT_AVAILABLE:
from sentence_transformers.models import Pooling

Pooling = print_provider_info("Pooling", _SENTENCE_TRANSFORMERS, Pooling)

logger = logging.getLogger(__name__)


class TextEmbedder(Task):
"""The ``TextEmbedder`` is a :class:`~flash.Task` for generating sentence embeddings, training and validation.
For more details, see `embeddings`.

You can change the backbone to any question answering model from `UKPLab/sentence-transformers
<https://github.com/UKPLab/sentence-transformers>`_ using the ``backbone``
argument.

Args:
backbone: backbone model to use for the task.
enable_ort: Enable Torch ONNX Runtime Optimization: https://onnxruntime.ai/docs/#onnx-runtime-for-training
"""

required_extras: str = "text"

backbones: FlashRegistry = HUGGINGFACE_BACKBONES

def __init__(
self,
backbone: str = "sentence-transformers/all-MiniLM-L6-v2",
tokenizer_backbone: Optional[str] = None,
tokenizer_kwargs: Optional[Dict[str, Any]] = None,
enable_ort: bool = False,
):
os.environ["TOKENIZERS_PARALLELISM"] = "TRUE"
# disable HF thousand warnings
warnings.simplefilter("ignore")
# set os environ variable for multiprocesses
os.environ["PYTHONWARNINGS"] = "ignore"
super().__init__()

if tokenizer_backbone is None:
tokenizer_backbone = backbone
self.set_state(TransformersBackboneState(tokenizer_backbone, tokenizer_kwargs=tokenizer_kwargs))
self.model = self.backbones.get(backbone)()
self.pooling = Pooling(self.model.config.hidden_size)
self.enable_ort = enable_ort

def training_step(self, batch: Any, batch_idx: int) -> Any:
raise NotImplementedError("Training a `TextEmbedder` is not supported. Use a different text task instead.")

def validation_step(self, batch: Any, batch_idx: int) -> Any:
raise NotImplementedError("Validating a `TextEmbedder` is not supported. Use a different text task instead.")

def test_step(self, batch: Any, batch_idx: int) -> Any:
tchaton marked this conversation as resolved.
Show resolved Hide resolved
raise NotImplementedError("Testing a `TextEmbedder` is not supported. Use a different text task instead.")

def forward(self, batch: Dict[str, torch.Tensor]) -> torch.Tensor:
"""Adapted from sentence-transformers:

https://github.com/UKPLab/sentence-transformers/blob/master/sentence_transformers/models/Transformer.py#L45
"""

trans_features = {"input_ids": batch["input_ids"], "attention_mask": batch["attention_mask"]}
if "token_type_ids" in batch:
trans_features["token_type_ids"] = batch["token_type_ids"]

output_states = self.model(**trans_features, return_dict=False)
output_tokens = output_states[0]

batch.update({"token_embeddings": output_tokens, "attention_mask": batch["attention_mask"]})

return self.pooling(batch)["sentence_embedding"]

def configure_callbacks(self) -> List[Callback]:
callbacks = super().configure_callbacks() or []
if self.enable_ort:
callbacks.append(ORTCallback())
return callbacks
34 changes: 34 additions & 0 deletions flash_examples/text_embedder.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch

import flash
from flash.text import TextClassificationData, TextEmbedder

# 1. Create the DataModule
datamodule = TextClassificationData.from_lists(
predict_data=[
"Turgid dialogue, feeble characterization - Harvey Keitel a judge?.",
"The worst movie in the history of cinema.",
"I come from Bulgaria where it 's almost impossible to have a tornado.",
]
)

# 2. Load a previously trained TextEmbedder
model = TextEmbedder(backbone="sentence-transformers/all-MiniLM-L6-v2")

# 3. Generate embeddings for the first 3 graphs
trainer = flash.Trainer(gpus=torch.cuda.device_count())
predictions = trainer.predict(model, datamodule=datamodule)
print(predictions)
1 change: 1 addition & 0 deletions requirements/datatype_text.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,4 @@ filelock
transformers>=4.5
torchmetrics[text]>=0.5.1
datasets>=1.8,<1.13
sentence-transformers
4 changes: 4 additions & 0 deletions tests/examples/test_scripts.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,10 @@
"text_classification.py",
marks=pytest.mark.skipif(not _TEXT_TESTING, reason="text libraries aren't installed"),
),
pytest.param(
"text_embedder.py",
marks=pytest.mark.skipif(not _TEXT_TESTING, reason="text libraries aren't installed"),
),
# pytest.param(
# "text_classification_multi_label.py",
# marks=pytest.mark.skipif(not _TEXT_TESTING, reason="text libraries aren't installed")
Expand Down
Empty file.
43 changes: 43 additions & 0 deletions tests/text/embedding/test_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os

import pytest
import torch

import flash
from flash.text import TextClassificationData, TextEmbedder
from tests.helpers.utils import _TEXT_TESTING

# ======== Mock data ========

predict_data = [
"Turgid dialogue, feeble characterization - Harvey Keitel a judge?.",
"The worst movie in the history of cinema.",
"I come from Bulgaria where it 's almost impossible to have a tornado.",
]
# ==============================

TEST_BACKBONE = "sentence-transformers/all-MiniLM-L6-v2" # super small model for testing


@pytest.mark.skipif(os.name == "nt", reason="Huggingface timing out on Windows")
@pytest.mark.skipif(not _TEXT_TESTING, reason="text libraries aren't installed.")
def test_predict(tmpdir):
datamodule = TextClassificationData.from_lists(predict_data=predict_data)
model = TextEmbedder(backbone=TEST_BACKBONE)

trainer = flash.Trainer(gpus=torch.cuda.device_count())
predictions = trainer.predict(model, datamodule=datamodule)
assert [t.size() for t in predictions[0]] == [torch.Size([384]), torch.Size([384]), torch.Size([384])]