Skip to content

Commit

Permalink
[refactor] Split GraphBuilder out of Graph class (#5064)
Browse files Browse the repository at this point in the history
  • Loading branch information
ailzhang authored May 31, 2022
1 parent 69a53fe commit e6cb288
Show file tree
Hide file tree
Showing 4 changed files with 25 additions and 21 deletions.
14 changes: 7 additions & 7 deletions python/taichi/examples/graph/mpm88_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,11 +119,11 @@ def init_particles(x: ti.any_arr(field_dim=1), v: ti.any_arr(field_dim=1),
sym_J = ti.graph.Arg(ti.graph.ArgKind.NDARRAY, 'J', ti.f32)
sym_grid_v = ti.graph.Arg(ti.graph.ArgKind.NDARRAY, 'grid_v', ti.f32, element_shape=(2, ))
sym_grid_m = ti.graph.Arg(ti.graph.ArgKind.NDARRAY, 'grid_m', ti.f32)
g_init = ti.graph.Graph()
g_init.dispatch(init_particles, sym_x, sym_v, sym_J)
g_init_builder = ti.graph.GraphBuilder()
g_init_builder.dispatch(init_particles, sym_x, sym_v, sym_J)

g_update = ti.graph.Graph()
substep = g_update.create_sequential()
g_update_builder = ti.graph.GraphBuilder()
substep = g_update_builder.create_sequential()

substep.dispatch(substep_reset_grid, sym_grid_v, sym_grid_m)
substep.dispatch(substep_p2g, sym_x, sym_v, sym_C, sym_J, sym_grid_v,
Expand All @@ -132,11 +132,11 @@ def init_particles(x: ti.any_arr(field_dim=1), v: ti.any_arr(field_dim=1),
substep.dispatch(substep_g2p, sym_x, sym_v, sym_C, sym_J, sym_grid_v)

for i in range(N_ITER):
g_update.append(substep)
g_update_builder.append(substep)

# Compile
g_init.compile()
g_update.compile()
g_init = g_init_builder.compile()
g_update = g_update_builder.compile()

# Run
g_init.run({'x': x, 'v': v, 'J': J})
Expand Down
12 changes: 8 additions & 4 deletions python/taichi/graph/_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,9 @@ def dispatch(self, kernel_fn, *args):
self.seq_.dispatch(kernel_cpp, args)


class Graph:
class GraphBuilder:
def __init__(self):
self._graph_builder = _ti_core.GraphBuilder()
self._compiled_graph = None

def dispatch(self, kernel_fn, *args):
kernel_cpp = gen_cpp_kernel(kernel_fn, args)
Expand All @@ -43,7 +42,12 @@ def append(self, node):
self._graph_builder.seq().append(node.seq_)

def compile(self):
self._compiled_graph = self._graph_builder.compile()
return Graph(self._graph_builder.compile())


class Graph:
def __init__(self, compiled_graph) -> None:
self._compiled_graph = compiled_graph

def run(self, args):
arg_ptrs = {}
Expand All @@ -65,4 +69,4 @@ def run(self, args):
self._compiled_graph.run(arg_ptrs, arg_ints, arg_floats)


__all__ = ['Graph', 'Arg', 'ArgKind']
__all__ = ['GraphBuilder', 'Graph', 'Arg', 'ArgKind']
14 changes: 7 additions & 7 deletions tests/python/test_aot.py
Original file line number Diff line number Diff line change
Expand Up @@ -736,11 +736,11 @@ def init_particles(x: ti.any_arr(field_dim=1), v: ti.any_arr(field_dim=1),
'grid_m',
ti.f32,
element_shape=())
g_init = ti.graph.Graph()
g_init.dispatch(init_particles, sym_x, sym_v, sym_J)
g_init_builder = ti.graph.GraphBuilder()
g_init_builder.dispatch(init_particles, sym_x, sym_v, sym_J)

g_update = ti.graph.Graph()
substep = g_update.create_sequential()
g_update_builder = ti.graph.GraphBuilder()
substep = g_update_builder.create_sequential()

substep.dispatch(substep_reset_grid, sym_grid_v, sym_grid_m)
substep.dispatch(substep_p2g, sym_x, sym_v, sym_C, sym_J, sym_grid_v,
Expand All @@ -749,10 +749,10 @@ def init_particles(x: ti.any_arr(field_dim=1), v: ti.any_arr(field_dim=1),
substep.dispatch(substep_g2p, sym_x, sym_v, sym_C, sym_J, sym_grid_v)

for i in range(N_ITER):
g_update.append(substep)
g_update_builder.append(substep)

g_init.compile()
g_update.compile()
g_init = g_init_builder.compile()
g_update = g_update_builder.compile()

x = ti.Vector.ndarray(2, ti.f32, shape=(n_particles))
v = ti.Vector.ndarray(2, ti.f32, shape=(n_particles))
Expand Down
6 changes: 3 additions & 3 deletions tests/python/test_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,10 @@ def test(pos: ti.types.ndarray(field_dim=1, element_shape=())):
pos[i] = 1

sym_pos = ti.graph.Arg(ti.graph.ArgKind.NDARRAY, 'pos', ti.i32)
g_init = ti.graph.Graph()
g_init = ti.graph.GraphBuilder()
g_init.dispatch(test, sym_pos)
g_init.compile()
g = g_init.compile()

a = ti.ndarray(ti.i32, shape=(n, ))
g_init.run({'pos': a})
g.run({'pos': a})
assert (a.to_numpy() == np.ones(4)).all()

0 comments on commit e6cb288

Please sign in to comment.