Skip to content

Commit

Permalink
SpecDB: Add spec: scatter.value
Browse files Browse the repository at this point in the history
Reviewed By: JacobSzwejbka

Differential Revision: D61874510

fbshipit-source-id: 11622265cc0fda1b75104697b7cc4443fe6aeca5
  • Loading branch information
manuelcandales authored and facebook-github-bot committed Aug 27, 2024
1 parent b8d1403 commit 4c7affb
Showing 1 changed file with 72 additions and 0 deletions.
72 changes: 72 additions & 0 deletions specdb/db.py
Original file line number Diff line number Diff line change
Expand Up @@ -3478,6 +3478,78 @@
],
outspec=[OutArg(ArgType.Tensor)],
),
Spec( # TODO(mcandales): Calibrate.
op="scatter.value", # (Tensor self, int dim, Tensor index, Scalar value) -> Tensor
inspec=[
InPosArg(ArgType.Tensor, name="self"),
InPosArg(
ArgType.Dim,
name="dim",
deps=[0],
constraints=[
cp.Value.In(lambda deps: fn.dim_non_zero_size(deps[0])),
],
),
InPosArg(
ArgType.Tensor,
name="index",
deps=[0, 1],
# TODO(mcandales) Handle index.numel() == 0 case
constraints=[
cp.Dtype.Eq(lambda deps: torch.long),
cp.Rank.Eq(
lambda deps: deps[0].dim() if deps[0].dim() >= 2 else None
),
cp.Rank.In(
lambda deps: [0, 1] if deps[0].dim() in [0, 1] else None
),
cp.Size.Le(
lambda deps, r, d: (
fn.safe_size(deps[0], d)
if d != fn.normalize(deps[1], deps[0].dim())
else None
)
),
cp.Value.Ge(lambda deps, dtype, struct: 0),
cp.Value.Le(
lambda deps, dtype, struct: (
0
if deps[0].dim() == 0
else max(0, fn.safe_size(deps[0], deps[1]) - 1)
)
),
],
),
InPosArg(
ArgType.Scalar,
name="value",
deps=[0],
constraints=[
cp.Value.NotIn(
lambda deps, dtype: (
[float("-inf"), float("inf")]
if deps[0].dtype not in dt._floating
else None
)
),
cp.Value.Ge(
lambda deps, dtype: fn.dtype_lower_bound(deps[0].dtype)
),
cp.Value.Le(
lambda deps, dtype: fn.dtype_upper_bound(deps[0].dtype)
),
],
),
],
outspec=[
OutArg(
ArgType.Tensor,
constraints=[
cp.Dtype.Eq(lambda deps: deps[0].dtype),
],
),
],
),
Spec( # TODO(mcandales): Calibrate.
op="scatter_add.default", # (Tensor self, int dim, Tensor index, Tensor src) -> Tensor
inspec=[
Expand Down

0 comments on commit 4c7affb

Please sign in to comment.