-
Notifications
You must be signed in to change notification settings - Fork 2
/
app.py
44 lines (32 loc) · 1011 Bytes
/
app.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
import argparse
import taichi as ti
WIDTH = 16
HEIGHT = 16
def compile_graph_aot(arch):
ti.init(arch=arch)
if ti.lang.impl.current_cfg().arch != arch:
return
@ti.kernel
def chess_board(arr: ti.types.ndarray(ndim=2)):
for i, j in arr:
value = ti.cast((j * (WIDTH + 1) + i) % 2, ti.f32)
arr[i, j] = value
_arr = ti.graph.Arg(ti.graph.ArgKind.NDARRAY,
'arr',
ti.f32,
ndim=2,
element_shape=())
g_builder = ti.graph.GraphBuilder()
g_builder.dispatch(chess_board, _arr)
run_graph = g_builder.compile()
mod = ti.aot.Module(arch)
mod.add_graph('g_run', run_graph)
mod.save("module")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--arch", type=str)
args = parser.parse_args()
if args.arch == "vulkan":
compile_graph_aot(arch=ti.vulkan)
else:
assert False