Skip to content

Commit

Permalink
fix!(proto): change fused_adamw's arg type float->double (#1359)
Browse files Browse the repository at this point in the history
fix adamw
  • Loading branch information
ustclight-sls authored Nov 22, 2024
1 parent 6747443 commit b179d61
Show file tree
Hide file tree
Showing 4 changed files with 6 additions and 7 deletions.
4 changes: 2 additions & 2 deletions diopi_test/python/configs/diopi_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -5552,8 +5552,8 @@
'fused_adamw': dict(
name=['fused_adamw'],
interface=["CustomizedTest"],
atol=1e-2,
rtol=2e-3,
atol=3e-5,
rtol=3e-5,
atol_half=1e-2,
rtol_half=2e-3,
para=dict(
Expand Down
5 changes: 2 additions & 3 deletions diopi_test/python/conformance/customized_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,21 +163,20 @@ def fused_adamw(
amsgrad,
maximize,
):
torch.optim._functional.adamw(
torch._fused_adamw_(
params,
grads,
exp_avgs,
exp_avg_sqs,
max_exp_avg_sqs,
state_steps,
amsgrad=amsgrad,
lr=lr,
beta1=beta1,
beta2=beta2,
lr=lr,
weight_decay=weight_decay,
eps=eps,
maximize=maximize,
fused=True,
)
return params, exp_avgs, exp_avg_sqs, max_exp_avg_sqs

Expand Down
2 changes: 1 addition & 1 deletion impl/torch/functions/functions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3360,7 +3360,7 @@ diopiError_t diopiGridSample(diopiContextHandle_t ctx, diopiTensorHandle_t out,

diopiError_t diopiFusedAdamW(diopiContextHandle_t ctx, diopiTensorHandle_t* params, diopiConstTensorHandle_t* grads, diopiTensorHandle_t* exp_avgs,
diopiTensorHandle_t* exp_avg_sqs, diopiTensorHandle_t* max_exp_avg_sqs, diopiConstTensorHandle_t* state_steps, int64_t nums,
float lr, float beta1, float beta2, float eps, float weight_decay, bool amsgrad, bool maximize) {
double lr, double beta1, double beta2, double eps, double weight_decay, bool amsgrad, bool maximize) {
impl::aten::setCurStream(ctx);
DIOPI_CHECK_PTR(params);
DIOPI_IMPL_BUILD_ATEN_LIST(atParam, params, nums);
Expand Down
2 changes: 1 addition & 1 deletion proto/include/diopi/functions.h
Original file line number Diff line number Diff line change
Expand Up @@ -2829,7 +2829,7 @@ DIOPI_API diopiError_t diopiReciprocalInp(diopiContextHandle_t ctx, diopiTensorH
*/
DIOPI_API diopiError_t diopiFusedAdamW(diopiContextHandle_t ctx, diopiTensorHandle_t* params, diopiConstTensorHandle_t* grads, diopiTensorHandle_t* exp_avgs,
diopiTensorHandle_t* exp_avg_sqs, diopiTensorHandle_t* max_exp_avg_sqs, diopiConstTensorHandle_t* state_steps,
int64_t nums, float lr, float beta1, float beta2, float eps, float weight_decay, bool amsgrad, bool maximize);
int64_t nums, double lr, double beta1, double beta2, double eps, double weight_decay, bool amsgrad, bool maximize);

/**
* @brief The function is used to implement the AdamW optimizer. Its functionality is to perform a single parameter update.
Expand Down

0 comments on commit b179d61

Please sign in to comment.