Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add e5m2 & e4m3fn to nvfuserex_impl dtype map #1551

Merged
merged 1 commit into from
Dec 16, 2024
Merged

Conversation

crcrpar
Copy link
Collaborator

@crcrpar crcrpar commented Dec 13, 2024

What does this PR do?

As per title.

import torch
import torch.nn as nn
from torchao.float8 import convert_to_float8_training

import thunder
from thunder.tests.make_tensor import make_tensor


def main():
    batch_size, in_features, out_features = 16, 32, 64
    device = torch.device("cuda")
    dtype = torch.bfloat16
    bias = False

    model = nn.Sequential(
        nn.Linear(in_features, out_features, bias=bias),
        # nn.GELU(approximate="tanh"),
        # nn.Linear(out_features, out_features, bias=bias),
    ).to(device=device, dtype=dtype)
    fp8_model = convert_to_float8_training(model)
    x = make_tensor((batch_size, in_features), device=device, dtype=dtype)

    jitted = thunder.jit(fp8_model)
    y = jitted(x)

    for i, t in enumerate(traces := thunder.last_traces(jitted)):
        if i == len(traces) - 1:
            print(t)
        if t._provenance is not None and ("subclass" in t._provenance.pss or "Subclass" in t._provenance.pss):
            print(t)


if __name__ == "__main__":
    main()

For this single fp8 linear layer, #1415 produces the following extrace:

# Constructed by Delete Last Used (took 0 milliseconds)
import thunder.core.devices as devices
import thunder.core.dtypes as dtypes
from torch import Tensor
import torch
from torchao.float8.float8_tensor import Float8Tensor
from torchao.float8.float8_tensor import ScaledMMConfig
from torchao.float8.float8_tensor import GemmInputRole
from torchao.float8.float8_tensor import LinearMMConfig
from thunder.executors.torchex import no_autocast

