Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

FEA: Add dynamic negative sampling #1006

Merged
merged 6 commits into from
Oct 20, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion docs/source/user_guide/config/training_settings.rst
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,11 @@ Training settings are designed to set parameters about model training.
Range in ``['adam', 'sgd', 'adagrad', 'rmsprop', 'sparse_adam']``.
- ``learning_rate (float)`` : Learning rate. Defaults to ``0.001``.
- ``neg_sampling(dict)``: This parameter controls the negative sampling for model training.
The key range is ``['uniform', 'popularity']``, which decides the distribution of negative items in sampling pools.
The key range is ``['uniform', 'popularity']``, which decides the distribution of negative items in sampling pools. In addition, we also support dynamic negative sampling ``['dynamic']``.
``uniform`` means uniformly select negative items while ``popularity`` means select negative items based on
their popularity (Counter(item) in `.inter` file). Note that if your data is labeled, you need to set this parameter as ``None``.
The default value of this parameter is ``{'uniform': 1}``.
When dynamic negative sampling, ``dynamic`` decides the number of candidate negative items. For example, ``{'uniform': 1, 'dynamic': 2}`` means sampling 2 items for each positive item uniformly, and dynamically choosing the item with the higher score as the selected negative item. In particular, ``'uniform': 1`` means that a positive item pairs with one negative item, and ``'dynamic': 2`` means dynamically selecting each negative item from two candidates.
- ``eval_step (int)`` : The number of training epochs before an evaluation
on the valid dataset. If it is less than 1, the model will not be
evaluated on the valid dataset. Defaults to ``1``.
Expand Down
9 changes: 6 additions & 3 deletions recbole/config/configurator.py
Original file line number Diff line number Diff line change
Expand Up @@ -354,19 +354,22 @@ def _set_train_neg_sample_args(self):
else:
if not isinstance(neg_sampling, dict):
raise ValueError(f"neg_sampling:[{neg_sampling}] should be a dict.")
if len(neg_sampling) > 1:
raise ValueError(f"the len of neg_sampling [{neg_sampling}] should be 1.")

distribution = list(neg_sampling.keys())[0]
sample_num = neg_sampling[distribution]
if distribution not in ['uniform', 'popularity']:
raise ValueError(f"The distribution [{distribution}] of neg_sampling "
f"should in ['uniform', 'popularity']")

dynamic = 'none'
if 'dynamic' in neg_sampling.keys():
dynamic = neg_sampling['dynamic']

self.final_config_dict['train_neg_sample_args'] = {
'strategy': 'by',
'by': sample_num,
'distribution': distribution
'distribution': distribution,
'dynamic': dynamic
}

def _set_eval_neg_sample_args(self):
Expand Down
24 changes: 22 additions & 2 deletions recbole/data/dataloader/abstract_dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
"""

import math
import copy
from logging import getLogger

import torch
Expand Down Expand Up @@ -45,7 +46,7 @@ def __init__(self, config, dataset, sampler, shuffle=False):
self.logger = getLogger()
self.dataset = dataset
self.sampler = sampler
self.batch_size = self.step = None
self.batch_size = self.step = self.model = None
self.shuffle = shuffle
self.pr = 0
self._init_batch_size_and_step()
Expand Down Expand Up @@ -154,7 +155,23 @@ def _set_neg_sample_args(self, config, dataset, dl_format, neg_sample_args):
raise ValueError(f'`neg_sample_args` [{self.neg_sample_args["strategy"]}] is not supported!')

def _neg_sampling(self, inter_feat):
if self.neg_sample_args['strategy'] == 'by':
if 'dynamic' in self.neg_sample_args.keys() and self.neg_sample_args['dynamic'] != 'none':
candidate_num = self.neg_sample_args['dynamic']
user_ids = inter_feat[self.uid_field]
item_ids = inter_feat[self.iid_field]
neg_candidate_ids = self.sampler.sample_by_user_ids(user_ids, item_ids, self.neg_sample_num * candidate_num)
self.model.eval()
interaction = copy.deepcopy(inter_feat).to(self.model.device)
interaction = interaction.repeat(self.neg_sample_num * candidate_num)
neg_item_feat = Interaction({self.iid_field: neg_candidate_ids.to(self.model.device)})
interaction.update(neg_item_feat)
scores = self.model.predict(interaction).reshape(candidate_num, -1)
indices = torch.max(scores, dim=0)[1].detach()
neg_candidate_ids = neg_candidate_ids.reshape(candidate_num, -1)
neg_item_ids = neg_candidate_ids[indices, [i for i in range(neg_candidate_ids.shape[1])]].view(-1)
self.model.train()
return self.sampling_func(inter_feat, neg_item_ids)
elif self.neg_sample_args['strategy'] == 'by':
user_ids = inter_feat[self.uid_field]
item_ids = inter_feat[self.iid_field]
neg_item_ids = self.sampler.sample_by_user_ids(user_ids, item_ids, self.neg_sample_num)
Expand All @@ -179,3 +196,6 @@ def _neg_sample_by_point_wise_sampling(self, inter_feat, neg_item_ids):
labels[:pos_inter_num] = 1.0
new_data.update(Interaction({self.label_field: labels}))
return new_data

def get_model(self, model):
self.model = model
5 changes: 5 additions & 0 deletions recbole/data/dataloader/knowledge_dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,3 +182,8 @@ def set_mode(self, state):
if state not in set(KGDataLoaderState):
raise NotImplementedError(f'Kg data loader has no state named [{self.state}].')
self.state = state

def get_model(self, model):
"""Let the general_dataloader get the model, used for dynamic sampling.
"""
self.general_dataloader.get_model(model)
2 changes: 1 addition & 1 deletion recbole/data/dataset/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -433,7 +433,7 @@ def _load_feat(self, filepath, source):
return None

df = pd.read_csv(
filepath, delimiter=self.config['field_separator'], usecols=usecols, dtype=dtype, encoding=encoding
filepath, delimiter=self.config['field_separator'], usecols=usecols, dtype=dtype, encoding=encoding, engine='python'
)
df.columns = columns

Expand Down
3 changes: 3 additions & 0 deletions recbole/model/abstract_recommender.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,9 @@ def __init__(self, config, dataset):
self.max_seq_length = config['MAX_ITEM_LIST_LENGTH']
self.n_items = dataset.num(self.ITEM_ID)

# load parameters info
self.device = config['device']

def gather_indexes(self, output, gather_index):
"""Gathers the vectors at the specific positions over a minibatch"""
gather_index = gather_index.view(-1, 1, 1).expand(-1, -1, output.shape[-1])
Expand Down
1 change: 0 additions & 1 deletion recbole/model/sequential_recommender/hrm.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,6 @@ def predict(self, interaction):
test_item = interaction[self.ITEM_ID]
user = interaction[self.USER_ID]
seq_output = self.forward(item_seq, user, seq_item_len)
seq_output = seq_output.repeat(1, self.embedding_size)
test_item_emb = self.item_embedding(test_item)
scores = torch.mul(seq_output, test_item_emb).sum(dim=1)

Expand Down
3 changes: 2 additions & 1 deletion recbole/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -308,7 +308,8 @@ def fit(self, train_data, valid_data=None, verbose=True, saved=True, show_progre
self._save_checkpoint(-1)

self.eval_collector.data_collect(train_data)

if 'dynamic' in self.config['train_neg_sample_args'].keys() and self.config['train_neg_sample_args']['dynamic'] != 'none':
train_data.get_model(self.model)
for epoch_idx in range(self.start_epoch, self.epochs):
# train
training_start_time = time()
Expand Down
30 changes: 26 additions & 4 deletions tests/model/test_model_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,16 @@ def test_bpr(self):
}
quick_test(config_dict)

def test_bpr_with_dns(self):
config_dict = {
'model': 'BPR',
'neg_sampling': {
'uniform': 1,
'dynamic': 2
}
}
quick_test(config_dict)

def test_neumf(self):
config_dict = {
'model': 'NeuMF',
Expand Down Expand Up @@ -656,17 +666,29 @@ def test_bert4rec(self):
'model': 'BERT4Rec',
'neg_sampling': None
}
objective_function(config_dict=config_dict,
config_file_list=config_file_list, saved=False)
quick_test(config_dict)

def test_bert4rec_with_BPR_loss_and_swish(self):
config_dict = {
'model': 'BERT4Rec',
'loss_type': 'BPR',
'hidden_act': 'swish'
}
objective_function(config_dict=config_dict,
config_file_list=config_file_list, saved=False)
quick_test(config_dict)

def test_lightsans(self):
config_dict = {
'model': 'LightSANs',
'neg_sampling': None
}
quick_test(config_dict)

def test_lightsans_with_BPR_loss(self):
config_dict = {
'model': 'LightSANs',
'loss_type': 'BPR',
}
quick_test(config_dict)

# def test_gru4reckg(self):
# config_dict = {
Expand Down