Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Impossible to generate random numbers with constexpr seed #3390

Closed
GuillaumeLeclerc opened this issue Mar 15, 2024 · 8 comments · Fixed by #3396
Closed

Impossible to generate random numbers with constexpr seed #3390

GuillaumeLeclerc opened this issue Mar 15, 2024 · 8 comments · Fixed by #3396

Comments

@GuillaumeLeclerc
Copy link

@triton.jit
def blop(
    output,
    RAND_SIZE: tl.constexpr,
    SEED: tl.constexpr
):
    o = tl.arange(0, RAND_SIZE) + tl.program_id(axis=0)
    random_values = tl.rand(SEED, o)
result = ch.zeros(10, dtype=ch.int32).cuda()
blop[(1,)](result, RAND_SIZE=16, SEED=0)

yields

AttributeError("'constexpr' object has no attribute 'to'")
@jlebar
Copy link
Collaborator

jlebar commented Mar 15, 2024

cc @manman-ren

@manman-ren
Copy link
Collaborator

I will try to triage this later this afternoon. If anyone wants to take a look before, feel free!

@manman-ren
Copy link
Collaborator

This "help wanted" label looks new :] @jlebar

@manman-ren
Copy link
Collaborator

I can reproduce with top of trunk. With a simple fix

-def philox(seed, c0, c1, c2, c3, n_rounds: tl.constexpr = N_ROUNDS_DEFAULT):
+def philox(seedin, c0, c1, c2, c3, n_rounds: tl.constexpr = N_ROUNDS_DEFAULT):
+    seed = seedin

It works, the error is because we can't apply "to(tl.uint64)" to a constexpr. The fix (which I am not sure if it is correct) is to copy the constexpr to another variable and apply "to(tl.uint64)" on the variable.

@Jokeren
Copy link
Contributor

Jokeren commented Mar 15, 2024

  1. You can convert seed to a triton tensor in the philox function.
  2. Or to workaround, you can remove the constexpr annotation from SEED.

lijinpei added a commit to lijinpei/triton that referenced this issue Mar 16, 2024
- Otherwise, frontend crashes for non-tensor arguments.
@lijinpei
Copy link
Contributor

Can be fixed by inserting to_tensor in philox, but maybe ast visiter should automatically insert a to_tensor call when we are about to get a attribute access error?

@Jokeren
Copy link
Contributor

Jokeren commented Mar 16, 2024

That might be difficult. You could have a string or other objects as a constexpr, which cannot be converted to a tensor.

@Jokeren
Copy link
Contributor

Jokeren commented Mar 16, 2024

Then in the AST visitor, you'll have to check if the constexpr's value is an integer or a float, only then you can convert to a tensor

Jokeren pushed a commit that referenced this issue Mar 16, 2024
- Otherwise, frontend crashes for non-tensor arguments.

Fixes #3390
karupayun pushed a commit to openxla/triton that referenced this issue Apr 3, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging a pull request may close this issue.

5 participants