Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Bring up Llama 3.1 405B on two pods #68

Open
wants to merge 31 commits into
base: flash_attention_405b
Choose a base branch
from

Conversation

tengyifei
Copy link
Collaborator

What does this PR do?

These are the changes to make Llama 3.1 405B work on two Trillium TPU pods. It includes:

  • Initialize the model layer by layer on the CPU, to workaround an OOM bug when initializing all layers in the model at once on the TPU.
  • Save the inv_freq buffer in Llama for now, to ensure we can initialize it using load_state_dict. Otherwise, the inv_freq buffer will stay as a meta tensor.
  • Adds a USE_SINGLE_SLICE env var before importing jax, to prevent jax from re-initializing MegaScale client. A corresponding custom libtpu build is required.
  • Use a custom hybrid ring mesh to improve collectives performance.

device_attributes = xr.global_runtime_device_attributes()
num_slices = max(int(d.get('slice_index', 0)) for d in device_attributes) + 1

if num_slices > 1 and model_args.spmd_2d_sharding == 0:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we give the flexibility to choose whether to use the 2d sharding or not?

model = AutoModelForCausalLM.from_config(config, trust_remote_code=model_args.trust_remote_code)
# note: at this point, the mode is not materialized
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

NIT: typo mode.

@zpcore
Copy link
Collaborator

zpcore commented Nov 1, 2024

This is great, can you point us to the material how does hybrid ring mesh work? Thanks

# Multi-slice 2D sharding
tensor_axis = model_args.spmd_2d_sharding
fsdp_axis = num_devices // tensor_axis
mesh_shape = (fsdp_axis, tensor_axis) # Should be (128, 4)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The comment seems only apply to 2 slices with 4 tensor parallelism.


model.load_state_dict(dict_of_params, assign=True)
model.to('xla')
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Question: with meta device, did we skip loading the model into CPU RAM but loading into XLA device directly?

> /workspaces/torch/transformers/<eval_with_key>.6(210)forward()

206         mul_15 = torch.ops.aten.mul.Tensor(add_7, _to_copy_3);  add_7 = _to_copy_3 = None
207         sum_1 = torch.ops.aten.sum.dim_IntList(mul_15, [0, 1], True);  mul_15 = None
208         view_58 = torch.ops.aten.view.default(sum_1, [4096]);  sum_1 = None
209         _to_copy_4 = torch.ops.aten._to_copy.default(mul_14, dtype = torch.float32, layout = torch.strided, device = device(type='xla', index=0));  mul_14 = None
210  ->     mul_16 = torch.ops.aten.mul.Tensor(_to_copy_4, _to_copy_2)
211         mul_17 = torch.ops.aten.mul.Tensor(_to_copy_4, rsqrt_1);  _to_copy_4 = None
212         sum_2 = torch.ops.aten.sum.dim_IntList(mul_16, [2], True);  mul_16 = None
213         detach_29 = torch.ops.aten.detach.default(rsqrt_1);  rsqrt_1 = None
214         detach_30 = torch.ops.aten.detach.default(detach_29);  detach_29 = None

_to_copy_2 is `meta`.
_to_copy_4 is `xla`.
This supports hardcoded XLA device names in the input graph (e.g. due
to casts). We need to trace the graph correspondingly with XLA devices.

Why does a cast involve devices? Looks like the ATen ops doesn't include
cast, only `_to_copy`: https://pytorch.org/docs/stable/torch.compiler_ir.html
This helps XLA avoid gathering the logits and saves a few Gigs of RAM.

Before: http://shortn/_tlV88E1Ca3
After:  http://shortn/_GodQzu6GMu
Using a custom aten slice fast path in the ptxla branch.

Before: http://shortn/_1Xyrm0lLdL
After:  http://shortn/_SgU6vG1pNm
After we trace out HLO in scan using placeholder tensors, memory usage
drops enough that we can up the batch size to 16.

Profile: http://shortn/_ESRGhAhKce
These are negative optimizations (probably some are wrong).

On 2D sharded v6e-8, memory usage goes 26 GiB -> 22 GiB.
Step time: 4.76s -> 4.71s.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants