Skip to content

Commit

Permalink
[Feature] Add hardware flops for pretraining (#9069)
Browse files Browse the repository at this point in the history
* fix hardware tflops.

* Support mfu for pretraining.
  • Loading branch information
ZHUI authored Sep 13, 2024
1 parent e2f4c33 commit 5c1779c
Show file tree
Hide file tree
Showing 9 changed files with 240 additions and 1 deletion.
7 changes: 7 additions & 0 deletions paddlenlp/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1354,15 +1354,22 @@ def _maybe_log_save_evaluate(self, tr_loss, model, epoch, ignore_keys_for_eval,
)
num_steps = self.state.global_step - self._globalstep_last_logged
seq_length = None
model_flops = None
if getattr(self, "is_pretraining", False) and hasattr(self.model, "config"):
seq_length = getattr(self.model.config, "seq_length", None)
try:
model_flops = self.model.get_hardware_flops(seq_length=seq_length, recompute=self.args.recompute)
except NotImplementedError:
model_flops = None

logs.update(
speed_metrics(
"interval",
self._globalstep_last_start_time,
num_samples=total_train_batch_size * num_steps,
num_steps=num_steps,
seq_length=seq_length,
model_flops=model_flops,
)
)

Expand Down
7 changes: 6 additions & 1 deletion paddlenlp/trainer/trainer_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -344,7 +344,7 @@ def total_processes_number(local_rank):
return 1


def speed_metrics(split, start_time, num_samples=None, num_steps=None, seq_length=None):
def speed_metrics(split, start_time, num_samples=None, num_steps=None, seq_length=None, model_flops=None):
"""
Measure and return speed performance metrics.
Expand All @@ -365,6 +365,11 @@ def speed_metrics(split, start_time, num_samples=None, num_steps=None, seq_lengt
if seq_length is not None:
tokens_per_second_per_device = samples_per_second * seq_length / paddle.distributed.get_world_size()
result[f"{split}_tokens_per_second_per_device"] = round(tokens_per_second_per_device, 4)
if model_flops is not None:
result[f"{split}_hardware_tflops_per_device"] = round(
tokens_per_second_per_device * model_flops / seq_length / 2**40, 2
)

if num_steps is not None:
steps_per_second = num_steps / runtime
result[f"{split}_steps_per_second"] = round(steps_per_second, 4)
Expand Down
34 changes: 34 additions & 0 deletions paddlenlp/transformers/gemma/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@
from .. import linear_utils
from ..linear_utils import Linear
from ..segment_parallel_utils import ReshardLayer
from ..utils import caculate_llm_flops
from .configuration import (
GEMMA_PRETRAINED_INIT_CONFIGURATION,
GEMMA_PRETRAINED_RESOURCE_FILES_MAP,
Expand Down Expand Up @@ -1074,6 +1075,39 @@ def __init__(self, config: GemmaConfig):

self.gradient_checkpointing = False

def get_model_flops(self, batch_size=1, seq_length=None, **kwargs):
if seq_length is None:
if hasattr(self.config, "seq_length"):
seq_length = self.config.seq_length
else:
seq_length = 2048

return caculate_llm_flops(
hidden_size=self.config.hidden_size,
intermediate_size=self.config.intermediate_size,
layer_num=self.config.num_hidden_layers,
vocab_size=self.config.vocab_size,
seq_length=seq_length,
recompute=False,
)

def get_hardware_flops(self, batch_size=1, seq_length=None, recompute=False, **kwargs):
if seq_length is None:
if hasattr(self.config, "seq_length"):
seq_length = self.config.seq_length
else:
seq_length = 2048

return caculate_llm_flops(
hidden_size=self.config.hidden_size,
intermediate_size=self.config.intermediate_size,
layer_num=self.config.num_hidden_layers,
vocab_size=self.config.vocab_size,
seq_length=seq_length,
recompute=recompute,
recompute_granularity=self.config.recompute_granularity,
)

def get_input_embeddings(self):
return self.embed_tokens

Expand Down
34 changes: 34 additions & 0 deletions paddlenlp/transformers/gpt/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@
TokenClassifierOutput,
)
from ..model_utils import dy2st_nocheck_guard_context
from ..utils import caculate_llm_flops
from .configuration import (
GPT_PRETRAINED_INIT_CONFIGURATION,
GPT_PRETRAINED_RESOURCE_FILES_MAP,
Expand Down Expand Up @@ -1105,6 +1106,39 @@ def __init__(self, config: GPTConfig):
decoder_layers,
)

def get_model_flops(self, batch_size=1, seq_length=None, **kwargs):
if seq_length is None:
if hasattr(self.config, "seq_length"):
seq_length = self.config.seq_length
else:
seq_length = 2048

return caculate_llm_flops(
hidden_size=self.config.hidden_size,
intermediate_size=self.config.intermediate_size,
layer_num=self.config.num_hidden_layers,
vocab_size=self.config.vocab_size,
seq_length=seq_length,
recompute=False,
)

def get_hardware_flops(self, batch_size=1, seq_length=None, recompute=False, **kwargs):
if seq_length is None:
if hasattr(self.config, "seq_length"):
seq_length = self.config.seq_length
else:
seq_length = 2048

return caculate_llm_flops(
hidden_size=self.config.hidden_size,
intermediate_size=self.config.intermediate_size,
layer_num=self.config.num_hidden_layers,
vocab_size=self.config.vocab_size,
seq_length=seq_length,
recompute=recompute,
recompute_granularity=self.config.recompute_granularity,
)

def get_input_embeddings(self):
return self.embeddings.word_embeddings

Expand Down
34 changes: 34 additions & 0 deletions paddlenlp/transformers/llama/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ def swiglu(x, y=None):
from .. import linear_utils
from ..linear_utils import Linear
from ..segment_parallel_utils import ReshardLayer
from ..utils import caculate_llm_flops
from .configuration import (
LLAMA_PRETRAINED_INIT_CONFIGURATION,
LLAMA_PRETRAINED_RESOURCE_FILES_MAP,
Expand Down Expand Up @@ -1468,6 +1469,39 @@ def __init__(self, config: LlamaConfig):

self.gradient_checkpointing = False

def get_model_flops(self, batch_size=1, seq_length=None, **kwargs):
if seq_length is None:
if hasattr(self.config, "seq_length"):
seq_length = self.config.seq_length
else:
seq_length = 2048

return caculate_llm_flops(
hidden_size=self.config.hidden_size,
intermediate_size=self.config.intermediate_size,
layer_num=self.config.num_hidden_layers,
vocab_size=self.config.vocab_size,
seq_length=seq_length,
recompute=False,
)

def get_hardware_flops(self, batch_size=1, seq_length=None, recompute=False, **kwargs):
if seq_length is None:
if hasattr(self.config, "seq_length"):
seq_length = self.config.seq_length
else:
seq_length = 2048

return caculate_llm_flops(
hidden_size=self.config.hidden_size,
intermediate_size=self.config.intermediate_size,
layer_num=self.config.num_hidden_layers,
vocab_size=self.config.vocab_size,
seq_length=seq_length,
recompute=recompute,
recompute_granularity=self.config.recompute_granularity,
)

def get_input_embeddings(self):
return self.embed_tokens

Expand Down
14 changes: 14 additions & 0 deletions paddlenlp/transformers/model_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1102,6 +1102,20 @@ def get_memory_footprint(self, return_buffers=True):
mem = mem + mem_bufs
return mem

def get_model_flops(self, *args, **kwargs):
base_model = getattr(self, self.base_model_prefix, self)
if base_model is not self:
return base_model.get_model_flops()

raise NotImplementedError(f"model of {type(base_model)} has not implemented the `get_model_flops`")

def get_hardware_flops(self, *args, **kwargs):
base_model = getattr(self, self.base_model_prefix, self)
if base_model is not self:
return base_model.get_hardware_flops()

raise NotImplementedError(f"model of {type(base_model)} has not implemented the `get_hardware_flops`")

def get_input_embeddings(self) -> nn.Embedding:
"""get input embedding of model
Expand Down
34 changes: 34 additions & 0 deletions paddlenlp/transformers/qwen/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ def swiglu(x, y=None):
from .. import linear_utils
from ..linear_utils import Linear
from ..model_outputs import ModelOutput
from ..utils import caculate_llm_flops
from .configuration import QWenConfig

try:
Expand Down Expand Up @@ -690,6 +691,39 @@ def __init__(self, config):
)
self.ln_f = QWenRMSNorm(config)

