Skip to content

Commit

Permalink
Cleanup tensorflow dependencies (#530)
Browse files Browse the repository at this point in the history
* Remove tensorflow from requiremnents extras

* Remove import of tensorflow from `transformers4rec/config/transformer`

* Move tensorflow-metadata from dev to base requirements.

This is used in the `merlin_standard_lib` package which is used by `transformers4rec`

* Remove unused tensorflow-estimator and tensorflow-ranking from dev requirements

* Remove test of `to_huggingface_tf_model`

Co-authored-by: Sara Rabhi <sara.rabhi@gmail.com>
  • Loading branch information
oliverholworthy and sararb authored Nov 21, 2022
1 parent a8ae652 commit 1c5736f
Show file tree
Hide file tree
Showing 6 changed files with 2 additions and 62 deletions.
1 change: 1 addition & 0 deletions requirements/base.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,4 @@ transformers<4.19
tqdm>=4.27
betterproto<2.0.0
pyarrow>=1.0
tensorflow-metadata
3 changes: 0 additions & 3 deletions requirements/dev.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,5 @@ flake8
isort
bandit
mypy==0.971
tensorflow-metadata
tensorflow-estimator==2.6.*
tensorflow-ranking>=0.4
codespell
click<8.1.0
1 change: 0 additions & 1 deletion requirements/tensorflow.txt

This file was deleted.

1 change: 0 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,6 @@ def read_requirements(filename):

requirements = {
"base": read_requirements("requirements/base.txt"),
"tensorflow": read_requirements("requirements/tensorflow.txt"),
"pytorch": read_requirements("requirements/pytorch.txt"),
"nvtabular": read_requirements("requirements/nvtabular.txt"),
"docs": read_requirements("requirements/docs.txt"),
Expand Down
11 changes: 1 addition & 10 deletions tests/config/test_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
#

import pytest
from transformers import PreTrainedModel, TFPreTrainedModel
from transformers import PreTrainedModel

from transformers4rec.config import transformer as tconf

Expand All @@ -35,12 +35,3 @@ def test_to_hugginface_torch_model(config_cls):
model = config.to_huggingface_torch_model()

assert isinstance(model, PreTrainedModel)


@pytest.mark.parametrize("config_cls", list(set(config_classes) - {tconf.ReformerConfig}))
def test_to_hugginface_tf_model(config_cls):
config = config_cls.build(100, 4, 2, 20)

model = config.to_huggingface_tf_model()

assert isinstance(model, TFPreTrainedModel)
47 changes: 0 additions & 47 deletions transformers4rec/config/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,6 @@

from merlin_standard_lib import Registry

try:
import tensorflow as tf
except ImportError:
tf = None

transformer_registry: Registry = Registry("transformers")


Expand Down Expand Up @@ -63,48 +58,6 @@ def to_torch_model(
loss_reduction=loss_reduction,
).to_model(**kwargs)

if tf:

def to_tf_model(
self,
input_features,
*prediction_task,
task_blocks=None,
task_weights=None,
loss_reduction=tf.reduce_mean,
**kwargs
):
from .. import tf as tf4rec

if not isinstance(input_features, tf4rec.TabularSequenceFeatures):
raise ValueError("`input_features` must an instance of SequentialTabularFeatures")
if not all(isinstance(t, tf4rec.PredictionTask) for t in prediction_task):
raise ValueError(
"`task` is of the wrong type, please provide one or multiple "
"instance(s) of PredictionTask"
)

body = tf4rec.SequentialBlock(
[input_features, tf4rec.TransformerBlock(self, masking=input_features.masking)]
)

return tf4rec.Model(
tf4rec.Head(
body,
*prediction_task,
task_blocks=task_blocks,
task_weights=task_weights,
loss_reduction=loss_reduction,
inputs=input_features,
),
**kwargs,
)

def to_huggingface_tf_model(self):
model_cls = transformers.TF_MODEL_MAPPING[self.transformers_config_cls]

return model_cls(self)

@property
def transformers_config_cls(self):
return self.__class__.__bases__[1]
Expand Down

0 comments on commit 1c5736f

Please sign in to comment.