-
Notifications
You must be signed in to change notification settings - Fork 9
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Bring up Llama 3.1 405B on two pods #68
base: flash_attention_405b
Are you sure you want to change the base?
Conversation
…stributed model initialization
Instead of leaving it as a meta tensor, mark it as persistent=True
device_attributes = xr.global_runtime_device_attributes() | ||
num_slices = max(int(d.get('slice_index', 0)) for d in device_attributes) + 1 | ||
|
||
if num_slices > 1 and model_args.spmd_2d_sharding == 0: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should we give the flexibility to choose whether to use the 2d sharding or not?
model = AutoModelForCausalLM.from_config(config, trust_remote_code=model_args.trust_remote_code) | ||
# note: at this point, the mode is not materialized |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
NIT: typo mode
.
This is great, can you point us to the material how does |
# Multi-slice 2D sharding | ||
tensor_axis = model_args.spmd_2d_sharding | ||
fsdp_axis = num_devices // tensor_axis | ||
mesh_shape = (fsdp_axis, tensor_axis) # Should be (128, 4) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The comment seems only apply to 2 slices with 4 tensor parallelism.
|
||
model.load_state_dict(dict_of_params, assign=True) | ||
model.to('xla') |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Question: with meta device, did we skip loading the model into CPU RAM but loading into XLA
device directly?
> /workspaces/torch/transformers/<eval_with_key>.6(210)forward() 206 mul_15 = torch.ops.aten.mul.Tensor(add_7, _to_copy_3); add_7 = _to_copy_3 = None 207 sum_1 = torch.ops.aten.sum.dim_IntList(mul_15, [0, 1], True); mul_15 = None 208 view_58 = torch.ops.aten.view.default(sum_1, [4096]); sum_1 = None 209 _to_copy_4 = torch.ops.aten._to_copy.default(mul_14, dtype = torch.float32, layout = torch.strided, device = device(type='xla', index=0)); mul_14 = None 210 -> mul_16 = torch.ops.aten.mul.Tensor(_to_copy_4, _to_copy_2) 211 mul_17 = torch.ops.aten.mul.Tensor(_to_copy_4, rsqrt_1); _to_copy_4 = None 212 sum_2 = torch.ops.aten.sum.dim_IntList(mul_16, [2], True); mul_16 = None 213 detach_29 = torch.ops.aten.detach.default(rsqrt_1); rsqrt_1 = None 214 detach_30 = torch.ops.aten.detach.default(detach_29); detach_29 = None _to_copy_2 is `meta`. _to_copy_4 is `xla`.
This supports hardcoded XLA device names in the input graph (e.g. due to casts). We need to trace the graph correspondingly with XLA devices. Why does a cast involve devices? Looks like the ATen ops doesn't include cast, only `_to_copy`: https://pytorch.org/docs/stable/torch.compiler_ir.html
This helps XLA avoid gathering the logits and saves a few Gigs of RAM. Before: http://shortn/_tlV88E1Ca3 After: http://shortn/_GodQzu6GMu
Using a custom aten slice fast path in the ptxla branch. Before: http://shortn/_1Xyrm0lLdL After: http://shortn/_SgU6vG1pNm
After we trace out HLO in scan using placeholder tensors, memory usage drops enough that we can up the batch size to 16. Profile: http://shortn/_ESRGhAhKce
These are negative optimizations (probably some are wrong). On 2D sharded v6e-8, memory usage goes 26 GiB -> 22 GiB. Step time: 4.76s -> 4.71s.
What does this PR do?
These are the changes to make Llama 3.1 405B work on two Trillium TPU pods. It includes:
inv_freq
buffer in Llama for now, to ensure we can initialize it usingload_state_dict
. Otherwise, theinv_freq
buffer will stay as a meta tensor.USE_SINGLE_SLICE
env var before importing jax, to prevent jax from re-initializing MegaScale client. A corresponding custom libtpu build is required.