def get_model_flops(self, batch_size=1, seq_length=None, **kwargs):
if seq_length is None:
if hasattr(self.config, "seq_length"):
seq_length = self.config.seq_length
else:
seq_length = 2048

return caculate_llm_flops(
hidden_size=self.config.hidden_size,
intermediate_size=self.config.intermediate_size,
layer_num=self.config.num_hidden_layers,
vocab_size=self.config.vocab_size,
seq_length=seq_length,
recompute=False,
)

def get_hardware_flops(self, batch_size=1, seq_length=None, recompute=False, **kwargs):
if seq_length is None:
if hasattr(self.config, "seq_length"):
seq_length = self.config.seq_length
else:
seq_length = 2048

return caculate_llm_flops(
hidden_size=self.config.hidden_size,
intermediate_size=self.config.intermediate_size,
layer_num=self.config.num_hidden_layers,
vocab_size=self.config.vocab_size,
seq_length=seq_length,
recompute=recompute,
recompute_granularity=self.config.recompute_granularity,
)

def get_input_embeddings(self):
return self.wte

Expand Down
34 changes: 34 additions & 0 deletions paddlenlp/transformers/qwen2/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
TokenClassifierOutput,
)
from ..model_utils import PretrainedModel, register_base_model
from ..utils import caculate_llm_flops
from .configuration import Qwen2Config

