From 2c7f71d56e4280bfa20c17b6e1e2511a1c168393 Mon Sep 17 00:00:00 2001 From: Lanling Xu Date: Mon, 18 Oct 2021 11:45:10 +0000 Subject: [PATCH 1/6] FEA: Add dynamic negative sampling and fix some bugs --- .../user_guide/config/training_settings.rst | 3 +- recbole/config/configurator.py | 18 ++++++++-- .../data/dataloader/abstract_dataloader.py | 36 ++++++++++++++++--- .../data/dataloader/knowledge_dataloader.py | 5 +++ recbole/data/dataset/dataset.py | 2 +- recbole/model/abstract_recommender.py | 3 ++ recbole/model/sequential_recommender/hrm.py | 1 - recbole/trainer/trainer.py | 6 +++- 8 files changed, 62 insertions(+), 12 deletions(-) diff --git a/docs/source/user_guide/config/training_settings.rst b/docs/source/user_guide/config/training_settings.rst index 2e3a51499..a877f55b7 100644 --- a/docs/source/user_guide/config/training_settings.rst +++ b/docs/source/user_guide/config/training_settings.rst @@ -9,10 +9,11 @@ Training settings are designed to set parameters about model training. Range in ``['adam', 'sgd', 'adagrad', 'rmsprop', 'sparse_adam']``. - ``learning_rate (float)`` : Learning rate. Defaults to ``0.001``. - ``neg_sampling(dict)``: This parameter controls the negative sampling for model training. - The key range is ``['uniform', 'popularity']``, which decides the distribution of negative items in sampling pools. + The key range is ``['uniform', 'popularity']``, which decides the distribution of negative items in sampling pools. In addition, we also support dynamic negative sampling ``['dns']``. ``uniform`` means uniformly select negative items while ``popularity`` means select negative items based on their popularity (Counter(item) in `.inter` file). Note that if your data is labeled, you need to set this parameter as ``None``. The default value of this parameter is ``{'uniform': 1}``. + When dynamic negative sampling, ``dynamic_sampling(dict)`` decides the dynamic negative sampler and the number of candidate negative items. For example, ``{uniform: 1, dynamic_sampling: {sampler: dns, candidate_num: 2}}`` means sample 2 items for each positive item uniformly, and dynamically choose the item with the higher score as the selected negative item. - ``eval_step (int)`` : The number of training epochs before an evaluation on the valid dataset. If it is less than 1, the model will not be evaluated on the valid dataset. Defaults to ``1``. diff --git a/recbole/config/configurator.py b/recbole/config/configurator.py index c7307818c..3cc733ed6 100644 --- a/recbole/config/configurator.py +++ b/recbole/config/configurator.py @@ -354,8 +354,6 @@ def _set_train_neg_sample_args(self): else: if not isinstance(neg_sampling, dict): raise ValueError(f"neg_sampling:[{neg_sampling}] should be a dict.") - if len(neg_sampling) > 1: - raise ValueError(f"the len of neg_sampling [{neg_sampling}] should be 1.") distribution = list(neg_sampling.keys())[0] sample_num = neg_sampling[distribution] @@ -363,8 +361,22 @@ def _set_train_neg_sample_args(self): raise ValueError(f"The distribution [{distribution}] of neg_sampling " f"should in ['uniform', 'popularity']") + strategy = 'by' + if 'dynamic_sampling' in neg_sampling.keys(): + dynamic_sampling = neg_sampling['dynamic_sampling'] + if not isinstance(dynamic_sampling, dict): + raise ValueError(f"dynamic_sampling:[{dynamic_sampling}] should be a dict.") + if not ({'sampler', 'candidate_num'} <= set(dynamic_sampling.keys())): + raise ValueError(f"'sampler' and 'candidate_num' should be in " + f"dynamic_sampling:[{dynamic_sampling}].keys()") + dynamic_sampler = dynamic_sampling['sampler'] + if dynamic_sampler.lower() not in ['dns']: + raise ValueError(f"The sampler [{sampler}] of dynamic_sampling " + f"should be in ['dns']") + strategy = dynamic_sampling + self.final_config_dict['train_neg_sample_args'] = { - 'strategy': 'by', + 'strategy': strategy, 'by': sample_num, 'distribution': distribution } diff --git a/recbole/data/dataloader/abstract_dataloader.py b/recbole/data/dataloader/abstract_dataloader.py index 638b8cb16..2f32bc338 100644 --- a/recbole/data/dataloader/abstract_dataloader.py +++ b/recbole/data/dataloader/abstract_dataloader.py @@ -13,6 +13,7 @@ """ import math +import copy from logging import getLogger import torch @@ -45,7 +46,7 @@ def __init__(self, config, dataset, sampler, shuffle=False): self.logger = getLogger() self.dataset = dataset self.sampler = sampler - self.batch_size = self.step = None + self.batch_size = self.step = self.model = None self.shuffle = shuffle self.pr = 0 self._init_batch_size_and_step() @@ -127,7 +128,8 @@ def _set_neg_sample_args(self, config, dataset, dl_format, neg_sample_args): self.dl_format = dl_format self.neg_sample_args = neg_sample_args self.times = 1 - if self.neg_sample_args['strategy'] == 'by': + self.strategy = self.neg_sample_args['strategy'] + if self.strategy == 'by' or isinstance(self.strategy, dict): self.neg_sample_num = self.neg_sample_args['by'] if self.dl_format == InputType.POINTWISE: @@ -150,11 +152,32 @@ def _set_neg_sample_args(self, config, dataset, dl_format, neg_sample_args): else: raise ValueError(f'`neg sampling by` with dl_format [{self.dl_format}] not been implemented.') - elif self.neg_sample_args['strategy'] != 'none': - raise ValueError(f'`neg_sample_args` [{self.neg_sample_args["strategy"]}] is not supported!') + elif self.strategy != 'none': + raise ValueError(f'`neg_sample_args` [{self.strategy}] is not supported!') def _neg_sampling(self, inter_feat): - if self.neg_sample_args['strategy'] == 'by': + if isinstance(self.strategy, dict): + sampler = self.strategy['sampler'] + candidate_num = self.strategy['candidate_num'] + if sampler.lower() == 'dns': + user_ids = inter_feat[self.uid_field] + item_ids = inter_feat[self.iid_field] + neg_candidate_ids = self.sampler.sample_by_user_ids(user_ids, item_ids, self.neg_sample_num * candidate_num) + + self.model.eval() + interaction = copy.deepcopy(inter_feat).to(self.model.device) + interaction = interaction.repeat(self.neg_sample_num * candidate_num) + neg_item_feat = Interaction({self.iid_field: neg_candidate_ids.to(self.model.device)}) + interaction.update(neg_item_feat) + scores = self.model.predict(interaction).reshape(candidate_num, -1) + indices = torch.max(scores, dim=0)[1].detach() + neg_candidate_ids = neg_candidate_ids.reshape(candidate_num, -1) + neg_item_ids = neg_candidate_ids[indices, [i for i in range(neg_candidate_ids.shape[1])]].view(-1) + self.model.train() + return self.sampling_func(inter_feat, neg_item_ids) + else: + raise NotImplementedError + elif self.strategy == 'by': user_ids = inter_feat[self.uid_field] item_ids = inter_feat[self.iid_field] neg_item_ids = self.sampler.sample_by_user_ids(user_ids, item_ids, self.neg_sample_num) @@ -179,3 +202,6 @@ def _neg_sample_by_point_wise_sampling(self, inter_feat, neg_item_ids): labels[:pos_inter_num] = 1.0 new_data.update(Interaction({self.label_field: labels})) return new_data + + def get_model(self, model): + self.model = model \ No newline at end of file diff --git a/recbole/data/dataloader/knowledge_dataloader.py b/recbole/data/dataloader/knowledge_dataloader.py index 23e80e6f8..e3d2a42ba 100644 --- a/recbole/data/dataloader/knowledge_dataloader.py +++ b/recbole/data/dataloader/knowledge_dataloader.py @@ -182,3 +182,8 @@ def set_mode(self, state): if state not in set(KGDataLoaderState): raise NotImplementedError(f'Kg data loader has no state named [{self.state}].') self.state = state + + def get_model(self, model): + """Let the general_dataloader get the model, used for dynamic sampling. + """ + self.general_dataloader.get_model(model) \ No newline at end of file diff --git a/recbole/data/dataset/dataset.py b/recbole/data/dataset/dataset.py index 05ce96ff1..1177f44d3 100644 --- a/recbole/data/dataset/dataset.py +++ b/recbole/data/dataset/dataset.py @@ -433,7 +433,7 @@ def _load_feat(self, filepath, source): return None df = pd.read_csv( - filepath, delimiter=self.config['field_separator'], usecols=usecols, dtype=dtype, encoding=encoding + filepath, delimiter=self.config['field_separator'], usecols=usecols, dtype=dtype, encoding=encoding, engine='python' ) df.columns = columns diff --git a/recbole/model/abstract_recommender.py b/recbole/model/abstract_recommender.py index 2aaae77a6..9cb80fbfa 100644 --- a/recbole/model/abstract_recommender.py +++ b/recbole/model/abstract_recommender.py @@ -124,6 +124,9 @@ def __init__(self, config, dataset): self.max_seq_length = config['MAX_ITEM_LIST_LENGTH'] self.n_items = dataset.num(self.ITEM_ID) + # load parameters info + self.device = config['device'] + def gather_indexes(self, output, gather_index): """Gathers the vectors at the specific positions over a minibatch""" gather_index = gather_index.view(-1, 1, 1).expand(-1, -1, output.shape[-1]) diff --git a/recbole/model/sequential_recommender/hrm.py b/recbole/model/sequential_recommender/hrm.py index 421835266..3dd1b0ae1 100644 --- a/recbole/model/sequential_recommender/hrm.py +++ b/recbole/model/sequential_recommender/hrm.py @@ -156,7 +156,6 @@ def predict(self, interaction): test_item = interaction[self.ITEM_ID] user = interaction[self.USER_ID] seq_output = self.forward(item_seq, user, seq_item_len) - seq_output = seq_output.repeat(1, self.embedding_size) test_item_emb = self.item_embedding(test_item) scores = torch.mul(seq_output, test_item_emb).sum(dim=1) diff --git a/recbole/trainer/trainer.py b/recbole/trainer/trainer.py index 10238d3cb..b7a655355 100644 --- a/recbole/trainer/trainer.py +++ b/recbole/trainer/trainer.py @@ -80,6 +80,7 @@ def __init__(self, config, model): self.tensorboard = get_tensorboard(self.logger) self.learner = config['learner'] self.learning_rate = config['learning_rate'] + self.sample_strategy = config['train_neg_sample_args']['strategy'] self.epochs = config['epochs'] self.eval_step = min(config['eval_step'], self.epochs) self.stopping_step = config['stopping_step'] @@ -308,7 +309,8 @@ def fit(self, train_data, valid_data=None, verbose=True, saved=True, show_progre self._save_checkpoint(-1) self.eval_collector.data_collect(train_data) - + if isinstance(self.sample_strategy, dict): + train_data.get_model(self.model) for epoch_idx in range(self.start_epoch, self.epochs): # train training_start_time = time() @@ -770,6 +772,8 @@ def _save_checkpoint(self, epoch): torch.save(state, self.saved_model_file) def fit(self, train_data, valid_data=None, verbose=True, saved=True, show_progress=False): + if isinstance(self.sample_strategy, dict): + train_data.get_model(self.model) for epoch_idx in range(self.epochs): self._train_at_once(train_data, valid_data) From 726bc9f6ff95f187edad19b851618a7211c5e261 Mon Sep 17 00:00:00 2001 From: Lanling Xu Date: Mon, 18 Oct 2021 11:59:36 +0000 Subject: [PATCH 2/6] FIX: fix trainer.py --- recbole/trainer/trainer.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/recbole/trainer/trainer.py b/recbole/trainer/trainer.py index b7a655355..2571d6464 100644 --- a/recbole/trainer/trainer.py +++ b/recbole/trainer/trainer.py @@ -772,8 +772,6 @@ def _save_checkpoint(self, epoch): torch.save(state, self.saved_model_file) def fit(self, train_data, valid_data=None, verbose=True, saved=True, show_progress=False): - if isinstance(self.sample_strategy, dict): - train_data.get_model(self.model) for epoch_idx in range(self.epochs): self._train_at_once(train_data, valid_data) From e7688b664adb416983d55e54fdf44b1ff8f17605 Mon Sep 17 00:00:00 2001 From: Lanling Xu Date: Mon, 18 Oct 2021 12:24:15 +0000 Subject: [PATCH 3/6] FEA: Add test_lightsans and test_lightsans_with_BPR_loss in test_model_auto.py --- tests/model/test_model_auto.py | 20 ++++++++++++++++---- 1 file changed, 16 insertions(+), 4 deletions(-) diff --git a/tests/model/test_model_auto.py b/tests/model/test_model_auto.py index 17360cd6f..fd0198782 100644 --- a/tests/model/test_model_auto.py +++ b/tests/model/test_model_auto.py @@ -656,8 +656,7 @@ def test_bert4rec(self): 'model': 'BERT4Rec', 'neg_sampling': None } - objective_function(config_dict=config_dict, - config_file_list=config_file_list, saved=False) + quick_test(config_dict) def test_bert4rec_with_BPR_loss_and_swish(self): config_dict = { @@ -665,8 +664,21 @@ def test_bert4rec_with_BPR_loss_and_swish(self): 'loss_type': 'BPR', 'hidden_act': 'swish' } - objective_function(config_dict=config_dict, - config_file_list=config_file_list, saved=False) + quick_test(config_dict) + + def test_lightsans(self): + config_dict = { + 'model': 'LightSANs', + 'neg_sampling': None + } + quick_test(config_dict) + + def test_lightsans_with_BPR_loss(self): + config_dict = { + 'model': 'LightSANs', + 'loss_type': 'BPR', + } + quick_test(config_dict) # def test_gru4reckg(self): # config_dict = { From 0bca35cfe7c94deadce421c9844550ba35962a8c Mon Sep 17 00:00:00 2001 From: Lanling Xu Date: Tue, 19 Oct 2021 05:51:07 +0000 Subject: [PATCH 4/6] FIX: Update dynamic sampling --- .../user_guide/config/training_settings.rst | 4 +- recbole/config/configurator.py | 21 +++------ .../data/dataloader/abstract_dataloader.py | 46 ++++++++----------- recbole/trainer/trainer.py | 3 +- 4 files changed, 29 insertions(+), 45 deletions(-) diff --git a/docs/source/user_guide/config/training_settings.rst b/docs/source/user_guide/config/training_settings.rst index a877f55b7..38395c1cf 100644 --- a/docs/source/user_guide/config/training_settings.rst +++ b/docs/source/user_guide/config/training_settings.rst @@ -9,11 +9,11 @@ Training settings are designed to set parameters about model training. Range in ``['adam', 'sgd', 'adagrad', 'rmsprop', 'sparse_adam']``. - ``learning_rate (float)`` : Learning rate. Defaults to ``0.001``. - ``neg_sampling(dict)``: This parameter controls the negative sampling for model training. - The key range is ``['uniform', 'popularity']``, which decides the distribution of negative items in sampling pools. In addition, we also support dynamic negative sampling ``['dns']``. + The key range is ``['uniform', 'popularity']``, which decides the distribution of negative items in sampling pools. In addition, we also support dynamic negative sampling ``['dynamic']``. ``uniform`` means uniformly select negative items while ``popularity`` means select negative items based on their popularity (Counter(item) in `.inter` file). Note that if your data is labeled, you need to set this parameter as ``None``. The default value of this parameter is ``{'uniform': 1}``. - When dynamic negative sampling, ``dynamic_sampling(dict)`` decides the dynamic negative sampler and the number of candidate negative items. For example, ``{uniform: 1, dynamic_sampling: {sampler: dns, candidate_num: 2}}`` means sample 2 items for each positive item uniformly, and dynamically choose the item with the higher score as the selected negative item. + When dynamic negative sampling, ``dynamic`` decides the number of candidate negative items. For example, ``{'uniform': 1, 'dynamic': 2}`` means sampling 2 items for each positive item uniformly, and dynamically choosing the item with the higher score as the selected negative item. In particular, ``'uniform': 1`` means that a positive item pairs with one negative item, and ``'dynamic': 2`` means dynamically selecting each negative item from two candidates. - ``eval_step (int)`` : The number of training epochs before an evaluation on the valid dataset. If it is less than 1, the model will not be evaluated on the valid dataset. Defaults to ``1``. diff --git a/recbole/config/configurator.py b/recbole/config/configurator.py index 3cc733ed6..e61853778 100644 --- a/recbole/config/configurator.py +++ b/recbole/config/configurator.py @@ -361,24 +361,15 @@ def _set_train_neg_sample_args(self): raise ValueError(f"The distribution [{distribution}] of neg_sampling " f"should in ['uniform', 'popularity']") - strategy = 'by' - if 'dynamic_sampling' in neg_sampling.keys(): - dynamic_sampling = neg_sampling['dynamic_sampling'] - if not isinstance(dynamic_sampling, dict): - raise ValueError(f"dynamic_sampling:[{dynamic_sampling}] should be a dict.") - if not ({'sampler', 'candidate_num'} <= set(dynamic_sampling.keys())): - raise ValueError(f"'sampler' and 'candidate_num' should be in " - f"dynamic_sampling:[{dynamic_sampling}].keys()") - dynamic_sampler = dynamic_sampling['sampler'] - if dynamic_sampler.lower() not in ['dns']: - raise ValueError(f"The sampler [{sampler}] of dynamic_sampling " - f"should be in ['dns']") - strategy = dynamic_sampling + dynamic = 'none' + if 'dynamic' in neg_sampling.keys(): + dynamic = neg_sampling['dynamic'] self.final_config_dict['train_neg_sample_args'] = { - 'strategy': strategy, + 'strategy': 'by', 'by': sample_num, - 'distribution': distribution + 'distribution': distribution, + 'dynamic': dynamic } def _set_eval_neg_sample_args(self): diff --git a/recbole/data/dataloader/abstract_dataloader.py b/recbole/data/dataloader/abstract_dataloader.py index 2f32bc338..808f3cb60 100644 --- a/recbole/data/dataloader/abstract_dataloader.py +++ b/recbole/data/dataloader/abstract_dataloader.py @@ -128,8 +128,7 @@ def _set_neg_sample_args(self, config, dataset, dl_format, neg_sample_args): self.dl_format = dl_format self.neg_sample_args = neg_sample_args self.times = 1 - self.strategy = self.neg_sample_args['strategy'] - if self.strategy == 'by' or isinstance(self.strategy, dict): + if self.neg_sample_args['strategy'] == 'by': self.neg_sample_num = self.neg_sample_args['by'] if self.dl_format == InputType.POINTWISE: @@ -152,32 +151,27 @@ def _set_neg_sample_args(self, config, dataset, dl_format, neg_sample_args): else: raise ValueError(f'`neg sampling by` with dl_format [{self.dl_format}] not been implemented.') - elif self.strategy != 'none': - raise ValueError(f'`neg_sample_args` [{self.strategy}] is not supported!') + elif self.neg_sample_args['strategy'] != 'none': + raise ValueError(f'`neg_sample_args` [{self.neg_sample_args["strategy"]}] is not supported!') def _neg_sampling(self, inter_feat): - if isinstance(self.strategy, dict): - sampler = self.strategy['sampler'] - candidate_num = self.strategy['candidate_num'] - if sampler.lower() == 'dns': - user_ids = inter_feat[self.uid_field] - item_ids = inter_feat[self.iid_field] - neg_candidate_ids = self.sampler.sample_by_user_ids(user_ids, item_ids, self.neg_sample_num * candidate_num) - - self.model.eval() - interaction = copy.deepcopy(inter_feat).to(self.model.device) - interaction = interaction.repeat(self.neg_sample_num * candidate_num) - neg_item_feat = Interaction({self.iid_field: neg_candidate_ids.to(self.model.device)}) - interaction.update(neg_item_feat) - scores = self.model.predict(interaction).reshape(candidate_num, -1) - indices = torch.max(scores, dim=0)[1].detach() - neg_candidate_ids = neg_candidate_ids.reshape(candidate_num, -1) - neg_item_ids = neg_candidate_ids[indices, [i for i in range(neg_candidate_ids.shape[1])]].view(-1) - self.model.train() - return self.sampling_func(inter_feat, neg_item_ids) - else: - raise NotImplementedError - elif self.strategy == 'by': + if self.neg_sample_args['dynamic'] != 'none': + candidate_num = self.neg_sample_args['dynamic'] + user_ids = inter_feat[self.uid_field] + item_ids = inter_feat[self.iid_field] + neg_candidate_ids = self.sampler.sample_by_user_ids(user_ids, item_ids, self.neg_sample_num * candidate_num) + self.model.eval() + interaction = copy.deepcopy(inter_feat).to(self.model.device) + interaction = interaction.repeat(self.neg_sample_num * candidate_num) + neg_item_feat = Interaction({self.iid_field: neg_candidate_ids.to(self.model.device)}) + interaction.update(neg_item_feat) + scores = self.model.predict(interaction).reshape(candidate_num, -1) + indices = torch.max(scores, dim=0)[1].detach() + neg_candidate_ids = neg_candidate_ids.reshape(candidate_num, -1) + neg_item_ids = neg_candidate_ids[indices, [i for i in range(neg_candidate_ids.shape[1])]].view(-1) + self.model.train() + return self.sampling_func(inter_feat, neg_item_ids) + elif self.neg_sample_args['strategy'] == 'by': user_ids = inter_feat[self.uid_field] item_ids = inter_feat[self.iid_field] neg_item_ids = self.sampler.sample_by_user_ids(user_ids, item_ids, self.neg_sample_num) diff --git a/recbole/trainer/trainer.py b/recbole/trainer/trainer.py index 2571d6464..da66d798a 100644 --- a/recbole/trainer/trainer.py +++ b/recbole/trainer/trainer.py @@ -80,7 +80,6 @@ def __init__(self, config, model): self.tensorboard = get_tensorboard(self.logger) self.learner = config['learner'] self.learning_rate = config['learning_rate'] - self.sample_strategy = config['train_neg_sample_args']['strategy'] self.epochs = config['epochs'] self.eval_step = min(config['eval_step'], self.epochs) self.stopping_step = config['stopping_step'] @@ -309,7 +308,7 @@ def fit(self, train_data, valid_data=None, verbose=True, saved=True, show_progre self._save_checkpoint(-1) self.eval_collector.data_collect(train_data) - if isinstance(self.sample_strategy, dict): + if self.config['train_neg_sample_args']['dynamic'] != 'none': train_data.get_model(self.model) for epoch_idx in range(self.start_epoch, self.epochs): # train From 2da7da00f1a76743a94981cb674da08357b23182 Mon Sep 17 00:00:00 2001 From: Lanling Xu Date: Tue, 19 Oct 2021 07:10:16 +0000 Subject: [PATCH 5/6] FIX: fix error in dynamic sampling --- recbole/data/dataloader/abstract_dataloader.py | 2 +- recbole/trainer/trainer.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/recbole/data/dataloader/abstract_dataloader.py b/recbole/data/dataloader/abstract_dataloader.py index 808f3cb60..f382d00f5 100644 --- a/recbole/data/dataloader/abstract_dataloader.py +++ b/recbole/data/dataloader/abstract_dataloader.py @@ -155,7 +155,7 @@ def _set_neg_sample_args(self, config, dataset, dl_format, neg_sample_args): raise ValueError(f'`neg_sample_args` [{self.neg_sample_args["strategy"]}] is not supported!') def _neg_sampling(self, inter_feat): - if self.neg_sample_args['dynamic'] != 'none': + if 'dynamic' in self.neg_sample_args.keys() and self.neg_sample_args['dynamic'] != 'none': candidate_num = self.neg_sample_args['dynamic'] user_ids = inter_feat[self.uid_field] item_ids = inter_feat[self.iid_field] diff --git a/recbole/trainer/trainer.py b/recbole/trainer/trainer.py index da66d798a..e1eeaf24c 100644 --- a/recbole/trainer/trainer.py +++ b/recbole/trainer/trainer.py @@ -308,7 +308,7 @@ def fit(self, train_data, valid_data=None, verbose=True, saved=True, show_progre self._save_checkpoint(-1) self.eval_collector.data_collect(train_data) - if self.config['train_neg_sample_args']['dynamic'] != 'none': + if 'dynamic' in self.config['train_neg_sample_args'].keys() and self.config['train_neg_sample_args']['dynamic'] != 'none': train_data.get_model(self.model) for epoch_idx in range(self.start_epoch, self.epochs): # train From 86fff6a0bfc3dc8b0a7f8663b1b9a3d26e3e5c86 Mon Sep 17 00:00:00 2001 From: Lanling Xu Date: Wed, 20 Oct 2021 02:58:41 +0000 Subject: [PATCH 6/6] FEA: add test_bpr_with_dns in test_model_auto.py --- tests/model/test_model_auto.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/tests/model/test_model_auto.py b/tests/model/test_model_auto.py index fd0198782..3cc684e67 100644 --- a/tests/model/test_model_auto.py +++ b/tests/model/test_model_auto.py @@ -41,6 +41,16 @@ def test_bpr(self): } quick_test(config_dict) + def test_bpr_with_dns(self): + config_dict = { + 'model': 'BPR', + 'neg_sampling': { + 'uniform': 1, + 'dynamic': 2 + } + } + quick_test(config_dict) + def test_neumf(self): config_dict = { 'model': 'NeuMF',