Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

Commit

Permalink
Fix backbone freezing for question answering and speech recognition (#…
Browse files Browse the repository at this point in the history
…1275)

Co-authored-by: Kushashwa Ravi Shrimali <kushashwaravishrimali@gmail.com>
  • Loading branch information
ethanwharris and krshrimali committed Apr 13, 2022
1 parent 48a1556 commit 6fd639b
Show file tree
Hide file tree
Showing 7 changed files with 19 additions and 96 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
### Fixed

- Fixed a bug where some backbones were incorrectly listed as available for the `ObjectDetector`, `InstanceSegmentation`, and `KeypointDetector` ([#1267](https://github.com/PyTorchLightning/lightning-flash/pull/1267))
- Fixed a bug where the backbone would not be frozen when finetuning the `SpeechRecognition` task ([#1275](https://github.com/PyTorchLightning/lightning-flash/pull/1275))
- Fixed a bug where the backbone would not be frozen when finetuning the `QuestionAnswering` task with certain model types ([#1275](https://github.com/PyTorchLightning/lightning-flash/pull/1275))

## [0.7.2] - 2022-03-30

Expand Down
3 changes: 3 additions & 0 deletions flash/audio/speech_recognition/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,9 @@ def __init__(
else AutoProcessor.from_pretrained(processor_backbone)
)

def modules_to_freeze(self) -> Optional[nn.Module]:
return self.model.base_model

def forward(self, batch: Dict[str, torch.Tensor]):
return self.model(batch["input_values"])

Expand Down
93 changes: 0 additions & 93 deletions flash/text/question_answering/finetuning.py

This file was deleted.

3 changes: 1 addition & 2 deletions flash/text/question_answering/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,6 @@
from flash.core.utilities.types import LR_SCHEDULER_TYPE, METRICS_TYPE, OPTIMIZER_TYPE
from flash.text.ort_callback import ORTCallback
from flash.text.question_answering.collate import TextQuestionAnsweringCollate
from flash.text.question_answering.finetuning import _get_question_answering_bacbones_for_freezing
from flash.text.question_answering.output_transform import QuestionAnsweringOutputTransform

if _TEXT_AVAILABLE:
Expand Down Expand Up @@ -322,7 +321,7 @@ def _initialize_model_specific_parameters(self):

def modules_to_freeze(self) -> Union[Module, Iterable[Union[Module, Iterable]]]:
"""Return the module attributes of the model to be frozen."""
return _get_question_answering_bacbones_for_freezing(self.model)
return self.model.base_model

def configure_callbacks(self) -> List[Callback]:
callbacks = super().configure_callbacks() or []
Expand Down
2 changes: 1 addition & 1 deletion flash_examples/speech_recognition.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@

# 3. Create the trainer and finetune the model
trainer = flash.Trainer(max_epochs=1, gpus=torch.cuda.device_count())
trainer.finetune(model, datamodule=datamodule, strategy="no_freeze")
trainer.finetune(model, datamodule=datamodule, strategy="freeze")

# 4. Predict on audio files!
datamodule = SpeechRecognitionData.from_files(predict_files=["data/timit/example.wav"], batch_size=4)
Expand Down
6 changes: 6 additions & 0 deletions tests/audio/speech_recognition/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,12 @@ def __len__(self) -> int:
TEST_BACKBONE = "patrickvonplaten/wav2vec2_tiny_random_robust" # tiny model for testing


@pytest.mark.skipif(not _AUDIO_TESTING, reason="audio libraries aren't installed.")
def test_modules_to_freeze():
model = SpeechRecognition(backbone=TEST_BACKBONE)
assert model.modules_to_freeze() is model.model.wav2vec2


@pytest.mark.skipif(os.name == "nt", reason="Huggingface timing out on Windows")
@pytest.mark.skipif(not _AUDIO_TESTING, reason="audio libraries aren't installed.")
def test_init_train(tmpdir):
Expand Down
6 changes: 6 additions & 0 deletions tests/text/question_answering/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,12 @@ def __len__(self) -> int:
TEST_BACKBONE = "distilbert-base-uncased"


@pytest.mark.skipif(not _TEXT_TESTING, reason="text libraries aren't installed.")
def test_modules_to_freeze():
model = QuestionAnsweringTask(backbone=TEST_BACKBONE)
assert model.modules_to_freeze() is model.model.distilbert


@pytest.mark.skipif(os.name == "nt", reason="Huggingface timing out on Windows")
@pytest.mark.skipif(not _TEXT_TESTING, reason="text libraries aren't installed.")
def test_init_train(tmpdir):
Expand Down

0 comments on commit 6fd639b

Please sign in to comment.