diff --git a/include/tvm/script/ir_builder/ir/frame.h b/include/tvm/script/ir_builder/ir/frame.h new file mode 100644 index 000000000000..181774bc53bc --- /dev/null +++ b/include/tvm/script/ir_builder/ir/frame.h @@ -0,0 +1,71 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#ifndef TVM_SCRIPT_IR_BUILDER_IR_FRAME_H_ +#define TVM_SCRIPT_IR_BUILDER_IR_FRAME_H_ + +#include +#include +#include +#include + +#include + +namespace tvm { +namespace script { +namespace ir_builder { + +/*! + * \brief A frame that represents the IRModule frame with functions and global variables. + * + * \sa IRModuleFrame + */ +class IRModuleFrameNode : public IRBuilderFrameNode { + public: + Array global_vars; + Array functions; + + void VisitAttrs(tvm::AttrVisitor* v) { + IRBuilderFrameNode::VisitAttrs(v); + v->Visit("global_vars", &global_vars); + v->Visit("functions", &functions); + } + + static constexpr const char* _type_key = "script.ir_builder.IRModuleFrame"; + TVM_DECLARE_FINAL_OBJECT_INFO(IRModuleFrameNode, IRBuilderFrameNode); + + public: + void ExitWithScope() final; +}; + +/*! + * \brief Managed reference to IRModuleFrameNode. + * + * \sa IRModuleFrameNode + */ +class IRModuleFrame : public IRBuilderFrame { + public: + TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(IRModuleFrame, IRBuilderFrame, + IRModuleFrameNode); +}; + +} // namespace ir_builder +} // namespace script +} // namespace tvm + +#endif // TVM_SCRIPT_IR_BUILDER_IR_FRAME_H_ diff --git a/include/tvm/script/ir_builder/ir/ir.h b/include/tvm/script/ir_builder/ir/ir.h new file mode 100644 index 000000000000..0bd5473c7eaf --- /dev/null +++ b/include/tvm/script/ir_builder/ir/ir.h @@ -0,0 +1,43 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#ifndef TVM_SCRIPT_IR_BUILDER_IR_IR_H_ +#define TVM_SCRIPT_IR_BUILDER_IR_IR_H_ + +#include +#include +#include +#include + +#include + +namespace tvm { +namespace script { +namespace ir_builder { + +/*! + * \brief The IRModule declaration statement. + * \return The IRModuleFrame. + */ +TVM_DLL IRModuleFrame IRModule(); + +} // namespace ir_builder +} // namespace script +} // namespace tvm + +#endif // TVM_SCRIPT_IR_BUILDER_IR_IR_H_ diff --git a/python/tvm/script/ir_builder/ir/__init__.py b/python/tvm/script/ir_builder/ir/__init__.py new file mode 100644 index 000000000000..ebb9728737ad --- /dev/null +++ b/python/tvm/script/ir_builder/ir/__init__.py @@ -0,0 +1,19 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Package tvm.script.ir_builder.ir""" +from .frame import IRModuleFrame +from .ir import ir_module diff --git a/python/tvm/script/ir_builder/ir/_ffi_api.py b/python/tvm/script/ir_builder/ir/_ffi_api.py new file mode 100644 index 000000000000..874cc278af83 --- /dev/null +++ b/python/tvm/script/ir_builder/ir/_ffi_api.py @@ -0,0 +1,20 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""FFI APIs""" +import tvm._ffi + +tvm._ffi._init_api("script.ir_builder.ir", __name__) # pylint: disable=protected-access diff --git a/python/tvm/script/ir_builder/ir/frame.py b/python/tvm/script/ir_builder/ir/frame.py new file mode 100644 index 000000000000..e16d86dc227e --- /dev/null +++ b/python/tvm/script/ir_builder/ir/frame.py @@ -0,0 +1,26 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Package tvm.script.ir_builder.ir.frame""" + +from tvm._ffi import register_object as _register_object + +from ..base import IRBuilderFrame + + +@_register_object("script.ir_builder.IRModuleFrame") +class IRModuleFrame(IRBuilderFrame): + ... diff --git a/python/tvm/script/ir_builder/ir/ir.py b/python/tvm/script/ir_builder/ir/ir.py new file mode 100644 index 000000000000..df920364356b --- /dev/null +++ b/python/tvm/script/ir_builder/ir/ir.py @@ -0,0 +1,24 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Package tvm.script.ir_builder.ir.ir""" + +from . import _ffi_api +from .frame import IRModuleFrame + + +def ir_module() -> IRModuleFrame: + return _ffi_api.IRModule() # pylint: disable=no-member # type: ignore diff --git a/src/script/ir_builder/ir/frame.cc b/src/script/ir_builder/ir/frame.cc new file mode 100644 index 000000000000..c85e30544aca --- /dev/null +++ b/src/script/ir_builder/ir/frame.cc @@ -0,0 +1,43 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#include +#include +#include + +namespace tvm { +namespace script { +namespace ir_builder { + +void IRModuleFrameNode::ExitWithScope() { + ICHECK_EQ(functions.size(), global_vars.size()); + int n = functions.size(); + Map func_map; + for (int i = 0; i < n; ++i) { + func_map.Set(global_vars[i], functions[i]); + } + IRBuilder builder = IRBuilder::Current(); + ICHECK(!builder->result.defined()) << "ValueError: Builder.result has already been set"; + builder->result = tvm::IRModule(func_map); +} + +TVM_REGISTER_NODE_TYPE(IRModuleFrameNode); + +} // namespace ir_builder +} // namespace script +} // namespace tvm diff --git a/src/script/ir_builder/ir/ir.cc b/src/script/ir_builder/ir/ir.cc new file mode 100644 index 000000000000..bcd21de144bb --- /dev/null +++ b/src/script/ir_builder/ir/ir.cc @@ -0,0 +1,38 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#include +#include +#include + +namespace tvm { +namespace script { +namespace ir_builder { + +IRModuleFrame IRModule() { + ObjectPtr n = make_object(); + n->global_vars.clear(); + n->functions.clear(); + return IRModuleFrame(n); +} + +TVM_REGISTER_GLOBAL("script.ir_builder.ir.IRModule").set_body_typed(IRModule); + +} // namespace ir_builder +} // namespace script +} // namespace tvm diff --git a/tests/python/unittest/test_tvmscript_ir_builder_irmodule.py b/tests/python/unittest/test_tvmscript_ir_builder_irmodule.py new file mode 100644 index 000000000000..4ca1af4c4445 --- /dev/null +++ b/tests/python/unittest/test_tvmscript_ir_builder_irmodule.py @@ -0,0 +1,40 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Unittests for tvm.script.ir_builder.ir""" +import pytest +from tvm.script.ir_builder import IRBuilder +from tvm.script.ir_builder import ir as I +from tvm import ir +from tvm.ir.base import assert_structural_equal + + +def test_ir_builder_irmodule(): + with IRBuilder() as ib: # pylint: disable=invalid-name + with I.ir_module(): + pass + + # the ir_module generated by IRBuilder + ir_module_actual = ib.get() + + # the expected prim_func + ir_module_expected = ir.IRModule(None, None) + + assert_structural_equal(ir_module_actual, ir_module_expected, map_free_vars=True) + + +if __name__ == "__main__": + test_ir_builder_irmodule()