Skip to content

Commit

Permalink
[fix] torch compile executor trace building does not add names (#1545)
Browse files Browse the repository at this point in the history
  • Loading branch information
ali-alshaar7 authored Dec 12, 2024
1 parent de2bbd5 commit 673bdb9
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 0 deletions.
7 changes: 7 additions & 0 deletions thunder/executors/torch_compile.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,16 +72,23 @@ def make_compiled(
region_trace = TraceCtx(None)
region_trace.args = sorted_unique_inputs
region_trace.kwargs = {}
region_trace.names = {a.name for a in region_trace.args}
with tracectx(region_trace):
for a in sorted_unique_inputs:
prims.unpack_trivial(a, name=a.name)

region_trace.bound_symbols += list(bsyms)
region_trace.bound_symbols.append(prims.python_return.bind(sorted_unique_outputs, output=None))
for bsym in region_trace.bound_symbols:
if bsym.sym == prims.unpack_trivial:
continue
for o in bsym.flat_outs:
if o is not None:
region_trace.add_name(o.name)
for sbsym in bsym.subsymbols:
for o in sbsym.flat_outs:
if o is not None and o.name not in region_trace.names:
region_trace.add_name(o.name)

# maybe make this the default if no sig info is present?
region_trace._siginfo = SigInfo("to_be_compiled")
Expand Down
31 changes: 31 additions & 0 deletions thunder/tests/test_torch_compile_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,3 +90,34 @@ def fn(a):
a = torch.randn(3)
jfn = thunder.jit(fn, executors=(thunder.executors.torch_compile.torch_compile_ex,))
assert_close(jfn(a), fn(a))


@pytest.mark.skipif(not is_inductor_supported(), reason="inductor unsupported")
@requiresCUDA
@pytest.mark.skipif(not device_supports_bf16(torch.device("cuda")), reason="bf16 is not supported")
def test_litgpt_fabric_for_callable():
from typing import Any, Optional, Tuple, Union, List, Dict
from collections.abc import Callable
from litgpt.model import Config, GPT
import torch.nn as nn

def jit(fn: Callable, executors: list[str]) -> Any:
assert executors is not None
return thunder.jit(fn, executors=executors)

def forward_and_loss(model: nn.Module, input_ids: torch.Tensor) -> torch.Tensor:
logits = model(input_ids)
return logits

forward_and_loss_jitted = jit(forward_and_loss, executors=("sdpa", "torchcompile", "nvfuser", "torch"))

config = Config(block_size=2, n_layer=2, n_embd=8, n_head=4, padded_vocab_size=8)

with torch.device("cuda"):
model = GPT(config)

input_ids = torch.zeros(1, 2, dtype=torch.int64, device="cuda")
out = forward_and_loss(model, input_ids)
out_jitted = forward_and_loss_jitted(model, input_ids)

assert_close(out, out_jitted)

0 comments on commit 673bdb9

Please sign in to comment.