diff --git a/.clang-format b/.clang-format new file mode 100644 index 000000000000..9b3aa8b7213b --- /dev/null +++ b/.clang-format @@ -0,0 +1 @@ +BasedOnStyle: LLVM diff --git a/.github/workflows/documentation.yml b/.github/workflows/documentation.yml deleted file mode 100644 index e921709ba1d9..000000000000 --- a/.github/workflows/documentation.yml +++ /dev/null @@ -1,51 +0,0 @@ -name: Documentation -on: - workflow_dispatch: - schedule: - - cron: "0 0 * * *" - -jobs: - - Build-Documentation: - - runs-on: self-hosted - - steps: - - - - name: Checkout gh-pages - uses: actions/checkout@v1 - with: - ref: 'gh-pages' - - - name: Checkout branch - uses: actions/checkout@v1 - - - name: Build docs - run: | - git fetch origin master:master - cd docs - sphinx-multiversion . _build/html/ - - - name: Publish docs - run: | - git branch - # update docs - rm -r /tmp/triton-docs; - mkdir /tmp/triton-docs; - mv docs/_build/html/* /tmp/triton-docs/ - git checkout gh-pages - cp -r CNAME /tmp/triton-docs/ - cp -r index.html /tmp/triton-docs/ - cp -r .nojekyll /tmp/triton-docs/ - rm -r * - cp -r /tmp/triton-docs/* . - # ln -s master/index.html . - # mv master docs - git add . - git commit -am "[GH-PAGES] Updated website" - # publish docs - eval `ssh-agent -s` - DISPLAY=:0 SSH_ASKPASS=~/.ssh/give_pass.sh ssh-add ${{ secrets.SSH_KEY }} <<< ${{ secrets.SSH_PASS }} - git remote set-url origin git@github.com:openai/triton.git - git push diff --git a/.github/workflows/integration-tests.yml b/.github/workflows/integration-tests.yml index 45798e62882c..7891fbe435c8 100644 --- a/.github/workflows/integration-tests.yml +++ b/.github/workflows/integration-tests.yml @@ -4,8 +4,7 @@ on: workflow_dispatch: pull_request: branches: - - master - - v2.0 + - main jobs: @@ -21,7 +20,7 @@ jobs: - name: Clear cache run: | - rm -r /tmp/triton/ + rm -r ~/.triton/ continue-on-error: true - name: Install Triton diff --git a/.github/workflows/wheels.yml b/.github/workflows/wheels.yml deleted file mode 100644 index 1d8d450f2693..000000000000 --- a/.github/workflows/wheels.yml +++ /dev/null @@ -1,40 +0,0 @@ -name: Wheels -on: - workflow_dispatch: - schedule: - - cron: "0 0 * * *" - -jobs: - - Build-Wheels: - - runs-on: self-hosted - - steps: - - - name: Checkout - uses: actions/checkout@v2 - - - name: Patch setup.py - run: | - #sed -i 's/name\=\"triton\"/name="triton-nightly"/g' python/setup.py - export LATEST_DATE=$(git show -s --format=%ci `git rev-parse HEAD` | cut -d ' ' -f 1 | sed 's/-//g') - sed -i -r "s/version\=\"(.*)\"/version=\"\1-dev"$LATEST_DATE"\"/g" python/setup.py - echo "" >> python/setup.cfg - echo "[build_ext]" >> python/setup.cfg - echo "base-dir=/project" >> python/setup.cfg - - - name: Build wheels - run: | - export CIBW_MANYLINUX_X86_64_IMAGE="manylinux2014" - export CIBW_MANYLINUX_PYPY_X86_64_IMAGE="manylinux2014" - export CIBW_BEFORE_BUILD="pip install cmake;\ - yum install -y llvm11 llvm11-devel llvm11-static llvm11-libs zlib-devel;" - export CIBW_SKIP="{cp,pp}35-*" - export CIBW_BUILD="{cp,pp}3*-manylinux_x86_64" - python3 -m cibuildwheel python --output-dir wheelhouse - - - - name: Upload wheels to PyPI - run: | - python3 -m twine upload wheelhouse/* --skip-existing \ No newline at end of file diff --git a/CMakeLists.txt b/CMakeLists.txt index 6eac81fe1005..3b04f69d65c4 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -126,18 +126,10 @@ include_directories(${LLVM_INCLUDE_DIRS}) # Python module if(BUILD_PYTHON_MODULE) message(STATUS "Adding Python module") - # Build CUTLASS python wrapper if requested set(PYTHON_SRC_PATH ${CMAKE_CURRENT_SOURCE_DIR}/python/src) - set(CUTLASS_INCLUDE_DIR "$ENV{CUTLASS_INCLUDE_DIR}") - set(CUTLASS_LIBRARY_DIR "$ENV{CUTLASS_LIBRARY_DIR}") - if(NOT("${CUTLASS_INCLUDE_DIR}" STREQUAL "") AND NOT("${CUTLASS_LIBRARY_DIR}" STREQUAL "")) - set(CUTLASS_SRC ${PYTHON_SRC_PATH}/cutlass.cc) - add_definitions(-DWITH_CUTLASS_BINDINGS) - set(CUTLASS_LIBRARIES "cutlass.a") - endif() - include_directories("." ${PYTHON_SRC_PATH} ${PYTHON_INCLUDE_DIRS} ${CUTLASS_INCLUDE_DIR}) - link_directories(${PYTHON_LINK_DIRS} ${CUTLASS_LIBRARY_DIR}) - set(PYTHON_SRC ${PYTHON_SRC_PATH}/main.cc ${PYTHON_SRC_PATH}/triton.cc ${PYTHON_SRC_PATH}/superblock.cc ${CUTLASS_SRC}) + include_directories("." ${PYTHON_SRC_PATH} ${PYTHON_INCLUDE_DIRS}) + link_directories(${PYTHON_LINK_DIRS}) + set(PYTHON_SRC ${PYTHON_SRC_PATH}/main.cc ${PYTHON_SRC_PATH}/triton.cc) endif() diff --git a/deps/dlfcn-win32 b/deps/dlfcn-win32 deleted file mode 160000 index 522c301ec366..000000000000 --- a/deps/dlfcn-win32 +++ /dev/null @@ -1 +0,0 @@ -Subproject commit 522c301ec366e9b42205ae21617780d37cc0e9f0 diff --git a/include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td b/include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td index 4df2482ccbc8..0f0e570ae938 100644 --- a/include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td +++ b/include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td @@ -59,13 +59,13 @@ For example, a row-major coalesced layout may distribute a 64x16 tensor over 2 w thread tile size 2 - - - - - - /\ - - - - - - -block| thread || A_{0, 0}[T0] A_{0, 1}[T0] ... A_{0, 6}[T3] A_{0, 7}[T3] A_{0, 8}[T0] A_{0, 9}[T0] ... A_{0, 14}[T3] A_{0, 15}[T3] +warp | thread || A_{0, 0}[T0] A_{0, 1}[T0] ... A_{0, 6}[T3] A_{0, 7}[T3] A_{0, 8}[T0] A_{0, 9}[T0] ... A_{0, 14}[T3] A_{0, 15}[T3] tile | tile size 2 || A_{1, 0}[T0] A_{1, 1}[T0] ... A_{1, 6}[T3] A_{1, 7}[T3] A_{1, 8}[T0] A_{1, 9}[T0] ... A_{1, 14}[T3] A_{1, 15}[T3] size } .... 32 | A_{30, 0}[T60] A_{14, 1}[T60] ... A_{14, 6}[T63] A_{14, 7}[T63] A_{14, 8}[T60] A_{14, 9}[T60] ... A_{14, 14}[T63] A_{14, 15}[T63] | A_{31, 0}[T60] A_{15, 1}[T60] ... A_{15, 6}[T63] A_{15, 7}[T63] A_{15, 8}[T60] A_{15, 9}[T60] ... A_{15, 14}[T63] A_{15, 15}[T63] -----------------------------/\----------------------------------- - block tile size 8 + warp tile size 8 A_{32, 0}[T0] A_{32, 1}[T0] ... A_{32, 6}[T3] A_{32, 7}[T3] A_{32, 8}[T0] A_{32, 9}[T0] ... A_{32, 14}[T3] A_{32, 15}[T3] diff --git a/include/triton/codegen/analysis/align.h b/include/triton/codegen/analysis/align.h deleted file mode 100644 index 2393603cbeda..000000000000 --- a/include/triton/codegen/analysis/align.h +++ /dev/null @@ -1,80 +0,0 @@ -#ifndef TDL_INCLUDE_CODEGEN_ALIGNMENT_INFO_PASS_H -#define TDL_INCLUDE_CODEGEN_ALIGNMENT_INFO_PASS_H - -#include -#include - -namespace triton { - -namespace ir { - class value; - class module; - class phi_node; - class splat_inst; - class cast_inst; - class reshape_inst; - class broadcast_inst; - class binary_operator; - class getelementptr_inst; -} - -namespace codegen{ -namespace analysis{ - -class align { -private: - struct cst_info { - unsigned num_cst; - unsigned value; - }; - // helpers - std::vector get_shapes(ir::value *v); - // populate is_constant - std::vector populate_is_constant_phi(ir::phi_node* x); - std::vector populate_is_constant_splat(ir::splat_inst* x); - std::vector populate_is_constant_reshape(ir::reshape_inst* x); - std::vector populate_is_constant_broadcast(ir::broadcast_inst* x); - std::vector populate_is_constant_binop(ir::binary_operator* x); - std::vector populate_is_constant_gep(ir::getelementptr_inst* x); - std::vector populate_is_constant_default(ir::value* v); - std::vector populate_is_constant(ir::value *v); - // populate max_contiguous - std::vector populate_max_contiguous_phi(ir::phi_node* x); - std::vector populate_max_contiguous_splat(ir::splat_inst* x); - std::vector populate_max_contiguous_reshape(ir::reshape_inst* x); - std::vector populate_max_contiguous_broadcast(ir::broadcast_inst* x); - std::vector populate_max_contiguous_binop(ir::binary_operator* x); - std::vector populate_max_contiguous_gep(ir::getelementptr_inst* x); - std::vector populate_max_contiguous_cast(ir::cast_inst* x); - std::vector populate_max_contiguous_default(ir::value* v); - std::vector populate_max_contiguous(ir::value *v); - // populate starting_multiple - std::vector populate_starting_multiple_phi(ir::phi_node* x); - std::vector populate_starting_multiple_splat(ir::splat_inst* x); - std::vector populate_starting_multiple_reshape(ir::reshape_inst* x); - std::vector populate_starting_multiple_broadcast(ir::broadcast_inst* x); - std::vector populate_starting_multiple_binop(ir::binary_operator* x); - std::vector populate_starting_multiple_gep(ir::getelementptr_inst* x); - std::vector populate_starting_multiple_cast(ir::cast_inst* x); - std::vector populate_starting_multiple_default(ir::value* v); - std::vector populate_starting_multiple(ir::value *v); - // populate all maps - void populate(ir::value *v); - -public: - void run(ir::module &mod); - unsigned get(ir::value* v, unsigned ax) const; - std::vector contiguous(ir::value* v) const; - -private: - std::map> is_constant_; - std::map> max_contiguous_; - std::map> starting_multiple_; -}; - - -} -} -} - -#endif diff --git a/include/triton/codegen/analysis/allocation.h b/include/triton/codegen/analysis/allocation.h deleted file mode 100644 index e49f5c591026..000000000000 --- a/include/triton/codegen/analysis/allocation.h +++ /dev/null @@ -1,47 +0,0 @@ -#ifndef TDL_INCLUDE_IR_CODEGEN_STORAGE_ALLOC_H -#define TDL_INCLUDE_IR_CODEGEN_STORAGE_ALLOC_H - -#include -#include -#include -#include "triton/codegen/analysis/liveness.h" - -namespace triton{ - -namespace ir{ - class value; - class function; - class module; -} - -namespace codegen{ -namespace analysis{ - -class tiles; - -class liveness; -class cts; - -class allocation { -public: - allocation(liveness *live) - : liveness_(live) { } - // accessors - bool has_offset(const data_layout *x) const { return offsets_.find(x) != offsets_.end(); } - unsigned offset(const data_layout *x) const { return offsets_.at(x); } - unsigned allocated_size() const { return allocated_size_; } - // run - void run(ir::module& mod); - -private: - std::map offsets_; - size_t allocated_size_; - // dependences - liveness *liveness_; -}; - -} -} -} - -#endif diff --git a/include/triton/codegen/analysis/axes.h b/include/triton/codegen/analysis/axes.h deleted file mode 100644 index 759ed0f8f9f3..000000000000 --- a/include/triton/codegen/analysis/axes.h +++ /dev/null @@ -1,52 +0,0 @@ -#ifndef _TRITON_CODEGEN_ANALYSIS_AXES_H_ -#define _TRITON_CODEGEN_ANALYSIS_AXES_H_ - -#include "triton/tools/graph.h" -#include -#include - -namespace triton{ - -namespace ir{ - class value; - class module; - class instruction; -} - -namespace codegen{ -namespace analysis{ - -class axes { - typedef std::pair node_t; - -private: - // update graph - void update_graph_store(ir::instruction *i); - void update_graph_reduce(ir::instruction *i); - void update_graph_reshape(ir::instruction *i); - void update_graph_trans(ir::instruction *i); - void update_graph_broadcast(ir::instruction *i); - void update_graph_dot(ir::instruction *i); - void update_graph_elementwise(ir::instruction *i, - bool is_masked_load_async=false); - void update_graph_no_edge(ir::instruction *i); - void update_graph(ir::instruction *i); - -public: - axes(); - void run(ir::module &mod); - // accessors - int get(ir::value *value, unsigned dim); - std::vector get(ir::value *value); - -private: - tools::graph graph_; - std::map axes_; -}; - -} -} - -} - -#endif diff --git a/include/triton/codegen/analysis/layout.h b/include/triton/codegen/analysis/layout.h deleted file mode 100644 index 56fb1e4b9885..000000000000 --- a/include/triton/codegen/analysis/layout.h +++ /dev/null @@ -1,345 +0,0 @@ -#ifndef _TRITON_CODEGEN_ANALYSIS_GRID_H_ -#define _TRITON_CODEGEN_ANALYSIS_GRID_H_ - -#include -#include -#include -#include -#include "triton/tools/graph.h" -#include "triton/codegen/target.h" - -namespace triton{ - -namespace ir{ - class value; - class type; - class module; - class instruction; - class phi_node; -} - -namespace codegen{ -namespace analysis{ - -class axes; -class align; -class layout_visitor; -class data_layout; -class mma_layout; -class scanline_layout; -class shared_layout; - - -class layout_visitor { -public: - virtual void visit_layout(data_layout *); - virtual void visit_layout_mma(mma_layout*) = 0; - virtual void visit_layout_scanline(scanline_layout*) = 0; - virtual void visit_layout_shared(shared_layout*) = 0; -}; - -class data_layout { -protected: - enum id_t { - MMA, - SCANLINE, - SHARED - }; - - typedef std::vector axes_t; - typedef std::vector shape_t; - typedef std::vector order_t; - typedef std::vector values_t; - -private: - template - T* downcast(id_t id) { - if(id_ == id) - return static_cast(this); - return nullptr; - } - -public: - data_layout(id_t id, - const std::vector& axes, - const std::vector &shape, - const std::vector &values, - analysis::align* align); - // visitor - virtual void accept(layout_visitor* vst) = 0; - // downcast - mma_layout* to_mma() { return downcast(MMA); } - scanline_layout* to_scanline() { return downcast(SCANLINE); } - shared_layout* to_shared() { return downcast(SHARED); } - // accessors - size_t get_rank() { return shape_.size(); } - const shape_t& get_shape() const { return shape_; } - const order_t& get_order() const { return order_; } - const values_t& get_values() const { return values_;} - int get_axis(size_t k) const { return axes_.at(k); } - std::vector get_axes() const { return axes_; } - const int get_order(size_t k) const { return order_.at(k); } - // find the position of given axis - int find_axis(int to_find) const; - - -private: - id_t id_; - axes_t axes_; - values_t values_; - -protected: - order_t order_; - shape_t shape_; -}; - -class distributed_layout: public data_layout{ -public: - distributed_layout(id_t id, - const std::vector& axes, - const std::vector& shape, - const std::vector& values, - analysis::align* align); - - int shape_per_cta(size_t k) { return shape_per_cta_.at(k); } - int rep_per_cta(size_t k) { return shape_[k] / shape_per_cta_[k]; } - virtual int contig_per_thread(size_t k) = 0; - -protected: - std::vector shape_per_cta_; -}; - -class mma_layout: public distributed_layout { -public: - enum TensorCoreType : uint8_t { - // floating-point tensor core instr - FP32_FP16_FP16_FP32 = 0, // default - FP32_BF16_BF16_FP32, - FP32_TF32_TF32_FP32, - // integer tensor core instr - INT32_INT1_INT1_INT32, // Not implemented - INT32_INT4_INT4_INT32, // Not implemented - INT32_INT8_INT8_INT32, // Not implemented - // - NOT_APPLICABLE, - }; - - // Used on nvidia GPUs with sm >= 80 - inline static const std::map> mma_instr_shape_ = { - {FP32_FP16_FP16_FP32, {16, 8, 16}}, - {FP32_BF16_BF16_FP32, {16, 8, 16}}, - {FP32_TF32_TF32_FP32, {16, 8, 8}}, - - {INT32_INT1_INT1_INT32, {16, 8, 256}}, - {INT32_INT4_INT4_INT32, {16, 8, 64}}, - {INT32_INT8_INT8_INT32, {16, 8, 32}}, - }; - - // shape of matrices loaded by ldmatrix (m-n-k, for mxk & kxn matrices) - inline static const std::map> mma_mat_shape_ = { - {FP32_FP16_FP16_FP32, {8, 8, 8}}, - {FP32_BF16_BF16_FP32, {8, 8, 8}}, - {FP32_TF32_TF32_FP32, {8, 8, 4}}, - - {INT32_INT1_INT1_INT32, {8, 8, 64}}, - {INT32_INT4_INT4_INT32, {8, 8, 32}}, - {INT32_INT8_INT8_INT32, {8, 8, 16}}, - }; - - inline static const std::map mma_instr_ptx_ = { - {FP32_FP16_FP16_FP32, "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32"}, - {FP32_BF16_BF16_FP32, "mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32"}, - {FP32_TF32_TF32_FP32, "mma.sync.aligned.m16n8k8.row.col.f32.tf32.tf32.f32"}, - - {INT32_INT1_INT1_INT32, "mma.sync.aligned.m16n8k256.row.col.s32.b1.b1.s32.xor.popc"}, - {INT32_INT4_INT4_INT32, "mma.sync.aligned.m16n8k64.row.col.satfinite.s32.s4.s4.s32"}, - {INT32_INT8_INT8_INT32, "mma.sync.aligned.m16n8k32.row.col.satfinite.s32.s8.s8.s32"}, - }; - - // vector length per ldmatrix (16*8/elelment_size_in_bits) - inline static const std::map mma_instr_vec_ = { - {FP32_FP16_FP16_FP32, 8}, - {FP32_BF16_BF16_FP32, 8}, - {FP32_TF32_TF32_FP32, 4}, - - {INT32_INT1_INT1_INT32, 128}, - {INT32_INT4_INT4_INT32, 32}, - {INT32_INT8_INT8_INT32, 16}, - }; - -public: - mma_layout(size_t num_warps, - const std::vector& axes, - const std::vector& shapes, - const std::vector &values, - analysis::align* align, target *tgt, - shared_layout* layout_a, - shared_layout* layout_b, - ir::value *dot); - void accept(layout_visitor* vst) { vst->visit_layout_mma(this); } - // accessor - int fpw(size_t k) { return fpw_.at(k); } - int wpt(size_t k) { return wpt_.at(k); } - int spw(size_t k) { return spw_.at(k); } - int rep(size_t k) { return rep_.at(k); } - int contig_per_thread(size_t k) { return contig_per_thread_.at(k); } - - // helpers for generator.cc - std::string get_ptx_instr() const { return mma_instr_ptx_.at(tensor_core_type_); } - std::vector get_mma_instr_shape() const { return mma_instr_shape_.at(tensor_core_type_); } - std::vector get_mma_mat_shape() const { return mma_mat_shape_.at(tensor_core_type_); } - int get_vec_a() const { return mma_instr_vec_.at(tensor_core_type_); } - int get_vec_b() const { return mma_instr_vec_.at(tensor_core_type_); } - - // setter - void set_tensor_core_type(TensorCoreType type) { tensor_core_type_ = type; } - -private: - // fragment per warp - std::vector fpw_; - // shape per warp - std::vector spw_; - // warp per tile - std::vector wpt_; - // shape per tile - std::vector spt_; - // repetitions - std::vector rep_; - // contiguous per thread - std::vector contig_per_thread_; - - TensorCoreType tensor_core_type_ = FP32_FP16_FP16_FP32; -}; - -struct scanline_layout: public distributed_layout { - scanline_layout(size_t num_warps, - const std::vector& axes, - const std::vector& shape, - const std::vector &values, - analysis::align* align, - target* tgt); - void accept(layout_visitor* vst) { vst->visit_layout_scanline(this); } - // accessor - int mts(size_t k) { return mts_.at(k); } - int nts(size_t k) { return nts_.at(k); } - int contig_per_thread(size_t k) { return nts_.at(k); } - -public: - // micro tile size. The size of a tile held by a thread block. - std::vector mts_; - // nano tile size. The size of a tile held by a thread. - std::vector nts_; -}; - -struct double_buffer_info_t { - ir::value* first; - ir::value* latch; - ir::phi_node* phi; -}; - -struct N_buffer_info_t { - std::vector firsts; // not necessarily ordered as input order - ir::value* latch; - ir::phi_node* phi; - std::map firsts_idx; -}; - -// abstract for dot and coresponding smem values -class shared_layout: public data_layout { -private: - static bool is_loop_latch(ir::phi_node *phi, ir::instruction *terminator); - static void extract_double_bufferable(ir::value *v, std::shared_ptr& res); - static void extract_N_bufferable(ir::value *v, std::shared_ptr& res, int &prev_stages); - -public: - shared_layout(data_layout *arg, - const std::vector& axes, - const std::vector& shapes, - const std::vector &values_, - ir::type *ty, - analysis::align* align, target *tgt); - void accept(layout_visitor* vst) { vst->visit_layout_shared(this); } - // accessors - size_t get_size() { return size_; } - ir::type* get_type() { return ty_; } - double_buffer_info_t* get_double_buffer() { return double_buffer_.get(); } - N_buffer_info_t* get_N_buffer() { return N_buffer_.get(); } - int get_num_stages() const; - size_t get_per_stage_size() const { return size_ / get_num_stages(); } - size_t get_per_stage_elements() const; - size_t get_num_per_phase() { return num_per_phase_; } - ir::value* hmma_dot_a() { return hmma_dot_a_; } - ir::value* hmma_dot_b() { return hmma_dot_b_; } - void set_mma_vec(int mma_vec) { mma_vec_ = mma_vec; } - int get_mma_vec() { return mma_vec_;} - int get_mma_strided() { return mma_strided_; } - bool allow_swizzle() const { return allow_swizzle_; } - data_layout* get_arg_layout() { return arg_layout_; } - -private: - size_t size_; - ir::type *ty_; - std::shared_ptr double_buffer_; - std::shared_ptr N_buffer_; - size_t num_per_phase_; - ir::value* hmma_dot_a_; - ir::value* hmma_dot_b_; - data_layout* arg_layout_; - int mma_vec_; - int mma_strided_; - bool allow_swizzle_ = true; - target *tgt_; -}; - - - -class layouts { - typedef ir::value* node_t; - typedef std::map > graph_t; - -private: - // graph creation - void connect(ir::value *x, ir::value *y); - void make_graph(ir::instruction *i); - - void init_hmma_tile(data_layout& layouts); - void init_scanline_tile(data_layout &layouts); - - void create(size_t id, const std::vector& values); - -public: - // constructor - layouts(analysis::axes *axes, analysis::align *align, size_t num_warps, target* tgt); - - // accessors - unsigned layout_of(ir::value *value) const { return groups_.at(value); } - bool has(ir::value* value) const { return groups_.find(value) != groups_.end(); } - const std::vector& values_of(unsigned id) const { return values_.at(id); } - size_t num_layouts() const { return values_.size();} - data_layout* get(size_t id) { return layouts_.at(id); } - data_layout* get(ir::value *v) { return get(layout_of(v));} - std::map &get_all() { return layouts_; } - bool has_tmp(ir::value* i) { return tmp_.find(i) != tmp_.end(); } - int tmp(ir::value* i) { return tmp_.at(i);} - void copy(ir::value* dst, ir::value* src) { groups_[dst] = groups_[src]; } - // execution - void run(ir::module &mod); - -private: - analysis::axes* axes_; - analysis::align* align_; - size_t num_warps_; - target* tgt_; - tools::graph graph_; - std::map groups_; - std::map> values_; - std::map layouts_; - std::map tmp_; -}; - -} -} - -} - -#endif diff --git a/include/triton/codegen/analysis/liveness.h b/include/triton/codegen/analysis/liveness.h deleted file mode 100644 index a95d62a065cc..000000000000 --- a/include/triton/codegen/analysis/liveness.h +++ /dev/null @@ -1,67 +0,0 @@ -#ifndef TDL_INCLUDE_IR_CODEGEN_LIVENESS_H -#define TDL_INCLUDE_IR_CODEGEN_LIVENESS_H - -#include -#include -#include -#include "triton/codegen/analysis/layout.h" -#include "triton/tools/graph.h" - -namespace triton{ - -namespace ir{ - class value; - class phi_node; - class function; - class module; - class instruction; -} - -namespace codegen{ -namespace analysis{ - -typedef unsigned slot_index; - -class tiles; -class layouts; -class data_layout; - -struct segment { - slot_index start; - slot_index end; - - bool contains(slot_index idx) const { - return start <= idx && idx < end; - } - - bool intersect(const segment &Other){ - return contains(Other.start) || Other.contains(start); - } -}; - - -class liveness { -private: - typedef std::map intervals_map_t; - -public: - // constructor - liveness(layouts *l): layouts_(l){ } - // accessors - const intervals_map_t& get() const { return intervals_; } - segment get(shared_layout* v) const { return intervals_.at(v); } - // run - void run(ir::module &mod); - -private: - // analysis - layouts *layouts_; - intervals_map_t intervals_; -}; - -} -} -} - - -#endif diff --git a/include/triton/codegen/analysis/swizzle.h b/include/triton/codegen/analysis/swizzle.h deleted file mode 100644 index 6f2833a6851b..000000000000 --- a/include/triton/codegen/analysis/swizzle.h +++ /dev/null @@ -1,43 +0,0 @@ -#ifndef TRITON_INCLUDE_IR_CODEGEN_SWIZZLE_H -#define TRITON_INCLUDE_IR_CODEGEN_SWIZZLE_H - -#include - -namespace triton{ - -namespace ir{ - class module; -} - -namespace codegen{ -class target; - -namespace analysis{ - -class layouts; -class data_layout; - -class swizzle { -public: - // constructor - swizzle(layouts *l, target* tgt): layouts_(l), tgt_(tgt){ } - // accessors - int get_per_phase(data_layout* layout) { return per_phase_.at(layout); } - int get_max_phase(data_layout* layout) { return max_phase_.at(layout); } - int get_vec (data_layout* layout) { return vec_.at(layout); } - // run - void run(ir::module &mod); -private: - layouts* layouts_; - target* tgt_; - std::map per_phase_; - std::map max_phase_; - std::map vec_; -}; - -} -} -} - - -#endif diff --git a/include/triton/codegen/pass.h b/include/triton/codegen/pass.h deleted file mode 100644 index 0c8f1131593a..000000000000 --- a/include/triton/codegen/pass.h +++ /dev/null @@ -1,42 +0,0 @@ -#ifndef _TRITON_CODEGEN_PASS_H_ -#define _TRITON_CODEGEN_PASS_H_ - - -#include - -namespace llvm{ - class Module; - class LLVMContext; -} - -namespace triton{ - -namespace codegen { - class target; -} - -namespace ir{ - class module; -} -namespace driver{ - class device; - class module; - class kernel; -} -} - -namespace triton{ -namespace codegen{ - -// TODO: -// There should be a proper pass manager there! -std::unique_ptr add_passes_to_emit_bin(ir::module &ir, llvm::LLVMContext& ctx, - codegen::target* target, - int sm, int num_warps, - int num_stages, int &shared_static); - - -} -} - -#endif diff --git a/include/triton/codegen/selection/generator.h b/include/triton/codegen/selection/generator.h deleted file mode 100644 index ad7d01a55435..000000000000 --- a/include/triton/codegen/selection/generator.h +++ /dev/null @@ -1,258 +0,0 @@ -#pragma once - -#ifndef _TRITON_SELECTION_GENERATOR_H_ -#define _TRITON_SELECTION_GENERATOR_H_ - -#include "triton/ir/visitor.h" -#include "triton/codegen/analysis/layout.h" -#include - -// forward -namespace llvm{ - class Type; - class Value; - class PHINode; - class BasicBlock; - class Attribute; - class Instruction; - class Constant; - class LLVMContext; - class Module; - class ConstantFolder; - class IRBuilderDefaultInserter; - template - class IRBuilder; - class ArrayType; - class Function; -} - -namespace triton{ - -namespace ir{ -class attribute; -class load_inst; -class store_inst; -} - -namespace codegen{ - -// forward -namespace analysis{ -class liveness; -class tiles; -class align; -class allocation; -class cts; -class axes; -class layouts; -class swizzle; -} -// typedef -typedef llvm::IRBuilder Builder; -typedef llvm::LLVMContext LLVMContext; -typedef llvm::Type Type; -typedef llvm::Value Value; -typedef llvm::Attribute Attribute; -typedef llvm::BasicBlock BasicBlock; -typedef llvm::Module Module; -typedef llvm::Instruction Instruction; -typedef llvm::Constant Constant; -typedef llvm::ArrayType ArrayType; -typedef llvm::Function Function; -typedef std::vector indices_t; -class target; - -} -} - -namespace triton{ -namespace codegen{ - -struct distributed_axis { - int contiguous; - std::vector values; - Value* thread_id; -}; - -class adder{ -public: - adder(Builder** builder): builder_(builder) { } - Value* operator()(Value* x, Value* y, const std::string& name = ""); - -private: - Builder** builder_; -}; - -class multiplier{ -public: - multiplier(Builder** builder): builder_(builder) { } - Value* operator()(Value* x, Value* y, const std::string& name = ""); -private: - Builder** builder_; -}; - -class geper{ -public: - geper(Builder** builder): builder_(builder) { } - Value* operator()(Value *ptr, Value* off, const std::string& name = ""); - Value* operator()(Type* ty, Value*ptr, std::vector vals, const std::string& name = ""); - -private: - Builder** builder_; -}; - -class generator: public ir::visitor, public analysis::layout_visitor { -private: - void init_idx(ir::value *x); - Instruction* add_barrier(); - Value* shared_off(const std::vector& shapes, const std::vector& order, indices_t idx); - void finalize_shared_layout(analysis::shared_layout*); - void finalize_function(ir::function*); - void finalize_phi_node(ir::phi_node*); - -private: - Type *cvt(ir::type *ty); - llvm::Attribute cvt(ir::attribute attr); - -public: - generator(analysis::axes *a_axes, - analysis::layouts *layouts, - analysis::align *alignment, - analysis::allocation *alloc, - analysis::swizzle *swizzle, - target *tgt, - unsigned num_warps); - - void visit_value(ir::value* v); - void visit_phi_node(ir::phi_node*); - void visit_binary_operator(ir::binary_operator*); - void visit_getelementptr_inst(ir::getelementptr_inst*); - void visit_icmp_inst(ir::icmp_inst*); - void visit_fcmp_inst(ir::fcmp_inst*); - std::tuple fp8x4_to_fp32x4(Value *in0, Value *in1, Value *in2, Value *in3); - std::tuple fp32x4_to_fp8x4(Value *in0, Value *in1, Value *in2, Value *in3); - std::tuple fp8x4_to_fp16x4(Value *in0, Value *in1, Value *in2, Value *in3); - std::tuple fp16x4_to_fp8x4(Value *in0, Value *in1, Value *in2, Value *in3); - Value* bf16_to_fp32(Value *in0); - Value* fp32_to_bf16(Value *in0); - - void visit_cast_inst(ir::cast_inst*); - void visit_return_inst(ir::return_inst*); - void visit_cond_branch_inst(ir::cond_branch_inst*); - void visit_uncond_branch_inst(ir::uncond_branch_inst*); - void visit_load_inst(ir::load_inst*); - void visit_unmasked_load_inst(ir::unmasked_load_inst*); - void visit_masked_load_inst(ir::masked_load_inst*); - void visit_store_inst(ir::store_inst*); - void visit_unmasked_store_inst(ir::unmasked_store_inst*); - void visit_masked_store_inst(ir::masked_store_inst*); - void visit_cat_inst(ir::cat_inst*); - void visit_reshape_inst(ir::reshape_inst*); - void visit_splat_inst(ir::splat_inst*); - void visit_broadcast_inst(ir::broadcast_inst*); - void visit_downcast_inst(ir::downcast_inst*); - void visit_exp_inst(ir::exp_inst*); - void visit_cos_inst(ir::cos_inst*); - void visit_umulhi_inst(ir::umulhi_inst* x); - void visit_sin_inst(ir::sin_inst*); - void visit_log_inst(ir::log_inst*); - void visit_get_program_id_inst(ir::get_program_id_inst*); - void visit_get_num_programs_inst(ir::get_num_programs_inst*); - void visit_atomic_cas_inst(ir::atomic_cas_inst*); - void visit_atomic_rmw_inst(ir::atomic_rmw_inst*); - void visit_mma884(ir::dot_inst*, ir::value *A, ir::value *B, ir::value *D, unsigned NK); - void visit_mma16816(ir::dot_inst*, ir::value *A, ir::value *B, ir::value *D, unsigned NK); - void visit_fmadot(ir::dot_inst*, ir::value *A, ir::value *B, ir::value *D, unsigned NK, Type *c_ty, Function *f_mul_add); - void visit_dot_inst(ir::dot_inst*); - void visit_trans_inst(ir::trans_inst*); - void visit_sqrt_inst(ir::sqrt_inst*); - Value* shfl_sync(Value* acc, int32_t i); - void visit_reduce1d_inst(ir::reduce_inst*, std::function, Value*); - void visit_reducend_inst(ir::reduce_inst*, std::function, Value*); - void visit_reduce_inst(ir::reduce_inst*); - void visit_select_inst(ir::select_inst*); - void visit_layout_convert(ir::value *out, ir::value *in); - void visit_cvt_layout_inst(ir::cvt_layout_inst*); - void visit_masked_load_async_inst(ir::masked_load_async_inst*); - void visit_copy_to_shared_inst(ir::copy_to_shared_inst*); - void visit_copy_from_shared_inst(ir::copy_from_shared_inst*); - void visit_barrier_inst(ir::barrier_inst*); - void visit_prefetch_s_inst(ir::prefetch_s_inst*); - void visit_async_wait_inst(ir::async_wait_inst*); -// void visit_make_range_dyn(ir::make_range_dyn*); - void visit_make_range(ir::make_range*); -// void visit_make_range_sta(ir::make_range_sta*); - void visit_undef_value(ir::undef_value*); - void visit_constant_int(ir::constant_int*); - void visit_constant_fp(ir::constant_fp*); - void visit_alloc_const(ir::alloc_const*); - void visit_function(ir::function*); - void visit_basic_block(ir::basic_block*); - void visit_argument(ir::argument*); - void visit(ir::module &, llvm::Module &); - - // layouts - void visit_layout_mma(analysis::mma_layout*); - void visit_layout_scanline(analysis::scanline_layout*); - void visit_layout_shared(analysis::shared_layout*); - - -private: - LLVMContext *ctx_; - Builder* builder_; - Module *mod_; - - analysis::axes *a_axes_; - analysis::swizzle *swizzle_; - std::map axes_; - target *tgt_; - analysis::layouts *layouts_; - analysis::align *alignment_; - analysis::allocation *alloc_; - Value *shmem_; - std::set seen_; - - unsigned num_warps_; - - std::map offset_a_m_; - std::map offset_a_k_; - std::map offset_b_k_; - std::map offset_b_n_; - - /// layout -> base ptr - std::map shared_ptr_; - std::map shared_pre_ptr_; - std::map shared_next_ptr_; - /// offset for double-buffered layout - std::map shared_off_; - - /// Base shmem pointer of ir value - std::map shmems_; - std::map shoffs_; - std::map> idxs_; - std::map> vals_; - /// idx for multi-stage pipeline - std::map read_smem_idx_; - std::map write_smem_idx_; - - /// triton bb -> llvm bb - std::map bbs_; - std::map> ords_; - - // helper for creating llvm values - adder add; - multiplier mul; - geper gep; - - /// PHI nodes - std::vector> lazy_phi_incs_; - - /// Record prefetch instrs that needs to be moved - std::map> prefetch_latch_to_bb_; -}; - -} -} - -#endif diff --git a/include/triton/codegen/target.h b/include/triton/codegen/target.h deleted file mode 100644 index 96e4d5c31dde..000000000000 --- a/include/triton/codegen/target.h +++ /dev/null @@ -1,105 +0,0 @@ -#ifndef TDL_INCLUDE_IR_CODEGEN_TARGET_H -#define TDL_INCLUDE_IR_CODEGEN_TARGET_H - -namespace llvm{ - class Type; - class Value; - class Instruction; - class Constant; - class LLVMContext; - class Module; - class ConstantFolder; - class IRBuilderDefaultInserter; - template - class IRBuilder; - class ArrayType; - class Function; -} - -// typedefs -namespace triton{ -namespace codegen{ - typedef llvm::IRBuilder Builder; - typedef llvm::LLVMContext LLVMContext; - typedef llvm::Type Type; - typedef llvm::Value Value; - typedef llvm::Module Module; - typedef llvm::Instruction Instruction; - typedef llvm::Constant Constant; - typedef llvm::ArrayType ArrayType; - typedef llvm::Function Function; -} -} - -namespace triton{ -namespace codegen{ - -class nvidia_cu_target; - -class target { -public: - target(bool is_gpu): is_gpu_(is_gpu){} - virtual ~target() {} - virtual void set_kernel(Builder& builder, LLVMContext &ctx, Module *module, Function* fn) = 0; - virtual Instruction* add_barrier(Module *module, Builder& builder) = 0; - virtual Instruction* add_memfence(Module *module, Builder& builder) = 0; - virtual Value* get_global_offset(Module *module, Builder& builder, unsigned stride, unsigned ax) = 0; - virtual Value* get_local_id(Module *module, Builder& builder, unsigned ax) = 0; - virtual Value* get_block_id(Module *module, Builder& builder, unsigned ax) = 0; - virtual Value* get_num_blocks(Module *module, Builder& builder, unsigned ax) = 0; - virtual unsigned guaranteed_alignment() = 0; - nvidia_cu_target* as_nvidia(); - bool is_gpu() const; - -private: - bool is_gpu_; -}; - -class amd_cl_target: public target { -public: - amd_cl_target(): target(true){} - void set_kernel(Builder& builder, LLVMContext &ctx, Module *module, Function* fn); - Instruction* add_barrier(Module *module, Builder& builder); - Instruction* add_memfence(Module *module, Builder& builder); - Value* get_global_offset(Module *module, Builder& builder, unsigned stride, unsigned ax); - Value* get_local_id(Module *module, Builder& builder, unsigned ax); - Value* get_block_id(Module *module, Builder& builder, unsigned ax); - Value* get_num_blocks(Module *module, Builder& builder, unsigned ax); - unsigned guaranteed_alignment() { return 16; } -}; - -class nvidia_cu_target: public target { -public: - nvidia_cu_target(int sm): target(true), sm_(sm){} - void set_kernel(Builder& builder, LLVMContext &ctx, Module *module, Function* fn); - Instruction* add_barrier(Module *module, Builder& builder); - Instruction* add_memfence(Module *module, Builder& builder); - Value* get_global_offset(Module *module, Builder& builder, unsigned stride, unsigned ax); - Value* get_local_id(Module *module, Builder& builder, unsigned ax); - Value* get_block_id(Module *module, Builder& builder, unsigned ax); - Value* get_num_blocks(Module *module, Builder& builder, unsigned ax); - int sm() { return sm_; } - unsigned guaranteed_alignment() { return 16; } - -private: - int sm_; -}; - -class cpu_target: public target { -public: - cpu_target(): target(false){} - void set_kernel(Builder& builder, LLVMContext &ctx, Module *module, Function* fn); - Instruction* add_barrier(Module *module, Builder& builder); - Instruction* add_memfence(Module *module, Builder& builder); - Value* get_global_offset(Module *module, Builder& builder, unsigned stride, unsigned ax); - Value* get_local_id(Module *module, Builder& builder, unsigned ax); - Value* get_block_id(Module *module, Builder& builder, unsigned ax); - Value* get_num_blocks(Module *module, Builder& builder, unsigned ax); - unsigned guaranteed_alignment() { return 1; } -}; - -} -} - -#endif diff --git a/include/triton/codegen/transform/coalesce.h b/include/triton/codegen/transform/coalesce.h deleted file mode 100644 index 869ca9975658..000000000000 --- a/include/triton/codegen/transform/coalesce.h +++ /dev/null @@ -1,48 +0,0 @@ -#ifndef TDL_INCLUDE_CODEGEN_OPTIMIZE_REORDER_H -#define TDL_INCLUDE_CODEGEN_OPTIMIZE_REORDER_H - -#include -#include -#include - -namespace triton { - -namespace ir { - class module; - class value; - class io_inst; - class instruction; - class builder; -} - -namespace codegen{ - -namespace analysis{ - class align; - class layouts; - class cts; -} - -namespace transform{ - -class coalesce { -private: - void extract_io_use(ir::value *v, std::set& result); - void extract_ld(ir::io_inst *i, std::map > &result); - ir::value* rematerialize(ir::value *v, ir::builder& builder, std::map& seen); - -public: - coalesce(analysis::align* align, triton::codegen::analysis::layouts *layouts); - triton::ir::value *simplify(ir::instruction* i, triton::ir::builder &builder); - void run(ir::module &mod); - -private: - analysis::align* align_; - analysis::layouts* layout_; -}; - -} -} -} - -#endif diff --git a/include/triton/codegen/transform/cts.h b/include/triton/codegen/transform/cts.h deleted file mode 100644 index 70fbc474baf3..000000000000 --- a/include/triton/codegen/transform/cts.h +++ /dev/null @@ -1,36 +0,0 @@ -#ifndef TDL_INCLUDE_CODEGEN_BUFFER_INFO_PASS_H -#define TDL_INCLUDE_CODEGEN_BUFFER_INFO_PASS_H - -#include -#include - -namespace triton { - -namespace ir { - class module; - class value; - class phi_node; - class instruction; - class builder; -} - -namespace codegen{ -namespace transform{ - -class cts { -private: - void add_copy(ir::instruction *parent, ir::value *x, ir::builder &builder, bool to_shared); - -public: - cts(bool use_async = false): use_async_(use_async) {} - void run(ir::module &mod); - -private: - bool use_async_; -}; - -} -} -} - -#endif \ No newline at end of file diff --git a/include/triton/codegen/transform/dce.h b/include/triton/codegen/transform/dce.h deleted file mode 100644 index 8bed0afef4f6..000000000000 --- a/include/triton/codegen/transform/dce.h +++ /dev/null @@ -1,24 +0,0 @@ -#ifndef TDL_INCLUDE_CODEGEN_OPTIMIZE_CSE_H -#define TDL_INCLUDE_CODEGEN_OPTIMIZE_CSE_H - - -namespace triton { - -namespace ir { - class module; -} - -namespace codegen{ -namespace transform{ - -class dce { -public: - dce() {} - void run(ir::module &mod); -}; - -} -} -} - -#endif diff --git a/include/triton/codegen/transform/disassociate.h b/include/triton/codegen/transform/disassociate.h deleted file mode 100644 index f2363f3fe2f6..000000000000 --- a/include/triton/codegen/transform/disassociate.h +++ /dev/null @@ -1,22 +0,0 @@ -#ifndef _TRITON_SELECTION_TRANSFORM_DISASSOCIATE_H_ -#define _TRITON_SELECTION_TRANSFORM_DISASSOCIATE_H_ - - -namespace triton { -namespace ir { - class module; -} - -namespace codegen{ -namespace transform{ - -class disassociate { -public: - void run(ir::module &mod); -}; - -} -} -} - -#endif diff --git a/include/triton/codegen/transform/membar.h b/include/triton/codegen/transform/membar.h deleted file mode 100644 index 21145a4fe769..000000000000 --- a/include/triton/codegen/transform/membar.h +++ /dev/null @@ -1,72 +0,0 @@ -#ifndef TDL_INCLUDE_CODEGEN_BARRIERS_H -#define TDL_INCLUDE_CODEGEN_BARRIERS_H - -#include -#include -#include -#include -#include "triton/codegen/target.h" - -namespace triton { - -namespace ir { - class module; - class basic_block; - class instruction; - class masked_load_async_inst; - class value; - class builder; -} - -namespace codegen{ - -namespace analysis{ - -class allocation; -class liveness; -class layouts; -class cts; -class shared_layout; - -} - -namespace transform{ - -class prefetch; - -class membar { -private: - typedef std::pair interval_t; - typedef std::set val_set_t; - typedef std::vector val_vec_t; - -private: - bool intersect(const val_set_t &X, const val_set_t &Y); - bool check_safe_war(ir::instruction* i); - int group_of(triton::ir::value *i, std::vector &async_write); - bool intersect_with(analysis::shared_layout* a_layout, analysis::shared_layout* b_layout); - val_set_t intersect_with(const val_set_t& as, const val_set_t& bs); - void transfer(ir::basic_block *block, val_vec_t &async_write, val_set_t &sync_write, val_set_t &sync_read, - std::set &safe_war, bool &inserted, ir::builder &builder); - -public: - membar(analysis::liveness *liveness, analysis::layouts *layouts, analysis::allocation *alloc, - transform::prefetch *prefetch, target* tgt): - liveness_(liveness), layouts_(layouts), alloc_(alloc), prefetch_(prefetch), tgt_(tgt) {} - void run(ir::module &mod); - -private: - analysis::liveness *liveness_; - analysis::layouts *layouts_; - analysis::allocation *alloc_; - transform::prefetch *prefetch_; - - target* tgt_; -}; - - -} -} -} - -#endif diff --git a/include/triton/codegen/transform/peephole.h b/include/triton/codegen/transform/peephole.h deleted file mode 100644 index 0e1ed222e783..000000000000 --- a/include/triton/codegen/transform/peephole.h +++ /dev/null @@ -1,53 +0,0 @@ -#ifndef TDL_INCLUDE_CODEGEN_OPTIMIZE_TRANS_H -#define TDL_INCLUDE_CODEGEN_OPTIMIZE_TRANS_H - -#include "triton/codegen/target.h" - -namespace triton { - -namespace ir { - class module; - class value; - class instruction; - class trans_inst; - class builder; - class constant_int; - class dot_inst; -} - -namespace codegen{ -namespace analysis{ -class layouts; -} - -namespace transform{ - -class peephole { -private: -// bool rewrite_cts_cfs(ir::instruction *value, ir::builder &builder); - bool rewrite_trans_phi(ir::instruction* value, ir::builder &builder); - bool rewrite_dot_fp32(ir::dot_inst *dot, ir::builder& builder, bool trans_a, bool trans_b, ir::value *A, ir::value *B, ir::value *D); - bool rewrite_dot_hmma(ir::dot_inst *dot, ir::builder& builder, bool trans_a, bool trans_b, ir::value *A, ir::value *B, ir::value *D); - bool rewrite_dot(ir::instruction *value, ir::builder& builder); - bool rewrite_mult(ir::instruction *value, ir::builder& builder); - bool rewrite_unit_red(ir::instruction *value, ir::builder& builder); - bool rewrite_gep_ptr_min_off_plus_off(ir::instruction *value, ir::builder& builder); - bool rewrite_select_masked_load(ir::instruction *value, ir::builder& builder); - bool rewrite_load_to_shared(ir::instruction *value, ir::builder& builder); - bool rewrite_cvt_layout(ir::instruction *value, ir::builder& builder); - -public: - peephole(target* tgt, analysis::layouts* layouts): tgt_(tgt), layouts_(layouts) {} - void run(ir::module &mod); - -private: - target* tgt_; - analysis::layouts* layouts_; -}; - - -} -} -} - -#endif diff --git a/include/triton/codegen/transform/pipeline.h b/include/triton/codegen/transform/pipeline.h deleted file mode 100644 index 35472de040f3..000000000000 --- a/include/triton/codegen/transform/pipeline.h +++ /dev/null @@ -1,30 +0,0 @@ -#ifndef TRITON_INCLUDE_IR_CODEGEN_PIPELINE_H -#define TRITON_INCLUDE_IR_CODEGEN_PIPELINE_H - -// forward declaration -namespace triton { -namespace ir { -class module; -} -} // namespace triton - -namespace triton { -namespace codegen { -namespace transform { - -class pipeline { -public: - pipeline(bool has_copy_async, int num_stages) - : has_copy_async_(has_copy_async), num_stages_(num_stages) {} - void run(ir::module &module); - -private: - bool has_copy_async_; - int num_stages_; -}; - -} // namespace transform -} // namespace codegen -} // namespace triton - -#endif diff --git a/include/triton/codegen/transform/prefetch.h b/include/triton/codegen/transform/prefetch.h deleted file mode 100644 index 6843b54633fc..000000000000 --- a/include/triton/codegen/transform/prefetch.h +++ /dev/null @@ -1,27 +0,0 @@ -#ifndef TRITON_INCLUDE_TRITON_CODEGEN_TRANSFORM_PREFETCH_H -#define TRITON_INCLUDE_TRITON_CODEGEN_TRANSFORM_PREFETCH_H - -#include - -// forward dclaration -namespace triton::ir{ -class module; -class value; -} - -namespace triton::codegen { -class target; -} - -namespace triton::codegen::transform { -class prefetch { - target* tgt_; - std::set prefetched_vals_; -public: - prefetch(target *tgt) : tgt_(tgt) {} - void run(ir::module &module); - bool is_prefetched(ir::value* v) { return prefetched_vals_.find(v) != prefetched_vals_.end(); } -}; -} - -#endif \ No newline at end of file diff --git a/include/triton/codegen/transform/reorder.h b/include/triton/codegen/transform/reorder.h deleted file mode 100644 index 3b48a330ff5c..000000000000 --- a/include/triton/codegen/transform/reorder.h +++ /dev/null @@ -1,26 +0,0 @@ -#ifndef TRITON_INCLUDE_IR_CODEGEN_REORDER_H -#define TRITON_INCLUDE_IR_CODEGEN_REORDER_H - -namespace triton { - -// forward declaration -namespace ir { -class module; -} - -namespace codegen{ - -namespace transform{ - -class reorder { -public: - void run(ir::module& module); -}; - -} - -} - -} - -#endif diff --git a/lib/codegen/CMakeLists.txt b/lib/codegen/CMakeLists.txt deleted file mode 100644 index a91806250732..000000000000 --- a/lib/codegen/CMakeLists.txt +++ /dev/null @@ -1,5 +0,0 @@ -file(GLOB_RECURSE CODEGEN_SRC *.cc) - -add_library(TritonCodeGen - ${CODEGEN_SRC} -) diff --git a/lib/codegen/analysis/align.cc b/lib/codegen/analysis/align.cc deleted file mode 100644 index e92d3b6ee73c..000000000000 --- a/lib/codegen/analysis/align.cc +++ /dev/null @@ -1,533 +0,0 @@ -#include "triton/codegen/analysis/align.h" -#include "triton/ir/utils.h" -#include "triton/ir/module.h" -#include "triton/ir/function.h" -#include "triton/ir/basic_block.h" -#include "triton/ir/instructions.h" -#include "triton/ir/type.h" -#include - -namespace triton { -namespace codegen{ -namespace analysis{ - - -// Function for extended Euclidean Algorithm -int gcd_impl(int a, int b, int *x, int *y) -{ - // Base Case - if (a == 0) - { - *x = 0; - *y = 1; - return b; - } - - int x1, y1; // To store results of recursive call - int gcd = gcd_impl(b%a, a, &x1, &y1); - - // Update x and y using results of - // recursive call - *x = y1 - (b/a) * x1; - *y = x1; - - return gcd; -} - -int gcd(int a, int b) { - int x, y; - return gcd_impl(a, b, &x, &y); -} - - -inline int lcm(int a, int b) { - return (a * b) / gcd(a, b); -} - -template -inline T add_to_cache(ir::value *i, T value, std::map &map) { - return map[i] = value; -} - -/* - * is constant - */ - -std::vector align::get_shapes(ir::value *v) { - ir::type *ty = v->get_type(); - if(ty->is_block_ty()) - return ty->get_block_shapes(); - else - return {1}; -} - -std::vector align::populate_is_constant_phi(ir::phi_node* x) { - auto shapes = get_shapes(x); - std::vector result(shapes.size(), cst_info{1, 0}); - for(unsigned n = 0; n < x->get_num_incoming(); n++){ - ir::value* inc = x->get_incoming_value(n); - auto it = is_constant_.find(inc); - if(it != is_constant_.end()) - result = it->second; - } - return add_to_cache(x, result, is_constant_); - // recurse - for(unsigned n = 0; n < x->get_num_incoming(); n++){ - ir::value* inc = x->get_incoming_value(n); - auto cst = populate_is_constant(inc); - for(size_t d = 0; d < cst.size(); d++) - result[d].num_cst = std::min(result[d].num_cst, cst[d].num_cst); - } - return add_to_cache(x, result, is_constant_); -} - -std::vector align::populate_is_constant_splat(ir::splat_inst* x) { - auto shapes = get_shapes(x); - ir::value* op = x->get_operand(0); - std::vector result; - auto op_cst = populate_is_constant(op); - for(auto d: shapes) - result.push_back(cst_info{d, op_cst[0].value}); - return add_to_cache(x, result, is_constant_); -} - -std::vector align::populate_is_constant_reshape(ir::reshape_inst* x) { - auto x_shapes = get_shapes(x); - std::vector result; - ir::value *op = x->get_operand(0); - auto op_shapes = op->get_type()->get_block_shapes(); - auto op_cst = populate_is_constant(op); - unsigned current = 0; - bool is_skewed = false; - for(size_t d = 0; d < x_shapes.size(); d ++){ - cst_info ax ; - if(x_shapes[d] == 1) - ax = {1, op_cst[current].value}; - else if(!is_skewed - && x_shapes[d] == op_shapes[current]) - ax = {x_shapes[d], op_cst[current++].value}; - else { - is_skewed = true; - ax = {x_shapes[d], 0}; - } - result.push_back(ax); - } - return add_to_cache(x, result, is_constant_); -} - -std::vector align::populate_is_constant_broadcast(ir::broadcast_inst* x) { - auto x_shapes = get_shapes(x); - std::vector result; - ir::value *op = x->get_operand(0); - auto op_shapes = op->get_type()->get_block_shapes(); - auto op_cst = populate_is_constant(op); - for(size_t d = 0; d < x_shapes.size(); d++) - if(op_shapes[d] == 1) - result.push_back(cst_info{x_shapes[d], op_cst[d].value}); - else - result.push_back(op_cst[d]); - return add_to_cache(x, result, is_constant_); -} - -std::vector align::populate_is_constant_binop(ir::binary_operator* x) { - auto x_shapes = get_shapes(x); - std::vector result; - ir::value* lhs_op = x->get_operand(0); - ir::value* rhs_op = x->get_operand(1); - auto lhs = populate_is_constant(lhs_op); - auto rhs = populate_is_constant(rhs_op); - auto max_contiguous = populate_max_contiguous(lhs_op); - for(size_t d = 0; d < x_shapes.size(); d++) { - cst_info ax; - if(lhs[d].num_cst==0 && rhs[d].value && x->is_int_div()){ - // todo might not be entirely true - unsigned num_constants = gcd(max_contiguous[d], rhs[d].value); - ax = {num_constants, 0}; - } - else - ax = {std::min(lhs[d].num_cst, rhs[d].num_cst), 0}; - result.push_back(ax); - } - return add_to_cache(x, result, is_constant_); -} - -std::vector align::populate_is_constant_gep(ir::getelementptr_inst* x) { - auto x_shapes = get_shapes(x); - ir::value* lhs_op = x->get_operand(0); - ir::value* rhs_op = x->get_operand(1); - auto lhs = populate_is_constant(lhs_op); - auto rhs = populate_is_constant(rhs_op); - std::vector result; - for(size_t d = 0; d < x_shapes.size(); d++) - result.push_back({std::min(lhs[d].num_cst, rhs[d].num_cst), 0}); - return add_to_cache(x, result, is_constant_); -} - -std::vector align::populate_is_constant_default(ir::value *v) { - auto shapes = get_shapes(v); - std::vector result(shapes.size(), {1, 0}); - return add_to_cache(v, result, is_constant_); -} - -std::vector align::populate_is_constant(ir::value *v) { - if(is_constant_.find(v) != is_constant_.end()) - return is_constant_.at(v); - if(auto *x = dynamic_cast(v)) - return add_to_cache(v, {cst_info{true, std::min(x->get_value(), 128)}}, is_constant_); - if(auto *x = dynamic_cast(v)) - return populate_is_constant_phi(x); - if(auto *x = dynamic_cast(v)) - return populate_is_constant_splat(x); - if(auto *x = dynamic_cast(v)) - return populate_is_constant_reshape(x); - if(auto *x = dynamic_cast(v)) - return populate_is_constant_broadcast(x); - if(auto *x = dynamic_cast(v)) - return populate_is_constant_binop(x); - if(auto *x = dynamic_cast(v)) - return populate_is_constant_gep(x); - return populate_is_constant_default(v); -} - - -/* - * max contiguous - */ - -std::vector align::populate_max_contiguous_phi(ir::phi_node* x) { - auto shapes = get_shapes(x); - std::vector result(shapes.size(), 1); - for(unsigned n = 0; n < x->get_num_incoming(); n++){ - ir::value* inc = x->get_incoming_value(n); - auto it = max_contiguous_.find(inc); - if(it != max_contiguous_.end()) - result = it->second; - } - add_to_cache(x, result, max_contiguous_); - // recurse - for(unsigned n = 0; n < x->get_num_incoming(); n++){ - ir::value* inc = x->get_incoming_value(n); - auto contiguous = populate_max_contiguous(inc); - for(size_t d = 0; d < result.size(); d++) - result[d] = std::min(result[d], contiguous[d]); - } - return add_to_cache(x, result, max_contiguous_); - -} - -std::vector align::populate_max_contiguous_splat(ir::splat_inst* x) { - auto x_shapes = get_shapes(x); - std::vector result; - for(size_t d = 0; d < x_shapes.size(); d++) - result.push_back({1}); - return add_to_cache(x, result, max_contiguous_); -} - -std::vector align::populate_max_contiguous_reshape(ir::reshape_inst* x) { - auto shapes = get_shapes(x); - std::vector result; - ir::value *op = x->get_operand(0); - auto op_shapes = op->get_type()->get_block_shapes(); - auto op_mc = populate_max_contiguous(op); - unsigned current = 0; - bool is_skewed = false; - for(size_t d = 0; d < shapes.size(); d ++){ - if(shapes[d] == 1) - result.push_back(1); - else if(!is_skewed - && shapes[d] == op_shapes[current]) - result.push_back(op_mc[current++]); - else { - is_skewed = true; - result.push_back(1); - } - } - return add_to_cache(x, result, max_contiguous_); -} - -std::vector align::populate_max_contiguous_broadcast(ir::broadcast_inst* x) { - auto shapes = get_shapes(x); - std::vector result; - ir::value *op = x->get_operand(0); - auto op_shapes = op->get_type()->get_block_shapes(); - auto op_mc = populate_max_contiguous(op); - for(size_t d = 0; d < shapes.size(); d++) - if(op_shapes[d] == 1) - result.push_back(1); - else - result.push_back(op_mc[d]); - return add_to_cache(x, result, max_contiguous_); -} - -std::vector align::populate_max_contiguous_binop(ir::binary_operator* x) { - auto shapes = get_shapes(x); - ir::value* lhs = x->get_operand(0); - ir::value* rhs = x->get_operand(1); - auto lhs_max_contiguous = populate_max_contiguous(lhs); - auto rhs_max_contiguous = populate_max_contiguous(rhs); - auto lhs_cst_info = populate_is_constant(lhs); - auto rhs_cst_info = populate_is_constant(rhs); - auto lhs_starting_multiple = populate_starting_multiple(lhs); - auto rhs_starting_multiple = populate_starting_multiple(rhs); - std::vector result; - for(size_t d = 0; d < shapes.size(); d++){ - unsigned value = 1; - if(x->is_int_rem() && rhs_starting_multiple[d] > 0){ - value = std::min(lhs_max_contiguous[d], rhs_starting_multiple[d]); - } - if(x->is_int_mult()){ - unsigned lvalue = 1, rvalue = 1; - if(rhs_cst_info[d].value == 1) - lvalue = lhs_max_contiguous[d]; - if(lhs_cst_info[d].value == 1) - rvalue = rhs_max_contiguous[d]; - value = std::max(lvalue, rvalue); - } - if(x->is_int_add_sub()){ - unsigned lvalue = 1, rvalue = 1; - lvalue = gcd(rhs_max_contiguous[d], lhs_starting_multiple[d]); - rvalue = gcd(lhs_max_contiguous[d], rhs_starting_multiple[d]); - value = std::max(lvalue, rvalue); - } - result.push_back(value); - } - return add_to_cache(x, result, max_contiguous_); -} - -std::vector align::populate_max_contiguous_gep(ir::getelementptr_inst* x) { - auto shapes = get_shapes(x); - ir::value* lhs = x->get_operand(0); - ir::value* rhs = x->get_operand(1); - auto lhs_max_contiguous = populate_max_contiguous(lhs); - auto rhs_max_contiguous = populate_max_contiguous(rhs); - auto lhs_cst_info = populate_is_constant(lhs); - auto rhs_cst_info = populate_is_constant(rhs); - std::vector result(shapes.size(), 1); - for(size_t d = 0; d < shapes.size(); d++){ - unsigned lvalue = 1, rvalue = 1; - if(lhs_cst_info[d].num_cst) - lvalue = rhs_max_contiguous[d]; - if(rhs_cst_info[d].num_cst) - rvalue = lhs_max_contiguous[d]; - result[d] = std::max(lvalue, rvalue); - } - return add_to_cache(x, result, max_contiguous_); -} - -std::vector align::populate_max_contiguous_default(ir::value* v) { - if(!v->get_type()->is_block_ty()) - return add_to_cache(v, {1}, max_contiguous_); - auto shapes = v->get_type()->get_block_shapes(); - if(dynamic_cast(v)) - return add_to_cache(v, {shapes[0]}, max_contiguous_); - return add_to_cache(v, std::vector(shapes.size(), 1), max_contiguous_); -} - -std::vector align::populate_max_contiguous_cast(ir::cast_inst* v){ - auto result = populate_max_contiguous(v->get_operand(0)); - return add_to_cache(v, result, max_contiguous_); -} - -std::vector align::populate_max_contiguous(ir::value *v){ - if(max_contiguous_.find(v) != max_contiguous_.end()) - return max_contiguous_.at(v); - if(auto *x = dynamic_cast(v)){ - unsigned max_contiguous = x->get_metadata(ir::metadata::max_contiguous); - if(max_contiguous > 0) - return add_to_cache(x, {max_contiguous}, max_contiguous_); - } - if(auto *x = dynamic_cast(v)) - return populate_max_contiguous_cast(x); - if(auto *x = dynamic_cast(v)) - return populate_max_contiguous_splat(x); - if(auto *x = dynamic_cast(v)) - return populate_max_contiguous_reshape(x); - if(auto *x = dynamic_cast(v)) - return populate_max_contiguous_broadcast(x); - if(auto *x = dynamic_cast(v)) - return populate_max_contiguous_binop(x); - if(auto *x = dynamic_cast(v)) - return populate_max_contiguous_gep(x); - if(auto *x = dynamic_cast(v)) - return populate_max_contiguous_phi(x); - return populate_max_contiguous_default(v); -} - - -/* - * starting multiple - */ - -std::vector align::populate_starting_multiple_splat(ir::splat_inst* x){ - auto shapes = get_shapes(x); - auto op = populate_starting_multiple(x->get_operand(0)); - std::vector result(shapes.size(), op[0]); - return add_to_cache(x, result, starting_multiple_); -} - -std::vector align::populate_starting_multiple_reshape(ir::reshape_inst* x){ - auto op = populate_starting_multiple(x->get_operand(0)); - auto op_shapes = get_shapes(x->get_operand(0)); - auto shapes = get_shapes(x); - std::vector result(shapes.size(), 1); - unsigned current = 0; - bool is_skewed = false; - for(size_t d = 0; d < shapes.size(); d ++){ - if(shapes[d] == 1) - result[d] = 1; - else if(!is_skewed - && shapes[d] == op_shapes[current]) - result[d] = op[current++]; - else { - is_skewed = true; - result[d] = 1; - } - } - return add_to_cache(x, result, starting_multiple_); -} - -std::vector align::populate_starting_multiple_broadcast(ir::broadcast_inst* x){ - auto result = populate_starting_multiple(x->get_operand(0)); - return add_to_cache(x, result, starting_multiple_); -} - -std::vector align::populate_starting_multiple_binop(ir::binary_operator* x){ - auto lhs = populate_starting_multiple(x->get_operand(0)); - auto rhs = populate_starting_multiple(x->get_operand(1)); - std::vector result(lhs.size(), 1); - for(size_t d = 0; d < lhs.size(); d++){ - if(x->is_int_mult()) - result[d] = lhs[d] * rhs[d]; - if(x->is_int_add_sub()) - result[d] = gcd(lhs[d], rhs[d]); - if(x->is_int_div()) - result[d] = 1; - if(x->is_int_rem() && rhs[d] > 1){ - result[d] = gcd(lhs[d], rhs[d]); - } - if(x->is_shl()) - result[d] = lhs[d] << rhs[d]; - if(x->is_shr()) - result[d] = std::max(lhs[d] >> rhs[d], 1); - } - return add_to_cache(x, result, starting_multiple_); -} - -std::vector align::populate_starting_multiple_gep(ir::getelementptr_inst* x){ - auto lhs = populate_starting_multiple(x->get_operand(0)); - auto rhs = populate_starting_multiple(x->get_operand(1)); - std::vector result(lhs.size(), 1); - for(size_t d = 0; d < lhs.size(); d++){ - result[d] = gcd(lhs[d], rhs[d]); -// std::cout << "starting multiple: " << x->get_name() << " " << d << " " << result[d] << std::endl; - } - return add_to_cache(x, result, starting_multiple_); -} - -std::vector align::populate_starting_multiple_phi(ir::phi_node* x){ - auto shape = get_shapes(x); - std::vector result(shape.size(), 1); - for(unsigned n = 0; n < x->get_num_incoming(); n++){ - ir::value* inc = x->get_incoming_value(n); - if(starting_multiple_.find(inc) != starting_multiple_.end()) - result = starting_multiple_.at(inc); - } - add_to_cache(x, result, starting_multiple_); - // recurse - for(unsigned n = 0; n < x->get_num_incoming(); n++){ - ir::value* inc = x->get_incoming_value(n); - auto sm = populate_starting_multiple(inc); - for(size_t d = 0; d < result.size(); d++) - result[d] = gcd(result[d], sm[d]); - } - return add_to_cache(x, result, starting_multiple_); -} - - -std::vector align::populate_starting_multiple_cast(ir::cast_inst* x){ - auto result = populate_starting_multiple(x->get_operand(0)); - return add_to_cache(x, result, starting_multiple_); -} - -std::vector align::populate_starting_multiple_default(ir::value* v) { - ir::type* ty = v->get_type(); - if(ty->is_block_ty()) { - return add_to_cache(v, ty->get_block_shapes(), starting_multiple_); - } - if(auto *x = dynamic_cast(v)){ - std::set attributes = x->get_parent()->get_attributes(x); - for(auto attr: attributes){ - if(attr.get_kind() == ir::multiple_of){ - return add_to_cache(x, {attr.get_value()}, starting_multiple_); - } - if(attr.get_kind() == ir::aligned){ - ir::type* ty = x->get_type()->get_pointer_element_ty(); - int nbits = ty->get_primitive_size_in_bits(); - int nbytes = std::max(nbits / 8, 1); - return add_to_cache(x, {attr.get_value() / nbytes}, starting_multiple_); - } - } - } - return add_to_cache(v, {1}, starting_multiple_); -} - -std::vector align::populate_starting_multiple(ir::value *v){ - if(starting_multiple_.find(v) != starting_multiple_.end()) - return starting_multiple_.at(v); - if(auto *x = dynamic_cast(v)){ - unsigned multiple_of = x->get_metadata(ir::metadata::multiple_of); - if(multiple_of > 0) - return add_to_cache(x, {multiple_of}, starting_multiple_); - } - if(auto *x = dynamic_cast(v)) - return populate_starting_multiple_cast(x); - if(auto *x = dynamic_cast(v)) - return populate_starting_multiple_binop(x); - if(auto *x = dynamic_cast(v)) - return add_to_cache(x, {std::min(x->get_value(), 128)}, starting_multiple_); - if(auto *x = dynamic_cast(v)) - return add_to_cache(x, {(unsigned)x->get_first()->get_value()}, starting_multiple_); - if(auto *x = dynamic_cast(v)) - return populate_starting_multiple_gep(x); - if(auto *x = dynamic_cast(v)) - return populate_starting_multiple_splat(x); - if(auto *x = dynamic_cast(v)) - return populate_starting_multiple_reshape(x); - if(auto *x = dynamic_cast(v)) - return populate_starting_multiple_broadcast(x); - if(auto *x = dynamic_cast(v)) - return populate_starting_multiple_phi(x); - return populate_starting_multiple_default(v); -} - - -unsigned align::get(ir::value *v, unsigned ax) const { - unsigned starting_multiple = starting_multiple_.at(v)[ax]; - unsigned max_contiguous = max_contiguous_.at(v)[ax]; - return std::min(starting_multiple, max_contiguous); -} - -std::vector align::contiguous(ir::value* v) const { - return max_contiguous_.at(v); -} - - -void align::populate(ir::value *v) { - populate_is_constant(v); - populate_starting_multiple(v); - populate_max_contiguous(v); - -} - -void align::run(ir::module &mod) { - ir::for_each_value(mod, [this](ir::value* v) { populate(v); } ); -// ir::for_each_value(mod, [this](ir::value* v) { -// if(dynamic_cast(v) || dynamic_cast(v)) -// std::cout << "ALIGN: " << v->get_name() << " " << max_contiguous_.at(v)[0] << " " << max_contiguous_.at(v)[1] << std::endl; -// }); -} - - -} -} -} diff --git a/lib/codegen/analysis/allocation.cc b/lib/codegen/analysis/allocation.cc deleted file mode 100644 index 3af40c2cc7d0..000000000000 --- a/lib/codegen/analysis/allocation.cc +++ /dev/null @@ -1,101 +0,0 @@ -#include -#include -#include "triton/codegen/analysis/layout.h" -#include "triton/codegen/analysis/allocation.h" -#include "triton/codegen/analysis/liveness.h" -#include "triton/ir/utils.h" - -namespace triton{ -namespace codegen{ -namespace analysis{ - - -void allocation::run(ir::module &mod) { - using std::max; - using std::min; - typedef std::multimap triples_map_type; - - std::vector I; - for(auto x: liveness_->get()) - I.push_back(x.first); - std::vector J = I; - - triples_map_type H; - H.insert({0, segment{0, INT_MAX}}); - - std::vector V; - std::map starts; - while(!J.empty()){ - auto h_it = H.begin(); - unsigned w = h_it->first; - segment xh = h_it->second; - H.erase(h_it); - auto j_it = std::find_if(J.begin(), J.end(), [&](shared_layout* JJ){ - segment xj = liveness_->get(JJ); - bool res = xj.intersect(xh); - for(auto val: H) - res = res && !val.second.intersect(xj); - return res; - }); - if(j_it != J.end()){ - unsigned size = (*j_it)->get_size(); - segment xj = liveness_->get(*j_it); - starts[*j_it] = w; - H.insert({w + size, segment{max(xh.start, xj.start), min(xh.end, xj.end)}}); - if(xh.start < xj.start) - H.insert({w, segment{xh.start, xj.end}}); - if(xj.end < xh.end) - H.insert({w, segment{xj.start, xh.end}}); - V.push_back(*j_it); - J.erase(j_it); - } - } - // Build interference graph - std::map> interferences; - for(shared_layout* x: V) - for(shared_layout* y: V){ - if(x == y) - continue; - unsigned X0 = starts[x], Y0 = starts[y]; - unsigned NX = x->get_size(); - unsigned NY = y->get_size(); - segment XS = {X0, X0 + NX}; - segment YS = {Y0, Y0 + NY}; - if(liveness_->get(x).intersect(liveness_->get(y)) - && XS.intersect(YS)) - interferences[x].insert(y); - } - // Initialize colors - std::map colors; - for(shared_layout* X: V) - colors[X] = (X==V[0])?0:-1; - // First-fit graph coloring - std::vector available(V.size()); - for(shared_layout* x: V){ - // Non-neighboring colors are available - std::fill(available.begin(), available.end(), true); - for(shared_layout* Y: interferences[x]){ - int color = colors[Y]; - if(color >= 0) - available[color] = false; - } - // Assigns first available color - auto It = std::find(available.begin(), available.end(), true); - colors[x] = std::distance(available.begin(), It); - } - // Finalize allocation - for(shared_layout* x: V){ - unsigned Adj = 0; - for(shared_layout* y: interferences[x]) - Adj = std::max(Adj, starts[y] + y->get_size()); - offsets_[x] = starts[x] + colors[x] * Adj; - } - // Save maximum size of induced memory space - allocated_size_ = 0; - for(shared_layout* x: V) - allocated_size_ = std::max(allocated_size_, starts[x] + x->get_size()); -} - -} -} -} diff --git a/lib/codegen/analysis/axes.cc b/lib/codegen/analysis/axes.cc deleted file mode 100644 index f079d2580ff7..000000000000 --- a/lib/codegen/analysis/axes.cc +++ /dev/null @@ -1,162 +0,0 @@ -#include "triton/codegen/analysis/axes.h" -#include "triton/ir/utils.h" -#include "triton/ir/instructions.h" -#include "triton/ir/type.h" -#include - - -namespace triton{ -namespace codegen{ -namespace analysis{ - -axes::axes() {} - -void axes::update_graph_reduce(ir::instruction *i) { - auto* red = static_cast(i); - unsigned axis = red->get_axis(); - ir::value *arg = red->get_operand(0); - auto in_shapes = arg->get_type()->get_block_shapes(); - unsigned current = 0; - for(unsigned d = 0; d < in_shapes.size(); d++){ - if(d == axis) - continue; - graph_.add_edge({i, current++}, {arg, d}); - } -} - -void axes::update_graph_reshape(ir::instruction *i) { - auto* reshape = static_cast(i); - // operands - ir::value *op = reshape->get_operand(0); - // shapes - auto op_shapes = op->get_type()->get_block_shapes(); - auto res_shapes = reshape->get_type()->get_block_shapes(); - // construct edges - unsigned current = 0; - bool is_skewed = false; - for(unsigned d = 0; d < res_shapes.size(); d ++){ - bool same_shape = res_shapes[d] == op_shapes[current]; - // either add edge between axis or just add a node in the graph - if(!is_skewed && same_shape) - graph_.add_edge({i, d}, {op, current++}); - else - graph_.add_edge({i, d}, {i, d}); - // reshaping is skewed - if(res_shapes[d] > 1 && !same_shape) - is_skewed = true; - } -} - -void axes::update_graph_trans(ir::instruction *i) { - auto *trans = static_cast(i); - ir::value *op = trans->get_operand(0); - auto perm = trans->get_perm(); - // add edge between axis perm[d] and axis d - for(unsigned d = 0; d < perm.size(); d++) - graph_.add_edge({i, perm[d]}, {op, d}); -} - -void axes::update_graph_broadcast(ir::instruction *i) { - auto *broadcast = static_cast(i); - auto shapes = broadcast->get_type()->get_block_shapes(); - ir::value *op = broadcast->get_operand(0); - ir::type *op_ty = op->get_type(); - const auto& op_shapes = op_ty->get_block_shapes(); - // add edge between non-broadcast axes - for(unsigned d = 0; d < shapes.size(); d ++) - if(op_shapes[d] == shapes[d]) - graph_.add_edge({i, d}, {op, d}); -} - -void axes::update_graph_dot(ir::instruction *i) { - auto *dot = static_cast(i); - auto shapes = dot->get_type()->get_block_shapes(); - ir::value *A = dot->get_operand(0); - ir::value *B = dot->get_operand(1); - ir::value *D = dot->get_operand(2); - // add edges between result and accumulator - for(unsigned d = 0; d < shapes.size(); d++) - graph_.add_edge({dot, d}, {D, d}); -} - -void axes::update_graph_elementwise(ir::instruction *i, - bool is_masked_load_async) { - if(i->get_num_operands() == 0) - return; - ir::value *op = i->get_operand(0); - if(!op->get_type()->is_block_ty()) - return; - auto rank = op->get_type()->get_tile_rank(); - for(unsigned d = 0; d < rank; d++) { - // If we are dealing with a masked async load we need to attach the - // dimensions so we match the behaviour of the copy_to_shared instruction - // which async masked load replaces. - if (is_masked_load_async) { - graph_.add_edge({i, d}, {i, d}); - } - - for(ir::value* opx: i->ops()) - for(ir::value* opy: i->ops()) { - if(!is_masked_load_async && !i->get_type()->is_void_ty()) - graph_.add_edge({i, d}, {opx, d}); - graph_.add_edge({opx, d}, {opy, d}); - } - } -} - -void axes::update_graph_no_edge(ir::instruction *i) { - if(!i->get_type()->is_block_ty()) - return; - auto rank = i->get_type()->get_tile_rank(); - for(unsigned d = 0; d < rank; d++) - graph_.add_edge({i, d}, {i, d}); -} - -void axes::update_graph(ir::instruction *i) { - switch (i->get_id()) { - case ir::INST_REDUCE: return update_graph_reduce(i); - case ir::INST_RESHAPE: return update_graph_reshape(i); - case ir::INST_SPLAT: return update_graph_no_edge(i); - case ir::INST_CAT: return update_graph_elementwise(i, true); - case ir::INST_TRANS: return update_graph_trans(i); - case ir::INST_BROADCAST: return update_graph_broadcast(i); - case ir::INST_DOT: return update_graph_dot(i); - case ir::INST_COPY_TO_SHARED: return update_graph_no_edge(i); - case ir::INST_MASKED_LOAD_ASYNC: return update_graph_elementwise(i, true); - case ir::INST_COPY_FROM_SHARED: return update_graph_no_edge(i); - case ir::INST_CVT_LAYOUT: return update_graph_no_edge(i); - default: return update_graph_elementwise(i); - } - return; -} - - -int axes::get(ir::value *value, unsigned dim) { - return axes_.at({value, dim}); -} - -std::vector axes::get(ir::value *value) { - std::vector result; - for(size_t d = 0; d < value->get_type()->get_tile_rank(); d++) - result.push_back(this->get(value, d)); - return result; -} - -void axes::run(ir::module &mod) { - // make graph - graph_.clear(); - axes_.clear(); - ir::for_each_instruction(mod, [this](ir::instruction *x) { - update_graph(x); - }); - // find connected components - graph_.connected_components(nullptr, &axes_); - std::set uniq; - for(auto x: axes_) - uniq.insert(x.second); -} - -} -} - -} diff --git a/lib/codegen/analysis/layout.cc b/lib/codegen/analysis/layout.cc deleted file mode 100644 index 5d30a2f45481..000000000000 --- a/lib/codegen/analysis/layout.cc +++ /dev/null @@ -1,653 +0,0 @@ -#include -#include -#include -#include "triton/codegen/analysis/axes.h" -#include "triton/codegen/analysis/align.h" -#include "triton/codegen/analysis/layout.h" -#include "triton/ir/function.h" -#include "triton/ir/module.h" -#include "triton/ir/utils.h" -// #include "triton/ir/type.h" - -namespace triton{ -namespace codegen{ -namespace analysis{ - -/* -------------------------------- * - * Helper Functions * - * -------------------------------- */ - -inline unsigned clamp(unsigned x, unsigned a, unsigned b) { - unsigned lo = std::min(a, b); - unsigned hi = std::max(a, b); - return std::min(std::max(x, lo), hi); -} - -inline bool is_hmma_c(ir::value *v, int sm){ - bool result = false; - if(auto *x = dynamic_cast(v)){ - ir::value *a = x->get_operand(0); - ir::type *a_ty = a->get_type(); - ir::value *b = x->get_operand(1); - ir::type *b_ty = b->get_type(); - result = (a_ty->get_scalar_ty()->is_fp16_ty() && b_ty->get_scalar_ty()->is_fp16_ty()) || - (a_ty->get_scalar_ty()->is_bf16_ty() && b_ty->get_scalar_ty()->is_bf16_ty()) || - (a_ty->get_scalar_ty()->is_fp32_ty() && b_ty->get_scalar_ty()->is_fp32_ty() && - x->allow_tf32() && sm >= 80) || - (a_ty->get_scalar_ty()->is_integer_ty(8) && b_ty->get_scalar_ty()->is_integer_ty(8) && - sm >= 80); - } - return result; -} - -static mma_layout::TensorCoreType get_mma_type(ir::value *v) { - mma_layout::TensorCoreType mma_type; - if (auto* dot = dynamic_cast(v)) { - ir::value* a = dot->get_operand(0); - ir::value* b = dot->get_operand(1); - ir::type* a_ty = a->get_type(); - ir::type* b_ty = b->get_type(); - ir::type* c_ty = v->get_type(); - - if (c_ty->get_scalar_ty()->is_fp32_ty()) { - // floating point tensor cores - if (a_ty->get_scalar_ty()->is_fp16_ty() && b_ty->get_scalar_ty()->is_fp16_ty()) { - mma_type = mma_layout::FP32_FP16_FP16_FP32; - return mma_type; - } - if (a_ty->get_scalar_ty()->is_bf16_ty() && b_ty->get_scalar_ty()->is_bf16_ty()) { - mma_type = mma_layout::FP32_BF16_BF16_FP32; - return mma_type; - } - if (a_ty->get_scalar_ty()->is_fp32_ty() && b_ty->get_scalar_ty()->is_fp32_ty() - && dot->allow_tf32()) { - mma_type = mma_layout::FP32_TF32_TF32_FP32; - return mma_type; - } - } else if (c_ty->get_scalar_ty()->is_integer_ty(32)) { - // throw std::runtime_error("integer tensor cores are not yet supported"); - // // integer tensor cores - // if (a_ty->get_scalar_ty()->is_integer_ty(1) && b_ty->get_scalar_ty()->is_integer_ty(1)) { - // mma_type = mma_layout::INT32_INT1_INT1_INT32; - // return mma_type; - // } - // if (a_ty->get_scalar_ty()->is_integer_ty(4) && b_ty->get_scalar_ty()->is_integer_ty(4)) { - // mma_type = mma_layout::INT32_INT4_INT4_INT32; - // return mma_type; - // } - if (a_ty->get_scalar_ty()->is_integer_ty(8) && b_ty->get_scalar_ty()->is_integer_ty(8)) { - mma_type = mma_layout::INT32_INT8_INT8_INT32; - return mma_type; - } - } - } - return mma_layout::NOT_APPLICABLE; -} - -inline void extract_io_use(ir::value *v, std::set& result) { - for(ir::user* u: v->get_users()){ - auto i = dynamic_cast(u); - if(i && i->get_pointer_operand() == v) - result.insert(v); - } -} - -inline void extract_dot_use(ir::value *v, ir::value*& result, size_t n) { - for(ir::user* u: v->get_users()){ - auto i = dynamic_cast(u); - if(i && i->get_operand(n) == v) - result = v; - } -} - -inline void extract_hmma_dot_use(ir::value *v, ir::value*& result, size_t n, int sm) { - for(ir::user* u: v->get_users()){ - auto i = dynamic_cast(u); - if(i && is_hmma_c(i, sm) && i->get_operand(n) == v) { - result = i; - } - } -} - - -inline bool is_trans(ir::value *v) { - if(dynamic_cast(v)) { - return true; - } - if(auto *phi = dynamic_cast(v)) { - bool result = true; - for(ir::value *op: phi->ops()) - result = result && is_trans(op); - return result; - } - return false; -} - - -/* -------------------------------- * - * Layout Visitor * - * -------------------------------- */ - -void layout_visitor::visit_layout(data_layout *layout) { - layout->accept(this); -} - - -/* -------------------------------- * - * Base Data Layout * - * -------------------------------- */ - -data_layout::data_layout(id_t id, - const std::vector &axes, - const std::vector &shape, - const std::vector &values, - analysis::align* align): id_(id), axes_(axes), shape_(shape), values_(values) { - // io pointer - std::set ptr; - for(ir::value* v: values_) - extract_io_use(v, ptr); - order_.resize(axes_.size()); - std::iota(order_.begin(), order_.end(), 0); - std::vector max_contiguous; - for(ir::value* p: ptr){ - std::vector curr = align->contiguous(p); - if(curr.size() > max_contiguous.size()) - max_contiguous = curr; - else if(curr.size() == max_contiguous.size()){ - if(*std::max_element(curr.begin(), curr.end()) > *std::max_element(max_contiguous.begin(), max_contiguous.end())) - max_contiguous = curr; - } - } - if(max_contiguous.size() > 0){ - std::sort(order_.begin(), order_.end(), [&](unsigned a, unsigned b) { - return max_contiguous[a] > max_contiguous[b]; - }); -// std::cout << max_contiguous[0] << " " << max_contiguous[1] << std::endl; -// std::cout << order_[0] << " " << order_[1] << std::endl; - } -} - -int data_layout::find_axis(int to_find) const { - auto it = std::find(axes_.begin(), axes_.end(), to_find); - if(it == axes_.end()) - return -1; - return std::distance(axes_.begin(), it); -} - - -distributed_layout::distributed_layout(id_t id, - const std::vector &axes, - const std::vector &shape, - const std::vector &values, - analysis::align* align): data_layout(id, axes, shape, values, align) -{ } - -/* -------------------------------- * - * MMA Layout * - * -------------------------------- */ - -mma_layout::mma_layout(size_t num_warps, - const std::vector& axes, - const std::vector& shape, - const std::vector &values, - analysis::align* align, target* tgt, - shared_layout *layout_a, shared_layout *layout_b, - ir::value *dot): distributed_layout(MMA, axes, shape, values, align) { - tensor_core_type_ = get_mma_type(dot); - /* fragments per warp */ - // try to make things as square as possible to maximize data re-use - if(tgt->as_nvidia()->sm() < 80){ - fpw_ = {2, 2, 1}; - auto ord_a = layout_a->get_order(); - auto ord_b = layout_b->get_order(); - bool is_a_row = ord_a[0] != 0; - bool is_b_row = ord_b[0] != 0; - bool is_a_vec4 = !is_a_row && (layout_a->get_shape()[ord_a[0]] <= 16); - bool is_b_vec4 = is_b_row && (layout_b->get_shape()[ord_b[0]] <= 16); - int pack_size_0 = (is_a_row || is_a_vec4) ? 1 : 2; - int pack_size_1 = (is_b_row && !is_b_vec4) ? 2 : 1; - rep_ = {2*pack_size_0, 2*pack_size_1, 1}; - spw_ = {fpw_[0]*4*rep_[0], fpw_[1]*4*rep_[1], 1}; - contig_per_thread_ = {1, 1}; - } - else{ - // fpw_ = {1, 1, 1}; - spw_ = mma_instr_shape_.at(tensor_core_type_); // e.g., {16, 8, 16} for f32.f16.f16.f32 - contig_per_thread_ = {1, 2}; - // rep_ = {2, 2, 1}; - } - order_ = {0, 1}; - - /* warps per tile */ - wpt_ = {1, 1, 1}; - // try to make warp-level tiles as square as possible to maximize data re-use - if (tgt->as_nvidia()->sm() < 80) { - std::vector wpt_nm1; - do{ - wpt_nm1 = wpt_; - if(wpt_[0] * wpt_[1] * wpt_[2] < num_warps) - wpt_[0] = clamp(wpt_[0]*2, 1, shape_[0] / spw_[0]); - if(wpt_[0] * wpt_[1] * wpt_[2] < num_warps) - wpt_[1] = clamp(wpt_[1]*2, 1, shape_[1] / spw_[1]); - }while(wpt_nm1 != wpt_); - } else { - bool changed = false; - do { - changed = false; - if (wpt_[0] * wpt_[1] * wpt_[2] >= num_warps) - break; - if (shape_[0] / spw_[0] / wpt_[0] >= shape_[1] / (spw_[1]*2) / wpt_[1]) { - if (wpt_[0] < shape_[0] / spw_[0]) { - wpt_[0] *= 2; - changed = true; - } - } else { - if (wpt_[1] < shape_[1] / (spw_[1]*2)) { - wpt_[1] *= 2; - changed = true; - } - } - } while (changed); - } - - /* shape per block */ - shape_per_cta_ = {spw_[0]*wpt_[0], spw_[1]*wpt_[1], 1}; -} - - -/* -------------------------------- * - * Scanline Layout * - * -------------------------------- */ - -scanline_layout::scanline_layout(size_t num_warps, - const std::vector& axes, - const std::vector& shape, - const std::vector &values, - analysis::align* align, target *tgt): distributed_layout(SCANLINE, axes, shape, values, align){ - unsigned size = std::accumulate(shape_.begin(), shape_.end(), 1, std::multiplies()); - unsigned num_threads = tgt->is_gpu() ? num_warps * 32 : 1; - nts_.resize(shape_.size()); - mts_.resize(shape_.size()); - bool is_dot = std::any_of(values.begin(), values.end(), - [&](ir::value* v) { return dynamic_cast(v); }); - - std::vector ptrs; - for(ir::value *v: values) - for(ir::user *usr: v->get_users()) - if(auto *io = dynamic_cast(usr)){ - if(ptrs.empty() || ptrs[0]->get_type()->get_tile_rank() <= io->get_pointer_operand()->get_type()->get_tile_rank()) - ptrs.push_back(io->get_pointer_operand()); - } - - unsigned i = order_[0]; - int contiguous = 1; - for(ir::value* ptr: ptrs){ - int nbits = ptr->get_type()->get_pointer_element_ty()->get_scalar_ty()->get_primitive_size_in_bits(); - contiguous = std::max(contiguous, std::min(align->get(ptr, i), 128 / nbits)); - } - - nts_[i] = clamp(size / num_threads, 1, std::min(contiguous, shape_[i])); - mts_[i] = clamp(num_threads, 1, shape_[i] / nts_[i]); - size /= shape_[i]; - num_threads /= mts_[i]; - if(is_dot) - nts_[order_[1]] = clamp(size / num_threads, 1, std::min(4, shape_[order_[1]])); - for(size_t d = 1; d < shape_.size(); d++){ - i = order_[d]; - if(d > 1 || !is_dot) - nts_[i] = 1; - mts_[i] = clamp(num_threads, 1, shape_[i] / nts_[i]); - num_threads = num_threads / mts_[i]; - } - - shape_per_cta_.resize(shape_.size()); - for(size_t d = 0; d < shape_.size(); d++) - shape_per_cta_[d] = mts_[d]*nts_[d]; -} - - -/* -------------------------------- * - * Shared Layout * - * -------------------------------- */ - -bool shared_layout::is_loop_latch(ir::phi_node *phi, ir::instruction *terminator){ - if(phi->get_parent() != terminator->get_parent()) - return false; - if(auto *br = dynamic_cast(terminator)) - return br->get_true_dest() == phi->get_parent() - || br->get_false_dest() == phi->get_parent(); - else if(dynamic_cast(terminator)) - return false; - else - throw std::runtime_error("unreachable"); -} - - -void shared_layout::extract_double_bufferable(ir::value *v, std::shared_ptr& res) { - auto* phi = dynamic_cast(v); - if(!phi || phi->get_num_incoming() != 2) - return; - ir::basic_block *block_0 = phi->get_incoming_block(0); - ir::basic_block *block_1 = phi->get_incoming_block(1); - ir::instruction *terminator_0 = block_0->get_inst_list().back(); - ir::instruction *terminator_1 = block_1->get_inst_list().back(); - bool is_latch_0 = is_loop_latch(phi, terminator_0); - bool is_latch_1 = is_loop_latch(phi, terminator_1); - ir::value *value_0 = phi->get_incoming_value(0); - ir::value *value_1 = phi->get_incoming_value(1); - ir::instruction *i_0 = dynamic_cast(value_0); - ir::instruction *i_1 = dynamic_cast(value_1); - if(!(i_0 && !i_1) && - !(dynamic_cast(i_0) && dynamic_cast(i_1)) && - !(dynamic_cast(i_0) && dynamic_cast(i_1))) - return; - if(is_latch_1) - res.reset(new double_buffer_info_t{value_0, value_1, phi}); - if(is_latch_0) - res.reset(new double_buffer_info_t{value_1, value_0, phi}); -} - -static bool is_smem(ir::value* v) { - if (dynamic_cast(v) || - dynamic_cast(v)) - return true; - else - return false; -} - -/// param: -/// value_1: next_value -static bool is_multistage_pipe_phi(ir::phi_node* phi, ir::basic_block* bb0, ir::basic_block* bb1, - std::vector& values_0, ir::value*& value_1) { - ir::value* next = phi; - while (auto cphi = dynamic_cast(next)) { - // smem from previous bb & phi/smem from current bb - ir::value* c0 = cphi->get_incoming_value(0); - ir::value* c1 = cphi->get_incoming_value(1); - ir::basic_block *cbb0 = cphi->get_incoming_block(0); - ir::basic_block *cbb1 = cphi->get_incoming_block(1); - - if (is_smem(c0)) { - assert(cbb0 == bb0); - values_0.push_back(c0); - if (auto phi1 = dynamic_cast(c1)) { - next = phi1; - continue; - } else { - if (is_smem(c1)) { - value_1 = c1; - assert(cbb1 == bb1); - return true; - } else { - return false; - } - } - } else - return false; - } - return false; -} - -void shared_layout::extract_N_bufferable(ir::value *v, std::shared_ptr &res, int &prev_stages) { - auto* phi = dynamic_cast(v); - // if the phi node is nested - if (!phi) - return; - - ir::basic_block *bb0 = phi->get_incoming_block(0); - ir::basic_block *bb1 = phi->get_incoming_block(1); - - std::vector values_0; - ir::value* value_1; - - if (!is_multistage_pipe_phi(phi, bb0, bb1, values_0, value_1)) - return; - - // double-buffer is a special case - if (values_0.size() == 1) - return; - - // compute original values_0 input order - std::map order; - int idx = 0; - for (ir::instruction* instr : *bb0) { - if (std::find(values_0.begin(), values_0.end(), instr) != values_0.end()) - order[static_cast(instr)] = idx++; - } - assert(order.size() == values_0.size() && "order size incorrect"); - - int curr_stages = values_0.size() + 1; - if (curr_stages > prev_stages) { - res.reset(new N_buffer_info_t{values_0, value_1, phi, order}); - prev_stages = curr_stages; - } -} - - -shared_layout::shared_layout(data_layout *arg, - const std::vector& axes, - const std::vector& shape, - const std::vector &values, - ir::type *ty, - analysis::align* align, target *tgt) - : data_layout(SHARED, axes, shape, values, align), ty_(ty), tgt_(tgt) { - - size_ = 0; - arg_layout_ = arg; - - // N-stage buffering - int prev_stages = 0; - for (ir::value *v : values) - extract_N_bufferable(v, N_buffer_, prev_stages); - - // double-buffering - if (!N_buffer_) - for(ir::value *v: values) - extract_double_bufferable(v, double_buffer_); - - // order - std::vector arg_order = arg ? arg->get_order() : std::vector{0}; - order_ = arg_order; - - ir::value* dot_a = nullptr; - ir::value* dot_b = nullptr; - ir::value* hmma_dot_a = nullptr; - ir::value* hmma_dot_b = nullptr; - for(ir::value* v: values){ - extract_dot_use(v, dot_a, 0); - extract_dot_use(v, dot_b, 1); - extract_hmma_dot_use(v, hmma_dot_a, /*op*/0, tgt_->as_nvidia()->sm()); - extract_hmma_dot_use(v, hmma_dot_b, /*op*/1, tgt_->as_nvidia()->sm()); - } - hmma_dot_a_ = hmma_dot_a; - hmma_dot_b_ = hmma_dot_b; - - // Update mma_vec - if (hmma_dot_a_) { - assert(order_.size() == 2); - std::vector mat_shape = mma_layout::mma_mat_shape_.at(get_mma_type(hmma_dot_a_)); - mma_vec_ = order_[0] == 1 ? mat_shape[2] : mat_shape[0]; // k : m - mma_strided_ = order_[0] == 1 ? mat_shape[0] : mat_shape[2]; - - // for now, disable swizzle when using lds.8 - if (get_mma_type(hmma_dot_a_) == mma_layout::INT32_INT8_INT8_INT32) - if (order_[0] == 0) // need transpose - allow_swizzle_ = false; - } else if (hmma_dot_b_) { - assert(order_.size() == 2); - std::vector mat_shape = mma_layout::mma_mat_shape_.at(get_mma_type(hmma_dot_b_)); - mma_vec_ = order_[0] == 1 ? mat_shape[1] : mat_shape[2]; // n : k - mma_strided_ = order_[0] == 1 ? mat_shape[2] : mat_shape[1]; - - // for now, disable swizzle when using lds.8 - if (get_mma_type(hmma_dot_b_) == mma_layout::INT32_INT8_INT8_INT32) - if (order_[0] == 1) // need transpose - allow_swizzle_ = false; - } - - // size - size_ = ty_->get_primitive_size_in_bits() / 8; - for(auto s: shape_) - size_ *= s; - if(double_buffer_) - size_ *= 2; - if (N_buffer_) { - size_ *= (N_buffer_->firsts.size() + 1); - } -} - -int shared_layout::get_num_stages() const { - if (double_buffer_) - return 2; - if (N_buffer_) - return N_buffer_->firsts.size() + 1; - return 1; -} - -size_t shared_layout::get_per_stage_elements() const { - return get_per_stage_size()/(ty_->get_primitive_size_in_bits()/8); -} - -/* -------------------------------- * - * ---- Layouts Inference Pass ---- * - * -------------------------------- */ - -layouts::layouts(analysis::axes *axes, analysis::align *align, size_t num_warps, target* tgt) - : axes_(axes), align_(align), num_warps_(num_warps), tgt_(tgt){ } - - -void layouts::connect(ir::value *x, ir::value *y) { - if(x == y) - return; - if(!x->get_type()->is_block_ty()) - return; - if(!y->get_type()->is_block_ty()) - return; - std::vector x_axes = axes_->get(x); - std::vector y_axes = axes_->get(y); - std::set sx_axes(x_axes.begin(), x_axes.end()); - std::set sy_axes(y_axes.begin(), y_axes.end()); - std::set common; - std::set_intersection(sx_axes.begin(), sx_axes.end(), - sy_axes.begin(), sy_axes.end(), - std::inserter(common, common.begin())); - graph_.add_edge(x, x); - graph_.add_edge(y, y); - if(!common.empty()) - graph_.add_edge(x, y); -} - -void layouts::make_graph(ir::instruction *i) { - for(ir::value* opx: i->ops()) - for(ir::value* opy: i->ops()){ - connect(i, opx); - connect(opx, opy); - } -} - -void layouts::create(size_t id, const std::vector& values) { -// if(layouts_.find(id) != layouts_.end()) -// return; - auto it_hmma_c = std::find_if(values.begin(), values.end(), - [&](ir::value* v){ return is_hmma_c(v, tgt_->as_nvidia()->sm()); }); - auto cmp = [](ir::value* x, ir::value *y) { - std::pair xx = {x->get_type()->get_tile_rank(), x->get_type()->get_tile_num_elements()}; - std::pair yy = {y->get_type()->get_tile_rank(), y->get_type()->get_tile_num_elements()}; - return xx < yy; - }; - std::vector lvalue = values; - std::remove_if(lvalue.begin(), lvalue.end(), [&](ir::value* v) { return dynamic_cast(v); }); - ir::value *largest = *std::max_element(lvalue.begin(), lvalue.end(), cmp); - const auto& axes = axes_->get(largest); - const auto& shapes = largest->get_type()->get_block_shapes(); - auto it_cts = std::find_if(values.begin(), values.end(), [](ir::value* v) { - return dynamic_cast(v) || - dynamic_cast(v); - }); - // type - if(it_hmma_c != values.end()){ - ir::instruction *dot = (ir::instruction*)*it_hmma_c; - ir::value *a = dot->get_operand(0); - ir::value *b = dot->get_operand(1); - create(groups_.at(a), values_.at(groups_.at(a))); - create(groups_.at(b), values_.at(groups_.at(b))); - layouts_[id] = new mma_layout(num_warps_, axes, shapes, values, align_, tgt_, - (shared_layout*)layouts_.at(groups_.at(a)), - (shared_layout*)layouts_.at(groups_.at(b)), - dot); - } - else if(it_cts != values.end()){ - ir::instruction *cts = (ir::instruction*)*it_cts; - ir::value *arg = cts->get_operand(0); - create(groups_.at(arg), values_.at(groups_.at(arg))); - layouts_[id] = new shared_layout(get(arg), axes, shapes, values, largest->get_type()->get_scalar_ty(), align_, tgt_); - } - else{ - layouts_[id] = new scanline_layout(num_warps_, axes, shapes, values, align_, tgt_); - } -} - -void layouts::run(ir::module &mod) { - // make graph - graph_.clear(); - layouts_.clear(); - groups_.clear(); - - ir::for_each_instruction(mod, [this](ir::instruction* i) { - make_graph(i); - }); - - - // connected components - graph_.connected_components(&values_, &groups_); - - // create layouts - for(const auto& x: values_) - create(x.first, x.second); - - // create temporaries - size_t id = values_.size(); - ir::for_each_instruction(mod, [this, &id](ir::instruction* i) { - if(auto *red = dynamic_cast(i)) { - id++; - ir::value *arg = red->get_operand(0); - unsigned axis = red->get_axis(); - // shape - auto shapes = arg->get_type()->get_block_shapes(); - scanline_layout *layout = get(arg)->to_scanline(); - shapes[axis] = layout->mts(axis); - // create layout - layouts_[id] = new shared_layout(layout, axes_->get(arg), shapes, {red}, red->get_type()->get_scalar_ty(), align_, tgt_); - tmp_[red] = id; - } - if(auto *val = dynamic_cast(i)){ - distributed_layout* out_layout = dynamic_cast(get(val)); - distributed_layout* in_layout = dynamic_cast(get(i->get_operand(0))); - id++; - size_t dim = val->get_type()->get_tile_rank(); - ir::type::block_shapes_t shape(dim); - for(size_t k = 0; k < dim; k++){ - shape[k] = std::max(in_layout->shape_per_cta(k), - out_layout->shape_per_cta(k)); - } - auto in_ord = in_layout->get_order(); - auto out_ord = out_layout->get_order(); - int in_vec = in_layout->contig_per_thread(in_ord[0]); - int out_vec = out_layout->contig_per_thread(out_ord[0]); - int pad = std::max(in_vec, out_vec); - shape[out_ord[0]] += pad; - layouts_[id] = new shared_layout(out_layout, axes_->get(val), shape, {val}, val->get_type()->get_scalar_ty(), align_, tgt_); - tmp_[val] = id; - } - if(auto *atom = dynamic_cast(i)){ - id++; - layouts_[id] = new shared_layout(nullptr, {}, {1}, {atom}, atom->get_type()->get_scalar_ty(), align_, tgt_); - tmp_[atom] = id; - } - }); - -} - -} -} -} diff --git a/lib/codegen/analysis/liveness.cc b/lib/codegen/analysis/liveness.cc deleted file mode 100644 index 7beae21a1247..000000000000 --- a/lib/codegen/analysis/liveness.cc +++ /dev/null @@ -1,59 +0,0 @@ -#include -#include -#include "triton/codegen/analysis/liveness.h" -#include "triton/codegen/analysis/layout.h" -#include "triton/ir/function.h" -#include "triton/ir/module.h" -#include "triton/ir/utils.h" - -namespace triton{ -namespace codegen{ -namespace analysis{ - - -void liveness::run(ir::module &mod) { - intervals_.clear(); - - // Assigns index to each instruction - std::map indices; - for(ir::function *fn: mod.get_function_list()){ - slot_index index = 0; - for(ir::basic_block *block: fn->blocks()) - for(ir::instruction *instr: block->get_inst_list()){ - index += 1; - indices.insert({instr, index}); - } - } - - // create live intervals - for(auto &x: layouts_->get_all()) { - shared_layout* layout = x.second->to_shared(); - if(!layout) - continue; - // users - std::set users; - for(ir::value *v: layout->get_values()){ - for(ir::user *u: v->get_users()) - users.insert(u); - } - // compute intervals - unsigned start = INT32_MAX; - for(ir::value *v: layout->get_values()) - if(indices.find(v) != indices.end()) - start = std::min(start, indices.at(v)); - unsigned end = 0; - for(ir::user *u: users) - if(indices.find(u) != indices.end()) - end = std::max(end, indices.at(u)); - if(end == 0) - end = start + 1; - intervals_[layout] = segment{start, end}; - } - - - -} - -} -} -} diff --git a/lib/codegen/analysis/swizzle.cc b/lib/codegen/analysis/swizzle.cc deleted file mode 100644 index 5737f80a0768..000000000000 --- a/lib/codegen/analysis/swizzle.cc +++ /dev/null @@ -1,61 +0,0 @@ -#include "triton/codegen/analysis/swizzle.h" -#include "triton/codegen/analysis/layout.h" -#include "triton/codegen/target.h" -#include "triton/ir/type.h" -#include - -namespace triton{ -namespace codegen{ -namespace analysis{ - - -void swizzle::run(ir::module &) { - per_phase_.clear(); - max_phase_.clear(); - - for(auto &x: layouts_->get_all()){ - shared_layout* layout = dynamic_cast(x.second); - if(!layout) - continue; - ir::value* mma_dot_a = layout->hmma_dot_a(); - ir::value* mma_dot_b = layout->hmma_dot_b(); - - if(!mma_dot_a && !mma_dot_b){ - per_phase_[layout] = 1; - max_phase_[layout] = 1; - vec_[layout] = 1; - continue; - } - auto ord = layout->get_order(); - scanline_layout* in_layout = dynamic_cast(layout->get_arg_layout()); - if(!in_layout) - continue; - int dtsize = layout->get_type()->get_scalar_ty()->get_primitive_size_in_bits() / 8; - if(tgt_->as_nvidia() && tgt_->as_nvidia()->sm() < 80){ - int inner = mma_dot_a ? 0 : 1; - per_phase_[layout] = std::max(128 / (in_layout->mts(ord[0])*in_layout->nts(ord[0])*dtsize), 1); - max_phase_[layout] = (ord[inner] == 1 ? 8 : 4) / per_phase_[layout]; - if(mma_dot_a) - vec_[layout] = 2*layouts_->get(mma_dot_a)->to_mma()->rep(0); - else - vec_[layout] = 2*layouts_->get(mma_dot_b)->to_mma()->rep(1); - } - else { - if (!layout->allow_swizzle()) { - per_phase_[layout] = 1; - max_phase_[layout] = 1; - vec_[layout] = 1; - } else { - per_phase_[layout] = std::max(128 / (in_layout->mts(ord[0])*in_layout->nts(ord[0])*dtsize), 1); - max_phase_[layout] = layout->get_mma_strided() / per_phase_[layout]; - vec_[layout] = layout->get_mma_vec(); - } - } - } -} - -} -} -} - - diff --git a/lib/codegen/pass.cc b/lib/codegen/pass.cc deleted file mode 100644 index 5c93e10e6d88..000000000000 --- a/lib/codegen/pass.cc +++ /dev/null @@ -1,86 +0,0 @@ -#include "triton/codegen/pass.h" -#include "triton/codegen/analysis/align.h" -#include "triton/codegen/analysis/allocation.h" -#include "triton/codegen/analysis/axes.h" -#include "triton/codegen/analysis/liveness.h" -#include "triton/codegen/analysis/swizzle.h" -#include "triton/codegen/selection/generator.h" -#include "triton/codegen/transform/coalesce.h" -#include "triton/codegen/transform/cts.h" -#include "triton/codegen/transform/dce.h" -#include "triton/codegen/transform/disassociate.h" -#include "triton/codegen/transform/membar.h" -#include "triton/codegen/transform/peephole.h" -#include "triton/codegen/transform/pipeline.h" -#include "triton/codegen/transform/prefetch.h" -#include "triton/ir/function.h" -#include "triton/ir/module.h" -#include "triton/ir/print.h" -#include "llvm/IR/Module.h" -#include "llvm/IR/LegacyPassManager.h" -#include "llvm/IR/Verifier.h" -namespace triton { -namespace codegen { - -// TODO: -// There should be a proper pass manager there! -std::unique_ptr add_passes_to_emit_bin(ir::module &ir, llvm::LLVMContext& ctx, codegen::target* target, - int cc, int num_warps, int num_stages, int& shared_static) { - // generate llvm code - std::string name = ir.get_function_list()[0]->get_name(); - std::unique_ptr llvm(new llvm::Module(name, ctx)); - // optimizations - bool cts_use_async = target->as_nvidia() && target->as_nvidia()->sm() >= 80; - // create passes - codegen::analysis::align align; - codegen::analysis::axes axes; - codegen::transform::cts cts(cts_use_async); - codegen::transform::pipeline pipeline(cts_use_async, num_stages); - codegen::transform::disassociate disassociate; - codegen::analysis::layouts layouts(&axes, &align, num_warps, target); - codegen::analysis::liveness liveness(&layouts); - codegen::analysis::swizzle swizzle(&layouts, target); - codegen::analysis::allocation allocation(&liveness); - codegen::transform::dce dce; - codegen::transform::peephole peephole(target, &layouts); - codegen::transform::coalesce coalesce(&align, &layouts); - codegen::transform::prefetch prefetch_s(target); - codegen::transform::membar barriers(&liveness, &layouts, &allocation, &prefetch_s, target); - codegen::generator isel(&axes, &layouts, &align, &allocation, &swizzle, target, num_warps); - // run passes - dce.run(ir); - peephole.run(ir); - dce.run(ir); - pipeline.run(ir); - dce.run(ir); - disassociate.run(ir); - dce.run(ir); - align.run(ir); axes.run(ir); layouts.run(ir); - peephole.run(ir); - dce.run(ir); - if (target->is_gpu()) - cts.run(ir); - align.run(ir); axes.run(ir); layouts.run(ir); - coalesce.run(ir); - dce.run(ir); - align.run(ir); - dce.run(ir); - if (target->is_gpu()) - cts.run(ir); - dce.run(ir); - align.run(ir); axes.run(ir); layouts.run(ir); - peephole.run(ir); - dce.run(ir); - align.run(ir); axes.run(ir); layouts.run(ir); - swizzle.run(ir); - liveness.run(ir); - allocation.run(ir); - prefetch_s.run(ir); - barriers.run(ir); - isel.visit(ir, *llvm); - shared_static = allocation.allocated_size(); - return llvm; -} - -} // namespace codegen -} // namespace triton diff --git a/lib/codegen/selection/generator.cc b/lib/codegen/selection/generator.cc deleted file mode 100644 index f8cf08cbae35..000000000000 --- a/lib/codegen/selection/generator.cc +++ /dev/null @@ -1,3182 +0,0 @@ -#include -#include -#include -#include -#include "triton/codegen/selection/generator.h" -#include "triton/codegen/target.h" -#include "triton/codegen/analysis/axes.h" -#include "triton/codegen/analysis/allocation.h" -#include "triton/codegen/analysis/align.h" -#include "triton/codegen/analysis/swizzle.h" -#include "triton/codegen/transform/coalesce.h" -#include "triton/ir/context.h" -#include "triton/ir/module.h" -#include "triton/ir/function.h" -#include "triton/ir/type.h" -#include "llvm/IR/Module.h" -#include "llvm/IR/IRBuilder.h" -#include "llvm/IR/IntrinsicsNVPTX.h" -#include "llvm/IR/BasicBlock.h" -#include "llvm/IR/Attributes.h" -#include "llvm/IR/InlineAsm.h" -#include "llvm/Transforms/Utils/BasicBlockUtils.h" - -namespace triton{ -namespace codegen{ - -using namespace llvm; - -Value* adder::operator()(Value *x, Value *y, const std::string& name) { - // (x + cst) + y -> (x + y) + cst - if(auto* bin = dyn_cast(x)) - if(bin->getOpcode() == llvm::BinaryOperator::BinaryOps::Add) - if(dyn_cast(bin->getOperand(1))){ - return (*builder_)->CreateAdd((*builder_)->CreateAdd(bin->getOperand(0), y), - bin->getOperand(1)); - } - // (x + (y + cst)) -> (x + y) + cst - if(auto* bin = dyn_cast(y)) - if(bin->getOpcode() == llvm::BinaryOperator::BinaryOps::Add) - if(dyn_cast(bin->getOperand(1))){ - return (*builder_)->CreateAdd((*builder_)->CreateAdd(x, bin->getOperand(0)), - bin->getOperand(1)); - } - - // default - return (*builder_)->CreateAdd(x, y, name); -} - -Value* multiplier::operator()(Value *x, Value *y, const std::string &name) { - // (x + cst1) * cst2 -> (x * cst2) + (cst1 * cst2) - if(auto* bin = dyn_cast(x)) - if(bin->getOpcode() == llvm::BinaryOperator::BinaryOps::Add) - if(dyn_cast(bin->getOperand(1))) - if(dyn_cast(y)){ - return (*builder_)->CreateAdd((*builder_)->CreateMul(bin->getOperand(0), y), - (*builder_)->CreateMul(bin->getOperand(1), y)); - } - // default - return (*builder_)->CreateMul(x, y, name); -} - -Value* geper::operator()(Value *ptr, Value* off, const std::string& name){ - // (ptr + cst1) + (cst2) -> ptr + (cst1 + cst2) - if(auto* gep = dyn_cast(ptr)) - if(ConstantInt* cst1 = dyn_cast(gep->idx_begin())) - if(ConstantInt* cst2 = dyn_cast(off)){ - return (*builder_)->CreateGEP(gep->getPointerOperand(), - (*builder_)->CreateAdd(cst1, cst2)); - } - // ptr + (off + cst) -> (ptr + off) + cst - if(auto* bin = dyn_cast(off)) - if(bin->getOpcode() == llvm::BinaryOperator::BinaryOps::Add) - if(ConstantInt* cst = dyn_cast(bin->getOperand(1))){ - return (*builder_)->CreateGEP((*builder_)->CreateGEP(ptr, bin->getOperand(0)), - bin->getOperand(1)); - } - // default - return (*builder_)->CreateGEP(ptr, off, name); -} - -//Value* geper::operator()(Type *ty, Value *ptr, std::vector vals, const std::string &name) { -// return (*builder_)->CreateGEP(ty, ptr, vals, name); -//} - -// types -#define void_ty builder_->getVoidTy() -#define f16_ty builder_->getHalfTy() -#define bf16_ty builder_->getBFloatTy() -#define f32_ty builder_->getFloatTy() -#define i8_ty builder_->getInt8Ty() -#define i16_ty builder_->getInt16Ty() -#define i32_ty builder_->getInt32Ty() -#define vec_ty(type, num_el) VectorType::get(type, num_el, false) -#define ptr_ty(...) PointerType::get(__VA_ARGS__) -// constants -#define i32(...) builder_->getInt32(__VA_ARGS__) -// ops -#define and_(...) builder_->CreateAnd(__VA_ARGS__) -#define atomic_cmp_xchg(...) builder_->CreateAtomicCmpXchg(__VA_ARGS__) -#define atomic_rmw(...) builder_->CreateAtomicRMW(__VA_ARGS__) -#define bin_op(...) builder_->CreateBinOp(__VA_ARGS__) -#define bit_cast(...) builder_->CreateBitCast(__VA_ARGS__) -#define br(...) builder_->CreateBr(__VA_ARGS__) -#define call(...) builder_->CreateCall(__VA_ARGS__) -#define cast(...) builder_->CreateCast(__VA_ARGS__) -#define cond_br(...) builder_->CreateCondBr(__VA_ARGS__) -#define exact_udiv(...) builder_->CreateExactUDiv(__VA_ARGS__) -#define extract_elt(...) builder_->CreateExtractElement(__VA_ARGS__) -#define extract_val(...) builder_->CreateExtractValue(__VA_ARGS__) -#define fadd(...) builder_->CreateFAdd(__VA_ARGS__) -#define fcmp(...) builder_->CreateFCmp(__VA_ARGS__) -#define fmul(...) builder_->CreateFMul(__VA_ARGS__) -#define fpcast(...) builder_->CreateFPCast(__VA_ARGS__) -#define fsub(...) builder_->CreateFSub(__VA_ARGS__) -#define icmp(...) builder_->CreateICmp(__VA_ARGS__) -#define icmp_eq(...) builder_->CreateICmpEQ(__VA_ARGS__) -#define icmp_sge(...) builder_->CreateICmpSGE(__VA_ARGS__) -#define icmp_sle(...) builder_->CreateICmpSLE(__VA_ARGS__) -#define icmp_ult(...) builder_->CreateICmpULT(__VA_ARGS__) -#define insert_elt(...) builder_->CreateInsertElement(__VA_ARGS__) -#define intrinsic(...) builder_->CreateIntrinsic(__VA_ARGS__) -#define load(ptr) builder_->CreateLoad(ptr->getType()->getPointerElementType(), ptr) -#define lshr(...) builder_->CreateLShr(__VA_ARGS__) -#define max_num(...) builder_->CreateMaxNum(__VA_ARGS__) -#define min_num(...) builder_->CreateMinNum(__VA_ARGS__) -#define neg(...) builder_->CreateNeg(__VA_ARGS__) -#define phi(...) builder_->CreatePHI(__VA_ARGS__) -#define ret(...) builder_->CreateRet(__VA_ARGS__) -#define select(...) builder_->CreateSelect(__VA_ARGS__) -#define store(...) builder_->CreateStore(__VA_ARGS__) -#define sub(...) builder_->CreateSub(__VA_ARGS__) -#define shl(...) builder_->CreateShl(__VA_ARGS__) -#define udiv(...) builder_->CreateUDiv(__VA_ARGS__) -#define urem(...) builder_->CreateURem(__VA_ARGS__) -#define splat(...) builder_->CreateVectorSplat(__VA_ARGS__) -#define xor_(...) builder_->CreateXor(__VA_ARGS__) - -/** - * \brief Convert Triton-IR Type to LLVM-IR Type - */ -Type *generator::cvt(ir::type *ty) { - // function - if(auto* tt = dynamic_cast(ty)){ - Type *ret_ty = cvt(tt->get_return_ty()); - std::vector arg_tys(tt->get_num_params()); - for(size_t i = 0; i < arg_tys.size(); i++) - arg_tys[i] = cvt(tt->get_param_ty(i)); - return FunctionType::get(ret_ty, arg_tys, false); - } - // pointer - if(ty->is_pointer_ty()){ - Type *elt_ty = cvt(ty->get_pointer_element_ty()); - unsigned addr_space = ty->get_pointer_address_space(); - return ptr_ty(elt_ty, addr_space); - } - // integer - if(ty->is_integer_ty()){ - unsigned bitwidth = ty->get_integer_bitwidth(); - return IntegerType::get(*ctx_, bitwidth); - } - // primitive types - switch(ty->get_type_id()){ - case ir::type::VoidTyID: return Type::getVoidTy(*ctx_); - case ir::type::FP8TyID: return Type::getInt8Ty(*ctx_); - case ir::type::FP16TyID: return Type::getHalfTy(*ctx_); - case ir::type::BF16TyID: return Type::getBFloatTy(*ctx_); - case ir::type::FP32TyID: return Type::getFloatTy(*ctx_); - case ir::type::FP64TyID: return Type::getDoubleTy(*ctx_); - case ir::type::LabelTyID: return Type::getLabelTy(*ctx_); - case ir::type::MetadataTyID: return Type::getMetadataTy(*ctx_); - case ir::type::TokenTyID: return Type::getTokenTy(*ctx_); - default: break; - } - // unknown type - throw std::runtime_error("unknown conversion from ir::type to Type"); -} - -/** - * \brief Convert Triton-IR Attribute to LLVM-IR Attribute - */ -llvm::Attribute generator::cvt(ir::attribute attr) { - switch(attr.get_kind()){ - case ir::noalias: return llvm::Attribute::get(*ctx_, llvm::Attribute::NoAlias); - case ir::readonly: return llvm::Attribute::get(*ctx_, llvm::Attribute::ReadOnly); - case ir::writeonly: return llvm::Attribute::get(*ctx_, llvm::Attribute::WriteOnly); - case ir::aligned: return llvm::Attribute::get(*ctx_, llvm::Attribute::Alignment, attr.get_value()); - case ir::retune: return llvm::Attribute::get(*ctx_, llvm::Attribute::None); - default: throw std::runtime_error("cannot convert ir::attribute_t to llvm::Attribute"); - } -} - -/** - * \brief Constructor of LLVM code generator - */ -generator::generator(analysis::axes *a_axes, - analysis::layouts *layouts, - analysis::align *alignment, - analysis::allocation *alloc, - analysis::swizzle *swizzle, - target *tgt, - unsigned num_warps) - : a_axes_(a_axes), layouts_(layouts), alignment_(alignment), alloc_(alloc), swizzle_(swizzle), - tgt_(tgt), num_warps_(num_warps), add(&builder_), mul(&builder_), gep(&builder_) { - -} - -/** - * \brief Code Generation for `value` - */ -void generator::visit_value(ir::value* v) { - if(!seen_.insert(v).second) - return; - if(v->get_type()->is_block_ty()){ - if(analysis::shared_layout* layout = layouts_->get(v)->to_shared()){ - analysis::N_buffer_info_t *n_buffer = layout->get_N_buffer(); - analysis::double_buffer_info_t *double_buffer = layout->get_double_buffer(); - - // offset - Value *offset = nullptr; - // base pointer - Value *ptr = shared_ptr_[layout]; - - if (n_buffer) { - // ptr = base (shared_ptr_[layout]) + smem_idx * size - // read_smem_idx - if (v == n_buffer->phi) { - ptr = shared_ptr_[layout]; - } - // write_smem_idx - if (std::find(n_buffer->firsts.begin(), n_buffer->firsts.end(), v) != n_buffer->firsts.end()) { - int write_smem_idx = /*stage_idx*/n_buffer->firsts_idx.at(v); - int elements = write_smem_idx * layout->get_per_stage_elements(); - ptr = gep(shared_pre_ptr_[layout], i32(elements)); - } else if (v == n_buffer->latch) { - Value* write_smem_idx = write_smem_idx_[layout]; - Value* elements = mul(write_smem_idx, i32(layout->get_per_stage_elements())); - ptr = gep(shared_pre_ptr_[layout], elements); - } - } else if (double_buffer) { - if(v == double_buffer->phi) - offset = shared_off_[layout]; - if(v == double_buffer->latch) - ptr = shared_next_ptr_[layout]; - else if(v == double_buffer->first) - ptr = shared_pre_ptr_[layout]; - } // else do nothing - // what visit_dot & vist_cts & ... see - shmems_[v] = ptr; - // now only latches have offset (PHINode), only used by finalize_share_layout() - shoffs_[v] = offset; - } - } - // visit operands - BasicBlock *current = builder_->GetInsertBlock(); - auto *inst = dynamic_cast(v); - if(inst) - for(ir::value *op: inst->ops()){ - if(dynamic_cast(op) || !dynamic_cast(v)) - visit_value(op); - } - init_idx(v); - // change insert point for phi node - builder_->SetInsertPoint(current); - auto *phi = dynamic_cast(v); - if(phi && !current->empty() && current->getFirstNonPHI()) - builder_->SetInsertPoint(&*current->getFirstNonPHI()); - // visit user - if(auto *usr = dynamic_cast(v)){ - usr->accept(this); - } - // revert insert point - if(phi && !current->empty() && current->getFirstNonPHI()) - builder_->SetInsertPoint(current); -} - -/** - * \brief Code Generation for `phi` - */ -void generator::visit_phi_node(ir::phi_node* x) { - Type *ty = cvt(x->get_type()->get_scalar_ty()); - for(indices_t idx: idxs_.at(x)) - vals_[x][idx] = phi(ty, x->get_num_operands()); -} - -/** - * \brief Code Generation for `binary_operator` - */ -void generator::visit_binary_operator(ir::binary_operator*x) { - using ll = llvm::Instruction::BinaryOps; - auto cvt = [](ir::binary_op_t op){ - using tt = ir::binary_op_t; - switch(op) { - case tt::Add: return ll::Add; - case tt::FAdd: return ll::FAdd; - case tt::Sub: return ll::Sub; - case tt::FSub: return ll::FSub; - case tt::Mul: return ll::Mul; - case tt::FMul: return ll::FMul; - case tt::UDiv: return ll::UDiv; - case tt::SDiv: return ll::SDiv; - case tt::FDiv: return ll::FDiv; - case tt::URem: return ll::URem; - case tt::SRem: return ll::SRem; - case tt::FRem: return ll::FRem; - case tt::Shl: return ll::Shl; - case tt::LShr: return ll::LShr; - case tt::AShr: return ll::AShr; - case tt::And: return ll::And; - case tt::Or: return ll::Or; - case tt::Xor: return ll::Xor; - default: throw std::runtime_error("unreachable switch"); - } - }; - for(indices_t idx: idxs_.at(x)){ - Value *lhs = vals_[x->get_operand(0)][idx]; - Value *rhs = vals_[x->get_operand(1)][idx]; - auto op = cvt(x->get_op()); - if(op == ll::Add) - vals_[x][idx] = add(lhs, rhs); - else if(op == ll::Mul) - vals_[x][idx] = mul(lhs, rhs); - else if(op == ll::FDiv && !x->get_fdiv_ieee_rounding() && - x->get_type()->get_scalar_ty()->is_fp32_ty()){ - InlineAsm *ptx = InlineAsm::get(FunctionType::get(f32_ty, {f32_ty, f32_ty}, false), - " div.full.f32 $0, $1, $2;", "=r,r,r", false); - vals_[x][idx] = builder_->CreateCall(ptx, {lhs, rhs}); - - } - else - vals_[x][idx] = bin_op(op, lhs, rhs); - } -} - -/** - * \brief Code Generation for `getelementptr` - */ -void generator::visit_getelementptr_inst(ir::getelementptr_inst* x) { - for(indices_t idx: idxs_.at(x)){ - Value *ptr = vals_[x->get_pointer_operand()][idx]; - std::vector vals; - for(auto it= x->idx_begin(); it != x->idx_end(); it++) - vals.push_back(vals_[*it][idx]); - assert(vals.size() == 1); - vals_[x][idx] = gep(ptr, vals[0]); - } -} - -/** - * \brief Code Generation for `icmp` - */ -void generator::visit_icmp_inst(ir::icmp_inst* x) { - auto cvt = [](ir::cmp_pred_t pred) { - using ll = llvm::CmpInst::Predicate; - using tt = ir::cmp_pred_t; - switch(pred){ - case tt::FIRST_ICMP_PREDICATE: return ll::FIRST_ICMP_PREDICATE; - case tt::ICMP_EQ: return ll::ICMP_EQ; - case tt::ICMP_NE: return ll::ICMP_NE; - case tt::ICMP_UGT: return ll::ICMP_UGT; - case tt::ICMP_UGE: return ll::ICMP_UGE; - case tt::ICMP_ULT: return ll::ICMP_ULT; - case tt::ICMP_ULE: return ll::ICMP_ULE; - case tt::ICMP_SGT: return ll::ICMP_SGT; - case tt::ICMP_SGE: return ll::ICMP_SGE; - case tt::ICMP_SLT: return ll::ICMP_SLT; - case tt::ICMP_SLE: return ll::ICMP_SLE; - case tt::LAST_ICMP_PREDICATE: return ll::LAST_ICMP_PREDICATE; - default: throw std::runtime_error("unreachable switch"); - } - }; - - for(indices_t idx: idxs_.at(x)){ - Value *lhs = vals_[x->get_operand(0)][idx]; - Value *rhs = vals_[x->get_operand(1)][idx]; - vals_[x][idx] = icmp(cvt(x->get_pred()), lhs, rhs); - } -} - -/** - * \brief Code Generation for `fcmp` - */ -void generator::visit_fcmp_inst(ir::fcmp_inst* x) { - auto cvt = [](ir::cmp_pred_t pred) { - using ll = llvm::CmpInst::Predicate; - using tt = ir::cmp_pred_t; - switch(pred){ - case tt::FIRST_FCMP_PREDICATE: return ll::FIRST_FCMP_PREDICATE; - case tt::FCMP_FALSE: return ll::FCMP_FALSE; - case tt::FCMP_OEQ: return ll::FCMP_OEQ; - case tt::FCMP_OGT: return ll::FCMP_OGT; - case tt::FCMP_OGE: return ll::FCMP_OGE; - case tt::FCMP_OLT: return ll::FCMP_OLT; - case tt::FCMP_OLE: return ll::FCMP_OLE; - case tt::FCMP_ONE: return ll::FCMP_ONE; - case tt::FCMP_ORD: return ll::FCMP_ORD; - case tt::FCMP_UNO: return ll::FCMP_UNO; - case tt::FCMP_UEQ: return ll::FCMP_UEQ; - case tt::FCMP_UGT: return ll::FCMP_UGT; - case tt::FCMP_UGE: return ll::FCMP_UGE; - case tt::FCMP_ULT: return ll::FCMP_ULT; - case tt::FCMP_ULE: return ll::FCMP_ULE; - case tt::FCMP_UNE: return ll::FCMP_UNE; - case tt::FCMP_TRUE: return ll::FCMP_TRUE; - case tt::LAST_FCMP_PREDICATE: return ll::LAST_FCMP_PREDICATE; - default: throw std::runtime_error("unreachable switch"); - } - }; - for(indices_t idx: idxs_.at(x)){ - Value *lhs = vals_[x->get_operand(0)][idx]; - Value *rhs = vals_[x->get_operand(1)][idx]; - vals_[x][idx] = fcmp(cvt(x->get_pred()), lhs, rhs); - } -} - - -std::tuple generator::fp32x4_to_fp8x4(Value *in0, Value *in1, Value *in2, Value *in3){ - in0 = cast(llvm::Instruction::FPTrunc, in0, f16_ty); - in1 = cast(llvm::Instruction::FPTrunc, in1, f16_ty); - in2 = cast(llvm::Instruction::FPTrunc, in2, f16_ty); - in3 = cast(llvm::Instruction::FPTrunc, in3, f16_ty); - Value *ret0, *ret1, *ret2, *ret3; - std::tie(ret0, ret1, ret2, ret3) = fp16x4_to_fp8x4(in0, in1, in2, in3); - return std::make_tuple(ret0, ret1, ret2, ret3); -} - -std::tuple generator::fp8x4_to_fp32x4(Value *in0, Value *in1, Value *in2, Value *in3){ - Value *ret0, *ret1, *ret2, *ret3; - std::tie(ret0, ret1, ret2, ret3) = fp8x4_to_fp16x4(in0, in1, in2, in3); - ret0 = cast(llvm::Instruction::FPExt, ret0, f32_ty); - ret1 = cast(llvm::Instruction::FPExt, ret1, f32_ty); - ret2 = cast(llvm::Instruction::FPExt, ret2, f32_ty); - ret3 = cast(llvm::Instruction::FPExt, ret3, f32_ty); - return std::make_tuple(ret0, ret1, ret2, ret3); -} - - -std::tuple generator::fp8x4_to_fp16x4(Value *in0, Value *in1, Value *in2, Value *in3){ - Type *ret_ty = StructType::get(*ctx_, {vec_ty(f16_ty, 2), vec_ty(f16_ty, 2)}); - InlineAsm *ptx = InlineAsm::get(FunctionType::get(ret_ty, {i32_ty}, false), - "{" - ".reg .b32 a<2>, b<2>; \n\t" - "prmt.b32 a0, 0, $2, 0x5040; \n\t" // If input is 0xdcba set a0 to 0xb0a0 - "prmt.b32 a1, 0, $2, 0x7060; \n\t" // If input is 0xdcba set a1 to 0xd0c0 - "lop3.b32 b0, a0, 0x7fff7fff, 0, 0xc0; \n\t" // b0 = a0 & 0x7fff7fff (strip sign) - "lop3.b32 b1, a1, 0x7fff7fff, 0, 0xc0; \n\t" // b1 = a1 & 0x7fff7fff (strip sign) - "shr.b32 b0, b0, 1; \n\t" // b0 <<= 1 (shift into fp16 poistion) - "shr.b32 b1, b1, 1; \n\t" // b1 <<= 1 (shift into fp16 position) - "lop3.b32 $0, b0, 0x80008000, a0, 0xf8; \n\t" // out0 = b0 | (0x80008000 | a0) (restore sign) - "lop3.b32 $1, b1, 0x80008000, a1, 0xf8; \n\t" // out1 = b1 | (0x80008000 | a1) (restore sign) - "}", "=r,=r,r", false); - Value *packed_in = UndefValue::get(vec_ty(i8_ty, 4)); - packed_in = insert_elt(packed_in, in0, (uint64_t)0); - packed_in = insert_elt(packed_in, in1, (uint64_t)1); - packed_in = insert_elt(packed_in, in2, (uint64_t)2); - packed_in = insert_elt(packed_in, in3, (uint64_t)3); - Value *in = bit_cast(packed_in, i32_ty); - Value *ret = call(ptx, {in}); - Value *packed_ret0 = extract_val(ret, {0}); - Value *packed_ret1 = extract_val(ret, {1}); - Value *ret0 = extract_elt(packed_ret0, (uint64_t)0); - Value *ret1 = extract_elt(packed_ret0, (uint64_t)1); - Value *ret2 = extract_elt(packed_ret1, (uint64_t)0); - Value *ret3 = extract_elt(packed_ret1, (uint64_t)1); - return std::make_tuple(ret0, ret1, ret2, ret3); -} - -std::tuple generator::fp16x4_to_fp8x4(Value *in0, Value *in1, Value *in2, Value *in3) { - /* fp16 bit representation is seeeeemmmmmmmmmm (s=sign, e=exponent, m=mantissa) - * fp8 bit representation is seeeemmm - * The 4 fp8 exponent bits are the low order 4 exponent bits in fp16. - * The 3 fp8 mantissa bits are the high order 3 mantissa bits in fp16. - * Note that the low order exponent bits and high order mantissa bits in fp16 are contiguous. - * We want to round to nearest fp8 value. To do that add 1 to 4th mantissa bit in fp16 (that's - * one more than the number of mantissa bits in fp8). - * fp8 = (fp16 & 0x8000) | (((f16 << 1) + 0x0080) & 0x7fff) - * - * We compute two fp16s in one uint32. The addition could cause bit flips from one fp16 to the - * other. To avoid this we zero out the most significant exponent bit. If that bit is set then - * the value isn't representable in float8 anyway so we assume it's never set (and give garbage - * output if it is). If we were willing to assume the most significant exponent was never set - * we could save the first two lop3.b32 instructions below. - */ - InlineAsm *ptx = InlineAsm::get(FunctionType::get({vec_ty(i8_ty, 4)}, {i32_ty, i32_ty}, false), - "{" - ".reg .b32 a<2>, b<2>; \n\t" - "shl.b32 a0, $1, 1; \n\t" // a0 = input0 << 1 - "shl.b32 a1, $2, 1; \n\t" // a1 = input1 << 1 - "lop3.b32 a0, a0, 0x7fff7fff, 0, 0xc0; \n\t" // a0 = (a0 & 0x7fff7fff) - "lop3.b32 a1, a1, 0x7fff7fff, 0, 0xc0; \n\t" // a1 = (a1 & 0x7fff7fff) - "add.u32 a0, a0, 0x00800080; \n\t" // a0 += 0x00800080 - "add.u32 a1, a1, 0x00800080; \n\t" // a1 += 0x00800080 - "lop3.b32 b0, $1, 0x80008000, a0, 0xea; \n\t" // b0 = (input0 & 0x80008000) | a0 - "lop3.b32 b1, $2, 0x80008000, a1, 0xea; \n\t" // b1 = (input1 & 0x80008000) | a1 - "prmt.b32 $0, b0, b1, 0x7531; \n\t" // If b0 = 0xabcd and b1=0x0123 sets output to 0xac02 - "}", "=r,r,r", false); - Value *packed_in0 = UndefValue::get(vec_ty(f16_ty, 2)); - Value *packed_in1 = UndefValue::get(vec_ty(f16_ty, 2)); - packed_in0 = insert_elt(packed_in0, in0, (int)0); - packed_in0 = insert_elt(packed_in0, in1, (int)1); - packed_in1 = insert_elt(packed_in1, in2, (int)0); - packed_in1 = insert_elt(packed_in1, in3, (int)1); - Value *in_arg0 = bit_cast(packed_in0, i32_ty); - Value *in_arg1 = bit_cast(packed_in1, i32_ty); - Value *ret = call(ptx, {in_arg0, in_arg1}); - Value *ret0 = extract_elt(ret, (int)0); - Value *ret1 = extract_elt(ret, (int)1); - Value *ret2 = extract_elt(ret, (int)2); - Value *ret3 = extract_elt(ret, (int)3); - return std::make_tuple(ret0, ret1, ret2, ret3); -} - -Value* generator::bf16_to_fp32(Value *in0){ - if (tgt_->as_nvidia()->sm() >= 80) { - InlineAsm *ptx = InlineAsm::get(FunctionType::get(f32_ty, {bf16_ty}, false), - "cvt.rn.f32.bf16 $0, $1;", "=r,h", false); - return call(ptx, {in0}); - } else { - Value *ret = UndefValue::get(vec_ty(i16_ty, 2)); - ret = insert_elt(ret, bit_cast(in0, i16_ty), (uint64_t)1); - ret = insert_elt(ret, bit_cast(builder_->getInt16(0), i16_ty), (uint64_t)0); - return bit_cast(ret, f32_ty); - } -} - -Value* generator::fp32_to_bf16(Value *in0){ - if(tgt_->as_nvidia()->sm() >= 80){ - InlineAsm *ptx = InlineAsm::get(FunctionType::get(bf16_ty, {f32_ty}, false), - "cvt.rn.bf16.f32 $0, $1;", "=h,r", false); - return call(ptx, {in0}); - } - return extract_elt(bit_cast(in0, vec_ty(i16_ty, 2)), (uint64_t)1); -} - -/** - * \brief Code Generation for `cast` - */ -void generator::visit_cast_inst(ir::cast_inst* x) { - ir::value *op = x->get_operand(0); - ir::type* ret_sca_ty = x->get_type()->get_scalar_ty(); - ir::type* op_sca_ty = op->get_type()->get_scalar_ty(); - auto x_idxs = idxs_.at(x); - auto op_idxs = idxs_.at(op); - - // <> FP8 - if(ret_sca_ty->is_fp8_ty() || op_sca_ty->is_fp8_ty()){ - // ensure that conversions can be vectorized - int ld = layouts_->get(x)->get_order(0); - int contiguous = layouts_->get(x)->to_scanline()->nts(ld); - if(contiguous % 4 != 0) - throw std::runtime_error("unsupported fp32 -> fp8 conversion"); - - // run the conversion - auto cvt = [&](Value* a, Value* b, Value* c, Value* d){ - if(op_sca_ty->is_fp32_ty() && ret_sca_ty->is_fp8_ty()) - return fp32x4_to_fp8x4(a, b, c, d); - if(op_sca_ty->is_fp16_ty() && ret_sca_ty->is_fp8_ty()) - return fp16x4_to_fp8x4(a, b, c, d); - if(op_sca_ty->is_fp8_ty() && ret_sca_ty->is_fp16_ty()) - return fp8x4_to_fp16x4(a, b, c, d); - if(op_sca_ty->is_fp8_ty() && ret_sca_ty->is_fp32_ty()) - return fp8x4_to_fp32x4(a, b, c, d); - throw std::runtime_error("unsupported conversion"); - }; - for(size_t i = 0; i < x_idxs.size(); i+=4){ - std::tie(vals_[x][x_idxs[i+0]], - vals_[x][x_idxs[i+1]], - vals_[x][x_idxs[i+2]], - vals_[x][x_idxs[i+3]]) = cvt(vals_[op][op_idxs[i+0]], - vals_[op][op_idxs[i+1]], - vals_[op][op_idxs[i+2]], - vals_[op][op_idxs[i+3]]); - } - return; - } - - // <> BF16 - if(ret_sca_ty->is_bf16_ty() || op_sca_ty->is_bf16_ty()){ - // FP32 -> BF16 - if(op_sca_ty->is_fp32_ty()){ - for (indices_t idx: idxs_.at(x)) { - Value *arg = vals_[x->get_operand(0)][idx]; - vals_[x][idx] = fp32_to_bf16(arg); // cast(cvt(x->get_op()), arg, ty); - } - return; - } - // BF16 -> FP32 - if(ret_sca_ty->is_fp32_ty()){ - for(size_t i = 0; i < x_idxs.size(); i++) - vals_[x][x_idxs[i + 0]] = bf16_to_fp32(vals_[op][op_idxs[i + 0]]); - return; - } - } - - - Type *ty = cvt(x->get_type()->get_scalar_ty()); - auto cvt = [](ir::cast_op_t op){ - using ll = llvm::Instruction::CastOps; - using tt = ir::cast_op_t; - switch(op){ - case tt::Trunc: return ll::Trunc; - case tt::ZExt: return ll::ZExt; - case tt::SExt: return ll::SExt; - case tt::FPTrunc: return ll::FPTrunc; - case tt::FPExt: return ll::FPExt; - case tt::UIToFP: return ll::UIToFP; - case tt::SIToFP: return ll::SIToFP; - case tt::FPToUI: return ll::FPToUI; - case tt::FPToSI: return ll::FPToSI; - case tt::PtrToInt: return ll::PtrToInt; - case tt::IntToPtr: return ll::IntToPtr; - case tt::BitCast: return ll::BitCast; - case tt::AddrSpaceCast: return ll::AddrSpaceCast; - default: throw std::runtime_error("unreachable switch"); - } - }; - for(indices_t idx: idxs_.at(x)){ - Value *arg = vals_[x->get_operand(0)][idx]; - vals_[x][idx] = cast(cvt(x->get_op()), arg, ty); - } -} - -/** - * \brief Code Generation for `return` - */ -void generator::visit_return_inst(ir::return_inst* rr) { - ir::value *ret_val = rr->get_return_value(); - ret(ret_val ? vals_[ret_val][{}] : nullptr); -} - -/** - * \brief Code Generation for `cond_branch` - */ -void generator::visit_cond_branch_inst(ir::cond_branch_inst* br) { - BasicBlock *true_dest = bbs_.at(br->get_true_dest()); - BasicBlock *false_dest = bbs_.at(br->get_false_dest()); - Value *cond = vals_[br->get_cond()][{}]; - cond_br(cond, true_dest, false_dest); -} - -/** - * \brief Code Generation for `uncond_branch` - */ -void generator::visit_uncond_branch_inst(ir::uncond_branch_inst* br) { - BasicBlock *dest = bbs_.at(br->get_dest()); - br(dest); -} - -/** - * \brief Code Generation for a (synchronous) `load` - */ -void generator::visit_load_inst(ir::load_inst* x){ - ir::value *op = x->get_pointer_operand(); - ir::masked_load_inst *mx = dynamic_cast(x); - Type* ty = cvt(op->get_type()->get_scalar_ty()->get_pointer_element_ty()); - // compute vector width - size_t vec = 1; - if(op->get_type()->is_block_ty()){ - auto ord = ords_.at(op); - size_t aln = alignment_->get(op, ord[0]); - auto layout = layouts_->get(x)->to_scanline(); - if(layout){ - size_t nts = layout->nts(ord[0]); - vec = std::min(nts, aln); - } - } - // code generation - auto idxs = idxs_.at(x); - for(size_t i = 0; i < idxs.size(); i += vec){ - indices_t idx = idxs[i]; - // pointer value - Value *ptr = vals_[op][idx]; - // masked load - size_t dtsize = x->get_type()->get_scalar_ty()->get_primitive_size_in_bits() / 8; - // input ptr info - GetElementPtrInst *in_gep = dyn_cast(ptr); - size_t in_off; - if(in_gep){ - ConstantInt* cst = dyn_cast(in_gep->idx_begin()); - in_off = cst ? cst->getValue().getSExtValue()*dtsize : 0; - ptr = cst ? in_gep->getPointerOperand() : in_gep; - } - else{ - in_off = 0; - } - Value *pred = mx ? vals_[mx->get_mask_operand()][idx] : builder_->getTrue(); - Value *other = mx ? vals_[mx->get_false_value_operand()][idx] : nullptr; - size_t nbits = dtsize*8; - // pack sub-words (< 32/64bits) into words - // each load has width min(nbits*vec, 32/64) - // and there are (nbits * vec)/width of them - int max_word_width = std::max(32, nbits); - int tot_width = nbits*vec; - int width = std::min(tot_width, max_word_width); - int n_words = std::max(1, tot_width / width); - // ----- - // create inline asm string - // ----- - std::ostringstream asm_oss; - asm_oss << "@$" << n_words; // predicate - asm_oss << " ld"; - if(x->get_is_volatile()) - asm_oss << ".volatile"; - asm_oss << ".global"; - if (x->get_cache_modifier() == ir::load_inst::CA) asm_oss << ".ca"; - if (x->get_cache_modifier() == ir::load_inst::CG) asm_oss << ".cg"; - if (x->get_eviction_policy() == ir::load_inst::EVICT_LAST) asm_oss << ".L1::evict_last"; - if (x->get_eviction_policy() == ir::load_inst::EVICT_FIRST) asm_oss << ".L1::evict_first"; - if(n_words > 1) - asm_oss << ".v" << n_words; // vector width - asm_oss << ".b" << width; // word size - asm_oss << " {"; - for(int i = 0; i < n_words; i++){ // return values - if(i > 0) asm_oss << ","; - asm_oss << "$" << i; - } - asm_oss << "}"; - asm_oss << ", [ $" << n_words + 1; // load - asm_oss << " + " << in_off << "];"; // constant offset - bool has_other = other && (other != UndefValue::get(other->getType())); - std::vector others; - // handle `other` values for indices where the mask - // is false - if(has_other) - for(size_t ii = 0; ii < n_words; ii++){ - size_t size = width / nbits; - Value *v = UndefValue::get(vec_ty(ty, size)); - for(size_t s = 0; s < size; s++){ - ir::value *false_val = mx->get_false_value_operand(); - v = insert_elt(v, vals_[false_val][idxs[i + ii*size + s]], s); - } - v = bit_cast(v, IntegerType::get(*ctx_, width)); - asm_oss << "\n "; - asm_oss << "@!$" << n_words << " mov.u" << width; - asm_oss << " $" << ii << ", "; - std::ios_base::fmtflags flags(asm_oss.flags()); - if(ConstantInt* cst = dyn_cast(v)) - asm_oss << "0x" << std::hex << cst->getSExtValue(); - else{ - asm_oss << "$" << n_words + 2 + ii; - others.push_back(v); - } - asm_oss.flags(flags); - asm_oss << ";"; - } - // ---- - // create inline ASM signature - // --- - std::vector ret_tys(n_words, IntegerType::get(*ctx_, width)); - Type* ret_ty = ret_tys.size() > 1 ? StructType::get(*ctx_, ret_tys) : ret_tys[0]; - // ret_ty->print(llvm::outs()); - std::vector arg_tys = {pred->getType(), ptr->getType()}; - for(Value *v: others) - arg_tys.push_back(v->getType()); - FunctionType *asm_ty = FunctionType::get(ret_ty, arg_tys, false); - // --- - // create inline ASM constraints - // --- - std::string asm_cstrt; - for(int ii = 0; ii < n_words; ii++){ - if(ii > 0) asm_cstrt += ","; - asm_cstrt += (width == 64) ? "=l" : ((width == 32) ? "=r" : "=c"); - } - asm_cstrt += ",b,l"; - for(size_t ii = 0; ii < others.size(); ii++){ - asm_cstrt += ","; - asm_cstrt += (width == 64) ? "l" : ((width == 32) ? "r" : "c"); - } - // --- - // finally call inline ASM - // --- - InlineAsm *inlineAsm = InlineAsm::get(asm_ty, asm_oss.str(), asm_cstrt, true); - std::vector args = {pred, ptr}; - for(Value *v: others) - args.push_back(v); - Value *_ret = call(inlineAsm, args); - // --- - // extract and store return values - // --- - std::vector rets; - for(unsigned int ii = 0; ii < n_words; ii++){ - Value *curr; - if(ret_ty->isStructTy()) - curr = extract_val(_ret, {ii}); - else - curr = _ret; - rets.push_back(bit_cast(curr, vec_ty(ty, width / (dtsize*8)))); - } - int tmp = (width / (dtsize * 8)); - for(size_t ii = 0; ii < vec; ii++) - vals_[x][idxs[i+ii]] = extract_elt(rets[ii/tmp], ii % tmp); - } -} - -void generator::visit_unmasked_load_inst(ir::unmasked_load_inst* x) { - visit_load_inst(x); -} -void generator::visit_masked_load_inst(ir::masked_load_inst* x) { - visit_load_inst(x); -} - -/** - * \brief Code Generation for a (synchronous) `store` - */ - -void generator::visit_store_inst(ir::store_inst * x){ - ir::masked_store_inst *mx = dynamic_cast(x); - // operands - ir::value *ptr_op = x->get_pointer_operand(); - ir::value *val_op = x->get_value_operand(); - // vector size - size_t vec = 1; - if(val_op->get_type()->is_block_ty()){ - auto ord = ords_.at(x->get_pointer_operand()); - size_t aln = alignment_->get(ptr_op, ord[0]); - size_t nts = axes_.at(a_axes_->get(x->get_pointer_operand(), ord[0])).contiguous; - vec = std::min(nts, aln); - } - auto idxs = idxs_.at(val_op); - Type *ty = cvt(val_op->get_type()->get_scalar_ty()); - if (ty->isBFloatTy()) // llvm11-nvptx cannot select bf16 store - ty = f16_ty; - for(size_t i = 0; i < idxs.size(); i += vec){ - auto idx = idxs[i]; - // pointer - Value *ptr = vals_[ptr_op][idx]; - // vectorize - Type *v_ty = vec_ty(ty, vec); - ptr = bit_cast(ptr, v_ty->getPointerTo(1)); - // value - Value* val = UndefValue::get(v_ty); - for(size_t ii = 0; ii < vec; ii++) - val = insert_elt(val, bit_cast(vals_.at(val_op)[idxs[i + ii]], ty), ii); - if(mx){ - Value *msk = vals_[mx->get_mask_operand()][idx]; - Instruction *no_op = intrinsic(Intrinsic::donothing, {}, {}); - builder_->SetInsertPoint(no_op->getParent()); - Instruction* dummy = builder_->CreateRet(nullptr); - Instruction *term = llvm::SplitBlockAndInsertIfThen(msk, no_op, false); - dummy->removeFromParent(); - builder_->SetInsertPoint(term); - store(val, ptr); - builder_->SetInsertPoint(no_op); - } - else - store(val, ptr); - } -} -void generator::visit_unmasked_store_inst(ir::unmasked_store_inst* x) { - visit_store_inst(x); -} -void generator::visit_masked_store_inst(ir::masked_store_inst* x) { - visit_store_inst(x); -} - -/** - * \brief Code Generation for `cat` - */ -void generator::visit_cat_inst(ir::cat_inst* x) { - auto idxs = idxs_.at(x); - ir::value* lhs = x->get_operand(0); - ir::value* rhs = x->get_operand(1); - int i = 0; - for(size_t j = 0; j < idxs_.at(lhs).size(); j ++){ - vals_[x][idxs_[x][i++]] = vals_[lhs][idxs_[lhs][j]]; - } - for(size_t j = 0; j < idxs_.at(rhs).size(); j ++){ - vals_[x][idxs_[x][i++]] = vals_[rhs][idxs_[rhs][j]]; - } -} - - - -/** - * \brief Code Generation for `reshape` - */ -void generator::visit_reshape_inst(ir::reshape_inst* x) { - auto idxs = idxs_.at(x); - for(size_t i = 0; i < idxs_.at(x).size(); i ++){ - ir::value* op = x->get_operand(0); - vals_[x][idxs_[x][i]] = vals_[op][idxs_[op][i]]; - }; -} - -/** - * \brief Code Generation for `splat` - */ -void generator::visit_splat_inst(ir::splat_inst* x) { - for(auto idx: idxs_.at(x)) - vals_[x][idx] = vals_[x->get_operand(0)][{}]; -} - -/** - * \brief Code Generation for `broadcast` - */ -void generator::visit_broadcast_inst(ir::broadcast_inst* x) { - ir::value* op = x->get_operand(0); - const auto& shape = op->get_type()->get_block_shapes(); - for(auto out_idx: idxs_.at(x)){ - indices_t in_idx = out_idx; - for(size_t k = 0; k < in_idx.size(); k++) - in_idx[k] = shape[k] == 1 ? i32(0) : in_idx[k]; - vals_[x][out_idx] = vals_[op][in_idx]; - } -// for(size_t i = 0; i < idxs_.at(x).size(); i++) -// vals_[x][idxs_[x][i]] = vals_[op][idxs_[op][i]]; -} - -/** - * \brief Code Generation for `downcast` - */ -void generator::visit_downcast_inst(ir::downcast_inst* x) { - vals_[x][{}] = vals_[x->get_operand(0)][{i32(0)}]; -} - -/** - * \brief Code Generation for `get_program_id` - */ -void generator::visit_get_program_id_inst(ir::get_program_id_inst* pid) { - Module *module = builder_->GetInsertBlock()->getModule(); - Value *ret = tgt_->get_block_id(module, *builder_, pid->get_axis()); - vals_[pid][{}] = ret; -} - -/** - * \brief Code Generation for `get_num_programs` - */ -void generator::visit_get_num_programs_inst(ir::get_num_programs_inst* np) { - Module *module = builder_->GetInsertBlock()->getModule(); - Value *ret = tgt_->get_num_blocks(module, *builder_, np->get_axis()); - vals_[np][{}] = ret; -} - -/** - * \brief Code Generation for `exp` - */ -void generator::visit_exp_inst(ir::exp_inst* x){ - Constant *log2e = ConstantFP::get(f32_ty, 1.4426950408889634); - std::vector tys = {f32_ty}; - FunctionType *fn_ty = FunctionType::get(f32_ty, tys, false); - InlineAsm *ex2 = InlineAsm::get(fn_ty, "ex2.approx.f32 $0, $0;", "=f,0", false); - for(auto idx: idxs_.at(x)){ - Value *ex2arg = fmul(vals_[x->get_operand(0)][idx], log2e); - vals_[x][idx] = call(ex2, std::vector{ex2arg}); - } -} - -/** - * \brief Code Generation for `cos` - */ -void generator::visit_cos_inst(ir::cos_inst* x){ - std::vector tys = {f32_ty}; - FunctionType *fn_ty = FunctionType::get(f32_ty, tys, false); - InlineAsm *cos = InlineAsm::get(fn_ty, "cos.approx.f32 $0, $0;", "=f,0", false); - for(auto idx: idxs_.at(x)){ - vals_[x][idx] = call(cos, std::vector{vals_[x->get_operand(0)][idx]}); - } - } - -/** - * \brief Code Generation for `umulhi` - */ -void generator::visit_umulhi_inst(ir::umulhi_inst* x){ - std::vector tys = {i32_ty, i32_ty}; - FunctionType *fn_ty = FunctionType::get(i32_ty, tys, false); - InlineAsm *umulhi = InlineAsm::get(fn_ty, "mul.hi.u32 $0, $1, $2;", "=r,r,r", false); - for(auto idx: idxs_.at(x)){ - Value* lhs = vals_[x->get_operand(0)][idx]; - Value* rhs = vals_[x->get_operand(1)][idx]; - vals_[x][idx] = call(umulhi, std::vector{lhs, rhs}); - } - } - -/** - * \brief Code Generation for `sin` - */ -void generator::visit_sin_inst(ir::sin_inst* x){ - std::vector tys = {f32_ty}; - FunctionType *fn_ty = FunctionType::get(f32_ty, tys, false); - InlineAsm *sin = InlineAsm::get(fn_ty, "sin.approx.f32 $0, $0;", "=f,0", false); - for(auto idx: idxs_.at(x)){ - vals_[x][idx] = call(sin, std::vector{vals_[x->get_operand(0)][idx]}); - } - } - -/** - * \brief Code Generation for `log` - */ -void generator::visit_log_inst(ir::log_inst* x){ - Constant *rcplog2e = ConstantFP::get(f32_ty, 0.6931471805599453); - std::vector tys = {f32_ty}; - FunctionType *fn_ty = FunctionType::get(f32_ty, tys, false); - InlineAsm *lg2 = InlineAsm::get(fn_ty, "lg2.approx.f32 $0, $1;", "=f,f", false); - for(auto idx: idxs_.at(x)){ - Value *lg2arg = call(lg2, std::vector{vals_[x->get_operand(0)][idx]}); - vals_[x][idx] = fmul(lg2arg, rcplog2e); - } -} - -/** - * \brief Code Generation for `atomic_cas` - */ -void generator::visit_atomic_cas_inst(ir::atomic_cas_inst* cas) { - BasicBlock *current = builder_->GetInsertBlock(); - Module *module = current->getModule(); - Value *tid = tgt_->get_local_id(module, *builder_, 0); - Value *pred = icmp_eq(tid, i32(0)); -// BasicBlock *tid_0_bb = BasicBlock::Create(*ctx_, "tid_0", current->getParent()); -// BasicBlock *tid_0_done_bb = BasicBlock::Create(*ctx_, "tid_0_done", current->getParent()); - add_barrier(); - tgt_->add_memfence(module, *builder_); - Value *atom_ptr; - atom_ptr = gep(shmem_, i32(alloc_->offset(layouts_->get(layouts_->tmp(cas)))), ""); - atom_ptr = bit_cast(atom_ptr, ptr_ty(cvt(cas->get_type()->get_scalar_ty()), 3)); -// cond_br(pred, tid_0_bb, tid_0_done_bb); -// builder_->SetInsertPoint(tid_0_bb); - Value *cas_ptr = vals_[cas->get_operand(0)][{}]; - Value *cas_cmp = vals_[cas->get_operand(1)][{}]; - Value *cas_val = vals_[cas->get_operand(2)][{}]; - std::string asm_str = "@$1 atom.global.cas.b32 $0, [$2], $3, $4;"; - FunctionType *fn_ty = FunctionType::get(i32_ty, {pred->getType(), cas_ptr->getType(), cas_cmp->getType(), cas_val->getType()}, false); - InlineAsm *iasm = InlineAsm::get(fn_ty, asm_str, "=r,b,l,r,r", true); - add_barrier(); - Value *old = call(iasm, {pred, cas_ptr, cas_cmp, cas_val}); - add_barrier(); - - std::string asm2_str = "@$0 st.shared.b32 [$1], $2;"; - FunctionType *fn2_ty = FunctionType::get(void_ty, {pred->getType(), atom_ptr->getType(), old->getType()}, false); - InlineAsm *iasm2 = InlineAsm::get(fn2_ty, asm2_str, "b,r,r", true); - add_barrier(); - call(iasm2, {pred, atom_ptr, old}); - tgt_->add_memfence(module, *builder_); - add_barrier(); - vals_[cas][{}] = load(atom_ptr); - add_barrier(); -} - -/** - * \brief Code Generation for `atomic_rmw` - */ -void generator::visit_atomic_rmw_inst(ir::atomic_rmw_inst *atom) { - ir::value* ptr = atom->get_operand(0); - ir::value* val = atom->get_operand(1); - ir::value* msk = atom->get_operand(2); - - // vector size - int vec = 1; - if(atom->get_type()->is_block_ty()){ - int ld = ords_.at(ptr)[0]; - unsigned alignment = alignment_->get(ptr, ld); - vec = std::min(layouts_->get(ptr)->to_scanline()->nts(ld), alignment); - vec = std::min(vec, val->get_type()->get_tile_element_ty()->is_fp16_ty() ? 2 : 1); - } - - for(int i = 0; i < idxs_.at(val).size(); i += vec){ - auto idx = idxs_[val][i]; - Value *rmw_val = UndefValue::get(vec_ty(vals_[val][idx]->getType(), vec)); - for(int ii = 0; ii < vec; ii++) - rmw_val = insert_elt(rmw_val, vals_[val][idxs_[val][i+ii]], ii); - Value *rmw_ptr = vals_[ptr][idx]; - Value *rmw_msk = vals_[msk][idx]; - if(vec == 1) - rmw_val = extract_elt(rmw_val, i32(0)); - Type* ty = rmw_val->getType(); - size_t nbits = ty->getScalarSizeInBits(); - // extract pointer offset - std::string offset = ""; - if(GetElementPtrInst *gep = dyn_cast(rmw_ptr)) - if(gep->getNumIndices() == 1) - if(ConstantInt *cst = dyn_cast(gep->idx_begin())){ - offset = " + " + std::to_string(cst->getValue().getSExtValue()*nbits/8); - rmw_ptr = gep->getPointerOperand(); - } - rmw_ptr = bit_cast(rmw_ptr, ty->getPointerTo(1)); - // asm argument type - std::vector arg_ty = {rmw_msk->getType(), rmw_ptr->getType(), rmw_val->getType()}; - // asm function type - FunctionType *fn_ty = FunctionType::get(ty, arg_ty, false); - // asm string - std::string s_nbits = std::to_string(nbits); - std::string name; - std::string s_ty; - using tt = ir::atomic_rmw_op_t; - switch(atom->get_op()){ - case tt::Or: name = "or"; s_ty = "b"; break; - case tt::And: name = "and"; s_ty = "b"; break; - case tt::Xor: name = "xor", s_ty = "b"; break; - case tt::Add: name = "add" , s_ty = "s"; break; - case tt::Min: name = "min", s_ty = "s"; break; - case tt::Max: name = "max", s_ty = "s"; break; - case tt::UMin: name = "min", s_ty = "u"; break; - case tt::UMax: name = "max", s_ty = "u"; break; - case tt::FAdd: name = "add", s_ty = "f"; break; - case tt::Xchg: name = "exch", s_ty = "b"; break; - } - std::string s_vec = vec == 2 ? "x2" : ""; - std::string mod = nbits == 32 ? "" : ".noftz"; - - std::string asm_str = "@$1 atom.global.gpu." + name + mod + "." + s_ty + s_nbits + s_vec + " $0, [$2" + offset + "], $3;"; - std::string ty_id = nbits*vec == 32 ? "r" : "h"; - std::string constraint = "=" + ty_id + ",b,l," + ty_id; - // create inline asm - InlineAsm *iasm = InlineAsm::get(fn_ty, asm_str, constraint, true); - // call asm - if(atom->get_type()->is_block_ty()) - vals_[atom][idx] = call(iasm, (ArrayRef{rmw_msk, rmw_ptr, rmw_val})); - else{ - Module *mod = builder_->GetInsertBlock()->getModule(); - tgt_->add_memfence(mod, *builder_); - add_barrier(); - Value *tid = tgt_->get_local_id(mod, *builder_, 0); - rmw_msk = builder_->CreateAnd(rmw_msk, icmp_eq(tid, i32(0))); - Value *old = call(iasm, (ArrayRef{rmw_msk, rmw_ptr, rmw_val})); - Value *atom_ptr; - atom_ptr = gep(shmem_, i32(alloc_->offset(layouts_->get(layouts_->tmp(atom)))), ""); - atom_ptr = bit_cast(atom_ptr, ptr_ty(old->getType(), 3)); - store(old, atom_ptr); - add_barrier(); - vals_[atom][idx] = load(atom_ptr); - add_barrier(); - } - } -} - -/** - * \brief Code Generation for `mma.884` (V100) - */ -//TODO: clean-up -void generator::visit_mma884(ir::dot_inst* C, ir::value *A, ir::value *B, ir::value *D, unsigned NK) { - // shapes - auto shape_c = C->get_type()->get_block_shapes(); - auto shape_a = A->get_type()->get_block_shapes(); - auto shape_b = B->get_type()->get_block_shapes(); - // order - auto ord_a = layouts_->get(A)->get_order(); - auto ord_b = layouts_->get(B)->get_order(); - // layouts - analysis::mma_layout* layout_c = layouts_->get(C)->to_mma(); - analysis::shared_layout* layout_a = layouts_->get(A)->to_shared(); - analysis::shared_layout* layout_b = layouts_->get(B)->to_shared(); - // vectorization - int vec_a = swizzle_->get_vec(layout_a); - int vec_b = swizzle_->get_vec(layout_b); - // strides - bool is_a_row = ord_a[0] != 0; - bool is_b_row = ord_b[0] != 0; - int stride_am = is_a_row ? shape_a[1] : 1; - int stride_ak = is_a_row ? 1 : shape_a[0]; - int stride_a0 = is_a_row ? stride_ak : stride_am; - int stride_a1 = is_a_row ? stride_am : stride_ak; - int stride_bn = is_b_row ? 1 : shape_b[0]; - int stride_bk = is_b_row ? shape_b[1] : 1; - int stride_b0 = is_b_row ? stride_bn : stride_bk; - int stride_b1 = is_b_row ? stride_bk : stride_bn; - int stride_rep_m = layout_c->wpt(0) * layout_c->fpw(0) * 8; - int stride_rep_n = layout_c->wpt(1) * layout_c->fpw(1) * 8; - int stride_rep_k = 1; - // swizzling - int per_phase_a = swizzle_->get_per_phase(layout_a); - int max_phase_a = swizzle_->get_max_phase(layout_a); - int step_a0 = is_a_row ? stride_rep_k : stride_rep_m; - int num_ptr_a = std::max(2 * per_phase_a * max_phase_a / step_a0, 1); - int per_phase_b = swizzle_->get_per_phase(layout_b); - int max_phase_b = swizzle_->get_max_phase(layout_b); - int step_b0 = is_b_row ? stride_rep_n : stride_rep_k; - int num_ptr_b = std::max(2 * per_phase_b * max_phase_b / step_b0, 1); - - /* --------------------------------- */ - /* --- pre-compute pointer lanes --- */ - /* --------------------------------- */ - BasicBlock* curr_bb = builder_->GetInsertBlock(); - BasicBlock* entry = &curr_bb->getParent()->getEntryBlock(); - if(entry != curr_bb) - builder_->SetInsertPoint(entry->getTerminator()); - Value* off_a0 = is_a_row ? offset_a_k_[layout_c] : offset_a_m_[layout_c]; - Value* off_a1 = is_a_row ? offset_a_m_[layout_c] : offset_a_k_[layout_c]; - Value* phase_a = urem(udiv(off_a1, i32(per_phase_a)), i32(max_phase_a)); - std::vector off_a(num_ptr_a); - for(int i = 0; i < num_ptr_a; i++){ - Value* off_a0i = add(off_a0, i32(i*(is_a_row?4:stride_rep_m))); - off_a0i = exact_udiv(off_a0i, i32(vec_a)); - off_a0i = xor_(off_a0i, phase_a); - off_a0i = mul(off_a0i, i32(vec_a)); - off_a[i] = add(mul(off_a0i, i32(stride_a0)), mul(off_a1, i32(stride_a1))); - } - Value* off_b0 = is_b_row ? offset_b_n_[layout_c] : offset_b_k_[layout_c]; - Value* off_b1 = is_b_row ? offset_b_k_[layout_c] : offset_b_n_[layout_c]; - Value* phase_b = urem(udiv(off_b1, i32(per_phase_b)), i32(max_phase_b)); - std::vector off_b(num_ptr_b); - for(int i = 0; i < num_ptr_b; i++){ - Value* off_b0i = add(off_b0, i32(i*(is_b_row?stride_rep_n:4))); - off_b0i = udiv(off_b0i, i32(vec_b)); - off_b0i = xor_(off_b0i, phase_b); - off_b0i = mul(off_b0i, i32(vec_b)); - off_b[i] = add(mul(off_b0i, i32(stride_b0)), mul(off_b1, i32(stride_b1))); - } - builder_->SetInsertPoint(curr_bb); - - /* --------------------------------- */ - /* --- MMA intrinsic --- */ - /* --------------------------------- */ - Type *f16x2_ty = vec_ty(f16_ty, 2); - Type *ret_ty = StructType::get(*ctx_, {f32_ty, f32_ty, f32_ty, f32_ty, f32_ty, f32_ty, f32_ty, f32_ty}); - std::vector arg_ty = {f16x2_ty, f16x2_ty, f16x2_ty, f16x2_ty, - f32_ty, f32_ty, f32_ty, f32_ty, f32_ty, f32_ty, f32_ty, f32_ty}; - InlineAsm *mma = InlineAsm::get(FunctionType::get(ret_ty, arg_ty, false), - " mma.sync.aligned.m8n8k4." - + std::string(is_a_row ? "row" : "col") - + "." - + std::string(is_b_row ? "row" : "col") - + ".f32.f16.f16.f32 " - "{$0, $1, $2, $3, $4, $5, $6, $7}, " - "{$8, $9}, " - "{$10, $11}, " - "{$0, $1, $2, $3, $4, $5, $6, $7};", "=f,=f,=f,=f,=f,=f,=f,=f,r,r,r,r,0,1,2,3,4,5,6,7", false); - - - std::vector ptr_a(num_ptr_a); - std::vector ptr_b(num_ptr_b); - std::map, std::pair> has, hbs; - for(int i = 0; i < num_ptr_a; i++) - ptr_a[i] = gep(shmems_[A], off_a[i]); - for(int i = 0; i < num_ptr_b; i++) - ptr_b[i] = gep(shmems_[B], off_b[i]); - - - // initialize accumulators - std::vector acc; - for(indices_t idx: idxs_.at(C)) - acc.push_back(vals_[D][idx]); - - unsigned num_m = layout_c->rep(0) * shape_c[0] / layout_c->shape_per_cta(0); - unsigned num_n = layout_c->rep(1) * shape_c[1] / layout_c->shape_per_cta(1); - - // create mma & unpack result - auto call_mma = [&](unsigned m, unsigned n, unsigned K) { - auto ha = has[{m, K}]; - auto hb = hbs[{n, K}]; - // arguments - std::vector idx = { - (m*2 + 0) + (n*4 + 0)*num_m, (m*2 + 0) + (n*4 + 1)*num_m, - (m*2 + 1) + (n*4 + 0)*num_m, (m*2 + 1) + (n*4 + 1)*num_m, - (m*2 + 0) + (n*4 + 2)*num_m, (m*2 + 0) + (n*4 + 3)*num_m, - (m*2 + 1) + (n*4 + 2)*num_m, (m*2 + 1) + (n*4 + 3)*num_m - }; - std::vector args = {ha.first, ha.second, hb.first, hb.second}; - for(unsigned i = 0; i < 8; i++) - args.push_back(acc[idx[i]]); - // execute mma - Value *nc = call(mma, args); - // unpack - for(unsigned i = 0; i < 8; i++) - acc[idx[i]] = extract_val(nc, {i}); - }; - - ir::phi_node* phiA = dynamic_cast(A); - ir::phi_node* phiB = dynamic_cast(B); - - // Cache lds value. If values are prefetched, create phi node - // @param inc: incoming block (0 = header, 1 = loop) - auto register_lds = - [&](decltype(has)& vals, int m, int K, int inc, Value* val0, Value *val1, bool is_prefetch) { - if (K == 0 && is_prefetch) { - ir::basic_block* inc_block = phiA->get_incoming_block(inc); - lazy_phi_incs_.push_back(std::make_tuple((PHINode*)vals[{m, K}].first, val0, inc_block)); - lazy_phi_incs_.push_back(std::make_tuple((PHINode*)vals[{m, K}].second, val1, inc_block)); - } else - vals[{m, K}] = {val0, val1}; - }; - - auto load_a = [&](int m, int K, int inc, bool is_prefetch) { - int offidx = (is_a_row ? K/4 : m) % num_ptr_a; - Value* ptra; - if(K==0 && is_prefetch){ - if(inc == 0) - ptra = gep(shared_pre_ptr_[layout_a], off_a[offidx]); - else - ptra = gep(shared_next_ptr_[layout_a], off_a[offidx]); - } - else - ptra = ptr_a[offidx]; - int step_am = is_a_row ? m : m / (num_ptr_a)*(num_ptr_a); - int step_ak = is_a_row ? K / (num_ptr_a*vec_a)*(num_ptr_a*vec_a) : K; - Value* pa = gep(ptra, i32(step_am*stride_rep_m*stride_am + step_ak*stride_ak)); - Value* ha = load(bit_cast(pa, ptr_ty(vec_ty(i32_ty, vec_a/2), 3))); - // record lds that needs to be moved - if (K == 0 && inc == 1 && is_prefetch) - prefetch_latch_to_bb_[phiA->get_incoming_value(1)].push_back(ha); - Value *ha00 = bit_cast(extract_elt(ha, i32(0)), f16x2_ty); - Value *ha01 = bit_cast(extract_elt(ha, i32(1)), f16x2_ty); - register_lds(has, m, K, inc, ha00, ha01, is_prefetch); - if(vec_a > 4){ - Value *ha10 = bit_cast(extract_elt(ha, i32(2)), f16x2_ty); - Value *ha11 = bit_cast(extract_elt(ha, i32(3)), f16x2_ty); - if(is_a_row) - register_lds(has, m, K+4, inc, ha10, ha11, is_prefetch); - else - register_lds(has, m+1, K, inc, ha10, ha11, is_prefetch); - } - }; - - auto load_b = [&](int n, int K, int inc, bool is_prefetch) { - int offidx = (is_b_row? n : K/4) % num_ptr_b; - Value* ptrb; - if(K==0 && is_prefetch){ - if(inc == 0) - ptrb = gep(shared_pre_ptr_[layout_b], off_b[offidx]); - else - ptrb = gep(shared_next_ptr_[layout_b], off_b[offidx]); - } else - ptrb = ptr_b[offidx]; - - int stepbn = is_b_row ? n / (num_ptr_b)*(num_ptr_b) : n; - int stepbk = is_b_row ? K : K / (num_ptr_b*vec_b)*(num_ptr_b*vec_b); - Value* pb = gep(ptrb, i32(stepbn*stride_rep_n*stride_bn + stepbk*stride_bk)); - Value* hb = load(bit_cast(pb, ptr_ty(vec_ty(i32_ty, vec_b/2), 3))); - // record lds that needs to be moved - if (K == 0 && inc == 1 && is_prefetch) - prefetch_latch_to_bb_[phiB->get_incoming_value(1)].push_back(hb); - Value *hb00 = bit_cast(extract_elt(hb, i32(0)), f16x2_ty); - Value *hb01 = bit_cast(extract_elt(hb, i32(1)), f16x2_ty); - register_lds(hbs, n, K, inc, hb00, hb01, is_prefetch); - if(vec_b > 4){ - Value *hb10 = bit_cast(extract_elt(hb, i32(2)), f16x2_ty); - Value *hb11 = bit_cast(extract_elt(hb, i32(3)), f16x2_ty); - if(is_b_row) - register_lds(hbs, n+1, K, inc, hb10, hb11, is_prefetch); - else - register_lds(hbs, n, K+4, inc, hb10, hb11, is_prefetch); - } - - }; - - // update accumulators - if (C->is_prefetched()) { - // create phis - builder_->SetInsertPoint(curr_bb->getFirstNonPHI()); - for (unsigned m = 0; m < num_m/2; m += is_a_row?1:2) { - has[{m, 0}].first = phi(f16x2_ty, 2); - has[{m, 0}].second = phi(f16x2_ty, 2); - if (!is_a_row && vec_a>4) { - has[{m+1, 0}].first = phi(f16x2_ty, 2); - has[{m+1, 0}].second = phi(f16x2_ty, 2); - } - } - for (unsigned n = 0; n < num_n/2; n += is_b_row?2:1) { - hbs[{n, 0}].first = phi(f16x2_ty, 2); - hbs[{n, 0}].second = phi(f16x2_ty, 2); - if (is_b_row && vec_b>4) { - hbs[{n+1, 0}].first = phi(f16x2_ty, 2); - hbs[{n+1, 0}].second = phi(f16x2_ty, 2); - } - } - - // insert prefetched lds at the end of loop header - builder_->SetInsertPoint(bbs_[phiA->get_incoming_block(0)]->getTerminator()); - for (unsigned m = 0; m < num_m/2; m += is_a_row?1:2) - load_a(m, 0, 0, true); - for (unsigned n = 0; n < num_n/2; n += is_b_row?2:1) - load_b(n, 0, 0, true); - - // update accumulators - builder_->SetInsertPoint(curr_bb); - for (unsigned K = 0; K < NK; K += 4) { - int NEXTK = (K + 4) % NK; - // prefetch A - for (unsigned m = 0; m < num_m/2; m+=is_a_row?1:2) - load_a(m, NEXTK, 1, true); - // prefetch B - for (unsigned n = 0; n < num_n/2; n+=is_b_row?2:1) - load_b(n, NEXTK, 1, true); - // tensor core ops - for(unsigned m = 0; m < num_m/2; m++) - for(unsigned n = 0; n < num_n/2; n++){ - call_mma(m, n, K); - } - } - } else { // not prefetched - for(unsigned K = 0; K < NK; K += 4) - for(unsigned m = 0; m < num_m/2; m++) - for(unsigned n = 0; n < num_n/2; n++) { - if(has.find({m, K}) == has.end()) - load_a(m, K, /*inc*/0, /*is_prefetch*/false); - if(hbs.find({n, K}) == hbs.end()) - load_b(n, K, /*inc*/0, /*is_prefetch*/false); - call_mma(m, n, K); - } - } - - // write back accumulators - for(size_t i = 0; i < idxs_.at(C).size(); i++) - vals_[C][idxs_[C][i]] = acc[i]; -} - -namespace { -class mma16816_smem_loader { -public: - mma16816_smem_loader(int wpt, std::vector order, int k_order, - std::vector tile_shape, - std::vector instr_shape, std::vector mat_shape, - int per_phase, int max_phase, int dtsize, Builder *builder, - adder add, multiplier mul, geper gep) - : wpt_(wpt), order_(order), k_order_(k_order), tile_shape_(tile_shape), - instr_shape_(instr_shape), mat_shape_(mat_shape), - per_phase_(per_phase), max_phase_(max_phase), dtsize_(dtsize), builder_(builder), - add(add), mul(mul), gep(gep) { - // compute compile-time constant variables & types - c_mat_shape_ = mat_shape[order[0]]; - s_mat_shape_ = mat_shape[order[1]]; - - c_stride_ = tile_shape[order[1]]; - s_stride_ = tile_shape[order[0]]; - - // rule: k must be the fast-changing axis - need_trans_ = k_order_ != order_[0]; - can_use_ldmatrix_ = dtsize == 2 || (!need_trans_); - - // we need more pointers at the fast-changing axis, - if (can_use_ldmatrix_) - num_ptr_ = tile_shape[order[0]] / (order[0] == k_order? 1 : wpt) / instr_shape[order[0]]; - else // warning: this only works for tf32 & need transpose - num_ptr_ = tile_shape[order[0]] / wpt / mat_shape[order[0]]; - num_ptr_ = std::max(num_ptr_, 2); - - // special rule for i8/u8, 4 ptrs for each matrix - if (!can_use_ldmatrix_ && dtsize_ == 1) - num_ptr_ *= 4; - - // load_v4 stride (in num of mats) - int load_stride_in_mat[2]; - load_stride_in_mat[k_order] = 2; // instr_shape[k_order] / mat_shape[k_order], always 2 - load_stride_in_mat[k_order^1] = wpt * (instr_shape[k_order^1] / mat_shape[k_order^1]); - p_load_stride_in_mat_ = load_stride_in_mat[order[0]]; - // stride in mat, used by load_v4 - s_mat_stride_ = load_stride_in_mat[order[1]] / (instr_shape[order[1]]/mat_shape[order[1]]); - } - - std::vector compute_offs(Value *warp_off, Value *lane) { - // TODO: this needs to be moved to constructor (and extracted to arr_order) - mat_arr_stride_ = (k_order_ == 1) ? 1 : wpt_; - warp_off_stride_ = instr_shape_[k_order_^1] / mat_shape_[k_order_^1]; - // start matrix logic offset (rename it as base_mat_off?) - Value *mat_off[2] = {nullptr, nullptr}; - - if (can_use_ldmatrix_) { - // c: lane idx inside a group (a group is a collection of 8 contiguous threads) - // s: group idx (0,1,2,3) inside a warp - Value *c = urem(lane, i32(8)); - Value *s = udiv(lane, i32(8)); - // We can decompose s => s_0, s_1... - Value *s0 = urem(s, i32(2)); - Value *s1 = udiv(s, i32(2)); - - // We use different orders for a & b for better performance. - Value *k_mat_arr = (k_order_ == 1) ? s1 : s0; - Value *nk_mat_arr = (k_order_ == 1) ? s0 : s1; - mat_off[k_order_^1] = add(mul(warp_off, i32(warp_off_stride_)), - mul(nk_mat_arr, i32(mat_arr_stride_))); - mat_off[k_order_] = k_mat_arr; - // physical offset (before swizzling) - Value *c_mat_off = mat_off[order_[0]]; - Value *s_mat_off = mat_off[order_[1]]; - // offset inside a matrix - Value *s_off_in_mat = c; - - std::vector offs(num_ptr_); - Value *phase = urem(udiv(s_off_in_mat, i32(per_phase_)), i32(max_phase_)); - // pre-compute strided offset - Value *s_off = add(s_off_in_mat, mul(s_mat_off, i32(s_mat_shape_))); - for (int i=0; i < num_ptr_; ++i) { - Value *c_mat_off_i = add(c_mat_off, i32(i*p_load_stride_in_mat_)); - c_mat_off_i = xor_(c_mat_off_i, phase); // smem swizzle - offs[i] = add(mul(c_mat_off_i, i32(c_mat_shape_)), mul(s_off, i32(s_stride_))); - } - return offs; - } else if (dtsize_ == 4 && need_trans_) { - // load tf32 matrices with lds32 - Value *c_off_in_mat = udiv(lane, i32(4)); // 4 = mat_shape[order[1]] - Value *s_off_in_mat = urem(lane, i32(4)); // - - Value *phase = urem(udiv(s_off_in_mat, i32(per_phase_)), i32(max_phase_)); - std::vector offs(num_ptr_); - for (int mat = 0; mat < 4; ++mat) { // loads 4 mats each time - int k_mat_arr_int = (k_order_ == 1) ? mat/2 : mat%2; - int nk_mat_arr_int = (k_order_ == 1) ? mat%2 : mat/2; - if (k_mat_arr_int > 0) // we don't need pointers for k - continue; - Value *k_mat_arr = i32(k_mat_arr_int); - Value *nk_mat_arr = i32(nk_mat_arr_int); - // physical offset (before swizzling) - Value *c_mat_off = add(mul(warp_off, i32(warp_off_stride_)), - mul(nk_mat_arr, i32(mat_arr_stride_))); - Value *s_mat_off = k_mat_arr; // always 0? - Value *s_off = add(s_off_in_mat, mul(s_mat_off, i32(s_mat_shape_))); - // FIXME: (k_order_ == 1?) is really dirty hack - for (int i = 0; i < num_ptr_/2; ++i) { - Value *c_mat_off_i = add(c_mat_off, i32(i*p_load_stride_in_mat_*(k_order_ == 1?1:2))); - c_mat_off_i = xor_(c_mat_off_i, phase); - Value *c_off = add(c_off_in_mat, mul(c_mat_off_i, i32(c_mat_shape_))); - // TODO: move this out of the loop - c_off = urem(c_off, i32(tile_shape_[order_[0]])); - s_off = urem(s_off, i32(tile_shape_[order_[1]])); - offs[2*i + nk_mat_arr_int] = add(c_off, mul(s_off, i32(s_stride_))); - } - } - return offs; - // throw std::runtime_error("not implemented"); - } else if (dtsize_ == 1 && need_trans_) { - // load i8/u8 matrices with lds8 - Value *c_off_in_mat = udiv(lane, i32(4)); // - Value *s_off_in_mat = mul(urem(lane, i32(4)), i32(4)); // each thread load 4 cols - - // Value *phase = urem(udiv(s_off_in_mat, i32(per_phase_)), i32(max_phase_)); - std::vector offs(num_ptr_); - for (int mat = 0; mat < 4; ++mat) { // loads 4 mats each time - int k_mat_arr_int = (k_order_ == 1) ? mat/2 : mat%2; - int nk_mat_arr_int = (k_order_ == 1) ? mat%2 : mat/2; - if (k_mat_arr_int > 0) // we don't need pointers for k - continue; - Value *k_mat_arr = i32(k_mat_arr_int); - Value *nk_mat_arr = i32(nk_mat_arr_int); - // physical offset (before swizzling) - Value *c_mat_off = add(mul(warp_off, i32(warp_off_stride_)), - mul(nk_mat_arr, i32(mat_arr_stride_))); - Value *s_mat_off = k_mat_arr; // always 0? - - for (int loadx4_off = 0; loadx4_off < num_ptr_/8; ++loadx4_off) { - for (int elem_off = 0; elem_off < 4; ++elem_off) { - int ptr_off = loadx4_off*8 + nk_mat_arr_int*4 + elem_off; - - Value *c_mat_off_i = add(c_mat_off, i32(loadx4_off*p_load_stride_in_mat_*(k_order_ == 1?1:2))); - Value *s_off_in_mat_elem = add(s_off_in_mat, i32(elem_off)); - - // disable swizzling ... - // Value *phase = urem(udiv(s_off_in_mat, i32(per_phase_)), i32(max_phase_)); - // c_mat_off_i = xor_(c_mat_off_i, phase); - - Value *c_off = add(c_off_in_mat, mul(c_mat_off_i, i32(c_mat_shape_))); - Value *s_off = add(s_off_in_mat_elem, mul(s_mat_off, i32(s_mat_shape_))); - // To prevent out-of-bound access when the tile is too small - c_off = urem(c_off, i32(tile_shape_[order_[0]])); - s_off = urem(s_off, i32(tile_shape_[order_[1]])); - offs[ptr_off] = add(c_off, mul(s_off, i32(s_stride_))); - } - } - } - return offs; - } else - throw std::runtime_error("invalid smem load config"); - } - - std::tuple - load_x4(int mat0, int mat1, int inc, bool is_prefetch, ir::phi_node *pn, - Value *pre_ptr, Value *next_ptr, std::vector &off, std::vector &ptrs, - FunctionType *ldmatrix_ty, Type *smem_ptr_ty, - std::map> &prefetch_latch_to_bb_) { - assert(mat0 % 2 == 0 && mat1 % 2 == 0 && "smem matrix load must be aligned"); - int mat_idx[2] = {mat0, mat1}; - int k = mat_idx[k_order_]; - - int ptr_idx = -1; - if (can_use_ldmatrix_) - ptr_idx = mat_idx[order_[0]] / (instr_shape_[order_[0]] / mat_shape_[order_[0]]); - else if (dtsize_ == 4 && need_trans_) // tf32 & trans - ptr_idx = mat_idx[order_[0]]; - else // i8 & trans - ptr_idx = mat_idx[order_[0]] * 4; - - auto get_ptr = [&](int idx) -> Value* { - Value *ptr = nullptr; - if (k == 0 && is_prefetch) { - if (inc == 0) - ptr = bit_cast(gep(pre_ptr, off.at(idx)), smem_ptr_ty); - else - ptr = bit_cast(gep(next_ptr, off.at(idx)), smem_ptr_ty); - } else - ptr = ptrs.at(idx); - return ptr; - }; - Value *ptr = get_ptr(ptr_idx); - - Value *res_v4 = nullptr; - if (can_use_ldmatrix_) { - std::string trans = need_trans_ ? ".trans" : ""; - // the offset (in byte) on the strided axis is a constant - int s_offset = mat_idx[order_[1]] * (s_mat_stride_*s_mat_shape_) * s_stride_ * dtsize_; - InlineAsm *ld_fn = InlineAsm::get(ldmatrix_ty, - "ldmatrix.sync.aligned.m8n8.x4" + trans + ".shared.b16 " - "{$0, $1, $2, $3}, " - "[$4 + " + std::to_string(s_offset) + "];", - "=r,=r,=r,=r,r", true); - assert(ptr); - res_v4 = call(ldmatrix_ty, ld_fn, {ptr}); - if (k == 0 && inc == 1 && is_prefetch) - prefetch_latch_to_bb_[pn->get_incoming_value(1)].push_back(res_v4); - return {extract_val(res_v4, std::vector{0}), - extract_val(res_v4, std::vector{1}), - extract_val(res_v4, std::vector{2}), - extract_val(res_v4, std::vector{3})}; - } else if (dtsize_ == 4 && need_trans_) { // use lds.32 to load tf32 matrices - Value *ptr2 = get_ptr(ptr_idx+1); - assert(s_mat_stride_ == 1); - int s_offset_elem = mat_idx[order_[1]] * (s_mat_stride_*s_mat_shape_) * s_stride_; - int s_offset_arr_elem = 1 * (s_mat_stride_*s_mat_shape_) * s_stride_; - Value *elem0, *elem1, *elem2, *elem3; - if (k_order_ == 1) { - elem0 = load(gep(ptr, i32(s_offset_elem))); - elem1 = load(gep(ptr2, i32(s_offset_elem))); - elem2 = load(gep(ptr, i32(s_offset_elem + s_offset_arr_elem))); - elem3 = load(gep(ptr2, i32(s_offset_elem + s_offset_arr_elem))); - } else { // for b (k first) - elem0 = load(gep(ptr, i32(s_offset_elem))); - elem2 = load(gep(ptr2, i32(s_offset_elem))); - elem1 = load(gep(ptr, i32(s_offset_elem + s_offset_arr_elem))); - elem3 = load(gep(ptr2, i32(s_offset_elem + s_offset_arr_elem))); - } - if (k == 0 && inc == 1 && is_prefetch) { - prefetch_latch_to_bb_[pn->get_incoming_value(1)].push_back(elem0); - prefetch_latch_to_bb_[pn->get_incoming_value(1)].push_back(elem1); - prefetch_latch_to_bb_[pn->get_incoming_value(1)].push_back(elem2); - prefetch_latch_to_bb_[pn->get_incoming_value(1)].push_back(elem3); - } - return {elem0, elem1, elem2, elem3}; - } else if (dtsize_ == 1 && need_trans_) { // use lds.8 to load i8/u8 matrices - Value *ptr00 = get_ptr(ptr_idx); - Value *ptr01 = get_ptr(ptr_idx+1); - Value *ptr02 = get_ptr(ptr_idx+2); - Value *ptr03 = get_ptr(ptr_idx+3); - - Value *ptr10 = get_ptr(ptr_idx+4); - Value *ptr11 = get_ptr(ptr_idx+5); - Value *ptr12 = get_ptr(ptr_idx+6); - Value *ptr13 = get_ptr(ptr_idx+7); - - assert(s_mat_stride_ == 1); - int s_offset_elem = mat_idx[order_[1]] * (s_mat_stride_*s_mat_shape_) * s_stride_; - int s_offset_arr_elem = 1 * (s_mat_stride_*s_mat_shape_) * s_stride_; - - Value *i8v4_elems[4]; - Value *i32_elems[4]; - for (int i=0; i<4; ++i) - i8v4_elems[i] = UndefValue::get(vec_ty(i8_ty, 4)); - - Value *elem00, *elem01, *elem02, *elem03; - Value *elem10, *elem11, *elem12, *elem13; - Value *elem20, *elem21, *elem22, *elem23; - Value *elem30, *elem31, *elem32, *elem33; - Value *i8_elems[4*4]; - if (k_order_ == 1) { // - i8_elems[0*4 + 0] = load(gep(ptr00, i32(s_offset_elem))); - i8_elems[0*4 + 1] = load(gep(ptr01, i32(s_offset_elem))); - i8_elems[0*4 + 2] = load(gep(ptr02, i32(s_offset_elem))); - i8_elems[0*4 + 3] = load(gep(ptr03, i32(s_offset_elem))); - - assert(i8_elems[0*4 + 0]->getType()->isIntegerTy(8)); - - i8_elems[1*4 + 0] = load(gep(ptr10, i32(s_offset_elem))); - i8_elems[1*4 + 1] = load(gep(ptr11, i32(s_offset_elem))); - i8_elems[1*4 + 2] = load(gep(ptr12, i32(s_offset_elem))); - i8_elems[1*4 + 3] = load(gep(ptr13, i32(s_offset_elem))); - - i8_elems[2*4 + 0] = load(gep(ptr00, i32(s_offset_elem + s_offset_arr_elem))); - i8_elems[2*4 + 1] = load(gep(ptr01, i32(s_offset_elem + s_offset_arr_elem))); - i8_elems[2*4 + 2] = load(gep(ptr02, i32(s_offset_elem + s_offset_arr_elem))); - i8_elems[2*4 + 3] = load(gep(ptr03, i32(s_offset_elem + s_offset_arr_elem))); - - i8_elems[3*4 + 0] = load(gep(ptr10, i32(s_offset_elem + s_offset_arr_elem))); - i8_elems[3*4 + 1] = load(gep(ptr11, i32(s_offset_elem + s_offset_arr_elem))); - i8_elems[3*4 + 2] = load(gep(ptr12, i32(s_offset_elem + s_offset_arr_elem))); - i8_elems[3*4 + 3] = load(gep(ptr13, i32(s_offset_elem + s_offset_arr_elem))); - - for (int m=0; m<4; ++m) { - for (int e=0; e<4; ++e) - i8v4_elems[m] = insert_elt(i8v4_elems[m], i8_elems[m*4 + e], e); - i32_elems[m] = bit_cast(i8v4_elems[m], i32_ty); - } - } else { // for b (k first) - i8_elems[0*4 + 0] = load(gep(ptr00, i32(s_offset_elem))); - i8_elems[0*4 + 1] = load(gep(ptr01, i32(s_offset_elem))); - i8_elems[0*4 + 2] = load(gep(ptr02, i32(s_offset_elem))); - i8_elems[0*4 + 3] = load(gep(ptr03, i32(s_offset_elem))); - - assert(i8_elems[0*4 + 0]->getType()->isIntegerTy(8)); - - i8_elems[2*4 + 0] = load(gep(ptr10, i32(s_offset_elem))); - i8_elems[2*4 + 1] = load(gep(ptr11, i32(s_offset_elem))); - i8_elems[2*4 + 2] = load(gep(ptr12, i32(s_offset_elem))); - i8_elems[2*4 + 3] = load(gep(ptr13, i32(s_offset_elem))); - - i8_elems[1*4 + 0] = load(gep(ptr00, i32(s_offset_elem + s_offset_arr_elem))); - i8_elems[1*4 + 1] = load(gep(ptr01, i32(s_offset_elem + s_offset_arr_elem))); - i8_elems[1*4 + 2] = load(gep(ptr02, i32(s_offset_elem + s_offset_arr_elem))); - i8_elems[1*4 + 3] = load(gep(ptr03, i32(s_offset_elem + s_offset_arr_elem))); - - i8_elems[3*4 + 0] = load(gep(ptr10, i32(s_offset_elem + s_offset_arr_elem))); - i8_elems[3*4 + 1] = load(gep(ptr11, i32(s_offset_elem + s_offset_arr_elem))); - i8_elems[3*4 + 2] = load(gep(ptr12, i32(s_offset_elem + s_offset_arr_elem))); - i8_elems[3*4 + 3] = load(gep(ptr13, i32(s_offset_elem + s_offset_arr_elem))); - - for (int m=0; m<4; ++m) { - for (int e=0; e<4; ++e) - i8v4_elems[m] = insert_elt(i8v4_elems[m], i8_elems[m*4 + e], e); - i32_elems[m] = bit_cast(i8v4_elems[m], i32_ty); - } - } - if (k == 0 && inc == 1 && is_prefetch) { - for (int m = 0; m < 4; ++m) - for (int e = 0; e < 4; ++e) - prefetch_latch_to_bb_[pn->get_incoming_value(1)].push_back(i8_elems[m*4 + e]); - } - return {i32_elems[0], i32_elems[1], i32_elems[2], i32_elems[3]}; - } else - throw std::runtime_error("invalid smem load"); - } - - int get_num_ptr() const { return num_ptr_; } - -private: - int wpt_; - std::vector order_; - int k_order_; - std::vector tile_shape_; - std::vector instr_shape_; - std::vector mat_shape_; - int per_phase_, max_phase_; - int dtsize_; - - // generated - int c_mat_shape_, s_mat_shape_; - int c_stride_, s_stride_; - // p_: on the pointer axis - int p_load_stride_in_mat_; - int s_mat_stride_; - // stride when moving to next not-k mat - int warp_off_stride_; - int mat_arr_stride_; // matrix arrangement (inside a load) stride - bool need_trans_, can_use_ldmatrix_; - int num_ptr_; - - Builder *builder_; - adder add; - multiplier mul; - geper gep; -}; -} - -/** - * \brief Code Generation for `mma.16816` (A100) - */ -//TODO: clean-up -void generator::visit_mma16816(ir::dot_inst* C, ir::value *A, ir::value *B, ir::value *D, unsigned NK) { - const std::vector& shapes = C->get_type()->get_block_shapes(); - std::map, std::vector> fcs; - for(indices_t idx: idxs_.at(C)){ - std::vector key(idx.size() - 2); - std::copy(idx.begin() + 2, idx.end(), key.begin()); - fcs[key].push_back(vals_[D][idx]); - }; - auto shape_a = A->get_type()->get_block_shapes(); - auto shape_b = B->get_type()->get_block_shapes(); - auto ord_a = layouts_->get(A)->get_order(); - auto ord_b = layouts_->get(B)->get_order(); - analysis::mma_layout* layout = layouts_->get(C)->to_mma(); - analysis::shared_layout* layout_a = (analysis::shared_layout*)layouts_->get(C->get_operand(0)); - analysis::shared_layout* layout_b = (analysis::shared_layout*)layouts_->get(C->get_operand(1)); - bool is_a_row = ord_a[0] == 1; - bool is_b_row = ord_b[0] == 1; - - std::vector mma_instr_shape = layout->get_mma_instr_shape(); - const int mma_instr_m = mma_instr_shape[0]; - const int mma_instr_n = mma_instr_shape[1]; - const int mma_instr_k = mma_instr_shape[2]; - - std::vector mat_shape = layout->get_mma_mat_shape(); - const int mat_shape_m = mat_shape[0]; - const int mat_shape_n = mat_shape[1]; - const int mat_shape_k = mat_shape[2]; - - const int per_phase_a = swizzle_->get_per_phase(layout_a); - const int max_phase_a = swizzle_->get_max_phase(layout_a); - const int per_phase_b = swizzle_->get_per_phase(layout_b); - const int max_phase_b = swizzle_->get_max_phase(layout_b); - - const int num_rep_m = shapes[0] / layout->shape_per_cta(0); - const int num_rep_n = shapes[1] / layout->shape_per_cta(1); - const int num_rep_k = std::max(NK/mma_instr_k, 1); - - // floating point types - Type *fp32_ty = f32_ty; - Type *fp16x2_ty = vec_ty(f16_ty, 2); - Type *bf16x2_ty = vec_ty(bf16_ty, 2); - Type *fp16x2_pack4_ty = StructType::get(*ctx_, std::vector{fp16x2_ty, fp16x2_ty, fp16x2_ty, fp16x2_ty}); - Type *bf16x2_pack4_ty = StructType::get(*ctx_, std::vector{bf16x2_ty, bf16x2_ty, bf16x2_ty, bf16x2_ty}); - Type *fp32_pack4_ty = StructType::get(*ctx_, std::vector{fp32_ty, fp32_ty, fp32_ty, fp32_ty}); - // integer types - Type *i8x4_ty = vec_ty(i8_ty, 4); - Type *i8x4_pack4_ty = StructType::get(*ctx_, std::vector{i8x4_ty, i8x4_ty, i8x4_ty, i8x4_ty}); - Type *i32_pack4_ty = StructType::get(*ctx_, std::vector{i32_ty, i32_ty, i32_ty, i32_ty}); - - - FunctionType *ldmatrix_ty = nullptr; - FunctionType *mma_ty = nullptr; - Type *phi_ty = nullptr; - Type *smem_ptr_ty = nullptr; - - ir::type *A_ir_ty = A->get_type()->get_scalar_ty(); - ir::type *B_ir_ty = B->get_type()->get_scalar_ty(); - if (A_ir_ty->is_fp16_ty() && B_ir_ty->is_fp16_ty()) { - mma_ty = FunctionType::get(fp32_pack4_ty, std::vector{fp16x2_ty, fp16x2_ty, fp16x2_ty, fp16x2_ty, fp16x2_ty, fp16x2_ty, fp32_ty, fp32_ty, fp32_ty, fp32_ty}, false); - smem_ptr_ty = ptr_ty(f16_ty, 3); - ldmatrix_ty = FunctionType::get(fp16x2_pack4_ty, std::vector{smem_ptr_ty}, false); - phi_ty = fp16x2_ty; - } else if (A_ir_ty->is_bf16_ty() && B_ir_ty->is_bf16_ty()) { - // FIXME: We should use bf16 here. - mma_ty = FunctionType::get(fp32_pack4_ty, std::vector{fp16x2_ty, fp16x2_ty, fp16x2_ty, fp16x2_ty, fp16x2_ty, fp16x2_ty, fp32_ty, fp32_ty, fp32_ty, fp32_ty}, false); - smem_ptr_ty = ptr_ty(f16_ty, 3); - ldmatrix_ty = FunctionType::get(fp16x2_pack4_ty, std::vector{smem_ptr_ty}, false); - phi_ty = fp16x2_ty; - // mma_ty = FunctionType::get(fp32_pack4_ty, std::vector{bf16x2_ty, bf16x2_ty, bf16x2_ty, bf16x2_ty, bf16x2_ty, bf16x2_ty, fp32_ty, fp32_ty, fp32_ty, fp32_ty}, false); - // smem_ptr_ty = ptr_ty(bf16_ty, 3); - // ldmatrix_ty = FunctionType::get(bf16x2_pack4_ty, std::vector{smem_ptr_ty}, false); - // phi_ty = bf16x2_ty; - } else if (A_ir_ty->is_fp32_ty() && B_ir_ty->is_fp32_ty()) { - mma_ty = FunctionType::get(fp32_pack4_ty, std::vector{fp32_ty, fp32_ty, fp32_ty, fp32_ty, fp32_ty, fp32_ty, fp32_ty, fp32_ty, fp32_ty, fp32_ty}, false); - smem_ptr_ty = ptr_ty(fp32_ty, 3); - ldmatrix_ty = FunctionType::get(fp32_pack4_ty, std::vector{smem_ptr_ty}, false); - phi_ty = fp32_ty; - } else if (A_ir_ty->is_integer_ty(8) && B_ir_ty->is_integer_ty(8)) { - // FIXME: We should use i8 here (but nvptx will generate extra casts when using i8) - mma_ty = FunctionType::get(i32_pack4_ty, std::vector{i32_ty, i32_ty, i32_ty, i32_ty, i32_ty, i32_ty, i32_ty, i32_ty, i32_ty, i32_ty}, false); - smem_ptr_ty = ptr_ty(i8_ty, 3); - ldmatrix_ty = FunctionType::get(i32_pack4_ty, std::vector{smem_ptr_ty}, false); - phi_ty = i32_ty; - // mma_ty = FunctionType::get(i32_pack4_ty, std::vector{i8x4_ty, i8x4_ty, i8x4_ty, i8x4_ty, i8x4_ty, i8x4_ty, i32_ty, i32_ty, i32_ty, i32_ty}, false); - // smem_ptr_ty = ptr_ty(i8_ty, 3); - // ldmatrix_ty = FunctionType::get(i8x4_pack4_ty, std::vector{smem_ptr_ty}, false); - // phi_ty = i8x4_ty; - } else - throw std::runtime_error("mma16816 data type not supported"); - - // left-hand-side values - std::map, Value*> ha; - std::map, Value*> hb; - - BasicBlock* CurrBB = builder_->GetInsertBlock(); - BasicBlock* FirstBB = &CurrBB->getParent()->getEntryBlock(); - if(FirstBB != CurrBB) - builder_->SetInsertPoint(FirstBB->getTerminator()); - - Value* thread = tgt_->get_local_id(mod_, *builder_, 0); - Value *lane = urem(thread, i32(32)); - Value *warp = udiv(thread, i32(32)); - Value *warp_mn = udiv(warp, i32(layout->wpt(0))); - Value *warp_m = urem(warp, i32(layout->wpt(0))); - Value *warp_n = urem(warp_mn, i32(layout->wpt(1))); - std::vector& fc = fcs.begin()->second; - - size_t dtsize_a = A->get_type()->get_scalar_ty()->get_primitive_size_in_bits() / 8; - size_t dtsize_b = B->get_type()->get_scalar_ty()->get_primitive_size_in_bits() / 8; - - // | -> k (row-major), since we have ldmatrix.trans, we only need to change stride - // v (s0_0(0), s1_0(2), | *num_rep_k - // m s0_1(1), s1_1(3)) | (stride in num of matrices(mat_stride_ak): 2) - // ----------- - // *num_rep_m (stride in num of matrices(mat_stride_am): 2*layout->wpt(0)) - mma16816_smem_loader a_loader(layout->wpt(0), ord_a, /*k_order*/1, shape_a, - {mma_instr_m, mma_instr_k}, {mat_shape_m, mat_shape_k}, - per_phase_a, max_phase_a, dtsize_a, builder_, add, mul, gep); - std::vector off_a = a_loader.compute_offs(warp_m, lane); - int num_ptr_a = a_loader.get_num_ptr(); - - // | -> n (col-major) - // v (s0_0(0), | (stride: wpt(1)) | s1_0(2) | *num_rep_n - // k s0_1(1), | | s1_1(3)) | (stride in num of matrices(mat_stride_bn): wpt(1)) - // ----------- - // *num_rep_k (stride in num of matrices(mat_stride_bk): 2) - mma16816_smem_loader b_loader(layout->wpt(1), ord_b, /*k_order*/0, shape_b, - {mma_instr_k, mma_instr_n}, {mat_shape_k, mat_shape_n}, - per_phase_b, max_phase_b, dtsize_b, builder_, add, mul, gep); - std::vector off_b = b_loader.compute_offs(warp_n, lane); - int num_ptr_b = b_loader.get_num_ptr(); - - builder_->SetInsertPoint(CurrBB); - // A pointer - std::vector ptrs_a(num_ptr_a); - for(int i = 0; i < num_ptr_a; i++) - ptrs_a[i] = bit_cast(gep(shmems_[A], {off_a[i]}), smem_ptr_ty); - // B pointer - std::vector ptrs_b(num_ptr_b); - for(int i = 0; i < num_ptr_b; i++) - ptrs_b[i] = bit_cast(gep(shmems_[B], {off_b[i]}), smem_ptr_ty); - - InlineAsm *mma_fn = InlineAsm::get(mma_ty, layout->get_ptx_instr() + - " {$0, $1, $2, $3}," - " {$4, $5, $6, $7}," - " {$8, $9}," - " {$10, $11, $12, $13};", - "=r,=r,=r,=r,r,r,r,r,r,r,0,1,2,3", true); - - // create mma & unpack result, m, n, k are offsets in mat - auto call_mma = [&](unsigned m, unsigned n, unsigned k) { - unsigned cols_per_thread = num_rep_m * 2; - std::vector idx = { - (m + 0) + (n*2 + 0)*cols_per_thread, - (m + 0) + (n*2 + 1)*cols_per_thread, - (m + 1) + (n*2 + 0)*cols_per_thread, - (m + 1) + (n*2 + 1)*cols_per_thread - }; - Value *nc = call(mma_ty, mma_fn, - {ha[{m, k}], ha[{m+1, k}], ha[{m, k+1}], ha[{m+1, k+1}], - hb[{n, k}], hb[{n, k+1}], - fc[idx[0]], fc[idx[1]], fc[idx[2]], fc[idx[3]]}); - fc[idx[0]] = extract_val(nc, std::vector{0}); - fc[idx[1]] = extract_val(nc, std::vector{1}); - fc[idx[2]] = extract_val(nc, std::vector{2}); - fc[idx[3]] = extract_val(nc, std::vector{3}); - }; - - ir::phi_node* phiA = dynamic_cast(A); - ir::phi_node* phiB = dynamic_cast(B); - - auto register_lds2 = - [&](std::map, Value*>& vals, int mn, int k, int inc, Value* val, bool is_prefetch) { - if (k < 2 && is_prefetch) { - ir::basic_block* inc_block = phiA->get_incoming_block(inc); - lazy_phi_incs_.push_back(std::make_tuple((PHINode*)vals[{mn, k}], val, inc_block)); - } else - vals[{mn, k}] = val; - }; - - auto load_a = [&](int m, int k, int inc, bool is_prefetch) { - auto [ha0, ha1, ha2, ha3] = a_loader.load_x4(m, k, inc, is_prefetch, phiA, shared_pre_ptr_[layout_a], - shared_next_ptr_[layout_a], off_a, ptrs_a, - ldmatrix_ty, smem_ptr_ty, prefetch_latch_to_bb_); - register_lds2(ha, m, k, inc, ha0, is_prefetch); - register_lds2(ha, m+1, k, inc, ha1, is_prefetch); - register_lds2(ha, m, k+1, inc, ha2, is_prefetch); - register_lds2(ha, m+1, k+1, inc, ha3, is_prefetch); - }; - - auto load_b = [&](int n, int k, int inc, bool is_prefetch) { - auto [hb0, hb1, hb2, hb3] = b_loader.load_x4(k, n, inc, is_prefetch, phiB, shared_pre_ptr_[layout_b], - shared_next_ptr_[layout_b], off_b, ptrs_b, - ldmatrix_ty, smem_ptr_ty, prefetch_latch_to_bb_); - register_lds2(hb, n, k, inc, hb0, is_prefetch); - register_lds2(hb, n+1, k, inc, hb2, is_prefetch); - register_lds2(hb, n, k+1, inc, hb1, is_prefetch); - register_lds2(hb, n+1, k+1, inc, hb3, is_prefetch); - }; - - if (C->is_prefetched()) { - // create phis - builder_->SetInsertPoint(CurrBB->getFirstNonPHI()); - for(unsigned m = 0; m < num_rep_m; m++){ - ha[{2*m, 0}] = phi(phi_ty, 2); - ha[{2*m+1, 0}] = phi(phi_ty, 2); - ha[{2*m, 1}] = phi(phi_ty, 2); - ha[{2*m+1, 1}] = phi(phi_ty, 2); - } - for(unsigned n = 0; n < num_rep_n; n+=2){ - hb[{n, 0}] = phi(phi_ty, 2); - hb[{n+1, 0}] = phi(phi_ty, 2); - hb[{n, 1}] = phi(phi_ty, 2); - hb[{n+1, 1}] = phi(phi_ty, 2); - } - // insert prefetched lds at the end of loop header - builder_->SetInsertPoint(bbs_[phiA->get_incoming_block(0)]->getTerminator()); - for(unsigned m = 0; m < num_rep_m; m++) - load_a(2*m, 0, 0, true); - for(unsigned n = 0; n < num_rep_n; n+=2) - load_b(n, 0, 0, true); - // update accumulators - builder_->SetInsertPoint(CurrBB); - for(unsigned k = 0; k < num_rep_k; ++k){ // stride of instr in mat is 2 - int next_k = (k + 1) % num_rep_k; - // prefetch A - for(unsigned m = 0; m < num_rep_m; m++) - load_a(2*m, 2*next_k, 1, true); - // prefetch B - for(unsigned n = 0; n < num_rep_n; n+=2) - load_b(n, 2*next_k, 1, true); - // tensor core ops - for(unsigned m = 0; m < num_rep_m; m++) - for(unsigned n = 0; n < num_rep_n; n++){ - call_mma(2*m, n, 2*k); - } - } - } - else{ - for (unsigned k = 0; k < num_rep_k; k++) { - for (unsigned m = 0; m < num_rep_m; m++) - load_a(2*m, 2*k, 0, /*is_prefetch*/false); - for (unsigned n = 0; n < num_rep_n; n+=2) - load_b(n, 2*k, 0, /*is_prefetch*/false); - for (unsigned m = 0; m < num_rep_m; m++) - for (unsigned n = 0; n < num_rep_n; n++) - call_mma(2*m, n, 2*k); - } - } - // write back - unsigned i = 0; - for(indices_t idx: idxs_.at(C)){ - std::vector key(idx.size() - 2); - std::copy(idx.begin() + 2, idx.end(), key.begin()); - if(i >= fcs.at(key).size()) - i = 0; - vals_[C][idx] = fcs.at(key)[i++]; - }; -} - -/** - * \brief Code Generation for FMA-based `dot` (FP32, FP64, Default) - */ -void generator::visit_fmadot(ir::dot_inst* C, ir::value* A, ir::value* B, ir::value* D, unsigned NK, Type *c_ty, Function *f_mul_add) { - auto shape_c = C->get_type()->get_block_shapes(); - auto shape_a = A->get_type()->get_block_shapes(); - auto shape_b = B->get_type()->get_block_shapes(); - auto ord_a = layouts_->get(A)->get_order(); - auto ord_b = layouts_->get(B)->get_order(); - analysis::scanline_layout* layout_c = layouts_->get(C)->to_scanline(); - analysis::shared_layout* layout_a = (analysis::shared_layout*)layouts_->get(C->get_operand(0)); - analysis::shared_layout* layout_b = (analysis::shared_layout*)layouts_->get(C->get_operand(1)); - bool is_a_row = ord_a[0] == 1; - bool is_b_row = ord_b[0] == 1; - std::string a_trans = is_a_row ? "" : ".trans"; - std::string b_trans = is_b_row ? ".trans" : ""; - int stride_a_m = is_a_row ? shape_a[1] : 1; - int stride_a_k = is_a_row ? 1 : shape_a[0]; - int stride_b_n = is_b_row ? 1 : shape_b[0]; - int stride_b_k = is_b_row ? shape_b[1] : 1; - int stride_a0 = is_a_row ? stride_a_k : stride_a_m; - int stride_a1 = is_a_row ? stride_a_m : stride_a_k; - int stride_b0 = is_b_row ? stride_b_n : stride_b_k; - int stride_b1 = is_b_row ? stride_b_k : stride_b_n; - int lda = is_a_row ? stride_a_m : stride_a_k; - int ldb = is_b_row ? stride_b_k : stride_b_n; - int per_phase_a = swizzle_->get_per_phase(layout_a); - int max_phase_a = swizzle_->get_max_phase(layout_a); - int per_phase_b = swizzle_->get_per_phase(layout_b); - int max_phase_b = swizzle_->get_max_phase(layout_b); - int num_ptr_a = 8; - int num_ptr_b = 8; - int vec_a = 2; - int vec_b = 4; - distributed_axis ax_m = axes_.at(a_axes_->get(C, 0)); - distributed_axis ax_n = axes_.at(a_axes_->get(C, 1)); -// Value* thread = tgt_->get_local_id(mod_, *builder_, 0); - - Value* off_a0 = is_a_row ? i32(0) : mul(ax_m.thread_id, i32(ax_m.contiguous)); - Value* off_a1 = is_a_row ? mul(ax_m.thread_id, i32(ax_m.contiguous)): i32(0); - std::vector off_a(num_ptr_a); - for(int i = 0; i < num_ptr_a; i++){ -// Value* off_a0i = add(off_a0, i32(is_a_row ? vec_a : layout_c->mts(0)*vec_a)); -// off_a0i = exact_udiv(off_a0i, i32(vec_a)); -// off_a0i = xor_(off_a0i, phase_a); -// off_a0i = mul(off_a0i, i32(vec_a)); - off_a[i] = add(mul(off_a0, i32(stride_a0)), mul(off_a1, i32(stride_a1))); - } - Value* off_b0 = is_b_row ? mul(ax_n.thread_id, i32(ax_n.contiguous)): i32(0); - Value* off_b1 = is_b_row ? i32(0) : mul(ax_n.thread_id, i32(ax_n.contiguous)); - std::vector off_b(num_ptr_b); - for(int i = 0; i < num_ptr_b; i++){ -// Value* off_b0i = add(off_b0, i32(is_b_row ? layout_c->mts(1)*vec_b : vec_b)); -// off_b0i = exact_udiv(off_b0i, i32(vec_b)); -// off_b0i = xor_(off_b0i, phase_b); -// off_b0i = mul(off_b0i, i32(vec_b)); - off_b[i] = add(mul(off_b0, i32(stride_b0)), mul(off_b1, i32(stride_b1))); - } - std::vector ptrs_a(num_ptr_a); - for(int i = 0; i < num_ptr_a; i++) - ptrs_a[i] = gep(shmems_[A], off_a[i]); - std::vector ptrs_b(num_ptr_b); - for(int i = 0; i < num_ptr_b; i++) - ptrs_b[i] = gep(shmems_[B], off_b[i]); - - std::map ret = vals_[D]; - std::map, Value*> has, hbs; - auto ord = layout_c->get_order(); - for(unsigned k = 0; k < NK; k++){ - int z = 0; - for(unsigned i = 0; i < shape_c[ord[1]]; i += layout_c->shape_per_cta(ord[1])) - for(unsigned j = 0; j < shape_c[ord[0]]; j += layout_c->shape_per_cta(ord[0])) - for(unsigned ii = 0; ii < layout_c->nts(ord[1]); ii++) - for(unsigned jj = 0; jj < layout_c->nts(ord[0]); jj++){ - unsigned m = (ord[0] == 1) ? i : j; - unsigned n = (ord[0] == 1) ? j : i; - unsigned mm = (ord[0] == 1) ? ii : jj; - unsigned nn = (ord[0] == 1) ? jj : ii; - if(has.find({m + mm, k}) == has.end()){ - Value* pa = gep(ptrs_a[0], i32((m + mm)*stride_a_m + k*stride_a_k)); - Value* va = load(pa); - has[{m + mm, k}] = va; - } - if(hbs.find({n + nn, k}) == hbs.end()){ - Value* pb = gep(ptrs_b[0], i32((n + nn)*stride_b_n + k*stride_b_k)); - Value* vb = load(pb); - hbs[{n + nn, k}] = vb; - } - ret[idxs_[C].at(z)] = call(f_mul_add, {has[{m+mm,k}], hbs[{n+nn, k}], ret[idxs_[C].at(z)]}); - z++; - } - } - - for(indices_t idx: idxs_.at(C)){ - vals_[C][idx] = ret[idx]; - } -} - -/** - * \brief Code Generation for `dot` - * Dispatches to appropriate specialized function - */ -void generator::visit_dot_inst(ir::dot_inst* dot) { - Function *fn = builder_->GetInsertBlock()->getParent(); - Module *module = fn->getParent(); - ir::value *A = dot->get_operand(0); - ir::value *B = dot->get_operand(1); - ir::value *D = dot->get_operand(2); - Type *c_ty = cvt(D->get_type()->get_scalar_ty()); - Function *f_mul_add = Intrinsic::getDeclaration(module, Intrinsic::fmuladd, std::vector{c_ty}); - auto A_shapes = A->get_type()->get_block_shapes(); - size_t red_axis = 1; - unsigned NK = A_shapes[red_axis]; - bool is_outer = NK == 1; - bool is_mma = layouts_->get(dot)->to_mma(); - if(!is_outer && is_mma && tgt_->as_nvidia()->sm() < 80) - return visit_mma884(dot, A, B, D, NK); - if(!is_outer && is_mma && tgt_->as_nvidia()->sm() >= 80) - return visit_mma16816(dot, A, B, D, NK); // rename it as visit_mma_v2()? - if (dot->get_type()->get_scalar_ty()->is_fp32_ty() && - A->get_type()->get_scalar_ty()->is_fp32_ty()) - return visit_fmadot(dot, A, B, D, NK, c_ty, f_mul_add); - throw std::runtime_error("dot has invalid operand type"); -} - -void generator::visit_trans_inst(ir::trans_inst* trans) { - throw std::runtime_error("not supported"); -} - -/** - * \brief Code Generation for `sqrt` - */ -void generator::visit_sqrt_inst(ir::sqrt_inst* x) { - for(indices_t idx: idxs_.at(x)){ - Value *val = vals_[x->get_operand(0)][idx]; - Value *ret = intrinsic(Intrinsic::sqrt, {val->getType()}, {val}); - vals_[x][idx] = ret; - } -} - -Value* generator::shared_off(const std::vector& shapes, const std::vector& order, indices_t idx){ - // strides - std::vector strides(shapes.size(), builder_->getInt32(0)); - strides[order[0]] = builder_->getInt32(1); - for(size_t i = 1; i < idx.size(); i++) - strides[order[i]] = builder_->CreateMul(strides[order[i-1]], builder_->getInt32(shapes[order[i-1]])); - // result - Value *result = builder_->getInt32(0); - for(size_t i = 0; i < idx.size(); i++) - result = builder_->CreateAdd(result, builder_->CreateMul(idx[i], strides[i])); - return result; -} - -inline Value* generator::shfl_sync(Value* acc, int32_t i){ - Type* ty = acc->getType(); - std::string asm_str = "shfl.sync.bfly.b32 $0, $1, $2, 0x1f, 0xffffffff;"; - InlineAsm *shfl = InlineAsm::get(FunctionType::get(ty, {ty, i32_ty}, false), asm_str, "=f,f,r", false); - if(ty->getPrimitiveSizeInBits() <= 32) - return call(shfl, {acc, i32(i)}); - acc = bit_cast(acc, vec_ty(f32_ty, 2)); - Value* acc0 = builder_->CreateExtractElement(acc, i32(0)); - Value* acc1 = builder_->CreateExtractElement(acc, i32(1)); - Value* ret = UndefValue::get(vec_ty(f32_ty, 2)); - ret = insert_elt(ret, shfl_sync(acc0, i), i32(0)); - ret = insert_elt(ret, shfl_sync(acc1, i), i32(1)); - return bit_cast(ret, ty); -} - -/** - * \brief Code Generation for `reduce` (1D case) - */ -void generator::visit_reduce1d_inst(ir::reduce_inst* x, std::function do_acc, Value *neutral) { - std::map partial; - ir::value *arg = x->get_operand(0); - Type *ret_ty = cvt(x->get_type()->get_scalar_ty()); - Value *acc = nullptr; - - // reduce within thread - for(indices_t idx: idxs_.at(arg)){ - Value *val = vals_[arg][idx]; - acc = !acc ? val : do_acc(acc, val); - } - // reduce within wrap - for(int i = 16; i > 0; i >>= 1) - acc = do_acc(acc, shfl_sync(acc, i)); - // pointers - unsigned addr_space = shmem_->getType()->getPointerAddressSpace(); - Value *base = bit_cast(shmem_, ptr_ty(ret_ty, addr_space)); - Value* thread = tgt_->get_local_id(mod_, *builder_, 0); - Value* warp = udiv(thread, i32(32)); - Value* lane = urem(thread, i32(32)); - // store warp result in shared memory - add_barrier(); - store(neutral, gep(base, lane)); - add_barrier(); - store(acc, gep(base, warp)); - add_barrier(); - - // reduce across warps - Value *cond = icmp_eq(warp, i32(0)); - Instruction *barrier = add_barrier(); - builder_->SetInsertPoint(barrier->getParent()); - Instruction* dummy = builder_->CreateRet(nullptr); - Instruction *term = llvm::SplitBlockAndInsertIfThen(cond, barrier, false); - dummy->removeFromParent(); - builder_->SetInsertPoint(term); - Value* ret = load(gep(base, thread)); - for(int i = (num_warps_+1)/2; i > 0; i >>= 1){ - Value *current = shfl_sync(ret, i); - ret = do_acc(ret, current); - } - store(ret, gep(base, thread)); - - // store first warp done - builder_->SetInsertPoint(barrier->getParent()); - ret = load(base); - for(indices_t idx: idxs_.at(x)) - vals_[x][idx] = ret; -} - -/** - * \brief Code Generation for `reduce` (ND case) - */ -void generator::visit_reducend_inst(ir::reduce_inst* x, std::function do_acc, Value *neutral) { - ir::value *arg = x->get_operand(0); - Type *ty = cvt(x->get_type()->get_scalar_ty()); - unsigned axis = x->get_axis(); - - // reduce within thread - std::map accs; - for(indices_t idx: idxs_.at(arg)){ - indices_t pidx = idx; - pidx[axis] = i32(0); - Value *current = vals_[arg][idx]; - bool is_first = accs.find(pidx) == accs.end(); - accs[pidx] = is_first ? current : do_acc(accs[pidx], current); - }; - - // reduce within blocks - analysis::data_layout* layout = layouts_->get(layouts_->tmp(x)); - Value *base = shared_ptr_.at(layout); - auto shape = layout->get_shape(); - auto order = layout->get_order(); - int space = base->getType()->getPointerAddressSpace(); - Value *ptr = bit_cast(base, ptr_ty(ty, space)); - Value *lane = axes_.at(a_axes_->get(arg, axis)).thread_id; - for(auto& x: accs) { - // current element being computed - Value *&acc = x.second; - indices_t write_idx = x.first; - write_idx[axis] = lane; - // shared memory write pointer - Value *write_off = shared_off(shape, order, write_idx); - Value *write_ptr = gep(ptr, write_off); - // initialize shared memory - add_barrier(); - store(acc, write_ptr); - // build result - indices_t idx(write_idx.size(), i32(0)); - for(size_t i = shape[axis]/2; i > 0; i >>= 1){ - idx[axis] = i32(i); - // read pointer - Value *read_msk = icmp_ult(lane, i32(i)); - Value *read_off = select(read_msk, shared_off(shape, order, idx), i32(0)); - Value *read_ptr = gep(write_ptr, read_off); - add_barrier(); - // update accumulator - acc = do_acc(acc, load(read_ptr)); - add_barrier(); - store(acc, write_ptr); - } - } - add_barrier(); - - // write back - for(indices_t idx: idxs_.at(x)){ - indices_t read_idx = idx; - read_idx.insert(read_idx.begin() + axis, i32(0)); - Value *read_off = shared_off(shape, order, read_idx); - Value *read_ptr = gep(ptr, read_off); - vals_[x][idx] = load(read_ptr); - }; -} - -/** - * \brief Code Generation for `reduce` (generic case) - */ -void generator::visit_reduce_inst(ir::reduce_inst* x) { - Type *ty = cvt(x->get_type()->get_scalar_ty()); - // accumulation function - ir::reduce_inst::op_t op = x->get_op(); - auto do_acc = [&](Value *x, Value *y) -> Value* { - switch(op){ - case ir::reduce_inst::ADD: return add(x, y); - case ir::reduce_inst::SUB: return sub(x, y); - case ir::reduce_inst::MAX: return select(icmp_sge(x, y), x, y); - case ir::reduce_inst::MIN: return select(icmp_sle(x, y), x, y); - case ir::reduce_inst::FADD: return fadd(x, y); - case ir::reduce_inst::FSUB: return fsub(x, y); - case ir::reduce_inst::FMAX: return max_num(x, y); - case ir::reduce_inst::FMIN: return min_num(x, y); - case ir::reduce_inst::XOR: return xor_(x, y); - default: throw std::runtime_error("unreachable"); - } - }; - // neutral element - Value *neutral; - switch(op) { - case ir::reduce_inst::ADD: neutral = ConstantInt::get(ty, 0); break; - case ir::reduce_inst::SUB: neutral = ConstantInt::get(ty, 0); break; - case ir::reduce_inst::MAX: neutral = ConstantInt::get(ty, INT32_MIN); break; - case ir::reduce_inst::MIN: neutral = ConstantInt::get(ty, INT32_MAX); break; - case ir::reduce_inst::FADD: neutral = ConstantFP::get(ty, 0); break; - case ir::reduce_inst::FSUB: neutral = ConstantFP::get(ty, 0); break; - case ir::reduce_inst::FMAX: neutral = ConstantFP::get(ty, -INFINITY); break; - case ir::reduce_inst::FMIN: neutral = ConstantFP::get(ty, INFINITY); break; - case ir::reduce_inst::XOR: neutral = neutral = ConstantInt::get(ty, 0); break; - default: throw std::runtime_error("unreachable"); - } - ir::value *arg = x->get_operand(0); - if(arg->get_type()->get_tile_rank() == 1) - visit_reduce1d_inst(x, do_acc, neutral); - else - visit_reducend_inst(x, do_acc, neutral); -} - -/** - * \brief Code Generation for `select` - */ -void generator::visit_select_inst(ir::select_inst* x) { - for(indices_t idx: idxs_.at(x)){ - vals_[x][idx] = select(vals_[x->get_operand(0)][idx], - vals_[x->get_operand(1)][idx], - vals_[x->get_operand(2)][idx]); - } -} - - - -void generator::visit_layout_convert(ir::value *out, ir::value *in){ - ir::block_type::block_shapes_t shape = out->get_type()->get_block_shapes(); - // pointer to temporary shared memory - Type *ty = cvt(out->get_type()->get_scalar_ty()); - - if (ty->isBFloatTy()) // llvm11-nvptx cannot select bf16 store - ty = f16_ty; - - // Orders - analysis::distributed_layout* in_layout = dynamic_cast(layouts_->get(in)); - analysis::distributed_layout* out_layout = dynamic_cast(layouts_->get(out)); - auto in_ord = in_layout->get_order(); - auto out_ord = out_layout->get_order(); - Value *base; - base = gep(shmem_, i32(alloc_->offset(layouts_->get(layouts_->tmp(out))))); - base = bit_cast(base, ptr_ty(ty, 3)); - std::vector n_reps; - for(int i = 0; i < shape.size(); i++){ - int in_per_cta = in_layout->shape_per_cta(i); - int out_per_cta = out_layout->shape_per_cta(i); - int max_per_cta = std::max(in_per_cta, out_per_cta); - n_reps.push_back(shape[i]/max_per_cta); - } - std::vector> in_ax; - std::vector> out_ax; - for(int d = 0; d < shape.size(); d++){ - in_ax.push_back(axes_.at(a_axes_->get(in, d)).values); - out_ax.push_back(axes_.at(a_axes_->get(out, d)).values); - } - in_ord = in_layout->to_mma() ? out_ord : in_ord; - out_ord = out_layout->to_mma() ? in_ord : out_ord; - int in_vec = out_ord[0] == 0 ? 1 : in_layout->contig_per_thread(in_ord[0]); - int out_vec = out_ord[0] == 0 ? 1 : out_layout->contig_per_thread(out_ord[0]); - int pad = std::max(in_vec, out_vec); - Value *in_ld = i32(shape[in_ord[0]] + pad); - Value *out_ld = i32(shape[out_ord[0]] + pad); - for(int i = 0; i < n_reps[0]; i++) - for(int j = 0; j < n_reps[1]; j++){ - int max_ii, max_jj; - add_barrier(); - max_ii = in_ax[0].size()/n_reps[0]; - max_jj = in_ax[1].size()/n_reps[1]; - for(int ii = 0; ii < max_ii; ii++) - for(int jj = 0; jj < max_jj; jj+=in_vec){ - // shared mem pointer - indices_t offs = {in_ax[0][ii], in_ax[1][jj]}; - Value *off = add(offs[out_ord[0]], mul(out_ld, offs[out_ord[1]])); - Value *ptr = gep(base, off); - // stash value to shared mem - Value* vals = UndefValue::get(vec_ty(ty, in_vec)); - for(int jjj = 0; jjj < in_vec; jjj++){ - indices_t idxs = {in_ax[0][i*max_ii + ii], - in_ax[1][j*max_jj + jj + jjj]}; - Value* val = bit_cast(vals_[in][idxs], ty); - vals = insert_elt(vals, val, jjj); - } - ptr = bit_cast(ptr, ptr_ty(vals->getType(), ptr->getType()->getPointerAddressSpace())); - store(vals, ptr); - } - add_barrier(); - max_ii = out_ax[0].size()/n_reps[0]; - max_jj = out_ax[1].size()/n_reps[1]; - for(int ii = 0; ii < max_ii; ii++) - for(int jj = 0; jj < max_jj; jj+=out_vec){ - // shared mem pointer - indices_t offs = {out_ax[0][ii], out_ax[1][jj]}; - Value *off = add(offs[out_ord[0]], mul(out_ld, offs[out_ord[1]])); - Value *ptr = gep(base, off); - ptr = bit_cast(ptr, ptr_ty(vec_ty(ty, out_vec), ptr->getType()->getPointerAddressSpace())); - // load value from shared rem - Value* vals = load(ptr); - for(int jjj = 0; jjj < out_vec; jjj++){ - indices_t idxs = {out_ax[0][i*max_ii + ii], - out_ax[1][j*max_jj + jj + jjj]}; - vals_[out][idxs] = extract_elt(vals, jjj); - } - } - - } -} - -void generator::visit_cvt_layout_inst(ir::cvt_layout_inst *rc) { - visit_layout_convert(rc, rc->get_operand(0)); -} - -void generator::visit_masked_load_async_inst(ir::masked_load_async_inst* x){ - unsigned in_vec = 1; - ir::value *arg = x->get_pointer_operand(); - analysis::shared_layout* out_layout = layouts_->get(x)->to_shared(); - analysis::scanline_layout* in_layout = layouts_->get(arg)->to_scanline(); - auto out_order = out_layout->get_order(); - auto in_order = in_layout->get_order(); - // tiles - if(out_order == in_order) - in_vec = in_layout->nts(in_order[0]); - int out_vec = swizzle_->get_vec(out_layout); - int min_vec = std::min(out_vec, in_vec); - int s = std::max(out_vec / in_vec, 1); - // - int per_phase = swizzle_->get_per_phase(out_layout); - int max_phase = swizzle_->get_max_phase(out_layout); - // - int in_ld = in_layout->get_shape()[in_order[0]] / in_layout->mts(in_order[0]); - int n_shared_1 = std::max(per_phase*max_phase / in_layout->mts(in_order[1]), 1); - int n_shared_0 = std::max(in_vec / out_vec, 1); - auto shapes = x->get_type()->get_block_shapes(); - BasicBlock* CurrBB = builder_->GetInsertBlock(); - BasicBlock* FirstBB = &CurrBB->getParent()->getEntryBlock(); - std::map, Value*> tmp; - std::vector> shared; - for(int i = 0; i < idxs_.at(arg).size(); i++){ - unsigned id = i / min_vec; - // input ptr info - int id_0 = id % (in_ld/min_vec); - int id_1 = id / (in_ld/min_vec); - int off_0 = id_0 / n_shared_0 * n_shared_0 * in_layout->mts(in_order[0]); - int off_1 = id_1 / n_shared_1 * n_shared_1 * in_layout->mts(in_order[1]); - int off = (off_1*shapes[in_order[0]] + off_0); - std::pair key = {id_1 % n_shared_1, id_0 % n_shared_0}; - if(tmp.find(key) == tmp.end()){ - if(CurrBB != FirstBB) - builder_->SetInsertPoint(FirstBB->getTerminator()); - indices_t idx = idxs_.at(arg).at(key.first*in_ld); - Value* phase = udiv(idx[in_order[1]], i32(per_phase)); - phase = urem(phase, i32(max_phase)); - Value* off_1 = mul(idx[in_order[1]], i32(shapes[in_order[0]])); - Value* off_0 = add(idx[in_order[0]], i32(key.second*out_vec)); - off_0 = udiv(off_0, i32(min_vec)); - off_0 = add(mul(xor_(udiv(off_0, i32(s)), phase),i32(s)), urem(off_0, i32(s))); - off_0 = mul(off_0 , i32(min_vec)); - Value* off = add(off_0, off_1); - if(CurrBB != FirstBB) - builder_->SetInsertPoint(CurrBB); - tmp[key] = gep(shmems_[x], {off}); - } - shared.push_back({tmp[key], off}); - } - size_t dtsize = x->get_type()->get_scalar_ty()->get_primitive_size_in_bits() / 8; - for(size_t i = 0; i < idxs_.at(arg).size(); i += in_vec){ - auto idx = idxs_[arg][i]; - // input ptr info - Value *ptr = vals_[arg][idx]; - size_t in_off = 0; - GetElementPtrInst *in_gep = dyn_cast(vals_[arg][idx]); - if(in_gep){ - ConstantInt* cst = dyn_cast(in_gep->idx_begin()); - in_off = cst ? cst->getValue().getSExtValue()*dtsize : 0; - ptr= cst ? in_gep->getPointerOperand() : in_gep; - } - // output ptr info - Value* out_base = shared[i].first; - int out_off = shared[i].second*dtsize; - // asm - std::string mod = (in_vec*dtsize == 16) ? ".cg" : ".ca"; -// Value* false_value = vals_[x->get_false_value_operand()][idx]; -// bool is_zero_false_value = false; -// if(Constant* cst = dyn_cast(false_value)) -// is_zero_false_value = cst->isZeroValue(); - Value* src_size = builder_->CreateSelect(vals_[x->get_mask_operand()][idx], i32(in_vec*dtsize), i32(0)); - std::string asm_str = "cp.async" + mod + ".shared.global [$0 + " + std::to_string(out_off) + "], [$1 + " + std::to_string(in_off) + "], " + std::to_string(in_vec*dtsize) + ", $2;"; - FunctionType *ty = FunctionType::get(void_ty, {out_base->getType(), ptr->getType(), builder_->getInt32Ty()}, false); - InlineAsm *iasm = InlineAsm::get(ty, asm_str, "r,l,r", true); - call(iasm, {out_base, ptr, src_size}); - } - - std::string asm_str = "cp.async.commit_group;"; - InlineAsm *iasm = InlineAsm::get(FunctionType::get(void_ty, {}), asm_str, "", true); - call(iasm); -} - -void generator::visit_copy_to_shared_inst(ir::copy_to_shared_inst* cts) { - unsigned in_vec = 1; - ir::value *arg = cts->get_operand(0); - analysis::shared_layout* out_layout = layouts_->get(cts)->to_shared(); - analysis::scanline_layout* in_layout = layouts_->get(arg)->to_scanline(); - auto out_order = out_layout->get_order(); - auto in_order = in_layout->get_order(); - // tiles - if(out_order == in_order) - in_vec = in_layout->nts(in_order[0]); - int out_vec = swizzle_->get_vec(out_layout); - int min_vec = std::min(out_vec, in_vec); - int s = std::max(out_vec / in_vec, 1); - // - int per_phase = swizzle_->get_per_phase(out_layout); - int max_phase = swizzle_->get_max_phase(out_layout); - // - int in_ld = in_layout->get_shape()[in_order[0]] / in_layout->mts(in_order[0]); - int n_shared_1 = std::max(per_phase*max_phase / in_layout->mts(in_order[1]), 1); - int n_shared_0 = std::max(in_vec / out_vec, 1); - - BasicBlock* CurrBB = builder_->GetInsertBlock(); - BasicBlock* FirstBB = &CurrBB->getParent()->getEntryBlock(); - auto shapes = cts->get_type()->get_block_shapes(); - - // store to shared - Value *current = nullptr; - std::map, Value*> ptrs; - for(int i = 0; i < idxs_.at(arg).size(); i++){ - auto idx = idxs_[arg][i]; - Value *in_value = vals_[arg][idx]; - if(i % min_vec == 0) - current = UndefValue::get(vec_ty(in_value->getType(), min_vec)); - current = insert_elt(current, in_value, i % min_vec); - if(i % min_vec == min_vec - 1){ - unsigned id = i / min_vec; - // input ptr info - int id_0 = id % (in_ld/min_vec); - int id_1 = id / (in_ld/min_vec); - int off_0 = id_0 / n_shared_0 * n_shared_0 * in_layout->mts(in_order[0]); - int off_1 = id_1 / n_shared_1 * n_shared_1 * in_layout->mts(in_order[1]); - int off = (off_1*shapes[in_order[0]] + off_0); - std::pair key = {id_1 % n_shared_1, id_0 % n_shared_0}; - if(ptrs.find(key) == ptrs.end()){ - if(FirstBB->getTerminator()) - builder_->SetInsertPoint(FirstBB->getTerminator()); - else - builder_->SetInsertPoint(FirstBB); - indices_t idx = idxs_.at(arg).at(key.first*in_ld); - Value* phase = udiv(idx[in_order[1]], i32(per_phase)); - phase = urem(phase, i32(max_phase)); - Value* off_1 = mul(idx[in_order[1]], i32(shapes[in_order[0]])); - Value* off_0 = add(idx[in_order[0]], i32(key.second*out_vec)); - off_0 = udiv(off_0, i32(min_vec)); - off_0 = add(mul(xor_(udiv(off_0, i32(s)), phase),i32(s)), urem(off_0, i32(s))); - off_0 = mul(off_0 , i32(min_vec)); - Value* off = add(off_0, off_1); - builder_->SetInsertPoint(CurrBB); - ptrs[key] = gep(shmems_.at(cts), {off}); - } - Value* ptr = gep(ptrs[key], {i32(off)}); - ptr = bit_cast(ptr, current->getType()->getPointerTo(3)); - // asm - store(current, ptr); - } - }; -} - -void generator::visit_copy_from_shared_inst(ir::copy_from_shared_inst*) { - throw std::runtime_error("TODO"); -} - -Instruction* generator::add_barrier() { - Module *module = builder_->GetInsertBlock()->getModule(); - return tgt_->add_barrier(module, *builder_); -} - -void generator::visit_barrier_inst(ir::barrier_inst*) { - add_barrier(); -} - -void generator::visit_prefetch_s_inst(ir::prefetch_s_inst *i) { - ir::value *v = i->get_operand(0); - int inc = i->get_inc(); - if (inc == 0) { - // If dot has not been visitied, do nothing. - } else { - // If dot has been visitied, insert prefetched lds - assert(inc == 1); - assert(prefetch_latch_to_bb_.find(v) != prefetch_latch_to_bb_.end() && - "dot hasn't be visited"); - // sink lds & extract element - // move lds & all uses to current location - std::stack work_stack; - for (Value *value : prefetch_latch_to_bb_[v]) - work_stack.push(value); - std::vector dead_instrs; - while (!work_stack.empty()) { - Value *m = work_stack.top(); - work_stack.pop(); - - for (auto u : m->users()) - work_stack.push(u); - - assert(isa(m)); - auto m_instr = static_cast(m); - - m_instr->removeFromParent(); - m_instr->insertAfter(&*std::prev(builder_->GetInsertBlock()->end())); - assert(m_instr->getParent() == &*builder_->GetInsertBlock()); - builder_->SetInsertPoint(m_instr->getParent()); - } - } -} - -void generator::visit_async_wait_inst(ir::async_wait_inst* i) { - std::string asm_str = "cp.async.wait_group " + std::to_string(i->get_N()) + ";"; - InlineAsm *iasm = InlineAsm::get(FunctionType::get(void_ty, {}), asm_str, "", true); - call(iasm); -} - -//void generator::visit_make_range_dyn(ir::make_range_dyn* x) { -// for(indices_t idx: idxs_.at(x)){ -// assert(idx.size() == 1); -// if(idx[0] == i32(0)) -// vals_[x][idx] = idx[0]; -// else{ -// BinaryOperator *bin_add = dyn_cast(idx[0]); -// assert(bin_add); -// vals_[x][idx] = bin_add->getOperand(0); -// } -// } -//} - -//void generator::visit_make_range_sta(ir::make_range_sta* x) { -// for(indices_t idx: idxs_.at(x)){ -// assert(idx.size() == 1); -// if(idx[0] == i32(0)){ -// vals_[x][idx] = idx[0]; -// } -// else{ -// BinaryOperator *bin_add = dyn_cast(idx[0]); -// assert(bin_add); -// Value *cst = bin_add->getOperand(1); -// assert(isa(cst)); -// vals_[x][idx] = cst; -// } -// }; -//} - -void generator::visit_make_range(ir::make_range* x) { - for(indices_t idx: idxs_.at(x)){ - Value* start = ConstantInt::get(idx[0]->getType(), x->get_first()->get_value()); - vals_[x][idx] = add(start, idx[0]); - } -} - -void generator::visit_undef_value(ir::undef_value *x) { - Type* ty = cvt(x->get_type()->get_scalar_ty()); - for(indices_t idx: idxs_.at(x)) - vals_[x][idx] = llvm::UndefValue::get(ty); -} - -void generator::visit_constant_int(ir::constant_int *x){ - Type *ty = cvt(x->get_type()->get_scalar_ty()); - for(indices_t idx: idxs_.at(x)) - vals_[x][idx] = ConstantInt::get(ty, x->get_value()); -} - -void generator::visit_constant_fp(ir::constant_fp *x){ - Type *ty = cvt(x->get_type()->get_scalar_ty()); - for(indices_t idx: idxs_.at(x)) - vals_[x][idx] = ConstantFP::get(ty, x->get_value()); -} - -void generator::visit_alloc_const(ir::alloc_const *alloc) { - unsigned size = ((ir::constant_int*)alloc->get_operand(0))->get_value(); - Type *element_ty = cvt(alloc->get_type()->get_pointer_element_ty()); - Type *array_ty = llvm::ArrayType::get(element_ty, size); - Value *array = new llvm::GlobalVariable(*mod_, array_ty, false, llvm::GlobalVariable::ExternalLinkage, - nullptr, alloc->get_name(), nullptr, llvm::GlobalVariable::NotThreadLocal, 4); - vals_[alloc][{}] = bit_cast(array, element_ty->getPointerTo(4)); -} - - -void generator::visit_function(ir::function* fn) { - LLVMContext &ctx = builder_->getContext(); - FunctionType *fn_ty = (FunctionType*)cvt(fn->get_fn_type()); - if(!tgt_->is_gpu()){ - Type *fn_ret_ty = fn_ty->getReturnType(); - std::vector fn_args_ty; - for(unsigned i = 0; i < fn_ty->getNumParams(); i++) - fn_args_ty.push_back(fn_ty->getParamType(i)); - fn_args_ty.push_back(i32_ty); - fn_args_ty.push_back(i32_ty); - fn_args_ty.push_back(i32_ty); - fn_ty = FunctionType::get(fn_ret_ty, fn_args_ty, false); - } - Function *ret = Function::Create(fn_ty, Function::ExternalLinkage, fn->get_name(), mod_); - // set attributes - for(auto attr_pair: fn->attrs()){ - unsigned id = attr_pair.first; - for(ir::attribute attr: attr_pair.second) - if(attr.is_llvm_attr()){ - llvm::Attribute llattr = cvt(attr); - if(llattr.getKindAsEnum() != llvm::Attribute::None) - ret->addAttribute(id, cvt(attr)); - } - } - // set metadata - if(tgt_->is_gpu()){ - tgt_->set_kernel(*builder_, ctx, mod_, ret); - Metadata *md_args[] = { - ValueAsMetadata::get(ret), - MDString::get(ctx, "maxntidx"), - ValueAsMetadata::get(i32(num_warps_*32)) - }; - mod_->getOrInsertNamedMetadata("nvvm.annotations")->addOperand(MDNode::get(ctx, md_args)); - } - // set arguments - for(unsigned i = 0; i < fn->args().size(); i++) - vals_[fn->args()[i]][{}] = &*(ret->arg_begin() + i); - // create blocks - for(ir::basic_block *block: fn->blocks()) { - BasicBlock *dst_block = BasicBlock::Create(ctx, block->get_name(), ret); - bbs_[block] = dst_block; - } - builder_->SetInsertPoint(bbs_[fn->blocks()[0]]); - // initialize layouts - for(auto x: layouts_->get_all()){ - visit_layout(x.second); - } - // generate LLVM-IR code - for(ir::basic_block *block: fn->blocks()) - visit_basic_block(block); - // finalize - finalize_function(fn); -} - - - -void generator::visit_layout_mma(analysis::mma_layout* layout) { - ir::value *a = nullptr; - ir::value *b = nullptr; - for(ir::value* v: layout->get_values()) - if(ir::dot_inst* dot = dynamic_cast(v)){ - a = dot->get_operand(0); - b = dot->get_operand(1); - } - analysis::data_layout* layout_a = layouts_->get(a); - analysis::data_layout* layout_b = layouts_->get(b); - - const auto& shape = layout->get_shape(); - Value *_1 = i32(1); - Value *_2 = i32(2); - Value *_3 = i32(3); - Value *_4 = i32(4); - Value *_8 = i32(8); - Value *_16 = i32(16); - Value *_32 = i32(32); - int cc = tgt_->as_nvidia()->sm(); - std::vector idx_m; - std::vector idx_n; - std::vector idx_z; - // - Value* thread = tgt_->get_local_id(mod_, *builder_, 0); - Value *lane = urem(thread, _32); - Value *warp = udiv(thread, _32); - /* lane offset */ - if(cc < 80){ - auto ord_a = layout_a->get_order(); - auto ord_b = layout_b->get_order(); - bool is_a_row = ord_a[0] != 0; - bool is_b_row = ord_b[0] != 0; - /* warp offset */ - Value *warp_0 = urem(warp, i32(layout->wpt(0))); - Value *warp_12 = udiv(warp, i32(layout->wpt(0))); - Value *warp_1 = urem(warp_12, i32(layout->wpt(1))); - Value *off_warp_m = mul(warp_0, i32(layout->spw(0))); - Value *off_warp_n = mul(warp_1, i32(layout->spw(1))); - // Quad offset - Value *off_quad_m = mul(udiv(and_(lane, _16), _4), i32(layout->fpw(0))); - Value *off_quad_n = mul(udiv(and_(lane, _16), _4), i32(layout->fpw(1))); - // Pair offset - Value *off_pair_m = udiv(urem(lane, _16), _4); - off_pair_m = urem(off_pair_m, i32(layout->fpw(0))); - off_pair_m = mul(off_pair_m, i32(4)); - Value *off_pair_n = udiv(urem(lane, _16), _4); - off_pair_n = udiv(off_pair_n, i32(layout->fpw(0))); - off_pair_n = urem(off_pair_n, i32(layout->fpw(1))); - off_pair_n = mul(off_pair_n, i32(4)); - // scale - off_pair_m = mul(off_pair_m, i32(layout->rep(0)/2)); - off_quad_m = mul(off_quad_m, i32(layout->rep(0)/2)); - off_pair_n = mul(off_pair_n, i32(layout->rep(1)/2)); - off_quad_n = mul(off_quad_n, i32(layout->rep(1)/2)); - // Quad pair offset - Value *off_lane_m = add(off_pair_m, off_quad_m); - Value *off_lane_n = add(off_pair_n, off_quad_n); - // a offset - offset_a_m_[layout] = add(off_warp_m, off_lane_m); - offset_a_k_[layout] = and_(lane, _3); - // b offsets - offset_b_n_[layout] = add(off_warp_n, off_lane_n); - offset_b_k_[layout] = and_(lane, _3); - // i indices - Value *offset_c_m = add(and_(lane, _1), offset_a_m_[layout]); - for(unsigned m = 0; m < shape[0]; m+=layout->shape_per_cta(0)) - for(unsigned mm = 0; mm < layout->rep(0); mm++) - idx_m.push_back(add(offset_c_m, i32(m + mm*2))); - // j indices - Value *offset_c_n = add(and_(lane, _2), add(off_warp_n, off_pair_n)); - for(unsigned n = 0; n < shape[1]; n+=layout->shape_per_cta(1)) - for(unsigned nn = 0; nn < layout->rep(1); nn++){ - idx_n.push_back(add(offset_c_n, i32(n + nn/2*4 + (nn%2)*2*layout->fpw(1)*layout->rep(1)))); - idx_n.push_back(add(offset_c_n, i32(n + nn/2*4 + (nn%2)*2*layout->fpw(1)*layout->rep(1) + 1))); - } - if(is_a_row){ - offset_a_m_[layout] = add(offset_a_m_[layout], urem(thread, i32(4))); - offset_a_k_[layout] = i32(0); - } - if(!is_b_row){ - offset_b_n_[layout] = add(offset_b_n_[layout], urem(thread, i32(4))); - offset_b_k_[layout] = i32(0); - } - /* axes */ - axes_[layout->get_axis(0)] = distributed_axis{1, idx_m, warp_0}; - axes_[layout->get_axis(1)] = distributed_axis{1, idx_n, warp_1}; - } - else{ - /* warp offset */ - Value *warp_0 = urem(warp, i32(layout->wpt(0))); - Value *warp_12 = udiv(warp, i32(layout->wpt(0))); - Value *warp_1 = urem(warp_12, i32(layout->wpt(1))); - Value *off_warp_m = mul(warp_0, i32(layout->spw(0))); - Value *off_warp_n = mul(warp_1, i32(layout->spw(1))); - Value *off_lane_m = urem(lane, _16); - Value *off_lane_n = urem(lane, _8); - /* offsets */ - // a offset - offset_a_m_[layout] = add(off_warp_m, off_lane_m); - offset_a_k_[layout] = i32(0); - // b offsets - offset_b_n_[layout] = add(off_warp_n, off_lane_n); - offset_b_k_[layout] = i32(0); - // c offset - Value *off_c_m = add(udiv(lane, _4), off_warp_m); - Value *off_c_n = add(mul(_2, urem(lane, _4)), off_warp_n); - for(unsigned m = 0; m < shape[0]; m+=layout->shape_per_cta(0)){ - idx_m.push_back(add(off_c_m, i32(m))); - idx_m.push_back(add(off_c_m, i32(m + 8))); - } - for(unsigned n = 0; n < shape[1]; n+=layout->shape_per_cta(1)){ - idx_n.push_back(add(off_c_n, i32(n))); - idx_n.push_back(add(off_c_n, i32(n + 1))); - } - /* axes */ - axes_[layout->get_axis(0)] = distributed_axis{1, idx_m, warp_0}; - axes_[layout->get_axis(1)] = distributed_axis{1, idx_n, warp_1}; - } -} - -void generator::visit_layout_scanline(analysis::scanline_layout* layout) { - Value *warp_size = i32(32); - Value* u_thread_id_0 = tgt_->get_local_id(mod_, *builder_, 0); - Value *u_thread_id = urem(u_thread_id_0, warp_size); - Value *u_warp_id = udiv(u_thread_id_0, warp_size); - - auto order = layout->get_order(); - const auto& shape = layout->get_shape(); - Value* full_thread_id = add(mul(u_warp_id, i32(32)), u_thread_id); - // Delinearize - size_t dim = shape.size(); - std::vector thread_id(dim); - for(unsigned k = 0; k < dim - 1; k++){ - Constant *dim_k = i32(layout->mts(order[k])); - Value *rem = urem(full_thread_id, dim_k); - full_thread_id = udiv(full_thread_id, dim_k); - thread_id[order[k]] = rem; - } - thread_id[order[dim - 1]] = full_thread_id; - // Create axes - for(unsigned k = 0; k < dim; k++) { - int nts = layout->nts(k); - int mts = layout->mts(k); - std::string str_k = std::to_string(k); - Value *contiguous_k = i32(nts); - Value *scaled_thread_id = mul(thread_id[k], contiguous_k); - unsigned per_cta = layout->shape_per_cta(k); - unsigned per_thread = nts * shape[k] / per_cta; - std::vector idx_list(per_thread); - for(unsigned n = 0 ; n < per_thread; n++){ - unsigned offset = n / nts * per_cta + n % nts; - idx_list[n] = add(scaled_thread_id, i32(offset), "idx_" + str_k + "_" + std::to_string(n)); - } - axes_[layout->get_axis(k)] = distributed_axis{nts, idx_list, thread_id[k]}; - } -} - -void generator::visit_layout_shared(analysis::shared_layout* layout) { - Type* ty = cvt(layout->get_type()); - PointerType *ptr_ty = ty->getPointerTo(shmem_->getType()->getPointerAddressSpace()); - if (layout->get_N_buffer()) { - // create pointers - shared_pre_ptr_[layout] = gep(shmem_, i32(alloc_->offset(layout))); - shared_pre_ptr_[layout] = bit_cast(shared_pre_ptr_[layout], ptr_ty); - - BasicBlock *current = builder_->GetInsertBlock(); - - auto info = *layout->get_N_buffer(); - ir::phi_node *phi = info.phi; - BasicBlock *parent = bbs_.at(phi->get_parent()); - if(parent->empty()) - builder_->SetInsertPoint(parent); - else if (const Instruction *first_non_phi = &*parent->getFirstNonPHI()) { - builder_->SetInsertPoint(&*parent->getFirstNonPHI()); - } else - builder_->SetInsertPoint(parent); - - // create smem_idx - read_smem_idx_[layout] = phi(i32_ty, 2); - write_smem_idx_[layout] = phi(i32_ty, 2); - - // create pointers - // ptr of the current iteration - shared_ptr_[layout] = phi(ptr_ty, 2); - // ptr of the next iteration - shared_next_ptr_[layout] = phi(ptr_ty, 2); - - builder_->SetInsertPoint(current); - } else if(layout->get_double_buffer()) { - BasicBlock *current = builder_->GetInsertBlock(); - auto info = *layout->get_double_buffer(); - ir::phi_node *phi = info.phi; - BasicBlock *parent = bbs_.at(phi->get_parent()); - if(parent->empty()) - builder_->SetInsertPoint(parent); - else - builder_->SetInsertPoint(&*parent->getFirstNonPHI()); - // create pointers - shared_ptr_[layout] = phi(ptr_ty, 2); - shared_pre_ptr_[layout] = gep(shmem_, i32(alloc_->offset(layout))); - shared_pre_ptr_[layout] = bit_cast(shared_pre_ptr_[layout], shared_ptr_[layout]->getType()); - shared_off_[layout] = phi(i32_ty, 2); - shared_next_ptr_[layout] = gep(shared_ptr_[layout], shared_off_[layout], "next_ptr"); - builder_->SetInsertPoint(current); - } else{ - size_t offset = alloc_->offset(layout); - shared_ptr_[layout] = gep(shmem_, i32(offset)); - shared_ptr_[layout] = bit_cast(shared_ptr_[layout], ptr_ty); - } -} - -void generator::visit_basic_block(ir::basic_block * block) { - BasicBlock *parent = bbs_[block]; - builder_->SetInsertPoint(parent); - for(ir::instruction *i: block->get_inst_list()) - visit_value(i); - // Update ir bb -> llvm bb mapping - bbs_[block] = builder_->GetInsertBlock(); -} - -void generator::visit_argument(ir::argument* arg) { - -} - -void generator::init_idx(ir::value *v) { - idxs_[v].clear(); - if(!v->get_type()->is_block_ty()){ - idxs_[v].push_back({}); - return; - } - if(layouts_->get(v)->to_shared()) - return; - const auto &shapes = v->get_type()->get_block_shapes(); - size_t rank = shapes.size(); - std::vector axes(rank); - std::vector ord(rank); - // compute axes - for(size_t d = 0; d < shapes.size(); d++){ - if(shapes[d] > 1){ - unsigned x = a_axes_->get(v, d); - axes[d] = axes_.at(x); - } - else{ - axes[d].contiguous = 1; - axes[d].values = {i32(0)}; - } - } - // compute order - analysis::data_layout* layout = layouts_->get(v); - std::iota(ord.begin(), ord.end(), 0); - auto cmp = [&](int x, int y) { - unsigned axx = a_axes_->get(v, x); - unsigned axy = a_axes_->get(v, y); - size_t posx = layout->find_axis(axx); - size_t posy = layout->find_axis(axy); - if(posx < rank && posy < rank) - return layout->get_order(posx) < layout->get_order(posy); - return false; - }; - std::sort(ord.begin(), ord.end(), cmp); - ords_[v] = ord; - // indices - if(axes.size() == 1) - for(Value* x0: axes[ord[0]].values){ - idxs_[v].push_back({x0}); - } - if(axes.size() == 2) - for(Value* x1: axes[ord[1]].values) - for(Value* x0: axes[ord[0]].values){ - indices_t idx(2); - idx[ord[0]] = x0; - idx[ord[1]] = x1; - idxs_[v].push_back(idx); - } - if(axes.size() == 3) - for(Value* x2: axes[ord[2]].values) - for(Value* x1: axes[ord[1]].values) - for(Value* x0: axes[ord[0]].values){ - indices_t idx(3); - idx[ord[0]] = x0; - idx[ord[1]] = x1; - idx[ord[2]] = x2; - idxs_[v].push_back(idx); - } -} - -void generator::finalize_shared_layout(analysis::shared_layout *shared) { - if (auto n_buffer = shared->get_N_buffer()) { - // if (*_smem_idx == #stages-1) { - // *_smem_idx = 0; - // } else *_smem_idx++; - auto finalize_smem_idx = [&](auto &smem_idx, int init_stage) { - // insert point - Value *idx = smem_idx[shared]; - builder_->SetInsertPoint(bbs_.at(n_buffer->phi->get_parent())->getTerminator()); - Value *cond = icmp_eq(idx, i32(shared->get_num_stages()-1)); - PHINode *_ret = phi(i32_ty, 2); - Instruction *then_term = nullptr; - Instruction *else_term = nullptr; - Instruction *dummy = builder_->CreateRet(nullptr); - llvm::SplitBlockAndInsertIfThenElse(cond, _ret, &then_term, &else_term, nullptr); - dummy->removeFromParent(); - builder_->SetInsertPoint(then_term); - Value *zero_smem_idx = i32(0); - builder_->SetInsertPoint(else_term); - Value *inc_smem_idx = add(idx, i32(1)); - builder_->SetInsertPoint(_ret->getParent()); - _ret->addIncoming(zero_smem_idx, then_term->getParent()); - _ret->addIncoming(inc_smem_idx, else_term->getParent()); - // update ir::bb -> llvm::bb mapping - bbs_.at(n_buffer->phi->get_parent()) = builder_->GetInsertBlock(); - // idx = init_stage; - // loop: ... - if (auto idx_phi = llvm::dyn_cast(smem_idx[shared])) { - idx_phi->addIncoming(i32(init_stage), bbs_.at(n_buffer->phi->get_incoming_block(0))); - idx_phi->addIncoming(_ret, bbs_.at(n_buffer->phi->get_incoming_block(1))); - } else - throw std::runtime_error("Should be PHINode"); - }; - - // read_smem_idx is used by next_ptr to compute the next iteration value, so init value is 2 - finalize_smem_idx(read_smem_idx_, 2); - finalize_smem_idx(write_smem_idx_, shared->get_num_stages()-1); - - // finalize pointers - ir::phi_node *pn = n_buffer->phi; - BasicBlock *header = bbs_.at(pn->get_incoming_block(0)); - BasicBlock *loop = bbs_.at(pn->get_incoming_block(1)); - // %curr_ptr = phi %shared_pre_ptr, %next_ptr - // %next_ptr = phi %shared_pre_ptr[+1], (gep(%pre_ptr, read_smem_idx*per_stage_size)) - if (auto curr_ptr = dyn_cast(shared_ptr_[shared])) { - curr_ptr->addIncoming(shared_pre_ptr_[shared], header); - curr_ptr->addIncoming(shared_next_ptr_[shared], loop); - } else - throw std::runtime_error("Should be PHINode"); - - BasicBlock *current = builder_->GetInsertBlock(); - builder_->SetInsertPoint(header->getTerminator()); - Value *next_ptr_header = gep(shared_pre_ptr_[shared], i32(shared->get_per_stage_elements())); - builder_->SetInsertPoint(current->getTerminator()); - - assert(isa(shared_next_ptr_[shared])); - static_cast(shared_next_ptr_[shared])->addIncoming(next_ptr_header, header); - - Value *lds_offset = mul(read_smem_idx_[shared], i32(shared->get_per_stage_elements())); - Value *next_ptr = gep(shared_pre_ptr_[shared], lds_offset); - static_cast(shared_next_ptr_[shared])->addIncoming(next_ptr, loop); - } else if(shared->get_double_buffer()) { - auto info = *shared->get_double_buffer(); - ir::phi_node *phi = info.phi; - PHINode *ptr = (PHINode*)shmems_[phi]; - PHINode *offset = (PHINode*)shoffs_[phi]; - for(unsigned n = 0; n < phi->get_num_incoming(); n++){ - ir::basic_block* inc_block = phi->get_incoming_block(n); - ir::value* inc_val = phi->get_incoming_value(n); - BasicBlock *llvm_inc_block = bbs_.at(inc_block); - if(inc_val == info.latch){ - builder_->SetInsertPoint(llvm_inc_block->getTerminator()); - Value *next_offset = neg(offset); - offset->addIncoming(next_offset, llvm_inc_block); - } - else { - unsigned num_bytes = shared->get_type()->get_primitive_size_in_bits() / 8; - offset->addIncoming(i32(shared->get_size() / (2*num_bytes)), llvm_inc_block); - } - ptr->addIncoming(shmems_[inc_val], llvm_inc_block); - } - } -} - -void generator::finalize_function(ir::function *fn) { - // finalize double-buffering - for(const auto& x: layouts_->get_all()) - if(auto *shared = dynamic_cast(x.second)) - finalize_shared_layout(shared); - // finalize phi - for(ir::basic_block *block: fn->blocks()) - for(ir::instruction *inst: block->get_inst_list()) - if(auto *phi = dynamic_cast(inst)) - finalize_phi_node(phi); - for(auto& x: lazy_phi_incs_) - std::get<0>(x)->addIncoming(std::get<1>(x), bbs_[std::get<2>(x)]); -} - -void generator::finalize_phi_node(ir::phi_node *x) { - if(shmems_.find(x) != shmems_.end()) - return; - for(unsigned n = 0; n < x->get_num_incoming(); n++){ - ir::basic_block *_block = x->get_incoming_block(n); - BasicBlock *block = bbs_.at(_block); - for(indices_t idx: idxs_.at(x)){ - PHINode *phi = (PHINode*)vals_[x][idx]; - Value *inc = vals_[x->get_incoming_value(n)][idx]; - phi->addIncoming(inc, block); - } - } -} - -void generator::visit(ir::module &src, llvm::Module &dst) { - mod_ = &dst; - ctx_ = &dst.getContext(); - builder_ = new Builder(*ctx_); - // allocate shared memory - if(tgt_->is_gpu()) - if(unsigned alloc_size = alloc_->allocated_size()){ - Type *int_8_ty = Type::getInt8Ty(*ctx_); - Type *int_32_ty = Type::getInt32Ty(*ctx_); - ArrayType *array_ty = ArrayType::get(int_32_ty, 0); - Type *ptr_ty = ptr_ty(int_8_ty, 3); - GlobalVariable *sh_mem_array = - new GlobalVariable(*mod_, array_ty, false, GlobalVariable::ExternalLinkage, - nullptr, "__shared_ptr", nullptr, GlobalVariable::NotThreadLocal, 3); - shmem_ = bit_cast(sh_mem_array, ptr_ty); - } - // visit functions - for(ir::function *fn: src.get_function_list()) - visit_function(fn); -} - - -} -} diff --git a/lib/codegen/target.cc b/lib/codegen/target.cc deleted file mode 100644 index 82ebbe64986a..000000000000 --- a/lib/codegen/target.cc +++ /dev/null @@ -1,173 +0,0 @@ -#include "triton/codegen/target.h" -#include "llvm/IR/IRBuilder.h" -#include "llvm/IR/Function.h" -#include "llvm/IR/Intrinsics.h" -#include "llvm/IR/IntrinsicsNVPTX.h" -#include "llvm/IR/IntrinsicsAMDGPU.h" -#include "llvm/IR/Value.h" -#include "llvm/IR/IRBuilder.h" -#include - -using namespace llvm; - -namespace triton{ -namespace codegen{ - -// base - - -nvidia_cu_target* target::as_nvidia() { - return dynamic_cast(this); -} - -bool target::is_gpu() const { - return is_gpu_; -} - -// AMD -void amd_cl_target::set_kernel(IRBuilder<>& builder, LLVMContext &ctx, Module *module, Function* fn) { - fn->setCallingConv(CallingConv::AMDGPU_KERNEL); -} - -Instruction* amd_cl_target::add_barrier(Module *module, IRBuilder<>& builder) { - Function *barrier = Intrinsic::getDeclaration(module, Intrinsic::amdgcn_s_barrier); - return builder.CreateIntrinsic(Intrinsic::amdgcn_s_barrier, {}, {}); -} - -Value* amd_cl_target::get_global_offset(Module *module, IRBuilder<>& builder, unsigned stride, unsigned ax) { - Value* group_id = get_block_id(module, builder, ax); - Value* result = builder.CreateMul(builder.getInt32(stride), group_id); - return result; -} - -Instruction* amd_cl_target::add_memfence(Module *module, IRBuilder<>& builder) { - throw std::runtime_error("not implemented"); -} - - -Value* amd_cl_target::get_block_id(Module *module, IRBuilder<>& builder, unsigned ax) { - static std::array ids = { - Intrinsic::amdgcn_workgroup_id_x, - Intrinsic::amdgcn_workgroup_id_y, - Intrinsic::amdgcn_workgroup_id_z - }; - Value* group_id = builder.CreateIntrinsic(ids[ax], {}, {}); - return group_id; -} - -Value* amd_cl_target::get_num_blocks(Module *module, IRBuilder<>& builder, unsigned ax) { - throw std::runtime_error("not implemented on AMD"); -} - -Value* amd_cl_target::get_local_id(Module *module, IRBuilder<>& builder, unsigned ax) { - static std::array ids = { - Intrinsic::amdgcn_workitem_id_x, - Intrinsic::amdgcn_workitem_id_y, - Intrinsic::amdgcn_workitem_id_z - }; - Function *get_local_id = Intrinsic::getDeclaration(module, ids[ax]); - return builder.CreateCall(get_local_id, {}); -} - -// NVIDIA - -void nvidia_cu_target::set_kernel(IRBuilder<>& builder, LLVMContext &ctx, Module *module, Function* fn){ - // set metadata - Metadata *md_args[] = { - ValueAsMetadata::get(fn), - MDString::get(ctx, "kernel"), - ValueAsMetadata::get(builder.getInt32(1)) - }; - module->getOrInsertNamedMetadata("nvvm.annotations")->addOperand(MDNode::get(ctx, md_args)); -} - -Instruction* nvidia_cu_target::add_barrier(Module *module, IRBuilder<>& builder) { - Function *barrier = Intrinsic::getDeclaration(module, Intrinsic::nvvm_barrier0); - return builder.CreateCall(barrier, {}); -} - -Instruction* nvidia_cu_target::add_memfence(Module *module, IRBuilder<>& builder) { - Function *barrier = Intrinsic::getDeclaration(module, Intrinsic::nvvm_membar_gl); - return builder.CreateCall(barrier, {}); -} - - -Value* nvidia_cu_target::get_global_offset(Module *module, IRBuilder<>& builder, unsigned stride, unsigned ax) { - Value* group_id = get_block_id(module, builder, ax); - Value* result = builder.CreateMul(builder.getInt32(stride), group_id); - return result; -} - -Value* nvidia_cu_target::get_block_id(Module *module, IRBuilder<>& builder, unsigned ax) { - static std::array cta_ids = { - Intrinsic::nvvm_read_ptx_sreg_ctaid_x, - Intrinsic::nvvm_read_ptx_sreg_ctaid_y, - Intrinsic::nvvm_read_ptx_sreg_ctaid_z - }; - Value* cta_id = builder.CreateIntrinsic(cta_ids[ax], {}, {}); - return cta_id; -} - -Value* nvidia_cu_target::get_local_id(Module *module, IRBuilder<>& builder, unsigned ax) { - static std::array ids = { - Intrinsic::nvvm_read_ptx_sreg_tid_x, - Intrinsic::nvvm_read_ptx_sreg_tid_y, - Intrinsic::nvvm_read_ptx_sreg_tid_z - }; - Function *get_local_id = Intrinsic::getDeclaration(module, ids[ax]); - return builder.CreateCall(get_local_id, {}); -} - -Value* nvidia_cu_target::get_num_blocks(Module *module, IRBuilder<>& builder, unsigned ax) { - static std::array ids = { - Intrinsic::nvvm_read_ptx_sreg_nctaid_x, - Intrinsic::nvvm_read_ptx_sreg_nctaid_y, - Intrinsic::nvvm_read_ptx_sreg_nctaid_z - }; - return builder.CreateIntrinsic(ids[ax], {}, {}); -} - -// CPU - -void cpu_target::set_kernel(IRBuilder<>& builder, LLVMContext &ctx, Module *module, Function* fn) { - // normal cpu functions can be kernels -} - -Instruction* cpu_target::add_barrier(Module *module, IRBuilder<>& builder) { - // no barrier on CPU - return (Instruction*)builder.CreateAdd(builder.getInt32(0), builder.getInt32(0)); -} - -Instruction* cpu_target::add_memfence(Module *module, IRBuilder<>& builder) { - // no barrier on CPU - return (Instruction*)builder.CreateAdd(builder.getInt32(0), builder.getInt32(0)); -} - - -Value* cpu_target::get_block_id(Module *module, llvm::IRBuilder<> &builder, unsigned ax) { - const Function *fn = builder.GetInsertBlock()->getParent(); - size_t num_params = fn->getFunctionType()->getNumParams(); - static std::array ids = { - fn->arg_begin() + num_params - 3, - fn->arg_begin() + num_params - 2, - fn->arg_begin() + num_params - 1 - }; - return (Argument*)ids[ax]; -} - -Value* cpu_target::get_num_blocks(Module *module, IRBuilder<>& builder, unsigned ax) { - throw std::runtime_error("not implemented"); -} - - -Value* cpu_target::get_global_offset(Module *module, IRBuilder<>& builder, unsigned stride, unsigned ax) { - Value* result = builder.CreateMul(builder.getInt32(stride), get_block_id(module, builder, ax)); - return result; -} - -Value* cpu_target::get_local_id(Module *module, IRBuilder<>& builder, unsigned ax) { - return builder.getInt32(0); -} - -} -} diff --git a/lib/codegen/transform/coalesce.cc b/lib/codegen/transform/coalesce.cc deleted file mode 100644 index ae8ce034def6..000000000000 --- a/lib/codegen/transform/coalesce.cc +++ /dev/null @@ -1,133 +0,0 @@ -#include -#include -#include "triton/ir/utils.h" -#include "triton/ir/instructions.h" -#include "triton/ir/function.h" -#include "triton/ir/module.h" -#include "triton/codegen/transform/coalesce.h" -#include "triton/codegen/analysis/align.h" -#include "triton/codegen/analysis/layout.h" - -namespace triton { -namespace codegen{ -namespace transform{ - -coalesce::coalesce(analysis::align* align, analysis::layouts *layouts) - : align_(align), layout_(layouts) { } - - -// simplify layout conversions using the following simple rules: -// - cvt_1(cvt_2(x)) if convert1 is the inverse of convert2 -// - cvt_1(elementwise(x, y)) = elementwise(convert(x), convert(y)) -//ir::value* coalesce::simplify(ir::instruction *inst, ir::builder& builder){ -// ir::value* _op = inst->get_operand(0); -// ir::instruction* op = dynamic_cast(_op); -// analysis::mma_layout* mma_in = layout_->get(op) ->to_mma(); -// analysis::mma_layout* mma_out = layout_->get(inst)->to_mma(); -// std::cout << 1 << std::endl; -// // i must be layout conversion instruction -// if(!mma_in && !mma_out) -// return inst; -// // - cvt_1(cvt_2(x)) if convert1 is the inverse of convert2 -// bool is_op_cvt = op->get_id() == ir::INST_CVT_LAYOUT; -// if((mma_in || mma_out) && is_op_cvt && -// (layout_->get(inst) == layout_->get(op->get_operand(0)))) -// return op->get_operand(0); -// // - cvt_1(elementwise(x, y)) = elementwise(cvt_1(x), cvt_2(y)) -// if(op->get_id() != ir::INST_BINOP && op->get_id() != ir::INST_GETELEMENTPTR) -// return inst; -// std::cout << 1 << std::endl; -// for(size_t i = 0; i < op->get_num_operands(); i++){ -// ir::value* arg_i = op->get_operand(i); -// builder.set_insert_point(op); -// // create new layout transform -// ir::instruction* new_arg_i = inst->clone(); -// builder.insert(new_arg_i); -// // set the right args -// new_arg_i->replace_uses_of_with(new_arg_i->get_operand(0), arg_i); -// op->replace_uses_of_with(arg_i, simplify(new_arg_i, builder)); -// } -// std::cout << 2 << std::endl; -// return op; -//} - -void coalesce::run(ir::module &mod) { - ir::builder& builder = mod.get_builder(); - // add layout conversion instructions - for(ir::function *fn: mod.get_function_list()) - for(ir::basic_block *block: fn->blocks()) - for(ir::instruction* i: block->get_inst_list()){ - // coalesce before store - if(dynamic_cast(i) || dynamic_cast(i)) - if(ir::value* op = i->get_operand(1)) - if(op->get_type()->is_block_ty()) - if(layout_->get(op)->to_mma()){ - ir::instruction* new_op = ir::cvt_layout_inst::create(op); - builder.set_insert_point(i); - builder.insert(new_op); - i->replace_uses_of_with(op, new_op); - } - // uncoalesce after load - if(auto x = dynamic_cast(i)) - if(x->get_type()->is_block_ty()) - if(x->get_type()->get_tile_rank()==2) - if(layout_->get(x)->to_mma()){ - builder.set_insert_point_after(x); - ir::instruction* new_x = ir::cvt_layout_inst::create(x); - builder.insert(new_x); - x->replace_all_uses_with(new_x); - new_x->replace_uses_of_with(new_x, x); - } - } - for(ir::function *fn: mod.get_function_list()) - for(ir::basic_block *block: fn->blocks()) - for(ir::instruction* i: block->get_inst_list()){ - // re-arrange scanline to promote memory coalescing - if(auto x = dynamic_cast(i)){ - ir::value* ptr = x->get_pointer_operand(); - ir::value* val = x->get_value_operand(); - auto out_contig = align_->contiguous(ptr); - auto val_inst = dynamic_cast(val); - if(!val_inst) - break; - if(dynamic_cast(val)) - break; - std::vector in_contig; - std::vector queue = {val_inst}; - std::set seen; - std::vector ios; - while(!queue.empty()){ - ir::instruction* curr = queue.back(); - seen.insert(curr); - queue.pop_back(); - if(auto dot_inst = dynamic_cast(curr)) - break; - if(auto io_inst = dynamic_cast(curr)){ - in_contig = align_->contiguous(io_inst->get_pointer_operand()); - break; - } - for(ir::value* op: curr->ops()){ - auto inst_op = dynamic_cast(op); - if(!inst_op || seen.find(inst_op) != seen.end()) - continue; - if(!op->get_type()->is_block_ty() || - !val->get_type()->is_block_ty()) - continue; - if(op->get_type()->get_tile_num_elements() == - val->get_type()->get_tile_num_elements()) - queue.push_back(inst_op); - } - } - if(in_contig.size() <= 1 || out_contig==in_contig) - continue; - builder.set_insert_point_after(val_inst); - auto new_val = builder.insert(ir::cvt_layout_inst::create(val_inst)); - x->replace_uses_of_with(val_inst, new_val); - } - } -} - - -} -} -} diff --git a/lib/codegen/transform/cts.cc b/lib/codegen/transform/cts.cc deleted file mode 100644 index c223d241319d..000000000000 --- a/lib/codegen/transform/cts.cc +++ /dev/null @@ -1,97 +0,0 @@ -#include "triton/codegen/transform/cts.h" -#include "triton/ir/module.h" -#include "triton/ir/function.h" -#include "triton/ir/basic_block.h" -#include "triton/ir/instructions.h" -#include - -namespace triton { -namespace codegen{ -namespace transform{ - - -inline bool is_shmem_op(ir::instruction* i, int op) { - if(i->get_id() == ir::INST_DOT) - return op==0 || op==1; - if(i->get_id() == ir::INST_COPY_FROM_SHARED) - return op==0; - if(i->get_id() == ir::INST_TRANS) - return op==0; - return false; -} - -inline bool is_shmem_res(ir::value* v){ - ir::instruction* i = dynamic_cast(v); - if(!i) - return false; - if(i->get_id() == ir::INST_TRANS) - return true; - if(i->get_id() == ir::INST_COPY_TO_SHARED) - return true; - if(i->get_id() == ir::INST_MASKED_LOAD_ASYNC) - return true; - return false; -} - - -// run pass on module -void cts::add_copy(ir::instruction *parent, ir::value *x, ir::builder &builder, bool to_shared) { - auto *i = dynamic_cast(x); - // not an instruction - if(!i) { - builder.set_insert_point(parent); - ir::value *copy; - if(to_shared) - copy = builder.create_copy_to_shared(x); - else - copy = builder.create_copy_from_shared(x); - parent->replace_uses_of_with(x, copy); - return; - } - // phi node - if(auto* phi = dynamic_cast(x)) { - for(unsigned i = 0; i < phi->get_num_incoming(); ++i) - add_copy(phi, phi->get_incoming_value(i), builder, to_shared); - return; - } - // already in shared memory - if(to_shared && is_shmem_res(i)) - return; - // copy - builder.set_insert_point_after(i); - ir::value *copy; - if(to_shared){ - copy = builder.create_copy_to_shared(x); - } - else - copy = builder.create_copy_from_shared(x); - parent->replace_uses_of_with(x, copy); -} - -void cts::run(ir::module &mod) { - // Add shared copies - ir::builder &builder = mod.get_builder(); - for(ir::function* fn: mod.get_function_list()){ - for(ir::basic_block* block: fn->blocks()) - for(ir::instruction* i: block->get_inst_list()){ - size_t num_op = i->get_num_operands(); - // copy to shared operands - for(size_t k = 0; k < num_op; k++) - if(is_shmem_op(i, k)){ - add_copy(i, i->get_operand(k), builder, true); - } - // copy from shared operands - for(size_t k = 0; k < num_op; k++) - if(!dynamic_cast(i) && - !is_shmem_op(i,k) && - is_shmem_res(i->get_operand(k))){ - add_copy(i, i->get_operand(k), builder, false); - } - } - } -} - - -} -} -} \ No newline at end of file diff --git a/lib/codegen/transform/dce.cc b/lib/codegen/transform/dce.cc deleted file mode 100644 index c555290f834b..000000000000 --- a/lib/codegen/transform/dce.cc +++ /dev/null @@ -1,75 +0,0 @@ -#include "triton/codegen/transform/dce.h" -#include "triton/ir/function.h" -#include "triton/ir/basic_block.h" -#include "triton/ir/module.h" -#include "triton/ir/utils.h" - -namespace triton { -namespace codegen{ -namespace transform{ - - -void dce::run(ir::module &mod) { - std::list work_list; - std::set marked; - - // initialize work-list - for(ir::function *fn: mod.get_function_list()){ - std::vector rpo = ir::cfg::reverse_post_order(fn); - // iterate through blocks - for(ir::basic_block *block: rpo) - for(ir::instruction *i: block->get_inst_list()){ - switch(i->get_id()){ - case ir::INST_RETURN: - case ir::INST_UNCOND_BRANCH: - case ir::INST_COND_BRANCH: - case ir::INST_UNMASKED_STORE: - case ir::INST_MASKED_STORE: - case ir::INST_ATOMIC_CAS: - case ir::INST_ATOMIC_RMW: - case ir::INST_ATOMIC_EXCH: - case ir::INST_BARRIER: { - work_list.push_back(i); - marked.insert(i); - break; - } - default: - break; - } - } - } - - // mark -- ignore branches - while(!work_list.empty()){ - ir::instruction* current = work_list.back(); - work_list.pop_back(); - // mark instruction operands - for(ir::value* op: current->ops()) { - if(auto *i = dynamic_cast(op)){ - if(marked.insert(i).second) - work_list.push_back(i); - } - } - // TODO: mark last intstruction of current's reverse-dominance frontier - } - - // sweep -- delete non-branch unmarked instructions - std::vector to_delete; - for(ir::function *fn: mod.get_function_list()){ - std::vector rpo = ir::cfg::reverse_post_order(fn); - // iterate through blocks - for(ir::basic_block *block: rpo) - for(ir::instruction *i: block->get_inst_list()){ - if(marked.find(i) == marked.end()) - to_delete.push_back(i); - } - } - - // delete - for(ir::instruction* i: to_delete) - i->erase_from_parent(); -} - -} -} -} diff --git a/lib/codegen/transform/disassociate.cc b/lib/codegen/transform/disassociate.cc deleted file mode 100644 index 2709125f8501..000000000000 --- a/lib/codegen/transform/disassociate.cc +++ /dev/null @@ -1,62 +0,0 @@ -#include "triton/codegen/transform/disassociate.h" -#include "triton/ir/utils.h" -#include "triton/ir/instructions.h" -#include "triton/ir/builder.h" -#include "triton/ir/module.h" -#include - -namespace triton { -namespace codegen{ -namespace transform{ - -ir::instruction* rematerialize(ir::builder& bld, ir::instruction *root, - std::set& seen) { - if (dynamic_cast(root)) - return root; - if(!seen.insert(root).second) - return root; - if(!root->get_type()->is_block_ty()) - return root; - - bld.set_insert_point(root); - ir::instruction *new_root = bld.insert(root->clone()); - for(ir::value *op: root->ops()){ - ir::instruction *i = dynamic_cast(op); - if(!i || i->get_id() == ir::INST_REDUCE) - continue; - ir::instruction* new_op = rematerialize(bld, i, seen); - new_root->replace_uses_of_with(op, new_op); - } - return new_root; -} - -void disassociate::run(ir::module &mod) { - ir::builder &bld = mod.get_builder(); - -// ir::for_each_instruction(mod, [&](ir::instruction *i){ -// bld.set_insert_point(i); -// for(ir::value* op: i->ops()){ -// auto reshape = dynamic_cast(op); -// if(!reshape) -// continue; -// ir::instruction* new_op = bld.insert(reshape->clone()); -// i->replace_uses_of_with(op, new_op); -// } -// }); - - - ir::for_each_instruction(mod, [&](ir::instruction *i){ - if(dynamic_cast(i) || dynamic_cast(i)){ - std::set seen; - ir::instruction* new_i = rematerialize(bld, i, seen); - i->replace_all_uses_with(new_i); - } - }); - - -} - - -} -} -} diff --git a/lib/codegen/transform/membar.cc b/lib/codegen/transform/membar.cc deleted file mode 100644 index 96249bcd53c9..000000000000 --- a/lib/codegen/transform/membar.cc +++ /dev/null @@ -1,244 +0,0 @@ -#include -#include -#include -#include "triton/codegen/analysis/layout.h" -#include "triton/codegen/analysis/allocation.h" -#include "triton/codegen/transform/membar.h" -#include "triton/codegen/transform/prefetch.h" -#include "triton/ir/module.h" -#include "triton/ir/function.h" -#include "triton/ir/basic_block.h" -#include "triton/ir/instructions.h" -#include "triton/ir/utils.h" - -namespace triton { - -namespace codegen{ -namespace transform{ - - - -int membar::group_of(ir::value* v, std::vector &async_write) { - if(ir::phi_node* phi = dynamic_cast(v)){ - analysis::shared_layout* layout = layouts_->get(v)->to_shared(); - if (analysis::double_buffer_info_t* info = layout->get_double_buffer()) - return group_of(info->first, async_write); - else if (analysis::N_buffer_info_t* info = layout->get_N_buffer()) { - if (v == info->phi) - return group_of(info->firsts[0], async_write); - else // prefetched value - return group_of(info->firsts[1], async_write); - } - std::vector groups(phi->get_num_operands()); - std::transform(phi->op_begin(), phi->op_end(), groups.begin(), [&](ir::value* v){ return group_of(v, async_write);}); - return *std::max_element(groups.begin(), groups.end()); - } - else{ - if(layouts_->has_tmp(v)) - return async_write.size() - 1; - auto it = std::find(async_write.begin(), async_write.end(), v); - return std::distance(async_write.begin(), it); - } -} - -inline bool membar::intersect_with(analysis::shared_layout* a_layout, analysis::shared_layout* b_layout) { - if(!a_layout || !b_layout) - return false; - int a_start = alloc_->offset(a_layout); - int a_end = a_start + a_layout->get_size(); - int b_start = alloc_->offset(b_layout); - int b_end = b_start + b_layout->get_size(); - if(a_start < b_end || b_start < a_end) - return true; - return false; -} - -membar::val_set_t membar::intersect_with(const val_set_t& as, const val_set_t& bs) { - val_set_t ret; - for(ir::value* a: as){ - if(!a->get_type()->is_block_ty()) - continue; - analysis::shared_layout* a_layout = layouts_->get(a)->to_shared(); - analysis::shared_layout* a_tmp = layouts_->has_tmp(a) ? layouts_->get(layouts_->tmp(a))->to_shared() : nullptr; - for(ir::value* b: bs){ - if(!b->get_type()->is_block_ty()) - continue; - analysis::shared_layout* b_layout = layouts_->get(b)->to_shared(); - analysis::shared_layout* b_tmp = layouts_->has_tmp(b) ? layouts_->get(layouts_->tmp(b))->to_shared() : nullptr; - if(intersect_with(a_layout, b_layout) || - intersect_with(a_layout, b_tmp) || - intersect_with(a_tmp, b_layout) || - intersect_with(a_tmp, b_tmp)) - ret.insert(b); - } - } - return ret; -} - -bool membar::check_safe_war(ir::instruction* i) { - bool is_i_shared_block = i->get_type()->is_block_ty() && - layouts_->get(i)->to_shared(); - bool is_i_double_buffered = is_i_shared_block && - layouts_->get(i)->to_shared()->get_double_buffer(); - bool is_i_n_buffered = is_i_shared_block && - layouts_->get(i)->to_shared()->get_N_buffer(); - - if (is_i_double_buffered || is_i_n_buffered) { - // with async copy & prefetch_s disabled, WARs are not safe - if (dynamic_cast(i) && !prefetch_->is_prefetched(i)) - return false; - else - return true; - } - return false; -} - -void membar::transfer(ir::basic_block *block, - val_vec_t& async_write, - val_set_t& sync_write, - val_set_t& sync_read, - std::set& safe_war, - bool& inserted, ir::builder& builder) { - std::vector async_waits; - ir::basic_block::inst_list_t instructions = block->get_inst_list(); - for(ir::instruction *i: instructions){ - if(dynamic_cast(i)) - continue; - if(std::find(async_write.begin(), async_write.end(), i) == async_write.end() && - dynamic_cast(i)){ - async_write.push_back(i); - } - if(dynamic_cast(i)) - sync_write.insert(i); - ir::barrier_inst* barrier = dynamic_cast(i); - ir::async_wait_inst* async_wait = dynamic_cast(i); - // Get shared memory reads - std::set read; - std::copy_if(i->op_begin(), i->op_end(), std::inserter(read, read.begin()), - [&](ir::value* i){ return i->get_type()->is_block_ty() && layouts_->get(i)->to_shared();}); - if(layouts_->has_tmp(i)) - read.insert(i); - // RAW (async) - val_set_t tmp; - std::copy(async_write.begin(), async_write.end(), std::inserter(tmp, tmp.begin())); - if(intersect_with(read, tmp).size()){ - std::vector groups(read.size()); - std::transform(read.begin(), read.end(), groups.begin(), [&](ir::value* v){ return group_of(v, async_write);}); - int N = *std::max_element(groups.begin(), groups.end()); - if(N < async_write.size()){ - builder.set_insert_point(i); - async_wait = (ir::async_wait_inst*)builder.create_async_wait(async_write.size() - 1 - N); - barrier = (ir::barrier_inst*)builder.create_barrier(); - inserted = true; - async_waits.push_back(async_wait); - } - } - // RAW, WAR - bool is_safe_war = check_safe_war(i); - // WAR barrier is not required when data is double-buffered - if(!intersect_with(read, sync_write).empty() || - (!intersect_with({i}, sync_read).empty() && !is_safe_war)) { - builder.set_insert_point(i); - barrier = (ir::barrier_inst*)builder.create_barrier(); - inserted = true; - } - // update state of asynchronous copies - if(async_wait){ - int N = async_write.size() - async_wait->get_N(); - async_write.erase(async_write.begin(), async_write.begin() + N); - } - // all the copy_to_shared and read from shared are synchronized after barrier - if(barrier){ - sync_write.clear(); - sync_read.clear(); - } - sync_read.insert(read.begin(), read.end()); - } - - // coalesce barriers - // fixme: to support more general cases - if (async_waits.size() == 2) { - // (aw N; bar; prefetch; aw N-1; bar; prefetch; => aw N-1; bar; 2*prefetch;) - for (int idx=0; idx to_erase; - ir::basic_block::inst_list_t instructions = block->get_inst_list(); - for(auto iter = instructions.begin(); iter != instructions.end(); ++iter){ - ir::instruction *i = *iter; - if (static_cast(first_async_wait) == i) { - // peak next 5 instructions - auto peak_iter = std::next(iter); - if (std::distance(peak_iter, instructions.end()) >= 5) { - auto first_bar = dynamic_cast(*peak_iter++); - auto first_pf = dynamic_cast(*peak_iter++); - auto second_async_wait = dynamic_cast(*peak_iter++); - auto second_bar = dynamic_cast(*peak_iter++); - auto second_pf = dynamic_cast(*peak_iter); - if (first_bar && first_pf && second_async_wait && second_bar && second_pf) { - int first_n = first_async_wait->get_N(); - int second_n = second_async_wait->get_N(); - to_erase.push_back(second_async_wait); - to_erase.push_back(second_bar); - first_async_wait->set_N(second_n); - } - } else - break; - for (ir::instruction *i : to_erase) - block->erase(i); - } - } - } - } -} - -void membar::run(ir::module &mod) { - ir::builder &builder = mod.get_builder(); - // extract phi-node associates with double-buffered - // shared-memory copies. These can be read from and written to - // without needing synchronization - std::set safe_war; - for(const auto& x: layouts_->get_all()){ - analysis::shared_layout* layout = x.second->to_shared(); - if(!layout || !layout->get_double_buffer() || !layout->get_N_buffer()) - continue; - for(ir::value *v: layout->get_values()) - if(v != layout->get_double_buffer()->phi){ - safe_war.insert(v); - } - } - - for(ir::function *fn: mod.get_function_list()){ - std::vector rpo = ir::cfg::reverse_post_order(fn); - std::map async_writes; - std::map sync_writes; - std::map sync_reads; - std::list pipelined; - bool inserted; - do{ - inserted = false; - // find barrier location - for(ir::basic_block *block: rpo){ - // join inputs - val_vec_t async_write; - val_set_t sync_write; - val_set_t sync_read; - val_set_t tmp; - for(ir::basic_block* pred: block->get_predecessors()){ - for(ir::value* v: async_writes[pred]) - if(tmp.insert(v).second) - async_write.push_back(v); - sync_write.insert(sync_writes[pred].begin(), sync_writes[pred].end()); - sync_read.insert(sync_reads[pred].begin(), sync_reads[pred].end()); - } - transfer(block, async_write, sync_write, sync_read, safe_war, inserted, builder); - async_writes[block] = async_write; - sync_writes[block] = sync_write; - sync_reads[block] = sync_read; - } - }while(inserted); - } -} - -} -} -} diff --git a/lib/codegen/transform/peephole.cc b/lib/codegen/transform/peephole.cc deleted file mode 100644 index 0961efc9c66b..000000000000 --- a/lib/codegen/transform/peephole.cc +++ /dev/null @@ -1,309 +0,0 @@ -#include -#include -#include "triton/ir/module.h" -#include "triton/ir/function.h" -#include "triton/codegen/transform/peephole.h" -#include "triton/codegen/analysis/layout.h" - -namespace triton { -namespace codegen{ -namespace transform{ - - -ir::value* rewrite_trans_phi_impl(ir::value *value, ir::builder &builder, - const std::vector& perm) { - if(auto phi = dynamic_cast(value)) { - // transpose operands - std::vector incs; - for(unsigned n = 0; n < phi->get_num_incoming(); n++) - incs.push_back(rewrite_trans_phi_impl(phi->get_incoming_value(n), builder, perm)); - // create phi for transposed values - builder.set_insert_point(phi); - ir::phi_node* result = builder.create_phi(incs[0]->get_type(), incs.size()); - for(unsigned n = 0; n < phi->get_num_incoming(); n++) - result->add_incoming(incs[n], phi->get_incoming_block(n)); - return result; - } - else if(auto i = dynamic_cast(value)){ - ir::basic_block* block = i->get_parent(); - auto it = std::find(block->begin(), block->end(), i); - it++; - builder.set_insert_point(it); - ir::instruction *trans = (ir::instruction*)builder.create_trans(i, perm); - trans->set_operand(0, i); - return trans; - } - return nullptr; -} - -bool peephole::rewrite_trans_phi(ir::instruction* value, ir::builder& builder) { - auto trans = dynamic_cast(value); - if(!trans) - return false; - auto users = trans->get_users(); - auto ops = trans->ops(); - if(users.size() > 1 || ops.size() > 1) - return false; - ir::value* op = *ops.begin(); - // trans(phi) -> phi(trans(), trans()...) - auto* phi = dynamic_cast(op); - if(!phi) - return false; - ir::value* new_phi = rewrite_trans_phi_impl(phi, builder, trans->get_perm()); - if(!new_phi) - return false; - trans->replace_all_uses_with(new_phi); - - return true; -} - -bool peephole::rewrite_dot(ir::instruction *value, ir::builder& builder){ - // dot(a, b, c) + d -> dot(a, b, c + d) - // d + dot(a, b, c) -> dot(a, b, c + d) - auto add = dynamic_cast(value); - if(add && (add->get_op() == ir::binary_op_t::FAdd || add->get_op() == ir::binary_op_t::Add)) { - bool is_int_dot = add->get_op() == ir::binary_op_t::Add; - ir::value *lhs = add->get_operand(0); - ir::value *rhs = add->get_operand(1); - ir::dot_inst *lhs_dot = dynamic_cast(lhs); - ir::dot_inst *rhs_dot = dynamic_cast(rhs); - if(!lhs_dot && !rhs_dot) - return false; - ir::dot_inst *dot = lhs_dot ? lhs_dot : rhs_dot; - ir::value *other = (dot == lhs) ? rhs : lhs; - ir::value *acc = dot->get_operand(2); - ir::splat_inst *splat = dynamic_cast(acc); - ir::constant *_0 = nullptr; - if(splat) - _0 = dynamic_cast(splat->get_operand(0)); - if(!_0) - return false; - if (auto *fp_0 = dynamic_cast(_0)) - if (fp_0->get_value() != 0.0) - return false; - if (auto *int_0 = dynamic_cast(_0)) - if (int_0->get_value() != 0) - return false; - ir::value *a = dot->get_operand(0); - ir::value *b = dot->get_operand(1); - builder.set_insert_point(add); - ir::value * new_dot = builder.insert(ir::dot_inst::create_nn(a, b, other, dot->allow_tf32(), dot->get_name())); - add->replace_all_uses_with(new_dot); - return true; - } - return false; -} - -//bool peephole::rewrite_cts_cfs(ir::instruction *value, ir::builder &builder){ -// auto cfs = dynamic_cast(value); -// if(cfs) { -// ir::value *arg = cfs->get_operand(0); -// ir::copy_to_shared_inst* cts = dynamic_cast(arg); -// if(!cts) -// return false; -// cfs->replace_all_uses_with(cts->get_operand(0)); -// return true; -// } - -//} - -bool peephole::rewrite_load_to_shared(ir::instruction *value, ir::builder& builder){ - auto copy_to_shared = dynamic_cast(value); - if(!copy_to_shared) - return false; - ir::value *arg = copy_to_shared->get_operand(0); - ir::masked_load_inst* ld = dynamic_cast(arg); - if(!ld) - return false; - builder.set_insert_point(copy_to_shared); - ir::value *ptr = ld->get_pointer_operand(); - ir::value *msk = ld->get_mask_operand(); - ir::value *val = ld->get_false_value_operand(); - analysis::scanline_layout* layout = layouts_->get(ptr)->to_scanline(); - int nts = layout->nts(layout->get_order()[0]); - int dtsize = value->get_type()->get_scalar_ty()->get_primitive_size_in_bits() / 8; - if(nts*dtsize >= 4){ - ir::value* new_load = builder.create_masked_load_async(ptr, msk, val, ld->get_cache_modifier(), ld->get_eviction_policy()); - copy_to_shared->replace_all_uses_with(new_load); - return true; - } - return false; -// analysis::scanline_layout* layout = layouts_->get(ptr)->to_scanline(); -// std::cout << layout->nts(layout->get_order(0)) << std::endl; -// return true; - -} - -bool peephole::rewrite_unit_red(ir::instruction *value, ir::builder& builder){ - auto x = dynamic_cast(value); - if(!x) - return false; - ir::value *arg = x->get_operand(0); - auto shapes = arg->get_type()->get_block_shapes(); - if(shapes[x->get_axis()] == 1){ - builder.set_insert_point(x); - ir::value* new_red = builder.create_reshape(arg, x->get_type()->get_block_shapes()); - x->replace_all_uses_with(new_red); - return true; - } - return false; -} - -bool peephole::rewrite_mult(ir::instruction *value, ir::builder& builder) { - auto binop = dynamic_cast(value); - if(binop && binop->get_op() == ir::binary_op_t::Mul) { - ir::value *lhs = binop->get_operand(0); - ir::value *rhs = binop->get_operand(1); - ir::constant_int *_1_lhs = nullptr; - if(ir::splat_inst *splat = dynamic_cast(lhs)){ - auto *cst = dynamic_cast(splat->get_operand(0)); - if(cst && cst->get_value() == 1) - _1_lhs = cst; - } - ir::constant_int *_1_rhs = nullptr; - if(ir::splat_inst *splat = dynamic_cast(rhs)){ - auto *cst = dynamic_cast(splat->get_operand(0)); - if(cst && cst->get_value() == 1) - _1_rhs = cst; - } - if(_1_lhs){ - binop->replace_all_uses_with(rhs); - return true; - } - else if(_1_rhs){ - binop->replace_all_uses_with(lhs); - return true; - } - } - return false; -} - - -bool peephole::rewrite_gep_ptr_min_off_plus_off(ir::instruction *value, ir::builder& builder) { - auto x = dynamic_cast(value); - if(!x) - return false; - auto y = dynamic_cast(x->get_pointer_operand()); - if(!y) - return false; - auto idx = *y->idx_begin(); - auto z = dynamic_cast(idx); - if(!z) - return false; - bool is_sub = z->get_op() == ir::binary_op_t::Sub; - auto *lhs = dynamic_cast(z->get_operand(0)); - bool is_lhs_0 = lhs && (lhs->get_value()==0); - bool is_rhs_eq_x_rhs = z->get_operand(1) == *x->idx_begin(); - if(is_sub && is_lhs_0 && is_rhs_eq_x_rhs){ - x->replace_all_uses_with(y->get_pointer_operand()); - return true; - } - return false; -} - -bool peephole::rewrite_select_masked_load(ir::instruction *value, ir::builder& builder){ - auto select = dynamic_cast(value); - if(!select) - return false; - auto if_value = dynamic_cast(select->get_if_value_op()); - if(!if_value) - return false; - if(select->get_pred_op() != if_value->get_mask_operand()) - return false; - builder.set_insert_point(select); - ir::value* new_load = builder.create_masked_load(if_value->get_pointer_operand(), - if_value->get_mask_operand(), - select->get_else_value_op(), - if_value->get_cache_modifier(), - if_value->get_eviction_policy(), - if_value->get_is_volatile()); - select->replace_all_uses_with(new_load); - return true; -} - -bool peephole::rewrite_cvt_layout(ir::instruction *value, ir::builder& builder){ - auto cvt = dynamic_cast(value); - if(!cvt) - return false; - ir::instruction* op = dynamic_cast(cvt->get_operand(0)); - if(!op) - return false; -// // convert(elementwise(x, y)) = elementwise(convert(x), convert(y)) -// if(op->get_id() == ir::INST_BINOP){ -// for(size_t i = 0; i < op->get_num_operands(); i++){ -// ir::value* arg_i = op->get_operand(i); -// builder.set_insert_point(op); -// // create new layout transform -// ir::instruction* new_arg_i = cvt->clone(); -// layouts_->copy(new_arg_i, op); -// builder.insert(new_arg_i); -// // set the right args -// new_arg_i->replace_uses_of_with(new_arg_i->get_operand(0), arg_i); -// op->replace_uses_of_with(arg_i, new_arg_i); -// } -// cvt->replace_all_uses_with(op); -// return true; -// } - auto cvt_op = dynamic_cast(op); - if(!cvt_op) - return false; - // convert1(convert2(x)) if convert1 is the inverse of convert2 - ir::value* op_op = cvt_op->get_operand(0); - if(layouts_->has(cvt) && layouts_->has(op_op) && - layouts_->get(cvt) && layouts_->get(op_op)){ - cvt->replace_all_uses_with(op_op); - return true; - } - return false; -} - -void peephole::run(ir::module &mod) { - ir::builder &builder = mod.get_builder(); - // keep track of whether any modification was made - std::set seen; - size_t n_seen; - - // rewrite dots first - do{ - n_seen = seen.size(); - for(ir::function *fn: mod.get_function_list()) - for(ir::basic_block *block: fn->blocks()) - for(ir::instruction* i: block->get_inst_list()){ - if(seen.find(i) != seen.end()) - continue; - bool was_modified = rewrite_dot(i, builder); - if(was_modified){ - seen.insert(i); - } - } - }while(seen.size() != n_seen); - - // rewrite other ops - seen.clear(); - do{ - n_seen = seen.size(); - for(ir::function *fn: mod.get_function_list()) - for(ir::basic_block *block: fn->blocks()) - for(ir::instruction* i: block->get_inst_list()){ - if(seen.find(i) != seen.end()) - continue; - bool was_modified = false; - was_modified = was_modified || rewrite_mult(i, builder); - // was_modified = was_modified || rewrite_cts_cfs(i, builder); -// was_modified = was_modified || rewrite_trans_phi(i, builder); - was_modified = was_modified || rewrite_unit_red(i, builder); - was_modified = was_modified || rewrite_gep_ptr_min_off_plus_off(i, builder); - // TODO: DOESN'T WORK FOR VECTORIZED MASKED LOAD -// was_modified = was_modified || rewrite_select_masked_load(i, builder); - was_modified = was_modified || rewrite_cvt_layout(i, builder); - if(tgt_->as_nvidia() && tgt_->as_nvidia()->sm() >= 80) - was_modified = was_modified || rewrite_load_to_shared(i, builder); - if(was_modified) - seen.insert(i); - } - }while(seen.size() != n_seen); -} - -} -} -} diff --git a/lib/codegen/transform/pipeline.cc b/lib/codegen/transform/pipeline.cc deleted file mode 100644 index c85ba43a1c3b..000000000000 --- a/lib/codegen/transform/pipeline.cc +++ /dev/null @@ -1,330 +0,0 @@ -#include -#include -#include "triton/codegen/transform/pipeline.h" -#include "triton/ir/module.h" -#include "triton/ir/function.h" -#include "triton/ir/basic_block.h" -#include "triton/ir/instructions.h" -#include "triton/ir/utils.h" - -namespace triton { -namespace codegen{ -namespace transform{ - - -void recursive_deps(ir::value* v, ir::basic_block* block, std::vector& ret){ - ir::instruction* i = dynamic_cast(v); - if(!i || i->get_parent() != block) - return; - if(i->get_id()==ir::INST_PHI) - return; - ret.push_back(i); - for(ir::user* u: i->get_users()) - recursive_deps(u, block, ret); -} - -void get_induction_vars(ir::value* cond, std::set& phis) { - auto instr = dynamic_cast(cond); - for (auto op : instr->ops()) { - if (auto phi_op = dynamic_cast(op)) { - phis.insert(phi_op); - return; - } - if (dynamic_cast(op)) - get_induction_vars(op, phis); - } -} - -/// assume incoming block is 1 -ir::value* rematerialize_vals(ir::builder& builder, ir::basic_block* block, ir::value* v, - std::map& prev_phi_vals) { - ir::instruction* i = dynamic_cast(v); - if(!i || i->get_parent() != block) - return v; - if(ir::phi_node* phi = dynamic_cast(v)) { - if (prev_phi_vals.find(phi) == prev_phi_vals.end()) - throw std::runtime_error("Don't have that phi node\n"); - return prev_phi_vals.at(phi); - } - - std::vector new_ops; - for(ir::value* op: i->ops()){ - new_ops.push_back(rematerialize_vals(builder, block, op, prev_phi_vals)); - } - ir::instruction* ret = i->clone(); - for(size_t k = 0; k < new_ops.size(); k++) - ret->set_operand(k, new_ops[k]); - builder.insert(ret); - return ret; -} - -ir::value* rematerialize(ir::builder& builder, ir::basic_block* block, - ir::value* v, size_t phi_idx){ - ir::instruction* i = dynamic_cast(v); - if(!i || i->get_parent() != block) - return v; - if(ir::phi_node* phi = dynamic_cast(v)) - return phi->get_incoming_value(phi_idx); - - std::vector new_ops; - for(ir::value* op: i->ops()){ - new_ops.push_back(rematerialize(builder, block, op, phi_idx)); - } - ir::instruction* ret = i->clone(); - for(size_t k = 0; k < new_ops.size(); k++) - ret->set_operand(k, new_ops[k]); - builder.insert(ret); - return ret; -} - -/// moving the prev phi vals to the next iteration -std::map update_prev_phi_vals( - ir::builder& builder, ir::basic_block* block, std::map& prev_phi_vals) { - std::map next_phi_vals; - for (auto &[phi, val] : prev_phi_vals) { - next_phi_vals[phi] = rematerialize_vals(builder, block, phi->get_incoming_value(1), prev_phi_vals); - } - return next_phi_vals; -} - -void finalize_iv_vals(ir::builder& builder, ir::basic_block* block, std::map& load_ivs, - std::map& next_load_ivs) { - for (auto& [phi, val] : load_ivs) { - if (auto new_phi = dynamic_cast(val)) { - ir::value* next_k = rematerialize_vals(builder, block, phi->get_incoming_value(1), load_ivs); - assert(new_phi->get_num_operands() == 1 && "should be incomplete phi"); - new_phi->add_incoming(next_k, phi->get_incoming_block(1)); - // cache next_k (to be used by next_mask) - next_load_ivs[phi] = next_k; - } else - throw std::runtime_error("must be phi"); - } -} - -struct pipeline_info_t { - ir::load_inst* load; - ir::phi_node* ptr; - ir::dot_inst* dot; - - pipeline_info_t(ir::load_inst* load, ir::phi_node* ptr, ir::dot_inst* dot) - : load(load), ptr(ptr), dot(dot) {} -}; - -void pipeline::run(ir::module &mod) { - if (num_stages_ <= 1) - return; - // *Very* conservative heuristics for pre-fetching. - // A load instruction can be pipelined if: - // - the pointer is a phi node that references a value - // in its basic block (i.e., pointer induction variable) - // - the load has only a single use in a dot instruction - // As more use cases become apparent, this pass will be improved - std::vector to_pipeline; - ir::for_each_instruction(mod, [&](ir::instruction *i){ - if(auto* load = dynamic_cast(i)){ - ir::phi_node* ptr = dynamic_cast(load->get_pointer_operand()); - auto users = load->get_users(); - auto dot = dynamic_cast(*users.begin()); - if(ptr && ptr->get_incoming_block(1) == ptr->get_parent() - && users.size() == 1 && dot) - to_pipeline.push_back({load, ptr, dot}); - }}); - // do the pipelining - std::vector new_loads; - ir::builder &builder = mod.get_builder(); - const int num_stages = num_stages_; - std::vector>> preheader_loads; // Used to reorder loads - for(auto info: to_pipeline){ - ir::load_inst* load = info.load; - ir::phi_node* ptr = info.ptr; - ir::basic_block* block = load->get_parent(); - ir::basic_block* header = block->get_predecessors()[0]; - auto* block_br = dynamic_cast(block->get_inst_list().back()); - auto* header_br = dynamic_cast(header->get_inst_list().back()); - assert(block_br); - assert(header_br); - ir::type* ty = load->get_type(); - // multi-stage pipe - if (has_copy_async_ && num_stages > 2) { - ir::value* header_cond = header_br->get_cond(); - ir::value* block_cond = block_br->get_cond(); - // 1. collect induction variables - std::set induction_vars; - get_induction_vars(block_cond, induction_vars); - - std::vector first_ptrs(num_stages-1); - std::vector first_loads(num_stages-1); - std::vector first_masks(num_stages-1); - std::vector loop_conds(num_stages-1); - - std::map prev_phi_vals; - // initialize prev_phi_vals - // Add all phi nodes. The following DCE pass will delete dead ones. - for (ir::instruction *instr : block->get_inst_list()) - if (auto *phi = dynamic_cast(instr)) - if (phi->get_incoming_block(1) == block) - prev_phi_vals[phi] = phi->get_value_for_block(header); - - builder.set_insert_point(header->get_inst_list().back()); - first_ptrs[0] = ptr->get_value_for_block(header); - loop_conds[0] = header_cond; - first_masks[0] = builder.create_splat(loop_conds[0], ty->get_block_shapes()); - ir::value* false_value = nullptr; - if (auto* masked_load = dynamic_cast(load)) { - ir::value* remat_mask =rematerialize_vals(builder, block, masked_load->get_mask_operand(), prev_phi_vals) ; - ir::value* remat_false_value = - rematerialize_vals(builder, block, masked_load->get_false_value_operand(), prev_phi_vals); - first_masks[0] = builder.create_and(first_masks[0], remat_mask); - false_value = remat_false_value; - } else - false_value = builder.create_splat(ir::undef_value::get(ty->get_scalar_ty()), ty->get_block_shapes()); - first_loads[0] = builder.create_masked_load(first_ptrs[0], first_masks[0], false_value, load->get_cache_modifier(), load->get_eviction_policy(), load->get_is_volatile()); - - for (int stage = 1; stage < num_stages-1; ++stage) { - // mask is the loop condition of the previous iteration - loop_conds[stage] = rematerialize_vals(builder, block, block_cond, prev_phi_vals); - prev_phi_vals = update_prev_phi_vals(builder, block, prev_phi_vals); - first_ptrs[stage] = rematerialize_vals(builder, block, ptr, prev_phi_vals); - first_masks[stage] = builder.create_splat(loop_conds[stage], ty->get_block_shapes()); - if (auto* masked_load = dynamic_cast(load)) { - ir::value* remat_mask = rematerialize_vals(builder, block, masked_load->get_mask_operand(), prev_phi_vals); - ir::value* remat_false_value = - rematerialize_vals(builder, block, masked_load->get_false_value_operand(), prev_phi_vals); - first_masks[stage] = builder.create_and(first_masks[stage], remat_mask); - false_value = remat_false_value; - } - first_loads[stage] = builder.create_masked_load(first_ptrs[stage], first_masks[stage], false_value, load->get_cache_modifier(), load->get_eviction_policy(), load->get_is_volatile()); - } - - // create new phis for induction variables - builder.set_insert_point(block->get_first_non_phi()); - std::map load_ivs; - std::map next_load_ivs; - for (auto& [iv, val] : prev_phi_vals) { - ir::phi_node* pn = builder.create_phi(iv->get_type(), 2); - pn->add_incoming(prev_phi_vals[iv], header); - load_ivs[iv] = pn; - } - // add incoming for phis & update next_load_ivs - finalize_iv_vals(builder, block, load_ivs, next_load_ivs); - - // pre-fetch next iteration - builder.set_insert_point(block->get_inst_list().back()); -// ir::value* next_ptr = ptr->get_value_for_block(block); - ir::value* next_ptr = rematerialize_vals(builder, block, ptr->get_value_for_block(block), load_ivs); - ir::value* next_mask = builder.create_splat( - rematerialize_vals(builder, block, block_cond, load_ivs), ty->get_block_shapes()); - if (auto* masked_load = dynamic_cast(load)) { - ir::value* remat_mask = rematerialize_vals(builder, block, masked_load->get_mask_operand(), next_load_ivs); - // TODO: false may depends on some other phi nodes - ir::value* remat_false_value = - rematerialize_vals(builder, block, masked_load->get_false_value_operand(), next_load_ivs); - next_mask = builder.create_and(next_mask, remat_mask); - false_value = remat_false_value; - } - ir::value* next_load = builder.create_masked_load(next_ptr, next_mask, false_value, load->get_cache_modifier(), load->get_eviction_policy(), load->get_is_volatile()); - - - // phi node - ptr->set_incoming_value(0, first_ptrs.back()); - builder.set_insert_point(block->get_first_non_phi()); - // nested phis for load - std::vector new_load_phis(num_stages-1); - for (auto& pn : new_load_phis) - pn = builder.create_phi(ty, 2); - for (int i=0; iadd_incoming(first_loads[i], header); - new_load_phis[i]->add_incoming(new_load_phis[i+1], block); - } - new_load_phis.back()->add_incoming(first_loads.back(), header); - new_load_phis.back()->add_incoming(next_load, block); - load->replace_all_uses_with(new_load_phis.front()); - new_loads.push_back(new_load_phis.back()); - - // record first_loads to reorder them - preheader_loads.push_back({new_load_phis.front(), first_loads}); - } else { - // pre-fetch first iteration - builder.set_insert_point(header->get_inst_list().back()); - ir::value* first_ptr = ptr->get_value_for_block(header); - ir::value* first_mask = builder.create_splat(header_br->get_cond(), ty->get_block_shapes()); - ir::value* false_value; - if(auto* masked_load = dynamic_cast(load)){ - ir::value* remat_mask = rematerialize(builder, block, masked_load->get_mask_operand(), 0); - ir::value* remat_false_value = rematerialize(builder, block, masked_load->get_false_value_operand(), 0); - first_mask = builder.create_and(first_mask, remat_mask); - false_value = remat_false_value; - } - else - false_value = builder.create_splat(ir::undef_value::get(ty->get_scalar_ty()), ty->get_block_shapes()); - ir::value* first_load = builder.create_masked_load(first_ptr, first_mask, false_value, load->get_cache_modifier(), load->get_eviction_policy(), load->get_is_volatile()); - // pre-fetch next iteration - builder.set_insert_point(block->get_inst_list().back()); - ir::value* next_ptr = ptr->get_value_for_block(block); - ir::value* next_mask = builder.create_splat(block_br->get_cond(), ty->get_block_shapes()); - if(auto* masked_load = dynamic_cast(load)){ - ir::value* remat_mask = rematerialize(builder, block, masked_load->get_mask_operand(), 1); - ir::value* remat_false_value = rematerialize(builder, block, masked_load->get_false_value_operand(), 1); - next_mask = builder.create_and(next_mask, remat_mask); - false_value = remat_false_value; - } - ir::value* next_load = builder.create_masked_load(next_ptr, next_mask, false_value, load->get_cache_modifier(), load->get_eviction_policy(), load->get_is_volatile()); - // phi node - builder.set_insert_point(block->get_first_non_phi()); - ir::phi_node* new_load = builder.create_phi(ty, 2); - new_load->add_incoming(first_load, header); - new_load->add_incoming(next_load, block); - load->replace_all_uses_with(new_load); - new_loads.push_back(new_load); - } - } - - // try to reorder prefetched value from a0, a1, a2, ..., b0, b1, b2, ... to - // a0, b0, a1, b1, ... - if (!preheader_loads.empty()) { - ir::basic_block* header = preheader_loads.begin()->first->get_incoming_block(0); - builder.set_insert_point(header->get_inst_list().back()); - for (int i=1; i(iter->second.at(i)); - ir::instruction* moved_load = original_load->clone(); - builder.insert(moved_load); - original_load->replace_all_uses_with(moved_load); - } - } - } - - // try to move dot_inst after loads - // for better overlap of io and compute - struct move_config_t{ - std::vector insts; - ir::load_inst* dst; - }; - std::vector to_move(to_pipeline.size()); - - if(has_copy_async_){ - for (size_t idx = 0; idx < to_pipeline.size(); ++idx) { - auto info = to_pipeline[idx]; - ir::load_inst* load = info.load; - ir::phi_node* ptr = info.ptr; - ir::dot_inst* dot = info.dot; - ir::basic_block* bb = dot->get_parent(); - recursive_deps(dot, bb, to_move[idx].insts); - to_move[idx].dst = load; - } - - for(auto& move_config: to_move){ - builder.set_insert_point_after(move_config.dst); - for(ir::instruction* i: move_config.insts){ - i->get_parent()->erase(i); - builder.insert(i); - } - } - } - - -} - -} -} -} diff --git a/lib/codegen/transform/prefetch.cc b/lib/codegen/transform/prefetch.cc deleted file mode 100644 index 30b2a10f2718..000000000000 --- a/lib/codegen/transform/prefetch.cc +++ /dev/null @@ -1,133 +0,0 @@ -#include "triton/codegen/transform/prefetch.h" -#include "triton/codegen/target.h" -#include "triton/ir/module.h" -#include "triton/ir/function.h" -#include "triton/ir/basic_block.h" -#include "triton/ir/instructions.h" -#include "triton/ir/utils.h" -#include "triton/ir/print.h" -#include -#include -#include - -namespace triton::codegen::transform { - -/// find defs till phis -static void recursive_defs(ir::value *v, ir::basic_block *bb, std::vector &ret) { - ir::instruction *i = dynamic_cast(v); - if (!i || i->get_parent() != bb) - return; - if (i->get_id() == ir::INST_PHI) - return; - ret.push_back(i); - for (ir::value *op : i->ops()) - recursive_defs(op, bb, ret); -} - -void prefetch::run(ir::module &mod) { - // 1. collect dots that can be prefethced - std::vector to_prefetch; - ir::for_each_instruction(mod, [&](ir::instruction *i) { - if (auto *dot = dynamic_cast(i)) { - // Now only do prefetching when dot is using tensor cores - if (!(dot->get_operand(0)->get_type()->get_scalar_ty()->is_fp16_ty() || - dot->get_operand(0)->get_type()->get_scalar_ty()->is_bf16_ty() || - (dot->get_operand(0)->get_type()->get_scalar_ty()->is_fp32_ty() && dot->allow_tf32() - && tgt_->as_nvidia() && tgt_->as_nvidia()->sm() >= 80) || - (dot->get_operand(0)->get_type()->get_scalar_ty()->is_integer_ty(8) - && dot->get_operand(1)->get_type()->get_scalar_ty()->is_integer_ty(8) - && tgt_->as_nvidia() && tgt_->as_nvidia()->sm() >= 80) - ) - ) - return; - auto *a = dynamic_cast(dot->get_operand(0)); - auto *b = dynamic_cast(dot->get_operand(1)); - if (a && a->get_incoming_block(1) == a->get_parent() && - b && b->get_incoming_block(1) == b->get_parent()) - to_prefetch.push_back(dot); - } - }); - - assert(to_prefetch.size() <=1 && "Don't know what to do with multiple dots"); - ir::builder &builder = mod.get_builder(); - // 2. do the prefetching - for (ir::dot_inst* dot : to_prefetch) { - auto *a = dynamic_cast(dot->get_operand(0)); - auto *b = dynamic_cast(dot->get_operand(1)); - assert(a->get_incoming_block(0) == b->get_incoming_block(0)); - ir::basic_block *loop_header = a->get_incoming_block(0); - ir::basic_block *loop_body = a->get_parent(); - - // mark as prefetched - dot->set_prefetched(true); - - // 1. in the loop header (first iteration) - builder.set_insert_point(loop_header->get_inst_list().back()); - assert(a && b); - builder.create_prefetch_s(a->get_incoming_value(0), /*inc*/ 0); - builder.create_prefetch_s(b->get_incoming_value(0), /*inc*/ 0); - - // 2. at the end of the loop body (next iteration) - builder.set_insert_point(loop_body->get_inst_list().back()); - builder.create_prefetch_s(a->get_incoming_value(1), /*inc*/ 1); - builder.create_prefetch_s(b->get_incoming_value(1), /*inc*/ 1); - - prefetched_vals_.insert(a->get_incoming_value(0)); - prefetched_vals_.insert(b->get_incoming_value(0)); - // nested phis - ir::value* next_a = a->get_incoming_value(1); - while (auto* next_a_phi = dynamic_cast(next_a)) { - prefetched_vals_.insert(next_a_phi->get_incoming_value(0)); - next_a = next_a_phi->get_incoming_value(1); - } - prefetched_vals_.insert(next_a); - - ir::value* next_b = b->get_incoming_value(1); - while (auto* next_b_phi = dynamic_cast(next_b)) { - prefetched_vals_.insert(next_b_phi->get_incoming_value(0)); - next_b = next_b_phi->get_incoming_value(1); - } - prefetched_vals_.insert(next_b); - } - - // move loads to the beginning of the loop - if (tgt_->as_nvidia() && tgt_->as_nvidia()->sm() < 80) { - for (ir::function *fn : mod.get_function_list()) - for (ir::basic_block *bb : fn->blocks()) { - // only apply to loop body - if (bb->get_predecessors().size() != 2 || bb->get_predecessors()[1] != bb) - continue; - // record loads (& dependency) to move - std::vector loads; - // record original inst order - std::map idx_map; - size_t idx = 0; - for (ir::instruction *inst : bb->get_inst_list()) { - if (auto *i = dynamic_cast(inst)) - recursive_defs(i, bb, loads); - idx_map[inst] = idx; - idx++; - } - - // remove duplicates & keep the original input order - std::sort(loads.begin(), loads.end()); - loads.erase(std::unique(loads.begin(), loads.end()), loads.end()); - std::sort(loads.begin(), loads.end(), [&idx_map](ir::instruction *a, ir::instruction *b) { - return idx_map[a] < idx_map[b]; - }); - - builder.set_insert_point(bb->get_first_non_phi()); - auto& inst_list = bb->get_inst_list(); - for (ir::instruction *i : loads){ - auto it = std::find(inst_list.begin(), inst_list.end(), i); - // make sure we don't invalidate insert point - // in case instruction already at the top - if(it == builder.get_insert_point()) - continue; - bb->erase(i); - builder.insert(i); - } - } - } -} -} // namespace triton::codegen::transform diff --git a/lib/codegen/transform/reorder.cc b/lib/codegen/transform/reorder.cc deleted file mode 100644 index 47dc47b6c24f..000000000000 --- a/lib/codegen/transform/reorder.cc +++ /dev/null @@ -1,51 +0,0 @@ -#include -#include -#include "triton/ir/module.h" -#include "triton/ir/function.h" -#include "triton/ir/basic_block.h" -#include "triton/ir/instructions.h" -#include "triton/codegen/transform/reorder.h" - -namespace triton { -namespace codegen{ -namespace transform{ - -void reorder::run(ir::module& mod){ -// ir::builder &builder = mod.get_builder(); -// std::vector> to_replace; - -// for(ir::function *fn: mod.get_function_list()) -// for(ir::basic_block *block: fn->blocks()) -// for(ir::instruction* i: block->get_inst_list()){ -// if(auto* ld = dynamic_cast(i)){ -// ir::value* _ptr = ld->get_pointer_operand(); -// ir::value* _msk = ld->get_mask_operand(); -// ir::value* _val = ld->get_false_value_operand(); -// auto ptr = std::find(block->begin(), block->end(), _ptr); -// auto msk = std::find(block->begin(), block->end(), _msk); -// auto val = std::find(block->begin(), block->end(), _val); -// if(ptr == block->end() || msk == block->end() || val == block->end()) -// continue; -// auto it = std::find(block->begin(), block->end(), i); -// int dist_ptr = std::distance(ptr, it); -// int dist_msk = std::distance(msk, it); -// int dist_val = std::distance(val, it); -// if(dist_ptr < dist_msk && dist_ptr < dist_val) -// builder.set_insert_point(++ptr); -// if(dist_msk < dist_ptr && dist_msk < dist_val) -// builder.set_insert_point(++msk); -// if(dist_val < dist_ptr && dist_val < dist_msk) -// builder.set_insert_point(++val); -// ir::value* new_ld = builder.create_masked_load(_ptr, _msk, _val); -// to_replace.push_back(std::make_pair(ld, new_ld)); -// } -// } - -// for(auto& x: to_replace) -// x.first->replace_all_uses_with(x.second); - -} - -} -} -} diff --git a/python/bench/README.md b/python/bench/README.md deleted file mode 100644 index 970c3a2a0706..000000000000 --- a/python/bench/README.md +++ /dev/null @@ -1,5 +0,0 @@ -# Run the benchmarks - -Install the required dependencies via `pip install -r requirements-bench.txt` from the triton/python/bench folder. - -Run the benchmarks through `python3 bench/run.py`, this will produce an HTML report in a results folder. diff --git a/python/bench/bench_blocksparse.py b/python/bench/bench_blocksparse.py deleted file mode 100644 index d678f49f807e..000000000000 --- a/python/bench/bench_blocksparse.py +++ /dev/null @@ -1,92 +0,0 @@ -import torch - -import triton - -# ------------------------------- -# Matrix Multiplication -# ------------------------------- - -nt = {False: 'n', True: 't'} -square_confs = [ - triton.testing.Benchmark( - x_names=['M', 'N', 'K'], - x_vals=[128, 256, 512, 1024, 2048, 3072, 4096, 6144], - line_arg='block', - line_vals=[16, 32, 64, 128], - line_names=['Block16', 'Block32', 'Block64', 'Block128'], - ylabel='TFLOPS', - plot_name=f'{op_mode}-{layout_mode}-square-{nt[AT]}{nt[BT]}', - args={'layout_mode': layout_mode, 'op_mode': op_mode, - 'AT': AT, 'BT': BT, 'dtype': torch.float16, 'provider': 'triton'} - ) - for AT in [False] for BT in [False] - for op_mode in ['dsd'] for layout_mode in ['dense'] -] - - -@triton.testing.perf_report(square_confs) -def bench_matmul(M, N, K, block, layout_mode, op_mode, AT, BT, dtype, provider, warmup=100, rep=1000): - Z, H = 1, 1 - make_layout = { - 'tril': lambda H, M, N: torch.tril(torch.ones((H, M, N), dtype=torch.int64)), - 'dense': lambda H, M, N: torch.ones(H, M, N, dtype=torch.int64), - }[layout_mode] - # create layout - shape = {'sdd': (M, N), 'dsd': (K, M) if AT else (M, K), 'dds': (N, K) if BT else (K, N)}[op_mode] - layout = make_layout(H, shape[0] // block, shape[1] // block) - # creat inputs - a = torch.randn((Z, H, K, M) if AT else (Z, H, M, K), dtype=dtype, device='cuda') - b = torch.randn((Z, H, N, K) if BT else (Z, H, K, N), dtype=dtype, device='cuda') - # create op - tflops = lambda ms: num_flops / ms * 1e3 - if provider == 'triton': - op = triton.ops.blocksparse.matmul(layout, block, op_mode, trans_a=AT, trans_b=BT) - # inputs - a = triton.testing.sparsify_tensor(a, layout, block) if op_mode == 'dsd' else a - b = triton.testing.sparsify_tensor(b, layout, block) if op_mode == 'dds' else b - mean_ms, min_ms, max_ms = triton.testing.do_bench(lambda: op(a, b), warmup=warmup, rep=rep) - num_flops = { - 'sdd': 2 * Z * K * float(layout.sum()) * block * block, - 'dsd': 2 * Z * N * float(layout.sum()) * block * block, - 'dds': 2 * Z * M * float(layout.sum()) * block * block - }[op_mode] * 1e-12 - return tflops(mean_ms), tflops(min_ms), tflops(max_ms) - - -# ------------------------------- -# Softmax -# ------------------------------- - -square_confs = [ - triton.testing.Benchmark( - x_names=['M', 'N'], - x_vals=[128, 256, 512, 1024, 2048, 3072, 4096, 6144], - line_arg='block', - line_vals=[16, 32, 64], - line_names=['Block16', 'Block32', 'Block64'], - ylabel='GBPS', - plot_name=f'{layout_mode}-square', - args={'layout_mode': layout_mode, 'dtype': torch.float16, 'provider': 'triton'} - ) - for layout_mode in ['dense', 'tril'] -] - - -@triton.testing.perf_report(square_confs) -def bench_softmax(M, N, block, layout_mode, dtype, provider, warmup=10, rep=50): - Z, H = 1, 1 - make_layout = { - 'tril': lambda H, M, N: torch.tril(torch.ones((H, M, N), dtype=torch.int64)), - 'dense': lambda H, M, N: torch.ones(H, M, N, dtype=torch.int64), - }[layout_mode] - layout = make_layout(H, M // block, N // block) - a = torch.randn((Z, H, M, N), dtype=dtype, device='cuda') - if provider == 'triton': - a = triton.testing.sparsify_tensor(a, layout, block) - op = triton.ops.blocksparse.softmax(layout, block) - gbps = lambda ms: (2 * a.numel() * a.element_size() * 1e-9) / (ms * 1e-3) - mean_ms, min_ms, max_ms = triton.testing.do_bench(lambda: op(a), warmup=warmup, rep=rep) - return gbps(mean_ms), gbps(min_ms), gbps(max_ms) - - -bench_matmul.run(print_data=True, show_plots=True) diff --git a/python/bench/bench_cross_entropy.py b/python/bench/bench_cross_entropy.py deleted file mode 100644 index aaa0e28f5423..000000000000 --- a/python/bench/bench_cross_entropy.py +++ /dev/null @@ -1,41 +0,0 @@ -import torch - -import triton - -confs = [ - triton.testing.Benchmark( - x_names=['N'], - x_vals=[128, 256, 512, 1024, 2048, 3072, 4096, 6144, 8192], - line_arg='provider', - line_vals=['triton', 'torch'], - line_names=['Triton', 'Torch'], - ylabel='GBPS', - plot_name=f'{mode}-2048', - args={'M': 2048, 'dtype': torch.float16, 'mode': mode} - ) - for mode in ['forward', 'backward'] -] - - -@triton.testing.perf_report(confs) -def bench_op(M, N, dtype, mode, provider): - # create inputs - x = torch.randn(M, N, dtype=dtype, device='cuda', requires_grad=True) - idx = 4 + torch.ones(M, dtype=torch.int64, device='cuda') - num_gb = (2 * x.numel() * x.element_size() * 1e-9) - gbps = lambda ms: num_gb / ms * 1e3 - # forward pass - op = {'torch': torch.nn.CrossEntropyLoss(reduction='none'), - 'triton': triton.ops.cross_entropy}[provider] - if mode == 'forward': - mean_ms, min_ms, max_ms = triton.testing.do_bench(lambda: op(x, idx)) - if mode == 'backward': - y = op(x, idx) - dy = torch.randn_like(y) - fn = lambda: y.backward(dy, retain_graph=True) - mean_ms, min_ms, max_ms = triton.testing.do_bench(fn, grad_to_none=[x]) - return gbps(mean_ms), gbps(min_ms), gbps(max_ms) - - -if __name__ == '__main__': - bench_op.run(print_data=True) diff --git a/python/bench/bench_matmul.py b/python/bench/bench_matmul.py deleted file mode 100644 index b776b3dbff38..000000000000 --- a/python/bench/bench_matmul.py +++ /dev/null @@ -1,67 +0,0 @@ -import torch - -import triton - - -def rounded_linspace(low, high, steps, div): - ret = torch.linspace(low, high, steps) - ret = (ret.int() + div - 1) // div * div - ret = torch.unique(ret) - return list(map(int, ret)) - - -# Square benchmarks -nt = {False: "n", True: "t"} -square_confs = [ - triton.testing.Benchmark( - x_names=["M", "N", "K"], - x_vals=rounded_linspace(512, 8192, 32, 128), - line_arg="provider", - line_vals=["cublas", "triton", "cutlass"], - line_names=["cuBLAS", "Triton", "CUTLASS"], - ylabel="TFLOPS", - plot_name=f"matmul-square-{nt[AT]}{nt[BT]}", - args={"AT": AT, "BT": BT, "dtype": torch.float16}, - ) for AT in [False] for BT in [False] -] - -# Transformer training benchmarks -transformer_confs = [ - triton.testing.Benchmark( - x_names=[x], - x_vals=rounded_linspace(NK // 16, NK, 32, 128), - line_arg="provider", - line_vals=["cublas", "triton", "cutlass"], - line_names=["cuBLAS", "Triton", "CUTLASS"], - ylabel="TFLOPS", - plot_name=f"matmul-M{M}-{'NK'.replace(x, '')}{NK}", - args={"M": M, 'NK'.replace(x, ''): NK, "AT": False, "BT": False, "dtype": torch.float16} - ) for NK in [12288] - for i, x in enumerate(["N", "K"]) - for M in [2048] -] - - -@triton.testing.perf_report(square_confs) -def bench_op(M, N, K, AT, BT, dtype, provider, warmup=25, rep=75): - a = torch.rand((K, M) if AT else (M, K), device="cuda", dtype=dtype) - b = torch.rand((N, K) if BT else (K, N), device="cuda", dtype=dtype) - if AT: - a = a.t() - if BT: - b = b.t() - tflops = lambda ms: 2. * M * N * K / ms * 1e-9 - if provider == "cublas": - ms, min_ms, max_ms = triton.testing.do_bench(lambda: torch.matmul(a, b), warmup=warmup, rep=rep) - return tflops(ms), tflops(max_ms), tflops(min_ms) - if provider == "triton": - ms, min_ms, max_ms = triton.testing.do_bench(lambda: triton.ops.matmul(a, b), warmup=warmup, rep=rep) - return tflops(ms), tflops(max_ms), tflops(min_ms) - if provider == "cutlass": - cutlass_matmul = triton.testing.cutlass_matmul - try: - ms, min_ms, max_ms = triton.testing.do_bench(lambda: cutlass_matmul(a, b), warmup=warmup, rep=rep) - return tflops(ms), tflops(max_ms), tflops(min_ms) - except Exception: - return None - return None diff --git a/python/bench/requirements-bench.txt b/python/bench/requirements-bench.txt deleted file mode 100644 index eb69de862c9a..000000000000 --- a/python/bench/requirements-bench.txt +++ /dev/null @@ -1,2 +0,0 @@ -pandas >= 1.3.3 -matplotlib >= 3.4.3 \ No newline at end of file diff --git a/python/bench/run.py b/python/bench/run.py deleted file mode 100644 index 5e6e3b392012..000000000000 --- a/python/bench/run.py +++ /dev/null @@ -1,44 +0,0 @@ -import argparse -import inspect -import os -import sys - -import triton - - -def run_all(result_dir, names): - if not os.path.exists(result_dir): - os.makedirs(result_dir) - for mod in os.listdir(os.path.dirname(os.path.realpath(__file__))): - # skip non python files - if not mod.endswith('.py'): - continue - # skip file not in provided names - if names and names not in mod: - continue - # skip files that don't start with 'bench_' - if not mod.startswith('bench_'): - continue - print(f'running {mod}...') - mod = __import__(os.path.splitext(mod)[0]) - benchmarks = inspect.getmembers(mod, lambda x: isinstance(x, triton.testing.Mark)) - for name, bench in benchmarks: - curr_dir = os.path.join(result_dir, mod.__name__.replace('bench_', '')) - if len(benchmarks) > 1: - curr_dir = os.path.join(curr_dir, name.replace('bench_', '')) - if not os.path.exists(curr_dir): - os.makedirs(curr_dir) - bench.run(save_path=curr_dir) - - -def main(args): - parser = argparse.ArgumentParser(description="Run the benchmark suite.") - parser.add_argument("-r", "--result-dir", type=str, default='results', required=False) - parser.add_argument("-n", "--names", type=str, default='', required=False) - parser.set_defaults(feature=False) - args = parser.parse_args(args) - run_all(args.result_dir, args.names) - - -if __name__ == '__main__': - main(sys.argv[1:]) diff --git a/python/examples/copy_strided.py b/python/examples/copy_strided.py new file mode 100644 index 000000000000..8e0db3e15beb --- /dev/null +++ b/python/examples/copy_strided.py @@ -0,0 +1,18 @@ + +import triton +import triton.language as tl + +# triton kernel +@triton.jit +def kernel(X, stride_xm, stride_xn, + Z, stride_zm, stride_zn, + BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr): + off_m = tl.arange(0, BLOCK_M) + off_n = tl.arange(0, BLOCK_N) + Xs = X + off_m[:, None] * stride_xm + off_n[None, :] * stride_xn + Zs = Z + off_m[:, None] * stride_zm + off_n[None, :] * stride_zn + tl.store(Zs, tl.load(Xs)) + + +ret = triton.compile(kernel, "*fp32,i32,i32,*fp32,i32,i32", constants={"BLOCK_M": 128, "BLOCK_N": 128}, output="ttgir") +print(ret) \ No newline at end of file diff --git a/python/examples/empty.py b/python/examples/empty.py new file mode 100644 index 000000000000..233aff36e85d --- /dev/null +++ b/python/examples/empty.py @@ -0,0 +1,8 @@ +import triton +import triton.language as tl + +@triton.jit +def kernel(X, stride_xm, stride_xn, BLOCK: tl.constexpr): + pass + +ret = triton.compile(kernel, "*fp32,i32,i32", constants={"BLOCK": 256}, output="ttgir") \ No newline at end of file diff --git a/python/src/cutlass.cc b/python/src/cutlass.cc deleted file mode 100644 index 14da81330b12..000000000000 --- a/python/src/cutlass.cc +++ /dev/null @@ -1,202 +0,0 @@ -#include "cutlass/library/handle.h" -#include "cutlass/library/library.h" -#include "cutlass/library/operation_table.h" -#include "cutlass/library/singleton.h" -#include "pybind11/pybind11.h" -#include "triton/tools/bench.hpp" - -using namespace cutlass; -using namespace cutlass::library; - -std::map, const Operation *> op_cache_; - -static int const kHostWorkspaceSize = (4 << 10); -static int const kDeviceWorkspaceSize = (4 << 20); - -void run(int M, int N, int K, - int lda, int ldb, int ldc, int ldd, - void const *ptr_A, void const *ptr_B, void const *ptr_C, void *ptr_D, - void const *alpha, void const *beta, - ScalarPointerMode scalar_mode, - const Operation *operation, - cudaStream_t stream) { - - GemmUniversalConfiguration configuration{ - GemmUniversalMode::kGemm, - {M, N, K}, - 1, - lda, - ldb, - ldc, - ldd}; - - // host workspace size - uint64_t host_workspace_size_needed = operation->get_host_workspace_size(&configuration); - if (uint64_t(kHostWorkspaceSize) < host_workspace_size_needed) - throw std::runtime_error("Unable to find gemm operation"); - char host_workspace[kHostWorkspaceSize]; - - // device workspace size - uint64_t device_workspace_size_needed = operation->get_device_workspace_size(&configuration); - if (uint64_t(kDeviceWorkspaceSize) < device_workspace_size_needed) - throw std::runtime_error("Unable to find gemm operation"); - static void *device_workspace; - - // Initialize host and device workspaces - Status status = operation->initialize(&configuration, host_workspace, device_workspace, stream); - if (status != cutlass::Status::kSuccess) - throw std::runtime_error("Unable to initialize workspace"); - - // Run the operator - GemmArguments arguments{ptr_A, ptr_B, ptr_C, ptr_D, alpha, beta, scalar_mode}; - operation->run(&arguments, host_workspace, device_workspace, stream); -} - -const Operation *autotune(int M, int N, int K, - NumericTypeID element_compute, - NumericTypeID element_scalar, - void const *alpha, - NumericTypeID element_A, - LayoutTypeID layout_A, - ComplexTransform transform_A, - void const *ptr_A, - int lda, - NumericTypeID element_B, - LayoutTypeID layout_B, - ComplexTransform transform_B, - void const *ptr_B, - int ldb, - void const *beta, - NumericTypeID element_C, - void const *ptr_C, - int ldc, - void *ptr_D, - int ldd, - ScalarPointerMode scalar_mode, - int device_id, - cudaStream_t stream) { - - // index operation table with functional key - GemmFunctionalKey key( - Provider::kCUTLASS, - GemmKind::kUniversal, - element_compute, - element_scalar, - element_A, - layout_A, - transform_A, - element_B, - layout_B, - transform_B, - element_C); - - auto operators_it = Singleton::get().operation_table.gemm_operations.find(key); - if (operators_it == Singleton::get().operation_table.gemm_operations.end()) - throw std::runtime_error("Unable to find gemm operation"); - if (operators_it->second.empty()) - throw std::runtime_error("Unable to find gemm operation"); - - cudaDeviceProp device_prop; - cudaError_t error = cudaGetDeviceProperties(&device_prop, device_id); - if (error != cudaSuccess) - throw std::runtime_error("Unable to get device properties"); - int cc = device_prop.major * 10 + device_prop.minor; - - // index operation table with preference key - // assume 8-bytes aligned memory pointers - int alignment = 8; - GemmPreferenceKey preference_key(cc, alignment); - auto autotune_it = operators_it->second.find(preference_key); - if (autotune_it == operators_it->second.end()) - throw std::runtime_error("Unable to find gemm operation"); - const std::vector &operations = autotune_it->second; - if (operations.empty()) - throw std::runtime_error("Unable to find gemm operation"); - - // auto-tune - const Operation *best = nullptr; - double best_ms = std::numeric_limits::max(); - for (const Operation *op : operations) { - auto fn = [&]() { run(M, N, K, lda, ldb, ldc, ldd, ptr_A, ptr_B, ptr_C, ptr_D, - alpha, beta, scalar_mode, op, stream); }; - triton::driver::cu_stream tt_stream((CUstream)stream, false); - double ms = triton::tools::bench(fn, &tt_stream, 10, 25); - if (ms < best_ms) { - best_ms = ms; - best = op; - } - } - return best; -} - -// map of torch datatypes to cutlass datatypes -std::map type_map = { - {"float16", NumericTypeID::kF16}, - {"float32", NumericTypeID::kF32}, - {"float64", NumericTypeID::kF64}}; - -void cutlass_matmul(uintptr_t A, uintptr_t B, uintptr_t C, - size_t M, size_t N, size_t K, - size_t stride_a_0, size_t stride_a_1, - size_t stride_b_0, size_t stride_b_1, - size_t stride_c_0, size_t stride_c_1, - std::string type_a, std::string type_b, std::string type_c, - size_t dev_id, uint64_t stream_handle) { - void *ptr_A = (void *)A; - void *ptr_B = (void *)B; - void *ptr_C = (void *)C; - void *ptr_D = ptr_C; - size_t lda = stride_a_0; - size_t ldb = stride_b_0; - size_t ldc = stride_c_1; - size_t ldd = ldc; - float alpha = 1.0f; - float beta = 0.0f; - // layout for A - LayoutTypeID layout_A; - if (stride_a_0 == 1) - layout_A = LayoutTypeID::kColumnMajor; - else if (stride_a_1 == 1) - layout_A = LayoutTypeID::kRowMajor; - else - throw std::runtime_error("A layout is not supported"); - // layout for B - LayoutTypeID layout_B; - if (stride_b_0 == 1) - layout_B = LayoutTypeID::kColumnMajor; - else if (stride_b_1 == 1) - layout_B = LayoutTypeID::kRowMajor; - else - throw std::runtime_error("B layout is not supported"); - // data types - NumericTypeID element_compute = NumericTypeID::kF32; - NumericTypeID element_A = type_map[type_a]; - NumericTypeID element_B = type_map[type_b]; - NumericTypeID element_C = type_map[type_c]; - // misc. flags - ScalarPointerMode scalar_mode = ScalarPointerMode::kHost; - NumericTypeID element_scalar = NumericTypeID::kF32; - ComplexTransform transform_A = ComplexTransform::kNone; - ComplexTransform transform_B = ComplexTransform::kNone; - // runtime flags - cudaStream_t stream = (cudaStream_t)stream_handle; - // auto-tune - std::vector tune_key = {M, N, K, (size_t)element_A, (size_t)element_B, (size_t)element_C, - dev_id, (size_t)element_compute, (size_t)scalar_mode}; - auto it = op_cache_.find(tune_key); - if (it == op_cache_.end()) { - const Operation *op = autotune(M, N, K, element_compute, element_scalar, &alpha, - element_A, layout_A, transform_A, ptr_A, lda, - element_B, layout_B, transform_B, ptr_B, ldb, - &beta, element_C, ptr_C, ldc, ptr_D, ldd, scalar_mode, - dev_id, stream); - it = op_cache_.insert({tune_key, op}).first; - } - run(M, N, K, lda, ldb, ldc, ldd, ptr_A, ptr_B, ptr_C, ptr_D, &alpha, &beta, - scalar_mode, it->second, stream); -} - -void init_cutlass(pybind11::module &m) { - pybind11::module subm = m.def_submodule("cutlass"); - subm.def("matmul", &cutlass_matmul, "matrix multiplication"); -} \ No newline at end of file diff --git a/python/src/functions.h b/python/src/functions.h deleted file mode 100644 index 19f7e7eb9bc7..000000000000 --- a/python/src/functions.h +++ /dev/null @@ -1,676 +0,0 @@ -#include "triton/ir/builder.h" -#include -#include -#include - -namespace ir = triton::ir; -namespace py = pybind11; - -static const std::string _builder_doc = R"pbdoc( - :param builder: IR builder to generate code into, optional, set automatically when called inside a @triton.jit function - :type builder: triton.ir.builder -)pbdoc"; - -#define VA_ARGS(...) , ##__VA_ARGS__ -#define DEF_FUNC(MOD, PY_NAME, C_FUNC, ...) \ - MOD.def(PY_NAME, C_FUNC, (C_FUNC##_docstr + _builder_doc).c_str(), \ - ret::reference VA_ARGS(__VA_ARGS__), "builder"_a) - -void throw_not_implemented(std::string key) { - throw std::runtime_error("Encountered unimplemented code path in `" + key + "`. This is likely a bug on our side."); -} - -void throw_not_int_or_float(std::string key) { - throw std::runtime_error("`" + key + "` only supported for integer and floating point types."); -} - -enum type_code { - _bool, - int8, - int16, - int32, - int64, - float16, - float32, - float64 -}; - -ir::type *make_ir(type_code ty, ir::builder *builder) { - switch (ty) { - case float16: - return builder->get_half_ty(); - case float32: - return builder->get_float_ty(); - default: - throw_not_implemented("make_ir"); - } -} - -type_code from_ir(ir::type *ty) { - if (ty->is_half_ty()) - return float16; - if (ty->is_float_ty()) - return float32; - throw_not_implemented("from_ir"); -} - -/*---------------------------------------------- - definition of triton.cast / triton.ir.value.to - ----------------------------------------------*/ -std::string cast_docstr = R"pbdoc( - Tries to cast a block to a new data type. - - :param input: The input block. - :type input: triton.ir.value -)pbdoc"; - -ir::value *cast(ir::value *input, type_code _dtype, ir::builder *builder) { - ir::type *src_ty = input->get_type(); - ir::type *dst_ty = make_ir(_dtype, builder); - if (src_ty->is_block_ty()) - dst_ty = ir::block_type::get(dst_ty, input->get_type()->get_block_shapes()); - ir::type *src_sca_ty = src_ty->get_scalar_ty(); - ir::type *dst_sca_ty = dst_ty->get_scalar_ty(); - // FP Truncation - bool truncate_fp = src_sca_ty->is_floating_point_ty() && - dst_sca_ty->is_floating_point_ty() && - src_sca_ty->get_fp_mantissa_width() > dst_sca_ty->get_fp_mantissa_width(); - if (truncate_fp) - return builder->create_fp_trunc(input, dst_ty); - // FP Extension - bool ext_fp = src_sca_ty->is_floating_point_ty() && - dst_sca_ty->is_floating_point_ty() && - src_sca_ty->get_fp_mantissa_width() < dst_sca_ty->get_fp_mantissa_width(); - if (ext_fp) - return builder->create_fp_ext(input, dst_ty); - // Int cast - if (src_sca_ty->is_integer_ty() && dst_sca_ty->is_integer_ty() && - src_sca_ty->get_integer_bitwidth() != dst_sca_ty->get_integer_bitwidth()) - return builder->create_int_cast(input, dst_ty, true); - // Float -> Int - if (src_sca_ty->is_floating_point_ty() && dst_sca_ty->is_integer_ty()) - return builder->create_fp_to_si(input, dst_ty); - // int -> Float - if (src_sca_ty->is_integer_ty() && dst_sca_ty->is_floating_point_ty()) - return builder->create_si_to_fp(input, dst_ty); - // Ptr -> Ptr - if (src_sca_ty->is_pointer_ty() && dst_sca_ty->is_pointer_ty()) - return builder->create_cast(ir::BitCast, input, dst_ty); - // * -> Bool - if (dst_sca_ty->is_bool_ty()) { - if (src_sca_ty->is_pointer_ty()) - input = cast(input, int64, builder); - ir::value *other = builder->get_int64(0); - if (src_ty->is_bool_ty()) - other = builder->create_splat(other, src_ty->get_block_shapes()); - return builder->create_icmpNE(input, other); - } - throw_not_implemented("cast from " + src_sca_ty->repr() + " to " + dst_sca_ty->repr()); -} - -/*---------------------------------------------- - definition of triton.broadcast_check - ----------------------------------------------*/ -std::string try_broadcast_docstr = R"pbdoc( - Tries to broadcast two blocks to a common compatible shape. - - :param input: The first input block. - :type input: triton.ir.value - :param other: The second input block. - :type other: triton.ir.value -)pbdoc"; - -std::tuple try_broadcast(ir::value *lhs, ir::value *rhs, ir::builder *builder) { - ir::type *lhs_ty = lhs->get_type(); - ir::type *rhs_ty = rhs->get_type(); - // make_shape_compatible(block, scalar) - if (lhs_ty->is_block_ty() && !rhs_ty->is_block_ty()) - rhs = builder->create_splat(rhs, lhs_ty->get_block_shapes()); - // make_shape_compatible(scalar, block) - else if (!lhs_ty->is_block_ty() && rhs_ty->is_block_ty()) - lhs = builder->create_splat(lhs, rhs_ty->get_block_shapes()); - // make_shape_compatible(block, block) - else if (lhs_ty->is_block_ty() && rhs_ty->is_block_ty()) { - auto lhs_shape = lhs_ty->get_block_shapes(); - auto rhs_shape = rhs_ty->get_block_shapes(); - if (lhs_shape.size() != rhs_shape.size()) - throw std::runtime_error("Cannot make_shape_compatible: blocks must have the same rank"); - ir::type::block_shapes_t ret_shape; - for (size_t i = 0; i < lhs_shape.size(); ++i) { - unsigned left = lhs_shape[i]; - unsigned right = rhs_shape[i]; - if (left == 1) - ret_shape.push_back(right); - else if (right == 1) - ret_shape.push_back(left); - else if (left == right) - ret_shape.push_back(left); - else - throw std::runtime_error("Cannot make_shape_compatible: incompatible dimensions at index " + std::to_string(i) + - ": " + std::to_string(left) + " and " + std::to_string(right)); - } - if (lhs_shape != ret_shape) - lhs = builder->create_broadcast(lhs, ret_shape); - if (rhs_shape != ret_shape) - rhs = builder->create_broadcast(rhs, ret_shape); - } - return std::make_tuple(lhs, rhs); -} - -/*---------------------------------------------- - definition of triton.broadcast_to - ----------------------------------------------*/ -std::string broadcast_to_docstr = R"pbdoc( - Tries to broadcast a block to a new shape. - - :param input: The input block. - :type input: triton.value - :param shape: The new shape. - :type shape: tuple of int -)pbdoc"; - -ir::value *broadcast_to(ir::value *input, const ir::type::block_shapes_t &shape, ir::builder *builder) { - if (!input->get_type()->is_block_ty()) - return builder->create_splat(input, shape); - auto src_shape = input->get_type()->get_block_shapes(); - if (src_shape.size() != shape.size()) - throw std::runtime_error("Cannot broadcast"); - return builder->create_broadcast(input, shape); -} - -/*---------------------------------------------- - definition of triton.load - ----------------------------------------------*/ -std::string load_docstr = R"pbdoc( - Return a block of data whose values are, elementwise, loaded from memory at location defined by `pointer`. - - :param pointer: Pointer to the data to be loaded. - :type pointer: Block of triton.pointer - :param mask: if mask[idx] is false, do not load the data at `pointer[idx]`. - :type mask: Block of triton.bool, optional - :param other: if mask[idx] is false, return other[idx] instead of 'pointer[idx]` - :type other: Block of triton.value, optional - )pbdoc"; - -ir::value *load(ir::value *pointer, std::optional _mask, std::optional _other, ir::builder *builder) { - if (!_mask.has_value() && !_other.has_value()) - return builder->create_load(pointer); - if (!_mask.has_value()) - throw std::runtime_error("`other` cannot be provided without `mask`"); - ir::value *mask = _mask.value(); - ir::type *elt_ty = pointer->get_type()->get_scalar_ty()->get_pointer_element_ty(); - auto shape = pointer->get_type()->get_block_shapes(); - ir::value *other = _other.has_value() ? _other.value() : ir::undef_value::get(elt_ty); - other = cast(other, from_ir(elt_ty), builder); - other = broadcast_to(other, shape, builder); - mask = broadcast_to(mask, shape, builder); - return builder->create_masked_load(pointer, mask, other); -} - -/*---------------------------------------------- - definition of triton.store - ----------------------------------------------*/ -std::string store_docstr = R"pbdoc( - Stores `value` block of elements in memory, element-wise, at the memory locations specified by `pointer`. - - :param pointer: The memory locations where the elements of `value` are stored. - :type pointer: Block of triton.pointer - :param value: The block of elements to be stored. - :type value: Block of triton.value - :param mask: If mask[idx] is false, do not store `value[idx]` at `pointer[idx]`. - :type mask: Block of triton.bool, optional - )pbdoc"; -ir::value *store(ir::value *ptr, ir::value *val, std::optional _mask, ir::builder *builder) { - if (!_mask.has_value()) - return builder->create_store(ptr, val); - ir::value *mask = _mask.value(); - return builder->create_masked_store(ptr, val, mask); -} - -/*---------------------------------------------- - definition of triton.dot - ----------------------------------------------*/ -std::string dot_docstr = R"pbdoc( - Returns the matrix product of two blocks. - The two blocks must be two dimensionals and have compatible inner dimensions. - - :param input: The first block to be multiplied. - :type input: 2D block of scalar-type in {`float16`, `float32`} - :param other: The second block to be multiplied. - :type other: 2D block of scalar-type in {`float16`, `float32`} - )pbdoc"; -ir::value *dot(ir::value *lhs, ir::value *rhs, ir::builder *builder) { - ir::value *_0 = builder->get_float32(0); - unsigned M = lhs->get_type()->get_block_shapes()[0]; - unsigned N = rhs->get_type()->get_block_shapes()[1]; - _0 = builder->create_splat(_0, {M, N}); - return builder->create_dot(lhs, rhs, _0); -} - -/*---------------------------------------------- - definition of triton.where - ----------------------------------------------*/ -std::string where_docstr = R"pbdoc( - Returns a block of elements from either `x` or `y`, depending on `condition`. - Note that `x` and `y` are always evaluated regardless of the value of `condition`. - If you want to avoid unintented memory operations, use the `mask` arguments in `triton.load` and `triton.store` instead. - - :param condition: When True (nonzero), yield x, otherwise yield y. - :type condition: Block of triton.bool - :param x: values selected at indices where condition is True. - :param y: values selected at indices where condition is False. - )pbdoc"; -ir::value *where(ir::value *condition, ir::value *x, ir::value *y, ir::builder *builder) { - return builder->create_select(condition, x, y); -}; - -/*---------------------------------------------- - definition of triton.arange - ----------------------------------------------*/ -std::string arange_docstr = R"pbdoc( - Returns contiguous values within the open interval [start, end). - - :param start: Start of the interval. - :type start: int - :param stop: End of the interval. - :type stop: int - )pbdoc"; -ir::value *arange(int start, int end, ir::builder *builder) { - return builder->get_range(start, end); -}; - -/*---------------------------------------------- - definition of triton.program_id - ----------------------------------------------*/ -std::string program_id_docstr = R"pbdoc( - Returns the id of the current program instance along the given `axis`. - Triton uses an SPMD model in which different @triton.jit functions run in parallel with different `program_id`s. - - :param axis: The axis of the 3D launch grid. Has to be either 0, 1 or 2. - :type axis: int - )pbdoc"; -ir::value *program_id(int axis, ir::builder *builder) { - return builder->create_get_program_id(axis); -}; - -/*---------------------------------------------- - definition of triton.num_programs - ----------------------------------------------*/ -std::string num_programs_docstr = R"pbdoc( - Returns the number of program instances launched along the given `axis`. - - :param axis: The axis of the 3D launch grid. Has to be either 0, 1 or 2. - :type axis: int - )pbdoc"; -ir::value *num_programs(int axis, ir::builder *builder) { - return builder->create_get_num_programs(axis); -}; - -/*---------------------------------------------- - definition of triton.zeros - ----------------------------------------------*/ -std::string zeros_docstr = R"pbdoc( - Returns a block filled with the scalar value 0 and the given shape. - - :param shape: Shape of the new array, e.g., (8, 16) or (8, ) - :type shape: tuple of ints - :param dtype: Data-type of the new array, e.g., tl.float16 - :type dtype: triton.ir.dtype - )pbdoc"; -ir::value *zeros(ir::type::block_shapes_t shape, type_code _dtype, ir::builder *builder) { - ir::type *dtype = make_ir(_dtype, builder); - ir::value *_0 = ir::constant::get_null_value(dtype); - return builder->create_splat(_0, shape); -}; - -/*---------------------------------------------- - definition of triton.exp - ----------------------------------------------*/ -std::string _exp_docstr = R"pbdoc( - Returns the element-wise exponential of `input`. - )pbdoc"; -ir::value *_exp(ir::value *input, ir::builder *builder) { - return builder->create_exp(input); -}; - -/*---------------------------------------------- - definition of triton.log - ----------------------------------------------*/ -std::string _log_docstr = R"pbdoc( - Returns the element-wise natural logarithm of `input`. - )pbdoc"; -ir::value *_log(ir::value *input, ir::builder *builder) { - return builder->create_log(input); -}; - -/*---------------------------------------------- - definition of triton.sqrt - ----------------------------------------------*/ -std::string sqrt_docstr = R"pbdoc( - Returns the element-wise square root of `input`. - )pbdoc"; -ir::value *sqrt(ir::value *input, ir::builder *builder) { - return builder->create_sqrt(input); -}; - -/*---------------------------------------------- - definition of triton.min - ----------------------------------------------*/ -ir::value *reduce_impl(ir::value *input, unsigned int axis, ir::builder *builder, const std::string &name, - ir::reduce_inst::op_t FLOAT_OP, ir::reduce_inst::op_t INT_OP) { - ir::type *scalar_ty = input->get_type()->get_scalar_ty(); - if (scalar_ty->is_floating_point_ty()) - return builder->create_reduce(input, FLOAT_OP, axis); - else if (scalar_ty->is_integer_ty()) - return builder->create_reduce(input, INT_OP, axis); - else - throw_not_int_or_float(name); -} - -std::string min_docstr = R"pbdoc( - Returns the minimum value of `input`. - )pbdoc"; -ir::value *min(ir::value *input, unsigned int axis, ir::builder *builder) { - return reduce_impl(input, axis, builder, "min", ir::reduce_inst::FMIN, ir::reduce_inst::MIN); -}; - -/*---------------------------------------------- - definition of triton.max - ----------------------------------------------*/ -std::string max_docstr = R"pbdoc( - Returns the maximum value of `input`. - )pbdoc"; -ir::value *max(ir::value *input, unsigned int axis, ir::builder *builder) { - return reduce_impl(input, axis, builder, "max", ir::reduce_inst::FMAX, ir::reduce_inst::MAX); -}; - -/*---------------------------------------------- - definition of triton.sum - ----------------------------------------------*/ -std::string sum_docstr = R"pbdoc( - Returns the sum of `input`. - )pbdoc"; -ir::value *sum(ir::value *input, unsigned int axis, ir::builder *builder) { - return reduce_impl(input, axis, builder, "sum", ir::reduce_inst::FADD, ir::reduce_inst::ADD); -}; - -/*---------------------------------------------- - definition of triton.atomic_cas - ----------------------------------------------*/ -std::string atomic_cas_docstr = R"pbdoc( - Atomic compare-and-swap. - )pbdoc"; -ir::value *atomic_cas(ir::value *ptr, ir::value *cmp, ir::value *val, ir::builder *builder) { - return builder->create_atomic_cas(ptr, cmp, val); -}; - -/*---------------------------------------------- - definition of triton.atomic_xchg - ----------------------------------------------*/ -std::string atomic_xchg_docstr = R"pbdoc( - Atomic exchange. - )pbdoc"; -ir::value *atomic_xchg(ir::value *ptr, ir::value *val, ir::builder *builder) { - return builder->create_atomic_exch(ptr, val); -}; - -/*---------------------------------------------- - debug barrier - ----------------------------------------------*/ -std::string debug_barrier_docstr = R"pbdoc( - Temporary hacky fixup for when the compiler forgets to insert sync barriers -)pbdoc"; -ir::value *debug_barrier(ir::builder *builder) { - return builder->create_barrier(); -} - -#define DEF_BINARY_OP(MOD, PY_NAME, C_FUNC, ...) \ - MOD.def(PY_NAME, binary_op(C_FUNC), (C_FUNC##_docstr + _builder_doc).c_str(), \ - ret::reference VA_ARGS(__VA_ARGS__), "builder"_a) - -template -std::function -binary_op(const FN &fn) { - auto ret = [&fn](ir::value *self, ir::value *other, ir::builder *builder) { - //std::tie(self, other) = try_broadcast(self, other, builder); - return fn(self, other, builder); - }; - return ret; -} - -/*---------------------------------------------- - definition of self + other - ----------------------------------------------*/ -std::string add_docstr = R"pbdoc( - Returns self + other, element-wise. -)pbdoc"; -ir::value *add(ir::value *self, ir::value *other, ir::builder *builder) { - ir::type *scalar_ty = self->get_type()->get_scalar_ty(); - // ptr + offset - if (scalar_ty->is_pointer_ty()) - return builder->create_gep(self, {other}); - // float + float - else if (scalar_ty->is_floating_point_ty()) - return builder->create_fadd(self, other); - // int + int - else if (scalar_ty->is_integer_ty()) - return builder->create_add(self, other); - throw_not_implemented("add"); -} - -/*---------------------------------------------- - definition of self - other - ----------------------------------------------*/ -std::string sub_docstr = R"pbdoc( - Returns self - other, element-wise. -)pbdoc"; -ir::value *sub(ir::value *self, ir::value *other, ir::builder *builder) { - ir::type *scalar_ty = self->get_type()->get_scalar_ty(); - // ptr + offset - if (scalar_ty->is_pointer_ty()) - return builder->create_gep(self, {other}); - // float + float - if (scalar_ty->is_floating_point_ty()) - return builder->create_fsub(self, other); - // int + int - else if (scalar_ty->is_integer_ty()) - return builder->create_sub(self, other); - throw_not_implemented("sub"); -} - -/*---------------------------------------------- - definition of self * other - ----------------------------------------------*/ -std::string mul_docstr = R"pbdoc( - Returns self * other, element-wise. -)pbdoc"; -ir::value *mul(ir::value *self, ir::value *other, ir::builder *builder) { - ir::type *scalar_ty = self->get_type()->get_scalar_ty(); - // float * float - if (scalar_ty->is_floating_point_ty()) - return builder->create_fmul(self, other); - // int * int - else if (scalar_ty->is_integer_ty()) - return builder->create_mul(self, other); - throw_not_implemented("mul"); -} - -/*---------------------------------------------- - definition of self > other - ----------------------------------------------*/ -std::string greater_than_docstr = R"pbdoc( - Returns self > other, element-wise. -)pbdoc"; -ir::value *greater_than(ir::value *self, ir::value *other, ir::builder *builder) { - ir::type *scalar_ty = self->get_type()->get_scalar_ty(); - // float > float - if (scalar_ty->is_floating_point_ty()) - return builder->create_fcmpOGT(self, other); - // int > int - else if (scalar_ty->is_integer_ty()) - return builder->create_icmpSGT(self, other); - throw_not_implemented("greater_than"); -} - -/*---------------------------------------------- - definition of self >= other - ----------------------------------------------*/ -std::string greater_equal_docstr = R"pbdoc( - Returns self >= other, element-wise. -)pbdoc"; -ir::value *greater_equal(ir::value *self, ir::value *other, ir::builder *builder) { - ir::type *scalar_ty = self->get_type()->get_scalar_ty(); - // float >= float - if (scalar_ty->is_floating_point_ty()) - return builder->create_fcmpOGE(self, other); - // int >= int - else if (scalar_ty->is_integer_ty()) - return builder->create_icmpSGE(self, other); - throw_not_implemented("greater_equal"); -} - -/*---------------------------------------------- - definition of self < other - ----------------------------------------------*/ -std::string less_than_docstr = R"pbdoc( - Returns self < other, element-wise. -)pbdoc"; -ir::value *less_than(ir::value *self, ir::value *other, ir::builder *builder) { - ir::type *scalar_ty = self->get_type()->get_scalar_ty(); - // float < float - if (scalar_ty->is_floating_point_ty()) - return builder->create_fcmpOLT(self, other); - // int < int - else if (scalar_ty->is_integer_ty()) - return builder->create_icmpSLT(self, other); - throw_not_implemented("less_than"); -} - -/*---------------------------------------------- - definition of self <= other - ----------------------------------------------*/ -std::string less_equal_docstr = R"pbdoc( - Returns self <= other, element-wise. -)pbdoc"; -ir::value *less_equal(ir::value *self, ir::value *other, ir::builder *builder) { - ir::type *scalar_ty = self->get_type()->get_scalar_ty(); - // float < float - if (scalar_ty->is_floating_point_ty()) - return builder->create_fcmpOLE(self, other); - // int < int - else if (scalar_ty->is_integer_ty()) - return builder->create_icmpSLE(self, other); - throw_not_implemented("less_equal"); -} - -/*---------------------------------------------- - definition of self == other - ----------------------------------------------*/ -std::string equal_docstr = R"pbdoc( - Returns self == other, element-wise. -)pbdoc"; -ir::value *equal(ir::value *self, ir::value *other, ir::builder *builder) { - ir::type *scalar_ty = self->get_type()->get_scalar_ty(); - // float == float - if (scalar_ty->is_floating_point_ty()) - return builder->create_fcmpOEQ(self, other); - // int == int - else if (scalar_ty->is_integer_ty()) - return builder->create_icmpEQ(self, other); - throw_not_implemented("equal"); -} - -/*---------------------------------------------- - definition of self / other - ----------------------------------------------*/ -std::string _div_docstr = R"pbdoc( - Returns self / other, element-wise. -)pbdoc"; -ir::value *_div(ir::value *self, ir::value *other, ir::builder *builder) { - ir::type *scalar_ty = self->get_type()->get_scalar_ty(); - // float / float - if (scalar_ty->is_floating_point_ty()) - return builder->create_fdiv(self, other); - // int / int - else if (scalar_ty->is_integer_ty()) - return builder->create_sdiv(self, other); - throw_not_implemented("div"); -} - -/*---------------------------------------------- - definition of self % other - ----------------------------------------------*/ -std::string mod_docstr = R"pbdoc( - Returns self % other, element-wise. -)pbdoc"; -ir::value *mod(ir::value *self, ir::value *other, ir::builder *builder) { - ir::type *scalar_ty = self->get_type()->get_scalar_ty(); - // float % int - if (scalar_ty->is_floating_point_ty()) - return builder->create_frem(self, other); - // int % int - else if (scalar_ty->is_integer_ty()) - return builder->create_srem(self, other); - throw_not_implemented("mod"); -} - -/*---------------------------------------------- - definition of self & other - ----------------------------------------------*/ -std::string _and_docstr = R"pbdoc( - Returns self & other, element-wise. -)pbdoc"; -ir::value *_and(ir::value *self, ir::value *other, ir::builder *builder) { - return builder->create_and(self, other); -} - -/*---------------------------------------------- - definition of minimum(self, other) - ----------------------------------------------*/ -std::string minimum_docstr = R"pbdoc( - Returns element-wise minimum of self and other -)pbdoc"; -ir::value *minimum(ir::value *self, ir::value *other, ir::builder *builder) { - return where(less_than(self, other, builder), self, other, builder); -} - -/*---------------------------------------------- - definition of self[slices] - ----------------------------------------------*/ - -enum slice_mode_t { - NEWAXIS, - ALL -}; - -std::string subscript_docstr = R"pbdoc( - returns self[slices]. - - :param slices: The slices to subscript with. - :type slices: List of `None` or `:` slices. -)pbdoc"; -ir::value *subscript(ir::value *self, std::vector slices, ir::builder *builder) { - std::vector modes; - for (py::object slice : slices) { - py::object none = py::none(); - py::object all = py::make_tuple(none, none, none); - if (slice.is(none)) - modes.push_back(NEWAXIS); - else if (all.attr("__eq__")(slice)) - modes.push_back(ALL); - else - throw std::runtime_error("slice must be None or (None, None, None)"); - } - - ir::type::block_shapes_t shape; - size_t curr = 0; - for (slice_mode_t mode : modes) { - if (mode == NEWAXIS) - shape.push_back(1); - else { - assert(mode == ALL); - shape.push_back(self->get_type()->get_block_shapes()[curr++]); - } - } - return builder->create_reshape(self, shape); -} diff --git a/python/src/main.cc b/python/src/main.cc index 48fc69e0de1e..d09679727030 100644 --- a/python/src/main.cc +++ b/python/src/main.cc @@ -8,8 +8,4 @@ void init_cutlass(pybind11::module &m); PYBIND11_MODULE(libtriton, m) { m.doc() = "Python bindings to the C++ Triton API"; init_triton(m); - init_superblocking(m); -#ifdef WITH_CUTLASS_BINDINGS - init_cutlass(m); -#endif } diff --git a/python/src/superblock.cc b/python/src/superblock.cc deleted file mode 100644 index 1420521a608b..000000000000 --- a/python/src/superblock.cc +++ /dev/null @@ -1,119 +0,0 @@ -#include -#include -#include -#include -#include -#include -#include -#ifdef _OPENMP -#include -#endif - -// row-major 3d tensor -class tensor_3d { -public: - tensor_3d(int size_0, int size_1, int size_2, int *data = nullptr) : data_(size_0 * size_1 * size_2, 0) { - if (data) - std::copy(data, data + data_.size(), data_.begin()); - stride_0_ = size_1 * size_2; - stride_1_ = size_2; - stride_2_ = 1; - } - - int &operator()(int i, int j, int k) { - return data_[i * stride_0_ + j * stride_1_ + k]; - } - -private: - std::vector data_; - int stride_0_; - int stride_1_; - int stride_2_; -}; - -std::vector segment_blocks(tensor_3d &layout, tensor_3d &idx, int max_width, int H, int M, int N) { - tensor_3d tmp(H, M, N); - std::vector current(H, 0); - int num = 0; - std::vector lut(H * M * N * 4); - for (ssize_t h = 0; h < H; h++) { - // surrounding indices - std::vector ii_left(max_width, -1); - std::vector> ii_top(max_width, std::vector(N, -1)); - // start the dynamic programming algorithm - for (ssize_t m = 0; m < M; m++) { - for (ssize_t n = 0; n < N; n++) { - int v = layout(h, m, n); - if (v == 0) - continue; - int n_left = ii_left[max_width - 1]; - int m_top = ii_top[max_width - 1][n]; - int top = (m_top >= 0) ? tmp(h, m_top, n) : 0; - int left = (n_left >= 0) ? tmp(h, m, n_left) : 0; - int topleft = (m_top >= 0 && n_left >= 0) ? tmp(h, m_top, n_left) : 0; - int width = std::min(left, std::min(top, topleft)) + 1; - // reset width if blocks cannot be - // packed together (i.e., there's a 1 "in the middle") - for (int nn = n_left + 1; nn < n; nn++) - if (ii_top[max_width - 1][nn] > ii_top[max_width - 1][n]) - width = 1; - tmp(h, m, n) = width; - // update n_left ring buffer - for (int k = 0; k < max_width - 1; k++) - ii_left[k] = ii_left[k + 1]; - ii_left[max_width - 1] = n; - // update ii_top ring buffer - for (int k = 0; k < max_width - 1; k++) - ii_top[k][n] = ii_top[k + 1][n]; - ii_top[max_width - 1][n] = m; - // block is too small -- skip - if (width != max_width) - continue; - // retained blocks are set to zeros - for (ssize_t km = 0; km < max_width; km++) - for (ssize_t kn = 0; kn < max_width; kn++) { - int mm = ii_top[km][n]; - int nn = ii_left[kn]; - if (mm < 0 || nn < 0) - continue; - layout(h, mm, nn) = 0; - tmp(h, mm, nn) = 0; - lut[num++] = (int)h; - lut[num++] = (int)mm; - lut[num++] = (int)nn; - lut[num++] = idx(h, mm, nn); - } - } - } - } - lut.resize(num); - return lut; -} - -typedef std::pair> lut_t; - -std::vector superblock(uintptr_t LAYOUT, int H, int M, int N, int start_width) { - std::vector ret; - int current = 0; - tensor_3d layout(H, M, N, (int *)LAYOUT); - tensor_3d idx(H, M, N); - for (int64_t h = 0; h < H; h++) - for (int64_t m = 0; m < M; m++) - for (int64_t n = 0; n < N; n++) { - if (layout(h, m, n) == 0) - continue; - idx(h, m, n) = current++; - } - // create lut - for (int max_width = start_width; max_width > 0; max_width /= 2) { - auto lut = segment_blocks(layout, idx, max_width, H, M, N); - if (lut.size() == 0) - continue; - ret.push_back(std::make_pair(max_width, pybind11::array_t(lut.size(), lut.data()))); - } - return ret; -} - -void init_superblocking(pybind11::module &m) { - m.def("superblock", &superblock, "super-blocking for block-sparse matrix multiplication"); -} diff --git a/python/src/triton.cc b/python/src/triton.cc index ba008cac4ec0..d26c4faf6127 100644 --- a/python/src/triton.cc +++ b/python/src/triton.cc @@ -764,19 +764,26 @@ void init_triton_ir(py::module &&m) { .def("get_after", &mlir::scf::WhileOp::getAfter, ret::reference); py::class_(m, "CondtionOp"); - py::class_(m, "module") - .def("dump", &mlir::ModuleOp::dump) - .def("push_back", [](mlir::ModuleOp &self, mlir::FuncOp &funcOp) -> void { - self.push_back(funcOp); - }) - .def("has_function", [](mlir::ModuleOp &self, std::string &funcName) -> bool { - if (self.lookupSymbol(funcName)) - return true; - return false; - }) - .def("get_function", [](mlir::ModuleOp &self, std::string &funcName) -> mlir::FuncOp { - return self.lookupSymbol(funcName); - }) + // dynamic_attr is used to transfer ownership of the MLIR context to the module + py::class_(m, "module", py::dynamic_attr()) + .def("dump", &mlir::ModuleOp::dump) + .def("str", [](mlir::ModuleOp &self) -> std::string { + std::string str; + llvm::raw_string_ostream os(str); + self.print(os); + return str; + }) + .def("push_back", [](mlir::ModuleOp &self, mlir::FuncOp &funcOp) -> void { + self.push_back(funcOp); + }) + .def("has_function", [](mlir::ModuleOp &self, std::string &funcName) -> bool { + if (self.lookupSymbol(funcName)) + return true; + return false; + }) + .def("get_function", [](mlir::ModuleOp &self, std::string &funcName) -> mlir::FuncOp { + return self.lookupSymbol(funcName); + }) ; py::class_(m, "function") diff --git a/python/test/regression/test_performance.py b/python/test/regression/test_performance.py deleted file mode 100644 index f30b203bb94e..000000000000 --- a/python/test/regression/test_performance.py +++ /dev/null @@ -1,164 +0,0 @@ -import subprocess -import sys - -import pytest -import torch - -import triton -import triton.language as tl -from triton.testing import get_dram_gbps, get_max_tensorcore_tflops - -DEVICE_NAME = 'v100' - -####################### -# Utilities -####################### - - -def nvsmi(attrs): - attrs = ','.join(attrs) - cmd = ['nvidia-smi', '-i', '0', '--query-gpu=' + attrs, '--format=csv,noheader,nounits'] - out = subprocess.check_output(cmd) - ret = out.decode(sys.stdout.encoding).split(',') - ret = [int(x) for x in ret] - return ret - - -####################### -# Matrix Multiplication -####################### - -sm_clocks = {'v100': 1350, 'a100': 1350} -mem_clocks = {'v100': 877, 'a100': 1215} - -matmul_data = { - 'v100': { - # square - (256, 256, 256): {'float16': 0.027}, - (512, 512, 512): {'float16': 0.158}, - (1024, 1024, 1024): {'float16': 0.466}, - (2048, 2048, 2048): {'float16': 0.695}, - (4096, 4096, 4096): {'float16': 0.831}, - (8192, 8192, 8192): {'float16': 0.849}, - # tall-skinny - (16, 1024, 1024): {'float16': 0.0128}, - (16, 4096, 4096): {'float16': 0.0883}, - (16, 8192, 8192): {'float16': 0.101}, - (64, 1024, 1024): {'float16': 0.073}, - (64, 4096, 4096): {'float16': 0.270}, - (64, 8192, 8192): {'float16': 0.459}, - (1024, 64, 1024): {'float16': 0.0692}, - (4096, 64, 4096): {'float16': 0.264}, - (8192, 64, 8192): {'float16': 0.452}, - }, - 'a100': { - (256, 256, 256): {'float16': 0.010, 'float32': 0.0214, 'int8': 0.006}, - (512, 512, 512): {'float16': 0.061, 'float32': 0.109, 'int8': 0.030}, - (1024, 1024, 1024): {'float16': 0.287, 'float32': 0.331, 'int8': 0.169}, - (2048, 2048, 2048): {'float16': 0.604, 'float32': 0.599, 'int8': 0.385}, - (4096, 4096, 4096): {'float16': 0.842, 'float32': 0.862, 'int8': 0.711}, - (8192, 8192, 8192): {'float16': 0.896, 'float32': 0.932, 'int8': 0.860}, - # tall-skinny - (16, 1024, 1024): {'float16': 0.0077, 'float32': 0.0127, 'int8': 0.005}, - (16, 4096, 4096): {'float16': 0.0363, 'float32': 0.0457, 'int8': 0.0259}, - (16, 8192, 8192): {'float16': 0.0564, 'float32': 0.0648, 'int8': 0.0431}, - (64, 1024, 1024): {'float16': 0.0271, 'float32': 0.0509, 'int8': 0.0169}, - (64, 4096, 4096): {'float16': 0.141, 'float32': 0.162, 'int8': 0.097}, - (64, 8192, 8192): {'float16': 0.244, 'float32': 0.257, 'int8': 0.174}, - (1024, 64, 1024): {'float16': 0.0263, 'float32': 0.0458, 'int8': 0.017}, - (4096, 64, 4096): {'float16': 0.135, 'float32': 0.177, 'int8': 0.102}, - (8192, 64, 8192): {'float16': 0.216, 'float32': 0.230, 'int8': 0.177}, - } - # # deep reductions - # (64 , 64 , 16384) : {'a100': 0.}, - # (64 , 64 , 65536) : {'a100': 0.}, - # (256 , 256 , 8192 ) : {'a100': 0.}, - # (256 , 256 , 32768) : {'a100': 0.}, -} - - -@pytest.mark.parametrize('M, N, K, dtype_str', - [(M, N, K, dtype_str) - for M, N, K in matmul_data[DEVICE_NAME].keys() - for dtype_str in ['float16']]) -def test_matmul(M, N, K, dtype_str): - if dtype_str in ['float32', 'int8'] and DEVICE_NAME != 'a100': - pytest.skip('Only test float32 & int8 on a100') - dtype = {'float16': torch.float16, 'float32': torch.float32, 'int8': torch.int8}[dtype_str] - torch.manual_seed(0) - ref_gpu_util = matmul_data[DEVICE_NAME][(M, N, K)][dtype_str] - cur_sm_clock = nvsmi(['clocks.current.sm'])[0] - ref_sm_clock = sm_clocks[DEVICE_NAME] - max_gpu_perf = get_max_tensorcore_tflops(dtype, clock_rate=cur_sm_clock * 1e3) - assert abs(cur_sm_clock - ref_sm_clock) < 10, f'GPU SMs must run at {ref_sm_clock} MHz' - if dtype == torch.int8: - a = torch.randint(-128, 127, (M, K), dtype=dtype, device='cuda') - b = torch.randint(-128, 127, (N, K), dtype=dtype, device='cuda') - b = b.t() # only test row-col layout - else: - a = torch.randn((M, K), dtype=dtype, device='cuda') - b = torch.randn((K, N), dtype=dtype, device='cuda') - fn = lambda: triton.ops.matmul(a, b) - ms = triton.testing.do_bench(fn, percentiles=None, warmup=25, rep=1000) - cur_gpu_perf = 2. * M * N * K / ms * 1e-9 - cur_gpu_util = cur_gpu_perf / max_gpu_perf - triton.testing.assert_almost_equal(cur_gpu_util, ref_gpu_util, decimal=2) - - -####################### -# Element-Wise -####################### - - -@triton.jit -def _add(x_ptr, y_ptr, output_ptr, n_elements, - BLOCK_SIZE: tl.constexpr): - pid = tl.program_id(axis=0) - block_start = pid * BLOCK_SIZE - offsets = block_start + tl.arange(0, BLOCK_SIZE) - mask = offsets < n_elements - x = tl.load(x_ptr + offsets, mask=mask) - y = tl.load(y_ptr + offsets, mask=mask) - output = x + y - tl.store(output_ptr + offsets, output, mask=mask) - - -elementwise_data = { - 'v100': { - 1024 * 16: 0.0219, - 1024 * 64: 0.0791, - 1024 * 256: 0.243, - 1024 * 1024: 0.534, - 1024 * 4096: 0.796, - 1024 * 16384: 0.905, - 1024 * 65536: 0.939, - }, - 'a100': { - 1024 * 16: 0.008, - 1024 * 64: 0.034, - 1024 * 256: 0.114, - 1024 * 1024: 0.315, - 1024 * 4096: 0.580, - 1024 * 16384: 0.782, - 1024 * 65536: 0.850, - } -} - - -@pytest.mark.parametrize('N', elementwise_data[DEVICE_NAME].keys()) -def test_elementwise(N): - torch.manual_seed(0) - ref_gpu_util = elementwise_data[DEVICE_NAME][N] - cur_mem_clock = nvsmi(['clocks.current.memory'])[0] - ref_mem_clock = mem_clocks[DEVICE_NAME] - max_gpu_perf = get_dram_gbps() - assert abs(cur_mem_clock - ref_mem_clock) < 10, f'GPU memmory must run at {ref_mem_clock} MHz' - z = torch.empty((N, ), dtype=torch.float16, device='cuda') - x = torch.randn_like(z) - y = torch.randn_like(z) - grid = lambda args: (triton.cdiv(N, args['BLOCK_SIZE']), ) - fn = lambda: _add[grid](x, y, z, N, BLOCK_SIZE=1024) - ms = triton.testing.do_bench(fn, percentiles=None, warmup=25, rep=250) - cur_gpu_perf = 3. * N * z.element_size() / ms * 1e-6 - cur_gpu_util = cur_gpu_perf / max_gpu_perf - triton.testing.assert_almost_equal(cur_gpu_util, ref_gpu_util, decimal=2) diff --git a/python/test/unit/language/test_core.py b/python/test/unit/language/test_core.py deleted file mode 100644 index 3561f7af4b01..000000000000 --- a/python/test/unit/language/test_core.py +++ /dev/null @@ -1,1010 +0,0 @@ -# flake8: noqa: F821,F841 -import itertools -import re -from typing import Optional, Union - -import numpy as np -import pytest -import torch -from numpy.random import RandomState - -import triton -import triton._C.libtriton.triton as _triton -import triton.language as tl -from triton.code_gen import TensorWrapper, reinterpret - -int_dtypes = ['int8', 'int16', 'int32', 'int64'] -uint_dtypes = ['uint8', 'uint16', 'uint32', 'uint64'] -float_dtypes = ['float16', 'float32', 'float64'] -dtypes = int_dtypes + uint_dtypes + float_dtypes - - -def _bitwidth(dtype: str) -> int: - # ex.: "int64" -> 64 - return int(re.search(r'(\d+)$', dtype).group(1)) - - -def numpy_random(shape, dtype_str, rs: Optional[RandomState] = None, low=None, high=None): - """ - Override `rs` if you're calling this function twice and don't want the same - result for both calls. - """ - if isinstance(shape, int): - shape = (shape, ) - if rs is None: - rs = RandomState(seed=17) - dtype = getattr(np, dtype_str) - if dtype_str in int_dtypes + uint_dtypes: - iinfo = np.iinfo(getattr(np, dtype_str)) - low = iinfo.min if low is None else max(low, iinfo.min) - high = iinfo.max if high is None else min(high, iinfo.max) - x = rs.randint(low, high, shape, dtype=dtype) - x[x == 0] = 1 # Hack. Never return zero so tests of division don't error out. - return x - elif dtype_str in float_dtypes: - return rs.normal(0, 1, shape).astype(dtype) - else: - raise RuntimeError(f'Unknown dtype {dtype_str}') - - -def to_triton(x: np.ndarray, device='cuda') -> Union[TensorWrapper, torch.Tensor]: - t = x.dtype.name - if t in uint_dtypes: - signed_type_name = t.lstrip('u') # e.g. "uint16" -> "int16" - x_signed = x.astype(getattr(np, signed_type_name)) - return reinterpret(torch.tensor(x_signed, device=device), getattr(tl, t)) - else: - return torch.tensor(x, device=device) - - -def torch_dtype_name(dtype) -> str: - if isinstance(dtype, triton.language.dtype): - return dtype.name - elif isinstance(dtype, torch.dtype): - # 'torch.int64' -> 'int64' - m = re.match(r'^torch\.(\w+)$', str(dtype)) - return m.group(1) - else: - raise TypeError(f'not a triton or torch dtype: {type(dtype)}') - - -def to_numpy(x): - if isinstance(x, TensorWrapper): - return x.base.cpu().numpy().astype(getattr(np, torch_dtype_name(x.dtype))) - elif isinstance(x, torch.Tensor): - return x.cpu().numpy() - else: - raise ValueError(f"Not a triton-compatible tensor: {x}") - - -def patch_kernel(template, to_replace): - kernel = triton.JITFunction(template.fn) - for key, value in to_replace.items(): - kernel.src = kernel.src.replace(key, value) - return kernel - - -@pytest.mark.parametrize("dtype_x", [dtype_x for dtype_x in dtypes]) -def test_empty_kernel(dtype_x, device='cuda'): - SIZE = 128 - - @triton.jit - def kernel(X, SIZE: tl.constexpr): - pass - x = to_triton(numpy_random(SIZE, dtype_str=dtype_x), device=device) - kernel[(1, )](x, SIZE=SIZE, num_warps=4) - - -# generic test functions -def _test_unary(dtype_x, expr, numpy_expr=None, device='cuda'): - SIZE = 128 - # define the kernel / launch-grid - - @triton.jit - def kernel(Z, X, SIZE: tl.constexpr): - off = tl.arange(0, SIZE) - x = tl.load(X + off) - z = GENERATE_TEST_HERE - tl.store(Z + off, z) - - kernel = patch_kernel(kernel, {'GENERATE_TEST_HERE': expr}) - # inputs - x = numpy_random(SIZE, dtype_str=dtype_x) - if 'log' in expr: - x = np.abs(x) + 0.01 - # reference result - z_ref = eval(expr if numpy_expr is None else numpy_expr) - # triton result - x_tri = to_triton(x, device=device) - z_tri = to_triton(np.empty_like(z_ref), device=device) - kernel[(1, )](z_tri, x_tri, SIZE=SIZE, num_warps=4) - # compare - np.testing.assert_allclose(z_ref, to_numpy(z_tri), rtol=0.01) - - -def _binary_op_dtype_override(a: str, b: str) -> Optional[np.dtype]: - """ - Given two dtype strings, returns the numpy dtype Triton thinks binary - operations on the two types should return. Returns None if the return value - matches numpy. This is generally needed because Triton and pytorch return - narrower floating point types than numpy in mixed operations, and because - Triton follows C/C++ semantics around mixed signed/unsigned operations, and - numpy/pytorch do not. - """ - overrides = { - ('float16', 'int16'): np.float16, - ('float16', 'int32'): np.float16, - ('float16', 'int64'): np.float16, - ('float16', 'uint16'): np.float16, - ('float16', 'uint32'): np.float16, - ('float16', 'uint64'): np.float16, - ('int8', 'uint8'): np.uint8, - ('int8', 'uint16'): np.uint16, - ('int8', 'uint32'): np.uint32, - ('int8', 'uint64'): np.uint64, - ('int16', 'uint16'): np.uint16, - ('int16', 'uint32'): np.uint32, - ('int16', 'uint64'): np.uint64, - ('int32', 'uint32'): np.uint32, - ('int32', 'uint64'): np.uint64, - ('int64', 'uint64'): np.uint64, - } - key = (a, b) if a < b else (b, a) - return overrides.get(key) - - -def _test_binary(dtype_x, dtype_y, expr, numpy_expr=None, mode_x='real', mode_y='real', device='cuda', y_low=None, y_high=None): - SIZE = 128 - # define the kernel / launch-grid - - @triton.jit - def kernel(Z, X, Y, SIZE: tl.constexpr): - off = tl.arange(0, SIZE) - x = tl.load(X + off) - y = tl.load(Y + off) - z = GENERATE_TEST_HERE - tl.store(Z + off, z) - - kernel = patch_kernel(kernel, {'GENERATE_TEST_HERE': expr}) - # inputs - rs = RandomState(17) - x = numpy_random(SIZE, dtype_str=dtype_x, rs=rs) - y = numpy_random(SIZE, dtype_str=dtype_y, rs=rs, low=y_low, high=y_high) - if mode_x == 'nan': - x[:] = float('nan') - if mode_y == 'nan': - y[:] = float('nan') - # reference result - z_ref = eval(expr if numpy_expr is None else numpy_expr) - dtype_z = _binary_op_dtype_override(dtype_x, dtype_y) - if dtype_z is not None: - z_ref = z_ref.astype(dtype_z) - # triton result - x_tri = to_triton(x, device=device) - y_tri = to_triton(y, device=device) - z_tri = to_triton(np.empty(SIZE, dtype=z_ref.dtype), device=device) - kernel[(1, )](z_tri, x_tri, y_tri, SIZE=SIZE, num_warps=4) - np.testing.assert_allclose(z_ref, to_numpy(z_tri), err_msg=expr, rtol=0.01) - - -def _mod_operation_ill_conditioned(dtype_x, dtype_y) -> bool: - # The result of x % y is ill-conditioned if x % y is much smaller than x. - # pytorch/CUDA has slightly different (probably better) rounding on - # remainders than stock LLVM. We currently don't expect to match it - # bit-for-bit. - return (dtype_x, dtype_y) in [ - ('int32', 'float16'), - ('int32', 'float32'), - ('int64', 'float16'), - ('int64', 'float32'), - ('int64', 'float64'), - ('uint16', 'float16'), - ('uint16', 'float32'), - ('uint32', 'float16'), - ('uint32', 'float32'), - ('uint64', 'float16'), - ('uint64', 'float32'), - ('uint64', 'float64'), - ] - -# --------------- -# test binary ops -# --------------- - - -@pytest.mark.parametrize("dtype_x, dtype_y, op", [ - (dtype_x, dtype_y, op) - for op in ['+', '-', '*', '/', '%'] - for dtype_x in dtypes - for dtype_y in dtypes -]) -def test_bin_op(dtype_x, dtype_y, op, device='cuda'): - expr = f' x {op} y' - if op == '%' and dtype_x in int_dtypes + uint_dtypes and dtype_y in int_dtypes + uint_dtypes: - # LLVM has 'numpy.fmod', not 'numpy.remainder', semantics on integer remainders. - numpy_expr = 'np.fmod(x, y)' - elif op in ('/', '%') and dtype_x in ('int16', 'float16') and dtype_y in ('int16', 'float16'): - # Triton promotes 16-bit floating-point / and % to 32-bit because there - # are no native div or FRem operations on float16. Since we have to - # convert anyway, we may as well take the accuracy bump. - numpy_expr = f'x.astype(np.float32) {op} y.astype(np.float32)' - elif (dtype_x in uint_dtypes and dtype_y in int_dtypes and _bitwidth(dtype_x) >= _bitwidth(dtype_y)): - numpy_expr = f'x.astype(np.{dtype_x}) {op} y.astype(np.{dtype_x})' - elif (dtype_y in uint_dtypes and dtype_x in int_dtypes and _bitwidth(dtype_y) >= _bitwidth(dtype_x)): - numpy_expr = f'x.astype(np.{dtype_y}) {op} y.astype(np.{dtype_y})' - else: - numpy_expr = None - if op == '%' and _mod_operation_ill_conditioned(dtype_x, dtype_y): - with pytest.raises(AssertionError, match='Not equal to tolerance'): - _test_binary(dtype_x, dtype_y, expr, numpy_expr, device=device) - elif (op in ('%', '/') and - ((dtype_x in int_dtypes and dtype_y in uint_dtypes) or - (dtype_x in uint_dtypes and dtype_y in int_dtypes))): - with pytest.raises(triton.code_gen.CompilationError) as exc_info: - _test_binary(dtype_x, dtype_y, expr, numpy_expr, device=device) - assert re.match('Cannot use .* because they have different signedness', str(exc_info.value.__cause__)) - else: - _test_binary(dtype_x, dtype_y, expr, numpy_expr, device=device) - - -@pytest.mark.parametrize("dtype_x, dtype_y", - [(dtype_x, dtype_y) for dtype_x in int_dtypes for dtype_y in int_dtypes] + - [(dtype_x, dtype_y) for dtype_x in uint_dtypes for dtype_y in uint_dtypes] - ) -def test_floordiv(dtype_x, dtype_y, device='cuda'): - # Triton has IEEE, not numpy/torch, semantics for %, and those carry - # through to //, so we have to use a nonstandard expression to get a - # reference result for //. - expr = 'x // y' - numpy_expr = '((x - np.fmod(x, y)) / y)' - _test_binary(dtype_x, dtype_y, expr, numpy_expr, device=device) - - -# --------------- -# test bitwise ops -# --------------- -@pytest.mark.parametrize("dtype_x, dtype_y, op", [ - (dtype_x, dtype_y, op) - for op in ['&', '|', '^'] - for dtype_x in dtypes - for dtype_y in dtypes -]) -def test_bitwise_op(dtype_x, dtype_y, op, device='cuda'): - expr = f'x {op} y' - if (dtype_x in uint_dtypes and dtype_y in int_dtypes and _bitwidth(dtype_x) >= _bitwidth(dtype_y)): - numpy_expr = f'x.astype(np.{dtype_x}) {op} y.astype(np.{dtype_x})' - elif (dtype_y in uint_dtypes and dtype_x in int_dtypes and _bitwidth(dtype_y) >= _bitwidth(dtype_x)): - numpy_expr = f'x.astype(np.{dtype_y}) {op} y.astype(np.{dtype_y})' - else: - numpy_expr = None - if 'float' in dtype_x + dtype_y: - with pytest.raises(triton.code_gen.CompilationError) as exc_info: - _test_binary(dtype_x, dtype_y, expr, numpy_expr='np.array([])', device=device) - # The CompilationError must have been caused by a C++ exception with this text. - assert re.match('invalid operands of type', str(exc_info.value.__cause__)) - else: - _test_binary(dtype_x, dtype_y, expr, numpy_expr, device=device) - - -@pytest.mark.parametrize("dtype_x, dtype_y, op", [ - (dtype_x, dtype_y, op) - for op in ['<<', '>>'] - for dtype_x in int_dtypes + uint_dtypes - for dtype_y in int_dtypes + uint_dtypes -]) -def test_shift_op(dtype_x, dtype_y, op, device='cuda'): - expr = f'x {op} y' - bw = max(_bitwidth(dtype_x), _bitwidth(dtype_y)) - dtype_z = f'uint{bw}' - numpy_expr = f'x.astype(np.{dtype_z}) {op} y.astype(np.{dtype_z})' - _test_binary(dtype_x, dtype_y, expr, numpy_expr, device=device, y_low=0, y_high=65) - - -# --------------- -# test compare ops -# --------------- -ops = ['==', '!=', '>', '<', '>=', '<='] - - -@pytest.mark.parametrize("dtype_x, dtype_y, op, mode_x, mode_y", - # real - [ - (dtype_x, dtype_y, op, 'real', 'real') - for op in ops - for dtype_x in dtypes - for dtype_y in dtypes - ] + - # NaNs - [('float32', 'float32', op, mode_x, mode_y) - for op in ops - for mode_x, mode_y in [('nan', 'real'), - ('real', 'nan'), - ('nan', 'nan')] - - ]) -def test_compare_op(dtype_x, dtype_y, op, mode_x, mode_y, device='cuda'): - expr = f'x {op} y' - if (dtype_x in uint_dtypes and dtype_y in int_dtypes and _bitwidth(dtype_x) >= _bitwidth(dtype_y)): - numpy_expr = f'x.astype(np.{dtype_x}) {op} y.astype(np.{dtype_x})' - elif (dtype_y in uint_dtypes and dtype_x in int_dtypes and _bitwidth(dtype_y) >= _bitwidth(dtype_x)): - numpy_expr = f'x.astype(np.{dtype_y}) {op} y.astype(np.{dtype_y})' - else: - numpy_expr = None - _test_binary(dtype_x, dtype_y, expr, numpy_expr, mode_x=mode_x, mode_y=mode_y, device=device) - - -# --------------- -# test unary ops -# --------------- -@pytest.mark.parametrize("dtype_x, expr", [ - (dtype_x, ' -x') for dtype_x in dtypes -] + [ - (dtype_x, ' ~x') for dtype_x in int_dtypes -]) -def test_unary_op(dtype_x, expr, device='cuda'): - _test_unary(dtype_x, expr, device=device) - -# ---------------- -# test math ops -# ---------------- -# @pytest.mark.paramterize("expr", [ -# 'exp', 'log', 'cos', 'sin' -# ]) - - -@pytest.mark.parametrize("expr", [ - 'exp', 'log', 'cos', 'sin' -]) -def test_math_op(expr, device='cuda'): - _test_unary('float32', f'tl.{expr}(x)', f'np.{expr}(x) ', device=device) - - -# ---------------- -# test indexing -# ---------------- - - -def make_ptr_str(name, shape): - rank = len(shape) - offsets = [] - stride = 1 - for i in reversed(range(rank)): - idx = ', '.join([':' if ii == i else 'None' for ii in range(rank)]) - offsets += [f'tl.arange(0, {shape[i]})[{idx}]*{stride}'] - stride *= shape[i] - return f"{name} + {' + '.join(offsets)}" - - -@pytest.mark.parametrize("expr, dtype_str", [ - (f'x[{s}]', d) - for s in ['None, :', ':, None', 'None, :, :', ':, :, None'] - for d in ['int32', 'uint32', 'uint16'] -]) -def test_index1d(expr, dtype_str, device='cuda'): - rank_x = expr.count(':') - rank_y = expr.count(',') + 1 - shape_x = [32 for _ in range(rank_x)] - shape_z = [32 for _ in range(rank_y)] - - # Triton kernel - @triton.jit - def kernel(Z, X, SIZE: tl.constexpr): - m = tl.arange(0, SIZE) - n = tl.arange(0, SIZE) - x = tl.load(X_PTR_EXPR) - z = GENERATE_TEST_HERE - tl.store(Z_PTR_EXPR, z) - - to_replace = { - 'X_PTR_EXPR': make_ptr_str('X', shape_x), - 'Z_PTR_EXPR': make_ptr_str('Z', shape_z), - 'GENERATE_TEST_HERE': expr, - } - kernel = patch_kernel(kernel, to_replace) - - # torch result - x = numpy_random(shape_x, dtype_str=dtype_str) - y = np.zeros(shape_z, dtype=getattr(np, dtype_str)) - z_ref = eval(expr) + y - # triton result - z_tri = to_triton(np.empty_like(z_ref), device=device) - x_tri = to_triton(x) - kernel[(1, )](z_tri, x_tri, num_warps=1, SIZE=shape_x[0]) - # compare - assert (z_ref == to_numpy(z_tri)).all() - - -# --------------- -# test tuples -# --------------- - - -@triton.jit -def fn(a, b): - return a + b, \ - a - b, \ - a * b - - -def test_tuples(): - device = 'cuda' - - @triton.jit - def with_fn(X, Y, A, B, C): - x = tl.load(X) - y = tl.load(Y) - a, b, c = fn(x, y) - tl.store(A, a) - tl.store(B, b) - tl.store(C, c) - - @triton.jit - def without_fn(X, Y, A, B, C): - x = tl.load(X) - y = tl.load(Y) - a, b, c = x + y, x - y, x * y - tl.store(A, a) - tl.store(B, b) - tl.store(C, c) - - x = torch.tensor([1.3], device=device, dtype=torch.float32) - y = torch.tensor([1.9], device=device, dtype=torch.float32) - a_tri = torch.tensor([0], device=device, dtype=torch.float32) - b_tri = torch.tensor([0], device=device, dtype=torch.float32) - c_tri = torch.tensor([0], device=device, dtype=torch.float32) - for kernel in [with_fn, without_fn]: - kernel[(1, )](x, y, a_tri, b_tri, c_tri, num_warps=1) - a_ref, b_ref, c_ref = x + y, x - y, x * y - assert a_tri == a_ref - assert b_tri == b_ref - assert c_tri == c_ref - - -# --------------- -# test atomics -# --------------- -@pytest.mark.parametrize("op, dtype_x_str, mode", itertools.chain.from_iterable([ - [ - ('add', 'float16', mode), - ('add', 'uint32', mode), ('add', 'int32', mode), ('add', 'float32', mode), - ('max', 'uint32', mode), ('max', 'int32', mode), ('max', 'float32', mode), - ('min', 'uint32', mode), ('min', 'int32', mode), ('min', 'float32', mode), - ] - for mode in ['all_neg', 'all_pos', 'min_neg', 'max_pos']])) -def test_atomic_rmw(op, dtype_x_str, mode, device='cuda'): - n_programs = 5 - - # triton kernel - @triton.jit - def kernel(X, Z): - pid = tl.program_id(0) - x = tl.load(X + pid) - old = GENERATE_TEST_HERE - - kernel = patch_kernel(kernel, {'GENERATE_TEST_HERE': f'tl.atomic_{op}(Z, x)'}) - numpy_op = {'add': np.sum, 'max': np.max, 'min': np.min}[op] - max_neutral = float('-inf') if dtype_x_str in float_dtypes else np.iinfo(getattr(np, dtype_x_str)).min - min_neutral = float('inf') if dtype_x_str in float_dtypes else np.iinfo(getattr(np, dtype_x_str)).max - neutral = {'add': 0, 'max': max_neutral, 'min': min_neutral}[op] - - # triton result - rs = RandomState(17) - x = numpy_random((n_programs, ), dtype_str=dtype_x_str, rs=rs) - if mode == 'all_neg': - x = -np.abs(x) - if mode == 'all_pos': - x = np.abs(x) - if mode == 'min_neg': - idx = rs.randint(n_programs, size=(1, )).item() - x[idx] = -np.max(np.abs(x)) - 1 - if mode == 'max_pos': - idx = rs.randint(n_programs, size=(1, )).item() - x[idx] = np.max(np.abs(x)) + 1 - x_tri = to_triton(x, device=device) - - z_tri = to_triton(np.array([neutral], dtype=getattr(np, dtype_x_str)), device=device) - kernel[(n_programs, )](x_tri, z_tri) - # torch result - z_ref = numpy_op(x).astype(getattr(np, dtype_x_str)) - # compare - exact = op not in ['add'] - if exact: - assert z_ref.item() == to_numpy(z_tri).item() - else: - np.testing.assert_allclose(z_ref, to_numpy(z_tri), rtol=0.01) - - -# --------------- -# test cast -# --------------- -@pytest.mark.parametrize("dtype_x, dtype_z, bitcast", [ - (dtype_x, dtype_z, False) - for dtype_x in dtypes - for dtype_z in dtypes -] + [ - ('float32', 'bfloat16', False), - ('bfloat16', 'float32', False), - ('float32', 'int32', True), -] + [ - (f'uint{x}', f'int{x}', True) for x in [8, 16, 32, 64] -] + [ - (f'int{x}', f'uint{x}', True) for x in [8, 16, 32, 64] -]) -def test_cast(dtype_x, dtype_z, bitcast, device='cuda'): - # This is tricky because numpy doesn't have bfloat, and torch doesn't have uints. - x0 = 43 if dtype_x in int_dtypes else 43.5 - if dtype_x.startswith('bfloat'): - x_tri = torch.tensor([x0], dtype=getattr(torch, dtype_x), device=device) - else: - x = np.array([x0], dtype=getattr(np, dtype_x)) - x_tri = to_triton(x) - - # triton kernel - @triton.jit - def kernel(X, Z, BITCAST: tl.constexpr): - x = tl.load(X) - z = x.to(Z.dtype.element_ty, bitcast=BITCAST) - tl.store(Z, z) - - # triton result - if dtype_z.startswith('bfloat'): - z_tri = torch.empty((1,), dtype=getattr(torch, dtype_z), device=device) - else: - z_tri = to_triton(np.empty((1, ), dtype=getattr(np, dtype_z)), device=device) - kernel[(1, )](x_tri, z_tri, BITCAST=bitcast) - # torch result - if dtype_z.startswith('bfloat') or dtype_x.startswith('bfloat'): - assert bitcast is False - z_ref = x_tri.to(z_tri.dtype) - assert z_tri == z_ref - else: - if bitcast: - z_ref = x.view(getattr(np, dtype_z)) - else: - z_ref = x.astype(getattr(np, dtype_z)) - assert to_numpy(z_tri) == z_ref - - -def test_f8_f16_roundtrip(): - """Tests that converting an f8 to f16 and back to f8 doesn't change its value""" - @triton.jit - def copy_kernel(input_ptr, output_ptr, n_elements, BLOCK_SIZE: tl.constexpr): - offsets = tl.program_id(axis=0) * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) - mask = offsets < n_elements - input = tl.load(input_ptr + offsets, mask=mask) - output = input - tl.store(output_ptr + offsets, output, mask=mask) - - f8_tensor = torch.tensor(range(-128, 128), dtype=torch.int8, device='cuda') - f8 = triton.reinterpret(f8_tensor, tl.float8) - n_elements = f8_tensor.numel() - f16 = torch.empty_like(f8_tensor, dtype=torch.float16) - grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),) - copy_kernel[grid](f8, f16, n_elements, BLOCK_SIZE=1024) - - f8_output_tensor = torch.empty_like(f16, dtype=torch.int8) - f8_output = triton.reinterpret(f8_output_tensor, tl.float8) - copy_kernel[grid](f16, f8_output, n_elements, BLOCK_SIZE=1024) - - assert torch.all(f8_tensor == f8_output_tensor) - - -def test_f16_to_f8_rounding(): - """Takes all float16s, converts them to float8 and back to float16. Checks that the absolute - error is the minimum over all float8. - - Or the same explanation a bit mathier: - for all f16 |f16 - fromf8(tof8(f16))| == min over all f8 |f16 - fromf8(f8)|""" - @triton.jit - def copy_kernel(input_ptr, output_ptr, n_elements, BLOCK_SIZE: tl.constexpr): - offsets = tl.program_id(axis=0) * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) - mask = offsets < n_elements - input = tl.load(input_ptr + offsets, mask=mask) - output = input - tl.store(output_ptr + offsets, output, mask=mask) - - # torch.view with a dtype isn't supported in triton's torch yet so use numpy's view - f16_input_np = ( - np.array( - range(-int(2 ** (16 - 1)), int(2 ** (16 - 1))), dtype=np.int16, - ) - .view(np.float16) - ) - f16_input = torch.tensor(f16_input_np, dtype=torch.float16, device='cuda') - n_elements = f16_input.numel() - f8_output_tensor = torch.empty_like(f16_input, dtype=torch.int8) - f8_output = triton.reinterpret(f8_output_tensor, tl.float8) - grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),) - copy_kernel[grid](f16_input, f8_output, n_elements, BLOCK_SIZE=1024) - - f16_output = torch.empty_like(f16_input, dtype=torch.float16) - copy_kernel[grid](f8_output, f16_output, n_elements, BLOCK_SIZE=1024) - - abs_error = torch.abs(f16_input - f16_output) - - all_f8_vals_tensor = torch.tensor(range(2 ** 8), dtype=torch.uint8, device='cuda') - all_f8_vals = triton.reinterpret(all_f8_vals_tensor, tl.float8) - all_f8_vals_in_f16 = torch.empty_like(all_f8_vals_tensor, dtype=torch.float16) - copy_kernel[grid](all_f8_vals, all_f8_vals_in_f16, n_elements=256, BLOCK_SIZE=1024) - - all_finite_f8_vals_in_f16 = all_f8_vals_in_f16[ - torch.isfinite(all_f8_vals_in_f16) - ] - - min_error = torch.min( - torch.abs( - f16_input.reshape((-1, 1)) - - all_finite_f8_vals_in_f16.reshape((1, -1)) - ), - dim=1, - )[0] - # 1.9375 is float8 max - mismatch = torch.logical_and( - abs_error != min_error, torch.logical_and(torch.isfinite(f16_input), torch.abs(f16_input) < 1.9375) - ) - assert torch.all( - torch.logical_not(mismatch) - ), f"{f16_input[mismatch]=} {f16_output[mismatch]=} {abs_error[mismatch]=} {min_error[mismatch]=}" - - -# --------------- -# test reduce -# --------------- - - -@pytest.mark.parametrize("dtype_str, shape", - [(dtype, shape) - for dtype in dtypes - for shape in [128, 512]]) -def test_reduce1d(dtype_str, shape, device='cuda'): - - # triton kernel - @triton.jit - def kernel(X, Z, BLOCK: tl.constexpr): - x = tl.load(X + tl.arange(0, BLOCK)) - tl.store(Z, tl.sum(x, axis=0)) - - rs = RandomState(17) - x = numpy_random((shape,), dtype_str=dtype_str, rs=rs) - # numpy result - z_ref = np.sum(x).astype(getattr(np, dtype_str)) - # triton result - x_tri = to_triton(x, device=device) - z_tri = to_triton(numpy_random((1,), dtype_str=dtype_str, rs=rs), device=device) - kernel[(1,)](x_tri, z_tri, BLOCK=shape) - # compare - np.testing.assert_allclose(z_ref, to_numpy(z_tri), rtol=0.01) - - -@pytest.mark.parametrize("dtype_str, shape, axis", [ - (dtype, (1, 1024), 1) for dtype in ['float32', 'uint32'] -]) -def test_reduce2d(dtype_str, shape, axis, device='cuda'): - # triton kernel - @triton.jit - def kernel(X, Z, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, AXIS: tl.constexpr): - range_m = tl.arange(0, BLOCK_M) - range_n = tl.arange(0, BLOCK_N) - x = tl.load(X + range_m[:, None] * BLOCK_N + range_n[None, :]) - z = tl.sum(x, axis=AXIS) - tl.store(Z + range_m, z) - # input - x = numpy_random(shape, dtype_str=dtype_str) - # triton result - x_tri = to_triton(x) - z_tri = to_triton(np.empty((shape[0],), dtype=getattr(np, dtype_str)), device=device) - kernel[(1,)](x_tri, z_tri, BLOCK_M=shape[0], BLOCK_N=shape[1], AXIS=axis) - # numpy reference result - z_ref = np.sum(x, axis=axis).astype(x.dtype) - # compare - np.testing.assert_allclose(z_ref, to_numpy(z_tri), rtol=0.01) - -# --------------- -# test permute -# --------------- - - -@pytest.mark.parametrize("dtype_str, shape, perm", - [(dtype, shape, perm) - for dtype in ['float32'] - for shape in [(128, 128)] - for perm in [(1, 0)]]) -def test_permute(dtype_str, shape, perm, device='cuda'): - - # triton kernel - @triton.jit - def kernel(X, stride_xm, stride_xn, - Z, stride_zm, stride_zn, - BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr): - off_m = tl.arange(0, BLOCK_M) - off_n = tl.arange(0, BLOCK_N) - Xs = X + off_m[:, None] * stride_xm + off_n[None, :] * stride_xn - Zs = Z + off_m[:, None] * stride_zm + off_n[None, :] * stride_zn - tl.store(Zs, tl.load(Xs)) - # input - x = numpy_random(shape, dtype_str=dtype_str) - # triton result - z_tri = to_triton(np.empty_like(x), device=device) - x_tri = to_triton(x, device=device) - pgm = kernel[(1, 1)](x_tri, x_tri.stride(0), x_tri.stride(1), - z_tri, z_tri.stride(1), z_tri.stride(0), - BLOCK_M=shape[0], BLOCK_N=shape[1]) - # torch result - z_ref = x.transpose(*perm) - # compare - triton.testing.assert_almost_equal(z_tri, z_ref) - # parse ptx to make sure ld/st are vectorized - ptx = pgm.asm['ptx'] - assert 'ld.global.v4' in ptx - assert 'st.global.v4' in ptx - -# --------------- -# test dot -# --------------- - - -@pytest.mark.parametrize("epilogue, allow_tf32, dtype", - [(epilogue, allow_tf32, dtype) - for epilogue in ['none', 'trans', 'add-matrix', 'add-rows', 'add-cols'] - for allow_tf32 in [True, False] - for dtype in ['float32', 'int8'] - if not (allow_tf32 and (dtype == 'int8'))]) -def test_dot(epilogue, allow_tf32, dtype, device='cuda'): - cc = _triton.runtime.cc(_triton.runtime.backend.CUDA, torch.cuda.current_device()) - if cc < 80: - if dtype == 'int8': - pytest.skip("Only test int8 on devices with sm >= 80") - elif dtype == 'float32' and allow_tf32: - pytest.skip("Only test tf32 on devices with sm >= 80") - - # triton kernel - @triton.jit - def kernel(X, stride_xm, stride_xk, - Y, stride_yk, stride_yn, - Z, stride_zm, stride_zn, - BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, - ADD_MATRIX: tl.constexpr, ADD_ROWS: tl.constexpr, ADD_COLS: tl.constexpr, - ALLOW_TF32: tl.constexpr): - off_m = tl.arange(0, BLOCK_M) - off_n = tl.arange(0, BLOCK_N) - off_k = tl.arange(0, BLOCK_K) - Xs = X + off_m[:, None] * stride_xm + off_k[None, :] * stride_xk - Ys = Y + off_k[:, None] * stride_yk + off_n[None, :] * stride_yn - Zs = Z + off_m[:, None] * stride_zm + off_n[None, :] * stride_zn - z = tl.dot(tl.load(Xs), tl.load(Ys), allow_tf32=ALLOW_TF32) - if ADD_MATRIX: - z += tl.load(Zs) - if ADD_ROWS: - ZRs = Z + off_m * stride_zm - z += tl.load(ZRs)[:, None] - if ADD_COLS: - ZCs = Z + off_n * stride_zn - z += tl.load(ZCs)[None, :] - tl.store(Zs, z) - # input - M, N, K = 64, 64, 32 - rs = RandomState(17) - x = numpy_random((M, K), dtype_str=dtype, rs=rs) - y = numpy_random((K, N), dtype_str=dtype, rs=rs) - if allow_tf32: - x = (x.view('uint32') & np.uint32(0xffffe000)).view('float32') - y = (y.view('uint32') & np.uint32(0xffffe000)).view('float32') - x_tri = to_triton(x, device=device) - y_tri = to_triton(y, device=device) - # triton result - z = numpy_random((M, N), dtype_str=dtype, rs=rs) - z_tri = to_triton(z, device=device) - if epilogue == 'trans': - z_tri = torch.as_strided(z_tri, (M, N), z_tri.stride()[::-1]) - pgm = kernel[(1, 1)](x_tri, x_tri.stride(0), x_tri.stride(1), - y_tri, y_tri.stride(0), y_tri.stride(1), - z_tri, z_tri.stride(0), z_tri.stride(1), - BLOCK_M=M, BLOCK_K=K, BLOCK_N=N, - ADD_MATRIX=epilogue == 'add-matrix', - ADD_ROWS=epilogue == 'add-rows', - ADD_COLS=epilogue == 'add-cols', - ALLOW_TF32=allow_tf32) - # torch result - z_ref = np.matmul(x, y) - if epilogue == 'add-matrix': - z_ref += z - if epilogue == 'add-rows': - z_ref += z[:, 0][:, None] - if epilogue == 'add-cols': - z_ref += z[0, :][None, :] - # compare - np.testing.assert_allclose(z_ref, to_numpy(z_tri), rtol=0.01) - # make sure ld/st are vectorized - ptx = pgm.asm['ptx'] - assert 'ld.global.v4' in ptx - assert 'st.global.v4' in ptx - if allow_tf32: - assert 'mma.sync.aligned.m16n8k8.row.col.f32.tf32.tf32.f32' in ptx - elif dtype == 'float32': - assert 'mma.sync.aligned.m16n8k8.row.col.f32.tf32.tf32.f32' not in ptx - elif dtype == 'int8': - assert 'mma.sync.aligned.m16n8k32.row.col.satfinite.s32.s8.s8.s32' in ptx - - -def test_dot_without_load(): - @triton.jit - def kernel(out): - pid = tl.program_id(axis=0) - a = tl.zeros((32, 32), tl.float32) - b = tl.zeros((32, 32), tl.float32) - c = tl.zeros((32, 32), tl.float32) - c = tl.dot(a, b) - pout = out + tl.arange(0, 32)[:, None] * 32 + tl.arange(0, 32)[None, :] - tl.store(pout, c) - - out = torch.ones((32, 32), dtype=torch.float32, device="cuda") - kernel[(1,)](out) - -# --------------- -# test arange -# --------------- - - -@pytest.mark.parametrize("start", [0, 1, 7, 16]) -def test_arange(start, device='cuda'): - BLOCK = 128 - z_tri = torch.empty(BLOCK, dtype=torch.int32, device=device) - - @triton.jit - def _kernel(z, BLOCK: tl.constexpr, - START: tl.constexpr, END: tl.constexpr): - off = tl.arange(0, BLOCK) - val = tl.arange(START, END) - tl.store(z + off, val) - _kernel[(1,)](z_tri, START=start, END=start + BLOCK, BLOCK=BLOCK) - z_ref = torch.arange(start, BLOCK + start, dtype=torch.int32, device=device) - triton.testing.assert_almost_equal(z_tri, z_ref) - -# --------------- -# test load -# --------------- -# 'bfloat16': torch.bfloat16, -# Testing masked loads with an intermate copy to shared memory run. - - -@pytest.mark.parametrize("dtype", [torch.float16, torch.float32]) -def test_masked_load_shared_memory(dtype, device='cuda'): - M = 32 - N = 32 - K = 8 - - in1 = torch.rand((M, K), dtype=dtype, device=device) - in2 = torch.rand((K, N), dtype=dtype, device=device) - out = torch.zeros((M, N), dtype=dtype, device=device) - - @triton.jit - def _kernel(in1_ptr, in2_ptr, output_ptr, - in_stride, in2_stride, out_stride, - in_numel, in2_numel, out_numel, - M: tl.constexpr, N: tl.constexpr, K: tl.constexpr): - - M_offsets = tl.arange(0, M) - N_offsets = tl.arange(0, N) - K_offsets = tl.arange(0, K) - - in_offsets = M_offsets[:, None] * in_stride + K_offsets[None, :] - in2_offsets = K_offsets[:, None] * in2_stride + N_offsets[None, :] - - # Load inputs. - x = tl.load(in1_ptr + in_offsets, mask=in_offsets < in_numel) - w = tl.load(in2_ptr + in2_offsets, mask=in2_offsets < in2_numel) - - # Without a dot product the memory doesn't get promoted to shared. - o = tl.dot(x, w) - - # Store output - output_offsets = M_offsets[:, None] * out_stride + N_offsets[None, :] - tl.store(output_ptr + output_offsets, o, mask=output_offsets < in2_numel) - - pgm = _kernel[(1,)](in1, in2, out, - in1.stride()[0], - in2.stride()[0], - out.stride()[0], - in1.numel(), - in2.numel(), - out.numel(), - M=M, N=N, K=K) - - reference_out = torch.matmul(in1, in2) - triton.testing.allclose(out, reference_out) - - -@pytest.mark.parametrize("cache", ["", ".ca", ".cg"]) -def test_load_cache_modifier(cache): - src = torch.empty(128, device='cuda') - dst = torch.empty(128, device='cuda') - - @triton.jit - def _kernel(dst, src, CACHE: tl.constexpr): - offsets = tl.arange(0, 128) - x = tl.load(src + offsets, cache_modifier=CACHE) - tl.store(dst + offsets, x) - - pgm = _kernel[(1,)](dst, src, CACHE=cache) - ptx = pgm.asm['ptx'] - if cache == '': - assert 'ld.global.ca' not in ptx - assert 'ld.global.cg' not in ptx - if cache == '.cg': - assert 'ld.global.cg' in ptx - assert 'ld.global.ca' not in ptx - if cache == '.ca': - assert 'ld.global.ca' in ptx - assert 'ld.global.cg' not in ptx - -# --------------- -# test store -# --------------- - -# --------------- -# test if -# --------------- - -# --------------- -# test for -# --------------- - -# --------------- -# test while -# --------------- - -# --------------- -# test default -# --------------- -# TODO: can't be local to test_default - - -@triton.jit -def _impl(value=10): - return value - - -def test_default(): - value = 5 - ret0 = torch.zeros(1, dtype=torch.int32, device='cuda') - ret1 = torch.zeros(1, dtype=torch.int32, device='cuda') - - @triton.jit - def _kernel(ret0, ret1, value): - tl.store(ret0, _impl()) - tl.store(ret1, _impl(value)) - - _kernel[(1,)](ret0, ret1, value) - assert ret0.item() == 10 - assert ret1.item() == value - -# --------------- -# test noop -# ---------------- - - -def test_noop(device='cuda'): - @triton.jit - def kernel(x): - pass - x = to_triton(numpy_random((1,), dtype_str='int32'), device=device) - kernel[(1, )](x) - - -@pytest.mark.parametrize( - "value, overflow", - [(2**64 - 1, False), (2**64, True), (-2**63, False), (-2**63 - 1, True)] -) -def test_value_specialization_overflow(value: int, overflow: bool, device='cuda') -> None: - - @triton.jit - def kernel(VALUE, X): - pass - - x = torch.tensor([3.14159], device='cuda') - - if overflow: - with pytest.raises(RuntimeError, match='integer overflow'): - kernel[(1, )](value, x) - else: - kernel[(1, )](value, x) diff --git a/python/test/unit/language/test_random.py b/python/test/unit/language/test_random.py deleted file mode 100644 index 042065403e8b..000000000000 --- a/python/test/unit/language/test_random.py +++ /dev/null @@ -1,177 +0,0 @@ -import numpy as np -import pytest -import scipy.stats -import torch - -import triton -import triton.language as tl - -##################################### -# Reference Philox Implementation -##################################### - - -class PhiloxConfig: - def __init__(self, PHILOX_ROUND_A, PHILOX_ROUND_B, PHILOX_KEY_A, PHILOX_KEY_B, DTYPE): - self.PHILOX_ROUND_A = np.array(PHILOX_ROUND_A, dtype=DTYPE) - self.PHILOX_ROUND_B = np.array(PHILOX_ROUND_B, dtype=DTYPE) - self.PHILOX_KEY_A = np.array(PHILOX_KEY_A, dtype=DTYPE) - self.PHILOX_KEY_B = np.array(PHILOX_KEY_B, dtype=DTYPE) - self.DTYPE = DTYPE - - -# This is better for GPU -PHILOX_32 = PhiloxConfig( - PHILOX_KEY_A=0x9E3779B9, - PHILOX_KEY_B=0xBB67AE85, - PHILOX_ROUND_A=0xD2511F53, - PHILOX_ROUND_B=0xCD9E8D57, - DTYPE=np.uint32, -) - -# This is what numpy implements -PHILOX_64 = PhiloxConfig( - PHILOX_KEY_A=0x9E3779B97F4A7C15, - PHILOX_KEY_B=0xBB67AE8584CAA73B, - PHILOX_ROUND_A=0xD2E7470EE14C6C93, - PHILOX_ROUND_B=0xCA5A826395121157, - DTYPE=np.uint64, -) - - -class CustomPhilox4x: - def __init__(self, seed, config): - self._config = config - seed = self._into_pieces(seed) - self._key = np.array(seed[:2], dtype=self._dtype) - self._counter = np.array((0, 0) + seed[2:], dtype=self._dtype) - - @property - def _dtype(self): - return self._config.DTYPE - - def _into_pieces(self, n, pad=4): - res = [] - while len(res) < pad: - res.append(np.array(n, dtype=self._dtype)) - n >>= (np.dtype(self._dtype).itemsize * 8) - assert n == 0 - return tuple(res) - - def _multiply_low_high(self, a, b): - low = a * b - high = int(a) * int(b) - high = np.array(high >> (np.dtype(self._dtype).itemsize * 8), dtype=self._dtype) - return low, high - - def _single_round(self, counter, key): - lo0, hi0 = self._multiply_low_high(self._config.PHILOX_ROUND_A, counter[0]) - lo1, hi1 = self._multiply_low_high(self._config.PHILOX_ROUND_B, counter[2]) - ret0 = hi1 ^ counter[1] ^ key[0] - ret1 = lo1 - ret2 = hi0 ^ counter[3] ^ key[1] - ret3 = lo0 - return np.array([ret0, ret1, ret2, ret3], dtype=self._dtype) - - def _raise_key(self, key): - pk = [self._config.PHILOX_KEY_A, self._config.PHILOX_KEY_B] - return key + np.array(pk, dtype=self._dtype) - - def random_raw(self): - counter = self._counter - key = self._key - for _ in range(10): - counter = self._single_round(counter, key) - key = self._raise_key(key) - self.advance(1) - return counter - - def advance(self, n_steps): - self._counter[0] += n_steps - assert self._counter[0] < 2**32, "FIXME: doesn't work for large offsets" - - -class CustomPhilox(CustomPhilox4x): - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - self.buffer = [] - - def random_raw(self): - if len(self.buffer) == 0: - self.buffer = list(super().random_raw())[::-1] - return int(self.buffer.pop()) - - -##################################### -# Unit Tests -##################################### - -BLOCK = 1024 - -# test generation of random uint32 - - -@pytest.mark.parametrize('size, seed', - [(size, seed) for size in ['10', '4,53', '10000'] - for seed in [0, 42, 124, 54, 0xffffffff, 0xdeadbeefcafeb0ba]] - ) -def test_randint(size, seed, device='cuda'): - size = list(map(int, size.split(','))) - - @triton.jit - def kernel(X, N, seed): - offset = tl.program_id(0) * BLOCK + tl.arange(0, BLOCK) - rand = tl.randint(seed, offset) - tl.store(X + offset, rand, mask=offset < N) - # triton result - x = torch.empty(size, dtype=torch.int32, device=device) - N = x.numel() - grid = (triton.cdiv(N, BLOCK),) - kernel[grid](x, N, seed) - out_tri = x.cpu().numpy().astype(np.uint32).flatten().tolist() - # reference result - gen = CustomPhilox4x(seed, config=PHILOX_32) - out_ref = [gen.random_raw()[0] for _ in out_tri] - assert out_tri == out_ref - -# test uniform PRNG - - -@pytest.mark.parametrize('size, seed', - [(size, seed) for size in [1000000] - for seed in [0, 42, 124, 54]] - ) -def test_rand(size, seed, device='cuda'): - @triton.jit - def kernel(X, N, seed): - offset = tl.program_id(0) * BLOCK + tl.arange(0, BLOCK) - rand = tl.rand(seed, offset) - tl.store(X + offset, rand, mask=offset < N) - # triton result - x = torch.empty(size, dtype=torch.float32, device=device) - N = x.numel() - grid = (triton.cdiv(N, BLOCK),) - kernel[grid](x, N, seed) - assert all((x >= 0) & (x <= 1)) - assert scipy.stats.kstest(x.tolist(), 'uniform', args=(0, 1)).statistic < 0.01 - -# test normal PRNG - - -@pytest.mark.parametrize('size, seed', - [(size, seed) for size in [1000000] - for seed in [0, 42, 124, 54]] - ) -def test_randn(size, seed, device='cuda'): - @triton.jit - def kernel(X, N, seed): - offset = tl.program_id(0) * BLOCK + tl.arange(0, BLOCK) - rand = tl.randn(seed, offset) - tl.store(X + offset, rand, mask=offset < N) - # triton result - x = torch.empty(size, dtype=torch.float32, device=device) - N = x.numel() - grid = (triton.cdiv(N, BLOCK),) - kernel[grid](x, N, seed) - assert abs(x.mean()) < 1e-2 - assert abs(x.std() - 1) < 1e-2 diff --git a/python/test/unit/operators/test_blocksparse.py b/python/test/unit/operators/test_blocksparse.py deleted file mode 100644 index 9e0c72de9d66..000000000000 --- a/python/test/unit/operators/test_blocksparse.py +++ /dev/null @@ -1,187 +0,0 @@ -import pytest -import torch - -import triton - - -@pytest.mark.parametrize("MODE", ["sdd", "dds", "dsd"]) -@pytest.mark.parametrize("TRANS_A", [False, True]) -@pytest.mark.parametrize("TRANS_B", [False, True]) -@pytest.mark.parametrize("BLOCK", [16, 32, 64]) -@pytest.mark.parametrize("DTYPE", [torch.float16]) -def test_matmul(MODE, TRANS_A, TRANS_B, BLOCK, DTYPE, Z=3, H=2, M=512, N=384, K=256): - seed = 0 - torch.manual_seed(seed) - is_sdd = MODE == "sdd" - is_dsd = MODE == "dsd" - is_dds = MODE == "dds" - do_sparsify = lambda x: triton.testing.sparsify_tensor(x, layout, BLOCK) - do_mask = lambda x: triton.testing.mask_tensor(x, layout, BLOCK) - # create inputs - # create op - a_shape = (Z, H, K, M) if TRANS_A else (Z, H, M, K) - b_shape = (Z, H, N, K) if TRANS_B else (Z, H, K, N) - c_shape = (Z, H, M, N) - shape = { - "sdd": (M, N), - "dsd": (a_shape[2], a_shape[3]), - "dds": (b_shape[2], b_shape[3]), - }[MODE] - layout = torch.randint(2, (H, shape[0] // BLOCK, shape[1] // BLOCK)) - layout[1, 2, :] = 0 - layout[1, :, 1] = 0 - # create data - a_ref, a_tri = triton.testing.make_pair(a_shape, alpha=.1) - b_ref, b_tri = triton.testing.make_pair(b_shape, alpha=.1) - dc_ref, dc_tri = triton.testing.make_pair(c_shape) - # compute [torch] - dc_ref = do_mask(dc_ref) if is_sdd else dc_ref - a_ref = do_mask(a_ref) if is_dsd else a_ref - b_ref = do_mask(b_ref) if is_dds else b_ref - a_ref.retain_grad() - b_ref.retain_grad() - c_ref = torch.matmul(a_ref.transpose(2, 3) if TRANS_A else a_ref, - b_ref.transpose(2, 3) if TRANS_B else b_ref) - c_ref.backward(dc_ref) - c_ref = do_sparsify(c_ref) if is_sdd else c_ref - da_ref = do_sparsify(a_ref.grad) if is_dsd else a_ref.grad - db_ref = do_sparsify(b_ref.grad) if is_dds else b_ref.grad - # triton result - dc_tri = do_sparsify(dc_tri) if is_sdd else dc_tri - a_tri = do_sparsify(a_tri) if is_dsd else a_tri - b_tri = do_sparsify(b_tri) if is_dds else b_tri - a_tri.retain_grad() - b_tri.retain_grad() - op = triton.ops.blocksparse.matmul(layout, BLOCK, MODE, trans_a=TRANS_A, trans_b=TRANS_B, device="cuda") - c_tri = triton.testing.catch_oor(lambda: op(a_tri, b_tri), pytest) - triton.testing.catch_oor(lambda: c_tri.backward(dc_tri), pytest) - da_tri = a_tri.grad - db_tri = b_tri.grad - # compare - triton.testing.assert_almost_equal(c_ref, c_tri) - triton.testing.assert_almost_equal(da_ref, da_tri) - triton.testing.assert_almost_equal(db_ref, db_tri) - - -configs = [ - (16, 256), - (32, 576), - (64, 1871), - (128, 2511), -] - - -@pytest.mark.parametrize("is_dense", [False, True]) -@pytest.mark.parametrize("BLOCK, WIDTH", configs) -def test_softmax(BLOCK, WIDTH, is_dense, Z=2, H=2, is_causal=True, scale=0.4): - # set seed - torch.random.manual_seed(0) - Z, H, M, N = 2, 3, WIDTH, WIDTH - # initialize layout - # make sure each row has at least one non-zero element - layout = torch.randint(2, (H, M // BLOCK, N // BLOCK)) - if is_dense: - layout[:] = 1 - else: - layout[1, 2, :] = 0 - layout[1, :, 1] = 0 - # initialize data - a_shape = (Z, H, M, N) - a_ref, a_tri = triton.testing.make_pair(a_shape) - dout_ref, dout_tri = triton.testing.make_pair(a_shape) - # compute [torch] - a_ref = triton.testing.mask_tensor(a_ref, layout, BLOCK, value=float("-inf")) - a_ref.retain_grad() - at_mask = torch.ones((M, N), device="cuda") - if is_causal: - at_mask = torch.tril(at_mask) - M = at_mask[None, None, :, :] + torch.zeros_like(a_ref) - a_ref[M == 0] = float("-inf") - out_ref = torch.softmax(a_ref * scale, -1) - out_ref.backward(dout_ref) - out_ref = triton.testing.sparsify_tensor(out_ref, layout, BLOCK) - da_ref = triton.testing.sparsify_tensor(a_ref.grad, layout, BLOCK) - # compute [triton] - a_tri = triton.testing.sparsify_tensor(a_tri, layout, BLOCK) - a_tri.retain_grad() - dout_tri = triton.testing.sparsify_tensor(dout_tri, layout, BLOCK) - op = triton.ops.blocksparse.softmax(layout, BLOCK, device="cuda", is_dense=is_dense) - out_tri = op(a_tri, scale=scale, is_causal=is_causal) - out_tri.backward(dout_tri) - da_tri = a_tri.grad - # compare - triton.testing.assert_almost_equal(out_tri, out_ref) - triton.testing.assert_almost_equal(da_tri, da_ref) - - -@pytest.mark.parametrize("block", [16, 32, 64]) -@pytest.mark.parametrize("dtype", [torch.float16, torch.float32]) -def test_attention_fwd_bwd( - block, - dtype, - input_scale=1.0, - scale=1 / 8.0, - n_ctx=256, - batch_size=2, - n_heads=2, -): - # inputs - qkv_shape = (batch_size, n_heads, n_ctx, 64) - qkvs = [ - torch.nn.Parameter(input_scale * torch.randn(qkv_shape), requires_grad=True).to(dtype).cuda() for _ in range(3) - ] - - # Triton: - n_blocks = n_ctx // block - layout = torch.tril(torch.ones([n_heads, n_blocks, n_blocks], dtype=torch.long)) - query, key, value = [x.clone() for x in qkvs] - query.retain_grad() - key.retain_grad() - value.retain_grad() - attn_out = triton_attention(layout, block, query=query, key=key, value=value, scale=scale) - # ad hoc loss - loss = (attn_out ** 2).mean() - loss.backward() - grads = [query.grad, key.grad, value.grad] - - # Torch version: - torch_q, torch_k, torch_v = [x.clone() for x in qkvs] - attn_mask = torch.ones([n_ctx, n_ctx], device="cuda", dtype=dtype) - attn_mask = torch.tril(attn_mask, diagonal=0) - attn_mask = 1e6 * (-1 + (attn_mask.reshape((1, 1, n_ctx, n_ctx)).cuda())) - torch_q.retain_grad() - torch_k.retain_grad() - torch_v.retain_grad() - scores = scale * torch.einsum("bhsd,bhtd->bhst", torch_q, torch_k) - scores = scores + attn_mask - probs = torch.softmax(scores, dim=-1) - torch_attn_out = torch.einsum("bhst,bhtd->bhsd", probs, torch_v) - # ad hoc loss - torch_loss = (torch_attn_out ** 2).mean() - torch_loss.backward() - torch_grads = [torch_q.grad, torch_k.grad, torch_v.grad] - - # comparison - # print(f"Triton loss {loss} and torch loss {torch_loss}. Also checking grads...") - triton.testing.assert_almost_equal(loss, torch_loss) - for g1, g2 in zip(grads, torch_grads): - triton.testing.assert_almost_equal(g1, g2) - - -@pytest.mark.parametrize("block", [16, 32, 64]) -def triton_attention( - layout, - block: int, - query: torch.Tensor, - key: torch.Tensor, - value: torch.Tensor, - scale: float, -): - sparse_dot_sdd_nt = triton.ops.blocksparse.matmul(layout, block, "sdd", trans_a=False, trans_b=True, device=value.device) - sparse_dot_dsd_nn = triton.ops.blocksparse.matmul(layout, block, "dsd", trans_a=False, trans_b=False, device=value.device) - sparse_softmax = triton.ops.blocksparse.softmax(layout, block, device=value.device) - - w = sparse_dot_sdd_nt(query, key) - w = sparse_softmax(w, scale=scale, is_causal=True) - a = sparse_dot_dsd_nn(w, value) - return a diff --git a/python/test/unit/operators/test_cross_entropy.py b/python/test/unit/operators/test_cross_entropy.py deleted file mode 100644 index 08516257be7c..000000000000 --- a/python/test/unit/operators/test_cross_entropy.py +++ /dev/null @@ -1,35 +0,0 @@ -import pytest -import torch - -import triton - - -@pytest.mark.parametrize("M, N, dtype, mode", - [ - (M, N, dtype, mode) for M in [1024, 821] - for N in [512, 857, 1871, 2089, 8573, 31000] - for dtype in ['float16', 'float32'] - for mode in ['forward', 'backward'] - ] - ) -def test_op(M, N, dtype, mode): - dtype = {'float16': torch.float16, 'float32': torch.float32}[dtype] - # create inputs - x = torch.randn(M, N, dtype=dtype, device='cuda', requires_grad=True) - idx = 4 + torch.ones(M, dtype=torch.int64, device='cuda') - # forward pass - tt_y = triton.ops.cross_entropy(x, idx) - th_y = torch.nn.CrossEntropyLoss(reduction="none")(x, idx) - if mode == 'forward': - triton.testing.assert_almost_equal(th_y, tt_y) - # backward pass - elif mode == 'backward': - dy = torch.randn_like(tt_y) - # triton backward - tt_y.backward(dy) - tt_dx = x.grad.clone() - # torch backward - x.grad.zero_() - th_y.backward(dy) - th_dx = x.grad.clone() - triton.testing.assert_almost_equal(th_dx, tt_dx) diff --git a/python/test/unit/operators/test_matmul.py b/python/test/unit/operators/test_matmul.py deleted file mode 100644 index 514fbab7bb98..000000000000 --- a/python/test/unit/operators/test_matmul.py +++ /dev/null @@ -1,98 +0,0 @@ -import itertools - -import pytest -import torch - -import triton -import triton._C.libtriton.triton as _triton - - -@pytest.mark.parametrize( - "BLOCK_M, BLOCK_N, BLOCK_K, SPLIT_K, NWARP, NSTAGE, M, N, K, AT, BT, DTYPE", - itertools.chain( - *[ - [ - # 1 warp - (16, 16, 16, 1, 1, 2, None, None, None, AT, BT, DTYPE), - (32, 16, 16, 1, 1, 2, None, None, None, AT, BT, DTYPE), - (16, 32, 16, 1, 1, 2, None, None, None, AT, BT, DTYPE), - (16, 16, 32, 1, 1, 2, None, None, None, AT, BT, DTYPE), - (32, 16, 32, 1, 1, 2, None, None, None, AT, BT, DTYPE), - (16, 32, 32, 1, 1, 2, None, None, None, AT, BT, DTYPE), - (16, 16, 64, 1, 1, 2, None, None, None, AT, BT, DTYPE), - (64, 16, 64, 1, 1, 2, None, None, None, AT, BT, DTYPE), - (16, 64, 64, 1, 1, 2, None, None, None, AT, BT, DTYPE), - # 2 warp - (64, 32, 64, 1, 2, 2, None, None, None, AT, BT, DTYPE), - (32, 64, 64, 1, 2, 2, None, None, None, AT, BT, DTYPE), - (64, 32, 16, 1, 2, 2, None, None, None, AT, BT, DTYPE), - (32, 64, 16, 1, 2, 2, None, None, None, AT, BT, DTYPE), - (128, 32, 32, 1, 2, 2, None, None, None, AT, BT, DTYPE), - (32, 128, 32, 1, 2, 2, None, None, None, AT, BT, DTYPE), - # 4 warp - (128, 64, 16, 1, 4, 2, None, None, None, AT, BT, DTYPE), - (64, 128, 16, 1, 4, 2, None, None, None, AT, BT, DTYPE), - (128, 32, 32, 1, 4, 2, None, None, None, AT, BT, DTYPE), - (32, 128, 32, 1, 4, 2, None, None, None, AT, BT, DTYPE), - (128, 32, 64, 1, 4, 2, None, None, None, AT, BT, DTYPE), - (32, 128, 64, 1, 4, 2, None, None, None, AT, BT, DTYPE), - # 8 warp - (128, 256, 16, 1, 8, 2, None, None, None, AT, BT, DTYPE), - (256, 128, 16, 1, 8, 2, None, None, None, AT, BT, DTYPE), - (256, 128, 32, 1, 8, 2, None, None, None, AT, BT, DTYPE), - # split-k - (64, 64, 16, 2, 4, 2, None, None, None, AT, BT, DTYPE), - (64, 64, 16, 4, 4, 2, None, None, None, AT, BT, DTYPE), - (64, 64, 16, 8, 4, 2, None, None, None, AT, BT, DTYPE), - # variable input - (128, 128, 32, 1, 4, 2, 1024, 1024, 1024, AT, BT, DTYPE), - (128, 128, 32, 1, 4, 2, 384, 128, 640, AT, BT, DTYPE), - (128, 128, 32, 1, 4, 2, 107, 233, 256, AT, BT, DTYPE), - (128, 128, 32, 1, 4, 2, 107, 233, 311, AT, BT, DTYPE), - ] for DTYPE in ["float16", "bfloat16", "float32"] for AT in [False, True] for BT in [False, True] - ], - # n-stage - *[ - [ - (16, 16, 16, 1, 1, STAGES, 1024, 1024, 1024, AT, BT, DTYPE), - (64, 32, 64, 1, 2, STAGES, 1024, 1024, 1024, AT, BT, DTYPE), - (128, 64, 16, 1, 4, STAGES, 1024, 1024, 1024, AT, BT, DTYPE), - (256, 128, 32, 1, 8, STAGES, 1024, 1024, 1024, AT, BT, DTYPE), - (128, 128, 32, 1, 4, STAGES, 384, 128, 640, AT, BT, DTYPE), - # split-k - (64, 64, 16, 8, 4, STAGES, 1024, 1024, 1024, AT, BT, DTYPE), - (64, 64, 16, 8, 4, STAGES, 1024, 1024, 32, AT, BT, DTYPE), - ] for DTYPE in ["float16", "bfloat16", "float32"] for AT in [False, True] for BT in [False, True] for STAGES in [2, 3, 4] - ] - ), -) -def test_op(BLOCK_M, BLOCK_N, BLOCK_K, SPLIT_K, NWARP, NSTAGE, M, N, K, AT, BT, DTYPE): - cc = _triton.runtime.cc(_triton.runtime.backend.CUDA, torch.cuda.current_device()) - if cc < 80 and DTYPE == "bfloat16": - pytest.skip("Only test bfloat16 on devices with sm >= 80") - if DTYPE == "bfloat16" and SPLIT_K != 1: - pytest.skip("bfloat16 matmuls don't allow split_k for now") - torch.manual_seed(0) - # nuke kernel decorators -- will set meta-parameters manually - kwargs = {'BLOCK_M': BLOCK_M, 'BLOCK_N': BLOCK_N, 'BLOCK_K': BLOCK_K, 'SPLIT_K': SPLIT_K} - pre_hook = None if SPLIT_K == 1 else lambda nargs: nargs['C'].zero_() - configs = [triton.Config(kwargs=kwargs, num_warps=NWARP, num_stages=NSTAGE, pre_hook=pre_hook)] - kernel = triton.ops._matmul.kernel - decorators = kernel.kernel_decorators - kernel.kernel_decorators = [] - triton.autotune(configs, [])(kernel) - kernel.kernel_decorators += decorators[1:] - # get matrix shape - M = BLOCK_M if M is None else M - N = BLOCK_N if N is None else N - K = BLOCK_K * SPLIT_K if K is None else K - # allocate/transpose inputs - DTYPE = {"float16": torch.float16, "bfloat16": torch.bfloat16, "float32": torch.float32}[DTYPE] - a = .1 * torch.randn((K, M) if AT else (M, K), device="cuda", dtype=DTYPE) - b = .1 * torch.randn((N, K) if BT else (K, N), device="cuda", dtype=DTYPE) - a = a.t() if AT else a - b = b.t() if BT else b - # run test - th_c = torch.matmul(a, b) - tt_c = triton.testing.catch_oor(lambda: triton.ops.matmul(a, b), pytest) - triton.testing.assert_almost_equal(th_c, tt_c) diff --git a/python/test/unit/runtime/test_cache.py b/python/test/unit/runtime/test_cache.py deleted file mode 100644 index d866d698375b..000000000000 --- a/python/test/unit/runtime/test_cache.py +++ /dev/null @@ -1,132 +0,0 @@ -import os -import re -import shutil - -import pytest -import torch - -import triton -import triton.language as tl -from triton.code_gen import JITFunction - -tmpdir = ".tmp" - - -@triton.jit -def function_1(i): - i = i + 1 - i = function_2(i) - return i - - -@triton.jit -def function_2(i): - i = i + 1 - return i - - -@triton.jit -def kernel(X, i, BLOCK: tl.constexpr): - i = i + 1 - i = function_1(i) - tl.store(X, i) - - -@triton.jit(do_not_specialize=["i"]) -def kernel_nospec(X, i, BLOCK: tl.constexpr): - i = i + 1 - i = function_1(i) - tl.store(X, i) - - -def apply_src_change(target, old, new): - kernel.hash = None - function_1.hash = None - function_2.hash = None - function_1.src = function_1.src.replace(old, new) - target.src = target.src.replace(old, new) - ret = target.cache_key - target.src = target.src.replace(new, old) - return ret - - -def test_nochange(): - baseline = kernel.cache_key - updated = apply_src_change(kernel, 'i + 1', 'i + 1') - assert baseline == updated - - -def test_toplevel_change(): - baseline = kernel.cache_key - updated = apply_src_change(kernel, 'i + 1', 'i + 2') - assert baseline != updated - - -def test_nested1_change(): - baseline = kernel.cache_key - updated = apply_src_change(function_1, 'i + 1', 'i + 2') - assert baseline != updated - - -def reset_tmp_dir(): - os.environ["TRITON_CACHE_DIR"] = tmpdir - if os.path.exists(tmpdir): - shutil.rmtree(tmpdir) - - -def test_reuse(): - counter = 0 - - def inc_counter(*args, **kwargs): - nonlocal counter - counter += 1 - JITFunction.cache_hook = inc_counter - reset_tmp_dir() - x = torch.empty(1, dtype=torch.int32, device='cuda') - for i in range(10): - kernel[(1,)](x, 1, BLOCK=1024) - assert counter == 1 - - -@pytest.mark.parametrize('mode', ['enable', 'disable']) -def test_specialize(mode): - counter = 0 - - def inc_counter(*args, **kwargs): - nonlocal counter - counter += 1 - JITFunction.cache_hook = inc_counter - reset_tmp_dir() - x = torch.empty(1, dtype=torch.int32, device='cuda') - function = {'enable': kernel, 'disable': kernel_nospec}[mode] - target = {'enable': 5, 'disable': 1}[mode] - for i in [1, 2, 4, 8, 16, 32]: - function[(1,)](x, i, BLOCK=512) - assert counter == target - - -@pytest.mark.parametrize("value, value_type", [ - (-1, 'int32'), (0, 'int32'), (1, None), (-2**31, 'int32'), (2**31 - 1, 'int32'), - (2**32, 'int64'), (2**63 - 1, 'int64'), (-2**63, 'int64'), - (2**31, 'uint32'), (2**32 - 1, 'uint32'), (2**63, 'uint64'), (2**64 - 1, 'uint64') -]) -def test_value_specialization(value: int, value_type: str, device='cuda') -> None: - - @triton.jit - def kernel(VALUE, X): - pass - - cache_str = None - - def get_cache_str(*args, **kwargs): - nonlocal cache_str - cache_str = kwargs['key'].split('-') - triton.code_gen.JITFunction.cache_hook = get_cache_str - reset_tmp_dir() - x = torch.tensor([3.14159], device='cuda') - kernel[(1, )](value, x) - triton.code_gen.JITFunction.cache_hook = None - - cache_str_match = re.match(r'_(\w+)\[multipleof\(\d+\)]_float32\*\[multipleof\(16\)\]', cache_str[-1]) - spec_type = None if cache_str_match is None else cache_str_match.group(1) - assert spec_type == value_type diff --git a/python/test/unit/runtime/test_comm.py b/python/test/unit/runtime/test_comm.py deleted file mode 100644 index ae3fb69d7466..000000000000 --- a/python/test/unit/runtime/test_comm.py +++ /dev/null @@ -1,98 +0,0 @@ -import subprocess - -import numpy as np -import pytest -import torch - -import triton -import triton.language as tl - - -def get_p2p_matrix(): - try: - stdout = subprocess.check_output(["nvidia-smi", "topo", "-p2p", "n"]).decode("ascii") - except subprocess.CalledProcessError: - return pytest.skip("No multi-GPU topology", allow_module_level=True) - - lines = stdout.split("Legend")[0].split('\n')[1:] - matrix = np.array([line.split('\t')[1:-1] for line in lines][:-2]) - if matrix.size <= 1: - return pytest.skip("No multi-GPU topology", allow_module_level=True) - else: - return matrix - - -def get_p2p_devices(): - matrix = get_p2p_matrix() - idx = np.where(matrix == "OK") - return [f"cuda:{idx[0][0]}", f"cuda:{idx[1][0]}"] if len(idx[0]) > 0 else [] - - -def get_non_p2p_devices(): - matrix = get_p2p_matrix() - idx = np.where(matrix == "NS") - return [f"cuda:{idx[0][0]}", f"cuda:{idx[1][0]}"] if len(idx[0]) > 0 else [] - - -p2p_devices = get_p2p_devices() -non_p2p_devices = get_non_p2p_devices() - - -@triton.jit -def _copy(from_ptr, to_ptr, N, **meta): - pid = tl.program_id(0) - offsets = pid * meta['BLOCK'] + tl.arange(0, meta['BLOCK']) - values = tl.load(from_ptr + offsets, mask=offsets < N) - tl.store(to_ptr + offsets, values, mask=offsets < N) - - -@pytest.mark.skipif(not p2p_devices, reason="No pair of device with P2P support") -@pytest.mark.parametrize("device_kernel, device_from, device_to, stream_from, stream_to", - [(device_kernel, device_from, device_to, stream_from, stream_to) - for device_kernel in p2p_devices - for device_from in p2p_devices - for device_to in p2p_devices - for stream_from in ['default', 'custom'] - for stream_to in ['default', 'custom'] - ]) -def test_p2p(device_kernel, device_from, device_to, stream_from, stream_to): - if device_to == device_from: - return pytest.skip() - - torch.cuda.set_device(device_kernel) - N = 512 - grid = lambda meta: (triton.cdiv(N, meta['BLOCK']),) - - with torch.cuda.stream(None if stream_from == 'default' else torch.cuda.Stream(device_from)): - x_from = torch.randn(N, dtype=torch.float32, device=device_from) - with torch.cuda.stream(None if stream_to == 'default' else torch.cuda.Stream(device_to)): - x_to = torch.empty(N, dtype=torch.float32, device=device_to) - - _copy[grid](x_from, x_to, N, BLOCK=1024) - assert torch.allclose(x_from, x_to.to(device_from)) - - -@pytest.mark.skipif(not non_p2p_devices, reason="No pair of device with no P2P support") -@pytest.mark.parametrize("device_kernel, device_from, device_to, stream_from, stream_to", - [(device_kernel, device_from, device_to, stream_from, stream_to) - for device_kernel in non_p2p_devices - for device_from in non_p2p_devices - for device_to in non_p2p_devices - for stream_from in ['default', 'custom'] - for stream_to in ['default', 'custom'] - ]) -def test_non_p2p(device_kernel, device_from, device_to, stream_from, stream_to): - if device_to == device_from: - return pytest.skip() - - with pytest.raises(RuntimeError): - torch.cuda.set_device(device_kernel) - N = 512 - grid = lambda meta: (triton.cdiv(N, meta['BLOCK']),) - - with torch.cuda.stream(None if stream_from == 'default' else torch.cuda.Stream(device_from)): - x_from = torch.randn(N, dtype=torch.float32, device=device_from) - with torch.cuda.stream(None if stream_to == 'default' else torch.cuda.Stream(device_to)): - x_to = torch.empty(N, dtype=torch.float32, device=device_to) - - _copy[grid](x_from, x_to, N, BLOCK=1024) diff --git a/python/triton/__init__.py b/python/triton/__init__.py index 37ba46efc86a..6963bf25cc1d 100644 --- a/python/triton/__init__.py +++ b/python/triton/__init__.py @@ -6,9 +6,9 @@ # or pybind11 shows `munmap_chunk(): invalid pointer` import torch # submodules -from .code_gen import cdiv, next_power_of_2, jit, autotune, heuristics, \ - JITFunction, Config, Autotuner, reinterpret +from .utils import * +from .runtime import jit, Config, autotune, heuristics +from .compiler import compile from . import language -from . import code_gen from . import testing from . import ops diff --git a/python/triton/code_gen.py b/python/triton/code_gen.py deleted file mode 100644 index 11ed31381eb9..000000000000 --- a/python/triton/code_gen.py +++ /dev/null @@ -1,1516 +0,0 @@ -from __future__ import annotations - -import ast -import builtins -import functools -import hashlib -import inspect -import os -import pickle -import subprocess -import sys -import tempfile -import textwrap -import time -import warnings -from typing import Dict, Optional, Set, Tuple, Union -from numpy import isin - -import torch -from filelock import FileLock - -import triton -import triton._C.libtriton.triton as _triton -from .tools.disasm import extract - - -def mangle_ty(ty): - if ty.is_ptr(): - return 'P' + mangle_ty(ty.element_ty) - if ty.is_int(): - return 'i' + str(ty.int_bitwidth) - if ty.is_fp8(): - return 'fp8' - if ty.is_fp16(): - return 'fp16' - if ty.is_bf16(): - return 'bf16' - if ty.is_fp32(): - return 'fp32' - if ty.is_fp64(): - return 'fp64' - if ty.is_void(): - return 'V' - if ty.is_block(): - elt = mangle_ty(ty.scalar) - shape = '_'.join(map(str, ty.shape)) - return f'{elt}S{shape}S' - assert False, "Unsupport type" - - -def mangle_fn(name, arg_tys, constants): - # doesn't mangle ret type, which must be a function of arg tys - mangled_arg_names = '_'.join([mangle_ty(ty) for ty in arg_tys]) - key = lambda x: x.__name__ if isinstance(x, JITFunction) else repr(x) - mangled_constants = '_'.join([f'{i}c{key(constants[i])}' for i in sorted(constants)]) - mangled_constants = mangled_constants.replace('.', '_d_') - mangled_constants = mangled_constants.replace("'", '_sq_') - ret = f'{name}__{mangled_arg_names}__{mangled_constants}' - return ret - -class enter_sub_region: - def __init__(self, generator: CodeGenerator): - self.generator = generator - - def __enter__(self): - # record lscope & local_defs in the parent scope - self.liveins = self.generator.lscope.copy() - self.prev_defs = self.generator.local_defs.copy() - self.generator.local_defs = {} - self.insert_block = self.generator.builder.get_insertion_block() - return self.liveins, self.insert_block - - def __exit__(self, *args, **kwargs): - self.generator.builder.set_insertion_point_to_end(self.insert_block) - self.generator.lscope = self.liveins - self.generator.local_defs = self.prev_defs - -class CodeGenerator(ast.NodeVisitor): - def __init__(self, context, prototype, gscope, attributes, constants, module=None, is_kernel=False, function_types=dict()): - self.builder = _triton.ir.builder(context) - self.module = self.builder.create_module() if module is None else module - self.function_ret_types = function_types - self.prototype = prototype - self.gscope = gscope - self.lscope = dict() - self.attributes = attributes - self.constants = constants - self.is_kernel = is_kernel - self.last_node = None - self.builtins = { - 'range': range, - 'min': triton.language.minimum, - 'float': float, - 'int': int, - 'print': print, - 'isinstance': isinstance, - 'getattr': getattr, - } - # SSA-construction - # name => triton.language.tensor - self.local_defs: Dict[str, triton.language.tensor] = {} - self.global_uses: Dict[str, triton.language.tensor] = {} - - def get_value(self, name): - ''' This function: - 1. make sure `name` is defined - 2. if `name` is triton.language.tensor, get stored tensor by calling - `self._get_tensor()` - ''' - # search node.id in local scope - ret = None - if name in self.lscope: - ret = self.lscope[name] - if name not in self.local_defs: - self.global_uses[name] = ret - # search node.id in global scope - elif name in self.gscope: - ret = self.gscope[name] - # search node.id in builtins - elif name in self.builtins: - ret = self.builtins[name] - else: - raise ValueError(f'{name} is not defined') - return ret - - def set_value(self, name: str, - value: Union[triton.language.tensor, triton.language.constexpr]) -> None: - ''' This function: - called by visit_Assign() & visit_FuncDef() to store left value (lvalue) - 1. record local defined name (FIXME: should consider control flow) - 2. store tensor in self.lvalue - ''' - self.lscope[name] = value - self.local_defs[name] = value - - def is_triton_tensor(self, value): - return isinstance(value, triton.language.tensor) - - # - # AST visitor - # - def visit_compound_statement(self, stmts): - for stmt in stmts: - self.last_ret_type = self.visit(stmt) - if isinstance(stmt, ast.Return): - break - return stmts and isinstance(stmt, ast.Return) - - def visit_Module(self, node): - ast.NodeVisitor.generic_visit(self, node) - - def visit_List(self, node): - ctx = self.visit(node.ctx) - assert ctx is None - elts = [self.visit(elt) for elt in node.elts] - return elts - - # By design, only non-kernel functions can return - def visit_Return(self, node): - ret_value = self.visit(node.value) - if ret_value is None: - self.builder.ret([]) - return None - if isinstance(ret_value, tuple): - ret_values = [triton.language.core._to_tensor(v, self.builder) for v in ret_value] - ret_types = [v.type for v in ret_values] - self.builder.ret([v.handle for v in ret_values]) - return tuple(ret_types) - else: - ret = triton.language.core._to_tensor(ret_value, self.builder) - self.builder.ret([ret_value.handle]) - return ret.type - - def visit_FunctionDef(self, node): - arg_names, kwarg_names = self.visit(node.args) - # initialize defaults - for i, default_value in enumerate(node.args.defaults): - arg_node = node.args.args[-i - 1] - annotation = arg_node.annotation - name = arg_node.arg - st_target = ast.Name(id=name, ctx=ast.Store()) - if annotation is None: - init_node = ast.Assign(targets=[st_target], value=default_value) - else: - init_node = ast.AnnAssign(target=st_target, value=default_value, annotation=annotation) - self.visit(init_node) - # initialize function - fn_name = mangle_fn(node.name, self.prototype.param_types, self.constants) - fn = self.builder.get_or_insert_function(self.module, fn_name, self.prototype.to_ir(self.builder)) - self.module.push_back(fn) - entry = fn.add_entry_block() - arg_values = [] - idx = 0 - for i, arg_name in enumerate(arg_names): - if i in self.constants: - cst = self.constants[i] - if not isinstance(cst, triton.language.constexpr): - cst = triton.language.constexpr(self.constants[i]) - arg_values.append(cst) - else: - pass - if i in self.attributes: - fn.set_arg_attr(idx, "tt.divisibility", self.attributes[i]) - arg_values.append(triton.language.tensor(fn.args(idx), self.prototype.param_types[idx])) - idx += 1 - - insert_pt = self.builder.get_insertion_block() - for arg_name, arg_value in zip(arg_names, arg_values): - self.set_value(arg_name, arg_value) - self.builder.set_insertion_point_to_start(entry) - # visit function body - has_ret = self.visit_compound_statement(node.body) - # finalize function - if not has_ret: - self.builder.ret([]) - else: - # update return type - if isinstance(self.last_ret_type, tuple): - self.prototype.ret_types = list(self.last_ret_type) - fn.reset_type(self.prototype.to_ir(self.builder)) - else: - self.prototype.ret_types = [self.last_ret_type] - fn.reset_type(self.prototype.to_ir(self.builder)) - if insert_pt: - self.builder.set_insertion_point_to_end(insert_pt) - - def visit_arguments(self, node): - arg_names = [] - for arg in node.args: - arg_names += [self.visit(arg)] - kwarg_names = self.visit(node.kwarg) - return arg_names, kwarg_names - - def visit_arg(self, node): - ast.NodeVisitor.generic_visit(self, node) - return node.arg - - def visit_AnnAssign(self, node): - # extract attributes - annotation = self.visit(node.annotation) - target = self.visit(node.target) - value = self.visit(node.value) - # constexpr - if annotation == triton.language.constexpr: - if target in self.lscope: - raise ValueError(f'{target} is already defined.' - f' constexpr cannot be reassigned.') - if not isinstance(value, triton.language.constexpr): - value = triton.language.constexpr(value) - self.lscope[target] = value - return self.lscope[target] - # default: call visit_Assign - return self.visit_Assign(node) - - def visit_Assign(self, node): - _names = [] - for target in node.targets: - _names += [self.visit(target)] - assert len(_names) == 1 - names = _names[0] - values = self.visit(node.value) - if not isinstance(names, tuple): - names = [names] - if not isinstance(values, tuple): - values = [values] - for name, value in zip(names, values): - # by default, constexpr are assigned into python variable - if isinstance(value, triton.language.constexpr): - value = value.value - if not isinstance(value, triton.language.tensor): - value = triton.language.core._to_tensor(value, self.builder) - self.set_value(name, value) - - def visit_AugAssign(self, node): - name = node.target.id - lhs = ast.Name(id=name, ctx=ast.Load()) - rhs = ast.BinOp(lhs, node.op, node.value) - assign = ast.Assign(targets=[node.target], value=rhs) - self.visit(assign) - return self.get_value(name) - - def visit_Name(self, node): - if type(node.ctx) == ast.Store: - return node.id - return self.get_value(node.id) - - def visit_Store(self, node): - ast.NodeVisitor.generic_visit(self, node) - - def visit_Load(self, node): - ast.NodeVisitor.generic_visit(self, node) - - def visit_Tuple(self, node): - args = [self.visit(x) for x in node.elts] - return tuple(args) - - def visit_BinOp(self, node): - lhs = self.visit(node.left) - rhs = self.visit(node.right) - if isinstance(lhs, triton.language.constexpr): - lhs = lhs.value - if isinstance(rhs, triton.language.constexpr): - rhs = rhs.value - fn = { - ast.Add: '__add__', - ast.Sub: '__sub__', - ast.Mult: '__mul__', - ast.Div: '__truediv__', - ast.FloorDiv: '__floordiv__', - ast.Mod: '__mod__', - ast.Pow: '__pow__', - ast.LShift: '__lshift__', - ast.RShift: '__rshift__', - ast.BitAnd: '__and__', - ast.BitOr: '__or__', - ast.BitXor: '__xor__', - }[type(node.op)] - if self.is_triton_tensor(lhs): - return getattr(lhs, fn)(rhs, _builder=self.builder) - elif self.is_triton_tensor(rhs): - fn = fn[:2] + 'r' + fn[2:] - return getattr(rhs, fn)(lhs, _builder=self.builder) - else: - return getattr(lhs, fn)(rhs) - - def visit_If(self, node): - cond = self.visit(node.test) - if isinstance(cond, triton.language.tensor): - cond = cond.to(triton.language.int1, _builder=self.builder) - with enter_sub_region(self) as sr: - liveins, ip_block = sr - - then_block = self.builder.create_block() - self.builder.set_insertion_point_to_start(then_block) - self.visit_compound_statement(node.body) - then_defs = self.local_defs.copy() - - # when need an else block when: - # 1. we have an orelse node - # or - # 2. the then block defines new variable - if then_defs or node.orelse: - if node.orelse: - self.lscope = liveins - self.local_defs = {} - else_block = self.builder.create_block() - self.builder.set_insertion_point_to_end(else_block) - self.visit_compound_statement(node.orelse) - else_defs = self.local_defs.copy() - else: - # collect else_defs - else_defs = {} - for name in then_defs: - if name in liveins: - assert self.is_triton_tensor(then_defs[name]) - assert self.is_triton_tensor(liveins[name]) - else_defs[name] = liveins[name] - # collect yields - names = [] - ret_types = [] - for then_name in then_defs: - for else_name in else_defs: - if then_name == else_name: - if then_defs[then_name].type == else_defs[else_name].type: - names.append(then_name) - ret_types.append(then_defs[then_name].type) - - self.builder.set_insertion_point_to_end(ip_block) - - if then_defs or node.orelse: # with else block - if_op = self.builder.create_if_op([ty.to_ir(self.builder) for ty in ret_types], cond.handle, True) - then_block.merge_block_before(if_op.get_then_block()) - self.builder.set_insertion_point_to_end(if_op.get_then_block()) - self.builder.create_yield_op([then_defs[n].handle for n in names]) - if not node.orelse: - else_block = if_op.get_else_block() - else: - else_block.merge_block_before(if_op.get_else_block()) - self.builder.set_insertion_point_to_end(if_op.get_else_block()) - self.builder.create_yield_op([else_defs[n].handle for n in names]) - else: # no else block - if_op = self.builder.create_if_op([ty.to_ir(self.builder) for ty in ret_types], cond.handle, False) - then_block.merge_block_before(if_op.get_then_block()) - - # update values yielded by IfOp - for i, name in enumerate(names): - new_tensor = triton.language.core.tensor(if_op.get_result(i), ret_types[i]) - self.lscope[name] = new_tensor - self.local_defs[name] = new_tensor - - else: - if isinstance(cond, triton.language.constexpr): - cond = cond.value - if cond: - self.visit_compound_statement(node.body) - else: - self.visit_compound_statement(node.orelse) - - def visit_IfExp(self, node): - cond = self.visit(node.test) - if cond.value: - return self.visit(node.body) - else: - return self.visit(node.orelse) - - def visit_Pass(self, node): - pass - - def visit_Compare(self, node): - assert len(node.comparators) == 1 - assert len(node.ops) == 1 - lhs = self.visit(node.left) - rhs = self.visit(node.comparators[0]) - if isinstance(lhs, triton.language.constexpr): - lhs = lhs.value - if isinstance(rhs, triton.language.constexpr): - rhs = rhs.value - if type(node.ops[0]) == ast.Is: - return triton.language.constexpr(lhs is rhs) - if type(node.ops[0]) == ast.IsNot: - return triton.language.constexpr(lhs is not rhs) - fn = { - ast.Eq: '__eq__', - ast.NotEq: '__ne__', - ast.Lt: '__lt__', - ast.LtE: '__le__', - ast.Gt: '__gt__', - ast.GtE: '__ge__', - }[type(node.ops[0])] - if self.is_triton_tensor(lhs): - return getattr(lhs, fn)(rhs, _builder=self.builder) - elif self.is_triton_tensor(rhs): - fn = fn[:2] + 'r' + fn[2:] - return getattr(rhs, fn)(lhs, _builder=self.builder) - else: - return getattr(lhs, fn)(rhs) - - def visit_UnaryOp(self, node): - op = self.visit(node.operand) - if type(node.op) == ast.Not: - assert isinstance(op, triton.language.constexpr), "`not` only supported for constexpr at the moment" - return triton.language.constexpr(not op) - if isinstance(op, triton.language.constexpr): - op = op.value - fn = { - ast.USub: '__neg__', - ast.UAdd: '__pos__', - ast.Invert: '__invert__', - }[type(node.op)] - if self.is_triton_tensor(op): - return getattr(op, fn)(_builder=self.builder) - return getattr(op, fn)() - - def visit_While(self, node): - with enter_sub_region(self) as sr: - liveins, insert_block = sr - - # condtion (the before region) - cond_block = self.builder.create_block() - self.builder.set_insertion_point_to_start(cond_block) - cond = self.visit(node.test) - - # loop body (the after region) - loop_block = self.builder.create_block() - self.builder.set_insertion_point_to_start(loop_block) - self.visit_compound_statement(node.body) - loop_defs = self.local_defs - - # collect loop-carried values - names = [] - ret_types = [] - init_args = [] - yields = [] - for name in loop_defs: - if name in liveins: - # We should not def new constexpr - assert self.is_triton_tensor(loop_defs[name]) - assert self.is_triton_tensor(liveins[name]) - if loop_defs[name].type == liveins[name].type: - # these are loop-carried values - names.append(name) - ret_types.append(loop_defs[name].type) - init_args.append(liveins[name]) - yields.append(loop_defs[name]) - - self.builder.set_insertion_point_to_end(insert_block) - while_op = self.builder.create_while_op([ty.to_ir(self.builder) for ty in ret_types], - [arg.handle for arg in init_args]) - # merge the condition region - before_block = self.builder.create_block_with_parent(while_op.get_before(), - [ty.to_ir(self.builder) for ty in ret_types]) - cond_block.merge_block_before(before_block) - self.builder.set_insertion_point_to_end(before_block) - # create CondtionOp: e.g., scf.condition(%cond) %arg0, %arg1, ... - self.builder.create_condtion_op(cond.handle, [before_block.arg(i) for i in range(len(init_args))]) - # merge the loop body - after_block = self.builder.create_block_with_parent(while_op.get_after(), - [ty.to_ir(self.builder) for ty in ret_types]) - loop_block.merge_block_before(after_block) - self.builder.set_insertion_point_to_end(after_block) - self.builder.create_yield_op([y.handle for y in yields]) - - # update global uses in while_op - for i, name in enumerate(names): - before_block.replace_use_in_block_with(init_args[i].handle, before_block.arg(i)) - after_block.replace_use_in_block_with(init_args[i].handle, after_block.arg(i)) - - # WhileOp defines new values, update the symbol table (lscope, local_defs) - for i, name in enumerate(names): - new_def = triton.language.core.tensor(while_op.get_result(i), ret_types[i]) - self.lscope[name] = new_def - self.local_defs[name] = new_def - - for stmt in node.orelse: - assert False, "Not implemented" - ast.NodeVisitor.generic_visit(self, stmt) - - def visit_Subscript(self, node): - assert node.ctx.__class__.__name__ == "Load" - lhs = self.visit(node.value) - slices = self.visit(node.slice) - if self.is_triton_tensor(lhs): - return lhs.__getitem__(slices, _builder=self.builder) - return lhs[slices] - - def visit_ExtSlice(self, node): - return [self.visit(dim) for dim in node.dims] - - def visit_For(self, node): - iterator = self.visit(node.iter.func) - if iterator != self.builtins['range']: - raise RuntimeError('Only `range` iterator currently supported') - # static for loops: all iterator arguments are constexpr - iter_args = [self.visit(arg) for arg in node.iter.args] - is_static = all([isinstance(x, triton.language.constexpr) for x in iter_args]) - if is_static: - st_target = ast.Name(id=node.target.id, ctx=ast.Store()) - iter_args = [arg.value for arg in iter_args] - range = iterator(*iter_args) - if len(range) <= 10: - for i in iterator(*iter_args): - self.lscope[node.target.id] = triton.language.constexpr(i) - self.visit_compound_statement(node.body) - for stmt in node.orelse: - ast.NodeVisitor.generic_visit(self, stmt) - return - - # collect lower bound (lb), upper bound (ub), and step - lb = self.visit(node.iter.args[0] if len(node.iter.args) > 1 else ast.Num(0)) - ub = self.visit(node.iter.args[1] if len(node.iter.args) > 1 else node.iter.args[0]) - step = self.visit(node.iter.args[2] if len(node.iter.args) > 2 else ast.Num(1)) - # lb/ub/step might be constexpr, we need to cast them to tensor - lb = triton.language.core._to_tensor(lb, self.builder).handle - ub = triton.language.core._to_tensor(ub, self.builder).handle - step = triton.language.core._to_tensor(step, self.builder).handle - # ForOp can only accept IndexType as lb/ub/step. Cast integer to Index - lb = self.builder.create_to_index(lb) - ub = self.builder.create_to_index(ub) - step = self.builder.create_to_index(step) - - with enter_sub_region(self) as sr: - liveins, insert_block = sr - - # create loop body block - block = self.builder.create_block() - self.builder.set_insertion_point_to_start(block) - - # visit loop body - self.visit_compound_statement(node.body) - - # If a variable (name) is defined in both its parent & itself, then it's - # a loop-carried variable. (They must be of the same type) - init_args = [] - yields = [] - names = [] - for name in self.local_defs: - if name in liveins: - assert self.is_triton_tensor(self.local_defs[name]), f'{name} is not tensor' - assert self.is_triton_tensor(liveins[name]) - if self.local_defs[name].type == liveins[name].type: - names.append(name) - init_args.append(triton.language.core._to_tensor(liveins[name], self.builder)) - yields.append(triton.language.core._to_tensor(self.local_defs[name], self.builder)) - # create ForOp - self.builder.set_insertion_point_to_end(insert_block) - for_op = self.builder.create_for_op(lb, ub, step, [arg.handle for arg in init_args]) - block.merge_block_before(for_op.get_body(0)) - # create YieldOp - self.builder.set_insertion_point_to_end(for_op.get_body(0)) - self.builder.create_yield_op([y.handle for y in yields]) - for_op_region = for_op.get_body(0).get_parent() - assert for_op_region.size() == 1, "We use SCF, so the loop body should only have one block" - # replace global uses with block arguments - for i, name in enumerate(names): - # arg0 is the induction variable - for_op.get_body(0).replace_use_in_block_with(init_args[i].handle, for_op.get_body(0).arg(i+1)) - - # update lscope & local_defs (ForOp defines new values) - for i, name in enumerate(names): - self.lscope[name] = triton.language.core.tensor(for_op.get_result(i), yields[i].type) - self.local_defs[name] = triton.language.core.tensor(for_op.get_result(i), yields[i].type) - - for stmt in node.orelse: - assert False, "Don't know what to do with else after for" - ast.NodeVisitor.generic_visit(self, stmt) - - def visit_Slice(self, node): - lower = self.visit(node.lower) - upper = self.visit(node.upper) - step = self.visit(node.step) - return slice(lower, upper, step) - - def visit_Index(self, node): - return self.visit(node.value) - - def visit_keyword(self, node): - return {node.arg: self.visit(node.value)} - - def visit_Call(self, node): - fn = self.visit(node.func) - if isinstance(fn, triton.language.constexpr): - fn = fn.value - kws = dict() - for keyword in node.keywords: - kws.update(self.visit(keyword)) - args = [self.visit(arg) for arg in node.args] - if isinstance(fn, JITFunction): - from inspect import getcallargs - args = getcallargs(fn.fn, *args, **kws) - args = [args[name] for name in fn.arg_names] - args = [arg if isinstance(arg, triton.language.tensor) - else triton.language.constexpr(arg) for arg in args] - # generate function def - attributes = dict() - constexprs = [i for i, arg in enumerate(args) if isinstance(arg, triton.language.constexpr)] - constants = {i: args[i] for i in constexprs} - # generate call - args = [None if i in constexprs else arg for i, arg in enumerate(args)] - arg_vals = [arg.handle for arg in args if arg is not None] - arg_types = [arg.type for arg in args if arg is not None] - fn_name = mangle_fn(fn.__name__, arg_types, constants) - # generate function def if necessary - if not self.module.has_function(fn_name): - ret_type = triton.language.void - prototype = triton.language.function_type([ret_type], arg_types) - gscope = sys.modules[fn.fn.__module__].__dict__ - generator = CodeGenerator(self.builder.context, prototype, gscope, attributes, constants, module=self.module, function_types=self.function_ret_types) - generator.visit(fn.parse()) - callee_ret_type = generator.last_ret_type - self.function_ret_types[fn_name] = callee_ret_type - else: - callee_ret_type = self.function_ret_types[fn_name] - symbol = self.module.get_function(fn_name) - call_op = self.builder.call(symbol, arg_vals) - if call_op.get_num_results() == 0: - return None - elif call_op.get_num_results() == 1: - return triton.language.tensor(call_op.get_result(0), callee_ret_type) - else: - # should return a tuple of tl.tensor - results = [] - for i in range(call_op.get_num_results()): - results.append(triton.language.tensor(call_op.get_result(i), callee_ret_type[i])) - return tuple(results) - if hasattr(fn, '__self__') and self.is_triton_tensor(fn.__self__) or \ - sys.modules[fn.__module__] is triton.language.core: - return fn(*args, _builder=self.builder, **kws) - if fn in self.builtins.values(): - args = [arg.value if isinstance(arg, triton.language.constexpr) else arg - for arg in args] - return fn(*args, **kws) - - def visit_Constant(self, node): - return triton.language.constexpr(node.value) - - if sys.version_info < (3, 8): - def visit_NameConstant(self, node): - return triton.language.constexpr(node.value) - - def visit_Num(self, node): - return triton.language.constexpr(node.n) - - def visit_Str(self, node): - return triton.language.constexpr(ast.literal_eval(node)) - - def visit_Attribute(self, node): - lhs = self.visit(node.value) - return getattr(lhs, node.attr) - - def visit_Expr(self, node): - ast.NodeVisitor.generic_visit(self, node) - - def visit_NoneType(self, node): - return None - - def visit(self, node): - if node is not None: - self.last_node = node - with warnings.catch_warnings(): - # The ast library added visit_Constant and deprecated some other - # methods but we can't move to that without breaking Python 3.6 and 3.7. - warnings.simplefilter("ignore", DeprecationWarning) # python 3.9 - warnings.simplefilter("ignore", PendingDeprecationWarning) # python 3.8 - return super().visit(node) - - def generic_visit(self, node): - typename = type(node).__name__ - raise NotImplementedError("Unsupported node: {}".format(typename)) - - -class Binary: - def __init__(self, backend, name, asm, shared_mem, num_warps): - self.backend = backend - self.name = name - self.asm = asm - self.shared_mem = shared_mem - self.num_warps = num_warps - - -class LoadedBinary: - def __init__(self, device: int, bin: Binary): - module, kernel = _triton.code_gen.load_binary(bin.backend, - bin.name, - bin.asm, - bin.shared_mem, - device) - self.bin = bin - self.asm = bin.asm - self.sass = '' - self.module = module - self.kernel = kernel - self.device = device - self.shared_mem = bin.shared_mem - - def __call__(self, stream, args, grid_0, grid_1=1, grid_2=1): - _triton.runtime.enqueue(self.bin.backend, stream, self.kernel, - grid_0, grid_1, grid_2, - self.bin.num_warps * 32, 1, 1, - args, self.bin.shared_mem) - - def get_sass(self, fun=None): - if self.sass: - return self.sass - fd, path = tempfile.mkstemp() - try: - with open(fd, 'wb') as cubin: - cubin.write(self.asm['cubin']) - self.sass = extract(path, fun) - finally: - os.remove(path) - self.asm['sass'] = self.sass - return self.sass - - -class CompilationError(Exception): - def __init__(self, src, node): - self.message = f'at {node.lineno}:{node.col_offset}:\n' - self.message += '\n'.join(src.split('\n')[:node.lineno]) - self.message += '\n' + ' ' * node.col_offset + '^' - self.src = src - self.node = node - super().__init__(self.message) - - def __reduce__(self): - # this is necessary to make CompilationError picklable - return (type(self), (self.src, self.node)) - - -class OutOfResources(Exception): - def __init__(self, required, limit, name): - self.message = f'out of resource: {name}, '\ - f'Required: {required}, '\ - f'Hardware limit: {limit}' - self.required = required - self.limit = limit - self.name = name - super().__init__(self.message) - - def __reduce__(self): - # this is necessary to make CompilationError picklable - return (type(self), (self.required, self.limit, self.name)) - - -class Kernel: - @staticmethod - def _type_name(obj): - type_names = { - triton.language.float8: 'f8', - torch.bfloat16: 'bf16', - torch.float16: 'f16', - torch.float32: 'f32', - torch.float64: 'f64', - torch.bool: 'i1', - torch.int8: 'i8', - torch.int16: 'i16', - torch.int32: 'i32', - torch.int64: 'i64', - triton.language.uint8: 'u8', - triton.language.uint16: 'u16', - triton.language.uint32: 'u32', - triton.language.uint64: 'u64', - } - if hasattr(obj, 'data_ptr'): - return type_names[obj.dtype] - if isinstance(obj, triton.language.constexpr): - obj = obj.value - if isinstance(obj, int): - if -2**31 <= obj < 2**31: - return 'i32' - elif 2**31 <= obj < 2**32: - return 'u32' - elif -2**63 <= obj < 2**63: - return 'i64' - elif 2**63 <= obj < 2**64: - return 'u64' - else: - raise ValueError(f'integer overflow representing {obj}') - if isinstance(obj, float): - return 'f' - if isinstance(obj, bool): - return 'B' - if isinstance(obj, str): - return 'str' - raise NotImplementedError(f'could not compute type name for {obj}') - - @staticmethod - def _to_python_ir(obj): - # convert torch.Tensor to Triton IR pointers - if hasattr(obj, 'data_ptr'): - name = Kernel._type_name(obj) - return 'ptr', name - # default path returns triton.ir.type directly - name = Kernel._type_name(obj) - return 'scalar', name - - @staticmethod - def _to_triton_ir(obj): - which, name = obj - type_map = { - 'I': triton.language.int32, - 'L': triton.language.int64, - 'f': triton.language.float32, - 'B': triton.language.int1, - 'f8': triton.language.float8, - 'f16': triton.language.float16, - 'bf16': triton.language.bfloat16, - 'f32': triton.language.float32, - 'f64': triton.language.float64, - 'i1': triton.language.int1, - 'i8': triton.language.int8, - 'i16': triton.language.int16, - 'i32': triton.language.int32, - 'i64': triton.language.int64, - 'u8': triton.language.uint8, - 'u16': triton.language.uint16, - 'u32': triton.language.uint32, - 'u64': triton.language.uint64, - } - # convert torch.Tensor to Triton IR pointers - if which == 'ptr': - elt_ty = type_map[name] - return triton.language.pointer_type(elt_ty, 1) - # default path returns triton.ir.type directly - return type_map[name] - - @staticmethod - def pow2_divisor(N): - if N % 16 == 0: - return 16 - if N % 8 == 0: - return 8 - if N % 4 == 0: - return 4 - if N % 2 == 0: - return 2 - return 1 - - def __init__(self, fn): - self.fn = fn - - def add_to_cache(self, key, wargs, device_idx, num_warps, num_stages): - tensor_idxs = [i for i, arg in enumerate(wargs) if hasattr(arg, 'data_ptr')] - # attributes - attributes = dict() - for i, arg in enumerate(wargs): - if i in self.fn.do_not_specialize: - continue - if isinstance(arg, int): - attributes[i] = Kernel.pow2_divisor(arg) - elif i in tensor_idxs: - addr = arg.data_ptr() - range_size = _triton.runtime.get_pointer_range_size(addr) - attributes[i] = min(Kernel.pow2_divisor(addr), - Kernel.pow2_divisor(range_size)) - # transforms ints whose value is one into constants for just-in-time compilation - constants = {i: arg for i, arg in enumerate(wargs) if isinstance(arg, int) and arg == 1 and i not in self.fn.do_not_specialize} - constants.update({i: arg.value for i, arg in enumerate(wargs) if isinstance(arg, triton.language.constexpr)}) - constants.update({i: None for i, arg in enumerate(wargs) if arg is None}) - arg_types = [Kernel._to_python_ir(arg) for i, arg in enumerate(wargs) if i not in constants] - return self.fn._warmup(key, arg_types=arg_types, device=device_idx, attributes=attributes, constants=constants, num_warps=num_warps, num_stages=num_stages, is_manual_warmup=False) - - def __call__(self, *wargs, grid, num_warps=4, num_stages=2, **kwargs): - # handle arguments passed by name - kwargs = {self.fn.arg_names.index(name): value for name, value in kwargs.items()} - wargs = list(wargs) - for i, pos in enumerate(sorted(kwargs)): - wargs.insert(pos + i, kwargs[pos]) - if len(wargs) != len(self.fn.arg_names): - raise TypeError(f"Function takes {len(self.fn.arg_names)} positional arguments but {len(wargs)} were given") - # handle annotations - for pos, _type in self.fn.annotations.items(): - wargs[pos] = _type(wargs[pos]) - # check that tensors are on GPU. - for arg in wargs: - if hasattr(arg, 'data_ptr'): - assert arg.is_cuda, "All tensors must be on GPU!" - # query device index and cuda stream - device = torch.cuda.current_device() - torch.cuda.set_device(device) - cc = torch.cuda.get_device_capability(device) - cc = str(cc[0]) + '-' + str(cc[1]) - # # query stream - # # this is hacky but much faster than `torch.cuda.current_stream(device).cuda_stream` - # # https://github.com/pytorch/pytorch/blob/master/c10/core/Stream.h#L154 - # # building a C wrapper to re-use the unpack function would add a build-time torch dependency - # # and require different wheels for different torch versions -- undesirable! - # bits = torch._C._cuda_getCurrentStream(device) - # mask = 1 << 47 - # stream = ((bits & 0xFFFFFFFFFFFF) ^ mask) - mask - stream = torch.cuda.current_stream(device).cuda_stream - # make key for cache - return _triton.runtime.launch(wargs, self.fn.do_not_specialize, self.fn.cache_key + cc, self.fn.arg_names, device, stream, - self.fn.bin_cache, num_warps, num_stages, self.add_to_cache, grid) - - -class Launcher: - def __init__(self, kernel, grid): - self.kernel = kernel - self.grid = grid - - def __call__(self, *wargs, **kwargs): - return self.kernel(*wargs, **kwargs, grid=self.grid) - - -class Autotuner: - def __init__(self, kernel, arg_names, configs, key, reset_to_zero, prune_configs_by: Dict = None): - ''' - :param prune_configs_by: a dict of functions that are used to prune configs, fields: - 'perf_model': performance model used to predicate running time with different configs, returns running time - 'top_k': number of configs to bench - 'prune_num_stages_by'(optional): a function used to prune num_stages. It take configs:List[Config] as its input, and returns pruned configs. - ''' - if not configs: - self.configs = [Config(dict(), num_warps=4, num_stages=2)] - else: - self.configs = configs - self.key_idx = [arg_names.index(k) for k in key] - self.cache = dict() - self.kernel = kernel - # hook to reset all required tensor to zeros before relaunching a kernel - self.hook = lambda args: 0 - if reset_to_zero is not None: - self.reset_idx = [arg_names.index(k) for k in reset_to_zero] - - def _hook(args): - for i in self.reset_idx: - args[i].zero_() - self.hook = _hook - self.arg_names = arg_names - # prune configs - if prune_configs_by: - perf_model, top_k = prune_configs_by['perf_model'], prune_configs_by['top_k'] - if 'early_config_prune' in prune_configs_by: - early_config_prune = prune_configs_by['early_config_prune'] - else: - perf_model, top_k, early_config_prune = None, None, None - self.perf_model, self.configs_top_k = perf_model, top_k - self.early_config_prune = early_config_prune - - def _bench(self, *args, config, **meta): - # check for conflicts, i.e. meta-parameters both provided - # as kwargs and by the autotuner - conflicts = meta.keys() & config.kwargs.keys() - if conflicts: - raise ValueError( - f"Conflicting meta-parameters: {', '.join(conflicts)}." - " Make sure that you don't re-define auto-tuned symbols." - ) - # augment meta-parameters with tunable ones - current = dict(meta, **config.kwargs) - - def kernel_call(): - if config.pre_hook: - config.pre_hook(self.nargs) - self.hook(args) - self.kernel(*args, num_warps=config.num_warps, num_stages=config.num_stages, **current) - return triton.testing.do_bench(kernel_call) - - def __call__(self, *args, **kwargs): - self.nargs = dict(zip(self.arg_names, args)) - if len(self.configs) > 1: - key = tuple([args[i] for i in self.key_idx]) - if key not in self.cache: - # prune configs - pruned_configs = self.configs - if self.early_config_prune: - pruned_configs = self.early_config_prune(self.configs, self.nargs) - if self.perf_model: - top_k = self.configs_top_k - if isinstance(top_k, float) and top_k <= 1.0: - top_k = int(len(self.configs) * top_k) - if len(pruned_configs) > top_k: - est_timing = {config: self.perf_model(**self.nargs, **kwargs, **config.kwargs, num_stages=config.num_stages, num_warps=config.num_warps) for config in pruned_configs} - pruned_configs = sorted(est_timing.keys(), key=lambda x: est_timing[x])[:top_k] - bench_start = time.time() - timings = {config: self._bench(*args, config=config, **kwargs) - for config in pruned_configs} - bench_end = time.time() - self.bench_time = bench_end - bench_start - self.cache[key] = builtins.min(timings, key=timings.get) - self.hook(args) - self.configs_timings = timings - config = self.cache[key] - else: - config = self.configs[0] - self.best_config = config - if config.pre_hook is not None: - config.pre_hook(self.nargs) - return self.kernel(*args, num_warps=config.num_warps, num_stages=config.num_stages, **kwargs, **config.kwargs) - - -@functools.lru_cache() -def version_key(): - import pkgutil - contents = [] - # frontend - with open(triton.code_gen.__file__, "rb") as f: - contents += [hashlib.md5(f.read()).hexdigest()] - # backend - with open(triton._C.libtriton.__file__, "rb") as f: - contents += [hashlib.md5(f.read()).hexdigest()] - # language - language_path = os.path.join(*triton.__path__, 'language') - for lib in pkgutil.iter_modules([language_path]): - with open(lib.module_finder.find_spec(lib.name).origin, "rb") as f: - contents += [hashlib.md5(f.read()).hexdigest()] - # ptxas version - try: - ptxas_version = hashlib.md5(subprocess.check_output(["ptxas", "--version"])).hexdigest() - except Exception: - ptxas_version = '' - return '-'.join(triton.__version__) + '-' + ptxas_version + '-' + '-'.join(contents) - - -class DependenciesFinder(ast.NodeVisitor): - - def __init__(self, globals, src) -> None: - super().__init__() - self.ret = hashlib.md5(src.encode("utf-8")).hexdigest() - self.globals = globals - - def visit_Name(self, node): - return self.globals.get(node.id, None) - - def visit_Attribute(self, node): - lhs = self.visit(node.value) - while isinstance(lhs, ast.Attribute): - lhs = self.visit(lhs.value) - if lhs is None or lhs is triton: - return None - return getattr(lhs, node.attr) - - def visit_Call(self, node): - func = self.visit(node.func) - if func is None: - return - if inspect.isbuiltin(func): - return - if func.__module__ and func.__module__.startswith('triton.'): - return - assert isinstance(func, triton.JITFunction) - if func.hash is None: - tree = ast.parse(func.src) - finder = DependenciesFinder(func.__globals__, func.src) - finder.visit(tree) - func.hash = finder.ret - self.ret = (self.ret + func.hash).encode("utf-8") - self.ret = hashlib.md5(self.ret).hexdigest() - - -class JITFunction: - - cache_hook = None - - def __init__(self, fn, version=None, inline=True, do_not_specialize=None): - # information of wrapped function - self.fn = fn - self.module = fn.__module__ - signature = inspect.signature(fn) - self.arg_names = [v.name for v in signature.parameters.values()] - self.arg_defaults = [v.default for v in signature.parameters.values()] - - self.version = version - self.inline = inline - self.src = textwrap.dedent(inspect.getsource(fn)) - self.src = self.src[self.src.find("def"):] - self.do_not_specialize = [] if do_not_specialize is None else do_not_specialize - self.do_not_specialize = [self.arg_names.index(arg) if isinstance(arg, str) else arg for arg in self.do_not_specialize] - # cache for callable driver objects (e.g. CUkernel) - self.bin_cache = dict() - self.hash = None - # JITFunction can be instantiated as kernel - # when called with a grid using __getitem__ - self.kernel_decorators = [] - self.kernel = None - # annotations - self.annotations = {self.arg_names.index(name): ty for name, ty in fn.__annotations__.items()} - self.__annotations__ = fn.__annotations__ - # constexprs - self.constexprs = [self.arg_names.index(ann) for ann in self.__annotations__.keys()] - # forward docs - self.__doc__ = fn.__doc__ - self.__name__ = fn.__name__ - self.__globals__ = fn.__globals__ - self.__module__ = fn.__module__ - - @property - @functools.lru_cache() - def cache_key(self): - # TODO : hash should be attribute of `self` - if self.hash is None: - dependencies_finder = DependenciesFinder(globals=self.__globals__, src=self.src) - dependencies_finder.visit(self.parse()) - self.hash = dependencies_finder.ret + version_key() - return self.hash - - # we do not parse `src` in the constructor because - # the user might want to monkey-patch self.src dynamically. - # Some unit tests do this, for example. - def parse(self): - tree = ast.parse(self.src) - assert isinstance(tree, ast.Module) - assert len(tree.body) == 1 - assert isinstance(tree.body[0], ast.FunctionDef) - return tree - - def __call__(self, *args, **kwargs): - raise RuntimeError("Cannot call @triton.jit'd outside of the scope of a kernel") - - # - when `.src` attribute is set, cache path needs - # to be reinitialized - # - when kernel decorators change, cached kernel - # needs to be cleared - def __setattr__(self, name, value): - if name == 'kernel_decorators': - self.kernel = None - super(JITFunction, self).__setattr__(name, value) - if name == 'src': - self.hash = None - JITFunction.cache_key.fget.cache_clear() - - def _init_kernel(self): - if self.kernel is None: - self.kernel = Kernel(self) - for decorator in reversed(self.kernel_decorators): - self.kernel = decorator(self.kernel) - return self.kernel - - def warmup(self, compile): - return self._warmup(**compile, is_manual_warmup=True) - - def _warmup(self, key, arg_types, device, attributes, constants, num_warps, num_stages, is_manual_warmup): - hashed_key = hashlib.md5(key.encode("utf-8")).hexdigest() - - # create cache directory - cache_dir = os.environ.get('TRITON_CACHE_DIR', '/tmp/triton/') - if cache_dir: - os.makedirs(cache_dir, exist_ok=True) - - if cache_dir: - bin_cache_path = os.path.join(cache_dir, hashed_key) - bin_lock_path = bin_cache_path + ".lock" - else: - bin_cache_path = None - bin_lock_path = None - - binary = None - if bin_cache_path and os.path.exists(bin_cache_path): - assert bin_lock_path is not None - with FileLock(bin_lock_path): - with open(bin_cache_path, 'rb') as f: - binary = pickle.load(f)["binary"] - - compile = dict(arg_types=arg_types, device=device, attributes=attributes, constants=constants, num_warps=num_warps, num_stages=num_stages) - if JITFunction.cache_hook is not None: - name = self.__name__ - info = key.split('-')[-3:] - num_warps, num_stages, sig = info[0], info[1], info[2].split('_')[1:] - # make signature human-readable - arg_reprs = [] - for arg_name, arg_sig in zip(self.arg_names, sig): - arg_reprs.append(f'{arg_name}: {arg_sig}') - # assemble the repr - arg_reprs = ", ".join(arg_reprs) - repr = f"{name}[num_warps={num_warps}, num_stages={num_stages}]({arg_reprs})" - noop = JITFunction.cache_hook(key=key, repr=repr, fn=self, compile={"key": key, **compile}, is_manual_warmup=is_manual_warmup, already_compiled=binary is not None) - if noop: - return True - - if binary is None: - binary = self._compile(**compile) - - if bin_cache_path: - assert bin_lock_path is not None - with FileLock(bin_lock_path): - with open(bin_cache_path + ".tmp", "wb") as f: - pickle.dump({"binary": binary, "key": key}, f) - os.rename(bin_cache_path + ".tmp", bin_cache_path) - - self.bin_cache[key] = LoadedBinary(device, binary) - return False - - def _compile(self, arg_types, device, attributes, constants, num_warps, num_stages): - # create IR module - context = _triton.ir.context() - context.load_triton() - # get just-in-time proto-type of kernel - arg_types = [Kernel._to_triton_ir(arg) for arg in arg_types] - ret_type = triton.language.void - prototype = triton.language.function_type([ret_type], arg_types) - # generate Triton-IR - # export symbols visible from self into code-generator object - gscope = self.__globals__ - generator = CodeGenerator(context, prototype, gscope=gscope, attributes=attributes, constants=constants, is_kernel=True) - try: - generator.visit(self.parse()) - except Exception as e: - node = generator.last_node - if node is None or isinstance(e, (NotImplementedError, CompilationError)): - raise e - raise CompilationError(self.src, node) from e - # Compile to machine code - if torch.version.hip is None: - backend = _triton.runtime.backend.CUDA - else: - backend = _triton.runtime.backend.ROCM - name, asm, shared_mem = _triton.code_gen.compile_ttir(backend, generator.module, device, num_warps, num_stages) - max_shared_memory = _triton.runtime.max_shared_memory(backend, device) - if shared_mem > max_shared_memory: - raise OutOfResources(shared_mem, max_shared_memory, "shared memory") - return Binary(backend, name, asm, shared_mem, num_warps) - - # Compile to ttir, for the propose of testing MLIR rewriting - def compile_to_ttir(self, *wargs, grid, num_warps=4, num_stages=2, **kwargs): - # TODO: share code with _compile & __call__ - # handle arguments passed by name - kwargs = {self.arg_names.index(name): value for name, value in kwargs.items()} - wargs = list(wargs) - for i, pos in enumerate(sorted(kwargs)): - wargs.insert(pos + i, kwargs[pos]) - if len(wargs) != len(self.arg_names): - raise TypeError(f"Function takes {len(self.arg_names)} positional arguments but {len(wargs)} were given") - # handle annotations - for pos, _type in self.annotations.items(): - wargs[pos] = _type(wargs[pos]) - # preparing args - tensor_idxs = [i for i, arg in enumerate(wargs) if hasattr(arg, 'data_ptr')] - # attributes - attributes = dict() - for i, arg in enumerate(wargs): - if isinstance(arg, int): - attributes[i] = Kernel.pow2_divisor(arg) - elif i in tensor_idxs: - addr = arg.data_ptr() - range_size = _triton.runtime.get_pointer_range_size(addr) - attributes[i] = min(Kernel.pow2_divisor(addr), - Kernel.pow2_divisor(range_size)) - # transforms ints whose value is one into constants for just-in-time compilation - constants = {i: arg for i, arg in enumerate(wargs) if isinstance(arg, int) and arg == 1 and i not in self.do_not_specialize} - constants.update({i: arg.value for i, arg in enumerate(wargs) if isinstance(arg, triton.language.constexpr)}) - constants.update({i: None for i, arg in enumerate(wargs) if arg is None}) - arg_types = [Kernel._to_python_ir(arg) for i, arg in enumerate(wargs) if i not in constants] - - # create IR module - context = _triton.ir.context() - context.load_triton() - # get just-in-time proto-type of kernel - arg_types = [Kernel._to_triton_ir(arg) for arg in arg_types] - ret_types = [] - prototype = triton.language.function_type(ret_types, arg_types) - # generate Triton-IR - # export symbols visible from self into code-generator object - gscope = self.__globals__ - generator = CodeGenerator(context, prototype, gscope=gscope, attributes=attributes, constants=constants, is_kernel=True) - try: - generator.visit(self.parse()) - except Exception as e: - node = generator.last_node - if node is None or isinstance(e, (NotImplementedError, CompilationError)): - raise e - raise CompilationError(self.src, node) from e - # cache num_warps & num_stages - self.num_warps, self.num_stages = num_warps, num_stages - # run simple SCCP and DCE here to clean-up the generated IR - mod = generator.module - pm = _triton.ir.pass_manager(context) - pm.add_canonicalizer_pass() - pm.run(mod) - # FIXME: now we need to return context, otherwise it will be deleted - return mod, context - - def compile_ttir_to_llir(self, mod, ctx): - num_warps, num_stages = self.num_warps, self.num_stages - pm = _triton.ir.pass_manager(ctx) - pm.add_inliner_pass() - pm.add_triton_combine_pass() - pm.add_canonicalizer_pass() - pm.add_cse_pass() - pm.add_convert_triton_to_tritongpu_pass(num_warps) - pm.add_tritongpu_pipeline_pass(num_stages) - pm.add_canonicalizer_pass() - pm.add_cse_pass() - pm.add_triton_gpu_combine_pass() - pm.add_triton_gpu_verifier_pass() - return pm.run(mod) - - - def __getitem__(self, grid): - return Launcher(self._init_kernel(), grid) - - def __repr__(self): - return f"JITFunction({self.module}:{self.fn.__name__})" - - -class Config: - """ - An object that represents a possible kernel configuration for the auto-tuner to try. - - :ivar meta: a dictionary of meta-parameters to pass to the kernel as keyword arguments. - :type meta: dict[Str, Any] - :ivar num_warps: the number of warps to use for the kernel when compiled for GPUs. For example, if - `num_warps=8`, then each kernel instance will be automatically parallelized to - cooperatively execute using `8 * 32 = 256` threads. - :type num_warps: int - :ivar num_stages: the number of stages that the compiler should use when software-pipelining loops. - Mostly useful for matrix multiplication workloads on SM80+ GPUs. - :type num_stages: int - :ivar pre_hook: a function that will be called before the kernel is called. Parameters of this - function are args. - """ - - def __init__(self, kwargs, num_warps=4, num_stages=2, pre_hook=None): - self.kwargs = kwargs - self.num_warps = num_warps - self.num_stages = num_stages - self.pre_hook = pre_hook - - def __str__(self): - res = [] - for k, v in self.kwargs.items(): - res.append(f'{k}: {v}') - res.append(f'num_warps: {self.num_warps}') - res.append(f'num_stages: {self.num_stages}') - return ', '.join(res) - - -def autotune(configs, key, prune_configs_by=None, reset_to_zero=None): - """ - Decorator for auto-tuning a :code:`triton.jit`'d function. - - .. highlight:: python - .. code-block:: python - - @triton.autotune(configs=[ - triton.Config(meta={'BLOCK_SIZE': 128}, num_warps=4), - triton.Config(meta={'BLOCK_SIZE': 1024}, num_warps=8), - ], - key=['x_size'] # the two above configs will be evaluated anytime - # the value of x_size changes - ) - @triton.jit - def kernel(x_ptr, x_size, **META): - BLOCK_SIZE = META['BLOCK_SIZE'] - - :note: When all the configurations are evaluated, the kernel will run multiple time. - This means that whatever value the kernel updates will be updated multiple times. - To avoid this undesired behavior, you can use the `reset_to_zero` argument, which - reset the value of the provided tensor to `zero` before running any configuration. - - :param configs: a list of :code:`triton.Config` objects - :type configs: list[triton.Config] - :param key: a list of argument names whose change in value will trigger the evaluation of all provided configs. - :type key: list[str] - :param prune_configs_by: a dict of functions that are used to prune configs, fields: - 'perf_model': performance model used to predicate running time with different configs, returns running time - 'top_k': number of configs to bench - 'early_config_prune'(optional): a function used to do early prune (eg, num_stages). It take configs:List[Config] as its input, and returns pruned configs. - :param reset_to_zero: a list of argument names whose value will be reset to zero before evaluating any configs. - :type reset_to_zero: list[str] - """ - def decorator(fn): - def wrapper(kernel): - return Autotuner(kernel, fn.arg_names, configs, key, reset_to_zero, prune_configs_by) - - fn.kernel_decorators.append(wrapper) - return fn - - return decorator - - -def heuristics(values): - """ - Decorator for specifying how the values of certain meta-parameters may be computed. - This is useful for cases where auto-tuning is prohibitevely expensive, or just not applicable. - - .. highlight:: python - .. code-block:: python - - @triton.heuristics(values={'BLOCK_SIZE': lambda args: 2 ** int(math.ceil(math.log2(args[1])))}) - @triton.jit - def kernel(x_ptr, x_size, **META): - BLOCK_SIZE = META['BLOCK_SIZE'] # smallest power-of-two >= x_size - - - .param values: a dictionary of meta-parameter names and functions that compute the value of the meta-parameter. - each such function takes a list of positional arguments as input. - .type values: dict[str, Callable[[list[Any]], Any]] - """ - def decorator(fn): - def wrapper(kernel): - def fun(*args, **meta): - for v, heur in values.items(): - assert v not in meta - meta[v] = heur({**dict(zip(fn.arg_names, args)), **meta}) - return kernel(*args, **meta) - return fun - - fn.kernel_decorators.append(wrapper) - return fn - - return decorator - - -def jit(*args, **kwargs): - """ - Decorator for JIT-compiling a function using the Triton compiler. - - :note: When a jit'd function is called, :code:`torch.tensor` arguments are implicitly converted to pointers using the :code:`.data_ptr()` method. - - :note: This function will be compiled and run on the GPU. It will only have access to: - - * python primitives, - * objects within the triton.language package, - * arguments to this function, - * other jit'd functions - - :param fn: the function to be jit-compiled - :type fn: Callable - """ - if args: - assert len(args) == 1 - assert callable(args[0]) - return JITFunction(args[0], **kwargs) - else: - def decorator(fn): - return JITFunction(fn, **kwargs) - return decorator - - -###### - -def cdiv(x, y): - return (x + y - 1) // y - - -def next_power_of_2(n): - """Return the smallest power of 2 greater than or equal to n""" - n -= 1 - n |= n >> 1 - n |= n >> 2 - n |= n >> 4 - n |= n >> 8 - n |= n >> 16 - n += 1 - return n - -###### - - -class TensorWrapper: - def __init__(self, base, dtype): - self.dtype = dtype - self.base = base - self.is_cuda = base.is_cuda - self.device = base.device - - def data_ptr(self): - return self.base.data_ptr() - - def __str__(self) -> str: - return f'TensorWrapper[{self.dtype}]({self.base})' - - -def reinterpret(tensor, dtype): - if isinstance(tensor, TensorWrapper): - if dtype == tensor.base.dtype: - # Reinterpreting to the original interpretation; return the base. - return tensor.base - else: - # Reinterpreting a wrapped tensor to a different type. - return TensorWrapper(tensor.base, dtype) - elif isinstance(tensor, torch.Tensor): - # A new wrapper is needed around an unwrapped tensor. - return TensorWrapper(tensor, dtype) - else: - raise TypeError(f'Cannot reinterpret a {type(tensor)}.') diff --git a/python/triton/compiler.py b/python/triton/compiler.py new file mode 100644 index 000000000000..6449de4211e4 --- /dev/null +++ b/python/triton/compiler.py @@ -0,0 +1,806 @@ +from __future__ import annotations +import ast +import sys +import warnings +from typing import Dict, Union +import triton +import triton._C.libtriton.triton as _triton + + +def str_to_ty(name): + if name[0] == "*": + ty = str_to_ty(name[1:]) + return triton.language.pointer_type(ty) + tys = { + "fp8": triton.language.float8, + "fp16": triton.language.float16, + "bf16": triton.language.bfloat16, + "fp32": triton.language.float32, + "fp64": triton.language.float64, + "i8": triton.language.int8, + "i16": triton.language.int16, + "i32": triton.language.int32, + "i64": triton.language.int64, + "u8": triton.language.uint8, + "u16": triton.language.uint16, + "u32": triton.language.uint32, + "u64": triton.language.uint64, + "B": triton.language.int1, + } + return tys[name] + +def mangle_ty(ty): + if ty.is_ptr(): + return 'P' + mangle_ty(ty.element_ty) + if ty.is_int(): + return 'i' + str(ty.int_bitwidth) + if ty.is_fp8(): + return 'fp8' + if ty.is_fp16(): + return 'fp16' + if ty.is_bf16(): + return 'bf16' + if ty.is_fp32(): + return 'fp32' + if ty.is_fp64(): + return 'fp64' + if ty.is_void(): + return 'V' + if ty.is_block(): + elt = mangle_ty(ty.scalar) + shape = '_'.join(map(str, ty.shape)) + return f'{elt}S{shape}S' + assert False, "Unsupport type" + + +def mangle_fn(name, arg_tys, constants): + # doesn't mangle ret type, which must be a function of arg tys + mangled_arg_names = '_'.join([mangle_ty(ty) for ty in arg_tys]) + mangled_constants = '_'.join([f'{i}c{repr(constants[i])}' for i in sorted(constants)]) + mangled_constants = mangled_constants.replace('.', '_d_') + mangled_constants = mangled_constants.replace("'", '_sq_') + ret = f'{name}__{mangled_arg_names}__{mangled_constants}' + return ret + +class enter_sub_region: + def __init__(self, generator: CodeGenerator): + self.generator = generator + + def __enter__(self): + # record lscope & local_defs in the parent scope + self.liveins = self.generator.lscope.copy() + self.prev_defs = self.generator.local_defs.copy() + self.generator.local_defs = {} + self.insert_block = self.generator.builder.get_insertion_block() + return self.liveins, self.insert_block + + def __exit__(self, *args, **kwargs): + self.generator.builder.set_insertion_point_to_end(self.insert_block) + self.generator.lscope = self.liveins + self.generator.local_defs = self.prev_defs + +class CodeGenerator(ast.NodeVisitor): + def __init__(self, context, prototype, gscope, attributes, constants, module=None, is_kernel=False, function_types=dict()): + self.builder = _triton.ir.builder(context) + self.module = self.builder.create_module() if module is None else module + self.function_ret_types = function_types + self.prototype = prototype + self.gscope = gscope + self.lscope = dict() + self.attributes = attributes + self.constants = constants + self.is_kernel = is_kernel + self.last_node = None + self.builtins = { + 'range': range, + 'min': triton.language.minimum, + 'float': float, + 'int': int, + 'print': print, + 'isinstance': isinstance, + 'getattr': getattr, + } + # SSA-construction + # name => triton.language.tensor + self.local_defs: Dict[str, triton.language.tensor] = {} + self.global_uses: Dict[str, triton.language.tensor] = {} + + def get_value(self, name): + ''' This function: + 1. make sure `name` is defined + 2. if `name` is triton.language.tensor, get stored tensor by calling + `self._get_tensor()` + ''' + # search node.id in local scope + ret = None + if name in self.lscope: + ret = self.lscope[name] + if name not in self.local_defs: + self.global_uses[name] = ret + # search node.id in global scope + elif name in self.gscope: + ret = self.gscope[name] + # search node.id in builtins + elif name in self.builtins: + ret = self.builtins[name] + else: + raise ValueError(f'{name} is not defined') + return ret + + def set_value(self, name: str, + value: Union[triton.language.tensor, triton.language.constexpr]) -> None: + ''' This function: + called by visit_Assign() & visit_FuncDef() to store left value (lvalue) + 1. record local defined name (FIXME: should consider control flow) + 2. store tensor in self.lvalue + ''' + self.lscope[name] = value + self.local_defs[name] = value + + def is_triton_tensor(self, value): + return isinstance(value, triton.language.tensor) + + # + # AST visitor + # + def visit_compound_statement(self, stmts): + for stmt in stmts: + self.last_ret_type = self.visit(stmt) + if isinstance(stmt, ast.Return): + break + return stmts and isinstance(stmt, ast.Return) + + def visit_Module(self, node): + ast.NodeVisitor.generic_visit(self, node) + + def visit_List(self, node): + ctx = self.visit(node.ctx) + assert ctx is None + elts = [self.visit(elt) for elt in node.elts] + return elts + + # By design, only non-kernel functions can return + def visit_Return(self, node): + ret_value = self.visit(node.value) + if ret_value is None: + self.builder.ret([]) + return None + if isinstance(ret_value, tuple): + ret_values = [triton.language.core._to_tensor(v, self.builder) for v in ret_value] + ret_types = [v.type for v in ret_values] + self.builder.ret([v.handle for v in ret_values]) + return tuple(ret_types) + else: + ret = triton.language.core._to_tensor(ret_value, self.builder) + self.builder.ret([ret_value.handle]) + return ret.type + + def visit_FunctionDef(self, node): + arg_names, kwarg_names = self.visit(node.args) + # initialize defaults + for i, default_value in enumerate(node.args.defaults): + arg_node = node.args.args[-i - 1] + annotation = arg_node.annotation + name = arg_node.arg + st_target = ast.Name(id=name, ctx=ast.Store()) + if annotation is None: + init_node = ast.Assign(targets=[st_target], value=default_value) + else: + init_node = ast.AnnAssign(target=st_target, value=default_value, annotation=annotation) + self.visit(init_node) + # initialize function + fn_name = mangle_fn(node.name, self.prototype.param_types, self.constants) + fn = self.builder.get_or_insert_function(self.module, fn_name, self.prototype.to_ir(self.builder)) + self.module.push_back(fn) + entry = fn.add_entry_block() + arg_values = [] + idx = 0 + for i, arg_name in enumerate(arg_names): + if i in self.constants: + cst = self.constants[i] + if not isinstance(cst, triton.language.constexpr): + cst = triton.language.constexpr(self.constants[i]) + arg_values.append(cst) + else: + pass + if i in self.attributes: + fn.set_arg_attr(idx, "tt.divisibility", self.attributes[i]) + arg_values.append(triton.language.tensor(fn.args(idx), self.prototype.param_types[idx])) + idx += 1 + + insert_pt = self.builder.get_insertion_block() + for arg_name, arg_value in zip(arg_names, arg_values): + self.set_value(arg_name, arg_value) + self.builder.set_insertion_point_to_start(entry) + # visit function body + has_ret = self.visit_compound_statement(node.body) + # finalize function + if not has_ret: + self.builder.ret([]) + else: + # update return type + if isinstance(self.last_ret_type, tuple): + self.prototype.ret_types = list(self.last_ret_type) + fn.reset_type(self.prototype.to_ir(self.builder)) + else: + self.prototype.ret_types = [self.last_ret_type] + fn.reset_type(self.prototype.to_ir(self.builder)) + if insert_pt: + self.builder.set_insertion_point_to_end(insert_pt) + + def visit_arguments(self, node): + arg_names = [] + for arg in node.args: + arg_names += [self.visit(arg)] + kwarg_names = self.visit(node.kwarg) + return arg_names, kwarg_names + + def visit_arg(self, node): + ast.NodeVisitor.generic_visit(self, node) + return node.arg + + def visit_AnnAssign(self, node): + # extract attributes + annotation = self.visit(node.annotation) + target = self.visit(node.target) + value = self.visit(node.value) + # constexpr + if annotation == triton.language.constexpr: + if target in self.lscope: + raise ValueError(f'{target} is already defined.' + f' constexpr cannot be reassigned.') + if not isinstance(value, triton.language.constexpr): + value = triton.language.constexpr(value) + self.lscope[target] = value + return self.lscope[target] + # default: call visit_Assign + return self.visit_Assign(node) + + def visit_Assign(self, node): + _names = [] + for target in node.targets: + _names += [self.visit(target)] + assert len(_names) == 1 + names = _names[0] + values = self.visit(node.value) + if not isinstance(names, tuple): + names = [names] + if not isinstance(values, tuple): + values = [values] + for name, value in zip(names, values): + # by default, constexpr are assigned into python variable + if isinstance(value, triton.language.constexpr): + value = value.value + if not isinstance(value, triton.language.tensor): + value = triton.language.core._to_tensor(value, self.builder) + self.set_value(name, value) + + def visit_AugAssign(self, node): + name = node.target.id + lhs = ast.Name(id=name, ctx=ast.Load()) + rhs = ast.BinOp(lhs, node.op, node.value) + assign = ast.Assign(targets=[node.target], value=rhs) + self.visit(assign) + return self.get_value(name) + + def visit_Name(self, node): + if type(node.ctx) == ast.Store: + return node.id + return self.get_value(node.id) + + def visit_Store(self, node): + ast.NodeVisitor.generic_visit(self, node) + + def visit_Load(self, node): + ast.NodeVisitor.generic_visit(self, node) + + def visit_Tuple(self, node): + args = [self.visit(x) for x in node.elts] + return tuple(args) + + def visit_BinOp(self, node): + lhs = self.visit(node.left) + rhs = self.visit(node.right) + if isinstance(lhs, triton.language.constexpr): + lhs = lhs.value + if isinstance(rhs, triton.language.constexpr): + rhs = rhs.value + fn = { + ast.Add: '__add__', + ast.Sub: '__sub__', + ast.Mult: '__mul__', + ast.Div: '__truediv__', + ast.FloorDiv: '__floordiv__', + ast.Mod: '__mod__', + ast.Pow: '__pow__', + ast.LShift: '__lshift__', + ast.RShift: '__rshift__', + ast.BitAnd: '__and__', + ast.BitOr: '__or__', + ast.BitXor: '__xor__', + }[type(node.op)] + if self.is_triton_tensor(lhs): + return getattr(lhs, fn)(rhs, _builder=self.builder) + elif self.is_triton_tensor(rhs): + fn = fn[:2] + 'r' + fn[2:] + return getattr(rhs, fn)(lhs, _builder=self.builder) + else: + return getattr(lhs, fn)(rhs) + + def visit_If(self, node): + cond = self.visit(node.test) + if isinstance(cond, triton.language.tensor): + cond = cond.to(triton.language.int1, _builder=self.builder) + with enter_sub_region(self) as sr: + liveins, ip_block = sr + + then_block = self.builder.create_block() + self.builder.set_insertion_point_to_start(then_block) + self.visit_compound_statement(node.body) + then_defs = self.local_defs.copy() + + # when need an else block when: + # 1. we have an orelse node + # or + # 2. the then block defines new variable + if then_defs or node.orelse: + if node.orelse: + self.lscope = liveins + self.local_defs = {} + else_block = self.builder.create_block() + self.builder.set_insertion_point_to_end(else_block) + self.visit_compound_statement(node.orelse) + else_defs = self.local_defs.copy() + else: + # collect else_defs + else_defs = {} + for name in then_defs: + if name in liveins: + assert self.is_triton_tensor(then_defs[name]) + assert self.is_triton_tensor(liveins[name]) + else_defs[name] = liveins[name] + # collect yields + names = [] + ret_types = [] + for then_name in then_defs: + for else_name in else_defs: + if then_name == else_name: + if then_defs[then_name].type == else_defs[else_name].type: + names.append(then_name) + ret_types.append(then_defs[then_name].type) + + self.builder.set_insertion_point_to_end(ip_block) + + if then_defs or node.orelse: # with else block + if_op = self.builder.create_if_op([ty.to_ir(self.builder) for ty in ret_types], cond.handle, True) + then_block.merge_block_before(if_op.get_then_block()) + self.builder.set_insertion_point_to_end(if_op.get_then_block()) + self.builder.create_yield_op([then_defs[n].handle for n in names]) + if not node.orelse: + else_block = if_op.get_else_block() + else: + else_block.merge_block_before(if_op.get_else_block()) + self.builder.set_insertion_point_to_end(if_op.get_else_block()) + self.builder.create_yield_op([else_defs[n].handle for n in names]) + else: # no else block + if_op = self.builder.create_if_op([ty.to_ir(self.builder) for ty in ret_types], cond.handle, False) + then_block.merge_block_before(if_op.get_then_block()) + + # update values yielded by IfOp + for i, name in enumerate(names): + new_tensor = triton.language.core.tensor(if_op.get_result(i), ret_types[i]) + self.lscope[name] = new_tensor + self.local_defs[name] = new_tensor + + else: + if isinstance(cond, triton.language.constexpr): + cond = cond.value + if cond: + self.visit_compound_statement(node.body) + else: + self.visit_compound_statement(node.orelse) + + def visit_IfExp(self, node): + cond = self.visit(node.test) + if cond.value: + return self.visit(node.body) + else: + return self.visit(node.orelse) + + def visit_Pass(self, node): + pass + + def visit_Compare(self, node): + assert len(node.comparators) == 1 + assert len(node.ops) == 1 + lhs = self.visit(node.left) + rhs = self.visit(node.comparators[0]) + if isinstance(lhs, triton.language.constexpr): + lhs = lhs.value + if isinstance(rhs, triton.language.constexpr): + rhs = rhs.value + if type(node.ops[0]) == ast.Is: + return triton.language.constexpr(lhs is rhs) + if type(node.ops[0]) == ast.IsNot: + return triton.language.constexpr(lhs is not rhs) + fn = { + ast.Eq: '__eq__', + ast.NotEq: '__ne__', + ast.Lt: '__lt__', + ast.LtE: '__le__', + ast.Gt: '__gt__', + ast.GtE: '__ge__', + }[type(node.ops[0])] + if self.is_triton_tensor(lhs): + return getattr(lhs, fn)(rhs, _builder=self.builder) + elif self.is_triton_tensor(rhs): + fn = fn[:2] + 'r' + fn[2:] + return getattr(rhs, fn)(lhs, _builder=self.builder) + else: + return getattr(lhs, fn)(rhs) + + def visit_UnaryOp(self, node): + op = self.visit(node.operand) + if type(node.op) == ast.Not: + assert isinstance(op, triton.language.constexpr), "`not` only supported for constexpr at the moment" + return triton.language.constexpr(not op) + if isinstance(op, triton.language.constexpr): + op = op.value + fn = { + ast.USub: '__neg__', + ast.UAdd: '__pos__', + ast.Invert: '__invert__', + }[type(node.op)] + if self.is_triton_tensor(op): + return getattr(op, fn)(_builder=self.builder) + return getattr(op, fn)() + + def visit_While(self, node): + with enter_sub_region(self) as sr: + liveins, insert_block = sr + + # condtion (the before region) + cond_block = self.builder.create_block() + self.builder.set_insertion_point_to_start(cond_block) + cond = self.visit(node.test) + + # loop body (the after region) + loop_block = self.builder.create_block() + self.builder.set_insertion_point_to_start(loop_block) + self.visit_compound_statement(node.body) + loop_defs = self.local_defs + + # collect loop-carried values + names = [] + ret_types = [] + init_args = [] + yields = [] + for name in loop_defs: + if name in liveins: + # We should not def new constexpr + assert self.is_triton_tensor(loop_defs[name]) + assert self.is_triton_tensor(liveins[name]) + if loop_defs[name].type == liveins[name].type: + # these are loop-carried values + names.append(name) + ret_types.append(loop_defs[name].type) + init_args.append(liveins[name]) + yields.append(loop_defs[name]) + + self.builder.set_insertion_point_to_end(insert_block) + while_op = self.builder.create_while_op([ty.to_ir(self.builder) for ty in ret_types], + [arg.handle for arg in init_args]) + # merge the condition region + before_block = self.builder.create_block_with_parent(while_op.get_before(), + [ty.to_ir(self.builder) for ty in ret_types]) + cond_block.merge_block_before(before_block) + self.builder.set_insertion_point_to_end(before_block) + # create CondtionOp: e.g., scf.condition(%cond) %arg0, %arg1, ... + self.builder.create_condtion_op(cond.handle, [before_block.arg(i) for i in range(len(init_args))]) + # merge the loop body + after_block = self.builder.create_block_with_parent(while_op.get_after(), + [ty.to_ir(self.builder) for ty in ret_types]) + loop_block.merge_block_before(after_block) + self.builder.set_insertion_point_to_end(after_block) + self.builder.create_yield_op([y.handle for y in yields]) + + # update global uses in while_op + for i, name in enumerate(names): + before_block.replace_use_in_block_with(init_args[i].handle, before_block.arg(i)) + after_block.replace_use_in_block_with(init_args[i].handle, after_block.arg(i)) + + # WhileOp defines new values, update the symbol table (lscope, local_defs) + for i, name in enumerate(names): + new_def = triton.language.core.tensor(while_op.get_result(i), ret_types[i]) + self.lscope[name] = new_def + self.local_defs[name] = new_def + + for stmt in node.orelse: + assert False, "Not implemented" + ast.NodeVisitor.generic_visit(self, stmt) + + def visit_Subscript(self, node): + assert node.ctx.__class__.__name__ == "Load" + lhs = self.visit(node.value) + slices = self.visit(node.slice) + if self.is_triton_tensor(lhs): + return lhs.__getitem__(slices, _builder=self.builder) + return lhs[slices] + + def visit_ExtSlice(self, node): + return [self.visit(dim) for dim in node.dims] + + def visit_For(self, node): + iterator = self.visit(node.iter.func) + if iterator != self.builtins['range']: + raise RuntimeError('Only `range` iterator currently supported') + # static for loops: all iterator arguments are constexpr + iter_args = [self.visit(arg) for arg in node.iter.args] + is_static = all([isinstance(x, triton.language.constexpr) for x in iter_args]) + if is_static: + st_target = ast.Name(id=node.target.id, ctx=ast.Store()) + iter_args = [arg.value for arg in iter_args] + range = iterator(*iter_args) + if len(range) <= 10: + for i in iterator(*iter_args): + self.lscope[node.target.id] = triton.language.constexpr(i) + self.visit_compound_statement(node.body) + for stmt in node.orelse: + ast.NodeVisitor.generic_visit(self, stmt) + return + + # collect lower bound (lb), upper bound (ub), and step + lb = self.visit(node.iter.args[0] if len(node.iter.args) > 1 else ast.Num(0)) + ub = self.visit(node.iter.args[1] if len(node.iter.args) > 1 else node.iter.args[0]) + step = self.visit(node.iter.args[2] if len(node.iter.args) > 2 else ast.Num(1)) + # lb/ub/step might be constexpr, we need to cast them to tensor + lb = triton.language.core._to_tensor(lb, self.builder).handle + ub = triton.language.core._to_tensor(ub, self.builder).handle + step = triton.language.core._to_tensor(step, self.builder).handle + # ForOp can only accept IndexType as lb/ub/step. Cast integer to Index + lb = self.builder.create_to_index(lb) + ub = self.builder.create_to_index(ub) + step = self.builder.create_to_index(step) + + with enter_sub_region(self) as sr: + liveins, insert_block = sr + + # create loop body block + block = self.builder.create_block() + self.builder.set_insertion_point_to_start(block) + + # visit loop body + self.visit_compound_statement(node.body) + + # If a variable (name) is defined in both its parent & itself, then it's + # a loop-carried variable. (They must be of the same type) + init_args = [] + yields = [] + names = [] + for name in self.local_defs: + if name in liveins: + assert self.is_triton_tensor(self.local_defs[name]), f'{name} is not tensor' + assert self.is_triton_tensor(liveins[name]) + if self.local_defs[name].type == liveins[name].type: + names.append(name) + init_args.append(triton.language.core._to_tensor(liveins[name], self.builder)) + yields.append(triton.language.core._to_tensor(self.local_defs[name], self.builder)) + # create ForOp + self.builder.set_insertion_point_to_end(insert_block) + for_op = self.builder.create_for_op(lb, ub, step, [arg.handle for arg in init_args]) + block.merge_block_before(for_op.get_body(0)) + # create YieldOp + self.builder.set_insertion_point_to_end(for_op.get_body(0)) + self.builder.create_yield_op([y.handle for y in yields]) + for_op_region = for_op.get_body(0).get_parent() + assert for_op_region.size() == 1, "We use SCF, so the loop body should only have one block" + # replace global uses with block arguments + for i, name in enumerate(names): + # arg0 is the induction variable + for_op.get_body(0).replace_use_in_block_with(init_args[i].handle, for_op.get_body(0).arg(i+1)) + + # update lscope & local_defs (ForOp defines new values) + for i, name in enumerate(names): + self.lscope[name] = triton.language.core.tensor(for_op.get_result(i), yields[i].type) + self.local_defs[name] = triton.language.core.tensor(for_op.get_result(i), yields[i].type) + + for stmt in node.orelse: + assert False, "Don't know what to do with else after for" + ast.NodeVisitor.generic_visit(self, stmt) + + def visit_Slice(self, node): + lower = self.visit(node.lower) + upper = self.visit(node.upper) + step = self.visit(node.step) + return slice(lower, upper, step) + + def visit_Index(self, node): + return self.visit(node.value) + + def visit_keyword(self, node): + return {node.arg: self.visit(node.value)} + + def visit_Call(self, node): + fn = self.visit(node.func) + if isinstance(fn, triton.language.constexpr): + fn = fn.value + kws = dict() + for keyword in node.keywords: + kws.update(self.visit(keyword)) + args = [self.visit(arg) for arg in node.args] + if isinstance(fn, triton.runtime.JITFunction): + from inspect import getcallargs + args = getcallargs(fn.fn, *args, **kws) + args = [args[name] for name in fn.arg_names] + args = [arg if isinstance(arg, triton.language.tensor) + else triton.language.constexpr(arg) for arg in args] + # generate function def + attributes = dict() + constexprs = [i for i, arg in enumerate(args) if isinstance(arg, triton.language.constexpr)] + constants = {i: args[i] for i in constexprs} + # generate call + args = [None if i in constexprs else arg for i, arg in enumerate(args)] + arg_vals = [arg.handle for arg in args if arg is not None] + arg_types = [arg.type for arg in args if arg is not None] + fn_name = mangle_fn(fn.__name__, arg_types, constants) + # generate function def if necessary + if not self.module.has_function(fn_name): + ret_type = triton.language.void + prototype = triton.language.function_type([ret_type], arg_types) + gscope = sys.modules[fn.fn.__module__].__dict__ + generator = CodeGenerator(self.builder.context, prototype, gscope, attributes, constants, module=self.module, function_types=self.function_ret_types) + generator.visit(fn.parse()) + callee_ret_type = generator.last_ret_type + self.function_ret_types[fn_name] = callee_ret_type + else: + callee_ret_type = self.function_ret_types[fn_name] + symbol = self.module.get_function(fn_name) + call_op = self.builder.call(symbol, arg_vals) + if call_op.get_num_results() == 0: + return None + elif call_op.get_num_results() == 1: + return triton.language.tensor(call_op.get_result(0), callee_ret_type) + else: + # should return a tuple of tl.tensor + results = [] + for i in range(call_op.get_num_results()): + results.append(triton.language.tensor(call_op.get_result(i), callee_ret_type[i])) + return tuple(results) + if hasattr(fn, '__self__') and self.is_triton_tensor(fn.__self__) or \ + sys.modules[fn.__module__] is triton.language.core: + return fn(*args, _builder=self.builder, **kws) + if fn in self.builtins.values(): + args = [arg.value if isinstance(arg, triton.language.constexpr) else arg + for arg in args] + return fn(*args, **kws) + + def visit_Constant(self, node): + return triton.language.constexpr(node.value) + + if sys.version_info < (3, 8): + def visit_NameConstant(self, node): + return triton.language.constexpr(node.value) + + def visit_Num(self, node): + return triton.language.constexpr(node.n) + + def visit_Str(self, node): + return triton.language.constexpr(ast.literal_eval(node)) + + def visit_Attribute(self, node): + lhs = self.visit(node.value) + return getattr(lhs, node.attr) + + def visit_Expr(self, node): + ast.NodeVisitor.generic_visit(self, node) + + def visit_NoneType(self, node): + return None + + def visit(self, node): + if node is not None: + self.last_node = node + with warnings.catch_warnings(): + # The ast library added visit_Constant and deprecated some other + # methods but we can't move to that without breaking Python 3.6 and 3.7. + warnings.simplefilter("ignore", DeprecationWarning) # python 3.9 + warnings.simplefilter("ignore", PendingDeprecationWarning) # python 3.8 + return super().visit(node) + + def generic_visit(self, node): + typename = type(node).__name__ + raise NotImplementedError("Unsupported node: {}".format(typename)) + + + +class CompilationError(Exception): + def __init__(self, src, node): + self.message = f'at {node.lineno}:{node.col_offset}:\n' + self.message += '\n'.join(src.split('\n')[:node.lineno]) + self.message += '\n' + ' ' * node.col_offset + '^' + self.src = src + self.node = node + super().__init__(self.message) + + def __reduce__(self): + # this is necessary to make CompilationError picklable + return (type(self), (self.src, self.node)) + + +class OutOfResources(Exception): + def __init__(self, required, limit, name): + self.message = f'out of resource: {name}, '\ + f'Required: {required}, '\ + f'Hardware limit: {limit}' + self.required = required + self.limit = limit + self.name = name + super().__init__(self.message) + + def __reduce__(self): + # this is necessary to make CompilationError picklable + return (type(self), (self.required, self.limit, self.name)) + + +def make_triton_ir(fn, signature, constants = dict(), attributes = dict()): + context = _triton.ir.context() + context.load_triton() + # create kernel prototype + arg_types = signature.replace(' ','').split(',') + constants = {fn.arg_names.index(name): value for name, value in constants.items()} + arg_types = [str_to_ty(x) for x in arg_types] + prototype = triton.language.function_type([], arg_types) + # visit kernel AST + gscope = fn.__globals__.copy() + generator = CodeGenerator(context, prototype, gscope=gscope, constants=constants, attributes=attributes, is_kernel=True) + try: + generator.visit(fn.parse()) + except Exception as e: + node = generator.last_node + if node is None or isinstance(e, (NotImplementedError, CompilationError)): + raise e + raise CompilationError(fn.src, node) from e + ret = generator.module + # module takes ownership of the MLIR context + ret.context = context + return ret + +def make_tritongpu_ir(mod, num_warps): + pm = _triton.ir.pass_manager(mod.context) + pm.add_inliner_pass() + pm.add_triton_combine_pass() + pm.add_canonicalizer_pass() + pm.add_cse_pass() + pm.add_convert_triton_to_tritongpu_pass(num_warps) + pm.run(mod) + return mod + +def optimize_tritongpu_ir(mod, num_stages): + pm = _triton.ir.pass_manager(mod.context) + pm.add_tritongpu_pipeline_pass(num_stages) + pm.add_canonicalizer_pass() + pm.add_cse_pass() + pm.add_triton_gpu_combine_pass() + pm.add_triton_gpu_verifier_pass() + pm.run(mod) + return mod + +def make_ptx(mod): + # TODO + return mod + +def compile(fn, signature, constants = dict(), attributes = dict(), num_warps=4, num_stages=3, output = "ttgir"): + assert output in ["ttir", "ttgir", "ptx"] + # triton-ir + module = make_triton_ir(fn, signature, constants, attributes) + if output == "ttir": + return module.str() + # tritongpu-ir + module = make_tritongpu_ir(module, num_warps) + module = optimize_tritongpu_ir(module, num_stages) + if output == "ttgir": + return module.str() + # ptx + if output == "ptx": + return make_ptx(module) + assert False \ No newline at end of file diff --git a/python/triton/runtime/__init__.py b/python/triton/runtime/__init__.py new file mode 100644 index 000000000000..47b5d5a1ece3 --- /dev/null +++ b/python/triton/runtime/__init__.py @@ -0,0 +1,2 @@ +from .jit import JITFunction, jit +from .autotuner import Config, autotune, heuristics \ No newline at end of file diff --git a/python/triton/runtime/autotuner.py b/python/triton/runtime/autotuner.py new file mode 100644 index 000000000000..63bf2bbf76e6 --- /dev/null +++ b/python/triton/runtime/autotuner.py @@ -0,0 +1,202 @@ +from __future__ import annotations +import builtins +import time +from typing import Dict + + + +class Autotuner: + def __init__(self, kernel, arg_names, configs, key, reset_to_zero, prune_configs_by: Dict = None): + ''' + :param prune_configs_by: a dict of functions that are used to prune configs, fields: + 'perf_model': performance model used to predicate running time with different configs, returns running time + 'top_k': number of configs to bench + 'prune_num_stages_by'(optional): a function used to prune num_stages. It take configs:List[Config] as its input, and returns pruned configs. + ''' + if not configs: + self.configs = [Config(dict(), num_warps=4, num_stages=2)] + else: + self.configs = configs + self.key_idx = [arg_names.index(k) for k in key] + self.cache = dict() + self.kernel = kernel + # hook to reset all required tensor to zeros before relaunching a kernel + self.hook = lambda args: 0 + if reset_to_zero is not None: + self.reset_idx = [arg_names.index(k) for k in reset_to_zero] + + def _hook(args): + for i in self.reset_idx: + args[i].zero_() + self.hook = _hook + self.arg_names = arg_names + # prune configs + if prune_configs_by: + perf_model, top_k = prune_configs_by['perf_model'], prune_configs_by['top_k'] + if 'early_config_prune' in prune_configs_by: + early_config_prune = prune_configs_by['early_config_prune'] + else: + perf_model, top_k, early_config_prune = None, None, None + self.perf_model, self.configs_top_k = perf_model, top_k + self.early_config_prune = early_config_prune + + def _bench(self, *args, config, **meta): + # check for conflicts, i.e. meta-parameters both provided + # as kwargs and by the autotuner + conflicts = meta.keys() & config.kwargs.keys() + if conflicts: + raise ValueError( + f"Conflicting meta-parameters: {', '.join(conflicts)}." + " Make sure that you don't re-define auto-tuned symbols." + ) + # augment meta-parameters with tunable ones + current = dict(meta, **config.kwargs) + + def kernel_call(): + if config.pre_hook: + config.pre_hook(self.nargs) + self.hook(args) + self.kernel(*args, num_warps=config.num_warps, num_stages=config.num_stages, **current) + return triton.testing.do_bench(kernel_call) + + def __call__(self, *args, **kwargs): + self.nargs = dict(zip(self.arg_names, args)) + if len(self.configs) > 1: + key = tuple([args[i] for i in self.key_idx]) + if key not in self.cache: + # prune configs + pruned_configs = self.configs + if self.early_config_prune: + pruned_configs = self.early_config_prune(self.configs, self.nargs) + if self.perf_model: + top_k = self.configs_top_k + if isinstance(top_k, float) and top_k <= 1.0: + top_k = int(len(self.configs) * top_k) + if len(pruned_configs) > top_k: + est_timing = {config: self.perf_model(**self.nargs, **kwargs, **config.kwargs, num_stages=config.num_stages, num_warps=config.num_warps) for config in pruned_configs} + pruned_configs = sorted(est_timing.keys(), key=lambda x: est_timing[x])[:top_k] + bench_start = time.time() + timings = {config: self._bench(*args, config=config, **kwargs) + for config in pruned_configs} + bench_end = time.time() + self.bench_time = bench_end - bench_start + self.cache[key] = builtins.min(timings, key=timings.get) + self.hook(args) + self.configs_timings = timings + config = self.cache[key] + else: + config = self.configs[0] + self.best_config = config + if config.pre_hook is not None: + config.pre_hook(self.nargs) + return self.kernel(*args, num_warps=config.num_warps, num_stages=config.num_stages, **kwargs, **config.kwargs) + + +class Config: + """ + An object that represents a possible kernel configuration for the auto-tuner to try. + + :ivar meta: a dictionary of meta-parameters to pass to the kernel as keyword arguments. + :type meta: dict[Str, Any] + :ivar num_warps: the number of warps to use for the kernel when compiled for GPUs. For example, if + `num_warps=8`, then each kernel instance will be automatically parallelized to + cooperatively execute using `8 * 32 = 256` threads. + :type num_warps: int + :ivar num_stages: the number of stages that the compiler should use when software-pipelining loops. + Mostly useful for matrix multiplication workloads on SM80+ GPUs. + :type num_stages: int + :ivar pre_hook: a function that will be called before the kernel is called. Parameters of this + function are args. + """ + + def __init__(self, kwargs, num_warps=4, num_stages=2, pre_hook=None): + self.kwargs = kwargs + self.num_warps = num_warps + self.num_stages = num_stages + self.pre_hook = pre_hook + + def __str__(self): + res = [] + for k, v in self.kwargs.items(): + res.append(f'{k}: {v}') + res.append(f'num_warps: {self.num_warps}') + res.append(f'num_stages: {self.num_stages}') + return ', '.join(res) + + +def autotune(configs, key, prune_configs_by=None, reset_to_zero=None): + """ + Decorator for auto-tuning a :code:`triton.jit`'d function. + + .. highlight:: python + .. code-block:: python + + @triton.autotune(configs=[ + triton.Config(meta={'BLOCK_SIZE': 128}, num_warps=4), + triton.Config(meta={'BLOCK_SIZE': 1024}, num_warps=8), + ], + key=['x_size'] # the two above configs will be evaluated anytime + # the value of x_size changes + ) + @triton.jit + def kernel(x_ptr, x_size, **META): + BLOCK_SIZE = META['BLOCK_SIZE'] + + :note: When all the configurations are evaluated, the kernel will run multiple time. + This means that whatever value the kernel updates will be updated multiple times. + To avoid this undesired behavior, you can use the `reset_to_zero` argument, which + reset the value of the provided tensor to `zero` before running any configuration. + + :param configs: a list of :code:`triton.Config` objects + :type configs: list[triton.Config] + :param key: a list of argument names whose change in value will trigger the evaluation of all provided configs. + :type key: list[str] + :param prune_configs_by: a dict of functions that are used to prune configs, fields: + 'perf_model': performance model used to predicate running time with different configs, returns running time + 'top_k': number of configs to bench + 'early_config_prune'(optional): a function used to do early prune (eg, num_stages). It take configs:List[Config] as its input, and returns pruned configs. + :param reset_to_zero: a list of argument names whose value will be reset to zero before evaluating any configs. + :type reset_to_zero: list[str] + """ + def decorator(fn): + def wrapper(kernel): + return Autotuner(kernel, fn.arg_names, configs, key, reset_to_zero, prune_configs_by) + + fn.kernel_decorators.append(wrapper) + return fn + + return decorator + + +def heuristics(values): + """ + Decorator for specifying how the values of certain meta-parameters may be computed. + This is useful for cases where auto-tuning is prohibitevely expensive, or just not applicable. + + .. highlight:: python + .. code-block:: python + + @triton.heuristics(values={'BLOCK_SIZE': lambda args: 2 ** int(math.ceil(math.log2(args[1])))}) + @triton.jit + def kernel(x_ptr, x_size, **META): + BLOCK_SIZE = META['BLOCK_SIZE'] # smallest power-of-two >= x_size + + + .param values: a dictionary of meta-parameter names and functions that compute the value of the meta-parameter. + each such function takes a list of positional arguments as input. + .type values: dict[str, Callable[[list[Any]], Any]] + """ + def decorator(fn): + def wrapper(kernel): + def fun(*args, **meta): + for v, heur in values.items(): + assert v not in meta + meta[v] = heur({**dict(zip(fn.arg_names, args)), **meta}) + return kernel(*args, **meta) + return fun + + fn.kernel_decorators.append(wrapper) + return fn + + return decorator + diff --git a/python/triton/runtime/jit.py b/python/triton/runtime/jit.py new file mode 100644 index 000000000000..758b053413d8 --- /dev/null +++ b/python/triton/runtime/jit.py @@ -0,0 +1,268 @@ +from __future__ import annotations + +import ast +import functools +import hashlib +import inspect +import os +import subprocess +import tempfile +import textwrap +import triton +import triton._C.libtriton.triton as _triton +from ..tools.disasm import extract + +# ----------------------------------------------------------------------------- +# Binary +# ----------------------------------------------------------------------------- + +class Binary: + def __init__(self, backend, name, asm, shared_mem, num_warps): + self.backend = backend + self.name = name + self.asm = asm + self.shared_mem = shared_mem + self.num_warps = num_warps + + +class LoadedBinary: + def __init__(self, device: int, bin: Binary): + module, kernel = _triton.code_gen.load_binary(bin.backend, + bin.name, + bin.asm, + bin.shared_mem, + device) + self.bin = bin + self.asm = bin.asm + self.sass = '' + self.module = module + self.kernel = kernel + self.device = device + self.shared_mem = bin.shared_mem + + def __call__(self, stream, args, grid_0, grid_1=1, grid_2=1): + _triton.runtime.enqueue(self.bin.backend, stream, self.kernel, + grid_0, grid_1, grid_2, + self.bin.num_warps * 32, 1, 1, + args, self.bin.shared_mem) + + def get_sass(self, fun=None): + if self.sass: + return self.sass + fd, path = tempfile.mkstemp() + try: + with open(fd, 'wb') as cubin: + cubin.write(self.asm['cubin']) + self.sass = extract(path, fun) + finally: + os.remove(path) + self.asm['sass'] = self.sass + return self.sass + +# ----------------------------------------------------------------------------- +# Kernel +# ----------------------------------------------------------------------------- + +class Kernel: + + def __call__(self, *args, grid, num_warps=4, num_stages=3, **kwargs): + raise RuntimeError("Not implemented. Public repo implementation will be rewritten to reduce latency.") + + + +# ----------------------------------------------------------------------------- +# Dependencies Finder +# ----------------------------------------------------------------------------- + +class DependenciesFinder(ast.NodeVisitor): + """ + This AST visitor is used to find dependencies of a JITFunction. This can + be used to invalidate a JITFunction's hash when its source code -- or + that of its dependencies -- changes. + """ + + def __init__(self, globals, src) -> None: + super().__init__() + self.ret = hashlib.md5(src.encode("utf-8")).hexdigest() + self.globals = globals + + def visit_Name(self, node): + return self.globals.get(node.id, None) + + def visit_Attribute(self, node): + lhs = self.visit(node.value) + while isinstance(lhs, ast.Attribute): + lhs = self.visit(lhs.value) + if lhs is None or lhs is triton: + return None + return getattr(lhs, node.attr) + + def visit_Call(self, node): + func = self.visit(node.func) + if func is None: + return + if inspect.isbuiltin(func): + return + if func.__module__ and func.__module__.startswith('triton.'): + return + assert isinstance(func, JITFunction) + if func.hash is None: + tree = ast.parse(func.src) + finder = DependenciesFinder(func.__globals__, func.src) + finder.visit(tree) + func.hash = finder.ret + self.ret = (self.ret + func.hash).encode("utf-8") + self.ret = hashlib.md5(self.ret).hexdigest() + +# ----------------------------------------------------------------------------- +# JITFunction +# ----------------------------------------------------------------------------- + +@functools.lru_cache() +def version_key(): + import pkgutil + contents = [] + # frontend + with open(__file__, "rb") as f: + contents += [hashlib.md5(f.read()).hexdigest()] + # backend + with open(triton._C.libtriton.__file__, "rb") as f: + contents += [hashlib.md5(f.read()).hexdigest()] + # language + language_path = os.path.join(*triton.__path__, 'language') + for lib in pkgutil.iter_modules([language_path]): + with open(lib.module_finder.find_spec(lib.name).origin, "rb") as f: + contents += [hashlib.md5(f.read()).hexdigest()] + # ptxas version + try: + ptxas_version = hashlib.md5(subprocess.check_output(["ptxas", "--version"])).hexdigest() + except Exception: + ptxas_version = '' + return '-'.join(triton.__version__) + '-' + ptxas_version + '-' + '-'.join(contents) + + +class JITFunction: + + cache_hook = None + + def __init__(self, fn, version=None, inline=True, do_not_specialize=None): + # information of wrapped function + self.fn = fn + self.module = fn.__module__ + signature = inspect.signature(fn) + self.arg_names = [v.name for v in signature.parameters.values()] + self.arg_defaults = [v.default for v in signature.parameters.values()] + + self.version = version + self.inline = inline + self.src = textwrap.dedent(inspect.getsource(fn)) + self.src = self.src[self.src.find("def"):] + self.do_not_specialize = [] if do_not_specialize is None else do_not_specialize + self.do_not_specialize = [self.arg_names.index(arg) if isinstance(arg, str) else arg for arg in self.do_not_specialize] + # cache for callable driver objects (e.g. CUkernel) + self.bin_cache = dict() + self.hash = None + # JITFunction can be instantiated as kernel + # when called with a grid using __getitem__ + self.kernel_decorators = [] + self.kernel = None + # annotations + self.annotations = {self.arg_names.index(name): ty for name, ty in fn.__annotations__.items()} + self.__annotations__ = fn.__annotations__ + # constexprs + self.constexprs = [self.arg_names.index(ann) for ann in self.__annotations__.keys()] + # forward docs + self.__doc__ = fn.__doc__ + self.__name__ = fn.__name__ + self.__globals__ = fn.__globals__ + self.__module__ = fn.__module__ + + @property + @functools.lru_cache() + def cache_key(self): + # TODO : hash should be attribute of `self` + if self.hash is None: + dependencies_finder = DependenciesFinder(globals=self.__globals__, src=self.src) + dependencies_finder.visit(self.parse()) + self.hash = dependencies_finder.ret + version_key() + return self.hash + + # we do not parse `src` in the constructor because + # the user might want to monkey-patch self.src dynamically. + # Some unit tests do this, for example. + def parse(self): + tree = ast.parse(self.src) + assert isinstance(tree, ast.Module) + assert len(tree.body) == 1 + assert isinstance(tree.body[0], ast.FunctionDef) + return tree + + def __call__(self, *args, **kwargs): + raise RuntimeError("Cannot call @triton.jit'd outside of the scope of a kernel") + + # - when `.src` attribute is set, cache path needs + # to be reinitialized + # - when kernel decorators change, cached kernel + # needs to be cleared + def __setattr__(self, name, value): + if name == 'kernel_decorators': + self.kernel = None + super(JITFunction, self).__setattr__(name, value) + if name == 'src': + self.hash = None + JITFunction.cache_key.fget.cache_clear() + + def _init_kernel(self): + if self.kernel is None: + self.kernel = Kernel(self) + for decorator in reversed(self.kernel_decorators): + self.kernel = decorator(self.kernel) + return self.kernel + + def __getitem__(self, grid): + """ + A JIT function is launched with: fn[grid](*args, **kwargs). + Hence JITFunction.__getitem__ returns a callable proxy that + memorizes the grid. + """ + class Launcher: + def __init__(self, kernel, grid): + self.kernel = kernel + self.grid = grid + + def __call__(self, *wargs, **kwargs): + return self.kernel(*wargs, **kwargs, grid=self.grid) + + return Launcher(self._init_kernel(), grid) + + def __repr__(self): + return f"JITFunction({self.module}:{self.fn.__name__})" + +# ----------------------------------------------------------------------------- +# `jit` decorator +# ----------------------------------------------------------------------------- + +def jit(*args, **kwargs): + """ + Decorator for JIT-compiling a function using the Triton compiler. + + :note: When a jit'd function is called, :code:`torch.tensor` arguments are implicitly converted to pointers using the :code:`.data_ptr()` method. + + :note: This function will be compiled and run on the GPU. It will only have access to: + + * python primitives, + * objects within the triton.language package, + * arguments to this function, + * other jit'd functions + + :param fn: the function to be jit-compiled + :type fn: Callable + """ + if args: + assert len(args) == 1 + assert callable(args[0]) + return JITFunction(args[0], **kwargs) + else: + def decorator(fn): + return JITFunction(fn, **kwargs) + return decorator \ No newline at end of file diff --git a/python/triton/testing.py b/python/triton/testing.py index fbca719ff0d1..f42f38b9f74e 100644 --- a/python/triton/testing.py +++ b/python/triton/testing.py @@ -5,7 +5,7 @@ import torch import triton._C.libtriton.triton as _triton -from .code_gen import OutOfResources +from .compiler import OutOfResources try: import triton._C.libtriton.cutlass as _cutlass diff --git a/python/triton/utils.py b/python/triton/utils.py new file mode 100644 index 000000000000..b9db92dfbfe5 --- /dev/null +++ b/python/triton/utils.py @@ -0,0 +1,46 @@ +from __future__ import annotations +import torch + + +def cdiv(x, y): + return (x + y - 1) // y + + +def next_power_of_2(n): + """Return the smallest power of 2 greater than or equal to n""" + n -= 1 + n |= n >> 1 + n |= n >> 2 + n |= n >> 4 + n |= n >> 8 + n |= n >> 16 + n += 1 + return n + +class TensorWrapper: + def __init__(self, base, dtype): + self.dtype = dtype + self.base = base + self.is_cuda = base.is_cuda + self.device = base.device + + def data_ptr(self): + return self.base.data_ptr() + + def __str__(self) -> str: + return f'TensorWrapper[{self.dtype}]({self.base})' + + +def reinterpret(tensor, dtype): + if isinstance(tensor, TensorWrapper): + if dtype == tensor.base.dtype: + # Reinterpreting to the original interpretation; return the base. + return tensor.base + else: + # Reinterpreting a wrapped tensor to a different type. + return TensorWrapper(tensor.base, dtype) + elif isinstance(tensor, torch.Tensor): + # A new wrapper is needed around an unwrapped tensor. + return TensorWrapper(tensor, dtype) + else: + raise TypeError(f'Cannot reinterpret a {type(tensor)}.') diff --git a/rewrite-test/inline.mlir b/rewrite-test/inline.mlir deleted file mode 100644 index e1bd07e8fd27..000000000000 --- a/rewrite-test/inline.mlir +++ /dev/null @@ -1,261 +0,0 @@ -============================= test session starts ============================== -platform linux -- Python 3.7.7, pytest-7.1.1, pluggy-1.0.0 -rootdir: /home/da/codes/triton-mlir-rewrite/triton/rewrite-test -collected 6 items - -scf_tests.py .....module { - func @matmul_kernel__Pfp16_Pfp16_Pfp16_i32_i32_i32_i32_i32_i32__7c1_9c1_11c1_12c64_13c64_14c32_15c8(%arg0: !tt.ptr, %arg1: !tt.ptr, %arg2: !tt.ptr, %arg3: i32, %arg4: i32, %arg5: i32, %arg6: i32, %arg7: i32, %arg8: i32) { - %0 = tt.get_program_id {axis = 0 : i32} : i32 - %1 = call @"cdiv__i32__1cconstexpr[64]"(%arg3) : (i32) -> i32 - %2 = call @"cdiv__i32__1cconstexpr[64]"(%arg4) : (i32) -> i32 - %c8_i32 = arith.constant 8 : i32 - %3 = arith.muli %2, %c8_i32 : i32 - %4 = arith.divsi %0, %3 : i32 - %c8_i32_0 = arith.constant 8 : i32 - %5 = arith.muli %4, %c8_i32_0 : i32 - %6 = arith.subi %1, %5 : i32 - %7 = call @"minimum__i32__1cconstexpr[8]"(%6) : (i32) -> i32 - %8 = arith.remsi %0, %7 : i32 - %9 = arith.addi %5, %8 : i32 - %10 = arith.remsi %0, %3 : i32 - %11 = arith.divsi %10, %7 : i32 - %c64_i32 = arith.constant 64 : i32 - %12 = arith.muli %9, %c64_i32 : i32 - %13 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32> - %14 = tt.broadcast %12 : (i32) -> tensor<64xi32> - %15 = arith.addi %14, %13 : tensor<64xi32> - %c64_i32_1 = arith.constant 64 : i32 - %16 = arith.muli %11, %c64_i32_1 : i32 - %17 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32> - %18 = tt.broadcast %16 : (i32) -> tensor<64xi32> - %19 = arith.addi %18, %17 : tensor<64xi32> - %20 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32> - %21 = tt.reshape %15 : (tensor<64xi32>) -> tensor<64x1xi32> - %22 = tt.broadcast %arg6 : (i32) -> tensor<64x1xi32> - %23 = arith.muli %21, %22 : tensor<64x1xi32> - %24 = tt.reshape %20 : (tensor<32xi32>) -> tensor<1x32xi32> - %c1_i32 = arith.constant 1 : i32 - %25 = tt.broadcast %c1_i32 : (i32) -> tensor<1x32xi32> - %26 = arith.muli %24, %25 : tensor<1x32xi32> - %27 = tt.broadcast %23 : (tensor<64x1xi32>) -> tensor<64x32xi32> - %28 = tt.broadcast %26 : (tensor<1x32xi32>) -> tensor<64x32xi32> - %29 = arith.addi %27, %28 : tensor<64x32xi32> - %30 = tt.broadcast %arg0 : (!tt.ptr) -> tensor<64x32x!tt.ptr> - %31 = tt.getelementptr %30, %29, : tensor<64x32x!tt.ptr> - %32 = tt.reshape %20 : (tensor<32xi32>) -> tensor<32x1xi32> - %33 = tt.broadcast %arg7 : (i32) -> tensor<32x1xi32> - %34 = arith.muli %32, %33 : tensor<32x1xi32> - %35 = tt.reshape %19 : (tensor<64xi32>) -> tensor<1x64xi32> - %c1_i32_2 = arith.constant 1 : i32 - %36 = tt.broadcast %c1_i32_2 : (i32) -> tensor<1x64xi32> - %37 = arith.muli %35, %36 : tensor<1x64xi32> - %38 = tt.broadcast %34 : (tensor<32x1xi32>) -> tensor<32x64xi32> - %39 = tt.broadcast %37 : (tensor<1x64xi32>) -> tensor<32x64xi32> - %40 = arith.addi %38, %39 : tensor<32x64xi32> - %41 = tt.broadcast %arg1 : (!tt.ptr) -> tensor<32x64x!tt.ptr> - %42 = tt.getelementptr %41, %40, : tensor<32x64x!tt.ptr> - %cst = arith.constant 0.000000e+00 : f32 - %43 = tt.broadcast %cst : (f32) -> tensor<64x64xf32> - %c0_i32 = arith.constant 0 : i32 - %c32_i32 = arith.constant 32 : i32 - %44 = arith.index_cast %c0_i32 : i32 to index - %45 = arith.index_cast %arg5 : i32 to index - %46 = arith.index_cast %c32_i32 : i32 to index - %47:3 = scf.for %arg9 = %44 to %45 step %46 iter_args(%arg10 = %43, %arg11 = %31, %arg12 = %42) -> (tensor<64x64xf32>, tensor<64x32x!tt.ptr>, tensor<32x64x!tt.ptr>) { - %cst_6 = arith.constant dense : tensor<64x32xi1> - %cst_7 = arith.constant dense<0.000000e+00> : tensor<64x32xf16> - %77 = tt.load %arg11, %cst_6, %cst_7 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<64x32xf16> - %cst_8 = arith.constant dense : tensor<32x64xi1> - %cst_9 = arith.constant dense<0.000000e+00> : tensor<32x64xf16> - %78 = tt.load %arg12, %cst_8, %cst_9 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<32x64xf16> - %cst_10 = arith.constant 0.000000e+00 : f32 - %79 = tt.broadcast %cst_10 : (f32) -> tensor<64x64xf32> - %80 = tt.dot %77, %78, %79 {allowTF32 = true} : tensor<64x32xf16> * tensor<32x64xf16> -> tensor<64x64xf32> - %81 = arith.addf %arg10, %80 : tensor<64x64xf32> - %c32_i32_11 = arith.constant 32 : i32 - %82 = tt.broadcast %c32_i32_11 : (i32) -> tensor<64x32xi32> - %83 = tt.getelementptr %arg11, %82, : tensor<64x32x!tt.ptr> - %c32_i32_12 = arith.constant 32 : i32 - %84 = arith.muli %arg7, %c32_i32_12 : i32 - %85 = tt.broadcast %84 : (i32) -> tensor<32x64xi32> - %86 = tt.getelementptr %arg12, %85, : tensor<32x64x!tt.ptr> - scf.yield %81, %83, %86 : tensor<64x64xf32>, tensor<64x32x!tt.ptr>, tensor<32x64x!tt.ptr> - } - %48 = arith.truncf %47#0 : tensor<64x64xf32> to tensor<64x64xf16> - %c64_i32_3 = arith.constant 64 : i32 - %49 = arith.muli %9, %c64_i32_3 : i32 - %50 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32> - %51 = tt.broadcast %49 : (i32) -> tensor<64xi32> - %52 = arith.addi %51, %50 : tensor<64xi32> - %c64_i32_4 = arith.constant 64 : i32 - %53 = arith.muli %11, %c64_i32_4 : i32 - %54 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32> - %55 = tt.broadcast %53 : (i32) -> tensor<64xi32> - %56 = arith.addi %55, %54 : tensor<64xi32> - %57 = tt.reshape %52 : (tensor<64xi32>) -> tensor<64x1xi32> - %58 = tt.broadcast %arg8 : (i32) -> tensor<64x1xi32> - %59 = arith.muli %58, %57 : tensor<64x1xi32> - %60 = tt.broadcast %arg2 : (!tt.ptr) -> tensor<64x1x!tt.ptr> - %61 = tt.getelementptr %60, %59, : tensor<64x1x!tt.ptr> - %62 = tt.reshape %56 : (tensor<64xi32>) -> tensor<1x64xi32> - %c1_i32_5 = arith.constant 1 : i32 - %63 = tt.broadcast %c1_i32_5 : (i32) -> tensor<1x64xi32> - %64 = arith.muli %62, %63 : tensor<1x64xi32> - %65 = tt.broadcast %61 : (tensor<64x1x!tt.ptr>) -> tensor<64x64x!tt.ptr> - %66 = tt.broadcast %64 : (tensor<1x64xi32>) -> tensor<64x64xi32> - %67 = tt.getelementptr %65, %66, : tensor<64x64x!tt.ptr> - %68 = tt.reshape %52 : (tensor<64xi32>) -> tensor<64x1xi32> - %69 = tt.broadcast %arg3 : (i32) -> tensor<64x1xi32> - %70 = arith.cmpi slt, %68, %69 : tensor<64x1xi32> - %71 = tt.reshape %56 : (tensor<64xi32>) -> tensor<1x64xi32> - %72 = tt.broadcast %arg4 : (i32) -> tensor<1x64xi32> - %73 = arith.cmpi slt, %71, %72 : tensor<1x64xi32> - %74 = tt.broadcast %70 : (tensor<64x1xi1>) -> tensor<64x64xi1> - %75 = tt.broadcast %73 : (tensor<1x64xi1>) -> tensor<64x64xi1> - %76 = arith.andi %74, %75 : tensor<64x64xi1> - tt.store %67, %48, %76, : tensor<64x64xf16> - return - } - func @"cdiv__i32__1cconstexpr[64]"(%arg0: i32) -> i32 { - %c64_i32 = arith.constant 64 : i32 - %0 = arith.addi %arg0, %c64_i32 : i32 - %c1_i32 = arith.constant 1 : i32 - %1 = arith.subi %0, %c1_i32 : i32 - %c64_i32_0 = arith.constant 64 : i32 - %2 = arith.divsi %1, %c64_i32_0 : i32 - return %2 : i32 - } - func @"minimum__i32__1cconstexpr[8]"(%arg0: i32) -> i32 { - %c8_i32 = arith.constant 8 : i32 - %0 = arith.cmpi slt, %arg0, %c8_i32 : i32 - %c8_i32_0 = arith.constant 8 : i32 - %1 = select %0, %arg0, %c8_i32_0 : i32 - return %1 : i32 - } -} - -module { - func @matmul_kernel__Pfp16_Pfp16_Pfp16_i32_i32_i32_i32_i32_i32__7c1_9c1_11c1_12c64_13c64_14c32_15c8(%arg0: !tt.ptr, %arg1: !tt.ptr, %arg2: !tt.ptr, %arg3: i32, %arg4: i32, %arg5: i32, %arg6: i32, %arg7: i32, %arg8: i32) { - %c8_i32 = arith.constant 8 : i32 - %c63_i32 = arith.constant 63 : i32 - %c64_i32 = arith.constant 64 : i32 - %cst = arith.constant 0.000000e+00 : f32 - %c0 = arith.constant 0 : index - %c32 = arith.constant 32 : index - %cst_0 = arith.constant dense : tensor<64x32xi1> - %cst_1 = arith.constant dense<0.000000e+00> : tensor<64x32xf16> - %cst_2 = arith.constant dense : tensor<32x64xi1> - %cst_3 = arith.constant dense<0.000000e+00> : tensor<32x64xf16> - %c32_i32 = arith.constant 32 : i32 - %c1_i32 = arith.constant 1 : i32 - %0 = tt.get_program_id {axis = 0 : i32} : i32 - %1 = arith.addi %arg3, %c63_i32 : i32 - %2 = arith.divsi %1, %c64_i32 : i32 - %3 = arith.addi %arg4, %c63_i32 : i32 - %4 = arith.divsi %3, %c64_i32 : i32 - %5 = arith.muli %4, %c8_i32 : i32 - %6 = arith.divsi %0, %5 : i32 - %7 = arith.muli %6, %c8_i32 : i32 - %8 = arith.subi %2, %7 : i32 - %9 = arith.cmpi slt, %8, %c8_i32 : i32 - %10 = select %9, %8, %c8_i32 : i32 - %11 = arith.remsi %0, %10 : i32 - %12 = arith.addi %7, %11 : i32 - %13 = arith.remsi %0, %5 : i32 - %14 = arith.divsi %13, %10 : i32 - %15 = arith.muli %12, %c64_i32 : i32 - %16 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32> - %17 = tt.broadcast %15 : (i32) -> tensor<64xi32> - %18 = arith.addi %17, %16 : tensor<64xi32> - %19 = arith.muli %14, %c64_i32 : i32 - %20 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32> - %21 = tt.broadcast %19 : (i32) -> tensor<64xi32> - %22 = arith.addi %21, %20 : tensor<64xi32> - %23 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32> - %24 = tt.reshape %18 : (tensor<64xi32>) -> tensor<64x1xi32> - %25 = tt.broadcast %arg6 : (i32) -> tensor<64x1xi32> - %26 = arith.muli %24, %25 : tensor<64x1xi32> - %27 = tt.reshape %23 : (tensor<32xi32>) -> tensor<1x32xi32> - %28 = tt.broadcast %c1_i32 : (i32) -> tensor<1x32xi32> - %29 = arith.muli %27, %28 : tensor<1x32xi32> - %30 = tt.broadcast %26 : (tensor<64x1xi32>) -> tensor<64x32xi32> - %31 = tt.broadcast %29 : (tensor<1x32xi32>) -> tensor<64x32xi32> - %32 = arith.addi %30, %31 : tensor<64x32xi32> - %33 = tt.broadcast %arg0 : (!tt.ptr) -> tensor<64x32x!tt.ptr> - %34 = tt.getelementptr %33, %32, : tensor<64x32x!tt.ptr> - %35 = tt.reshape %23 : (tensor<32xi32>) -> tensor<32x1xi32> - %36 = tt.broadcast %arg7 : (i32) -> tensor<32x1xi32> - %37 = arith.muli %35, %36 : tensor<32x1xi32> - %38 = tt.reshape %22 : (tensor<64xi32>) -> tensor<1x64xi32> - %39 = tt.broadcast %c1_i32 : (i32) -> tensor<1x64xi32> - %40 = arith.muli %38, %39 : tensor<1x64xi32> - %41 = tt.broadcast %37 : (tensor<32x1xi32>) -> tensor<32x64xi32> - %42 = tt.broadcast %40 : (tensor<1x64xi32>) -> tensor<32x64xi32> - %43 = arith.addi %41, %42 : tensor<32x64xi32> - %44 = tt.broadcast %arg1 : (!tt.ptr) -> tensor<32x64x!tt.ptr> - %45 = tt.getelementptr %44, %43, : tensor<32x64x!tt.ptr> - %46 = tt.broadcast %cst : (f32) -> tensor<64x64xf32> - %47 = arith.index_cast %arg5 : i32 to index - %48:3 = scf.for %arg9 = %c0 to %47 step %c32 iter_args(%arg10 = %46, %arg11 = %34, %arg12 = %45) -> (tensor<64x64xf32>, tensor<64x32x!tt.ptr>, tensor<32x64x!tt.ptr>) { - %78 = tt.load %arg11, %cst_0, %cst_1 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<64x32xf16> - %79 = tt.load %arg12, %cst_2, %cst_3 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<32x64xf16> - %80 = tt.broadcast %cst : (f32) -> tensor<64x64xf32> - %81 = tt.dot %78, %79, %80 {allowTF32 = true} : tensor<64x32xf16> * tensor<32x64xf16> -> tensor<64x64xf32> - %82 = arith.addf %arg10, %81 : tensor<64x64xf32> - %83 = tt.broadcast %c32_i32 : (i32) -> tensor<64x32xi32> - %84 = tt.getelementptr %arg11, %83, : tensor<64x32x!tt.ptr> - %85 = arith.muli %arg7, %c32_i32 : i32 - %86 = tt.broadcast %85 : (i32) -> tensor<32x64xi32> - %87 = tt.getelementptr %arg12, %86, : tensor<32x64x!tt.ptr> - scf.yield %82, %84, %87 : tensor<64x64xf32>, tensor<64x32x!tt.ptr>, tensor<32x64x!tt.ptr> - } - %49 = arith.truncf %48#0 : tensor<64x64xf32> to tensor<64x64xf16> - %50 = arith.muli %12, %c64_i32 : i32 - %51 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32> - %52 = tt.broadcast %50 : (i32) -> tensor<64xi32> - %53 = arith.addi %52, %51 : tensor<64xi32> - %54 = arith.muli %14, %c64_i32 : i32 - %55 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32> - %56 = tt.broadcast %54 : (i32) -> tensor<64xi32> - %57 = arith.addi %56, %55 : tensor<64xi32> - %58 = tt.reshape %53 : (tensor<64xi32>) -> tensor<64x1xi32> - %59 = tt.broadcast %arg8 : (i32) -> tensor<64x1xi32> - %60 = arith.muli %59, %58 : tensor<64x1xi32> - %61 = tt.broadcast %arg2 : (!tt.ptr) -> tensor<64x1x!tt.ptr> - %62 = tt.getelementptr %61, %60, : tensor<64x1x!tt.ptr> - %63 = tt.reshape %57 : (tensor<64xi32>) -> tensor<1x64xi32> - %64 = tt.broadcast %c1_i32 : (i32) -> tensor<1x64xi32> - %65 = arith.muli %63, %64 : tensor<1x64xi32> - %66 = tt.broadcast %62 : (tensor<64x1x!tt.ptr>) -> tensor<64x64x!tt.ptr> - %67 = tt.broadcast %65 : (tensor<1x64xi32>) -> tensor<64x64xi32> - %68 = tt.getelementptr %66, %67, : tensor<64x64x!tt.ptr> - %69 = tt.reshape %53 : (tensor<64xi32>) -> tensor<64x1xi32> - %70 = tt.broadcast %arg3 : (i32) -> tensor<64x1xi32> - %71 = arith.cmpi slt, %69, %70 : tensor<64x1xi32> - %72 = tt.reshape %57 : (tensor<64xi32>) -> tensor<1x64xi32> - %73 = tt.broadcast %arg4 : (i32) -> tensor<1x64xi32> - %74 = arith.cmpi slt, %72, %73 : tensor<1x64xi32> - %75 = tt.broadcast %71 : (tensor<64x1xi1>) -> tensor<64x64xi1> - %76 = tt.broadcast %74 : (tensor<1x64xi1>) -> tensor<64x64xi1> - %77 = arith.andi %75, %76 : tensor<64x64xi1> - tt.store %68, %49, %77, : tensor<64x64xf16> - return - } - func @"cdiv__i32__1cconstexpr[64]"(%arg0: i32) -> i32 { - %c64_i32 = arith.constant 64 : i32 - %c63_i32 = arith.constant 63 : i32 - %0 = arith.addi %arg0, %c63_i32 : i32 - %1 = arith.divsi %0, %c64_i32 : i32 - return %1 : i32 - } - func @"minimum__i32__1cconstexpr[8]"(%arg0: i32) -> i32 { - %c8_i32 = arith.constant 8 : i32 - %0 = arith.cmpi slt, %arg0, %c8_i32 : i32 - %1 = select %0, %arg0, %c8_i32 : i32 - return %1 : i32 - } -} - -. - -============================== 6 passed in 1.21s =============================== diff --git a/rewrite-test/jit/if-else/if-else.py b/rewrite-test/jit/if-else/if-else.py deleted file mode 100644 index bcf549fedfe3..000000000000 --- a/rewrite-test/jit/if-else/if-else.py +++ /dev/null @@ -1,51 +0,0 @@ -import triton - -@triton.jit -def if_else(lb, ub, value): - if value > lb: - a = 0.0 - else: - a = 1.0 - c = a + a - -@triton.jit -def only_if(lb, ub, value): - a = -1.0 - if value > lb: - a = 0.0 - c = a + a - -@triton.jit -def only_if_invalid(lb, ub, value): - if value > lb: - a = 0.0 - c = a + a - -@triton.jit -def nested_if(lb, ub, value): - if value > lb: - if value < ub: - a = 2.0 - else: - a = 1.0 - else: - a = 0.0 - c = a + a - - -mod_if_else, ctx_if_else = if_else.compile_to_ttir(2, 4, 3, grid=(1,)) -mod_if_else.dump() - -mod_only_if, ctx_only_if = only_if.compile_to_ttir(2, 4, 3, grid=(1,)) -mod_only_if.dump() - -try: - mod_only_if_invalid, ctx_only_if = only_if_invalid.compile_to_ttir(2, 4, 3, grid=(1,)) - mod_only_if_invalid.dump() -except: - print('value error') - -mod_nested_if, ctx_nested_if = nested_if.compile_to_ttir(2, 4, 3, grid=(1,)) -mod_nested_if.dump() - -print(mod_nested_if.str()) diff --git a/rewrite-test/jit/matmul/matmul.mlir b/rewrite-test/jit/matmul/matmul.mlir deleted file mode 100644 index a70dfd334c6c..000000000000 --- a/rewrite-test/jit/matmul/matmul.mlir +++ /dev/null @@ -1,261 +0,0 @@ -module { - func @matmul_kernel__Pfp16_Pfp16_Pfp16_i32_i32_i32_i32_i32_i32__7c1_9c1_11c1_12c128_13c128_14c128_15c8(%arg0: !tt.ptr, %arg1: !tt.ptr, %arg2: !tt.ptr, %arg3: i32, %arg4: i32, %arg5: i32, %arg6: i32, %arg7: i32, %arg8: i32) { - %0 = tt.get_program_id {axis = 0 : i32} : i32 - %1 = call @"cdiv__i32__1cconstexpr[128]"(%arg3) : (i32) -> i32 - %2 = call @"cdiv__i32__1cconstexpr[128]"(%arg4) : (i32) -> i32 - %c8_i32 = arith.constant 8 : i32 - %3 = arith.muli %2, %c8_i32 : i32 - %4 = arith.divsi %0, %3 : i32 - %c8_i32_0 = arith.constant 8 : i32 - %5 = arith.muli %4, %c8_i32_0 : i32 - %6 = arith.subi %1, %5 : i32 - %7 = call @"minimum__i32__1cconstexpr[8]"(%6) : (i32) -> i32 - %8 = arith.remsi %0, %7 : i32 - %9 = arith.addi %5, %8 : i32 - %10 = arith.remsi %0, %3 : i32 - %11 = arith.divsi %10, %7 : i32 - %c128_i32 = arith.constant 128 : i32 - %12 = arith.muli %9, %c128_i32 : i32 - %13 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32> - %14 = tt.broadcast %12 : (i32) -> tensor<128xi32> - %15 = arith.addi %14, %13 : tensor<128xi32> - %c128_i32_1 = arith.constant 128 : i32 - %16 = arith.muli %11, %c128_i32_1 : i32 - %17 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32> - %18 = tt.broadcast %16 : (i32) -> tensor<128xi32> - %19 = arith.addi %18, %17 : tensor<128xi32> - %20 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32> - %21 = tt.reshape %15 : (tensor<128xi32>) -> tensor<128x1xi32> - %22 = tt.broadcast %arg6 : (i32) -> tensor<128x1xi32> - %23 = arith.muli %21, %22 : tensor<128x1xi32> - %24 = tt.reshape %20 : (tensor<128xi32>) -> tensor<1x128xi32> - %c1_i32 = arith.constant 1 : i32 - %25 = tt.broadcast %c1_i32 : (i32) -> tensor<1x128xi32> - %26 = arith.muli %24, %25 : tensor<1x128xi32> - %27 = tt.broadcast %23 : (tensor<128x1xi32>) -> tensor<128x128xi32> - %28 = tt.broadcast %26 : (tensor<1x128xi32>) -> tensor<128x128xi32> - %29 = arith.addi %27, %28 : tensor<128x128xi32> - %30 = tt.broadcast %arg0 : (!tt.ptr) -> tensor<128x128x!tt.ptr> - %31 = tt.getelementptr %30, %29, : tensor<128x128x!tt.ptr> - %32 = tt.reshape %20 : (tensor<128xi32>) -> tensor<128x1xi32> - %33 = tt.broadcast %arg7 : (i32) -> tensor<128x1xi32> - %34 = arith.muli %32, %33 : tensor<128x1xi32> - %35 = tt.reshape %19 : (tensor<128xi32>) -> tensor<1x128xi32> - %c1_i32_2 = arith.constant 1 : i32 - %36 = tt.broadcast %c1_i32_2 : (i32) -> tensor<1x128xi32> - %37 = arith.muli %35, %36 : tensor<1x128xi32> - %38 = tt.broadcast %34 : (tensor<128x1xi32>) -> tensor<128x128xi32> - %39 = tt.broadcast %37 : (tensor<1x128xi32>) -> tensor<128x128xi32> - %40 = arith.addi %38, %39 : tensor<128x128xi32> - %41 = tt.broadcast %arg1 : (!tt.ptr) -> tensor<128x128x!tt.ptr> - %42 = tt.getelementptr %41, %40, : tensor<128x128x!tt.ptr> - %cst = arith.constant 0.000000e+00 : f32 - %43 = tt.broadcast %cst : (f32) -> tensor<128x128xf32> - %c0_i32 = arith.constant 0 : i32 - %c128_i32_3 = arith.constant 128 : i32 - %44 = arith.index_cast %c0_i32 : i32 to index - %45 = arith.index_cast %arg5 : i32 to index - %46 = arith.index_cast %c128_i32_3 : i32 to index - %47:3 = scf.for %arg9 = %44 to %45 step %46 iter_args(%arg10 = %43, %arg11 = %31, %arg12 = %42) -> (tensor<128x128xf32>, tensor<128x128x!tt.ptr>, tensor<128x128x!tt.ptr>) { - %cst_7 = arith.constant dense : tensor<128x128xi1> - %cst_8 = arith.constant dense<0.000000e+00> : tensor<128x128xf16> - %77 = tt.load %arg11, %cst_7, %cst_8 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<128x128xf16> - %cst_9 = arith.constant dense : tensor<128x128xi1> - %cst_10 = arith.constant dense<0.000000e+00> : tensor<128x128xf16> - %78 = tt.load %arg12, %cst_9, %cst_10 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<128x128xf16> - %cst_11 = arith.constant 0.000000e+00 : f32 - %79 = tt.broadcast %cst_11 : (f32) -> tensor<128x128xf32> - %80 = tt.dot %77, %78, %79 {allowTF32 = true} : tensor<128x128xf16> * tensor<128x128xf16> -> tensor<128x128xf32> - %81 = arith.addf %arg10, %80 : tensor<128x128xf32> - %c128_i32_12 = arith.constant 128 : i32 - %82 = tt.broadcast %c128_i32_12 : (i32) -> tensor<128x128xi32> - %83 = tt.getelementptr %arg11, %82, : tensor<128x128x!tt.ptr> - %c128_i32_13 = arith.constant 128 : i32 - %84 = arith.muli %arg7, %c128_i32_13 : i32 - %85 = tt.broadcast %84 : (i32) -> tensor<128x128xi32> - %86 = tt.getelementptr %arg12, %85, : tensor<128x128x!tt.ptr> - scf.yield %81, %83, %86 : tensor<128x128xf32>, tensor<128x128x!tt.ptr>, tensor<128x128x!tt.ptr> - } - %48 = arith.truncf %47#0 : tensor<128x128xf32> to tensor<128x128xf16> - %c128_i32_4 = arith.constant 128 : i32 - %49 = arith.muli %9, %c128_i32_4 : i32 - %50 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32> - %51 = tt.broadcast %49 : (i32) -> tensor<128xi32> - %52 = arith.addi %51, %50 : tensor<128xi32> - %c128_i32_5 = arith.constant 128 : i32 - %53 = arith.muli %11, %c128_i32_5 : i32 - %54 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32> - %55 = tt.broadcast %53 : (i32) -> tensor<128xi32> - %56 = arith.addi %55, %54 : tensor<128xi32> - %57 = tt.reshape %52 : (tensor<128xi32>) -> tensor<128x1xi32> - %58 = tt.broadcast %arg8 : (i32) -> tensor<128x1xi32> - %59 = arith.muli %58, %57 : tensor<128x1xi32> - %60 = tt.broadcast %arg2 : (!tt.ptr) -> tensor<128x1x!tt.ptr> - %61 = tt.getelementptr %60, %59, : tensor<128x1x!tt.ptr> - %62 = tt.reshape %56 : (tensor<128xi32>) -> tensor<1x128xi32> - %c1_i32_6 = arith.constant 1 : i32 - %63 = tt.broadcast %c1_i32_6 : (i32) -> tensor<1x128xi32> - %64 = arith.muli %62, %63 : tensor<1x128xi32> - %65 = tt.broadcast %61 : (tensor<128x1x!tt.ptr>) -> tensor<128x128x!tt.ptr> - %66 = tt.broadcast %64 : (tensor<1x128xi32>) -> tensor<128x128xi32> - %67 = tt.getelementptr %65, %66, : tensor<128x128x!tt.ptr> - %68 = tt.reshape %52 : (tensor<128xi32>) -> tensor<128x1xi32> - %69 = tt.broadcast %arg3 : (i32) -> tensor<128x1xi32> - %70 = arith.cmpi slt, %68, %69 : tensor<128x1xi32> - %71 = tt.reshape %56 : (tensor<128xi32>) -> tensor<1x128xi32> - %72 = tt.broadcast %arg4 : (i32) -> tensor<1x128xi32> - %73 = arith.cmpi slt, %71, %72 : tensor<1x128xi32> - %74 = tt.broadcast %70 : (tensor<128x1xi1>) -> tensor<128x128xi1> - %75 = tt.broadcast %73 : (tensor<1x128xi1>) -> tensor<128x128xi1> - %76 = arith.andi %74, %75 : tensor<128x128xi1> - tt.store %67, %48, %76, : tensor<128x128xf16> - return - } - func @"cdiv__i32__1cconstexpr[128]"(%arg0: i32) -> i32 { - %c128_i32 = arith.constant 128 : i32 - %0 = arith.addi %arg0, %c128_i32 : i32 - %c1_i32 = arith.constant 1 : i32 - %1 = arith.subi %0, %c1_i32 : i32 - %c128_i32_0 = arith.constant 128 : i32 - %2 = arith.divsi %1, %c128_i32_0 : i32 - return %2 : i32 - } - func @"minimum__i32__1cconstexpr[8]"(%arg0: i32) -> i32 { - %c8_i32 = arith.constant 8 : i32 - %0 = arith.cmpi slt, %arg0, %c8_i32 : i32 - %c8_i32_0 = arith.constant 8 : i32 - %1 = select %0, %arg0, %c8_i32_0 : i32 - return %1 : i32 - } -} -module { - func @matmul_kernel__Pfp16_Pfp16_Pfp16_i32_i32_i32_i32_i32_i32__7c1_9c1_11c1_12c128_13c128_14c128_15c8(%arg0: !tt.ptr, %arg1: !tt.ptr, %arg2: !tt.ptr, %arg3: i32, %arg4: i32, %arg5: i32, %arg6: i32, %arg7: i32, %arg8: i32) { - %c8_i32 = arith.constant 8 : i32 - %c127_i32 = arith.constant 127 : i32 - %c128_i32 = arith.constant 128 : i32 - %cst = arith.constant 0.000000e+00 : f32 - %c0 = arith.constant 0 : index - %c128 = arith.constant 128 : index - %cst_0 = arith.constant dense : tensor<128x128xi1, #triton_gpu<"coalesced encoding">> - %cst_1 = arith.constant dense<0.000000e+00> : tensor<128x128xf16, #triton_gpu<"coalesced encoding">> - %c1_i32 = arith.constant 1 : i32 - %0 = tt.get_program_id {axis = 0 : i32} : i32 - %1 = arith.addi %arg3, %c127_i32 : i32 - %2 = arith.divsi %1, %c128_i32 : i32 - %3 = arith.addi %arg4, %c127_i32 : i32 - %4 = arith.divsi %3, %c128_i32 : i32 - %5 = arith.muli %4, %c8_i32 : i32 - %6 = arith.divsi %0, %5 : i32 - %7 = arith.muli %6, %c8_i32 : i32 - %8 = arith.subi %2, %7 : i32 - %9 = arith.cmpi slt, %8, %c8_i32 : i32 - %10 = select %9, %8, %c8_i32 : i32 - %11 = arith.remsi %0, %10 : i32 - %12 = arith.addi %7, %11 : i32 - %13 = arith.remsi %0, %5 : i32 - %14 = arith.divsi %13, %10 : i32 - %15 = arith.muli %12, %c128_i32 : i32 - %16 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #triton_gpu<"coalesced encoding">> - %17 = tt.broadcast %15 : (i32) -> tensor<128xi32, #triton_gpu<"coalesced encoding">> - %18 = arith.addi %17, %16 : tensor<128xi32, #triton_gpu<"coalesced encoding">> - %19 = arith.muli %14, %c128_i32 : i32 - %20 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #triton_gpu<"coalesced encoding">> - %21 = tt.broadcast %19 : (i32) -> tensor<128xi32, #triton_gpu<"coalesced encoding">> - %22 = arith.addi %21, %20 : tensor<128xi32, #triton_gpu<"coalesced encoding">> - %23 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #triton_gpu<"coalesced encoding">> - %24 = tt.reshape %18 : (tensor<128xi32, #triton_gpu<"coalesced encoding">>) -> tensor<128x1xi32, #triton_gpu<"coalesced encoding">> - %25 = tt.broadcast %arg6 : (i32) -> tensor<128x1xi32, #triton_gpu<"coalesced encoding">> - %26 = arith.muli %24, %25 : tensor<128x1xi32, #triton_gpu<"coalesced encoding">> - %27 = tt.reshape %23 : (tensor<128xi32, #triton_gpu<"coalesced encoding">>) -> tensor<1x128xi32, #triton_gpu<"coalesced encoding">> - %28 = tt.broadcast %c1_i32 : (i32) -> tensor<1x128xi32, #triton_gpu<"coalesced encoding">> - %29 = arith.muli %27, %28 : tensor<1x128xi32, #triton_gpu<"coalesced encoding">> - %30 = tt.broadcast %26 : (tensor<128x1xi32, #triton_gpu<"coalesced encoding">>) -> tensor<128x128xi32, #triton_gpu<"coalesced encoding">> - %31 = tt.broadcast %29 : (tensor<1x128xi32, #triton_gpu<"coalesced encoding">>) -> tensor<128x128xi32, #triton_gpu<"coalesced encoding">> - %32 = arith.addi %30, %31 : tensor<128x128xi32, #triton_gpu<"coalesced encoding">> - %33 = tt.broadcast %arg0 : (!tt.ptr) -> tensor<128x128x!tt.ptr, #triton_gpu<"coalesced encoding">> - %34 = tt.getelementptr %33, %32, : tensor<128x128x!tt.ptr, #triton_gpu<"coalesced encoding">> - %35 = tt.reshape %23 : (tensor<128xi32, #triton_gpu<"coalesced encoding">>) -> tensor<128x1xi32, #triton_gpu<"coalesced encoding">> - %36 = tt.broadcast %arg7 : (i32) -> tensor<128x1xi32, #triton_gpu<"coalesced encoding">> - %37 = arith.muli %35, %36 : tensor<128x1xi32, #triton_gpu<"coalesced encoding">> - %38 = tt.reshape %22 : (tensor<128xi32, #triton_gpu<"coalesced encoding">>) -> tensor<1x128xi32, #triton_gpu<"coalesced encoding">> - %39 = tt.broadcast %c1_i32 : (i32) -> tensor<1x128xi32, #triton_gpu<"coalesced encoding">> - %40 = arith.muli %38, %39 : tensor<1x128xi32, #triton_gpu<"coalesced encoding">> - %41 = tt.broadcast %37 : (tensor<128x1xi32, #triton_gpu<"coalesced encoding">>) -> tensor<128x128xi32, #triton_gpu<"coalesced encoding">> - %42 = tt.broadcast %40 : (tensor<1x128xi32, #triton_gpu<"coalesced encoding">>) -> tensor<128x128xi32, #triton_gpu<"coalesced encoding">> - %43 = arith.addi %41, %42 : tensor<128x128xi32, #triton_gpu<"coalesced encoding">> - %44 = tt.broadcast %arg1 : (!tt.ptr) -> tensor<128x128x!tt.ptr, #triton_gpu<"coalesced encoding">> - %45 = tt.getelementptr %44, %43, : tensor<128x128x!tt.ptr, #triton_gpu<"coalesced encoding">> - %46 = tt.broadcast %cst : (f32) -> tensor<128x128xf32, #triton_gpu<"coalesced encoding">> - %47 = arith.index_cast %arg5 : i32 to index - %48 = "triton_gpu.copy_async"(%34, %cst_0, %cst_1) {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : (tensor<128x128x!tt.ptr, #triton_gpu<"coalesced encoding">>, tensor<128x128xi1, #triton_gpu<"coalesced encoding">>, tensor<128x128xf16, #triton_gpu<"coalesced encoding">>) -> tensor<128x128xf16, #triton_gpu<"shared (memory) encoding<>">> - %49 = "triton_gpu.copy_async"(%45, %cst_0, %cst_1) {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : (tensor<128x128x!tt.ptr, #triton_gpu<"coalesced encoding">>, tensor<128x128xi1, #triton_gpu<"coalesced encoding">>, tensor<128x128xf16, #triton_gpu<"coalesced encoding">>) -> tensor<128x128xf16, #triton_gpu<"shared (memory) encoding<>">> - %50 = tt.broadcast %c128_i32 : (i32) -> tensor<128x128xi32, #triton_gpu<"coalesced encoding">> - %51 = tt.getelementptr %34, %50, : tensor<128x128x!tt.ptr, #triton_gpu<"coalesced encoding">> - %52 = arith.muli %arg7, %c128_i32 : i32 - %53 = tt.broadcast %52 : (i32) -> tensor<128x128xi32, #triton_gpu<"coalesced encoding">> - %54 = tt.getelementptr %45, %53, : tensor<128x128x!tt.ptr, #triton_gpu<"coalesced encoding">> - %55:8 = scf.for %arg9 = %c0 to %47 step %c128 iter_args(%arg10 = %46, %arg11 = %34, %arg12 = %45, %arg13 = %48, %arg14 = %49, %arg15 = %51, %arg16 = %54, %arg17 = %c0) -> (tensor<128x128xf32, #triton_gpu<"coalesced encoding">>, tensor<128x128x!tt.ptr, #triton_gpu<"coalesced encoding">>, tensor<128x128x!tt.ptr, #triton_gpu<"coalesced encoding">>, tensor<128x128xf16, #triton_gpu<"shared (memory) encoding<>">>, tensor<128x128xf16, #triton_gpu<"shared (memory) encoding<>">>, tensor<128x128x!tt.ptr, #triton_gpu<"coalesced encoding">>, tensor<128x128x!tt.ptr, #triton_gpu<"coalesced encoding">>, index) { - %85 = tt.dot %arg13, %arg14, %arg10 {allowTF32 = true} : tensor<128x128xf16, #triton_gpu<"shared (memory) encoding<>">> * tensor<128x128xf16, #triton_gpu<"shared (memory) encoding<>">> -> tensor<128x128xf32, #triton_gpu<"coalesced encoding">> - %86 = tt.broadcast %c128_i32 : (i32) -> tensor<128x128xi32, #triton_gpu<"coalesced encoding">> - %87 = tt.getelementptr %arg11, %86, : tensor<128x128x!tt.ptr, #triton_gpu<"coalesced encoding">> - %88 = arith.muli %arg7, %c128_i32 : i32 - %89 = tt.broadcast %88 : (i32) -> tensor<128x128xi32, #triton_gpu<"coalesced encoding">> - %90 = tt.getelementptr %arg12, %89, : tensor<128x128x!tt.ptr, #triton_gpu<"coalesced encoding">> - %91 = arith.addi %arg17, %c128 : index - %92 = arith.cmpi slt, %91, %47 : index - %93 = tt.broadcast %92 : (i1) -> tensor<128x128xi1, #triton_gpu<"coalesced encoding">> - %94 = "triton_gpu.copy_async"(%arg15, %93, %cst_1) {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : (tensor<128x128x!tt.ptr, #triton_gpu<"coalesced encoding">>, tensor<128x128xi1, #triton_gpu<"coalesced encoding">>, tensor<128x128xf16, #triton_gpu<"coalesced encoding">>) -> tensor<128x128xf16, #triton_gpu<"shared (memory) encoding<>">> - %95 = "triton_gpu.copy_async"(%arg16, %93, %cst_1) {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : (tensor<128x128x!tt.ptr, #triton_gpu<"coalesced encoding">>, tensor<128x128xi1, #triton_gpu<"coalesced encoding">>, tensor<128x128xf16, #triton_gpu<"coalesced encoding">>) -> tensor<128x128xf16, #triton_gpu<"shared (memory) encoding<>">> - %96 = tt.broadcast %c128_i32 : (i32) -> tensor<128x128xi32, #triton_gpu<"coalesced encoding">> - %97 = tt.getelementptr %arg15, %96, : tensor<128x128x!tt.ptr, #triton_gpu<"coalesced encoding">> - %98 = arith.muli %arg7, %c128_i32 : i32 - %99 = tt.broadcast %98 : (i32) -> tensor<128x128xi32, #triton_gpu<"coalesced encoding">> - %100 = tt.getelementptr %arg16, %99, : tensor<128x128x!tt.ptr, #triton_gpu<"coalesced encoding">> - scf.yield %85, %87, %90, %94, %95, %97, %100, %91 : tensor<128x128xf32, #triton_gpu<"coalesced encoding">>, tensor<128x128x!tt.ptr, #triton_gpu<"coalesced encoding">>, tensor<128x128x!tt.ptr, #triton_gpu<"coalesced encoding">>, tensor<128x128xf16, #triton_gpu<"shared (memory) encoding<>">>, tensor<128x128xf16, #triton_gpu<"shared (memory) encoding<>">>, tensor<128x128x!tt.ptr, #triton_gpu<"coalesced encoding">>, tensor<128x128x!tt.ptr, #triton_gpu<"coalesced encoding">>, index - } - %56 = arith.truncf %55#0 : tensor<128x128xf32, #triton_gpu<"coalesced encoding">> to tensor<128x128xf16, #triton_gpu<"coalesced encoding">> - %57 = arith.muli %12, %c128_i32 : i32 - %58 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #triton_gpu<"coalesced encoding">> - %59 = tt.broadcast %57 : (i32) -> tensor<128xi32, #triton_gpu<"coalesced encoding">> - %60 = arith.addi %59, %58 : tensor<128xi32, #triton_gpu<"coalesced encoding">> - %61 = arith.muli %14, %c128_i32 : i32 - %62 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #triton_gpu<"coalesced encoding">> - %63 = tt.broadcast %61 : (i32) -> tensor<128xi32, #triton_gpu<"coalesced encoding">> - %64 = arith.addi %63, %62 : tensor<128xi32, #triton_gpu<"coalesced encoding">> - %65 = tt.reshape %60 : (tensor<128xi32, #triton_gpu<"coalesced encoding">>) -> tensor<128x1xi32, #triton_gpu<"coalesced encoding">> - %66 = tt.broadcast %arg8 : (i32) -> tensor<128x1xi32, #triton_gpu<"coalesced encoding">> - %67 = arith.muli %66, %65 : tensor<128x1xi32, #triton_gpu<"coalesced encoding">> - %68 = tt.broadcast %arg2 : (!tt.ptr) -> tensor<128x1x!tt.ptr, #triton_gpu<"coalesced encoding">> - %69 = tt.getelementptr %68, %67, : tensor<128x1x!tt.ptr, #triton_gpu<"coalesced encoding">> - %70 = tt.reshape %64 : (tensor<128xi32, #triton_gpu<"coalesced encoding">>) -> tensor<1x128xi32, #triton_gpu<"coalesced encoding">> - %71 = tt.broadcast %c1_i32 : (i32) -> tensor<1x128xi32, #triton_gpu<"coalesced encoding">> - %72 = arith.muli %70, %71 : tensor<1x128xi32, #triton_gpu<"coalesced encoding">> - %73 = tt.broadcast %69 : (tensor<128x1x!tt.ptr, #triton_gpu<"coalesced encoding">>) -> tensor<128x128x!tt.ptr, #triton_gpu<"coalesced encoding">> - %74 = tt.broadcast %72 : (tensor<1x128xi32, #triton_gpu<"coalesced encoding">>) -> tensor<128x128xi32, #triton_gpu<"coalesced encoding">> - %75 = tt.getelementptr %73, %74, : tensor<128x128x!tt.ptr, #triton_gpu<"coalesced encoding">> - %76 = tt.reshape %60 : (tensor<128xi32, #triton_gpu<"coalesced encoding">>) -> tensor<128x1xi32, #triton_gpu<"coalesced encoding">> - %77 = tt.broadcast %arg3 : (i32) -> tensor<128x1xi32, #triton_gpu<"coalesced encoding">> - %78 = "triton_gpu.cmpi"(%76, %77) {predicate = 2 : i64} : (tensor<128x1xi32, #triton_gpu<"coalesced encoding">>, tensor<128x1xi32, #triton_gpu<"coalesced encoding">>) -> tensor<128x1xi1, #triton_gpu<"coalesced encoding">> - %79 = tt.reshape %64 : (tensor<128xi32, #triton_gpu<"coalesced encoding">>) -> tensor<1x128xi32, #triton_gpu<"coalesced encoding">> - %80 = tt.broadcast %arg4 : (i32) -> tensor<1x128xi32, #triton_gpu<"coalesced encoding">> - %81 = "triton_gpu.cmpi"(%79, %80) {predicate = 2 : i64} : (tensor<1x128xi32, #triton_gpu<"coalesced encoding">>, tensor<1x128xi32, #triton_gpu<"coalesced encoding">>) -> tensor<1x128xi1, #triton_gpu<"coalesced encoding">> - %82 = tt.broadcast %78 : (tensor<128x1xi1, #triton_gpu<"coalesced encoding">>) -> tensor<128x128xi1, #triton_gpu<"coalesced encoding">> - %83 = tt.broadcast %81 : (tensor<1x128xi1, #triton_gpu<"coalesced encoding">>) -> tensor<128x128xi1, #triton_gpu<"coalesced encoding">> - %84 = arith.andi %82, %83 : tensor<128x128xi1, #triton_gpu<"coalesced encoding">> - tt.store %75, %56, %84, : tensor<128x128xf16, #triton_gpu<"coalesced encoding">> - return - } - func @"cdiv__i32__1cconstexpr[128]"(%arg0: i32) -> i32 { - %c128_i32 = arith.constant 128 : i32 - %c127_i32 = arith.constant 127 : i32 - %0 = arith.addi %arg0, %c127_i32 : i32 - %1 = arith.divsi %0, %c128_i32 : i32 - return %1 : i32 - } - func @"minimum__i32__1cconstexpr[8]"(%arg0: i32) -> i32 { - %c8_i32 = arith.constant 8 : i32 - %0 = arith.cmpi slt, %arg0, %c8_i32 : i32 - %1 = select %0, %arg0, %c8_i32 : i32 - return %1 : i32 - } -} diff --git a/rewrite-test/jit/matmul/matmul.py b/rewrite-test/jit/matmul/matmul.py deleted file mode 100644 index f16cf8b26851..000000000000 --- a/rewrite-test/jit/matmul/matmul.py +++ /dev/null @@ -1,105 +0,0 @@ -import triton -import triton.language as tl -import triton._C.libtriton.triton as _triton - - -import torch - -@triton.jit -def matmul_kernel( - # Pointers to matrices - a_ptr, b_ptr, c_ptr, - # Matrix dimensions - M, N, K, - # The stride variables represent how much to increase the ptr by when moving by 1 - # element in a particular dimension. E.g. stride_am is how much to increase a_ptr - # by to get the element one row down (A has M rows) - stride_am, stride_ak, - stride_bk, stride_bn, - stride_cm, stride_cn, - # Meta-parameters - BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, - GROUP_SIZE_M: tl.constexpr, -): - """Kernel for computing the matmul C = A x B. - A has shape (M, K), B has shape (K, N) and C has shape (M, N) - """ - # ----------------------------------------------------------- - # Map program ids `pid` to the block of C it should compute. - # This is done in a grouped ordering to promote L2 data reuse - # See above `L2 Cache Optimizations` section for details - pid = tl.program_id(axis=0) - num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) - num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) - num_pid_in_group = GROUP_SIZE_M * num_pid_n - group_id = pid // num_pid_in_group - first_pid_m = group_id * GROUP_SIZE_M - group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) - pid_m = first_pid_m + (pid % group_size_m) - pid_n = (pid % num_pid_in_group) // group_size_m - - # ---------------------------------------------------------- - # Create pointers for the first blocks of A and B. - # We will advance this pointer as we move in the K direction - # and accumulate - # a_ptrs is a block of [BLOCK_SIZE_M, BLOCK_SIZE_K] pointers - # b_ptrs is a block of [BLOCK_SIZE_K, BLOCK_SIZE_n] pointers - # see above `Pointer Arithmetics` section for details - offs_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) - offs_bn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) - offs_k = tl.arange(0, BLOCK_SIZE_K) - a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak) - b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn) - - # ----------------------------------------------------------- - # Iterate to compute a block of the C matrix - # We accumulate into a `[BLOCK_SIZE_M, BLOCK_SIZE_N]` block - # of fp32 values for higher accuracy. - # `accumulator` will be converted back to fp16 after the loop - accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) - for k in range(0, K, BLOCK_SIZE_K): - # Note that for simplicity, we don't apply a mask here. - # This means that if K is not a multiple of BLOCK_SIZE_K, - # this will access out-of-bounds memory and produce an - # error or (worse!) incorrect results. - a = tl.load(a_ptrs) - b = tl.load(b_ptrs) - # We accumulate along the K dimension - accumulator += tl.dot(a, b) - # Advance the ptrs to the next K block - a_ptrs += BLOCK_SIZE_K * stride_ak - b_ptrs += BLOCK_SIZE_K * stride_bk - c = accumulator.to(tl.float16) - - # ----------------------------------------------------------- - # Write back the block of the output matrix C - offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) - offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) - c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :] - c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N) - tl.store(c_ptrs, c, mask=c_mask) - -a = torch.randn((512, 512), device='cuda', dtype=torch.float16) -b = torch.randn((512, 512), device='cuda', dtype=torch.float16) -c = torch.empty((512, 512), device='cuda', dtype=torch.float16) - - -mod, ctx = matmul_kernel.compile_to_ttir( - a, b, c, - 512, 512, 512, - a.stride(0), a.stride(1), - b.stride(0), b.stride(1), - c.stride(0), c.stride(1), - 128, 128, 128, - 8, grid=(2,), - num_stages=4 -) - -assert mod.verify() -# mod.dump() - -res = matmul_kernel.compile_ttir_to_llir(mod, ctx) - -assert mod.verify() -# assert res -mod.dump() diff --git a/rewrite-test/jit/multi-return.py b/rewrite-test/jit/multi-return.py deleted file mode 100644 index 00588bf0b721..000000000000 --- a/rewrite-test/jit/multi-return.py +++ /dev/null @@ -1,27 +0,0 @@ -import triton -import triton.language as tl -import triton._C.libtriton.triton as _triton - - -@triton.jit -def foo(a, b): - max, min = maxmin(a, b) - return max, min - -@triton.jit -def maxmin(a, b): - max = tl.maximum(a, b) - min = tl.minimum(a, b) - return max, min - - -mod, ctx = foo.compile_to_ttir(3, 4, grid=(1,)) -assert mod.verify() -mod.dump() - - -pm = _triton.ir.pass_manager(ctx) -pm.add_inliner_pass() -pm.run(mod) -assert mod.verify() -mod.dump() diff --git a/rewrite-test/jit/vecadd.py b/rewrite-test/jit/vecadd.py deleted file mode 100644 index 11a99517f93d..000000000000 --- a/rewrite-test/jit/vecadd.py +++ /dev/null @@ -1,47 +0,0 @@ -import torch -import triton -import triton.language as tl -import triton._C.libtriton.triton as _triton - - - -@triton.jit -def add_kernel( - x_ptr, # *Pointer* to first input vector - y_ptr, # *Pointer* to second input vector - output_ptr, # *Pointer* to output vector - n_elements, # Size of the vector - BLOCK_SIZE: tl.constexpr, # Number of elements each program should process - # NOTE: `constexpr` so it can be used as a shape value -): - # There are multiple 'program's processing different data. We identify which program - # we are here - pid = tl.program_id(axis=0) # We use a 1D launch grid so axis is 0 - # This program will process inputs that are offset from the initial data. - # for instance, if you had a vector of length 256 and block_size of 64, the programs - # would each access the elements [0:64, 64:128, 128:192, 192:256]. - # Note that offsets is a list of pointers - block_start = pid * BLOCK_SIZE - offsets = block_start + tl.arange(0, BLOCK_SIZE) - # Create a mask to guard memory operations against out-of-bounds accesses - mask = offsets < n_elements - # Load x and y from DRAM, masking out any extra elements in case the input is not a - # multiple of the block size - x = tl.load(x_ptr + offsets, mask=mask, other=0.0) - y = tl.load(y_ptr + offsets, mask=mask, other=0.0) - output = x + y - # Write x + y back to DRAM - tl.store(output_ptr + offsets, output, mask=mask) - -size = 1024 -x = torch.rand(size, device='cuda') -y = torch.rand(size, device='cuda') -z = torch.empty_like(x) -# add_kernel[(1,)](x, y, z, size, 256) -# print(add_kernel[(1,)].kernel.compile_to_ttir()) -# print(add_kernel.annotations) -mod, ctx = add_kernel.compile_to_ttir(x, y, z, size, BLOCK_SIZE=256, grid=(1,)) -assert mod.verify() -mod.dump() -add_kernel.compile_ttir_to_llir(mod, ctx) -mod.dump() diff --git a/rewrite-test/jit/vecadd/vecadd-loop.py b/rewrite-test/jit/vecadd/vecadd-loop.py deleted file mode 100644 index 643af7360dd8..000000000000 --- a/rewrite-test/jit/vecadd/vecadd-loop.py +++ /dev/null @@ -1,56 +0,0 @@ -import torch -import triton -import triton.language as tl - - -@triton.jit -def add_kernel( - x_ptr, # *Pointer* to first input vector - y_ptr, # *Pointer* to second input vector - output_ptr, # *Pointer* to output vector - n_elements, # Size of the vector - K, - stride - # BLOCK_SIZE: tl.constexpr, # Number of elements each program should process - # # NOTE: `constexpr` so it can be used as a shape value -): - # There are multiple 'program's processing different data. We identify which program - # we are here - pid = tl.program_id(axis=0) # We use a 1D launch grid so axis is 0 - # This program will process inputs that are offset from the initial data. - # for instance, if you had a vector of length 256 and block_size of 64, the programs - # would each access the elements [0:64, 64:128, 128:192, 192:256]. - # Note that offsets is a list of pointers - block_start = pid * 256 - offsets = block_start + tl.arange(0, 256) - # Create a mask to guard memory operations against out-of-bounds accesses - mask = offsets < n_elements - - x_ptrs = x_ptr + offsets - y_ptrs = y_ptr + offsets - output = tl.zeros((256,), dtype=tl.float32) - for k in range(0, K, 32): - x = tl.load(x_ptrs, mask=mask, other=0.0) - y = tl.load(y_ptrs, mask=mask, other=0.0) - output += x + y - - x_ptrs += stride - y_ptrs += stride - - # Write x + y back to DRAM - tl.store(output_ptr + offsets, output, mask=mask) - -size = 1024 -x = torch.rand(size, device='cuda') -y = torch.rand(size, device='cuda') -z = torch.empty_like(x) -# add_kernel[(1,)](x, y, z, size, 256) -# print(add_kernel[(1,)].kernel.compile_to_ttir()) -mod, ctx = add_kernel.compile_to_ttir( - x, y, z, size, 128, 8, grid=(1,), num_stages=1) -mod.dump() -# print(mod) - -res = add_kernel.compile_ttir_to_llir(mod, ctx) - -mod.dump() diff --git a/rewrite-test/jit/vecadd/vecadd.mlir b/rewrite-test/jit/vecadd/vecadd.mlir deleted file mode 100644 index 27db9032369a..000000000000 --- a/rewrite-test/jit/vecadd/vecadd.mlir +++ /dev/null @@ -1,82 +0,0 @@ -module { - func @add_kernel__Pfp32_Pfp32_Pfp32_i32_i32_i32__(%arg0: !tt.ptr, %arg1: !tt.ptr, %arg2: !tt.ptr, %arg3: i32, %arg4: i32, %arg5: i32) { - %0 = tt.get_program_id {axis = 0 : i32} : i32 - %c256_i32 = arith.constant 256 : i32 - %1 = arith.muli %0, %c256_i32 : i32 - %2 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32> - %3 = tt.broadcast %1 : (i32) -> tensor<256xi32> - %4 = arith.addi %3, %2 : tensor<256xi32> - %5 = tt.broadcast %arg3 : (i32) -> tensor<256xi32> - %6 = arith.cmpi slt, %4, %5 : tensor<256xi32> - %7 = tt.broadcast %arg0 : (!tt.ptr) -> tensor<256x!tt.ptr> - %8 = tt.getelementptr %7, %4 : tensor<256x!tt.ptr> - %9 = tt.broadcast %arg1 : (!tt.ptr) -> tensor<256x!tt.ptr> - %10 = tt.getelementptr %9, %4 : tensor<256x!tt.ptr> - %cst = arith.constant 0.000000e+00 : f32 - %11 = tt.broadcast %cst : (f32) -> tensor<256xf32> - %c0_i32 = arith.constant 0 : i32 - %c32_i32 = arith.constant 32 : i32 - %12 = arith.index_cast %c0_i32 : i32 to index - %13 = arith.index_cast %arg4 : i32 to index - %14 = arith.index_cast %c32_i32 : i32 to index - %15:3 = scf.for %arg6 = %12 to %13 step %14 iter_args(%arg7 = %11, %arg8 = %8, %arg9 = %10) -> (tensor<256xf32>, tensor<256x!tt.ptr>, tensor<256x!tt.ptr>) { - %cst_0 = arith.constant 0.000000e+00 : f32 - %18 = tt.broadcast %cst_0 : (f32) -> tensor<256xf32> - %19 = tt.load %arg8, %6, %18 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<256xf32> - %cst_1 = arith.constant 0.000000e+00 : f32 - %20 = tt.broadcast %cst_1 : (f32) -> tensor<256xf32> - %21 = tt.load %arg9, %6, %20 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<256xf32> - %22 = arith.addf %19, %21 : tensor<256xf32> - %23 = arith.addf %arg7, %22 : tensor<256xf32> - %24 = tt.broadcast %arg5 : (i32) -> tensor<256xi32> - %25 = tt.getelementptr %arg8, %24 : tensor<256x!tt.ptr> - %26 = tt.broadcast %arg5 : (i32) -> tensor<256xi32> - %27 = tt.getelementptr %arg9, %26 : tensor<256x!tt.ptr> - scf.yield %23, %25, %27 : tensor<256xf32>, tensor<256x!tt.ptr>, tensor<256x!tt.ptr> - } - %16 = tt.broadcast %arg2 : (!tt.ptr) -> tensor<256x!tt.ptr> - %17 = tt.getelementptr %16, %4 : tensor<256x!tt.ptr> - tt.store %17, %15#0, %6, : tensor<256xf32> - return - } -} -module { - func @add_kernel__Pfp32_Pfp32_Pfp32_i32_i32_i32__(%arg0: !tt.ptr, %arg1: !tt.ptr, %arg2: !tt.ptr, %arg3: i32, %arg4: i32, %arg5: i32) { - %c32 = arith.constant 32 : index - %c0 = arith.constant 0 : index - %c256_i32 = arith.constant 256 : i32 - %cst = arith.constant dense<0.000000e+00> : tensor<256xf32, #triton_gpu.blocked_layout<{threadTileSize = [1], warpTileSize = [32], blockTileSize = [128], order = [0], broadcastAxis = []}>> - %0 = tt.get_program_id {axis = 0 : i32} : i32 - %1 = arith.muli %0, %c256_i32 : i32 - %2 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32, #triton_gpu.blocked_layout<{threadTileSize = [1], warpTileSize = [32], blockTileSize = [128], order = [0], broadcastAxis = []}>> - %3 = tt.broadcast %1 : (i32) -> tensor<256xi32, #triton_gpu.blocked_layout<{threadTileSize = [1], warpTileSize = [32], blockTileSize = [128], order = [0], broadcastAxis = [0]}>> - %4 = triton_gpu.convert_layout %3 : (tensor<256xi32, #triton_gpu.blocked_layout<{threadTileSize = [1], warpTileSize = [32], blockTileSize = [128], order = [0], broadcastAxis = [0]}>>) -> tensor<256xi32, #triton_gpu.blocked_layout<{threadTileSize = [1], warpTileSize = [32], blockTileSize = [128], order = [0], broadcastAxis = []}>> - %5 = arith.addi %4, %2 : tensor<256xi32, #triton_gpu.blocked_layout<{threadTileSize = [1], warpTileSize = [32], blockTileSize = [128], order = [0], broadcastAxis = []}>> - %6 = tt.broadcast %arg3 : (i32) -> tensor<256xi32, #triton_gpu.blocked_layout<{threadTileSize = [1], warpTileSize = [32], blockTileSize = [128], order = [0], broadcastAxis = [0]}>> - %7 = triton_gpu.convert_layout %6 : (tensor<256xi32, #triton_gpu.blocked_layout<{threadTileSize = [1], warpTileSize = [32], blockTileSize = [128], order = [0], broadcastAxis = [0]}>>) -> tensor<256xi32, #triton_gpu.blocked_layout<{threadTileSize = [1], warpTileSize = [32], blockTileSize = [128], order = [0], broadcastAxis = []}>> - %8 = "triton_gpu.cmpi"(%5, %7) {predicate = 2 : i64} : (tensor<256xi32, #triton_gpu.blocked_layout<{threadTileSize = [1], warpTileSize = [32], blockTileSize = [128], order = [0], broadcastAxis = []}>>, tensor<256xi32, #triton_gpu.blocked_layout<{threadTileSize = [1], warpTileSize = [32], blockTileSize = [128], order = [0], broadcastAxis = []}>>) -> tensor<256xi1, #triton_gpu.blocked_layout<{threadTileSize = [1], warpTileSize = [32], blockTileSize = [128], order = [0], broadcastAxis = []}>> - %9 = tt.broadcast %arg0 : (!tt.ptr) -> tensor<256x!tt.ptr, #triton_gpu.blocked_layout<{threadTileSize = [1], warpTileSize = [32], blockTileSize = [128], order = [0], broadcastAxis = [0]}>> - %10 = triton_gpu.convert_layout %9 : (tensor<256x!tt.ptr, #triton_gpu.blocked_layout<{threadTileSize = [1], warpTileSize = [32], blockTileSize = [128], order = [0], broadcastAxis = [0]}>>) -> tensor<256x!tt.ptr, #triton_gpu.blocked_layout<{threadTileSize = [1], warpTileSize = [32], blockTileSize = [128], order = [0], broadcastAxis = []}>> - %11 = tt.getelementptr %10, %5 : tensor<256x!tt.ptr, #triton_gpu.blocked_layout<{threadTileSize = [1], warpTileSize = [32], blockTileSize = [128], order = [0], broadcastAxis = []}>> - %12 = tt.broadcast %arg1 : (!tt.ptr) -> tensor<256x!tt.ptr, #triton_gpu.blocked_layout<{threadTileSize = [1], warpTileSize = [32], blockTileSize = [128], order = [0], broadcastAxis = [0]}>> - %13 = triton_gpu.convert_layout %12 : (tensor<256x!tt.ptr, #triton_gpu.blocked_layout<{threadTileSize = [1], warpTileSize = [32], blockTileSize = [128], order = [0], broadcastAxis = [0]}>>) -> tensor<256x!tt.ptr, #triton_gpu.blocked_layout<{threadTileSize = [1], warpTileSize = [32], blockTileSize = [128], order = [0], broadcastAxis = []}>> - %14 = tt.getelementptr %13, %5 : tensor<256x!tt.ptr, #triton_gpu.blocked_layout<{threadTileSize = [1], warpTileSize = [32], blockTileSize = [128], order = [0], broadcastAxis = []}>> - %15 = arith.index_cast %arg4 : i32 to index - %16:3 = scf.for %arg6 = %c0 to %15 step %c32 iter_args(%arg7 = %cst, %arg8 = %11, %arg9 = %14) -> (tensor<256xf32, #triton_gpu.blocked_layout<{threadTileSize = [1], warpTileSize = [32], blockTileSize = [128], order = [0], broadcastAxis = []}>>, tensor<256x!tt.ptr, #triton_gpu.blocked_layout<{threadTileSize = [1], warpTileSize = [32], blockTileSize = [128], order = [0], broadcastAxis = []}>>, tensor<256x!tt.ptr, #triton_gpu.blocked_layout<{threadTileSize = [1], warpTileSize = [32], blockTileSize = [128], order = [0], broadcastAxis = []}>>) { - %20 = tt.load %arg8, %8, %cst {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<256xf32, #triton_gpu.blocked_layout<{threadTileSize = [1], warpTileSize = [32], blockTileSize = [128], order = [0], broadcastAxis = []}>> - %21 = tt.load %arg9, %8, %cst {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<256xf32, #triton_gpu.blocked_layout<{threadTileSize = [1], warpTileSize = [32], blockTileSize = [128], order = [0], broadcastAxis = []}>> - %22 = arith.addf %20, %21 : tensor<256xf32, #triton_gpu.blocked_layout<{threadTileSize = [1], warpTileSize = [32], blockTileSize = [128], order = [0], broadcastAxis = []}>> - %23 = arith.addf %arg7, %22 : tensor<256xf32, #triton_gpu.blocked_layout<{threadTileSize = [1], warpTileSize = [32], blockTileSize = [128], order = [0], broadcastAxis = []}>> - %24 = tt.broadcast %arg5 : (i32) -> tensor<256xi32, #triton_gpu.blocked_layout<{threadTileSize = [1], warpTileSize = [32], blockTileSize = [128], order = [0], broadcastAxis = [0]}>> - %25 = triton_gpu.convert_layout %24 : (tensor<256xi32, #triton_gpu.blocked_layout<{threadTileSize = [1], warpTileSize = [32], blockTileSize = [128], order = [0], broadcastAxis = [0]}>>) -> tensor<256xi32, #triton_gpu.blocked_layout<{threadTileSize = [1], warpTileSize = [32], blockTileSize = [128], order = [0], broadcastAxis = []}>> - %26 = tt.getelementptr %arg8, %25 : tensor<256x!tt.ptr, #triton_gpu.blocked_layout<{threadTileSize = [1], warpTileSize = [32], blockTileSize = [128], order = [0], broadcastAxis = []}>> - %27 = tt.getelementptr %arg9, %25 : tensor<256x!tt.ptr, #triton_gpu.blocked_layout<{threadTileSize = [1], warpTileSize = [32], blockTileSize = [128], order = [0], broadcastAxis = []}>> - scf.yield %23, %26, %27 : tensor<256xf32, #triton_gpu.blocked_layout<{threadTileSize = [1], warpTileSize = [32], blockTileSize = [128], order = [0], broadcastAxis = []}>>, tensor<256x!tt.ptr, #triton_gpu.blocked_layout<{threadTileSize = [1], warpTileSize = [32], blockTileSize = [128], order = [0], broadcastAxis = []}>>, tensor<256x!tt.ptr, #triton_gpu.blocked_layout<{threadTileSize = [1], warpTileSize = [32], blockTileSize = [128], order = [0], broadcastAxis = []}>> - } - %17 = tt.broadcast %arg2 : (!tt.ptr) -> tensor<256x!tt.ptr, #triton_gpu.blocked_layout<{threadTileSize = [1], warpTileSize = [32], blockTileSize = [128], order = [0], broadcastAxis = [0]}>> - %18 = triton_gpu.convert_layout %17 : (tensor<256x!tt.ptr, #triton_gpu.blocked_layout<{threadTileSize = [1], warpTileSize = [32], blockTileSize = [128], order = [0], broadcastAxis = [0]}>>) -> tensor<256x!tt.ptr, #triton_gpu.blocked_layout<{threadTileSize = [1], warpTileSize = [32], blockTileSize = [128], order = [0], broadcastAxis = []}>> - %19 = tt.getelementptr %18, %5 : tensor<256x!tt.ptr, #triton_gpu.blocked_layout<{threadTileSize = [1], warpTileSize = [32], blockTileSize = [128], order = [0], broadcastAxis = []}>> - tt.store %19, %16#0, %8, : tensor<256xf32, #triton_gpu.blocked_layout<{threadTileSize = [1], warpTileSize = [32], blockTileSize = [128], order = [0], broadcastAxis = []}>> - return - } -} diff --git a/rewrite-test/jit/while.py b/rewrite-test/jit/while.py deleted file mode 100644 index 2c114cd4bb47..000000000000 --- a/rewrite-test/jit/while.py +++ /dev/null @@ -1,38 +0,0 @@ -import triton -import triton.language as tl -import torch - -@triton.jit -def atomic(lock): - while tl.atomic_cas(lock, 0, 1) == 1: - pass - -@triton.jit -def generic_while(lb, value): - c = -1 - while c <= 0: - c += 1 - -# locks = torch.zeros(32, dtype=torch.int32, device='cuda') -# mod_atomic, ctx_atomic = atomic.compile_to_ttir(locks, grid=(1,)) -# mod_atomic.dump() - -# mod_generic_while, ctx_generic_while = generic_while.compile_to_ttir(8, 9, grid=(1,)) -# mod_generic_while.dump() - -@triton.jit -def nested_cf(X, lb, ub, Z): - a = 0.0 - if lb < ub: - for z in range(0, Z): - a += 2.0 - else: - while a < 1.2: - a *= 2.0 - for _ in range(0, Z, 2): - a *= -3.3 - a -= 1.0 - -mod, _ = nested_cf.compile_to_ttir(3, 4, 5, 6, grid=(1,)) -assert mod.verify(), mod.str() -mod.dump() diff --git a/rewrite-test/test_ir.py b/rewrite-test/test_ir.py deleted file mode 100644 index 3ecdbdfd26c0..000000000000 --- a/rewrite-test/test_ir.py +++ /dev/null @@ -1,58 +0,0 @@ -import triton._C.libtriton.triton.ir as ir - -ctx = ir.context() -ctx.load_triton() - -# TODO -builder = ir.builder(ctx) - -module = builder.create_module() - - -i1_ty = builder.get_int1_ty() -i8_ty = builder.get_int8_ty() -i16_ty = builder.get_int16_ty() -i32_ty = builder.get_int32_ty() -i64_ty = builder.get_int64_ty() - -f16_ty = builder.get_half_ty() - -f16_ptr_ty = builder.get_ptr_ty(f16_ty, 1) - -func_ty = builder.get_function_ty([f16_ptr_ty, f16_ptr_ty, f16_ptr_ty], []) -func = builder.create_function('foo', func_ty) - -module.push_back(func) -module.set_attr("num_warps", builder.get_int32_attr(4)) - -# ... -entry = func.add_entry_block() -builder.set_insertion_point_to_start(entry) -offsets = builder.create_make_range(0, 128) -pid = builder.create_get_program_id(0) -_128 = builder.get_int32(128) -offset = builder.create_add(pid, _128) -offset = builder.create_splat(offset, [128]) -offsets = builder.create_add(offset, offsets) - - -a_ptrs = builder.create_splat(entry.arg(0), [128]) -b_ptrs = builder.create_splat(entry.arg(1), [128]) - -a_ptrs = builder.create_gep(a_ptrs, offsets) -b_ptrs = builder.create_gep(b_ptrs, offsets) - -a = builder.create_load(a_ptrs, ir.CACHE_MODIFIER.NONE, ir.EVICTION_POLICY.NORMAL, False) -b = builder.create_load(b_ptrs, ir.CACHE_MODIFIER.NONE, ir.EVICTION_POLICY.NORMAL, False) - -c = builder.create_fadd(a, b) -c.set_attr("ieee_rounding", builder.get_bool_attr(True)) - -c_ptrs = builder.create_splat(entry.arg(2), [128]) -c_ptrs = builder.create_gep(c_ptrs, offsets) -builder.create_store(c_ptrs, c) - -# func.dump() - - -module.dump() diff --git a/rewrite-test/test_scf.py b/rewrite-test/test_scf.py deleted file mode 100644 index c4bb8584072c..000000000000 --- a/rewrite-test/test_scf.py +++ /dev/null @@ -1,417 +0,0 @@ -import pytest - -import triton -import triton.language as tl -import triton._C.libtriton.triton as _triton - - -import torch - -def test_if(): - ref_ir = """module { - func @only_if__i32_i32_i32__(%arg0: i32, %arg1: i32, %arg2: i32) { - %cst = arith.constant -1.000000e+00 : f32 - %0 = arith.cmpi sgt, %arg2, %arg0 : i32 - %1 = scf.if %0 -> (f32) { - %cst_0 = arith.constant 0.000000e+00 : f32 - scf.yield %cst_0 : f32 - } else { - scf.yield %cst : f32 - } - %2 = arith.addf %1, %1 : f32 - return - } -} -""" - - @triton.jit - def only_if(lb, ub, value): - a = -1.0 - if value > lb: - a = 0.0 - c = a + a - - mod, _ = only_if.compile_to_ttir(2, 3, 4, grid=(1,)) - generated_ir = mod.str() - assert mod.verify() - assert ref_ir == generated_ir - -def test_if_else(): - ref_ir = """module { - func @if_else__i32_i32_i32__(%arg0: i32, %arg1: i32, %arg2: i32) { - %0 = arith.cmpi sgt, %arg2, %arg0 : i32 - %1 = scf.if %0 -> (f32) { - %cst = arith.constant 0.000000e+00 : f32 - scf.yield %cst : f32 - } else { - %cst = arith.constant 1.000000e+00 : f32 - scf.yield %cst : f32 - } - %2 = arith.addf %1, %1 : f32 - return - } -} -""" - @triton.jit - def if_else(lb, ub, value): - if value > lb: - a = 0.0 - else: - a = 1.0 - c = a + a - - mod, _ = if_else.compile_to_ttir(2, 3, 4, grid=(1,)) - generated_ir = mod.str() - assert mod.verify() - assert ref_ir == generated_ir - -def test_for(): - ref_ir = """module { - func @for_loop__i32__(%arg0: i32) { - %cst = arith.constant 1.000000e+00 : f32 - %c0_i32 = arith.constant 0 : i32 - %c1_i32 = arith.constant 1 : i32 - %0 = arith.index_cast %c0_i32 : i32 to index - %1 = arith.index_cast %arg0 : i32 to index - %2 = arith.index_cast %c1_i32 : i32 to index - %3 = scf.for %arg1 = %0 to %1 step %2 iter_args(%arg2 = %cst) -> (f32) { - %cst_0 = arith.constant 1.000000e+00 : f32 - %4 = arith.addf %arg2, %cst_0 : f32 - scf.yield %4 : f32 - } - return - } -} -""" - - @triton.jit - def for_loop(K): - a = 1.0 - for k in range(0, K): - a += 1.0 - - mod, _ = for_loop.compile_to_ttir(2, grid=(1,)) - generated_ir = mod.str() - assert mod.verify() - assert ref_ir == generated_ir - -def test_while(): - ref_ir = """module { - func @generic_while__i32__(%arg0: i32) { - %c-1_i32 = arith.constant -1 : i32 - %0 = scf.while (%arg1 = %c-1_i32) : (i32) -> i32 { - %c0_i32 = arith.constant 0 : i32 - %1 = arith.cmpi sle, %arg1, %c0_i32 : i32 - scf.condition(%1) %arg1 : i32 - } do { - ^bb0(%arg1: i32): - %c1_i32 = arith.constant 1 : i32 - %1 = arith.addi %arg1, %c1_i32 : i32 - scf.yield %1 : i32 - } - return - } -} -""" - @triton.jit - def generic_while(x): - c = -1 - while c <= 0: - c += 1 - - mod, _ = generic_while.compile_to_ttir(2, grid=(1,)) - generated_ir = mod.str() - assert mod.verify() - assert ref_ir == generated_ir - -def test_nested(): - ref_ir = """module { - func @nested_cf__i32_i32_i32_i32__(%arg0: i32, %arg1: i32, %arg2: i32, %arg3: i32) { - %cst = arith.constant 0.000000e+00 : f32 - %c0_i32 = arith.constant 0 : i32 - %c1_i32 = arith.constant 1 : i32 - %0 = arith.index_cast %c0_i32 : i32 to index - %1 = arith.index_cast %arg0 : i32 to index - %2 = arith.index_cast %c1_i32 : i32 to index - %3 = scf.for %arg4 = %0 to %1 step %2 iter_args(%arg5 = %cst) -> (f32) { - %5 = arith.cmpi slt, %arg1, %arg2 : i32 - %6 = scf.if %5 -> (f32) { - %c0_i32_1 = arith.constant 0 : i32 - %c1_i32_2 = arith.constant 1 : i32 - %7 = arith.index_cast %c0_i32_1 : i32 to index - %8 = arith.index_cast %arg3 : i32 to index - %9 = arith.index_cast %c1_i32_2 : i32 to index - %10 = scf.for %arg6 = %7 to %8 step %9 iter_args(%arg7 = %arg5) -> (f32) { - %cst_3 = arith.constant 2.000000e+00 : f32 - %11 = arith.addf %arg7, %cst_3 : f32 - scf.yield %11 : f32 - } - scf.yield %10 : f32 - } else { - %7 = scf.while (%arg6 = %arg5) : (f32) -> f32 { - %cst_1 = arith.constant 1.200000e+00 : f32 - %8 = arith.cmpf olt, %arg6, %cst_1 : f32 - scf.condition(%8) %arg6 : f32 - } do { - ^bb0(%arg6: f32): - %cst_1 = arith.constant 2.000000e+00 : f32 - %8 = arith.mulf %arg6, %cst_1 : f32 - scf.yield %8 : f32 - } - scf.yield %7 : f32 - } - scf.yield %6 : f32 - } - %cst_0 = arith.constant 1.000000e+00 : f32 - %4 = arith.subf %3, %cst_0 : f32 - return - } -} -""" - @triton.jit - def nested_cf(X, lb, ub, Z): - a = 0.0 - for x in range(0, X): - if lb < ub: - for z in range(0, Z): - a += 2.0 - else: - while a < 1.2: - a *= 2.0 - a -= 1.0 - - mod, _ = nested_cf.compile_to_ttir(3, 4, 5, 6, grid=(1,)) - generated_ir = mod.str() - assert mod.verify(), generated_ir - assert ref_ir == generated_ir - -def test_matmul(): - ref_ir = """module { - func @matmul_kernel(%arg0: !tt.ptr, %arg1: !tt.ptr, %arg2: !tt.ptr, %arg3: i32, %arg4: i32, %arg5: i32, %arg6: i32, %arg7: i32, %arg8: i32) { - %0 = tt.get_program_id {axis = 0 : i32} : i32 - %c64_i32 = arith.constant 64 : i32 - %1 = arith.addi %arg3, %c64_i32 : i32 - %c1_i32 = arith.constant 1 : i32 - %2 = arith.subi %1, %c1_i32 : i32 - %c64_i32_0 = arith.constant 64 : i32 - %3 = arith.divsi %2, %c64_i32_0 : i32 - %c64_i32_1 = arith.constant 64 : i32 - %4 = arith.addi %arg4, %c64_i32_1 : i32 - %c1_i32_2 = arith.constant 1 : i32 - %5 = arith.subi %4, %c1_i32_2 : i32 - %c64_i32_3 = arith.constant 64 : i32 - %6 = arith.divsi %5, %c64_i32_3 : i32 - %c8_i32 = arith.constant 8 : i32 - %7 = arith.muli %6, %c8_i32 : i32 - %8 = arith.divsi %0, %7 : i32 - %c8_i32_4 = arith.constant 8 : i32 - %9 = arith.muli %8, %c8_i32_4 : i32 - %10 = arith.subi %3, %9 : i32 - %c8_i32_5 = arith.constant 8 : i32 - %11 = arith.cmpi slt, %10, %c8_i32_5 : i32 - %c8_i32_6 = arith.constant 8 : i32 - %12 = select %11, %10, %c8_i32_6 : i32 - %13 = arith.remsi %0, %12 : i32 - %14 = arith.addi %9, %13 : i32 - %15 = arith.remsi %0, %7 : i32 - %16 = arith.divsi %15, %12 : i32 - %c64_i32_7 = arith.constant 64 : i32 - %17 = arith.muli %14, %c64_i32_7 : i32 - %18 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32> - %19 = tt.broadcast %17 : (i32) -> tensor<64xi32> - %20 = arith.addi %19, %18 : tensor<64xi32> - %c64_i32_8 = arith.constant 64 : i32 - %21 = arith.muli %16, %c64_i32_8 : i32 - %22 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32> - %23 = tt.broadcast %21 : (i32) -> tensor<64xi32> - %24 = arith.addi %23, %22 : tensor<64xi32> - %25 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32> - %26 = tt.reshape %20 : (tensor<64xi32>) -> tensor<64x1xi32> - %27 = tt.broadcast %arg6 : (i32) -> tensor<64x1xi32> - %28 = arith.muli %26, %27 : tensor<64x1xi32> - %29 = tt.reshape %25 : (tensor<32xi32>) -> tensor<1x32xi32> - %c1_i32_9 = arith.constant 1 : i32 - %30 = tt.broadcast %c1_i32_9 : (i32) -> tensor<1x32xi32> - %31 = arith.muli %29, %30 : tensor<1x32xi32> - %32 = tt.broadcast %28 : (tensor<64x1xi32>) -> tensor<64x32xi32> - %33 = tt.broadcast %31 : (tensor<1x32xi32>) -> tensor<64x32xi32> - %34 = arith.addi %32, %33 : tensor<64x32xi32> - %35 = tt.broadcast %arg0 : (!tt.ptr) -> tensor<64x32x!tt.ptr> - %36 = tt.getelementptr %35, %34, : tensor<64x32x!tt.ptr> - %37 = tt.reshape %25 : (tensor<32xi32>) -> tensor<32x1xi32> - %38 = tt.broadcast %arg7 : (i32) -> tensor<32x1xi32> - %39 = arith.muli %37, %38 : tensor<32x1xi32> - %40 = tt.reshape %24 : (tensor<64xi32>) -> tensor<1x64xi32> - %c1_i32_10 = arith.constant 1 : i32 - %41 = tt.broadcast %c1_i32_10 : (i32) -> tensor<1x64xi32> - %42 = arith.muli %40, %41 : tensor<1x64xi32> - %43 = tt.broadcast %39 : (tensor<32x1xi32>) -> tensor<32x64xi32> - %44 = tt.broadcast %42 : (tensor<1x64xi32>) -> tensor<32x64xi32> - %45 = arith.addi %43, %44 : tensor<32x64xi32> - %46 = tt.broadcast %arg1 : (!tt.ptr) -> tensor<32x64x!tt.ptr> - %47 = tt.getelementptr %46, %45, : tensor<32x64x!tt.ptr> - %cst = arith.constant 0.000000e+00 : f32 - %48 = tt.broadcast %cst : (f32) -> tensor<64x64xf32> - %c0_i32 = arith.constant 0 : i32 - %c32_i32 = arith.constant 32 : i32 - %49 = arith.index_cast %c0_i32 : i32 to index - %50 = arith.index_cast %arg5 : i32 to index - %51 = arith.index_cast %c32_i32 : i32 to index - %52:3 = scf.for %arg9 = %49 to %50 step %51 iter_args(%arg10 = %48, %arg11 = %36, %arg12 = %47) -> (tensor<64x64xf32>, tensor<64x32x!tt.ptr>, tensor<32x64x!tt.ptr>) { - %cst_14 = arith.constant dense : tensor<64x32xi1> - %cst_15 = arith.constant dense<0.000000e+00> : tensor<64x32xf16> - %82 = tt.load %arg11, %cst_14, %cst_15 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<64x32xf16> - %cst_16 = arith.constant dense : tensor<32x64xi1> - %cst_17 = arith.constant dense<0.000000e+00> : tensor<32x64xf16> - %83 = tt.load %arg12, %cst_16, %cst_17 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<32x64xf16> - %cst_18 = arith.constant 0.000000e+00 : f32 - %84 = tt.broadcast %cst_18 : (f32) -> tensor<64x64xf32> - %85 = tt.dot %82, %83, %84 {allowTF32 = true} : tensor<64x32xf16> * tensor<32x64xf16> -> tensor<64x64xf32> - %86 = arith.addf %arg10, %85 : tensor<64x64xf32> - %c32_i32_19 = arith.constant 32 : i32 - %87 = tt.broadcast %c32_i32_19 : (i32) -> tensor<64x32xi32> - %88 = tt.getelementptr %arg11, %87, : tensor<64x32x!tt.ptr> - %c32_i32_20 = arith.constant 32 : i32 - %89 = arith.muli %arg7, %c32_i32_20 : i32 - %90 = tt.broadcast %89 : (i32) -> tensor<32x64xi32> - %91 = tt.getelementptr %arg12, %90, : tensor<32x64x!tt.ptr> - scf.yield %86, %88, %91 : tensor<64x64xf32>, tensor<64x32x!tt.ptr>, tensor<32x64x!tt.ptr> - } - %53 = arith.truncf %52#0 : tensor<64x64xf32> to tensor<64x64xf16> - %c64_i32_11 = arith.constant 64 : i32 - %54 = arith.muli %14, %c64_i32_11 : i32 - %55 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32> - %56 = tt.broadcast %54 : (i32) -> tensor<64xi32> - %57 = arith.addi %56, %55 : tensor<64xi32> - %c64_i32_12 = arith.constant 64 : i32 - %58 = arith.muli %16, %c64_i32_12 : i32 - %59 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32> - %60 = tt.broadcast %58 : (i32) -> tensor<64xi32> - %61 = arith.addi %60, %59 : tensor<64xi32> - %62 = tt.reshape %57 : (tensor<64xi32>) -> tensor<64x1xi32> - %63 = tt.broadcast %arg8 : (i32) -> tensor<64x1xi32> - %64 = arith.muli %63, %62 : tensor<64x1xi32> - %65 = tt.broadcast %arg2 : (!tt.ptr) -> tensor<64x1x!tt.ptr> - %66 = tt.getelementptr %65, %64, : tensor<64x1x!tt.ptr> - %67 = tt.reshape %61 : (tensor<64xi32>) -> tensor<1x64xi32> - %c1_i32_13 = arith.constant 1 : i32 - %68 = tt.broadcast %c1_i32_13 : (i32) -> tensor<1x64xi32> - %69 = arith.muli %67, %68 : tensor<1x64xi32> - %70 = tt.broadcast %66 : (tensor<64x1x!tt.ptr>) -> tensor<64x64x!tt.ptr> - %71 = tt.broadcast %69 : (tensor<1x64xi32>) -> tensor<64x64xi32> - %72 = tt.getelementptr %70, %71, : tensor<64x64x!tt.ptr> - %73 = tt.reshape %57 : (tensor<64xi32>) -> tensor<64x1xi32> - %74 = tt.broadcast %arg3 : (i32) -> tensor<64x1xi32> - %75 = arith.cmpi slt, %73, %74 : tensor<64x1xi32> - %76 = tt.reshape %61 : (tensor<64xi32>) -> tensor<1x64xi32> - %77 = tt.broadcast %arg4 : (i32) -> tensor<1x64xi32> - %78 = arith.cmpi slt, %76, %77 : tensor<1x64xi32> - %79 = tt.broadcast %75 : (tensor<64x1xi1>) -> tensor<64x64xi1> - %80 = tt.broadcast %78 : (tensor<1x64xi1>) -> tensor<64x64xi1> - %81 = arith.andi %79, %80 : tensor<64x64xi1> - tt.store %72, %53, %81, : tensor<64x64xf16> - return - } -} -""" - @triton.jit - def matmul_kernel( - # Pointers to matrices - a_ptr, b_ptr, c_ptr, - # Matrix dimensions - M, N, K, - # The stride variables represent how much to increase the ptr by when moving by 1 - # element in a particular dimension. E.g. stride_am is how much to increase a_ptr - # by to get the element one row down (A has M rows) - stride_am, stride_ak, - stride_bk, stride_bn, - stride_cm, stride_cn, - # Meta-parameters - BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, - GROUP_SIZE_M: tl.constexpr, - ): - """Kernel for computing the matmul C = A x B. - A has shape (M, K), B has shape (K, N) and C has shape (M, N) - """ - # ----------------------------------------------------------- - # Map program ids `pid` to the block of C it should compute. - # This is done in a grouped ordering to promote L2 data reuse - # See above `L2 Cache Optimizations` section for details - pid = tl.program_id(axis=0) - num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) - num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) - num_pid_in_group = GROUP_SIZE_M * num_pid_n - group_id = pid // num_pid_in_group - first_pid_m = group_id * GROUP_SIZE_M - group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) - pid_m = first_pid_m + (pid % group_size_m) - pid_n = (pid % num_pid_in_group) // group_size_m - - # ---------------------------------------------------------- - # Create pointers for the first blocks of A and B. - # We will advance this pointer as we move in the K direction - # and accumulate - # a_ptrs is a block of [BLOCK_SIZE_M, BLOCK_SIZE_K] pointers - # b_ptrs is a block of [BLOCK_SIZE_K, BLOCK_SIZE_n] pointers - # see above `Pointer Arithmetics` section for details - offs_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) - offs_bn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) - offs_k = tl.arange(0, BLOCK_SIZE_K) - a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak) - b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn) - - # ----------------------------------------------------------- - # Iterate to compute a block of the C matrix - # We accumulate into a `[BLOCK_SIZE_M, BLOCK_SIZE_N]` block - # of fp32 values for higher accuracy. - # `accumulator` will be converted back to fp16 after the loop - accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) - for k in range(0, K, BLOCK_SIZE_K): - # Note that for simplicity, we don't apply a mask here. - # This means that if K is not a multiple of BLOCK_SIZE_K, - # this will access out-of-bounds memory and produce an - # error or (worse!) incorrect results. - a = tl.load(a_ptrs) - b = tl.load(b_ptrs) - # We accumulate along the K dimension - accumulator += tl.dot(a, b) - # Advance the ptrs to the next K block - a_ptrs += BLOCK_SIZE_K * stride_ak - b_ptrs += BLOCK_SIZE_K * stride_bk - c = accumulator.to(tl.float16) - - # ----------------------------------------------------------- - # Write back the block of the output matrix C - offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) - offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) - c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :] - c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N) - tl.store(c_ptrs, c, mask=c_mask) - - a = torch.randn((512, 512), device='cuda', dtype=torch.float16) - b = torch.randn((512, 512), device='cuda', dtype=torch.float16) - c = torch.empty((512, 512), device='cuda', dtype=torch.float16) - - - mod, ctx = matmul_kernel.compile_to_ttir( - a, b, c, - 512, 512, 512, - a.stride(0), a.stride(1), - b.stride(0), b.stride(1), - c.stride(0), c.stride(1), - 64, 64, 32, - 8, grid=(2,) - ) - verify = mod.verify() - assert verify - # assert ref_ir == mod.str() - print(mod.str()) - - pm = _triton.ir.pass_manager(ctx) - pm.add_inliner_pass() - pm.run(mod) - - verify = mod.verify() - assert verify - # assert ref_ir == mod.str() - print(mod.str())