Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

使用混合精度报错 #39881

Closed
jyjfjyjf opened this issue Feb 24, 2022 · 14 comments
Closed

使用混合精度报错 #39881

jyjfjyjf opened this issue Feb 24, 2022 · 14 comments
Assignees

Comments

@jyjfjyjf
Copy link

代码:

model, optimizer = paddle.amp.decorate(models=model, optimizers=optimizer, level=f16, master_weight=None, save_dtype=None)

global_step = 0
for epoch in range(1, epochs + 1):
    correct = 0
    data_num = 0
    batch_id = 0

    model.train()

    for data in tqdm(train_loader, desc=f'epoch {epoch} training'):
        input_ids = paddle.to_tensor(data[0])
        token_type_ids = paddle.to_tensor(data[1])
        labels = paddle.to_tensor(data[2])

        input_data = {'input_ids': input_ids,
                    'token_type_ids': token_type_ids,
                    'labels': labels}

        with paddle.amp.auto_cast(enable=True, custom_black_list=[
            "reduce_sum",
            "c_softmax_with_cross_entropy",
            "elementwise_div"
        ], level=f16):
            outputs = model(**input_data)
            logits = outputs[1]
            loss = outputs[0]

        probs = F.softmax(logits, axis=-1)

        correct = metric.compute(paddle.to_tensor(probs.cpu().detach().numpy()),
                                paddle.to_tensor(labels.cpu().detach().numpy()))
        metric.update(correct)
        acc = metric.accumulate()

        global_step += 1
        if global_step % 100 == 0:
            logger.info("global step %d, epoch: %d, batch: %d, loss: %.5f, acc: %.5f" % (
                global_step, epoch, batch_id, loss, acc))

        if f16 == 'O2':
            scaled = scaler.scale(loss)
            scaled.backward()
            scaler.step(optimizer)  
            scaler.update()

报错信息

          Traceback (most recent call last):
            File "C:\Users\woait\AppData\Roaming\JetBrains\IntelliJIdea2021.3\plugins\python\helpers\pydev\pydevd.py", line 1483, in _exec
              pydev_imports.execfile(file, globals, locals)  # execute the script
            File "C:\Users\woait\AppData\Roaming\JetBrains\IntelliJIdea2021.3\plugins\python\helpers\pydev\_pydev_imps\_pydev_execfile.py", line 18, in execfile
              exec(compile(contents+"\n", file, 'exec'), glob, loc)
            File "C:/JYN/workspace/DeBERTa_Decoding-enhanced_BERT_with_Disentangled_Attention/tools/train_eval.py", line 187, in <module>
              scaled.backward()
            File "C:\JYN\sd\anaconda\anaconda\envs\paddle-py37\lib\site-packages\decorator.py", line 232, in fun
              return caller(func, *(extras + args), **kw)
            File "C:\JYN\sd\anaconda\anaconda\envs\paddle-py37\lib\site-packages\paddle\fluid\wrapped_decorator.py", line 25, in __impl__
              return wrapped_func(*args, **kwargs)
            File "C:\JYN\sd\anaconda\anaconda\envs\paddle-py37\lib\site-packages\paddle\fluid\framework.py", line 229, in __impl__
              return func(*args, **kwargs)
            File "C:\JYN\sd\anaconda\anaconda\envs\paddle-py37\lib\site-packages\paddle\fluid\dygraph\varbase_patch_methods.py", line 249, in backward
              framework._dygraph_tracer())
          OSError: (External) RuntimeError: (NotFound) Operator where does not have kernel for data_type[::paddle::platform::float16]:data_layout[ANY_LAYOUT]:place[CUDAPlace(0)]:library_type[PLAIN].
            [Hint: Expected kernel_iter != kernels.end(), but received kernel_iter == kernels.end().] (at ..\paddle\fluid\imperative\prepared_operator.cc:159)
            [operator < where > error]
          
          At:
            C:\JYN\sd\anaconda\anaconda\envs\paddle-py37\lib\site-packages\paddle\tensor\search.py(522): where
            C:\JYN\workspace\DeBERTa_Decoding-enhanced_BERT_with_Disentangled_Attention\paddle_deberta\paddlenlp\transformers\deberta\modeling_deberta.py(127): backward
            C:\JYN\sd\anaconda\anaconda\envs\paddle-py37\lib\site-packages\paddle\autograd\py_layer.py(181): backward
            C:\JYN\sd\anaconda\anaconda\envs\paddle-py37\lib\site-packages\paddle\fluid\dygraph\varbase_patch_methods.py(249): backward
            C:\JYN\sd\anaconda\anaconda\envs\paddle-py37\lib\site-packages\paddle\fluid\framework.py(229): __impl__
            C:\JYN\sd\anaconda\anaconda\envs\paddle-py37\lib\site-packages\paddle\fluid\wrapped_decorator.py(25): __impl__
            C:\JYN\sd\anaconda\anaconda\envs\paddle-py37\lib\site-packages\decorator.py(232): fun
            C:/JYN/workspace/DeBERTa_Decoding-enhanced_BERT_with_Disentangled_Attention/tools/train_eval.py(187): <module>
            C:\Users\woait\AppData\Roaming\JetBrains\IntelliJIdea2021.3\plugins\python\helpers\pydev\_pydev_imps\_pydev_execfile.py(18): execfile
            C:\Users\woait\AppData\Roaming\JetBrains\IntelliJIdea2021.3\plugins\python\helpers\pydev\pydevd.py(1483): _exec
            C:\Users\woait\AppData\Roaming\JetBrains\IntelliJIdea2021.3\plugins\python\helpers\pydev\pydevd.py(1476): run
            C:\Users\woait\AppData\Roaming\JetBrains\IntelliJIdea2021.3\plugins\python\helpers\pydev\pydevd.py(2164): main
            C:\Users\woait\AppData\Roaming\JetBrains\IntelliJIdea2021.3\plugins\python\helpers\pydev\pydevd.py(2173): <module>
           (at ..\paddle\fluid\imperative\basic_engine.cc:571)
