Skip to content

Commit

Permalink
[AMP]fix embedding model weight type mismatch error (#53770) (#53827)
Browse files Browse the repository at this point in the history
Pcard-70458
cherry-pick: #53770
  • Loading branch information
shaojiewang authored May 16, 2023
1 parent e6464f3 commit 4a08f7e
Show file tree
Hide file tree
Showing 2 changed files with 164 additions and 4 deletions.
22 changes: 18 additions & 4 deletions python/paddle/static/amp/fp16_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -438,7 +438,7 @@ def set_param_dtype(program, dtype, amp_lists, use_fp16_guard, level):
# input in transformer, so the weight is also in to_fp16_var_names.
# TODO(zhangting2020): consider fix auto_parallel_fp16 and remove lookup_table
# from black_list and unsupport_list.
if op in ['lookup_table', 'lookup_table_v2']:
if op.type in amp_lists.black_list:
continue
if _need_keep_fp32(op, amp_lists.unsupported_list, use_fp16_guard):
for in_name in op.input_names:
Expand All @@ -461,8 +461,9 @@ def set_param_dtype(program, dtype, amp_lists, use_fp16_guard, level):
return keep_fp32_var_names


def op_need_keep_fp32(op, amp_lists, use_fp16_guard):
def op_need_keep_fp32(op, amp_lists, use_fp16_guard, params_list):
need_keep_fp32 = False
fp16_varname_list_in_fp32_op = set()
if _need_keep_fp32(
op,
amp_lists.unsupported_list,
Expand All @@ -475,8 +476,14 @@ def op_need_keep_fp32(op, amp_lists, use_fp16_guard):
need_keep_fp32 = True
elif op.type in amp_lists.black_list:
need_keep_fp32 = True
for in_name in op.input_names:
for params in params_list:
if op.input(in_name)[0] == params.name:
fp16_varname_list_in_fp32_op = (
fp16_varname_list_in_fp32_op.union(op.input(in_name))
)

return need_keep_fp32
return need_keep_fp32, fp16_varname_list_in_fp32_op


def get_promote_dtype(op, amp_dtype, block):
Expand Down Expand Up @@ -651,7 +658,14 @@ def need_process(op):
if not need_process(op):
_logger.debug("---- The op does not need to be processed ----.")
continue
if op_need_keep_fp32(op, amp_lists, use_fp16_guard):
all_params = global_block.all_parameters()
op_keep_fp32, fp16_var_names_in_fp32_op = op_need_keep_fp32(
op, amp_lists, use_fp16_guard, all_params
)
to_fp16_var_names = to_fp16_var_names.union(
fp16_var_names_in_fp32_op
)
if op_keep_fp32:
keep_fp32_ops.add(op)
process_op_input_and_outputs(
op, block, global_block, core.VarDesc.VarType.FP32
Expand Down
146 changes: 146 additions & 0 deletions test/amp/test_amp_o2_embedding_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,146 @@
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


import random
import unittest

import numpy as np
from amp_base_models import AmpTestBase, _build_optimizer

import paddle
from paddle import nn

paddle.enable_static()

_fixed_param = np.random.random(size=[64, 64]).astype("float32")


class SimpleUnittedEmbeddingNet(nn.Layer):
def __init__(self):
super().__init__()
self.vocab_size = 64
self.hidden_size = 64
global _fixed_param

self.param_attr = paddle.ParamAttr(
initializer=paddle.nn.initializer.Assign(_fixed_param)
)
self.embedding = nn.Embedding(
self.vocab_size, self.hidden_size, weight_attr=self.param_attr
)
self.linear = nn.Linear(
in_features=self.hidden_size,
out_features=self.vocab_size,
weight_attr=self.param_attr,
)

def forward(self, x):
out = self.embedding(x)
scale = paddle.full(shape=[1], fill_value=2, dtype="int64")
out = paddle.multiply(out, scale.astype("float32"))
out = self.linear(out)
out = nn.functional.dropout(out, p=0.2)
return out


def build_unitted_embedding_model(
use_amp,
amp_dtype="float16",
amp_level="O1",
use_promote=False,
):
main_program = paddle.static.Program()
startup_program = paddle.static.Program()
with paddle.utils.unique_name.guard():
with paddle.static.program_guard(main_program, startup_program):
model = SimpleUnittedEmbeddingNet()
x = paddle.static.data(name='x', shape=[None, 32], dtype='int64')
out = model(x)
loss = paddle.mean(out)
if use_amp:
amp_lists = paddle.static.amp.AutoMixedPrecisionLists(
custom_white_list=["elementwise_mul"],
custom_black_list=["reduce_mean"],
dtype=amp_dtype,
)
else:
amp_lists = None
optimizer = _build_optimizer(
use_amp,
amp_dtype,
amp_level,
amp_lists,
True,
use_promote=use_promote,
)
optimizer.minimize(loss)

feed_vars = [x]
fetch_vars = [loss]
return main_program, startup_program, optimizer, feed_vars, fetch_vars


class TestUnittedEmbedding(AmpTestBase):
def _generate_feed_x(self):
seed = 0
paddle.seed(seed)
np.random.seed(seed)
random.seed(seed)

x = np.random.randint(1, 64, size=[1, 32]).astype("int64")
return x

def test_compare_o1_and_o2_master_grad(self):
def _run(place, exe, x_np, max_iters, level):
(
main_program,
startup_program,
optimizer,
feed_vars,
fetch_vars,
) = build_unitted_embedding_model(
True,
"float16",
level,
)

seed = 0
paddle.seed(seed)
np.random.seed(seed)
random.seed(seed)

losses = self.run_program(
main_program,
startup_program,
optimizer,
feed_vars,
fetch_vars,
place,
exe,
x_np,
max_iters,
level,
)
return losses

max_iters = 5
x = self._generate_feed_x()
place = paddle.CUDAPlace(0)
exe = paddle.static.Executor(place)
losses_o2 = _run(place, exe, x, max_iters, 'O2')


if __name__ == "__main__":
unittest.main()

0 comments on commit 4a08f7e

Please sign in to comment.