@torch.no_grad()
@no_autocast
def computation(input, weight):
  # input: "cuda:0 bf16[16, 32]"
  # weight: "cuda:0 bf16[64, 32]"
  [scale, t247, t196, t248, t222] = nvFusion0(input, weight)
    # t173 = prims.convert_element_type(input, dtypes.float32)  # t173: "cuda:0 f32[16, 32]"
    # t174 = prims.abs(t173)  # t174: "cuda:0 f32[16, 32]"
    # t199 = prims.convert_element_type(weight, dtypes.float32)  # t199: "cuda:0 f32[64, 32]"
    # t200 = prims.abs(t199)  # t200: "cuda:0 f32[64, 32]"
    # t177 = prims.amax(t174, (0, 1))  # t177: "cuda:0 f32[]"
    # t203 = prims.amax(t200, (0, 1))  # t203: "cuda:0 f32[]"
    # t9 = prims.convert_element_type(t177, dtypes.float64)  # t9: "cuda:0 f64[]"
    # t64 = prims.convert_element_type(t203, dtypes.float64)  # t64: "cuda:0 f64[]"
    # t180 = prims.ne(t9, t9)  # t180: "cuda:0 b8[]"
    # t181 = prims.gt(t9, 1e-12)  # t181: "cuda:0 b8[]"
    # t182 = prims.where(t181, t9, 1e-12)  # t182: "cuda:0 f64[]"
    # t14 = prims.where(t180, t9, t182)  # t14: "cuda:0 f64[]"
    # t206 = prims.ne(t64, t64)  # t206: "cuda:0 b8[]"
    # t207 = prims.gt(t64, 1e-12)  # t207: "cuda:0 b8[]"
    # t208 = prims.where(t207, t64, 1e-12)  # t208: "cuda:0 f64[]"
    # t68 = prims.where(t206, t64, t208)  # t68: "cuda:0 f64[]"
    # res = prims.div(448.0, t14)  # res: "cuda:0 f64[]"
    # scale = prims.convert_element_type(res, dtypes.float32)  # scale: "cuda:0 f32[]"
    # t187 = prims.broadcast_in_dim(scale, (16, 32), ())  # t187: "cuda:0 f32[16, 32]"
    # t188 = prims.mul(t173, t187)  # t188: "cuda:0 f32[16, 32]"
    # t69 = prims.div(448.0, t68)  # t69: "cuda:0 f64[]"
    # weight_scale = prims.convert_element_type(t69, dtypes.float32)  # weight_scale: "cuda:0 f32[]"
    # t213 = prims.broadcast_in_dim(weight_scale, (64, 32), ())  # t213: "cuda:0 f32[64, 32]"
    # t214 = prims.mul(t199, t213)  # t214: "cuda:0 f32[64, 32]"
    # t247 = prims.reciprocal(scale)  # t247: "cuda:0 f32[]"
    # t189 = prims.ne(t188, t188)  # t189: "cuda:0 b8[16, 32]"
    # t190 = prims.gt(t188, -448.0)  # t190: "cuda:0 b8[16, 32]"
    # t191 = prims.where(t190, t188, -448.0)  # t191: "cuda:0 f32[16, 32]"
    # t192 = prims.where(t189, t188, t191)  # t192: "cuda:0 f32[16, 32]"
    # t193 = prims.ne(t192, t192)  # t193: "cuda:0 b8[16, 32]"
    # t194 = prims.lt(t192, 448.0)  # t194: "cuda:0 b8[16, 32]"
    # t195 = prims.where(t194, t192, 448.0)  # t195: "cuda:0 f32[16, 32]"
    # t196 = prims.where(t193, t192, t195)  # t196: "cuda:0 f32[16, 32]"
    # t248 = prims.reciprocal(weight_scale)  # t248: "cuda:0 f32[]"
    # t215 = prims.ne(t214, t214)  # t215: "cuda:0 b8[64, 32]"
    # t216 = prims.gt(t214, -448.0)  # t216: "cuda:0 b8[64, 32]"
    # t217 = prims.where(t216, t214, -448.0)  # t217: "cuda:0 f32[64, 32]"
    # t218 = prims.where(t215, t214, t217)  # t218: "cuda:0 f32[64, 32]"
    # t219 = prims.ne(t218, t218)  # t219: "cuda:0 b8[64, 32]"
    # t220 = prims.lt(t218, 448.0)  # t220: "cuda:0 b8[64, 32]"
    # t221 = prims.where(t220, t218, 448.0)  # t221: "cuda:0 f32[64, 32]"
    # t222 = prims.where(t219, t218, t221)  # t222: "cuda:0 f32[64, 32]"

  # /opt/pytorch/lightning-thunder/thunder/core/proxies.py:1966: 	                self.requires_grad,
  t197 = Tensor.to(t196, copy=False, dtype=torch.float8_e4m3fn)  # t197: "cuda:0 f8_e4m3fn[16, 32]"
    # t197 = ltorch.to(t196, None, None, device=None, dtype=torch.float8_e4m3fn, copy=False, memory_format=None)  # t197: "cuda:0 f8_e4m3fn[16, 32]"
      # t197 = prims.convert_element_type(t196, dtypes.float8_e4m3fn)  # t197: "cuda:0 f8_e4m3fn[16, 32]"
  del t196

  # /opt/pytorch/lightning-thunder/thunder/core/proxies.py:1966: 	                self.requires_grad,
  t223 = Tensor.to(t222, copy=False, dtype=torch.float8_e4m3fn)  # t223: "cuda:0 f8_e4m3fn[64, 32]"
    # t223 = ltorch.to(t222, None, None, device=None, dtype=torch.float8_e4m3fn, copy=False, memory_format=None)  # t223: "cuda:0 f8_e4m3fn[64, 32]"
      # t223 = prims.convert_element_type(t222, dtypes.float8_e4m3fn)  # t223: "cuda:0 f8_e4m3fn[64, 32]"
  del t222
  t241 = Tensor.view(t197, [-1, 32])  # t241: "cuda:0 f8_e4m3fn[16, 32]"
    # t241 = ltorch.view(t197, [-1, 32])  # t241: "cuda:0 f8_e4m3fn[16, 32]"
      # t241 = ltorch.reshape(t197, [-1, 32])  # t241: "cuda:0 f8_e4m3fn[16, 32]"
        # t241 = prims.reshape(t197, (16, 32))  # t241: "cuda:0 f8_e4m3fn[16, 32]"

  # /opt/pytorch/lightning-thunder/thunder/core/proxies.py:1966: 	                self.requires_grad,
  input_fp8 = Float8Tensor(t197, scale, torch.bfloat16, LinearMMConfig(output=ScaledMMConfig(emulate=False, use_fast_accum=True, fp8_output=False, pad_inner_dim=False), grad_input=ScaledMMConfig(emulate=False, use_fast_accum=False, fp8_output=False, pad_inner_dim=False), grad_weight=ScaledMMConfig(emulate=False, use_fast_accum=False, fp8_output=False, pad_inner_dim=False)), _GemmInputRole_5, None)  # input_fp8: "cuda:0 bf16[16, 32]"
  del t197, scale
  t234 = torch.transpose(t223, 0, 1)  # t234: "cuda:0 f8_e4m3fn[32, 64]"
    # t234 = ltorch.transpose(t223, 0, 1)  # t234: "cuda:0 f8_e4m3fn[32, 64]"
      # t234 = prims.transpose(t223, (1, 0))  # t234: "cuda:0 f8_e4m3fn[32, 64]"
  del t223
  t244 = torch.transpose(t234, 0, 1)  # t244: "cuda:0 f8_e4m3fn[64, 32]"
    # t244 = ltorch.transpose(t234, 0, 1)  # t244: "cuda:0 f8_e4m3fn[64, 32]"
      # t244 = prims.transpose(t234, (1, 0))  # t244: "cuda:0 f8_e4m3fn[64, 32]"
  del t234
  t245 = torch.clone(t244)  # t245: "cuda:0 f8_e4m3fn[64, 32]"
    # t245 = ltorch.clone(t244, memory_format=_torch_memory_format_6)  # t245: "cuda:0 f8_e4m3fn[64, 32]"
      # t245 = prims.clone(t244)  # t245: "cuda:0 f8_e4m3fn[64, 32]"
  del t244
  t246 = torch.transpose(t245, 0, 1)  # t246: "cuda:0 f8_e4m3fn[32, 64]"
    # t246 = ltorch.transpose(t245, 0, 1)  # t246: "cuda:0 f8_e4m3fn[32, 64]"
      # t246 = prims.transpose(t245, (1, 0))  # t246: "cuda:0 f8_e4m3fn[32, 64]"
  del t245
  t249 = torch._scaled_mm(t241, t246, t247, t248, None, None, torch.bfloat16, True)  # t249: "cuda:0 bf16[16, 64]"
  del t241, t246, t247, t248

  # /usr/local/lib/python3.12/dist-packages/torchao/float8/float8_linear.py:104: 	        return grad_input, grad_weight.t()
  t122 = shallow_copy(t249)  # t122: "cuda:0 bf16[16, 64]"
  del t249
  return {'output': (t122,), 'flat_args': [input, weight], 'flat_output': (t122,)}, ((input_fp8,), ())

