Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor TransformerBlock for serialization #306

Merged
merged 2 commits into from
Nov 3, 2021

Conversation

sararb
Copy link
Contributor

@sararb sararb commented Nov 2, 2021

Goals ⚽

  • Improve the signature HF transformers call to fix the issue of shape mismatch when saving the model. This error is linked to non-used arguments being initialized with wrong tensors: The previous call method provided transformer_kwargs twice, e.g. mask_perm was set equal to input_embed instead of None).

  • Extract from the HF model the serializable layer TF*MainLayer. In fact, all the HF models are defined with one custom Keras layer containing the architecture graph. (e.g. tfXLNetModel)

Testing Details 🔍

  • Update test_transformer.py to include Albert config.
  • Add a unit test that checks the serialization of the TransformerBlock

@sararb sararb merged commit 72a1953 into main Nov 3, 2021
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants