Skip to content

Commit

Permalink
[ROCM] Enable ROCM Backend #1: Empty Kernel (triton-lang#1312)
Browse files Browse the repository at this point in the history
This PR is a first in a series of PRs to import the changes that we have
made to enable ROCM on [our
fork](https://github.com/ROCmSoftwarePlatform/triton) of triton.

The PR contains the major changes to the python frontend and enough
changes to the c++ backend to allow compilation and running of the empty
kernel. We use the ROCM ci added a few weeks ago to verify things.

---------

Co-authored-by: Ronan Keryell <ronan@keryell.fr>
  • Loading branch information
micmelesse and keryell authored Mar 25, 2023
1 parent b1284ab commit 2ddc73f
Show file tree
Hide file tree
Showing 33 changed files with 1,602 additions and 131 deletions.
33 changes: 30 additions & 3 deletions .github/workflows/integration-tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ jobs:
id: set-matrix
run: |
if [ x"${{ github.repository }}" == x"openai/triton" ]; then
echo '::set-output name=matrix::[["self-hosted", "A100"], ["self-hosted", "V100"], "macos-10.15"]'
echo '::set-output name=matrix::[["self-hosted", "A100"], ["self-hosted", "V100"], ["self-hosted", "gfx908"], "macos-10.15"]'
else
echo '::set-output name=matrix::["ubuntu-latest", "macos-10.15"]'
fi
Expand All @@ -40,6 +40,16 @@ jobs:
- name: Checkout
uses: actions/checkout@v2

- name: Set CUDA ENV
if: ${{(matrix.runner[0] == 'self-hosted') && (matrix.runner[1] == 'V100' || matrix.runner[1] == 'A100')}}
run: |
echo "BACKEND=CUDA" >> $GITHUB_ENV
- name: Set ROCM ENV
if: ${{(matrix.runner[0] == 'self-hosted') && (matrix.runner[1] == 'gfx908')}}
run: |
echo "BACKEND=ROCM" >> $GITHUB_ENV
- name: Clear cache
run: |
rm -rf ~/.triton/
Expand Down Expand Up @@ -74,12 +84,22 @@ jobs:
flake8 --config ./python/setup.cfg ./python || ( echo '::error::Flake8 failed; see logs for errors.' ; exit 1 )
- name: Install Triton
if: ${{ env.BACKEND != 'ROCM'}}
run: |
cd python
pip3 install cmake==3.24
TRITON_USE_ASSERT_ENABLED_LLVM=TRUE pip3 install -e '.[tests]'
- name: Install Triton on ROCM
if: ${{ env.BACKEND == 'ROCM'}}
run: |
cd python
pip3 uninstall --yes torch torchvision torchaudio
pip3 install --no-cache-dir --force-reinstall torch==1.13.1 --extra-index-url https://download.pytorch.org/whl/rocm5.2
TRITON_USE_ASSERT_ENABLED_LLVM=TRUE pip3 install -e '.[tests]'
- name: Run lit tests
if: ${{ env.BACKEND != 'ROCM'}}
run: |
pip3 install lit
cd python
Expand All @@ -89,13 +109,20 @@ jobs:
fi
lit -v "$LIT_TEST_DIR"
- name: Run python tests
if: ${{matrix.runner[0] == 'self-hosted'}}
- name: Run python tests on CUDA
if: ${{ env.BACKEND == 'CUDA'}}
run: |
cd python/test/unit/
pytest
- name: Run python tests on ROCM
if: ${{ env.BACKEND == 'ROCM'}}
run: |
cd python/test/unit/language/
pytest --capture=tee-sys -rfs --verbose "test_core.py::test_empty_kernel"
- name: Run CXX unittests
if: ${{ env.BACKEND != 'ROCM'}}
run: |
cd python/
cd "build/$(ls build)"
Expand Down
Empty file modified .gitignore
100644 → 100755
Empty file.
2 changes: 2 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -216,6 +216,7 @@ if(TRITON_BUILD_PYTHON_MODULE)
TritonGPUTransforms
TritonLLVMIR
TritonPTX
TritonHSACO
${dialect_libs}
${conversion_libs}

Expand All @@ -228,6 +229,7 @@ if(TRITON_BUILD_PYTHON_MODULE)
MLIRExecutionEngine
MLIRMathToLLVM
MLIRNVVMToLLVMIRTranslation
MLIRROCDLToLLVMIRTranslation
MLIRIR
)

Expand Down
2 changes: 2 additions & 0 deletions bin/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ llvm_update_compile_flags(triton-translate)
TritonGPUTransforms
TritonLLVMIR
TritonPTX
TritonHSACO
${dialect_libs}
${conversion_libs}
# tests
Expand All @@ -53,5 +54,6 @@ llvm_update_compile_flags(triton-translate)
MLIRTransformUtils
MLIRLLVMToLLVMIRTranslation
MLIRNVVMToLLVMIRTranslation
MLIRROCDLToLLVMIRTranslation
)
mlir_check_all_link_libraries(triton-translate)
3 changes: 2 additions & 1 deletion bin/triton-opt.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@
#include "triton/Dialect/Triton/Transforms/Passes.h"
#include "triton/Dialect/TritonGPU/Transforms/Passes.h"

