Skip to content

Commit

Permalink
Support atomic ops for CPU. (triton-lang#20)
Browse files Browse the repository at this point in the history
Signed-off-by: Ilya Enkovich <ilya.enkovich@intel.com>
  • Loading branch information
ienkovich committed Dec 6, 2024
1 parent aea6125 commit 7b6ed89
Show file tree
Hide file tree
Showing 11 changed files with 409 additions and 0 deletions.
6 changes: 6 additions & 0 deletions python/test/unit/language/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -1422,6 +1422,7 @@ def kernel(X, Y, Z):
# ---------------
# test atomics
# ---------------
@pytest.mark.cpu
@pytest.mark.interpreter
@pytest.mark.parametrize(
"op, dtype_x_str, mode, sem",
Expand Down Expand Up @@ -1502,6 +1503,7 @@ def kernel(X, Z):
assert f"atom.global.gpu.{sem_str}" in h.asm["ptx"]


@pytest.mark.cpu
@pytest.mark.interpreter
@pytest.mark.parametrize("num_ctas", num_ctas_list)
def test_atomic_rmw_predicate(num_ctas, device):
Expand All @@ -1517,6 +1519,7 @@ def kernel(X):
assert x.item() == 63


@pytest.mark.cpu
@pytest.mark.interpreter
@pytest.mark.parametrize("shape, axis, num_ctas, dtype_x_str",
[(shape, axis, num_ctas, dtype_x_str)
Expand Down Expand Up @@ -1578,6 +1581,7 @@ def torch_to_triton_dtype(t):
np.testing.assert_equal(old_ref, to_numpy(old_tri))


@pytest.mark.cpu
@pytest.mark.interpreter
@pytest.mark.parametrize("num_ctas", num_ctas_list)
def test_tensor_atomic_rmw_block(num_ctas, device):
Expand All @@ -1597,6 +1601,7 @@ def kernel(X, SHAPE0: tl.constexpr, SHAPE1: tl.constexpr):
assert torch.min(x).item() == 0.0


@pytest.mark.cpu
@pytest.mark.interpreter
@pytest.mark.parametrize("sem", [None, 'acquire', 'release', 'acq_rel', 'relaxed'])
@pytest.mark.parametrize("num_ctas", num_ctas_list)
Expand Down Expand Up @@ -1638,6 +1643,7 @@ def serialized_add(data, Lock, SEM: tl.constexpr):
assert f"atom.global.{sem_str}" in h.asm["ptx"]


@pytest.mark.cpu
@pytest.mark.interpreter
@pytest.mark.parametrize("sem", [None, 'acquire', 'release', 'acq_rel', 'relaxed'])
@pytest.mark.parametrize("num_ctas", num_ctas_list)
Expand Down
1 change: 1 addition & 0 deletions third_party/cpu/include/TritonCPUToLLVM/Passes.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ std::unique_ptr<OperationPass<ModuleOp>> createFuncOpToLLVMPass();
std::unique_ptr<OperationPass<ModuleOp>> createMemoryOpToLLVMPass();
std::unique_ptr<OperationPass<ModuleOp>> createGetProgramIdOpToLLVMPass();
std::unique_ptr<OperationPass<triton::FuncOp>> createLowerMultiReductionPass();
std::unique_ptr<OperationPass<ModuleOp>> createAtomicOpsToLLVMPass();

void tritonCPUToLLVMPipelineBuilder(OpPassManager &pm);
void registerTritonCPUToLLVMPipeline();
Expand Down
11 changes: 11 additions & 0 deletions third_party/cpu/include/TritonCPUToLLVM/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -54,4 +54,15 @@ def LowerMultiReduction : Pass<"triton-cpu-lower-multi-reduction", "mlir::triton
"mlir::triton::TritonDialect"];
}

def AtomicOpsToLLVM : Pass<"triton-cpu-atomic-ops-to-llvm", "mlir::ModuleOp"> {
let summary = "Convert Triton atomic operations to LLVM.";
let description = [{
}];
let constructor = "mlir::triton::cpu::createAtomicOpsToLLVMPass()";

let dependentDialects = ["mlir::vector::VectorDialect",
"mlir::triton::cpu::TritonCPUDialect",
"mlir::triton::TritonDialect"];
}

#endif
1 change: 1 addition & 0 deletions third_party/cpu/include/TritonToTritonCPU/Passes.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ std::unique_ptr<OperationPass<ModuleOp>> createConvertControlFlowOps();
std::unique_ptr<OperationPass<ModuleOp>> createConvertHistogramOp();
std::unique_ptr<OperationPass<ModuleOp>> createConvertReductionOp();
std::unique_ptr<OperationPass<ModuleOp>> createConvertScanOp();
std::unique_ptr<OperationPass<ModuleOp>> createConvertAtomicOps();

void tritonToTritonCPUPipelineBuilder(OpPassManager &pm);
void registerTritonToTritonCPUPipeline();
Expand Down
14 changes: 14 additions & 0 deletions third_party/cpu/include/TritonToTritonCPU/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -114,4 +114,18 @@ def ConvertScanOp : Pass<"triton-cpu-convert-scan", "mlir::ModuleOp"> {
"mlir::triton::cpu::TritonCPUDialect"];
}

def ConvertAtomicOps : Pass<"triton-cpu-convert-atomic-ops", "mlir::ModuleOp"> {
let summary = "Convert Triton atomic operations.";
let description = [{

}];
let constructor = "mlir::triton::cpu::createConvertAtomicOps()";

let dependentDialects = ["mlir::arith::ArithDialect",
"mlir::vector::VectorDialect",
"mlir::scf::SCFDialect",
"mlir::triton::TritonDialect",
"mlir::triton::cpu::TritonCPUDialect"];
}

#endif
154 changes: 154 additions & 0 deletions third_party/cpu/lib/TritonCPUToLLVM/AtomicOpsToLLVM.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,154 @@
#include "TypeConverter.h"

#include "cpu/include/TritonCPUToLLVM/Passes.h"

#include "mlir/Conversion/LLVMCommon/VectorPattern.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/Pass/Pass.h"

#include "triton/Conversion/TritonGPUToLLVM/Utility.h"
#include "triton/Dialect/Triton/IR/Dialect.h"
#include "triton/Dialect/TritonCPU/IR/Dialect.h"

namespace mlir {
namespace triton {
#define GEN_PASS_DEF_ATOMICOPSTOLLVM
#include "cpu/include/TritonCPUToLLVM/Passes.h.inc"
} // namespace triton
} // namespace mlir

using namespace mlir;
using namespace mlir::triton;
using namespace mlir::triton::cpu;

namespace {

class TritonLLVMConversionTarget : public ConversionTarget {
public:
explicit TritonLLVMConversionTarget(MLIRContext &ctx)
: ConversionTarget(ctx) {
addLegalDialect<LLVM::LLVMDialect>();
addLegalOp<mlir::UnrealizedConversionCastOp>();
}
};

LLVM::AtomicOrdering getOrdering(MemSemantic sem) {
switch (sem) {
case MemSemantic::RELAXED:
return LLVM::AtomicOrdering::monotonic;
case MemSemantic::ACQUIRE:
return LLVM::AtomicOrdering::acquire;
case MemSemantic::RELEASE:
return LLVM::AtomicOrdering::release;
case MemSemantic::ACQUIRE_RELEASE:
return LLVM::AtomicOrdering::acq_rel;
default:
llvm_unreachable("Unexpected atomic mem semantic");
}
}

// TODO: use enums to access struct fields.
struct AtomicRMWOpConversion : public OpConversionPattern<AtomicRMWOp> {
using OpConversionPattern::OpConversionPattern;

LogicalResult
matchAndRewrite(AtomicRMWOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto opKind = getAtomicBinOp(op.getAtomicRmwOp(), op.getType());
auto ptr = rewriter.getRemappedValue(op.getPtr());
auto val = rewriter.getRemappedValue(op.getVal());
auto ordering = getOrdering(op.getSem());
rewriter.replaceOpWithNewOp<LLVM::AtomicRMWOp>(op, opKind, ptr, val,
ordering);
return success();
}

LLVM::AtomicBinOp getAtomicBinOp(RMWOp op, Type type) const {
switch (op) {
case RMWOp::AND:
return LLVM::AtomicBinOp::_and;
case RMWOp::OR:
return LLVM::AtomicBinOp::_or;
case RMWOp::XOR:
return LLVM::AtomicBinOp::_xor;
case RMWOp::ADD:
return LLVM::AtomicBinOp::add;
case RMWOp::FADD:
return LLVM::AtomicBinOp::fadd;
case RMWOp::MAX:
return type.isIntOrIndex() ? LLVM::AtomicBinOp::max
: LLVM::AtomicBinOp::fmax;
case RMWOp::MIN:
return type.isIntOrIndex() ? LLVM::AtomicBinOp::min
: LLVM::AtomicBinOp::fmin;
case RMWOp::UMAX:
return LLVM::AtomicBinOp::umax;
case RMWOp::UMIN:
return LLVM::AtomicBinOp::umin;
case RMWOp::XCHG:
return LLVM::AtomicBinOp::xchg;
default:
llvm_unreachable("Unexpected atomic op");
}
}
};

struct AtomicCASOpConversion : public OpConversionPattern<AtomicCASOp> {
using OpConversionPattern::OpConversionPattern;

LogicalResult
matchAndRewrite(AtomicCASOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto loc = op.getLoc();
auto ptr = rewriter.getRemappedValue(op.getPtr());
auto cmp = rewriter.getRemappedValue(op.getCmp());
auto val = rewriter.getRemappedValue(op.getVal());
auto ordering = getOrdering(op.getSem());
auto failureOrdering = ordering != LLVM::AtomicOrdering::monotonic
? LLVM::AtomicOrdering::acquire
: ordering;
Value cmpXchg = rewriter.create<LLVM::AtomicCmpXchgOp>(
loc, ptr, cmp, val, ordering, failureOrdering);
Value oldVal = rewriter.create<LLVM::ExtractValueOp>(loc, cmpXchg, 0);
rewriter.replaceOp(op, oldVal);
return success();
}
};

struct AtomicOpsToLLVM
: public triton::impl::AtomicOpsToLLVMBase<AtomicOpsToLLVM> {
using AtomicOpsToLLVMBase::AtomicOpsToLLVMBase;

AtomicOpsToLLVM() : AtomicOpsToLLVMBase() {}

void runOnOperation() override {
MLIRContext *context = &getContext();
ModuleOp mod = getOperation();

mlir::LowerToLLVMOptions option(context);
TritonCPUToLLVMTypeConverter typeConverter(context, option);
TritonLLVMConversionTarget convTarget(*context);

RewritePatternSet patterns(context);
patterns.add<AtomicRMWOpConversion>(typeConverter, context);
patterns.add<AtomicCASOpConversion>(typeConverter, context);

if (failed(applyPartialConversion(mod, convTarget, std::move(patterns))))
return signalPassFailure();
}
};

} // anonymous namespace

namespace mlir {
namespace triton {
namespace cpu {

std::unique_ptr<OperationPass<ModuleOp>> createAtomicOpsToLLVMPass() {
return std::make_unique<AtomicOpsToLLVM>();
}

} // namespace cpu
} // namespace triton
} // namespace mlir
1 change: 1 addition & 0 deletions third_party/cpu/lib/TritonCPUToLLVM/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
add_triton_library(TritonCPUToLLVM
AtomicOpsToLLVM.cpp
FuncOpToLLVM.cpp
GetProgramIdOpToLLVM.cpp
LowerMultiReduction.cpp
Expand Down
1 change: 1 addition & 0 deletions third_party/cpu/lib/TritonCPUToLLVM/Pipeline.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ void tritonCPUToLLVMPipelineBuilder(OpPassManager &pm) {
pm.addPass(mlir::triton::cpu::createFuncOpToLLVMPass());
pm.addPass(mlir::triton::cpu::createGetProgramIdOpToLLVMPass());
pm.addPass(mlir::triton::cpu::createMemoryOpToLLVMPass());
pm.addPass(mlir::triton::cpu::createAtomicOpsToLLVMPass());
// pm.addPass(mlir::createReconcileUnrealizedCastsPass());
}

Expand Down
1 change: 1 addition & 0 deletions third_party/cpu/lib/TritonToTritonCPU/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
add_triton_library(TritonToTritonCPU
ConvertAtomicOps.cpp
ConvertControlFlowOps.cpp
ConvertDotOp.cpp
ConvertElementwiseOps.cpp
Expand Down
Loading

0 comments on commit 7b6ed89

Please sign in to comment.