@paddle-bot-old
Copy link

您好,我们已经收到了您的问题,会安排技术人员尽快解答您的问题,请耐心等待。请您再次检查是否提供了清晰的问题描述、复现代码、环境&版本、报错信息等。同时,您也可以通过查看官网API文档常见问题历史IssueAI社区来寻求解答。祝您生活愉快~

Hi! We've received your issue and please be patient to get responded. We will arrange technicians to answer your questions as soon as possible. Please make sure that you have posted enough message to demo your request. You may also check out the APIFAQGithub Issue and AI community to get the answer.Have a nice day!

@zhiqiu
Copy link
Contributor

zhiqiu commented Feb 24, 2022

hi,请问你的模型代码能发出来看下吗?

@jyjfjyjf
Copy link
Author

https://github.com/jyjfjyjf/DeBERTa 你好这里是模型代码

@zhiqiu
Copy link
Contributor

zhiqiu commented Feb 25, 2022

https://github.com/jyjfjyjf/DeBERTa 你好这里是模型代码

这个问题是因为网络中用了PyLayer,因为该PyLayer的前向在auto cast中跑,部分计算可能用fp16。但是反向的时候没有使用auto cast,所以会报错。
image

你可以选择(1)在前向中加上 with auto_cast(enable=false)关闭混合精度,或者(2)在反向中加上和前向一样的with auto_cast(enable=true)

@jyjfjyjf
Copy link
Author

试过了,都不行啊使用(1)会报另一个错

