From e4b0bf09c7a27ef513aa4187852c4d78f5ac0e74 Mon Sep 17 00:00:00 2001 From: Oliver Batchelor Date: Sun, 23 Jun 2024 16:02:49 +1200 Subject: [PATCH] [misc] Add conversions for unsigned types, torch > 2.3.0 (#8528) ### Brief Summary pytorch 2.3.0 now has unsigned datatypes, add conversions for those from taichi unsigned types. --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Bob Cao --- python/taichi/lang/util.py | 25 +++++++++++++++++++++---- 1 file changed, 21 insertions(+), 4 deletions(-) diff --git a/python/taichi/lang/util.py b/python/taichi/lang/util.py index a853daea64a65..071209fe2e957 100644 --- a/python/taichi/lang/util.py +++ b/python/taichi/lang/util.py @@ -166,9 +166,18 @@ def to_pytorch_type(dt): return torch.uint8 if dt == f16: return torch.float16 + if dt in (u16, u32, u64): - raise RuntimeError(f"PyTorch doesn't support {dt.to_string()} data type.") - assert False + if hasattr(torch, "uint16"): + if dt == u16: + return torch.uint16 + if dt == u32: + return torch.uint32 + if dt == u64: + return torch.uint64 + raise RuntimeError(f"PyTorch doesn't support {dt.to_string()} data type before version 2.3.0.") + + raise RuntimeError(f"PyTorch doesn't support {dt.to_string()} data type.") def to_paddle_type(dt): @@ -266,8 +275,16 @@ def to_taichi_type(dt): return u8 if dt == torch.float16: return f16 - if dt in (u16, u32, u64): - raise RuntimeError(f"PyTorch doesn't support {dt.to_string()} data type.") + + if hasattr(torch, "uint16"): + if dt == torch.uint16: + return u16 + if dt == torch.uint32: + return u32 + if dt == torch.uint64: + return u64 + + raise RuntimeError(f"PyTorch doesn't support {dt.to_string()} data type before version 2.3.0.") if has_paddle(): import paddle # pylint: disable=C0415