Skip to content

Commit

Permalink
fix issue#3269: unwrap tensor shape without opt val (#3279)
Browse files Browse the repository at this point in the history
  • Loading branch information
lanluo-nvidia authored Nov 5, 2024
1 parent 8e2c82d commit aa36e9a
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 4 deletions.
3 changes: 2 additions & 1 deletion py/torch_tensorrt/dynamo/partitioning/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,8 @@ def construct_dynamic_input(
if isinstance(dim, torch.SymInt):
min_max_opt = extract_var_range_info(dim)
min_shape.append(min_max_opt["min"])
opt_shape.append(min_max_opt["opt"])
# opt might not exist
opt_shape.append(min_max_opt.get("opt"))
max_shape.append(min_max_opt["max"])
else:
min_shape.append(dim)
Expand Down
7 changes: 4 additions & 3 deletions py/torch_tensorrt/dynamo/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union

import numpy as np
import sympy
import tensorrt as trt
import torch
from torch._subclasses.fake_tensor import FakeTensor
Expand Down Expand Up @@ -342,14 +343,14 @@ def extract_var_range_info(symbolic_integer: torch.SymInt) -> Dict[str, int]:
shape_env.var_to_val
)
assert var_range, var_val
min_val, max_val, opt_val = int(var_range.lower), int(var_range.upper), int(var_val)
min_val, max_val = int(var_range.lower), int(var_range.upper)
# Torchdynamo 0/1 specialization outlier
min_val = 1 if min_val == 2 else min_val
min_max_opt = {}
min_max_opt["min"] = min_val
min_max_opt["max"] = max_val
min_max_opt["opt"] = opt_val

if isinstance(var_val, sympy.core.numbers.Integer):
min_max_opt["opt"] = int(var_val)
return min_max_opt


Expand Down

0 comments on commit aa36e9a

Please sign in to comment.