Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Recompute] Support ernie for dygraph recompute. #2849

Merged
merged 6 commits into from
Aug 10, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 31 additions & 23 deletions model_zoo/ernie-1.0/run_pretrain.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
import paddle
import paddle.distributed as dist
import paddle.distributed.fleet as fleet
from paddle.distributed.fleet.utils.hybrid_parallel_util import fused_allreduce_gradients
from paddle.io import DataLoader, Dataset
from visualdl import LogWriter

Expand Down Expand Up @@ -327,6 +328,7 @@ def do_train(args):
model_config["hidden_dropout_prob"] = args.hidden_dropout_prob
model_config[
"attention_probs_dropout_prob"] = args.attention_probs_dropout_prob
model_config["enable_recompute"] = args.use_recompute
model = model_class(base_class(**model_config))
else:
model = model_class.from_pretrained(
Expand Down Expand Up @@ -462,33 +464,39 @@ def do_train(args):
input_ids, segment_ids, input_mask, masked_lm_positions, \
masked_lm_labels, next_sentence_labels = batch

with paddle.amp.auto_cast(args.use_amp,
custom_black_list=[
"reduce_sum",
"c_softmax_with_cross_entropy",
"elementwise_div"
],
level='O2'):

# Create the model for the ernie pretrain
prediction_scores, seq_relationship_score = model(
input_ids=input_ids,
token_type_ids=segment_ids,
position_ids=None,
attention_mask=input_mask,
masked_positions=masked_lm_positions)

lm_loss, sop_loss = criterion(prediction_scores,
seq_relationship_score,
masked_lm_labels,
next_sentence_labels)
loss = lm_loss + sop_loss
with model.no_sync():
with paddle.amp.auto_cast(args.use_amp,
custom_black_list=[
"reduce_sum",
"c_softmax_with_cross_entropy",
"elementwise_div"
],
level='O2'):

# Create the model for the ernie pretrain
prediction_scores, seq_relationship_score = model(
input_ids=input_ids,
token_type_ids=segment_ids,
position_ids=None,
attention_mask=input_mask,
masked_positions=masked_lm_positions)

lm_loss, sop_loss = criterion(prediction_scores,
seq_relationship_score,
masked_lm_labels,
next_sentence_labels)
loss = lm_loss + sop_loss

if args.use_amp:
scaler.scale(loss).backward()
else:
loss.backward()

fused_allreduce_gradients(list(model.parameters()), None)
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里 recompute 基于 PyLayer,所以要手动 fused_allreduce_gradients


if args.use_amp:
scaler.scale(loss).backward()
scaler.minimize(optimizer, loss)
else:
loss.backward()
optimizer.step()

optimizer.clear_grad()
Expand Down
1 change: 1 addition & 0 deletions model_zoo/ernie-1.0/run_pretrain_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -351,6 +351,7 @@ def main():
model_config[
"attention_probs_dropout_prob"] = model_args.attention_probs_dropout_prob
model = model_class(base_class(**model_config))
# model_config["enable_recompute"] = args.use_recompute
Copy link
Collaborator Author

@ZHUI ZHUI Aug 4, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

需要改造 trainer 中 DP 的 fused_allreduce_gradients 方式。
这块后续需要分布式优化一下体验。

else:
model = model_class.from_pretrained(
model_args.model_name_or_path,
Expand Down
7 changes: 5 additions & 2 deletions paddlenlp/transformers/ernie/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -562,7 +562,8 @@ def __init__(self,
pad_token_id=0,
task_type_vocab_size=3,
task_id=0,
use_task_id=False):
use_task_id=False,
enable_recompute=False):
super(ErnieModel, self).__init__()
self.pad_token_id = pad_token_id
self.initializer_range = initializer_range
Expand All @@ -585,7 +586,9 @@ def __init__(self,
act_dropout=0,
weight_attr=weight_attr,
normalize_before=False)
self.encoder = nn.TransformerEncoder(encoder_layer, num_hidden_layers)
self.encoder = nn.TransformerEncoder(encoder_layer,
num_hidden_layers,
enable_recompute=enable_recompute)
self.pooler = ErniePooler(hidden_size, weight_attr)
self.apply(self.init_weights)

Expand Down
78 changes: 74 additions & 4 deletions paddlenlp/transformers/model_outputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,25 @@
from dataclasses import fields, dataclass
from typing import Any, List, Tuple, Optional
from paddle.nn.layer.transformer import _convert_attention_mask
from paddle.distributed.fleet.utils import recompute

from .utils import adapt_stale_fwd_patch


def layer_init_wrapper(func):

@functools.wraps(func)
def _impl(self, *args, **kwargs):
enable_recompute = kwargs.pop("enable_recompute", False)
func(self, *args, **kwargs)
if paddle.in_dynamic_mode():
self.enable_recompute = enable_recompute
else:
self.enable_recompute = False

return _impl


def _transformer_encoder_layer_fwd(self,
src,
src_mask=None,
Expand Down Expand Up @@ -60,6 +75,46 @@ def _transformer_encoder_layer_fwd(self,
(src, ) + outputs[::-1]) # hidden_states, cache, attentions


def _transformer_decoder_fwd(self,
tgt,
memory,
tgt_mask=None,
memory_mask=None,
cache=None):
tgt_mask = _convert_attention_mask(tgt_mask, tgt.dtype)
memory_mask = _convert_attention_mask(memory_mask, memory.dtype)

output = tgt
new_caches = []
for i, mod in enumerate(self.layers):
if cache is None:
if self.enable_recompute:
output = recompute(mod,
output,
memory,
tgt_mask,
memory_mask,
cache=None)
else:
output = mod(output,
memory,
tgt_mask=tgt_mask,
memory_mask=memory_mask,
cache=None)
else:
output, new_cache = mod(output,
memory,
tgt_mask=tgt_mask,
memory_mask=memory_mask,
cache=cache[i])
new_caches.append(new_cache)

if self.norm is not None:
output = self.norm(output)

return output if cache is None else (output, new_caches)


def _transformer_encoder_fwd(self,
src,
src_mask=None,
Expand All @@ -75,10 +130,16 @@ def _transformer_encoder_fwd(self,
# NOTE: Also includes embeding output which is same as HF.
all_hidden_states = [output] if output_hidden_states else None
for i, mod in enumerate(self.layers):
layer_outputs = mod(output,
src_mask=src_mask,
cache=None if cache is None else cache[i],
output_attentions=output_attentions)
if self.enable_recompute:
layer_outputs = recompute(mod, output, src_mask,
None if cache is None else cache[i],
output_attentions)
else:
layer_outputs = mod(output,
src_mask=src_mask,
cache=None if cache is None else cache[i],
output_attentions=output_attentions)

if isinstance(layer_outputs, tuple):
output = layer_outputs[0]
outputs = layer_outputs[1:]
Expand Down Expand Up @@ -122,6 +183,12 @@ def _transformer_encoder_fwd(self,
# patches of paddle.nn.Transformer to get all hidden_states and attentions
paddle.nn.TransformerEncoderLayer.forward = _transformer_encoder_layer_fwd
paddle.nn.TransformerEncoder.forward = _transformer_encoder_fwd
paddle.nn.TransformerDecoder.forward = _transformer_decoder_fwd

_encoder_init = paddle.nn.TransformerEncoder.__init__
_decoder_init = paddle.nn.TransformerDecoder.__init__
paddle.nn.TransformerEncoder.__init__ = layer_init_wrapper(_encoder_init)
paddle.nn.TransformerDecoder.__init__ = layer_init_wrapper(_decoder_init)


def _get_wrap_setattr(cls):
Expand All @@ -139,6 +206,9 @@ def _wrap_setattr(self, name, value):
paddle.nn.TransformerEncoder.__setattr__ = functools.wraps(
paddle.nn.TransformerEncoder.__setattr__)(_get_wrap_setattr(
paddle.nn.TransformerEncoder))
paddle.nn.TransformerDecoder.__setattr__ = functools.wraps(
paddle.nn.TransformerDecoder.__setattr__)(_get_wrap_setattr(
paddle.nn.TransformerDecoder))


def is_tensor(x):
Expand Down