try:
Expand Down Expand Up @@ -914,6 +915,39 @@ def __init__(self, config: Qwen2Config):
)
self.norm = Qwen2RMSNorm(config)

def get_model_flops(self, batch_size=1, seq_length=None, **kwargs):
if seq_length is None:
if hasattr(self.config, "seq_length"):
seq_length = self.config.seq_length
else:
seq_length = 2048

return caculate_llm_flops(
hidden_size=self.config.hidden_size,
intermediate_size=self.config.intermediate_size,
layer_num=self.config.num_hidden_layers,
vocab_size=self.config.vocab_size,
seq_length=seq_length,
recompute=False,
)

def get_hardware_flops(self, batch_size=1, seq_length=None, recompute=False, **kwargs):
if seq_length is None:
if hasattr(self.config, "seq_length"):
seq_length = self.config.seq_length
else:
seq_length = 2048

return caculate_llm_flops(
hidden_size=self.config.hidden_size,
intermediate_size=self.config.intermediate_size,
layer_num=self.config.num_hidden_layers,
vocab_size=self.config.vocab_size,
seq_length=seq_length,
recompute=recompute,
recompute_granularity=self.config.recompute_granularity,
)

def get_input_embeddings(self):
return self.embed_tokens

Expand Down
43 changes: 43 additions & 0 deletions paddlenlp/transformers/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -958,3 +958,46 @@ def __repr__(self):
if self.err_buf:
msg += f"stderr: {self.err}\n"
return msg


def caculate_llm_flops(
hidden_size,
intermediate_size,
layer_num,
vocab_size,
batch_size=1,
seq_length=None,
recompute=False,
recompute_granularity=None,
):

# TFLOPs formula (from Equation 3 in Section 5.1 of https://arxiv.org/pdf/2104.04473.pdf).
flops_per_transformer = 0
flops_recompute_transformer = 0

# qkvo matmul
flops_qkvo_matmul = seq_length * hidden_size**2 * 4

# [b,s,h] [b,h,s] bs^2h
# [b,s,s] [b,s,h] bs^2h
# q_states * k_states + attn_weight * v_states
flops_core_attn = seq_length**2 * hidden_size * 2

# swiglu, matmul + dot
flops_ffn = seq_length * hidden_size * intermediate_size * 3 + seq_length * intermediate_size

flops_per_transformer = flops_qkvo_matmul + flops_core_attn + flops_ffn
if recompute:
if recompute_granularity == "full":
flops_recompute_transformer = flops_per_transformer
if recompute_granularity == "full_attn":
flops_recompute_transformer = flops_qkvo_matmul + flops_core_attn
if recompute_granularity == "core_attn":
flops_recompute_transformer = flops_core_attn

# final loggits
flops_loggits = seq_length * hidden_size * vocab_size

# 2 for mul + add in matmul
# 1 for forward, 2 for backwards since we caluate gradients for input_x and input_y
return 2 * batch_size * (layer_num * (flops_per_transformer * 3 + flops_recompute_transformer) + 3 * flops_loggits)

0 comments on commit 5c1779c

Please sign in to comment.