运行下面的命令,检查学习率模块设置是否正确。
python test_lr_scheduler.py
最终输出内容如下。
[2021/11/17 21:44:19] root INFO: step_100_linear_lr:
[2021/11/17 21:44:19] root INFO: mean diff: check passed: True, value: 0.0
[2021/11/17 21:44:19] root INFO: step_300_linear_lr:
[2021/11/17 21:44:19] root INFO: mean diff: check passed: True, value: 0.0
[2021/11/17 21:44:19] root INFO: step_500_linear_lr:
[2021/11/17 21:44:19] root INFO: mean diff: check passed: True, value: 0.0
[2021/11/17 21:44:19] root INFO: step_700_linear_lr:
[2021/11/17 21:44:19] root INFO: mean diff: check passed: True, value: 0.0
[2021/11/17 21:44:19] root INFO: step_900_linear_lr:
[2021/11/17 21:44:19] root INFO: mean diff: check passed: True, value: 0.0
[2021/11/17 21:44:19] root INFO: step_100_cosine_lr:
[2021/11/17 21:44:19] root INFO: mean diff: check passed: True, value: 0.0
[2021/11/17 21:44:19] root INFO: step_300_cosine_lr:
[2021/11/17 21:44:19] root INFO: mean diff: check passed: True, value: 0.0
[2021/11/17 21:44:19] root INFO: step_500_cosine_lr:
[2021/11/17 21:44:19] root INFO: mean diff: check passed: False, value: 9.35605818719964e-06
[2021/11/17 21:44:19] root INFO: step_700_cosine_lr:
[2021/11/17 21:44:19] root INFO: mean diff: check passed: False, value: 1.3681476625617212e-05
[2021/11/17 21:44:19] root INFO: step_900_cosine_lr:
[2021/11/17 21:44:19] root INFO: mean diff: check passed: False, value: 1.8924391285779562e-05
[2021/11/17 21:44:19] root INFO: step_100_polynomial_lr:
[2021/11/17 21:44:19] root INFO: mean diff: check passed: True, value: 0.0
[2021/11/17 21:44:19] root INFO: step_300_polynomial_lr:
[2021/11/17 21:44:19] root INFO: mean diff: check passed: True, value: 0.0
[2021/11/17 21:44:19] root INFO: step_500_polynomial_lr:
[2021/11/17 21:44:19] root INFO: mean diff: check passed: True, value: 0.0
[2021/11/17 21:44:19] root INFO: step_700_polynomial_lr:
[2021/11/17 21:44:19] root INFO: mean diff: check passed: True, value: 0.0
[2021/11/17 21:44:19] root INFO: step_900_polynomial_lr:
[2021/11/17 21:44:19] root INFO: mean diff: check passed: True, value: 0.0
[2021/11/17 21:44:19] root INFO: diff check failed
linear和polynomial方式衰减的学习率diff为0,check通过,cosine方式衰减学习率可能由于计算误差未通过。
以PaddlePaddle为例,训练流程核心代码如下所示。每个iter中输入相同的fake data与fake label,计算loss,进行梯度反传与参数更新,将loss批量返回,用于后续的验证。
def pd_train_some_iters(model,
criterion,
optimizer,
fake_data,
fake_label,
max_iter=2):
model = PDBertForSequenceClassification.from_pretrained("bert-base-uncased", num_classes=2)
classifier_weights = paddle.load("../classifier_weights/paddle_classifier_weights.bin")
model.load_dict(classifier_weights)
model.eval()
criterion = paddle.nn.CrossEntropy()
decay_params = [
p.name for n, p in model.named_parameters()
if not any(nd in n for nd in ["bias", "norm"])
]
optimizer = paddle.optimizer.AdamW(learning_rate=3e-5, parameters=model.parameters(),
weight_decay=1e-2,
epsilon=1e-6,
apply_decay_param_fun=lambda x: x in decay_params)
loss_list = []
for idx in range(max_iter):
input_ids = paddle.to_tensor(fake_data)
labels = paddle.to_tensor(fake_label)
output = model(input_ids)
loss = criterion(output, labels)
loss.backward()
optimizer.step()
optimizer.clear_grad()
loss_list.append(loss)
return loss_list
运行下面的命令,基于fake data与fake label,依次生成若干轮loss数据并保存,使用reprod_log
工具进行diff排查。
# 生成paddle和torch的前向数据
python test_bp.py
# 对比生成log
python check_step4.py
最终输出结果如下,同时会保存在文件bp_align_diff.log
中。
[2021/11/17 22:08:30] root INFO: loss_0:
[2021/11/17 22:08:30] root INFO: mean diff: check passed: True, value: 0.0
[2021/11/17 22:08:30] root INFO: loss_1:
[2021/11/17 22:08:30] root INFO: mean diff: check passed: True, value: 0.0
[2021/11/17 22:08:30] root INFO: loss_2:
[2021/11/17 22:08:30] root INFO: mean diff: check passed: True, value: 0.0
[2021/11/17 22:08:30] root INFO: loss_3:
[2021/11/17 22:08:30] root INFO: mean diff: check passed: True, value: 0.0
[2021/11/17 22:08:30] root INFO: loss_4:
[2021/11/17 22:08:30] root INFO: mean diff: check passed: True, value: 0.0
[2021/11/17 22:08:30] root INFO: loss_5:
[2021/11/17 22:08:30] root INFO: mean diff: check passed: True, value: 0.0
[2021/11/17 22:08:30] root INFO: loss_6:
[2021/11/17 22:08:30] root INFO: mean diff: check passed: True, value: 0.0
[2021/11/17 22:08:30] root INFO: loss_7:
[2021/11/17 22:08:30] root INFO: mean diff: check passed: True, value: 0.0
[2021/11/17 22:08:30] root INFO: loss_8:
[2021/11/17 22:08:30] root INFO: mean diff: check passed: True, value: 0.0
[2021/11/17 22:08:30] root INFO: loss_9:
[2021/11/17 22:08:30] root INFO: mean diff: check passed: True, value: 0.0
[2021/11/17 22:08:30] root INFO: diff check passed
前面10轮的loss diff均等于0,check通过。