-
-
Notifications
You must be signed in to change notification settings - Fork 8.7k
/
learning_to_rank.py
215 lines (180 loc) · 6.2 KB
/
learning_to_rank.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
"""
Getting started with learning to rank
=====================================
.. versionadded:: 2.0.0
This is a demonstration of using XGBoost for learning to rank tasks using the
MSLR_10k_letor dataset. For more infomation about the dataset, please visit its
`description page <https://www.microsoft.com/en-us/research/project/mslr/>`_.
This is a two-part demo, the first one contains a basic example of using XGBoost to
train on relevance degree, and the second part simulates click data and enable the
position debiasing training.
For an overview of learning to rank in XGBoost, please see :doc:`Learning to Rank
</tutorials/learning_to_rank>`.
"""
from __future__ import annotations
import argparse
import json
import os
import pickle as pkl
import numpy as np
import pandas as pd
from sklearn.datasets import load_svmlight_file
import xgboost as xgb
from xgboost.testing.data import RelDataCV, simulate_clicks, sort_ltr_samples
def load_mslr_10k(data_path: str, cache_path: str) -> RelDataCV:
"""Load the MSLR10k dataset from data_path and cache a pickle object in cache_path.
Returns
-------
A list of tuples [(X, y, qid), ...].
"""
root_path = os.path.expanduser(args.data)
cacheroot_path = os.path.expanduser(args.cache)
cache_path = os.path.join(cacheroot_path, "MSLR_10K_LETOR.pkl")
# Use only the Fold1 for demo:
# Train, Valid, Test
# {S1,S2,S3}, S4, S5
fold = 1
if not os.path.exists(cache_path):
fold_path = os.path.join(root_path, f"Fold{fold}")
train_path = os.path.join(fold_path, "train.txt")
valid_path = os.path.join(fold_path, "vali.txt")
test_path = os.path.join(fold_path, "test.txt")
X_train, y_train, qid_train = load_svmlight_file(
train_path, query_id=True, dtype=np.float32
)
y_train = y_train.astype(np.int32)
qid_train = qid_train.astype(np.int32)
X_valid, y_valid, qid_valid = load_svmlight_file(
valid_path, query_id=True, dtype=np.float32
)
y_valid = y_valid.astype(np.int32)
qid_valid = qid_valid.astype(np.int32)
X_test, y_test, qid_test = load_svmlight_file(
test_path, query_id=True, dtype=np.float32
)
y_test = y_test.astype(np.int32)
qid_test = qid_test.astype(np.int32)
data = RelDataCV(
train=(X_train, y_train, qid_train),
test=(X_test, y_test, qid_test),
max_rel=4,
)
with open(cache_path, "wb") as fd:
pkl.dump(data, fd)
with open(cache_path, "rb") as fd:
data = pkl.load(fd)
return data
def ranking_demo(args: argparse.Namespace) -> None:
"""Demonstration for learning to rank with relevance degree."""
data = load_mslr_10k(args.data, args.cache)
# Sort data according to query index
X_train, y_train, qid_train = data.train
sorted_idx = np.argsort(qid_train)
X_train = X_train[sorted_idx]
y_train = y_train[sorted_idx]
qid_train = qid_train[sorted_idx]
X_test, y_test, qid_test = data.test
sorted_idx = np.argsort(qid_test)
X_test = X_test[sorted_idx]
y_test = y_test[sorted_idx]
qid_test = qid_test[sorted_idx]
ranker = xgb.XGBRanker(
tree_method="hist",
device="cuda",
lambdarank_pair_method="topk",
lambdarank_num_pair_per_sample=13,
eval_metric=["ndcg@1", "ndcg@8"],
)
ranker.fit(
X_train,
y_train,
qid=qid_train,
eval_set=[(X_test, y_test)],
eval_qid=[qid_test],
verbose=True,
)
def click_data_demo(args: argparse.Namespace) -> None:
"""Demonstration for learning to rank with click data."""
data = load_mslr_10k(args.data, args.cache)
train, test = simulate_clicks(data)
assert test is not None
assert train.X.shape[0] == train.click.size
assert test.X.shape[0] == test.click.size
assert test.score.dtype == np.float32
assert test.click.dtype == np.int32
X_train, clicks_train, y_train, qid_train = sort_ltr_samples(
train.X,
train.y,
train.qid,
train.click,
train.pos,
)
X_test, clicks_test, y_test, qid_test = sort_ltr_samples(
test.X,
test.y,
test.qid,
test.click,
test.pos,
)
class ShowPosition(xgb.callback.TrainingCallback):
def after_iteration(
self,
model: xgb.Booster,
epoch: int,
evals_log: xgb.callback.TrainingCallback.EvalsLog,
) -> bool:
config = json.loads(model.save_config())
ti_plus = np.array(config["learner"]["objective"]["ti+"])
tj_minus = np.array(config["learner"]["objective"]["tj-"])
df = pd.DataFrame({"ti+": ti_plus, "tj-": tj_minus})
print(df)
return False
ranker = xgb.XGBRanker(
n_estimators=512,
tree_method="hist",
device="cuda",
learning_rate=0.01,
reg_lambda=1.5,
subsample=0.8,
sampling_method="gradient_based",
# LTR specific parameters
objective="rank:ndcg",
# - Enable bias estimation
lambdarank_unbiased=True,
# - normalization (1 / (norm + 1))
lambdarank_bias_norm=1,
# - Focus on the top 12 documents
lambdarank_num_pair_per_sample=12,
lambdarank_pair_method="topk",
ndcg_exp_gain=True,
eval_metric=["ndcg@1", "ndcg@3", "ndcg@5", "ndcg@10"],
callbacks=[ShowPosition()],
)
ranker.fit(
X_train,
clicks_train,
qid=qid_train,
eval_set=[(X_test, y_test), (X_test, clicks_test)],
eval_qid=[qid_test, qid_test],
verbose=True,
)
ranker.predict(X_test)
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="Demonstration of learning to rank using XGBoost."
)
parser.add_argument(
"--data",
type=str,
help="Root directory of the MSLR-WEB10K data.",
required=True,
)
parser.add_argument(
"--cache",
type=str,
help="Directory for caching processed data.",
required=True,
)
args = parser.parse_args()
ranking_demo(args)
click_data_demo(args)