Skip to content

Commit

Permalink
[Unity][UnitTest] Enable BindParams test for R.Prim (#15978)
Browse files Browse the repository at this point in the history
This test was implemented in #15626,
but was initially disabled as it depended on functionality not
introduced until #15577.  Since that
PR has landed, cleaning up and enabling the unit test.
  • Loading branch information
Lunderberg authored Oct 30, 2023
1 parent 6936829 commit d932608
Showing 1 changed file with 8 additions and 6 deletions.
14 changes: 8 additions & 6 deletions tests/python/relax/test_bind_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,21 +111,23 @@ def expected() -> R.Shape([16]):
prim_value_dtype = tvm.testing.parameter("int64", "int32", "float32")


@pytest.mark.xfail(reason="Depends on relax.PrimValue holding a tir.PrimExpr, PR#15577")
def test_bind_prim_value(prim_value_dtype):
N = tir.Var("N", prim_value_dtype)
value = tir.const(16, prim_value_dtype)

@R.function
def before(A: R.Prim(value="N", dtype=prim_value_dtype)):
def before(A: R.Prim(value=N)):
R.func_attr({"global_symbol": "main"})
B: R.Prim(value="N", dtype=prim_value_dtype) = A
B: R.Prim(value=N) = A
return B

@R.function
def expected() -> R.Prim(value=16, dtype=prim_value_dtype):
def expected() -> R.Prim(value=value):
R.func_attr({"global_symbol": "main"})
B = R.PrimValue(value=16, dtype=dtype)
B = R.prim_value(value)
return B

after = before.bind_params({"A": relax.PrimValue(tir.const(16, prim_value_dtype))})
after = before.bind_params({"A": relax.PrimValue(value)})

tvm.ir.assert_structural_equal(expected, after)

Expand Down

0 comments on commit d932608

Please sign in to comment.