Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SASRec in Tensorflow 2.x #1530

Merged
merged 23 commits into from
Jan 18, 2022

Conversation

aeroabir
Copy link
Collaborator

Description

Added SASRec (Wang-Cheng Kang, Julian McAuley (2018). Self-Attentive Sequential Recommendation. In Proceedings of IEEE International Conference on Data Mining (ICDM'18)) coded in TF 2.x

This is to add newer algorithms especially the ones based on Transformers.

Related Issues

None

Checklist:

  • [ x] I have followed the contribution guidelines and code style for this project.
  • [ x] I have added tests covering my contributions.
  • I have updated the documentation accordingly.
  • [ x] This PR is being made to staging branch and not to main branch.

@review-notebook-app
Copy link

Check out this pull request on  ReviewNB

See visual diffs & provide feedback on Jupyter Notebooks.


Powered by ReviewNB

@anargyri anargyri changed the base branch from main to staging September 20, 2021 10:47
@anargyri
Copy link
Collaborator

Hey Abir, we use staging as our development branch. So we always merge PRs into staging and never to main (except staging -> main).

@anargyri
Copy link
Collaborator

Since this PR will depend on upgrading TF, let's keep reviewing but leave it open until we can put everything in the upgrade together.

@@ -0,0 +1,59 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think we need a separate test pipeline for TF2. Once we upgrade the TF version in setup.py the current pipeline will switch to version 2. So you don't need this file.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

removed this file.

@anargyri
Copy link
Collaborator

anargyri commented Sep 21, 2021

Could you please add docstrings to those functions that should appear in the documentation page on readthedocs? See https://github.com/Microsoft/Recommenders/wiki/Coding-Guidelines#python-and-docstrings-style
Don't forget to edit docs/source too.

Copy link
Collaborator

@anargyri anargyri left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Each new notebook needs to have some description in the beginning. See other notebooks in the examples. The description is a summary of the method and what scenario it addresses. There may be text in other places in the notebook too, if needed for clarification.

Moreover, the method needs to be included in the table of methods we have in the README.

- each tuple (q, k, v) are fed to scaled_dot_product_attention
- all attention outputs are concatenated
"""
class MultiHeadAttention(tf.keras.layers.Layer):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you please add docstrings?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added docstrings to all the methods.

recommenders/models/sasrec/model.py Show resolved Hide resolved
recommenders/models/sasrec/sampler.py Show resolved Hide resolved
Comment on lines 85 to 101
def evaluate(model, dataset, maxlen, num_neg_test):
[train, valid, test, usernum, itemnum] = copy.deepcopy(dataset)

NDCG = 0.0
HT = 0.0
valid_user = 0.0

if usernum>10000:
users = random.sample(range(1, usernum + 1), 10000)
else:
users = range(1, usernum + 1)

for u in tqdm(users, ncols=70, leave=False, unit='b'):

if len(train[u]) < 1 or len(test[u]) < 1: continue

# print(train[u])
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In other notebooks, we are standardizing the way we evaluate, so we are comparing apples to apples, would it be possible to use the evaluation functions that we have in the repo? https://github.com/microsoft/recommenders/blob/main/recommenders/evaluation/python_evaluation.py

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The metrics that are used here, NDCG@10 and Hit@10 are same as what is available in deeprec_utils.py (https://github.com/microsoft/recommenders/blob/main/recommenders/models/deeprec/deeprec_utils.py). Currently I cannot include them since Tf2.x is not supported (from deeprec_utils import ndcg_score, hit_score does not work). Once we migrate to TF 2.x it is easy to invoke those functions.

@@ -0,0 +1,59 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1

tests/integration/examples/test_notebooks_tf2.py Outdated Show resolved Hide resolved
examples/00_quick_start/sasrec-tf2.ipynb Outdated Show resolved Hide resolved
examples/00_quick_start/sasrec-tf2.ipynb Outdated Show resolved Hide resolved
examples/00_quick_start/sasrec-tf2.ipynb Outdated Show resolved Hide resolved
examples/00_quick_start/sasrec-tf2.ipynb Outdated Show resolved Hide resolved
examples/00_quick_start/sasrec-tf2.ipynb Outdated Show resolved Hide resolved
@@ -0,0 +1,321 @@
{
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not sure I understand the dataset creation here. In other notebooks we do a short explanation of the dataset, please see: https://github.com/microsoft/recommenders/blob/main/examples/00_quick_start/lstur_MIND.ipynb

Also, there is a step to download the dataset. We have some light wrappers on them like https://github.com/microsoft/recommenders/blob/main/recommenders/datasets/amazon_reviews.py


Reply via ReviewNB

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have elaborated on the data format and dataset creation. Let me know if I need to add more.

@@ -0,0 +1,321 @@
{
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Line #8.                #    conv_dims = kwargs.get("conv_dims", [100, 100])

I think this guy should be:

conv_dims=[100,100]

because kwargs is not defined here?

Actually, you may want to add a parameter here at the beginning like the other parameters



Reply via ReviewNB

examples/00_quick_start/sasrec-tf2.ipynb Outdated Show resolved Hide resolved
examples/00_quick_start/sasrec-tf2.ipynb Outdated Show resolved Hide resolved
@@ -0,0 +1,321 @@
{
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

would it be possible to run the notebook and show the logs? in the rest of the notebooks we follow that pattern


Reply via ReviewNB

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added the logs.

Copy link
Collaborator

@miguelgfierro miguelgfierro left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is really good Abir, I added some comments, please take a look

Comment on lines 8 to 22
class MultiHeadAttention(tf.keras.layers.Layer):
"""
- Q (query), K (key) and V (value) are split into multiple heads (num_heads)
- each tuple (q, k, v) are fed to scaled_dot_product_attention
- all attention outputs are concatenated
"""

def __init__(self, attention_dim, num_heads, dropout_rate):
super(MultiHeadAttention, self).__init__()
self.num_heads = num_heads
self.attention_dim = attention_dim
assert attention_dim % self.num_heads == 0
self.dropout_rate = dropout_rate

self.depth = attention_dim // self.num_heads
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This code for multihead attention is the same as in the other file? If so, it would be good to refactor to not repeat code

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, the code is refactored to use SASRec as base class.

Comment on lines 266 to 287
class SSEPT(tf.keras.Model):
"""
SSE-PT Model

:Citation:

Wu L., Li S., Hsieh C-J., Sharpnack J., SSE-PT: Sequential Recommendation
Via Personalized Transformer, RecSys, 2020.
TF 1.x codebase: https://github.com/SSE-PT/SSE-PT
TF 2.x codebase (SASREc): https://github.com/nnkkmto/SASRec-tf2

Args:
item_num: number of items in the dataset
seq_max_len: maximum number of items in user history
num_blocks: number of Transformer blocks to be used
embedding_dim: item embedding dimension
attention_dim: Transformer attention dimension
conv_dims: list of the dimensions of the Feedforward layer
dropout_rate: dropout rate
l2_reg: coefficient of the L2 regularization
num_neg_test: number of negative examples used in testing
"""
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

question, how different are SASRec from SSE-PT?

Copy link
Collaborator Author

@aeroabir aeroabir Oct 7, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The difference between SASRec and SSE-PT is that SSE-PT also creates user embedding whereas, SASRec has only item embeddings. The authors have shown that SSE-PT performance is better than that of SASRec on 5 datasets (by 5%).



def sample_function(
user_train, usernum, itemnum, batch_size, maxlen, result_queue, SEED
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

minor detail, typically we have all the inputs in lowercase, so SEED we would put it as seed

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

changed to seed.

recommenders/models/sasrec/sampler.py Outdated Show resolved Hide resolved
setup.py Outdated
"tqdm>=4.31.1,<5",
"matplotlib>=2.2.2,<4",
"scikit-learn>=0.22.1,<1",
"numba>=0.38.1,<1",
# "numba>=0.38.1,<1",
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

mmm I think this would break the repo. Abir does numba conflicts with this code?

FYI @anargyri

Copy link
Collaborator Author

@aeroabir aeroabir Oct 7, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

numba was causing problem due to LLVM module. I have installed it separately, so no more issue with numba.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, numba is needed for the repo.

setup.py Outdated
"pandas>1.0.3,<2",
"scipy>=1.0.0,<2",
"scipy==1.4.1",
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

also, not sure if pinning scipy can generate problems, Abir, does the code break with scipy>=1.0.0,<2

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, the problem is with the particular tensorflow-gpu (2.3.0) version that has a very specific requirement.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@aerobir, we can't a pin scipy, it is too restrictive. @anargyri is working to upgrade the repo to TF>2.5, also restricting numpy to <1.19 could be a problem.
Do you know if this code works with the same TF version that @anargyri is using?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I think we need to be careful when modifying setup.py; it can be a breaking change and it affects any PyPI release that will happen following the change.
Moreover, we are going to TF 2.6 (2.3 still suffers from some vulnerabilities). I think the best way is to ensure your code conforms with 2.6 from this PR already. Could you go through the migration process described here https://www.tensorflow.org/guide/migrate (i.e. run the script they provide and make the suggested changes in the syntax)?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You would need to use this setup.py (note the CUDA version too).

setup.py Outdated
Comment on lines 63 to 66
"tensorflow-gpu==2.3.0", # compiled with CUDA 10.0
"torch==1.9.1", # last os-common version with CUDA 10.0 support
"six~=1.15.0",
"typing-extensions~=3.7.4",
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same thing here, at this point if we add TF 2 we will break the repo
@anargyri is working in the transition, but so far we can't add this change

Comment on lines 34 to 37
@pytest.mark.notebooks
def test_sasrec_single_node_runs(notebooks, output_notebook, kernel_name):
notebook_path = notebooks["sasrec_quickstart"]
pm.execute_notebook(notebook_path, output_notebook, kernel_name=kernel_name)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

how long does this test take? if it takes too long we can add it in smoke or integration

Comment on lines 18 to 31
@pytest.mark.notebooks
def test_template_runs(notebooks, output_notebook, kernel_name):
notebook_path = notebooks["template"]
pm.execute_notebook(
notebook_path,
output_notebook,
parameters=dict(PM_VERSION=pm.__version__),
kernel_name=kernel_name,
)
nb = sb.read_notebook(output_notebook)
df = nb.papermill_dataframe
assert df.shape[0] == 2
check_version = df.loc[df["name"] == "checked_version", "value"].values[0]
assert check_version is True
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this needed here? my understanding was that there was one example like this in the same folder?

@@ -0,0 +1,139 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

one thing that is missing are the unit tests of the classes themselves. See for example: https://github.com/microsoft/recommenders/blob/main/tests/unit/recommenders/models/test_deeprec_model.py

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@aeroabir not sure if you saw this comment

@miguelgfierro
Copy link
Collaborator

@aeroabir something important, when you are doing the commits GitHub is not linking your username with the commits. I think it is important for everyone to show on github that the code that you developed is attributed to you. It took you a long time to do this and we should really make sure your work is showcased.
To fix this just follow these instructions: https://github.com/Microsoft/Recommenders/wiki/How-to-add-your-name-as-a-contributor-to-the-repo#make-sure-that-your-github-user-is-setup-correctly

Once this is set up, you would see that your name will automatically appear in the contributor list https://github.com/microsoft/recommenders/graphs/contributors

setup.py Outdated
"pandas>1.0.3,<2",
"scipy>=1.0.0,<2",
"scipy==1.4.1",
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@aerobir, we can't a pin scipy, it is too restrictive. @anargyri is working to upgrade the repo to TF>2.5, also restricting numpy to <1.19 could be a problem.
Do you know if this code works with the same TF version that @anargyri is using?


import os
import pytest
import tensorflow as tf
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think that we can have the TF2 tests in the same file as TF, because we are going to upgrade soon FYI @anargyri

@@ -0,0 +1,139 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@aeroabir not sure if you saw this comment

@anargyri
Copy link
Collaborator

@aeroabir This PR should be ready to revisit now that the TensorFlow upgrade has been checked into staging.
Feel free to merge staging branch into this branch (taking setup.py from staging) and address any comments left when you have time.

@anargyri
Copy link
Collaborator

Also could you add an entry in the table that lists the algos in the front page README?

@anargyri
Copy link
Collaborator

Could you also put SAS alphabetically in the table (after SAR) ?

@anargyri
Copy link
Collaborator

Another place to add SASrec is here.

@codecov-commenter
Copy link

codecov-commenter commented Jan 17, 2022

Codecov Report

Merging #1530 (6f39521) into staging (13072e7) will increase coverage by 59.05%.
The diff coverage is 72.14%.

❗ Current head 6f39521 differs from pull request most recent head f15d8b3. Consider uploading reports for the commit f15d8b3 to get more accurate results
Impacted file tree graph

@@             Coverage Diff              @@
##           staging    #1530       +/-   ##
============================================
+ Coverage     0.00%   59.05%   +59.05%     
============================================
  Files           84       88        +4     
  Lines         8462     8997      +535     
============================================
+ Hits             0     5313     +5313     
- Misses           0     3684     +3684     
Flag Coverage Δ
nightly ?
pr-gate 59.05% <72.14%> (?)

Flags with carried forward coverage won't be shown. Click here to find out more.

Impacted Files Coverage Δ
recommenders/models/sasrec/ssept.py 22.35% <22.35%> (ø)
recommenders/models/sasrec/util.py 57.74% <57.74%> (ø)
recommenders/models/sasrec/model.py 85.19% <85.19%> (ø)
recommenders/models/sasrec/sampler.py 91.66% <91.66%> (ø)
recommenders/datasets/mind.py 0.00% <0.00%> (ø)
recommenders/datasets/movielens.py 69.46% <0.00%> (+69.46%) ⬆️
recommenders/datasets/download_utils.py 90.00% <0.00%> (+90.00%) ⬆️
recommenders/models/newsrec/models/npa.py 95.58% <0.00%> (+95.58%) ⬆️
recommenders/models/newsrec/models/naml.py 92.43% <0.00%> (+92.43%) ⬆️
recommenders/models/newsrec/models/nrms.py 91.37% <0.00%> (+91.37%) ⬆️
... and 14 more

Continue to review full report at Codecov.

Legend - Click here to learn more
Δ = absolute <relative> (impact), ø = not affected, ? = missing data
Powered by Codecov. Last update 13072e7...f15d8b3. Read the comment docs.

Copy link
Collaborator

@miguelgfierro miguelgfierro left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hey Abir, I think the code is in good shape. There is one issue to discuss about the splitter

README.md Outdated
@@ -118,8 +118,10 @@ The table below lists the recommender algorithms currently available in the repo
| Restricted Boltzmann Machines (RBM) | Collaborative Filtering | Neural network based algorithm for learning the underlying probability distribution for explicit or implicit user/item feedback. It works in the CPU/GPU enviroment. | [Quick start](examples/00_quick_start/rbm_movielens.ipynb) / [Deep dive](examples/02_model_collaborative_filtering/rbm_deep_dive.ipynb) |
| Riemannian Low-rank Matrix Completion (RLRMC)<sup>*</sup> | Collaborative Filtering | Matrix factorization algorithm using Riemannian conjugate gradients optimization with small memory consumption to predice user/item interactions. It works in the CPU enviroment. | [Quick start](examples/00_quick_start/rlrmc_movielens.ipynb) |
| Simple Algorithm for Recommendation (SAR)<sup>*</sup> | Collaborative Filtering | Similarity-based algorithm for implicit user/item feedback. It works in the CPU environment. | [Quick start](examples/00_quick_start/sar_movielens.ipynb) / [Deep dive](examples/02_model_collaborative_filtering/sar_deep_dive.ipynb) |
| Self-Attentive Sequential Recommendation (SASRec) | Sequential | Transformer based algorithm for sequential recommendation. It works in the CPU/GPU environment. | [Deep dive](examples/00_quick_start/sasrec_amazon.ipynb) |
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think here it would be [Quick start](examples/00_quick_start/sasrec_amazon.ipynb)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the type, right now we only have as options Collaborative Filtering, Content based and Hybrid. We don't have sequential. However, we actually discussed to change the types in the past, but we didn't take any decision.

If you think it would be a good idea to change the types, would it be ok if we leave this as one of the old types, just for consistency, and then start a discussion about other types?

README.md Outdated
| Short-term and Long-term Preference Integrated Recommender (SLi-Rec)<sup>*</sup> | Collaborative Filtering | Sequential-based algorithm that aims to capture both long and short-term user preferences using attention mechanism, a time-aware controller and a content-aware controller. It works in the CPU/GPU environment. | [Quick start](examples/00_quick_start/sequential_recsys_amazondataset.ipynb) |
| Multi-Interest-Aware Sequential User Modeling (SUM)<sup>*</sup> | Collaborative Filtering | An enhanced memory network-based sequential user model which aims to capture users' multiple interests. It works in the CPU/GPU environment. | [Quick start](examples/00_quick_start/sequential_recsys_amazondataset.ipynb) |
| Sequential Recommendation Via Personalized Transformer (SSEPT) | Sequential | Transformer based algorithm for sequential recommendation with User embedding. It works in the CPU/GPU environment. | [Deep dive](examples/00_quick_start/sasrec_amazon.ipynb) |
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same note here as before with the type and quick start

Comment on lines 332 to 336
# TF2.x
"sasrec_quickstart": os.path.join(
folder_notebooks, "00_quick_start", "sasrec_amazon.ipynb"
),
}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

minor detail, the other notebooks are ordered based on the folder 00_quick_start, "02_model_content_based_filtering". Would you mind to follow the same structure?

)
],
)
# @pytest.mark.skipif(tf.__versoin__ > "2.0", reason="We are currently on TF 1.5")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

# @pytest.mark.skipif(tf.__versoin__ > "2.0", reason="We are currently on TF 1.5")
this can be updated now

Comment on lines +43 to +63
def split(self, **kwargs):
self.filename = kwargs.get("filename", self.filename)
if not self.filename:
raise ValueError("Filename is required")

if self.with_time:
self.data_partition_with_time()
else:
self.data_partition()

def data_partition(self):
# assume user/item index starting from 1
f = open(self.filename, "r")
for line in f:
u, i = line.rstrip().split(self.col_sep)
u = int(u)
i = int(i)
self.usernum = max(u, self.usernum)
self.itemnum = max(i, self.itemnum)
self.User[u].append(i)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One thing for you to consider Abir. In the other algos, we always split the data using a common splitters: https://github.com/microsoft/recommenders/blob/main/recommenders/datasets/python_splitters.py that helps us to compare algorithms, in an apple to apple fashion. If we use the original splitters we have we, we won't be able to compare these algos with the other ones in the repo.

What are your thoughts on this?

Comment on lines 167 to 188
# Amazon Electronics Data
itemnum = 85930
maxlen = 50
num_blocks = 2
hidden_units = 100
num_heads = 1
dropout_rate = 0.1
l2_emb = 0.0
num_neg_test = 100

model = SASREC(
item_num=itemnum,
seq_max_len=maxlen,
num_blocks=num_blocks,
embedding_dim=hidden_units,
attention_dim=hidden_units,
attention_num_heads=num_heads,
dropout_rate=dropout_rate,
conv_dims=[100, 100],
l2_reg=l2_emb,
num_neg_test=num_neg_test,
)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

An elegant way of adding input data to tests are fixtures, see how they work here: https://github.com/miguelgfierro/pybase/blob/master/test/pytest_fixtures.py
you could create a dictionary of fixtures of Amazon electronics data and input it in the test

Comment on lines 9 to 23
def random_neq(left, right, s):
t = np.random.randint(left, right)
while t in s:
t = np.random.randint(left, right)
return t


def sample_function(
user_train, usernum, itemnum, batch_size, maxlen, result_queue, seed
):
"""
Batch sampler that creates a sequence of negative items based on the
original sequence of items (positive) that the user has interacted with.
"""

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if you have time, it would be great to add the missing docstrings. This is an example of the format we follow:

    def check_column_dtypes_wrapper(
        rating_true,
        rating_pred,
        col_user=DEFAULT_USER_COL,
        col_item=DEFAULT_ITEM_COL,
        col_rating=DEFAULT_RATING_COL,
        col_prediction=DEFAULT_PREDICTION_COL,
        *args,
        **kwargs
    ):
        """Check columns of DataFrame inputs
        Args:
            rating_true (pandas.DataFrame): True data
            rating_pred (pandas.DataFrame): Predicted data
            col_user (str): column name for user
            col_item (str): column name for item
            col_rating (str): column name for rating
            col_prediction (str): column name for prediction
        """

see here more info: https://github.com/microsoft/recommenders/blob/main/recommenders/evaluation/python_evaluation.py#L51

@@ -0,0 +1,912 @@
{
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a very nice function, it removes users and items with less than k interactions. @anargyri @simonzhaoms @angusrtaylor @aeroabir do you see us using this function in others part of the repo?


Reply via ReviewNB

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually we have a function already here. You may use it instead.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@aeroabir if you want to do this change now, please let me know, otherwise we can create an issue and leave it for another PR

Copy link
Collaborator

@miguelgfierro miguelgfierro left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good to me! really great job Abir!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants