Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

FEA: Add dynamic negative sampling #1006

Merged
merged 6 commits into from
Oct 20, 2021
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion docs/source/user_guide/config/training_settings.rst
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,11 @@ Training settings are designed to set parameters about model training.
Range in ``['adam', 'sgd', 'adagrad', 'rmsprop', 'sparse_adam']``.
- ``learning_rate (float)`` : Learning rate. Defaults to ``0.001``.
- ``neg_sampling(dict)``: This parameter controls the negative sampling for model training.
The key range is ``['uniform', 'popularity']``, which decides the distribution of negative items in sampling pools.
The key range is ``['uniform', 'popularity']``, which decides the distribution of negative items in sampling pools. In addition, we also support dynamic negative sampling ``['dns']``.
2017pxy marked this conversation as resolved.
Show resolved Hide resolved
``uniform`` means uniformly select negative items while ``popularity`` means select negative items based on
their popularity (Counter(item) in `.inter` file). Note that if your data is labeled, you need to set this parameter as ``None``.
The default value of this parameter is ``{'uniform': 1}``.
When dynamic negative sampling, ``dynamic_sampling(dict)`` decides the dynamic negative sampler and the number of candidate negative items. For example, ``{uniform: 1, dynamic_sampling: {sampler: dns, candidate_num: 2}}`` means sample 2 items for each positive item uniformly, and dynamically choose the item with the higher score as the selected negative item.
- ``eval_step (int)`` : The number of training epochs before an evaluation
on the valid dataset. If it is less than 1, the model will not be
evaluated on the valid dataset. Defaults to ``1``.
Expand Down
18 changes: 15 additions & 3 deletions recbole/config/configurator.py
Original file line number Diff line number Diff line change
Expand Up @@ -354,17 +354,29 @@ def _set_train_neg_sample_args(self):
else:
if not isinstance(neg_sampling, dict):
raise ValueError(f"neg_sampling:[{neg_sampling}] should be a dict.")
if len(neg_sampling) > 1:
raise ValueError(f"the len of neg_sampling [{neg_sampling}] should be 1.")

distribution = list(neg_sampling.keys())[0]
sample_num = neg_sampling[distribution]
if distribution not in ['uniform', 'popularity']:
raise ValueError(f"The distribution [{distribution}] of neg_sampling "
f"should in ['uniform', 'popularity']")

strategy = 'by'
if 'dynamic_sampling' in neg_sampling.keys():
dynamic_sampling = neg_sampling['dynamic_sampling']
if not isinstance(dynamic_sampling, dict):
raise ValueError(f"dynamic_sampling:[{dynamic_sampling}] should be a dict.")
if not ({'sampler', 'candidate_num'} <= set(dynamic_sampling.keys())):
raise ValueError(f"'sampler' and 'candidate_num' should be in "
f"dynamic_sampling:[{dynamic_sampling}].keys()")
dynamic_sampler = dynamic_sampling['sampler']
if dynamic_sampler.lower() not in ['dns']:
raise ValueError(f"The sampler [{sampler}] of dynamic_sampling "
f"should be in ['dns']")
strategy = dynamic_sampling

