Skip to content

Commit

Permalink
[lang] Add tests for refine_coordinates (#2382)
Browse files Browse the repository at this point in the history
* [lang] Add tests for refine_coordinates

* fix fmt

* mess up fmt

* Auto Format

Co-authored-by: Taichi Gardener <taichigardener@gmail.com>
  • Loading branch information
k-ye and taichi-gardener authored Jun 4, 2021
1 parent a35085f commit 1d9be89
Show file tree
Hide file tree
Showing 12 changed files with 280 additions and 96 deletions.
9 changes: 7 additions & 2 deletions cmake/TaichiTests.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,13 @@ endif()
# TODO(#2195):
# 1. "cpp" -> "cpp_legacy", "cpp_new" -> "cpp"
# 2. Re-implement the legacy CPP tests using googletest
file(GLOB_RECURSE TAICHI_TESTS_SOURCE "tests/cpp/analysis/*.cpp" "tests/cpp/common/*.cpp" "tests/cpp/ir/*.cpp"
"tests/cpp/program/*.cpp" "tests/cpp/transforms/*.cpp")
file(GLOB_RECURSE TAICHI_TESTS_SOURCE
"tests/cpp/analysis/*.cpp"
"tests/cpp/codegen/*.cpp"
"tests/cpp/common/*.cpp"
"tests/cpp/ir/*.cpp"
"tests/cpp/program/*.cpp"
"tests/cpp/transforms/*.cpp")

include_directories(
${PROJECT_SOURCE_DIR},
Expand Down
4 changes: 2 additions & 2 deletions taichi/codegen/codegen_llvm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -281,7 +281,7 @@ CodeGenLLVM::CodeGenLLVM(Kernel *kernel, IRNode *ir)
initialize_context();

context_ty = get_runtime_type("Context");
physical_coordinate_ty = get_runtime_type("PhysicalCoordinates");
physical_coordinate_ty = get_runtime_type(kLLVMPhysicalCoordinatesName);

kernel_name = kernel->name + "_kernel";
}
Expand Down Expand Up @@ -1652,7 +1652,7 @@ void CodeGenLLVM::create_offload_struct_for(OffloadedStmt *stmt, bool spmd) {
}
}

auto coord_object = RuntimeObject("PhysicalCoordinates", this,
auto coord_object = RuntimeObject(kLLVMPhysicalCoordinatesName, this,
builder.get(), new_coordinates);
for (int i = 0; i < snode->num_active_indices; i++) {
auto j = snode->physical_index_position[i];
Expand Down
3 changes: 2 additions & 1 deletion taichi/ir/ir_builder.h
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,7 @@ class IRBuilder {
[[nodiscard]] LoopGuard get_loop_guard(XStmt *loop) {
return LoopGuard(*this, loop);
}

[[nodiscard]] IfGuard get_if_guard(IfStmt *if_stmt, bool true_branch) {
return IfGuard(*this, if_stmt, true_branch);
}
Expand Down Expand Up @@ -192,7 +193,7 @@ class IRBuilder {

// Print values and strings. Arguments can be Stmt* or std::string.
template <typename... Args>
PrintStmt *create_print(Args &&...args) {
PrintStmt *create_print(Args &&... args) {
return insert(Stmt::make_typed<PrintStmt>(std::forward<Args>(args)...));
}

Expand Down
6 changes: 4 additions & 2 deletions taichi/llvm/llvm_codegen_utils.cpp
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#include "llvm_codegen_utils.h"

TLANG_NAMESPACE_BEGIN
namespace taichi {
namespace lang {

std::string type_name(llvm::Type *type) {
std::string type_name_str;
Expand Down Expand Up @@ -41,4 +42,5 @@ void check_func_call_signature(llvm::Value *func,
}
}

TLANG_NAMESPACE_END
} // namespace lang
} // namespace taichi
30 changes: 19 additions & 11 deletions taichi/llvm/llvm_codegen_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,10 @@

#include "llvm_context.h"

TLANG_NAMESPACE_BEGIN
namespace taichi {
namespace lang {

inline constexpr char kLLVMPhysicalCoordinatesName[] = "PhysicalCoordinates";

std::string type_name(llvm::Type *type);

Expand All @@ -50,11 +53,11 @@ inline bool check_func_call_signature(llvm::Value *func, Args &&... args) {

class LLVMModuleBuilder {
public:
std::unique_ptr<llvm::Module> module;
llvm::BasicBlock *entry_block;
std::unique_ptr<llvm::IRBuilder<>> builder;
TaichiLLVMContext *tlctx;
llvm::LLVMContext *llvm_context;
std::unique_ptr<llvm::Module> module{nullptr};
llvm::BasicBlock *entry_block{nullptr};
std::unique_ptr<llvm::IRBuilder<>> builder{nullptr};
TaichiLLVMContext *tlctx{nullptr};
llvm::LLVMContext *llvm_context{nullptr};

LLVMModuleBuilder(std::unique_ptr<llvm::Module> &&module,
TaichiLLVMContext *tlctx)
Expand Down Expand Up @@ -145,10 +148,10 @@ class LLVMModuleBuilder {
class RuntimeObject {
public:
std::string cls_name;
llvm::Value *ptr;
LLVMModuleBuilder *mb;
llvm::Type *type;
llvm::IRBuilder<> *builder;
llvm::Value *ptr{nullptr};
LLVMModuleBuilder *mb{nullptr};
llvm::Type *type{nullptr};
llvm::IRBuilder<> *builder{nullptr};

RuntimeObject(const std::string &cls_name,
LLVMModuleBuilder *mb,
Expand Down Expand Up @@ -179,6 +182,10 @@ class RuntimeObject {
call(fmt::format("set_{}", field), val);
}

void set(const std::string &field, llvm::Value *index, llvm::Value *val) {
call(fmt::format("set_{}", field), index, val);
}

template <typename... Args>
llvm::Value *call(const std::string &func_name, Args &&... args) {
auto func = get_func(func_name);
Expand All @@ -192,4 +199,5 @@ class RuntimeObject {
}
};

TLANG_NAMESPACE_END
} // namespace lang
} // namespace taichi
36 changes: 19 additions & 17 deletions taichi/program/program.h
Original file line number Diff line number Diff line change
Expand Up @@ -60,8 +60,8 @@ TLANG_NAMESPACE_END
namespace std {
template <>
struct hash<taichi::lang::JITEvaluatorId> {
std::size_t operator()(taichi::lang::JITEvaluatorId const &id) const
noexcept {
std::size_t operator()(
taichi::lang::JITEvaluatorId const &id) const noexcept {
return ((std::size_t)id.op | (id.ret.hash() << 8) | (id.lhs.hash() << 16) |
(id.rhs.hash() << 24) | ((std::size_t)id.is_binary << 31)) ^
(std::hash<std::thread::id>{}(id.thread_id) << 32);
Expand All @@ -84,29 +84,31 @@ class AsyncEngine;
class Program {
public:
using Kernel = taichi::lang::Kernel;
Callable *current_callable;
std::unique_ptr<SNode> snode_root; // pointer to the data structure.
void *llvm_runtime;
Callable *current_callable{nullptr};
std::unique_ptr<SNode> snode_root{nullptr}; // pointer to the data structure.
void *llvm_runtime{nullptr};
CompileConfig config;
std::unique_ptr<TaichiLLVMContext> llvm_context_host, llvm_context_device;
bool sync; // device/host synchronized?
bool finalized;
float64 total_compilation_time;
std::unique_ptr<TaichiLLVMContext> llvm_context_host{nullptr};
std::unique_ptr<TaichiLLVMContext> llvm_context_device{nullptr};
bool sync{false}; // device/host synchronized?
bool finalized{false};
float64 total_compilation_time{0.0};
static std::atomic<int> num_instances;
std::unique_ptr<ThreadPool> thread_pool;
std::unique_ptr<MemoryPool> memory_pool;
uint64 *result_buffer; // TODO: move this
void *preallocated_device_buffer; // TODO: move this to memory allocator
std::unique_ptr<ThreadPool> thread_pool{nullptr};
std::unique_ptr<MemoryPool> memory_pool{nullptr};
uint64 *result_buffer{nullptr}; // TODO: move this
void *preallocated_device_buffer{
nullptr}; // TODO: move this to memory allocator
std::unordered_map<int, SNode *> snodes;

std::unique_ptr<Runtime> runtime;
std::unique_ptr<AsyncEngine> async_engine;
std::unique_ptr<Runtime> runtime{nullptr};
std::unique_ptr<AsyncEngine> async_engine{nullptr};

std::vector<std::unique_ptr<Kernel>> kernels;
std::vector<std::unique_ptr<Function>> functions;
std::unordered_map<FunctionKey, Function *> function_map;

std::unique_ptr<KernelProfilerBase> profiler;
std::unique_ptr<KernelProfilerBase> profiler{nullptr};

std::unordered_map<JITEvaluatorId, std::unique_ptr<Kernel>>
jit_evaluator_cache;
Expand All @@ -119,7 +121,7 @@ class Program {
Program() : Program(default_compile_config.arch) {
}

Program(Arch arch);
explicit Program(Arch arch);

void kernel_profiler_print() {
profiler->print();
Expand Down
4 changes: 0 additions & 4 deletions taichi/struct/struct.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,6 @@

TLANG_NAMESPACE_BEGIN

StructCompiler::StructCompiler(Program *prog) : prog(prog) {
root_size = 0;
}

void StructCompiler::collect_snodes(SNode &snode) {
snodes.push_back(&snode);
for (int ch_id = 0; ch_id < (int)snode.ch.size(); ch_id++) {
Expand Down
5 changes: 1 addition & 4 deletions taichi/struct/struct.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,7 @@ class StructCompiler {
std::vector<SNode *> stack;
std::vector<SNode *> snodes;
std::vector<SNode *> ambient_snodes;
std::size_t root_size;
Program *prog;

explicit StructCompiler(Program *prog);
std::size_t root_size{0};

virtual ~StructCompiler() = default;

Expand Down
Loading

0 comments on commit 1d9be89

Please sign in to comment.