Traceback (most recent call last):
  File "C:\JYN\sd\anaconda\anaconda\envs\paddle-py37\lib\contextlib.py", line 130, in __exit__
    self.gen.throw(type, value, traceback)
  File "C:\JYN\sd\anaconda\anaconda\envs\paddle-py37\lib\site-packages\paddle\fluid\dygraph\amp\auto_cast.py", line 290, in amp_guard
    yield
  File "C:/JYN/workspace/DeBERTa_Decoding-enhanced_BERT_with_Disentangled_Attention/tools/train_eval.py", line 148, in <module>
    outputs = model(**input_data)
  File "C:\JYN\sd\anaconda\anaconda\envs\paddle-py37\lib\site-packages\paddle\fluid\dygraph\layers.py", line 914, in __call__
    outputs = self.forward(*inputs, **kwargs)
  File "C:\JYN\workspace\DeBERTa_Decoding-enhanced_BERT_with_Disentangled_Attention\paddle_deberta\paddlenlp\transformers\deberta\modeling_deberta.py", line 1233, in forward
    return_dict=return_dict,
  File "C:\JYN\sd\anaconda\anaconda\envs\paddle-py37\lib\site-packages\paddle\fluid\dygraph\layers.py", line 914, in __call__
    outputs = self.forward(*inputs, **kwargs)
  File "C:\JYN\workspace\DeBERTa_Decoding-enhanced_BERT_with_Disentangled_Attention\paddle_deberta\paddlenlp\transformers\deberta\modeling_deberta.py", line 1030, in forward
    inputs_embeds=inputs_embeds,
  File "C:\JYN\sd\anaconda\anaconda\envs\paddle-py37\lib\site-packages\paddle\fluid\dygraph\layers.py", line 914, in __call__
    outputs = self.forward(*inputs, **kwargs)
  File "C:\JYN\workspace\DeBERTa_Decoding-enhanced_BERT_with_Disentangled_Attention\paddle_deberta\paddlenlp\transformers\deberta\modeling_deberta.py", line 694, in forward
    embeddings = self.dropout(embeddings)
  File "C:\JYN\sd\anaconda\anaconda\envs\paddle-py37\lib\site-packages\paddle\fluid\dygraph\layers.py", line 914, in __call__
    outputs = self.forward(*inputs, **kwargs)
  File "C:\JYN\workspace\DeBERTa_Decoding-enhanced_BERT_with_Disentangled_Attention\paddle_deberta\paddlenlp\transformers\deberta\modeling_deberta.py", line 152, in forward
    return XDropout.apply(x, self.get_context())
  File "C:\JYN\sd\anaconda\anaconda\envs\paddle-py37\lib\site-packages\decorator.py", line 232, in fun
    return caller(func, *(extras + args), **kw)
  File "C:\JYN\sd\anaconda\anaconda\envs\paddle-py37\lib\site-packages\paddle\fluid\wrapped_decorator.py", line 25, in __impl__
    return wrapped_func(*args, **kwargs)
  File "C:\JYN\sd\anaconda\anaconda\envs\paddle-py37\lib\site-packages\paddle\fluid\framework.py", line 229, in __impl__
    return func(*args, **kwargs)
  File "C:\JYN\sd\anaconda\anaconda\envs\paddle-py37\lib\site-packages\paddle\autograd\py_layer.py", line 174, in apply
    return core.pylayer_apply(place, cls, *args, **kwargs)
  File "C:\JYN\workspace\DeBERTa_Decoding-enhanced_BERT_with_Disentangled_Attention\paddle_deberta\paddlenlp\transformers\deberta\modeling_deberta.py", line 113, in forward
    input = paddle.where(mask, tmp_mask, input)
  File "C:\JYN\sd\anaconda\anaconda\envs\paddle-py37\lib\site-packages\paddle\tensor\search.py", line 522, in where
    return _C_ops.where(condition, x, y)
RuntimeError: (NotFound) Operator where does not have kernel for data_type[::paddle::platform::float16]:data_layout[ANY_LAYOUT]:place[CUDAPlace(0)]:library_type[PLAIN].
  [Hint: Expected kernel_iter != kernels.end(), but received kernel_iter == kernels.end().] (at ..\paddle\fluid\imperative\prepared_operator.cc:159)
  [operator < where > error]

Process finished with exit code 1

@jyjfjyjf
Copy link
Author

where不支持f16?

@jyjfjyjf
Copy link
Author

backward之前加上了

   with paddle.amp.auto_cast(enable=True, custom_black_list=[
                    "reduce_sum",
                    "c_softmax_with_cross_entropy",
                    "elementwise_div"
                ], level=f16):

还是会报之前的错

@jyjfjyjf
Copy link
Author

在backward的where前面加上了混合精度,但是又报新的错了

ValueError: (InvalidArgument) Tensor holds the wrong type, it holds float, but desires to be ::paddle::platform::float16.
  [Hint: Expected valid == true, but received valid:0 != true:1.] (at ..\paddle/fluid/framework/tensor_impl.h:33)

@zhiqiu
Copy link
Contributor

zhiqiu commented Feb 28, 2022

