Skip to content

Commit

Permalink
refactor TransformerBlock for serialization (#306)
Browse files Browse the repository at this point in the history
  • Loading branch information
sararb authored Nov 3, 2021
1 parent 3d638c2 commit 72a1953
Show file tree
Hide file tree
Showing 3 changed files with 77 additions and 23 deletions.
31 changes: 25 additions & 6 deletions tests/tf/block/test_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,14 @@
from transformers4rec.config import transformer as tconf

tr = pytest.importorskip("transformers4rec.tf")
test_utils = pytest.importorskip("transformers4rec.tf.utils.testing_utils")

config_classes = [
tconf.XLNetConfig,
tconf.LongformerConfig,
tconf.GPT2Config,
tconf.BertConfig,
tconf.RobertaConfig,
tconf.AlbertConfig,
]

Expand All @@ -34,19 +37,17 @@

# Test output of XLNet with different masking tasks using SequentialBlock
@pytest.mark.parametrize("task", lm_tasks[2:])
def test_transformer_block(yoochoose_schema, tf_yoochoose_like, task):
@pytest.mark.parametrize("config", config_classes)
def test_transformer_block(yoochoose_schema, tf_yoochoose_like, task, config):

col_group = yoochoose_schema
tab_module = tr.TabularSequenceFeatures.from_schema(
col_group,
yoochoose_schema,
max_sequence_length=20,
aggregation="concat",
masking=task,
)

transformer_config = tconf.XLNetConfig.build(
d_model=64, n_head=4, n_layer=2, total_seq_length=20
)
transformer_config = config.build(d_model=64, n_head=4, n_layer=2, total_seq_length=20)

block = tr.SequentialBlock(
[
Expand All @@ -62,6 +63,24 @@ def test_transformer_block(yoochoose_schema, tf_yoochoose_like, task):
assert outputs.shape[-1] == 64


@pytest.mark.parametrize("config", config_classes)
def test_serialization_transformer_block(
tf_yoochoose_tabular_sequence_features, tf_yoochoose_like, config
):

transformer_config = config.build(d_model=64, n_head=4, n_layer=2, total_seq_length=20)
transformer_block = tr.TransformerBlock(transformer_config)
copy_transformer = test_utils.assert_serialization(transformer_block)

body = tr.SequentialBlock(
[tf_yoochoose_tabular_sequence_features, tr.MLPBlock([64]), copy_transformer]
)

outputs = body(tf_yoochoose_like)
assert outputs.ndim == 3
assert outputs.shape[-1] == 64


# Test output of XLNet with permutation language model using SequentialBlock
# def test_xlnet_with_plm(yoochoose_schema, torch_yoochoose_like):
#
Expand Down
36 changes: 29 additions & 7 deletions transformers4rec/tf/block/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,14 @@

from ...config.transformer import T4RecConfig, transformer_registry
from ..masking import MaskSequence
from ..utils.tf_utils import (
get_tf_main_layer,
maybe_deserialize_keras_objects,
maybe_serialize_keras_objects,
)
from .base import Block

TransformerBody = Union[TFPreTrainedModel, PretrainedConfig]
TransformerBody = Union[TFPreTrainedModel, PretrainedConfig, tf.keras.layers.Layer]


class TransformerPrepare(tf.keras.layers.Layer):
Expand All @@ -46,7 +51,8 @@ class TransformerBlock(Block):
Parameters
----------
transformer: TransformerBody
The T4RecConfig or a pre-trained HF object related to specific transformer architecture.
The T4RecConfig, The pre-trained HF model or the custom keras layer TF*MainLayer,
related to specific transformer architecture.
masking:
Needed when masking is applied on the inputs.
"""
Expand All @@ -67,10 +73,12 @@ def __init__(

self.transformer: TFPreTrainedModel
if isinstance(transformer, T4RecConfig):
self.transformer = transformer.to_huggingface_tf_model()
self.transformer = get_tf_main_layer(transformer.to_huggingface_tf_model())
elif isinstance(transformer, PretrainedConfig):
model_cls = transformers.TF_MODEL_MAPPING[transformer.__class__]
self.transformer = model_cls(transformer)
self.transformer = get_tf_main_layer(model_cls(transformer))
elif isinstance(transformer, TFPreTrainedModel):
self.transformer = get_tf_main_layer(transformer)
else:
self.transformer = transformer

Expand All @@ -95,6 +103,21 @@ def __init__(
self.prepare_module = prepare_module(transformer, masking)
self.output_fn = output_fn

def get_config(self):
config = super().get_config()
config = maybe_serialize_keras_objects(
self, config, ["transformer", "prepare_module", "masking"]
)
return config

@classmethod
def from_config(cls, config):
config = maybe_deserialize_keras_objects(
config, ["transformer", "prepare_module", "masking"]
)

return super().from_config(config)

@classmethod
def from_registry(
cls,
Expand Down Expand Up @@ -156,9 +179,8 @@ def call(self, inputs_embeds: tf.Tensor, **kwargs):
if param in transformer_kwargs:
filtered_transformer_kwargs[param] = transformer_kwargs[param]

# In Keras the first (inputs) arg always needs to be set, therefore we supply the
# transformer_kwargs both as arg and **kwargs
model_outputs = self.transformer(transformer_kwargs, **filtered_transformer_kwargs)
# In HF the call accept inputs as a dictionnary contaning all needed tensors
model_outputs = self.transformer(filtered_transformer_kwargs)
outputs = self.output_fn(model_outputs)

# TODO: store the attention outputs for meta-data logging
Expand Down
33 changes: 23 additions & 10 deletions transformers4rec/tf/utils/tf_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,14 +116,24 @@ def calculate_batch_size_from_input_shapes(input_shapes):
return [i for i in input_shapes.values() if not isinstance(i, tuple)][0][0]


def get_tf_main_layer(hf_model):
"""
Extract serializable custom keras layer `TF*MainLayer` from the HF model
"""
main_layer = [v for _, v in hf_model.__dict__.items() if isinstance(v, tf.keras.layers.Layer)][
0
]
return main_layer


def maybe_serialize_keras_objects(
self,
config,
maybe_serialize_keys,
):
for key in maybe_serialize_keys:
maybe_value = getattr(self, key, None)
if maybe_value:
if maybe_value is not None:
if isinstance(maybe_value, dict):
config[key] = {
k: tf.keras.utils.serialize_keras_object(v) for k, v in maybe_value.items()
Expand Down Expand Up @@ -156,27 +166,30 @@ def maybe_deserialize_keras_objects(


def extract_topk(ks, scores, labels):
max_k = int(max(ks))
max_k = tf.reduce_max(ks)
topk_scores, topk_indices = tf.math.top_k(scores, max_k)
topk_labels = gather_torch_like(labels, topk_indices, max_k)
return topk_scores, topk_indices, topk_labels


def tranform_label_to_onehot(labels, vocab_size):
return tf.one_hot(tf.reshape(labels, -1), vocab_size)
return tf.one_hot(tf.reshape(labels, (-1,)), vocab_size)


def create_output_placeholder(scores, ks):
return tf.Variable(tf.zeros([scores.shape[0], len(ks)], tf.float32))
return tf.Variable(tf.zeros([tf.shape(scores)[0], len(ks)], tf.float32))


def gather_torch_like(labels, indices, max_k):
gather_indices = []
for i in range(indices.shape[0]):
gather_indices.append(
# gather_indices = []
gather_indices = tf.TensorArray(tf.int32, size=tf.shape(indices)[0])
for i in range(tf.shape(indices)[0]):
gather_indices = gather_indices.write(
i,
tf.concat(
[i * tf.ones((max_k, 1), tf.int32), tf.expand_dims(indices[i, :], -1)], axis=1
)
),
)
all_indices = tf.concat(gather_indices, 0)
return tf.reshape(tf.gather_nd(labels, all_indices), indices.shape)
all_indices = gather_indices.stack()
labels = tf.reshape(tf.gather_nd(labels, all_indices), tf.shape(indices))
return labels

0 comments on commit 72a1953

Please sign in to comment.