#include "triton/Conversion/Passes.h"
#include "triton/Conversion/TritonGPUToLLVM/Passes.h"
#include "triton/Conversion/TritonToTritonGPU/Passes.h"

#include "mlir/IR/Dialect.h"
#include "mlir/InitAllPasses.h"
Expand Down
26 changes: 24 additions & 2 deletions bin/triton-translate.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,14 @@
#include "triton/Conversion/TritonToTritonGPU/TritonToTritonGPUPass.h"
#include "triton/Dialect/Triton/IR/Dialect.h"
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
#include "triton/Target/HSACO/HSACOTranslation.h"
#include "triton/Target/LLVMIR/LLVMIRTranslation.h"
#include "triton/Target/PTX/PTXTranslation.h"
#include "llvm/IR/LLVMContext.h"
#include "llvm/Support/CommandLine.h"
#include "llvm/Support/InitLLVM.h"
#include "llvm/Support/SourceMgr.h"
#include "llvm/Support/ToolOutputFile.h"
#include <iostream>

namespace mlir {
namespace triton {
Expand Down Expand Up @@ -79,7 +79,8 @@ LogicalResult tritonTranslateMain(int argc, char **argv,
llvm::cl::init("-"));

static llvm::cl::opt<std::string> targetKind(
"target", llvm::cl::desc("<translation target, options: llvmir/ptx>"),
"target",
llvm::cl::desc("<translation target, options: llvmir/ptx/hsaco>"),
llvm::cl::value_desc("target"), llvm::cl::init("llvmir"));

static llvm::cl::opt<int> SMArch("sm", llvm::cl::desc("sm arch"),
Expand All @@ -88,6 +89,18 @@ LogicalResult tritonTranslateMain(int argc, char **argv,
static llvm::cl::opt<int> ptxVersion(
"ptx-version", llvm::cl::desc("PTX version"), llvm::cl::init(10000));

static llvm::cl::opt<std::string> GCNArch(
"gfx", llvm::cl::desc("AMDGCN target. e.g. '90a'"),
llvm::cl::value_desc("architecture"), llvm::cl::init("90a"));

static llvm::cl::opt<std::string> GCNTriple(
"amdgcn", llvm::cl::desc("AMDGCN triple. e.g. '-amd-amdhsa'"),
llvm::cl::value_desc("target triple"), llvm::cl::init("-amd-amdhsa"));

static llvm::cl::opt<std::string> GCNFeatures(
"", llvm::cl::desc("AMDGCN features. e.g. '+sramecc,-xnack'"),
llvm::cl::value_desc("features"), llvm::cl::init("+sramecc,-xnack"));

llvm::InitLLVM y(argc, argv);

registerAsmPrinterCLOptions();
Expand Down Expand Up @@ -119,6 +132,15 @@ LogicalResult tritonTranslateMain(int argc, char **argv,
else if (targetKind == "ptx")
llvm::outs() << ::triton::translateLLVMIRToPTX(*llvmir, SMArch.getValue(),
ptxVersion.getValue());
else if (targetKind == "hsaco") {
auto [module, hsaco] = ::triton::translateLLVMIRToHSACO(
*llvmir, GCNArch.getValue(), GCNTriple.getValue(),
GCNFeatures.getValue());
llvm::outs() << hsaco;
} else {
llvm::errs() << "Error: Unknown target specified: " << targetKind << "\n";
return failure();
}

return success();
}
Expand Down
5 changes: 2 additions & 3 deletions include/triton/Conversion/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,2 @@
set(LLVM_TARGET_DEFINITIONS Passes.td)
mlir_tablegen(Passes.h.inc -gen-pass-decls)
add_public_tablegen_target(TritonConversionPassIncGen)
add_subdirectory(TritonToTritonGPU)
add_subdirectory(TritonGPUToLLVM)
33 changes: 33 additions & 0 deletions include/triton/Conversion/TritonGPUToLLVM/AsmFormat.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
#ifndef TRITON_CONVERSION_TRITON_GPU_TO_LLVM_ASM_FORMAT_H_
#define TRITON_CONVERSION_TRITON_GPU_TO_LLVM_ASM_FORMAT_H_

#include "mlir/IR/Value.h"
#include "triton/Dialect/Triton/IR/Dialect.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/StringRef.h"
#include <memory>
#include <string>

namespace mlir {
class ConversionPatternRewriter;
class Location;

namespace triton {
using llvm::StringRef;

inline std::string strJoin(llvm::ArrayRef<std::string> strs,
llvm::StringRef delimiter) {
std::string osStr;
llvm::raw_string_ostream os(osStr);
for (size_t i = 0; !strs.empty() && i < strs.size() - 1; ++i)
os << strs[i] << delimiter;
if (!strs.empty())
os << strs.back();
os.flush();
return osStr;
}

} // namespace triton
} // namespace mlir

#endif // TRITON_CONVERSION_TRITON_GPU_TO_LLVM_ASM_FORMAT_H_
3 changes: 3 additions & 0 deletions include/triton/Conversion/TritonGPUToLLVM/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
set(LLVM_TARGET_DEFINITIONS Passes.td)
mlir_tablegen(Passes.h.inc -gen-pass-decls --name TritonGPUToLLVM)
add_public_tablegen_target(TritonGPUConversionPassIncGen)
Loading

0 comments on commit 2ddc73f

Please sign in to comment.