From d2a036523945a8e4f7eb00e9f4a353e55bce3733 Mon Sep 17 00:00:00 2001 From: Oliver Batchelor Date: Tue, 14 May 2024 12:52:30 +1200 Subject: [PATCH 1/3] Add conversions for unsigned types, torch > 2.3.0 --- python/taichi/lang/util.py | 27 +++++++++++++++++++++++---- 1 file changed, 23 insertions(+), 4 deletions(-) diff --git a/python/taichi/lang/util.py b/python/taichi/lang/util.py index a853daea64a65..1a70f12b74bd4 100644 --- a/python/taichi/lang/util.py +++ b/python/taichi/lang/util.py @@ -166,8 +166,18 @@ def to_pytorch_type(dt): return torch.uint8 if dt == f16: return torch.float16 + if dt in (u16, u32, u64): - raise RuntimeError(f"PyTorch doesn't support {dt.to_string()} data type.") + if hasattr(torch, "uint16"): + if dt == u16: + return torch.uint16 + if dt == u32: + return torch.uint32 + if dt == u64: + return torch.uint64 + raise RuntimeError(f"PyTorch doesn't support {dt.to_string()} data type before version 2.3.0.") + + raise RuntimeError(f"PyTorch doesn't support {dt.to_string()} data type.") assert False @@ -266,9 +276,18 @@ def to_taichi_type(dt): return u8 if dt == torch.float16: return f16 - if dt in (u16, u32, u64): - raise RuntimeError(f"PyTorch doesn't support {dt.to_string()} data type.") - + + if hasattr(torch, "uint16"): + if dt == torch.uint16: + return u16 + if dt == torch.uint32: + return u32 + if dt == torch.uint64: + return u64 + + raise RuntimeError(f"PyTorch doesn't support {dt.to_string()} data type before version 2.3.0.") + + if has_paddle(): import paddle # pylint: disable=C0415 From 15cf000d72515f8282500c89ee4242cc0d444327 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 14 May 2024 00:55:44 +0000 Subject: [PATCH 2/3] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- python/taichi/lang/util.py | 33 ++++++++++++++++----------------- 1 file changed, 16 insertions(+), 17 deletions(-) diff --git a/python/taichi/lang/util.py b/python/taichi/lang/util.py index 1a70f12b74bd4..5c395b30229ba 100644 --- a/python/taichi/lang/util.py +++ b/python/taichi/lang/util.py @@ -166,16 +166,16 @@ def to_pytorch_type(dt): return torch.uint8 if dt == f16: return torch.float16 - + if dt in (u16, u32, u64): - if hasattr(torch, "uint16"): - if dt == u16: - return torch.uint16 - if dt == u32: - return torch.uint32 - if dt == u64: - return torch.uint64 - raise RuntimeError(f"PyTorch doesn't support {dt.to_string()} data type before version 2.3.0.") + if hasattr(torch, "uint16"): + if dt == u16: + return torch.uint16 + if dt == u32: + return torch.uint32 + if dt == u64: + return torch.uint64 + raise RuntimeError(f"PyTorch doesn't support {dt.to_string()} data type before version 2.3.0.") raise RuntimeError(f"PyTorch doesn't support {dt.to_string()} data type.") assert False @@ -276,18 +276,17 @@ def to_taichi_type(dt): return u8 if dt == torch.float16: return f16 - + if hasattr(torch, "uint16"): - if dt == torch.uint16: - return u16 - if dt == torch.uint32: - return u32 - if dt == torch.uint64: - return u64 + if dt == torch.uint16: + return u16 + if dt == torch.uint32: + return u32 + if dt == torch.uint64: + return u64 raise RuntimeError(f"PyTorch doesn't support {dt.to_string()} data type before version 2.3.0.") - if has_paddle(): import paddle # pylint: disable=C0415 From aaf9529c78faed57b6b78ee543e693acc5f3d251 Mon Sep 17 00:00:00 2001 From: Bob Cao Date: Sat, 22 Jun 2024 17:05:57 -0700 Subject: [PATCH 3/3] Update python/taichi/lang/util.py --- python/taichi/lang/util.py | 1 - 1 file changed, 1 deletion(-) diff --git a/python/taichi/lang/util.py b/python/taichi/lang/util.py index 5c395b30229ba..071209fe2e957 100644 --- a/python/taichi/lang/util.py +++ b/python/taichi/lang/util.py @@ -178,7 +178,6 @@ def to_pytorch_type(dt): raise RuntimeError(f"PyTorch doesn't support {dt.to_string()} data type before version 2.3.0.") raise RuntimeError(f"PyTorch doesn't support {dt.to_string()} data type.") - assert False def to_paddle_type(dt):