Skip to content

Commit

Permalink
nvfp8 with Hopper+ check
Browse files Browse the repository at this point in the history
Signed-off-by: Masaki Kozuki <mkozuki@nvidia.com>
  • Loading branch information
crcrpar committed Dec 13, 2024
1 parent 673bdb9 commit 0c59273
Showing 1 changed file with 20 additions and 2 deletions.
22 changes: 20 additions & 2 deletions thunder/executors/nvfuserex_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,17 @@
}


_lcfp8_to_nvfp8_map: dict[dtypes.dtype, DataType] = {
dtypes.float8_e5m2: DataType.Float8_e5m2,
dtypes.float8_e5m2_: DataType.Float8_e5m2,
dtypes.float8_e4m3fn: DataType.Float8_e4m3fn,
dtypes.float8_e4m3fn_: DataType.Float8_e4m3fn,
}


_lcdtype_to_nvdtype_map.update(_lcfp8_to_nvfp8_map)


def lcdtype_to_nvdtype(lcdtype: type | dtypes.dtype) -> DataType:
return _lcdtype_to_nvdtype_map[lcdtype]

Expand Down Expand Up @@ -144,7 +155,14 @@ def is_supported_devicetype(devicetype: DeviceType) -> bool:
return devicetype is DeviceType.CUDA


_low_precision_floats = (dtypes.float16, dtypes.float16_, dtypes.bfloat16, dtypes.bfloat16_)
_low_precision_floats = (dtypes.float16, dtypes.float16_, dtypes.bfloat16, dtypes.bfloat16_) + tuple(
_lcfp8_to_nvfp8_map.keys()
)


def device_supports_fp8() -> bool:
cuda_major, _ = torch.cuda.get_device_capability()
return cuda_major > 8


def is_supported_dtype(dtype: type | dtypes.dtype, *, allow_low_precision_floats: bool = True) -> bool:
Expand All @@ -154,7 +172,7 @@ def is_supported_dtype(dtype: type | dtypes.dtype, *, allow_low_precision_floats
if dtype in _low_precision_floats:
return False

return dtype in _lcdtype_to_nvdtype_map
return dtype in _lcdtype_to_nvdtype_map and (device_supports_fp8() if dtype in _lcfp8_to_nvfp8_map else True)


def is_supported_tensor(a: TensorProxy, *, allow_low_precision_floats: bool = True) -> bool:
Expand Down

0 comments on commit 0c59273

Please sign in to comment.