Quick fix of compute_loss to be able to use fit method in tensorflow #275
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Goals ⚽
The current implementation of compute_loss does not use the call methods of the
Head
andModel
classes. This raises an issue for NextItemPredictionTask as the build method requiresbody
to retrieve necessary information about items embeddings and masking. This class is only accessed via the build method ofHead
.Implementation Details 🚧
The current PR presents a quick fix by exposing an argument
call_task
in thecompute_loss
ofPredictionTask
classes. The compute_loss methods of Head / Model are first computing the predictions through their call methods (to correctly build the related tasks), then computing the tasks losses withcall_task=False
.Testing Details 🔍
This PR adds a test in test_model.py that defines a model with NextItemPredictionTask head, call the
fit
andevaluate
methods, then check that losses and metrics are correctly computed.