-
Notifications
You must be signed in to change notification settings - Fork 0
/
recall_process.py
138 lines (101 loc) · 5.92 KB
/
recall_process.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
import os
import pickle
import pandas as pd
from tqdm import tqdm
from dataset import data_path, reduce_mem
from item_word2vec import trian_item_word2vec
# 对召回数据进行预处理
def get_recall_list(save_path, single_recall_model=None, multi_recall=True):
if multi_recall:
return pickle.load(open(save_path + 'final_recall_items_dict.pkl', 'rb'))
if single_recall_model == 'itemcf':
return pickle.load(open(save_path + 'itemcf_recall_dict.pkl', 'rb'))
elif single_recall_model == 'usercf':
return pickle.load(open(save_path + 'usercf_recall_dict.pkl', 'rb'))
elif single_recall_model == 'item2vec':
return pickle.load(open(save_path + 'embedding_sim_item_recall.pkl', 'rb'))
elif single_recall_model == 'youtubednn':
return pickle.load(open(save_path + 'youtubednn_recall_dict.pkl', 'rb'))
def get_article_info_df():
article_info_df = pd.read_csv(data_path + 'articles.csv')
article_info_df = reduce_mem(article_info_df)
return article_info_df
# 将召回列表转换成df的形式
def recall_dict_2_df(recall_list_dict):
df_row_list = [] # [user, item, score]
for user, recall_list in tqdm(recall_list_dict.items()):
for item, score in recall_list:
df_row_list.append([user, item, score])
col_names = ['user_id', 'sim_item', 'score']
recall_list_df = pd.DataFrame(df_row_list, columns=col_names)
return recall_list_df
# 负采样函数,这里可以控制负采样时的比例, 这里给了一个默认的值
def neg_sample_recall_data(recall_items_df, sample_rate=0.001):
pos_data = recall_items_df[recall_items_df['label'] == 1]
neg_data = recall_items_df[recall_items_df['label'] == 0]
print('pos_data_num:', len(pos_data), 'neg_data_num:', len(neg_data), 'pos/neg:', len(pos_data) / len(neg_data))
# 分组采样函数
def neg_sample_func(group_df):
neg_num = len(group_df)
sample_num = max(int(neg_num * sample_rate), 1) # 保证最少有一个
sample_num = min(sample_num, 5) # 保证最多不超过5个,这里可以根据实际情况进行选择
return group_df.sample(n=sample_num, replace=True)
# 对用户进行负采样,保证所有用户都在采样后的数据中
neg_data_user_sample = neg_data.groupby('user_id', group_keys=False).apply(neg_sample_func)
# 对文章进行负采样,保证所有文章都在采样后的数据中
neg_data_item_sample = neg_data.groupby('sim_item', group_keys=False).apply(neg_sample_func)
# 将上述两种情况下的采样数据合并
neg_data_new = neg_data_user_sample.append(neg_data_item_sample)
# 由于上述两个操作是分开的,可能将两个相同的数据给重复选择了,所以需要对合并后的数据进行去重
neg_data_new = neg_data_new.sort_values(['user_id', 'score']).drop_duplicates(['user_id', 'sim_item'], keep='last')
# 将正样本数据合并
data_new = pd.concat([pos_data, neg_data_new], ignore_index=True)
return data_new
# 召回数据打标签
def get_rank_label_df(recall_list_df, label_df, is_test=False):
# 测试集没有标签,直接给一个负数替代
if is_test:
recall_list_df['label'] = -1
return recall_list_df
label_df = label_df.rename(columns={'click_article_id': 'sim_item'})
recall_list_df_ = recall_list_df.merge(label_df[['user_id', 'sim_item', 'click_timestamp']],
how='left', on=['user_id', 'sim_item'])
recall_list_df_['label'] = recall_list_df_['click_timestamp'].apply(lambda x: 0.0 if np.isnan(x) else 1.0)
del recall_list_df_['click_timestamp']
return recall_list_df_
def get_user_recall_item_label_df(click_trn_hist, click_val_hist, click_tst_hist, click_trn_last, click_val_last,
recall_list_df, click_val=None):
# 获取训练数据的召回列表
trn_user_items_df = recall_list_df[recall_list_df['user_id'].isin(click_trn_hist['user_id'].unique())]
# 训练数据打标签
trn_user_item_label_df = get_rank_label_df(trn_user_items_df, click_trn_last, is_test=False)
# 训练数据负采样
trn_user_item_label_df = neg_sample_recall_data(trn_user_item_label_df)
if click_val is not None:
val_user_items_df = recall_list_df[recall_list_df['user_id'].isin(click_val_hist['user_id'].unique())]
val_user_item_label_df = get_rank_label_df(val_user_items_df, click_val_last, is_test=False)
val_user_item_label_df = neg_sample_recall_data(val_user_item_label_df)
else:
val_user_item_label_df = None
# 测试数据不需要进行负采样,直接对所有的召回商品进行打-1标签
tst_user_items_df = recall_list_df[recall_list_df['user_id'].isin(click_tst_hist['user_id'].unique())]
tst_user_item_label_df = get_rank_label_df(tst_user_items_df, None, is_test=True)
return trn_user_item_label_df, val_user_item_label_df, tst_user_item_label_df
# 将最终的召回的df数据转换成字典的形式做排序特征
def make_tuple_func(group_df):
row_data = []
for name, row_df in group_df.iterrows():
row_data.append((row_df['sim_item'], row_df['score'], row_df['label']))
return row_data
# 可以通过字典查询对应的item的Embedding
def get_embedding(save_path, all_click_df):
if os.path.exists(save_path + 'item_content_emb.pkl'):
item_content_emb_dict = pickle.load(open(save_path + 'item_content_emb.pkl', 'rb'))
else:
print('item_content_emb.pkl 文件不存在...')
# w2v Embedding是需要提前训练好的
if os.path.exists(save_path + 'item_w2v_emb.pkl'):
item_w2v_emb_dict = pickle.load(open(save_path + 'item_w2v_emb.pkl', 'rb'))
else:
item_w2v_emb_dict = trian_item_word2vec(all_click_df)
return item_content_emb_dict, item_w2v_emb_dict