Skip to content

Commit

Permalink
[tk] Switch symbolic/shape placeholders to sympy (nod-ai#375)
Browse files Browse the repository at this point in the history
By using sympy instead of the stand-in symbolic support, we get full
support for index (partial) evaluation, dynamic dimensions and
validation checks.
  • Loading branch information
stellaraccident authored Jan 31, 2024
1 parent d6821d3 commit c1dc94c
Show file tree
Hide file tree
Showing 13 changed files with 951 additions and 716 deletions.
573 changes: 345 additions & 228 deletions python/shark_turbine/kernel/_support/indexing.py

Large diffs are not rendered by default.

30 changes: 22 additions & 8 deletions python/shark_turbine/kernel/_support/tracing.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,12 @@
import torch.fx as fx

from .indexing import (
BoundedSymbolicValue,
backed_sym_index_type,
BoundedRelation,
IndexExpr,
Grid,
KernelBuffer,
sym_0,
SymIndex,
)

from ..lang.types import (
Expand Down Expand Up @@ -98,10 +100,17 @@ class KernelTracer(SubgraphTracer):
# Register our custom proxies.
def proxy(self, node: fx.Node) -> fx.Proxy:
t = node.type
if t is not None and issubclass(t, KernelBuffer):
return KernelBufferProxy(node, self, t)
if t is not None:
if issubclass(t, KernelBuffer):
return KernelBufferProxy(node, self, t)
return super().proxy(node)

def create_arg(self, a):
# Let IndexExpr persist as arguments.
if isinstance(a, IndexExpr):
return a
return super().create_arg(a)


class CapturedTrace:
def __init__(self, region_graph: RegionGraph, root_graph: str):
Expand Down Expand Up @@ -163,23 +172,28 @@ def __init__(self, region_graph: RegionGraph, *, grid_type: Type[Grid]):
super().__init__(eager=False)
self.region_graph = region_graph
self.grid_type = grid_type
self.current_thread_types = [
backed_sym_index_type(BoundedRelation(0, n, upper_inclusive=False))
for n in grid_type.symbolic_shape
]

### ========================================================================
### Core Operations
### ========================================================================

def handle_thread_program_id(self, op, axis: int) -> Index:
grid_shape = self.grid_type.symbolic_shape
if axis < 0 or axis >= len(grid_shape):
grid_types = self.current_thread_types
if axis < 0 or axis >= len(grid_types):
raise IndexError(
f"Illegal index into grid of rank {len(grid_shape)}: {axis}"
f"Illegal index into grid of rank {len(grid_types)}: {axis}"
)

proxy = self.region_graph.create_proxy(
"call_function",
op,
args=(axis,),
kwargs={},
type_expr=BoundedSymbolicValue.bound(sym_0, grid_shape[axis]),
type_expr=grid_types[axis],
)
return proxy

Expand Down
252 changes: 0 additions & 252 deletions python/shark_turbine/kernel/compiler/analysis.py

This file was deleted.

3 changes: 3 additions & 0 deletions python/shark_turbine/kernel/compiler/base.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
NDEBUG = False


class CodegenError(Exception):
...

Expand Down
Loading

0 comments on commit c1dc94c

Please sign in to comment.