Skip to content

Commit

Permalink
[do not merge] for test llama ernie in new architecture dy2st + prim …
Browse files Browse the repository at this point in the history
…+ cinn
  • Loading branch information
jeff41404 committed Aug 9, 2023
1 parent 25c25cf commit 0baaceb
Show file tree
Hide file tree
Showing 4 changed files with 52 additions and 10 deletions.
15 changes: 15 additions & 0 deletions paddlenlp/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -573,6 +573,21 @@ def train(
# so, the trainable numel is a little bigger than real.
logger.info(f" Number of trainable parameters = {trainable_numel:,} (all devices, roughly)")

model = paddle.jit.to_static(
model,
input_spec=[
paddle.static.InputSpec(name="input_ids", shape=[-1, -1], dtype="int64"), # input_ids
None, # position_ids
None, # attention_mask
None, # inputs_embeds
paddle.static.InputSpec(name="labels", shape=[-1, -1], dtype="int64"), # labels
False, # use_cache
None, # past_key_values
None, # output_attentions
None, # output_hidden_states
None, # return_dict
],
)
start_time = time.time()
self._globalstep_last_start_time = time.time()
self.state.epoch = 0
Expand Down
20 changes: 10 additions & 10 deletions paddlenlp/transformers/llama/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,8 +233,8 @@ def scaled_dot_product_attention(
)

attn_weights = attn_weights + attention_mask
with paddle.amp.auto_cast(False):
attn_weights = F.softmax(attn_weights, axis=-1, dtype="float32").astype(query_states.dtype)
# with paddle.amp.auto_cast(False):
attn_weights = F.softmax(attn_weights, axis=-1, dtype="float32").astype(query_states.dtype)

attn_output = paddle.matmul(attn_weights, value_states)
attn_output = attn_output.transpose([0, 2, 1, 3])
Expand Down Expand Up @@ -299,9 +299,9 @@ def forward(self, hidden_states):
if self.config.use_fused_rms_norm:
return rms_norm_fused(hidden_states, self.weight, self.variance_epsilon)

with paddle.amp.auto_cast(False):
variance = hidden_states.astype("float32").pow(2).mean(-1, keepdim=True)
hidden_states = paddle.rsqrt(variance + self.variance_epsilon) * hidden_states
# with paddle.amp.auto_cast(False):
variance = hidden_states.astype("float32").pow(2).mean(-1, keepdim=True)
hidden_states = paddle.rsqrt(variance + self.variance_epsilon) * hidden_states

if self.weight.dtype in [paddle.float16, paddle.bfloat16]:
hidden_states = paddle.cast(hidden_states, self.weight.dtype)
Expand Down Expand Up @@ -1129,11 +1129,11 @@ def forward(self, prediction_scores, masked_lm_labels):
prediction_scores = prediction_scores[..., :-1, :]
masked_lm_labels = masked_lm_labels[..., 1:]

with paddle.amp.auto_cast(False):
masked_lm_loss = self.loss_func(prediction_scores.astype("float32"), masked_lm_labels.unsqueeze(2))
# skip ignore_index which loss == 0
masked_lm_loss = masked_lm_loss[masked_lm_loss > 0].astype("float32")
loss = paddle.mean(masked_lm_loss)
# with paddle.amp.auto_cast(False):
masked_lm_loss = self.loss_func(prediction_scores.astype("float32"), masked_lm_labels.unsqueeze(2))
# skip ignore_index which loss == 0
masked_lm_loss = masked_lm_loss[masked_lm_loss > 0].astype("float32")
loss = paddle.mean(masked_lm_loss)

return loss

Expand Down
3 changes: 3 additions & 0 deletions tests/test_tipc/benchmark/options.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@
from .modules.stablediffusion import StableDiffusionBenchmark
except Exception:
StableDiffusionBenchmark = None
from paddlenlp.trainer.argparser import strtobool

from .modules.t5_for_conditional_generation import T5ForConditionalGenerationBenchmark
from .modules.xlnet import XLNetBenchmark

Expand Down Expand Up @@ -156,6 +158,7 @@ def get_parser():
help='The option of profiler, which should be in format "key1=value1;key2=value2;key3=value3".',
)
parser.add_argument("--save_model", type=str, default=None, help="Directory to save models. ")
parser.add_argument("--use_nsys", type=strtobool, default=False, help="Enable nsys.")

return parser

Expand Down
24 changes: 24 additions & 0 deletions tests/test_tipc/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from benchmark.modules.benchmark_utils import clone_inputs
from benchmark.options import LR_SCHEDULER_REGISTRY, MODEL_REGISTRY, OPTIMIZER_REGISTRY
from benchmark.utils.record import AverageStatistical
from paddle.fluid import core

from paddlenlp.utils import profiler
from paddlenlp.utils.log import logger
Expand Down Expand Up @@ -223,6 +224,17 @@ def do_train(args):
input_spec = benchmark_model.create_input_specs()
model = paddle.jit.to_static(model, input_spec=input_spec)
logger.info("Successfully to apply @to_static with specs: {}".format(input_spec))
# paddle.jit.save(model, "/data/run_model/ernie/ernie", input_spec=input_spec)
# composite_program = model.forward.get_concrete_program(**tmp_input_data)[1].train_program
# for op in composite_program.block(0).ops:
# print(op)
# #print(op.type)
# build_strategy = paddle.static.BuildStrategy()
# build_strategy.build_cinn_pass = True
# build_strategy.debug_graphviz_path = "/data/run_model/ernie/paddle_to_cinn_graph/"
# program = paddle.static.CompiledProgram(composite_program, build_strategy=build_strategy)
# program._compile(paddle.fluid.executor.global_scope(), paddle.CUDAPlace(0))
# import pdb; pdb.set_trace()

if args.lr_scheduler is not None:
benchmark_lr_scheduler = LR_SCHEDULER_REGISTRY[args.lr_scheduler]()
Expand Down Expand Up @@ -255,6 +267,18 @@ def do_train(args):
batch_id = 0
batch_start = time.time()
for input_data in train_loader:
if args.use_nsys:
iter_id = step_id
if iter_id == 100:
core.nvprof_start()
core.nvprof_enable_record_event()
core.nvprof_nvtx_push(str(iter_id))
if iter_id == 110:
core.nvprof_nvtx_pop()
core.nvprof_stop()
if iter_id > 100 and iter_id < 110:
core.nvprof_nvtx_pop()
core.nvprof_nvtx_push(str(iter_id))
train_reader_cost = time.time() - batch_start

if args.use_amp:
Expand Down

0 comments on commit 0baaceb

Please sign in to comment.