From 7edaee71a9583f8f1eedf91187aa77a0797c38e0 Mon Sep 17 00:00:00 2001 From: Jeremy Reizenstein Date: Tue, 16 Jul 2024 11:30:20 -0700 Subject: [PATCH] allow matrix_to_quaternion onnx export Summary: Attempt to allow torch.onnx.dynamo_export(matrix_to_quaternion) to work. Differential Revision: D59812279 fbshipit-source-id: 4497e5b543bec9d5c2bdccfb779d154750a075ad --- pytorch3d/transforms/rotation_conversions.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/pytorch3d/transforms/rotation_conversions.py b/pytorch3d/transforms/rotation_conversions.py index 0164e413..a9fcae22 100644 --- a/pytorch3d/transforms/rotation_conversions.py +++ b/pytorch3d/transforms/rotation_conversions.py @@ -97,7 +97,10 @@ def _sqrt_positive_part(x: torch.Tensor) -> torch.Tensor: """ ret = torch.zeros_like(x) positive_mask = x > 0 - ret[positive_mask] = torch.sqrt(x[positive_mask]) + if torch.is_grad_enabled(): + ret[positive_mask] = torch.sqrt(x[positive_mask]) + else: + ret = torch.where(positive_mask, torch.sqrt(x), ret) return ret