请问你用的是哪个版本的paddle呢?

@jyjfjyjf
Copy link
Author

这个版本的
paddlepaddle-gpu 2.2.2.post101

@zhiqiu
Copy link
Contributor

zhiqiu commented Feb 28, 2022

backward之前加上了

   with paddle.amp.auto_cast(enable=True, custom_black_list=[
                    "reduce_sum",
                    "c_softmax_with_cross_entropy",
                    "elementwise_div"
                ], level=f16):

还是会报之前的错

这个是咋改的?能发下上下文代码看看吗

@jyjfjyjf
Copy link
Author

    input_ids = paddle.to_tensor(data[0])
    token_type_ids = paddle.to_tensor(data[1])
    labels = paddle.to_tensor(data[2])

    input_data = {'input_ids': input_ids,
                  'token_type_ids': token_type_ids,
                  'labels': labels}
    with paddle.amp.auto_cast(enable=True, custom_black_list=[
        "reduce_sum",
        "c_softmax_with_cross_entropy",
        "elementwise_div"
    ], level=f16):
        outputs = model(**input_data)
        logits = outputs[1]
        loss = outputs[0]

    probs = F.softmax(logits, axis=-1)

    correct = metric.compute(paddle.to_tensor(probs.cpu().detach().numpy()),
                             paddle.to_tensor(labels.cpu().detach().numpy()))
    metric.update(correct)
    acc = metric.accumulate()

    global_step += 1
    if global_step % 100 == 0:
        logger.info("global step %d, epoch: %d, batch: %d, loss: %.5f, acc: %.5f" % (
            global_step, epoch, batch_id, loss, acc))

    if f16 == 'O2':
        scaled = scaler.scale(loss)
        with paddle.amp.auto_cast(enable=True, custom_black_list=[
            "reduce_sum",
            "c_softmax_with_cross_entropy",
            "elementwise_div"
        ], level=f16):
            scaled.backward()
        scaler.step(optimizer)
        scaler.update()
    else:
        loss.backward()

        optimizer.step()
    lr_scheduler.step()

    optimizer.clear_grad()

@zhiqiu
Copy link
Contributor

zhiqiu commented Mar 3, 2022

modeling_deberta.py

(2)在反向中加上和前向一样的with auto_cast(enable=true),可以参考下面的示例修改下:

class XDropout(paddle.autograd.PyLayer):
    """Optimized dropout function to save computation and memory by using mask operation instead of multiplication."""

    @staticmethod
    def forward(ctx, input):
        mask = (1 - paddle.bernoulli(paddle.rand(input.shape))).astype('bool')
        dropout = 0.2
        ctx.scale = 1.0 / (1 - dropout)
        if dropout > 0:
            ctx.save_for_backward(mask)
            stop_gradient = input.stop_gradient
            # 使用paddle实现torch的masked_fill_
            tmp_mask = paddle.full(input.shape, 0, input.dtype)
            input = paddle.where(mask, tmp_mask, input)
            input.stop_gradient = stop_gradient

            return input * ctx.scale
        else:
            return input

    @staticmethod
    def backward(ctx, grad_output):
        with paddle.amp.auto_cast(level='O2'):   # 反向添加auto_cast
            if ctx.scale > 1:
                mask, = ctx.saved_tensor()
                
                stop_gradient = grad_output.stop_gradient
                tmp_mask = paddle.full(grad_output.shape, 0, grad_output.dtype)
                grad_output = paddle.where(mask, tmp_mask, grad_output)
                grad_output.stop_gradient = stop_gradient
                grad_output = grad_output * ctx.scale

                return grad_output
            else:
                return grad_output

x = paddle.rand([10, 10])
x.stop_gradient=False

with paddle.amp.auto_cast(level='O2'):
   res = XDropout.apply(x)
   loss = paddle.mean(res)

loss.backward() 

@paddle-bot
Copy link

paddle-bot bot commented Mar 7, 2023

Since you haven't replied for more than a year, we have closed this issue/pr.
If the problem is not solved or there is a follow-up one, please reopen it at any time and we will continue to follow up.
由于您超过一年未回复,我们将关闭这个issue/pr。
若问题未解决或有后续问题,请随时重新打开,我们会继续跟进。

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants