Skip to content

Commit

Permalink
post rebase fixes and refactoring
Browse files Browse the repository at this point in the history
  • Loading branch information
binarman committed Apr 8, 2024
1 parent 12df59a commit 352331c
Show file tree
Hide file tree
Showing 7 changed files with 66 additions and 47 deletions.
11 changes: 9 additions & 2 deletions include/triton/Conversion/TritonGPUToLLVM/Utility.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ using namespace mlir::triton;

// Shortcuts for some commonly used LLVM ops to keep code simple and intuitive
// Operators
#define inttofloat(...) rewriter.create<LLVM::SIToFPOp>(loc, __VA_ARGS__)
#define inttoptr(...) rewriter.create<LLVM::IntToPtrOp>(loc, __VA_ARGS__)
#define ptrtoint(...) rewriter.create<LLVM::PtrToIntOp>(loc, __VA_ARGS__)
#define zext(...) rewriter.create<LLVM::ZExtOp>(loc, __VA_ARGS__)
Expand Down Expand Up @@ -198,6 +199,9 @@ using namespace mlir::triton;

Value createConstantI32(Location loc, OpBuilder &rewriter, int32_t v);

/// Create a 16-bit float constant.
Value createConstantF16(Location loc, OpBuilder &rewriter, float v);

/// Create a 32-bit float constant.
Value createConstantF32(Location loc, OpBuilder &rewriter, float v);

Expand Down Expand Up @@ -800,15 +804,18 @@ emitBaseIndexForMfmaLayout(Location loc, RewriterBase &rewriter,
Value warpId = udiv(threadId, warpSize);
SmallVector<Value> multiDimWarpId = delinearize(
rewriter, loc, warpId, _warpsPerCTA, triton::gpu::getOrder(mfmaLayout));
i32_val(12345);
if (shape[rank - 2] >= mDim) {
assert(shape[rank - 2] % mDim == 0);
multiDimWarpId[rank - 2] =
urem(multiDimWarpId[rank - 2], i32_val(ceil<unsigned>(shape[rank - 2], mDim)));
urem(multiDimWarpId[rank - 2],
i32_val(ceil<unsigned>(shape[rank - 2], mDim)));
}
if (shape[rank - 1] >= nDim) {
assert(shape[rank - 1] % nDim == 0);
multiDimWarpId[rank - 1] =
urem(multiDimWarpId[rank - 1], i32_val(ceil<unsigned>(shape[rank - 1], nDim)));
urem(multiDimWarpId[rank - 1],
i32_val(ceil<unsigned>(shape[rank - 1], nDim)));
}
Value offWarp0 = mul(multiDimWarpId[rank - 2], i32_val(mDim));
Value offWarp1 = mul(multiDimWarpId[rank - 1], i32_val(nDim));
Expand Down
2 changes: 1 addition & 1 deletion lib/Dialect/TritonGPU/IR/Dialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -253,7 +253,7 @@ SmallVector<unsigned> getOrder(Attribute layout) {
order[i] = rank - 1 - i;
if (auto mfmaLayout = layout.dyn_cast<AMDMfmaEncodingAttr>()) {
if (mfmaLayout.getIsTransposed()) {
std::reverse(order.begin(), order.end());
std::swap(order[rank - 2], order[rank - 1]);
}
}
return order;
Expand Down
2 changes: 1 addition & 1 deletion third_party/amd/backend/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ class HIPOptions:
arch: str = None
allow_fp8e4nv: bool = False
default_dot_input_precision: str = "ieee"
allowed_dot_input_precisions: Tuple[str] = ("ieee", )
allowed_dot_input_precisions: Tuple[str] = ("tf32", "ieee")
enable_fp_fusion: bool = True
capability: int = None
matrix_inst_shape: int = 0
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,9 @@ std::pair<mlir::Value, mlir::Value>
swizzleIndexes(ConversionPatternRewriter &rewriter, Location loc, Value row,
Value col, SharedMemoryObject smemObj, SharedEncodingAttr attr) {
(void)smemObj; // unused in current pattern
bool transposed = (attr.getOrder()[0] != 1);
const auto &order = attr.getOrder();
auto rank = order.size();
bool transposed = (order[rank - 2] != 1);
if (transposed) {
// tensor is column-wise, so swapping col and row in computations
std::swap(row, col);
Expand Down Expand Up @@ -54,9 +56,11 @@ Value computeOffset(ConversionPatternRewriter &rewriter, Location loc,
SharedEncodingAttr srcLayout) {
auto [swizzledRow, swizzledCol] =
swizzleIndexes(rewriter, loc, row, col, smemObj, srcLayout);
auto &strides = smemObj.strides;
Value rowOffset = mul(swizzledRow, strides[0]);
Value colOffset = mul(swizzledCol, strides[1]);
const auto &strides = smemObj.getStrides();
auto rank = strides.size();
assert(rank == 2 || rank == 3);
Value rowOffset = mul(swizzledRow, strides[rank - 2]);
Value colOffset = mul(swizzledCol, strides[rank - 1]);
return add(rowOffset, colOffset);
}

Expand All @@ -77,11 +81,12 @@ llvm::SmallVector<Value> computeOffsetsAType(
Value warpId, Value laneId, int warpsPerBlock, int numOfElems,
ArrayRef<int64_t> reps, SharedMemoryObject smemObj,
SharedEncodingAttr srcLayout, unsigned nonKDim, unsigned kDim) {
SmallVector<Value> strides{smemObj.strides[0], smemObj.strides[1]};
SmallVector<Value> offsets{smemObj.offsets[0], smemObj.offsets[1]};
SmallVector<Value> strides = smemObj.getStrides();
SmallVector<Value> offsets = smemObj.getOffsets();
auto rank = offsets.size();

int vectorSize = 1;
if (srcLayout.getOrder()[0] == 1) {
if (srcLayout.getOrder()[0] == rank - 1) {
if (isSwizzled(srcLayout))
vectorSize = std::min(static_cast<int>(srcLayout.getVec()), numOfElems);
else
Expand All @@ -90,14 +95,14 @@ llvm::SmallVector<Value> computeOffsetsAType(

auto mapping = fn(rewriter, loc, elemsPerInstr, warpId, laneId, numOfElems,
reps, offsets, vectorSize, nonKDim, kDim);
const auto numBlocks = reps[0];
const auto numBlocks = reps[1];
const auto blockSize = mapping.size();
auto order = srcLayout.getOrder();
llvm::SmallVector<Value> aOffsets(blockSize * numBlocks);

for (int block = 0; block < numBlocks; ++block) {
int blockNonKOffset = block * nonKDim * warpsPerBlock;
Value offAdjust = mul(i32_val(blockNonKOffset), strides[0]);
Value offAdjust = mul(i32_val(blockNonKOffset), strides[rank - 2]);
for (int i = 0; i < blockSize; ++i) {
Value row = mapping[i][0];
Value col = mapping[i][1];
Expand All @@ -109,6 +114,17 @@ llvm::SmallVector<Value> computeOffsetsAType(
return aOffsets;
}

template <typename Container>
static SmallVector<typename Container::value_type>
transposeSpatialDims(const Container &vec) {
auto rank = vec.size();
assert(rank == 2 || rank == 3);
SmallVector<typename Container::value_type> res(rank, vec[0]);
res[rank - 2] = vec[rank - 1];
res[rank - 1] = vec[rank - 2];
return res;
}

llvm::SmallVector<Value> computeOffsetsBType(
ConversionPatternRewriter &rewriter, Location loc,
computeTensorElemMappingInBlockT fn, const ArrayRef<int64_t> &elemsPerInstr,
Expand All @@ -118,13 +134,14 @@ llvm::SmallVector<Value> computeOffsetsBType(
// transpose reps and offsets, because operand B has layout equal to
// transposed operand A layout
// this unifies axis order, so non-K dim is 0, k dim is 1
auto rank = smemObj.getOffsets().size();
SmallVector<int64_t> tElemsPerInstr{elemsPerInstr[1], elemsPerInstr[0]};
SmallVector<int64_t> tReps{reps[1], reps[0]};
SmallVector<Value> tOffsets{smemObj.offsets[1], smemObj.offsets[0]};
SmallVector<Value> tStrides{smemObj.strides[1], smemObj.strides[0]};
SmallVector<int64_t> tReps = transposeSpatialDims(reps);
SmallVector<Value> tOffsets = transposeSpatialDims(smemObj.getOffsets());
SmallVector<Value> tStrides = transposeSpatialDims(smemObj.getStrides());

int vectorSize = 1;
if (srcLayout.getOrder()[0] == 0) {
if (srcLayout.getOrder()[0] == rank - 2) {
if (isSwizzled(srcLayout))
vectorSize = std::min(static_cast<int>(srcLayout.getVec()), numOfElems);
else
Expand All @@ -133,13 +150,13 @@ llvm::SmallVector<Value> computeOffsetsBType(

auto mapping = fn(rewriter, loc, tElemsPerInstr, warpId, laneId, numOfElems,
tReps, tOffsets, vectorSize, nonKDim, kDim);
const auto numBlocks = tReps[0];
const auto numBlocks = tReps[1];
const auto blockSize = mapping.size();
llvm::SmallVector<Value> bOffsets(blockSize * numBlocks);

for (int block = 0; block < numBlocks; ++block) {
int blockNonKOffset = block * nonKDim * warpsPerBlock;
Value offAdjust = mul(i32_val(blockNonKOffset), tStrides[0]);
Value offAdjust = mul(i32_val(blockNonKOffset), tStrides[rank - 2]);
for (int i = 0; i < mapping.size(); ++i) {
// swap row and col, because operand B layout is a transposed operand A
// layout
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -384,35 +384,20 @@ Value convertLayout(int opIdx, ConversionPatternRewriter &rewriter,
Value waveBlockOffAdjust = i32_val(blockNonKOffset * shape[order[0]]);
for (int k = 0; k < numRepK; ++k) {
auto vecTy = vec_ty(resElemTy, numOfElems);
Value valVec = undef(vecTy);
for (unsigned loadId = 0; loadId < loadsPerThread; ++loadId) {
auto loadVecTy = vec_ty(elemTy, elemsPerLoad);
Value loadOffset;
if (isFastPath)
loadOffset = offsets[nonK * loadsPerThread * numRepK +
k * loadsPerThread + loadId];
else
// In the normal path, we only computed the offsets of elements
// in the first wave-block. Therefore, we update the offsets
// of elements in later wave-blocks by adding a constant stride
loadOffset =
add(waveBlockOffAdjust, offsets[k * loadsPerThread + loadId]);
loadOffset = offsets[nonK * loadsPerThread * numRepK +
k * loadsPerThread + loadId];
loadOffset = add(loadOffset, batchOffset);
Value loadAddress = gep(smemPtrTy, elemTy, smemBase, loadOffset);
Value loadedValue = load(loadVecTy, loadAddress);
if (loadsPerThread > 1) {
for (int elemId = 0; elemId < elemsPerLoad; ++elemId) {
Value elemVal =
extract_element(elemTy, loadedValue, i32_val(elemId));
elemVal = bitcast(elemVal, resElemTy);
valVec = insert_element(vecTy, valVec, elemVal,
i32_val(loadId * elemsPerLoad + elemId));
}
} else {
valVec = loadedValue;
for (int elemId = 0; elemId < elemsPerLoad; ++elemId) {
Value elemVal =
extract_element(elemTy, loadedValue, i32_val(elemId));
loadedValues.push_back(elemVal);
}
}
loadedValues.push_back(valVec);
}
}
}
Expand Down
18 changes: 14 additions & 4 deletions third_party/amd/lib/TritonAMDGPUToLLVM/DotOpToLLVM/MFMA.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -264,9 +264,10 @@ struct DotOpMFMAConversionHelper {
return success();
}

/**
* @brief Converts dot operand structure to value table and converts types appropriate for mfma instructions
*/
/**
* @brief Converts dot operand structure to value table and converts types
* appropriate for mfma instructions
*/
ValueTable getValuesFromDotOperandLayoutStruct(Value value, int batch, int n0,
int n1, int kWidth,
Type type) const {
Expand All @@ -275,7 +276,16 @@ struct DotOpMFMAConversionHelper {
for (int b = 0; b < batch; ++b) {
for (int i = 0; i < n0; i++) {
for (int j = 0; j < n1; j++) {
auto rawElems = elems[b * n0 * n1 + n1 * i + j];
Type elemTy = typeConverter->convertType(type);
Type ty = vec_ty(elemTy, kWidth);
Value rawElems = undef(ty);
for (int k = 0; k < kWidth; ++k) {
rawElems = insert_element(
ty, rawElems,
elems[kWidth * n1 * n0 * b + kWidth * n1 * i + kWidth * j + k],
i32_val(k));
}

Value convertedElems;
if (type.isF32()) {
convertedElems = extract_element(type, rawElems, i32_val(0));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,9 +55,9 @@ int getMfmaVersion(MatrixCoreVersion matrixCoreVer) {
}

SmallVector<unsigned, 2> warpsPerTile(tt::DotOp dotOp,
const ArrayRef<int64_t> shape,
int numWarps,
SmallVector<int64_t, 2> shapePerWarp) {
const ArrayRef<int64_t> shape,
int numWarps,
SmallVector<int64_t, 2> shapePerWarp) {
auto rank = shape.size();
// Early exit for batched matmul
if (rank == 3)
Expand Down

0 comments on commit 352331c

Please sign in to comment.