diff --git a/python/taichi/lang/kernel_impl.py b/python/taichi/lang/kernel_impl.py index b6e124ba81cd6..7f6b7be233476 100644 --- a/python/taichi/lang/kernel_impl.py +++ b/python/taichi/lang/kernel_impl.py @@ -465,6 +465,9 @@ def materialize(self, key=None, args=None, arg_features=None): if self.is_grad: KernelSimplicityASTChecker(self.func).visit(tree) + if impl.current_cfg().use_mesh: + taichi.lang.Mesh.update_relation(tree, ctx) + # Do not change the name of 'taichi_ast_generator' # The warning system needs this identifier to remove unnecessary messages def taichi_ast_generator(kernel_cxx): diff --git a/python/taichi/lang/mesh.py b/python/taichi/lang/mesh.py index 62a3154eb5556..6b9e474b1ae75 100644 --- a/python/taichi/lang/mesh.py +++ b/python/taichi/lang/mesh.py @@ -1,3 +1,4 @@ +import ast import json import numpy as np @@ -284,6 +285,7 @@ class MeshInstance: def __init__(self, _type): self._type = _type self.mesh_ptr = _ti_core.create_mesh() + self.relation_set = set() def set_owned_offset(self, element_type: MeshElementType, owned_offset: ScalarField): @@ -310,12 +312,14 @@ def set_patch_max_element_num(self, element_type: MeshElementType, def set_relation_fixed(self, rel_type: MeshRelationType, value: ScalarField): + self.relation_set.add(rel_type) _ti_core.set_relation_fixed(self.mesh_ptr, rel_type, value.vars[0].ptr.snode()) def set_relation_dynamic(self, rel_type: MeshRelationType, value: ScalarField, patch_offset: ScalarField, offset: ScalarField): + self.relation_set.add(rel_type) _ti_core.set_relation_dynamic(self.mesh_ptr, rel_type, value.vars[0].ptr.snode(), patch_offset.vars[0].ptr.snode(), @@ -325,6 +329,33 @@ def add_mesh_attribute(self, element_type, snode, reorder_type): _ti_core.add_mesh_attribute(self.mesh_ptr, element_type, snode, reorder_type) + def get_relation_size(self, from_index, to_element_type): + return _ti_core.get_relation_size(self.mesh_ptr, from_index.ptr, + to_element_type) + + def get_relation_access(self, from_index, to_element_type, + neighbor_idx_ptr): + return _ti_core.get_relation_access(self.mesh_ptr, from_index.ptr, + to_element_type, neighbor_idx_ptr) + + def update_relation(self, from_order, to_order): + rel_type = MeshRelationType(relation_by_orders(from_order, to_order)) + if rel_type not in self.relation_set: + meta = self.patcher.get_relation_meta(from_order, to_order) + print('new relation') + + def fun(arr, dtype): + field = impl.field(dtype=dtype, shape=arr.shape) + field.from_numpy(arr) + return field + + if from_order <= to_order: + self.set_relation_dynamic(rel_type, fun(meta["value"], u16), + fun(meta["patch_offset"], u32), + fun(meta["offset"], u16)) + else: + self.set_relation_fixed(rel_type, fun(meta["value"], u16)) + class MeshMetadata: def __init__(self, data): @@ -368,6 +399,8 @@ def __init__(self, data): dtype=u16, shape=len(relation["offset"])) self.relation_fields[rel_type]["patch_offset"] = impl.field( dtype=u32, shape=len(relation["patch_offset"])) + self.relation_fields[rel_type]["from_order"] = from_order + self.relation_fields[rel_type]["to_order"] = to_order for element in data["elements"]: element_type = MeshElementType(element["order"]) @@ -397,6 +430,10 @@ def __init__(self, data): self.attrs = {} self.attrs["x"] = np.array(data["attrs"]["x"]).reshape(-1, 3) + if "patcher" in data: + self.patcher = data["patcher"] + else: + self.patcher = None # Define the Mesh Type, stores the field type info @@ -417,6 +454,8 @@ def __init__(self, topology): self.elements = set() self.relations = set() + impl.current_cfg().use_mesh = True + def build(self, metadata: MeshMetadata): """Build and instantiate mesh from model meta data @@ -434,7 +473,8 @@ def build(self, metadata: MeshMetadata): instance.set_num_patches(metadata.num_patches) - for element in self.elements: + for element in metadata.element_fields: + self.elements.add(element) _ti_core.set_num_elements(instance.mesh_ptr, element, metadata.num_elements[element]) instance.set_patch_max_element_num( @@ -459,11 +499,9 @@ def build(self, metadata: MeshMetadata): instance.set_index_mapping(element, ConvType.g2r, metadata.element_fields[element]["g2r"]) - for relation in self.relations: - from_order = element_order(relation[0]) - to_order = element_order(relation[1]) - rel_type = MeshRelationType( - relation_by_orders(from_order, to_order)) + for rel_type in metadata.relation_fields: + from_order = metadata.relation_fields[rel_type]["from_order"] + to_order = metadata.relation_fields[rel_type]["to_order"] if from_order <= to_order: instance.set_relation_dynamic( rel_type, metadata.relation_fields[rel_type]["value"], @@ -476,6 +514,8 @@ def build(self, metadata: MeshMetadata): if "x" in instance.verts.attr_dict: # pylint: disable=E1101 instance.verts.x.from_numpy(metadata.attrs["x"]) # pylint: disable=E1101 + instance.patcher = metadata.patcher + return instance @@ -520,6 +560,55 @@ def load_meta(filename): def generate_meta(data): return MeshMetadata(data) + class RelationVisitor(ast.NodeVisitor): + # TODO: only works for simple cases + + def __init__(self, ctx): + self.vars = {} + self.visits = [] + self.ctx = ctx + + def visit_For(self, node): + if isinstance(node.iter, ast.Attribute): + value = node.iter.value + if isinstance(value, ast.Name): + if value.id in self.ctx.global_vars: + var = self.ctx.global_vars[value.id] + if isinstance(var, MeshInstance): + self.vars[node.target.id] = [var, node.iter.attr] + if isinstance(node.iter, ast.Name): + if node.iter.id in self.ctx.global_vars: + var = self.ctx.global_vars[node.iter.id] + if isinstance(var, MeshElementField): + self.vars[node.target.id] = [ + var.mesh, element_type_name(var._type) + ] + ast.NodeVisitor.generic_visit(self, node) + + def visit_Assign(self, node): + if isinstance(node.targets[0], ast.Name): + if isinstance(node.value, ast.Name): + if node.value.id in self.vars: + self.vars[node.targets[0].id] = self.vars[ + node.value.id] + ast.NodeVisitor.generic_visit(self, node) + + def visit_Attribute(self, node): + if isinstance(node.value, ast.Name): + if node.value.id in self.vars: + self.visits.append(self.vars[node.value.id] + [node.attr]) + ast.NodeVisitor.generic_visit(self, node) + + @staticmethod + def update_relation(tree, ctx): + x = Mesh.RelationVisitor(ctx) + x.visit(tree) + name_to_order = {"verts": 0, "edges": 1, "faces": 2, "cells": 3} + for visit in x.visits: + if visit[1] in name_to_order and visit[2] in name_to_order: + visit[0].update_relation(name_to_order[visit[1]], + name_to_order[visit[2]]) + def TriMesh(): """Create a triangle mesh (a set of vert/edge/face elements, attributes, and connectivity) builder. @@ -594,15 +683,13 @@ def __init__(self, mesh: MeshInstance, from_index: impl.Expr, @property def size(self): return impl.Expr( - _ti_core.get_relation_size(self.mesh.mesh_ptr, self.from_index.ptr, - self.to_element_type)) + self.mesh.get_relation_size(self.from_index, self.to_element_type)) def subscript(self, *indices): assert len(indices) == 1 - entry_expr = _ti_core.get_relation_access(self.mesh.mesh_ptr, - self.from_index.ptr, - self.to_element_type, - impl.Expr(indices[0]).ptr) + entry_expr = self.mesh.get_relation_access(self.from_index, + self.to_element_type, + impl.Expr(indices[0]).ptr) entry_expr.type_check(impl.get_runtime().prog.config) return MeshElementFieldProxy(self.mesh, self.to_element_type, entry_expr) diff --git a/taichi/program/compile_config.cpp b/taichi/program/compile_config.cpp index 356dea7dc61fe..ac0b2784d2651 100644 --- a/taichi/program/compile_config.cpp +++ b/taichi/program/compile_config.cpp @@ -45,6 +45,7 @@ CompileConfig::CompileConfig() { make_block_local = true; detect_read_only = true; ndarray_use_cached_allocator = true; + use_mesh = false; saturating_grid_dim = 0; max_block_dim = 0; diff --git a/taichi/program/compile_config.h b/taichi/program/compile_config.h index fc5c4d8ccdcf9..12ec3f621ba26 100644 --- a/taichi/program/compile_config.h +++ b/taichi/program/compile_config.h @@ -43,6 +43,7 @@ struct CompileConfig { bool make_block_local; bool detect_read_only; bool ndarray_use_cached_allocator; + bool use_mesh; DataType default_fp; DataType default_ip; std::string extra_flags; diff --git a/taichi/python/export_lang.cpp b/taichi/python/export_lang.cpp index 313cd6469a5d0..8b2800dfcca02 100644 --- a/taichi/python/export_lang.cpp +++ b/taichi/python/export_lang.cpp @@ -178,6 +178,7 @@ void export_lang(py::module &m) { .def_readwrite("detect_read_only", &CompileConfig::detect_read_only) .def_readwrite("ndarray_use_cached_allocator", &CompileConfig::ndarray_use_cached_allocator) + .def_readwrite("use_mesh", &CompileConfig::use_mesh) .def_readwrite("cc_compile_cmd", &CompileConfig::cc_compile_cmd) .def_readwrite("cc_link_cmd", &CompileConfig::cc_link_cmd) .def_readwrite("async_opt_passes", &CompileConfig::async_opt_passes)