From 54d38e16b4309efe86115698b818571241f25f0e Mon Sep 17 00:00:00 2001 From: Tian Zhen <1204216974@qq.com> Date: Tue, 9 Nov 2021 21:56:04 +0800 Subject: [PATCH 1/3] FIX: fix issue#1038 --- docs/source/user_guide/usage/use_modules.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source/user_guide/usage/use_modules.rst b/docs/source/user_guide/usage/use_modules.rst index 6b20556a1..675837d45 100644 --- a/docs/source/user_guide/usage/use_modules.rst +++ b/docs/source/user_guide/usage/use_modules.rst @@ -36,7 +36,7 @@ The complete process is as follows: train_data, valid_data, test_data = data_preparation(config, dataset) # model loading and initialization - model = BPR(config, train_data).to(config['device']) + model = BPR(config, train_data.dataset).to(config['device']) logger.info(model) # trainer loading and initialization From c63fd71478fd2741d7b3a69ee0a2de5f5c918608 Mon Sep 17 00:00:00 2001 From: Tian Zhen <1204216974@qq.com> Date: Tue, 9 Nov 2021 21:57:02 +0800 Subject: [PATCH 2/3] undo --- docs/source/user_guide/usage/use_modules.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source/user_guide/usage/use_modules.rst b/docs/source/user_guide/usage/use_modules.rst index 675837d45..6b20556a1 100644 --- a/docs/source/user_guide/usage/use_modules.rst +++ b/docs/source/user_guide/usage/use_modules.rst @@ -36,7 +36,7 @@ The complete process is as follows: train_data, valid_data, test_data = data_preparation(config, dataset) # model loading and initialization - model = BPR(config, train_data.dataset).to(config['device']) + model = BPR(config, train_data).to(config['device']) logger.info(model) # trainer loading and initialization From 9ec41f822a73ca08351bdaac2b83884127af3440 Mon Sep 17 00:00:00 2001 From: Tian Zhen <1204216974@qq.com> Date: Mon, 21 Feb 2022 19:56:10 +0800 Subject: [PATCH 3/3] FEA: Add ADMMSLIM in General models --- .../user_guide/model/general/admmslim.rst | 102 +++++++++++++++ docs/source/user_guide/model_intro.rst | 1 + recbole/model/general_recommender/__init__.py | 3 +- recbole/model/general_recommender/admmslim.py | 117 ++++++++++++++++++ recbole/properties/model/ADMMSLIM.yaml | 7 ++ tests/model/test_model_auto.py | 6 + 6 files changed, 235 insertions(+), 1 deletion(-) create mode 100644 docs/source/user_guide/model/general/admmslim.rst create mode 100644 recbole/model/general_recommender/admmslim.py create mode 100644 recbole/properties/model/ADMMSLIM.yaml diff --git a/docs/source/user_guide/model/general/admmslim.rst b/docs/source/user_guide/model/general/admmslim.rst new file mode 100644 index 000000000..d419adddd --- /dev/null +++ b/docs/source/user_guide/model/general/admmslim.rst @@ -0,0 +1,102 @@ +ADMMSLIM +============ + +Introduction +------------------ + +`[paper] `_ + +**Title:** ADMM SLIM: Sparse Recommendations for Many Users + +**Authors:** Harald Steck,Maria Dimakopoulou,Nickolai Riabov,Tony Jebara + + +**Abstract:** The Sparse Linear Method (Slim) is a well-established approach +for top-N recommendations. This article proposes several improvements +that are enabled by the Alternating Directions Method of +Multipliers (ADMM), a well-known optimization method +with many application areas. First, we show that optimizing the +original Slim-objective by ADMM results in an approach where the +training time is independent of the number of users in the training +data, and hence trivially scales to large numbers of users. Second, +the flexibility of ADMM allows us to switch on and off the various +constraints and regularization terms in the original Slim-objective, +in order to empirically assess their contributions to ranking accuracy +on given data. Third, we also propose two extensions to the +original Slim training-objective in order to improve recommendation +accuracy further without increasing the computational cost. In +our experiments on three well-known data-sets, we first compare +to the original Slim-implementation and find that not only ADMM +reduces training time considerably, but also achieves an improvement +in recommendation accuracy due to better optimization. We +then compare to various state-of-the-art approaches and observe +up to 25% improvement in recommendation accuracy in our experiments. +Finally, we evaluate the importance of sparsity and the +non-negativity constraint in the original Slim-objective with subsampling +experiments that simulate scenarios of cold-starting and +large catalog sizes compared to relatively small user base, which +often occur in practice. + +Running with RecBole +------------------------- + +**Model Hyper-Parameters:** + +- ``lambda1 (float)`` : L1-norm regularization parameter. Defaults to ``3``. + +- ``lambda2 (float)`` : L2-norm regularization parameter. Defaults to ``200``. + +- ``alpha (float)`` : The exponents to control the power-law in the regularization terms. Defaults to ``0.5``. + +- ``rho (float)`` : The penalty parameter that applies to the squared difference between primal variables. Defaults to ``4000``. + +- ``k (int)`` : The number of running iterations. Defaults to ``100``. + +- ``positive_only (bool)`` : Whether only preserves all positive values. Defaults to ``True``. + +- ``center_columns (bool)`` : Whether to use additional item-bias terms.. Defaults to ``False``. + + +**A Running Example:** + +Write the following code to a python file, such as `run.py` + +.. code:: python + + from recbole.quick_start import run_recbole + + run_recbole(model='ADMMSLIM', dataset='ml-100k') + +And then: + +.. code:: bash + + python run.py + +Tuning Hyper Parameters +------------------------- + +If you want to use ``HyperTuning`` to tune hyper parameters of this model, you can copy the following settings and name it as ``hyper.test``. + +.. code:: bash + + lambda1 choice [0.1 , 0.5 , 5 , 10] + lambda2 choice [5 , 50 , 1000 , 5000] + alpha choice [0.25 , 0.5 , 0.75 , 1] + +Note that we just provide these hyper parameter ranges for reference only, and we can not guarantee that they are the optimal range of this model. + +Then, with the source code of RecBole (you can download it from GitHub), you can run the ``run_hyper.py`` to tuning: + +.. code:: bash + + python run_hyper.py --model=[model_name] --dataset=[dataset_name] --config_files=[config_files_path] --params_file=hyper.test + +For more details about Parameter Tuning, refer to :doc:`../../../user_guide/usage/parameter_tuning`. + +If you want to change parameters, dataset or evaluation settings, take a look at + +- :doc:`../../../user_guide/config_settings` +- :doc:`../../../user_guide/data_intro` +- :doc:`../../../user_guide/train_eval_intro` +- :doc:`../../../user_guide/usage` \ No newline at end of file diff --git a/docs/source/user_guide/model_intro.rst b/docs/source/user_guide/model_intro.rst index 6d6685fa3..97f12b91f 100644 --- a/docs/source/user_guide/model_intro.rst +++ b/docs/source/user_guide/model_intro.rst @@ -38,6 +38,7 @@ task of top-n recommendation. All the collaborative filter(CF) based models are model/general/ease model/general/slimelastic model/general/sgl + model/general/admmslim Context-aware Recommendation diff --git a/recbole/model/general_recommender/__init__.py b/recbole/model/general_recommender/__init__.py index 133a87bf9..4a54d0ffe 100644 --- a/recbole/model/general_recommender/__init__.py +++ b/recbole/model/general_recommender/__init__.py @@ -23,4 +23,5 @@ from recbole.model.general_recommender.recvae import RecVAE from recbole.model.general_recommender.slimelastic import SLIMElastic from recbole.model.general_recommender.spectralcf import SpectralCF -from recbole.model.general_recommender.sgl import SGL \ No newline at end of file +from recbole.model.general_recommender.sgl import SGL +from recbole.model.general_recommender.admmslim import ADMMSLIM \ No newline at end of file diff --git a/recbole/model/general_recommender/admmslim.py b/recbole/model/general_recommender/admmslim.py new file mode 100644 index 000000000..d8242904e --- /dev/null +++ b/recbole/model/general_recommender/admmslim.py @@ -0,0 +1,117 @@ +# @Time : 2021/01/09 +# @Author : Deklan Webster + +r""" +ADMMSLIM +################################################ +Reference: + Steck et al. ADMM SLIM: Sparse Recommendations for Many Users. https://doi.org/10.1145/3336191.3371774 + +""" + +from recbole.utils.enum_type import ModelType +import numpy as np +import scipy.sparse as sp +import torch + +from recbole.utils import InputType +from recbole.model.abstract_recommender import GeneralRecommender + + +def soft_threshold(x, threshold): + return (np.abs(x) > threshold) * (np.abs(x) - threshold) * np.sign(x) + + +def zero_mean_columns(a): + return a - np.mean(a, axis=0) + + +def add_noise(t, mag=1e-5): + return t + mag * torch.rand(t.shape) + + +class ADMMSLIM(GeneralRecommender): + input_type = InputType.POINTWISE + type = ModelType.TRADITIONAL + + def __init__(self, config, dataset): + super().__init__(config, dataset) + + # need at least one param + self.dummy_param = torch.nn.Parameter(torch.zeros(1)) + + X = dataset.inter_matrix(form='csr').astype(np.float32) + + num_users, num_items = X.shape + + lambda1 = config['lambda1'] + lambda2 = config['lambda2'] + alpha = config['alpha'] + rho = config['rho'] + k = config['k'] + positive_only = config['positive_only'] + self.center_columns = config['center_columns'] + self.item_means = X.mean(axis=0).getA1() + + if self.center_columns: + zero_mean_X = X.toarray() - self.item_means + G = (zero_mean_X.T @ zero_mean_X) + # large memory cost because we need to make X dense to subtract mean, delete asap + del zero_mean_X + else: + G = (X.T @ X).toarray() + + diag = lambda2 * np.diag(np.power(self.item_means, alpha)) + \ + rho * np.identity(num_items) + + P = np.linalg.inv(G + diag).astype(np.float32) + B_aux = (P @ G).astype(np.float32) + # initialize + Gamma = np.zeros_like(G, dtype=np.float32) + C = np.zeros_like(G, dtype=np.float32) + + del diag, G + # fixed number of iterations + for _ in range(k): + B_tilde = B_aux + P @ (rho * C - Gamma) + gamma = np.diag(B_tilde) / (np.diag(P) + 1e-7) + B = B_tilde - P * gamma + C = soft_threshold(B + Gamma / rho, lambda1 / rho) + if positive_only: + C = (C > 0) * C + Gamma += rho * (B - C) + # torch doesn't support sparse tensor slicing, so will do everything with np/scipy + self.item_similarity = C + self.interaction_matrix = X + + def forward(self): + pass + + def calculate_loss(self, interaction): + return torch.nn.Parameter(torch.zeros(1)) + + def predict(self, interaction): + user = interaction[self.USER_ID].cpu().numpy() + item = interaction[self.ITEM_ID].cpu().numpy() + + user_interactions = self.interaction_matrix[user, :].toarray() + + if self.center_columns: + r = (((user_interactions - self.item_means) * + self.item_similarity[:, item].T).sum(axis=1)).flatten() + self.item_means[item] + else: + r = (user_interactions * self.item_similarity[:, item].T).sum(axis=1).flatten() + + return add_noise(torch.from_numpy(r)) + + def full_sort_predict(self, interaction): + user = interaction[self.USER_ID].cpu().numpy() + + user_interactions = self.interaction_matrix[user, :].toarray() + + if self.center_columns: + r = ((user_interactions - self.item_means) @ self.item_similarity + self.item_means).flatten() + else: + r = (user_interactions @ self.item_similarity).flatten() + + return add_noise(torch.from_numpy(r)) diff --git a/recbole/properties/model/ADMMSLIM.yaml b/recbole/properties/model/ADMMSLIM.yaml new file mode 100644 index 000000000..84bec809c --- /dev/null +++ b/recbole/properties/model/ADMMSLIM.yaml @@ -0,0 +1,7 @@ +lambda1: 3 +lambda2: 200 +alpha: 0.50 +rho: 4000 +k: 100 +positive_only: True +center_columns: False \ No newline at end of file diff --git a/tests/model/test_model_auto.py b/tests/model/test_model_auto.py index 06e9e0e24..ca397d841 100644 --- a/tests/model/test_model_auto.py +++ b/tests/model/test_model_auto.py @@ -205,6 +205,12 @@ def test_SGL(self): 'model': 'SGL', } quick_test(config_dict) + + def test_ADMMSLIM(self): + config_dict = { + 'model': 'ADMMSLIM', + } + quick_test(config_dict) class TestContextRecommender(unittest.TestCase):