self.final_config_dict['train_neg_sample_args'] = {
'strategy': 'by',
'strategy': strategy,
'by': sample_num,
'distribution': distribution
}
Expand Down
36 changes: 31 additions & 5 deletions recbole/data/dataloader/abstract_dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
"""

import math
import copy
from logging import getLogger

import torch
Expand Down Expand Up @@ -45,7 +46,7 @@ def __init__(self, config, dataset, sampler, shuffle=False):
self.logger = getLogger()
self.dataset = dataset
self.sampler = sampler
self.batch_size = self.step = None
self.batch_size = self.step = self.model = None
self.shuffle = shuffle
self.pr = 0
self._init_batch_size_and_step()
Expand Down Expand Up @@ -127,7 +128,8 @@ def _set_neg_sample_args(self, config, dataset, dl_format, neg_sample_args):
self.dl_format = dl_format
self.neg_sample_args = neg_sample_args
self.times = 1
if self.neg_sample_args['strategy'] == 'by':
self.strategy = self.neg_sample_args['strategy']
if self.strategy == 'by' or isinstance(self.strategy, dict):
self.neg_sample_num = self.neg_sample_args['by']

if self.dl_format == InputType.POINTWISE:
Expand All @@ -150,11 +152,32 @@ def _set_neg_sample_args(self, config, dataset, dl_format, neg_sample_args):
else:
raise ValueError(f'`neg sampling by` with dl_format [{self.dl_format}] not been implemented.')

elif self.neg_sample_args['strategy'] != 'none':
raise ValueError(f'`neg_sample_args` [{self.neg_sample_args["strategy"]}] is not supported!')
elif self.strategy != 'none':
raise ValueError(f'`neg_sample_args` [{self.strategy}] is not supported!')

def _neg_sampling(self, inter_feat):
if self.neg_sample_args['strategy'] == 'by':
if isinstance(self.strategy, dict):
sampler = self.strategy['sampler']
candidate_num = self.strategy['candidate_num']
if sampler.lower() == 'dns':
user_ids = inter_feat[self.uid_field]
item_ids = inter_feat[self.iid_field]
neg_candidate_ids = self.sampler.sample_by_user_ids(user_ids, item_ids, self.neg_sample_num * candidate_num)

self.model.eval()
interaction = copy.deepcopy(inter_feat).to(self.model.device)
interaction = interaction.repeat(self.neg_sample_num * candidate_num)
neg_item_feat = Interaction({self.iid_field: neg_candidate_ids.to(self.model.device)})
interaction.update(neg_item_feat)
scores = self.model.predict(interaction).reshape(candidate_num, -1)
indices = torch.max(scores, dim=0)[1].detach()
neg_candidate_ids = neg_candidate_ids.reshape(candidate_num, -1)
neg_item_ids = neg_candidate_ids[indices, [i for i in range(neg_candidate_ids.shape[1])]].view(-1)
self.model.train()
return self.sampling_func(inter_feat, neg_item_ids)
else:
raise NotImplementedError
elif self.strategy == 'by':
user_ids = inter_feat[self.uid_field]
item_ids = inter_feat[self.iid_field]
neg_item_ids = self.sampler.sample_by_user_ids(user_ids, item_ids, self.neg_sample_num)
Expand All @@ -179,3 +202,6 @@ def _neg_sample_by_point_wise_sampling(self, inter_feat, neg_item_ids):
labels[:pos_inter_num] = 1.0
new_data.update(Interaction({self.label_field: labels}))
return new_data

def get_model(self, model):
self.model = model
5 changes: 5 additions & 0 deletions recbole/data/dataloader/knowledge_dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,3 +182,8 @@ def set_mode(self, state):
if state not in set(KGDataLoaderState):
raise NotImplementedError(f'Kg data loader has no state named [{self.state}].')
self.state = state

def get_model(self, model):
"""Let the general_dataloader get the model, used for dynamic sampling.
"""
self.general_dataloader.get_model(model)
2 changes: 1 addition & 1 deletion recbole/data/dataset/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -433,7 +433,7 @@ def _load_feat(self, filepath, source):
return None

df = pd.read_csv(
filepath, delimiter=self.config['field_separator'], usecols=usecols, dtype=dtype, encoding=encoding
filepath, delimiter=self.config['field_separator'], usecols=usecols, dtype=dtype, encoding=encoding, engine='python'
)
df.columns = columns

Expand Down
3 changes: 3 additions & 0 deletions recbole/model/abstract_recommender.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,9 @@ def __init__(self, config, dataset):
self.max_seq_length = config['MAX_ITEM_LIST_LENGTH']
self.n_items = dataset.num(self.ITEM_ID)

# load parameters info
self.device = config['device']

def gather_indexes(self, output, gather_index):
"""Gathers the vectors at the specific positions over a minibatch"""
gather_index = gather_index.view(-1, 1, 1).expand(-1, -1, output.shape[-1])
Expand Down
1 change: 0 additions & 1 deletion recbole/model/sequential_recommender/hrm.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,6 @@ def predict(self, interaction):
test_item = interaction[self.ITEM_ID]
user = interaction[self.USER_ID]
seq_output = self.forward(item_seq, user, seq_item_len)
seq_output = seq_output.repeat(1, self.embedding_size)
test_item_emb = self.item_embedding(test_item)
scores = torch.mul(seq_output, test_item_emb).sum(dim=1)

Expand Down
4 changes: 3 additions & 1 deletion recbole/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@ def __init__(self, config, model):
self.tensorboard = get_tensorboard(self.logger)
self.learner = config['learner']
self.learning_rate = config['learning_rate']
self.sample_strategy = config['train_neg_sample_args']['strategy']
self.epochs = config['epochs']
self.eval_step = min(config['eval_step'], self.epochs)
self.stopping_step = config['stopping_step']
Expand Down Expand Up @@ -308,7 +309,8 @@ def fit(self, train_data, valid_data=None, verbose=True, saved=True, show_progre
self._save_checkpoint(-1)

self.eval_collector.data_collect(train_data)

if isinstance(self.sample_strategy, dict):
train_data.get_model(self.model)
for epoch_idx in range(self.start_epoch, self.epochs):
# train
training_start_time = time()
Expand Down
20 changes: 16 additions & 4 deletions tests/model/test_model_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -656,17 +656,29 @@ def test_bert4rec(self):
'model': 'BERT4Rec',
'neg_sampling': None
}
objective_function(config_dict=config_dict,
config_file_list=config_file_list, saved=False)
quick_test(config_dict)

def test_bert4rec_with_BPR_loss_and_swish(self):
config_dict = {
'model': 'BERT4Rec',
'loss_type': 'BPR',
'hidden_act': 'swish'
}
objective_function(config_dict=config_dict,
config_file_list=config_file_list, saved=False)
quick_test(config_dict)

def test_lightsans(self):
config_dict = {
'model': 'LightSANs',
'neg_sampling': None
}
quick_test(config_dict)

def test_lightsans_with_BPR_loss(self):
config_dict = {
'model': 'LightSANs',
'loss_type': 'BPR',
}
quick_test(config_dict)

# def test_gru4reckg(self):
# config_dict = {
Expand Down