with this pullrequest on top of #1415, the extrace becomes:

# Constructed by Delete Last Used (took 0 milliseconds)
import thunder.core.devices as devices
import thunder.core.dtypes as dtypes
import torch
from torchao.float8.float8_tensor import Float8Tensor
from torchao.float8.float8_tensor import ScaledMMConfig
from torchao.float8.float8_tensor import GemmInputRole
from torchao.float8.float8_tensor import LinearMMConfig
from thunder.executors.torchex import no_autocast

@torch.no_grad()
@no_autocast
def computation(input, weight):
  # input: "cuda:0 bf16[16, 32]"
  # weight: "cuda:0 bf16[64, 32]"
  [scale, t247, t248, t197, t241, t244] = nvFusion0(input, weight)
    # t173 = prims.convert_element_type(input, dtypes.float32)  # t173: "cuda:0 f32[16, 32]"
    # t174 = prims.abs(t173)  # t174: "cuda:0 f32[16, 32]"
    # t199 = prims.convert_element_type(weight, dtypes.float32)  # t199: "cuda:0 f32[64, 32]"
    # t200 = prims.abs(t199)  # t200: "cuda:0 f32[64, 32]"
    # t177 = prims.amax(t174, (0, 1))  # t177: "cuda:0 f32[]"
    # t203 = prims.amax(t200, (0, 1))  # t203: "cuda:0 f32[]"
    # t9 = prims.convert_element_type(t177, dtypes.float64)  # t9: "cuda:0 f64[]"
    # t64 = prims.convert_element_type(t203, dtypes.float64)  # t64: "cuda:0 f64[]"
    # t180 = prims.ne(t9, t9)  # t180: "cuda:0 b8[]"
    # t181 = prims.gt(t9, 1e-12)  # t181: "cuda:0 b8[]"
    # t182 = prims.where(t181, t9, 1e-12)  # t182: "cuda:0 f64[]"
    # t14 = prims.where(t180, t9, t182)  # t14: "cuda:0 f64[]"
    # t206 = prims.ne(t64, t64)  # t206: "cuda:0 b8[]"
    # t207 = prims.gt(t64, 1e-12)  # t207: "cuda:0 b8[]"
    # t208 = prims.where(t207, t64, 1e-12)  # t208: "cuda:0 f64[]"
    # t68 = prims.where(t206, t64, t208)  # t68: "cuda:0 f64[]"
    # res = prims.div(448.0, t14)  # res: "cuda:0 f64[]"
    # scale = prims.convert_element_type(res, dtypes.float32)  # scale: "cuda:0 f32[]"
    # t187 = prims.broadcast_in_dim(scale, (16, 32), ())  # t187: "cuda:0 f32[16, 32]"
    # t188 = prims.mul(t173, t187)  # t188: "cuda:0 f32[16, 32]"
    # t69 = prims.div(448.0, t68)  # t69: "cuda:0 f64[]"
    # weight_scale = prims.convert_element_type(t69, dtypes.float32)  # weight_scale: "cuda:0 f32[]"
    # t213 = prims.broadcast_in_dim(weight_scale, (64, 32), ())  # t213: "cuda:0 f32[64, 32]"
    # t214 = prims.mul(t199, t213)  # t214: "cuda:0 f32[64, 32]"
    # t247 = prims.reciprocal(scale)  # t247: "cuda:0 f32[]"
    # t189 = prims.ne(t188, t188)  # t189: "cuda:0 b8[16, 32]"
    # t190 = prims.gt(t188, -448.0)  # t190: "cuda:0 b8[16, 32]"
    # t191 = prims.where(t190, t188, -448.0)  # t191: "cuda:0 f32[16, 32]"
    # t192 = prims.where(t189, t188, t191)  # t192: "cuda:0 f32[16, 32]"
    # t193 = prims.ne(t192, t192)  # t193: "cuda:0 b8[16, 32]"
    # t194 = prims.lt(t192, 448.0)  # t194: "cuda:0 b8[16, 32]"
    # t195 = prims.where(t194, t192, 448.0)  # t195: "cuda:0 f32[16, 32]"
    # t196 = prims.where(t193, t192, t195)  # t196: "cuda:0 f32[16, 32]"
    # t248 = prims.reciprocal(weight_scale)  # t248: "cuda:0 f32[]"
    # t215 = prims.ne(t214, t214)  # t215: "cuda:0 b8[64, 32]"
    # t216 = prims.gt(t214, -448.0)  # t216: "cuda:0 b8[64, 32]"
    # t217 = prims.where(t216, t214, -448.0)  # t217: "cuda:0 f32[64, 32]"
    # t218 = prims.where(t215, t214, t217)  # t218: "cuda:0 f32[64, 32]"
    # t219 = prims.ne(t218, t218)  # t219: "cuda:0 b8[64, 32]"
    # t220 = prims.lt(t218, 448.0)  # t220: "cuda:0 b8[64, 32]"
    # t221 = prims.where(t220, t218, 448.0)  # t221: "cuda:0 f32[64, 32]"
    # t222 = prims.where(t219, t218, t221)  # t222: "cuda:0 f32[64, 32]"
    # t197 = prims.convert_element_type(t196, dtypes.float8_e4m3fn)  # t197: "cuda:0 f8_e4m3fn[16, 32]"
    # t223 = prims.convert_element_type(t222, dtypes.float8_e4m3fn)  # t223: "cuda:0 f8_e4m3fn[64, 32]"
    # t234 = prims.transpose(t223, (1, 0))  # t234: "cuda:0 f8_e4m3fn[32, 64]"
    # t241 = prims.reshape(t197, (16, 32))  # t241: "cuda:0 f8_e4m3fn[16, 32]"
    # t244 = prims.transpose(t234, (1, 0))  # t244: "cuda:0 f8_e4m3fn[64, 32]"

  # /opt/pytorch/lightning-thunder/thunder/core/proxies.py:1966: 	                self.requires_grad,
  input_fp8 = Float8Tensor(t197, scale, torch.bfloat16, LinearMMConfig(output=ScaledMMConfig(emulate=False, use_fast_accum=True, fp8_output=False, pad_inner_dim=False), grad_input=ScaledMMConfig(emulate=False, use_fast_accum=False, fp8_output=False, pad_inner_dim=False), grad_weight=ScaledMMConfig(emulate=False, use_fast_accum=False, fp8_output=False, pad_inner_dim=False)), _GemmInputRole_5, None)  # input_fp8: "cuda:0 bf16[16, 32]"
  del t197, scale
  t245 = torch.clone(t244)  # t245: "cuda:0 f8_e4m3fn[64, 32]"
    # t245 = ltorch.clone(t244, memory_format=_torch_memory_format_6)  # t245: "cuda:0 f8_e4m3fn[64, 32]"
      # t245 = prims.clone(t244)  # t245: "cuda:0 f8_e4m3fn[64, 32]"
  del t244
  [t246] = nvFusion1(t245)
    # t246 = prims.transpose(t245, (1, 0))  # t246: "cuda:0 f8_e4m3fn[32, 64]"
  del t245
  t249 = torch._scaled_mm(t241, t246, t247, t248, None, None, torch.bfloat16, True)  # t249: "cuda:0 bf16[16, 64]"
  del t241, t246, t247, t248

  # /usr/local/lib/python3.12/dist-packages/torchao/float8/float8_linear.py:104: 	        return grad_input, grad_weight.t()
  t122 = shallow_copy(t249)  # t122: "cuda:0 bf16[16, 64]"
  del t249
  return {'output': (t122,), 'flat_args': [input, weight], 'flat_output': (t122,)}, ((input_fp8,), ())

If nvfuser executor is able to include clone and fp8 matmul, then there would be only one nvfuser region by postponing the creation of input_fp8.

@mruberry
Copy link
Collaborator

mruberry commented Dec 13, 2024

Exciting!

I added @riccardofelluga, @jjsjann123, @kshitij12345 and @kevinstephano in case they have some thoughts

@crcrpar crcrpar force-pushed the crpa/add_nvfuser_fp8s branch 2 times, most recently from 0c59273 to 100fbb6 Compare December 13, 2024 19:15
Signed-off-by: Masaki Kozuki <mkozuki@nvidia.com>
@crcrpar crcrpar force-pushed the crpa/add_nvfuser_fp8s branch from 100fbb6 to 9364572 Compare December 16, 2024 06:30
@mruberry mruberry enabled auto-merge (squash) December 16, 2024 16:06
Copy link
Collaborator

@mruberry mruberry left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@mruberry mruberry merged commit a210fba into main Dec 16, 2024
41 checks passed
@mruberry mruberry deleted the crpa/add_nvfuser_fp8s branch December 16, 2024 16:06
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants