-
Notifications
You must be signed in to change notification settings - Fork 3.1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
SASRec in Tensorflow 2.x #1530
SASRec in Tensorflow 2.x #1530
Conversation
Check out this pull request on See visual diffs & provide feedback on Jupyter Notebooks. Powered by ReviewNB |
Hey Abir, we use staging as our development branch. So we always merge PRs into staging and never to main (except staging -> main). |
Since this PR will depend on upgrading TF, let's keep reviewing but leave it open until we can put everything in the upgrade together. |
@@ -0,0 +1,59 @@ | |||
# Copyright (c) Microsoft Corporation. All rights reserved. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think we need a separate test pipeline for TF2. Once we upgrade the TF version in setup.py the current pipeline will switch to version 2. So you don't need this file.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
+1
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
removed this file.
Could you please add docstrings to those functions that should appear in the documentation page on readthedocs? See https://github.com/Microsoft/Recommenders/wiki/Coding-Guidelines#python-and-docstrings-style |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Each new notebook needs to have some description in the beginning. See other notebooks in the examples. The description is a summary of the method and what scenario it addresses. There may be text in other places in the notebook too, if needed for clarification.
Moreover, the method needs to be included in the table of methods we have in the README.
recommenders/models/sasrec/model.py
Outdated
- each tuple (q, k, v) are fed to scaled_dot_product_attention | ||
- all attention outputs are concatenated | ||
""" | ||
class MultiHeadAttention(tf.keras.layers.Layer): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you please add docstrings?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added docstrings to all the methods.
recommenders/models/sasrec/util.py
Outdated
def evaluate(model, dataset, maxlen, num_neg_test): | ||
[train, valid, test, usernum, itemnum] = copy.deepcopy(dataset) | ||
|
||
NDCG = 0.0 | ||
HT = 0.0 | ||
valid_user = 0.0 | ||
|
||
if usernum>10000: | ||
users = random.sample(range(1, usernum + 1), 10000) | ||
else: | ||
users = range(1, usernum + 1) | ||
|
||
for u in tqdm(users, ncols=70, leave=False, unit='b'): | ||
|
||
if len(train[u]) < 1 or len(test[u]) < 1: continue | ||
|
||
# print(train[u]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In other notebooks, we are standardizing the way we evaluate, so we are comparing apples to apples, would it be possible to use the evaluation functions that we have in the repo? https://github.com/microsoft/recommenders/blob/main/recommenders/evaluation/python_evaluation.py
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The metrics that are used here, NDCG@10 and Hit@10 are same as what is available in deeprec_utils.py (https://github.com/microsoft/recommenders/blob/main/recommenders/models/deeprec/deeprec_utils.py). Currently I cannot include them since Tf2.x is not supported (from deeprec_utils import ndcg_score, hit_score does not work). Once we migrate to TF 2.x it is easy to invoke those functions.
@@ -0,0 +1,59 @@ | |||
# Copyright (c) Microsoft Corporation. All rights reserved. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
+1
@@ -0,0 +1,321 @@ | |||
{ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
not sure I understand the dataset creation here. In other notebooks we do a short explanation of the dataset, please see: https://github.com/microsoft/recommenders/blob/main/examples/00_quick_start/lstur_MIND.ipynb
Also, there is a step to download the dataset. We have some light wrappers on them like https://github.com/microsoft/recommenders/blob/main/recommenders/datasets/amazon_reviews.py
Reply via ReviewNB
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have elaborated on the data format and dataset creation. Let me know if I need to add more.
@@ -0,0 +1,321 @@ | |||
{ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Line #8. # conv_dims = kwargs.get("conv_dims", [100, 100])
I think this guy should be:
conv_dims=[100,100]
because kwargs is not defined here?
Actually, you may want to add a parameter here at the beginning like the other parameters
Reply via ReviewNB
@@ -0,0 +1,321 @@ | |||
{ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
would it be possible to run the notebook and show the logs? in the rest of the notebooks we follow that pattern
Reply via ReviewNB
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added the logs.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is really good Abir, I added some comments, please take a look
class MultiHeadAttention(tf.keras.layers.Layer): | ||
""" | ||
- Q (query), K (key) and V (value) are split into multiple heads (num_heads) | ||
- each tuple (q, k, v) are fed to scaled_dot_product_attention | ||
- all attention outputs are concatenated | ||
""" | ||
|
||
def __init__(self, attention_dim, num_heads, dropout_rate): | ||
super(MultiHeadAttention, self).__init__() | ||
self.num_heads = num_heads | ||
self.attention_dim = attention_dim | ||
assert attention_dim % self.num_heads == 0 | ||
self.dropout_rate = dropout_rate | ||
|
||
self.depth = attention_dim // self.num_heads |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This code for multihead attention is the same as in the other file? If so, it would be good to refactor to not repeat code
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, the code is refactored to use SASRec as base class.
class SSEPT(tf.keras.Model): | ||
""" | ||
SSE-PT Model | ||
|
||
:Citation: | ||
|
||
Wu L., Li S., Hsieh C-J., Sharpnack J., SSE-PT: Sequential Recommendation | ||
Via Personalized Transformer, RecSys, 2020. | ||
TF 1.x codebase: https://github.com/SSE-PT/SSE-PT | ||
TF 2.x codebase (SASREc): https://github.com/nnkkmto/SASRec-tf2 | ||
|
||
Args: | ||
item_num: number of items in the dataset | ||
seq_max_len: maximum number of items in user history | ||
num_blocks: number of Transformer blocks to be used | ||
embedding_dim: item embedding dimension | ||
attention_dim: Transformer attention dimension | ||
conv_dims: list of the dimensions of the Feedforward layer | ||
dropout_rate: dropout rate | ||
l2_reg: coefficient of the L2 regularization | ||
num_neg_test: number of negative examples used in testing | ||
""" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
question, how different are SASRec from SSE-PT?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The difference between SASRec and SSE-PT is that SSE-PT also creates user embedding whereas, SASRec has only item embeddings. The authors have shown that SSE-PT performance is better than that of SASRec on 5 datasets (by 5%).
|
||
|
||
def sample_function( | ||
user_train, usernum, itemnum, batch_size, maxlen, result_queue, SEED |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
minor detail, typically we have all the inputs in lowercase, so SEED we would put it as seed
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
changed to seed.
setup.py
Outdated
"tqdm>=4.31.1,<5", | ||
"matplotlib>=2.2.2,<4", | ||
"scikit-learn>=0.22.1,<1", | ||
"numba>=0.38.1,<1", | ||
# "numba>=0.38.1,<1", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
mmm I think this would break the repo. Abir does numba conflicts with this code?
FYI @anargyri
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
numba was causing problem due to LLVM module. I have installed it separately, so no more issue with numba.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, numba is needed for the repo.
setup.py
Outdated
"pandas>1.0.3,<2", | ||
"scipy>=1.0.0,<2", | ||
"scipy==1.4.1", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
also, not sure if pinning scipy can generate problems, Abir, does the code break with scipy>=1.0.0,<2
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes, the problem is with the particular tensorflow-gpu (2.3.0) version that has a very specific requirement.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, I think we need to be careful when modifying setup.py
; it can be a breaking change and it affects any PyPI release that will happen following the change.
Moreover, we are going to TF 2.6 (2.3 still suffers from some vulnerabilities). I think the best way is to ensure your code conforms with 2.6 from this PR already. Could you go through the migration process described here https://www.tensorflow.org/guide/migrate (i.e. run the script they provide and make the suggested changes in the syntax)?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You would need to use this setup.py (note the CUDA version too).
setup.py
Outdated
"tensorflow-gpu==2.3.0", # compiled with CUDA 10.0 | ||
"torch==1.9.1", # last os-common version with CUDA 10.0 support | ||
"six~=1.15.0", | ||
"typing-extensions~=3.7.4", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
same thing here, at this point if we add TF 2 we will break the repo
@anargyri is working in the transition, but so far we can't add this change
@pytest.mark.notebooks | ||
def test_sasrec_single_node_runs(notebooks, output_notebook, kernel_name): | ||
notebook_path = notebooks["sasrec_quickstart"] | ||
pm.execute_notebook(notebook_path, output_notebook, kernel_name=kernel_name) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
how long does this test take? if it takes too long we can add it in smoke or integration
@pytest.mark.notebooks | ||
def test_template_runs(notebooks, output_notebook, kernel_name): | ||
notebook_path = notebooks["template"] | ||
pm.execute_notebook( | ||
notebook_path, | ||
output_notebook, | ||
parameters=dict(PM_VERSION=pm.__version__), | ||
kernel_name=kernel_name, | ||
) | ||
nb = sb.read_notebook(output_notebook) | ||
df = nb.papermill_dataframe | ||
assert df.shape[0] == 2 | ||
check_version = df.loc[df["name"] == "checked_version", "value"].values[0] | ||
assert check_version is True |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is this needed here? my understanding was that there was one example like this in the same folder?
@@ -0,0 +1,139 @@ | |||
# Copyright (c) Microsoft Corporation. All rights reserved. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
one thing that is missing are the unit tests of the classes themselves. See for example: https://github.com/microsoft/recommenders/blob/main/tests/unit/recommenders/models/test_deeprec_model.py
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@aeroabir not sure if you saw this comment
@aeroabir something important, when you are doing the commits GitHub is not linking your username with the commits. I think it is important for everyone to show on github that the code that you developed is attributed to you. It took you a long time to do this and we should really make sure your work is showcased. Once this is set up, you would see that your name will automatically appear in the contributor list https://github.com/microsoft/recommenders/graphs/contributors |
setup.py
Outdated
"pandas>1.0.3,<2", | ||
"scipy>=1.0.0,<2", | ||
"scipy==1.4.1", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
|
||
import os | ||
import pytest | ||
import tensorflow as tf |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think that we can have the TF2 tests in the same file as TF, because we are going to upgrade soon FYI @anargyri
@@ -0,0 +1,139 @@ | |||
# Copyright (c) Microsoft Corporation. All rights reserved. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@aeroabir not sure if you saw this comment
@aeroabir This PR should be ready to revisit now that the TensorFlow upgrade has been checked into staging. |
Also could you add an entry in the table that lists the algos in the front page README? |
Could you also put SAS alphabetically in the table (after SAR) ? |
Another place to add SASrec is here. |
Codecov Report
@@ Coverage Diff @@
## staging #1530 +/- ##
============================================
+ Coverage 0.00% 59.05% +59.05%
============================================
Files 84 88 +4
Lines 8462 8997 +535
============================================
+ Hits 0 5313 +5313
- Misses 0 3684 +3684
Flags with carried forward coverage won't be shown. Click here to find out more.
Continue to review full report at Codecov.
|
…enders into staging_abir_tf2
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
hey Abir, I think the code is in good shape. There is one issue to discuss about the splitter
README.md
Outdated
@@ -118,8 +118,10 @@ The table below lists the recommender algorithms currently available in the repo | |||
| Restricted Boltzmann Machines (RBM) | Collaborative Filtering | Neural network based algorithm for learning the underlying probability distribution for explicit or implicit user/item feedback. It works in the CPU/GPU enviroment. | [Quick start](examples/00_quick_start/rbm_movielens.ipynb) / [Deep dive](examples/02_model_collaborative_filtering/rbm_deep_dive.ipynb) | | |||
| Riemannian Low-rank Matrix Completion (RLRMC)<sup>*</sup> | Collaborative Filtering | Matrix factorization algorithm using Riemannian conjugate gradients optimization with small memory consumption to predice user/item interactions. It works in the CPU enviroment. | [Quick start](examples/00_quick_start/rlrmc_movielens.ipynb) | | |||
| Simple Algorithm for Recommendation (SAR)<sup>*</sup> | Collaborative Filtering | Similarity-based algorithm for implicit user/item feedback. It works in the CPU environment. | [Quick start](examples/00_quick_start/sar_movielens.ipynb) / [Deep dive](examples/02_model_collaborative_filtering/sar_deep_dive.ipynb) | | |||
| Self-Attentive Sequential Recommendation (SASRec) | Sequential | Transformer based algorithm for sequential recommendation. It works in the CPU/GPU environment. | [Deep dive](examples/00_quick_start/sasrec_amazon.ipynb) | |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think here it would be [Quick start](examples/00_quick_start/sasrec_amazon.ipynb)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In the type, right now we only have as options Collaborative Filtering, Content based and Hybrid. We don't have sequential. However, we actually discussed to change the types in the past, but we didn't take any decision.
If you think it would be a good idea to change the types, would it be ok if we leave this as one of the old types, just for consistency, and then start a discussion about other types?
README.md
Outdated
| Short-term and Long-term Preference Integrated Recommender (SLi-Rec)<sup>*</sup> | Collaborative Filtering | Sequential-based algorithm that aims to capture both long and short-term user preferences using attention mechanism, a time-aware controller and a content-aware controller. It works in the CPU/GPU environment. | [Quick start](examples/00_quick_start/sequential_recsys_amazondataset.ipynb) | | ||
| Multi-Interest-Aware Sequential User Modeling (SUM)<sup>*</sup> | Collaborative Filtering | An enhanced memory network-based sequential user model which aims to capture users' multiple interests. It works in the CPU/GPU environment. | [Quick start](examples/00_quick_start/sequential_recsys_amazondataset.ipynb) | | ||
| Sequential Recommendation Via Personalized Transformer (SSEPT) | Sequential | Transformer based algorithm for sequential recommendation with User embedding. It works in the CPU/GPU environment. | [Deep dive](examples/00_quick_start/sasrec_amazon.ipynb) | |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same note here as before with the type and quick start
tests/conftest.py
Outdated
# TF2.x | ||
"sasrec_quickstart": os.path.join( | ||
folder_notebooks, "00_quick_start", "sasrec_amazon.ipynb" | ||
), | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
minor detail, the other notebooks are ordered based on the folder 00_quick_start
, "02_model_content_based_filtering"
. Would you mind to follow the same structure?
) | ||
], | ||
) | ||
# @pytest.mark.skipif(tf.__versoin__ > "2.0", reason="We are currently on TF 1.5") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
# @pytest.mark.skipif(tf.__versoin__ > "2.0", reason="We are currently on TF 1.5")
this can be updated now
def split(self, **kwargs): | ||
self.filename = kwargs.get("filename", self.filename) | ||
if not self.filename: | ||
raise ValueError("Filename is required") | ||
|
||
if self.with_time: | ||
self.data_partition_with_time() | ||
else: | ||
self.data_partition() | ||
|
||
def data_partition(self): | ||
# assume user/item index starting from 1 | ||
f = open(self.filename, "r") | ||
for line in f: | ||
u, i = line.rstrip().split(self.col_sep) | ||
u = int(u) | ||
i = int(i) | ||
self.usernum = max(u, self.usernum) | ||
self.itemnum = max(i, self.itemnum) | ||
self.User[u].append(i) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
One thing for you to consider Abir. In the other algos, we always split the data using a common splitters: https://github.com/microsoft/recommenders/blob/main/recommenders/datasets/python_splitters.py that helps us to compare algorithms, in an apple to apple fashion. If we use the original splitters we have we, we won't be able to compare these algos with the other ones in the repo.
What are your thoughts on this?
# Amazon Electronics Data | ||
itemnum = 85930 | ||
maxlen = 50 | ||
num_blocks = 2 | ||
hidden_units = 100 | ||
num_heads = 1 | ||
dropout_rate = 0.1 | ||
l2_emb = 0.0 | ||
num_neg_test = 100 | ||
|
||
model = SASREC( | ||
item_num=itemnum, | ||
seq_max_len=maxlen, | ||
num_blocks=num_blocks, | ||
embedding_dim=hidden_units, | ||
attention_dim=hidden_units, | ||
attention_num_heads=num_heads, | ||
dropout_rate=dropout_rate, | ||
conv_dims=[100, 100], | ||
l2_reg=l2_emb, | ||
num_neg_test=num_neg_test, | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
An elegant way of adding input data to tests are fixtures, see how they work here: https://github.com/miguelgfierro/pybase/blob/master/test/pytest_fixtures.py
you could create a dictionary of fixtures of Amazon electronics data and input it in the test
def random_neq(left, right, s): | ||
t = np.random.randint(left, right) | ||
while t in s: | ||
t = np.random.randint(left, right) | ||
return t | ||
|
||
|
||
def sample_function( | ||
user_train, usernum, itemnum, batch_size, maxlen, result_queue, seed | ||
): | ||
""" | ||
Batch sampler that creates a sequence of negative items based on the | ||
original sequence of items (positive) that the user has interacted with. | ||
""" | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if you have time, it would be great to add the missing docstrings. This is an example of the format we follow:
def check_column_dtypes_wrapper(
rating_true,
rating_pred,
col_user=DEFAULT_USER_COL,
col_item=DEFAULT_ITEM_COL,
col_rating=DEFAULT_RATING_COL,
col_prediction=DEFAULT_PREDICTION_COL,
*args,
**kwargs
):
"""Check columns of DataFrame inputs
Args:
rating_true (pandas.DataFrame): True data
rating_pred (pandas.DataFrame): Predicted data
col_user (str): column name for user
col_item (str): column name for item
col_rating (str): column name for rating
col_prediction (str): column name for prediction
"""
see here more info: https://github.com/microsoft/recommenders/blob/main/recommenders/evaluation/python_evaluation.py#L51
@@ -0,0 +1,912 @@ | |||
{ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is a very nice function, it removes users and items with less than k interactions. @anargyri @simonzhaoms @angusrtaylor @aeroabir do you see us using this function in others part of the repo?
Reply via ReviewNB
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actually we have a function already here. You may use it instead.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@aeroabir if you want to do this change now, please let me know, otherwise we can create an issue and leave it for another PR
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good to me! really great job Abir!
Description
Added SASRec (Wang-Cheng Kang, Julian McAuley (2018). Self-Attentive Sequential Recommendation. In Proceedings of IEEE International Conference on Data Mining (ICDM'18)) coded in TF 2.x
This is to add newer algorithms especially the ones based on Transformers.
Related Issues
None
Checklist:
staging branch
and not tomain branch
.