diff --git a/.gitignore b/.gitignore index 89f29ff753d0..af59b15e3023 100644 --- a/.gitignore +++ b/.gitignore @@ -10,10 +10,9 @@ python/triton*.egg-info/ python/triton/_C/*.pyd python/triton/_C/*.so python/triton/_C/*.dylib -python/triton.egg-info/ -python/triton/_C/libtriton.pyd -python/triton/_C/libtriton.so -python/triton/_C/triton.dll +python/triton/_C/*.pdb +python/triton/_C/*.exe +python/triton/_C/*.ilk # Backends copied from submodules python/triton/backends/ @@ -47,6 +46,10 @@ cuobjdump nvdisasm ptxas +cuobjdump.exe +nvdisasm.exe +ptxas.exe + # Third-party include third_party/nvidia/backend/include diff --git a/CMakeLists.txt b/CMakeLists.txt index 1bdf05384133..01157b023643 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -22,9 +22,16 @@ endif() # Options +if(WIN32) + set(DEFAULT_BUILD_PROTON OFF) +else() + set(DEFAULT_BUILD_PROTON ON) +endif() + +# Define the option with the determined default value +option(TRITON_BUILD_PROTON "Build the Triton Proton profiler" ${DEFAULT_BUILD_PROTON}) option(TRITON_BUILD_TUTORIALS "Build C++ Triton tutorials" ON) option(TRITON_BUILD_PYTHON_MODULE "Build Python Triton bindings" OFF) -option(TRITON_BUILD_PROTON "Build the Triton Proton profiler" ON) option(TRITON_BUILD_UT "Build C++ Triton Unit Tests" ON) set(TRITON_CODEGEN_BACKENDS "" CACHE STRING "Enable different codegen backends") @@ -134,7 +141,8 @@ include_directories(${PROJECT_BINARY_DIR}/include) # Tablegen'd files include_directories(${PROJECT_SOURCE_DIR}/third_party) include_directories(${PROJECT_BINARY_DIR}/third_party) # Tablegen'd files -# link_directories(${LLVM_LIBRARY_DIR}) +link_directories(${LLVM_LIBRARY_DIR}) + add_subdirectory(include) add_subdirectory(lib) @@ -244,7 +252,7 @@ if(TRITON_BUILD_PYTHON_MODULE) LLVMAArch64CodeGen LLVMAArch64AsmParser ) - elseif(CMAKE_SYSTEM_PROCESSOR MATCHES "x86_64") + elseif(CMAKE_SYSTEM_PROCESSOR MATCHES "x86_64" OR CMAKE_SYSTEM_PROCESSOR MATCHES "AMD64") list(APPEND TRITON_LIBRARIES LLVMX86CodeGen LLVMX86AsmParser @@ -272,7 +280,7 @@ if(TRITON_BUILD_PYTHON_MODULE) # Link triton with its dependencies target_link_libraries(triton PUBLIC ${TRITON_LIBRARIES}) - if(WIN32) + if(WIN32) target_link_libraries(triton PRIVATE ${CMAKE_DL_LIBS}) set_target_properties(triton PROPERTIES SUFFIX ".pyd") set_target_properties(triton PROPERTIES PREFIX "lib") @@ -301,6 +309,7 @@ if(NOT TRITON_BUILD_PYTHON_MODULE) foreach(CODEGEN_BACKEND ${TRITON_CODEGEN_BACKENDS}) add_subdirectory(third_party/${CODEGEN_BACKEND}) endforeach() +endif() if(WIN32) option(CMAKE_USE_WIN32_THREADS_INIT "using WIN32 threads" ON) option(gtest_disable_pthreads "Disable uses of pthreads in gtest." ON) diff --git a/bin/RegisterTritonDialects.h b/bin/RegisterTritonDialects.h index 1bd1db9496ea..02be6d456e4a 100644 --- a/bin/RegisterTritonDialects.h +++ b/bin/RegisterTritonDialects.h @@ -4,10 +4,11 @@ #include "triton/Dialect/TritonGPU/IR/Dialect.h" #include "triton/Dialect/TritonNvidiaGPU/IR/Dialect.h" -// Below headers will allow registration to ROCm passes +#ifndef WIN32 #include "TritonAMDGPUToLLVM/Passes.h" #include "TritonAMDGPUTransforms/Passes.h" #include "TritonAMDGPUTransforms/TritonGPUConversion.h" +#endif #include "triton/Dialect/Triton/Transforms/Passes.h" #include "triton/Dialect/TritonGPU/Transforms/Passes.h" @@ -48,6 +49,7 @@ inline void registerTritonDialects(mlir::DialectRegistry ®istry) { mlir::triton::registerDecomposeUnsupportedNVIDIAConversions(); mlir::registerLLVMDIScope(); +#ifndef WIN32 // TritonAMDGPUToLLVM passes mlir::triton::registerConvertTritonAMDGPUToLLVM(); mlir::triton::registerConvertBuiltinFuncToLLVM(); @@ -58,7 +60,7 @@ inline void registerTritonDialects(mlir::DialectRegistry ®istry) { mlir::registerTritonAMDGPUOptimizeEpilogue(); mlir::registerTritonAMDGPUReorderInstructions(); mlir::registerTritonAMDGPUStreamPipeline(); - +#endif // TODO: register Triton & TritonGPU passes registry.insert #include +#include namespace mlir { diff --git a/lib/Analysis/Utility.cpp b/lib/Analysis/Utility.cpp index 32cc43c9d5d2..83a40f606249 100644 --- a/lib/Analysis/Utility.cpp +++ b/lib/Analysis/Utility.cpp @@ -19,11 +19,8 @@ namespace mlir { namespace { -using namespace triton; -using namespace triton::gpu; - int getParentAxis(Attribute layout, int axis) { - if (auto sliceEncoding = dyn_cast(layout)) { + if (auto sliceEncoding = dyn_cast(layout)) { axis = axis < sliceEncoding.getDim() ? axis : axis + 1; return getParentAxis(sliceEncoding.getParent(), axis); } @@ -31,10 +28,11 @@ int getParentAxis(Attribute layout, int axis) { } SmallVector getParentOrder(Attribute layout) { - if (auto sliceEncoding = mlir::dyn_cast(layout)) { + if (auto sliceEncoding = + mlir::dyn_cast(layout)) { return getParentOrder(sliceEncoding.getParent()); } - return getOrder(layout); + return triton::gpu::getOrder(layout); } } // namespace @@ -47,7 +45,7 @@ bool ReduceOpHelper::isReductionOnLayoutFastAxis() { SmallVector ReduceOpHelper::getOrderWithAxisAtBeginning() { auto srcLayout = getSrcLayout(); - auto order = getOrder(srcLayout); + auto order = triton::gpu::getOrder(srcLayout); auto it = std::find(order.begin(), order.end(), axis); // delete the axis from order order.erase(it); @@ -67,13 +65,14 @@ unsigned ReduceOpHelper::getThreadOffsetOnReductionAxis() { } unsigned threadOffset = 1; - if (auto sliceLayout = mlir::dyn_cast(srcLayout)) { + if (auto sliceLayout = + mlir::dyn_cast(srcLayout)) { auto parentLayout = sliceLayout.getParent(); - auto threadsPerWarp = getThreadsPerWarp(parentLayout); + auto threadsPerWarp = triton::gpu::getThreadsPerWarp(parentLayout); threadOffset = threadsPerWarp[sliceLayout.getDim()]; } else { - auto threadsPerWarp = getThreadsPerWarp(srcLayout); - auto order = getOrder(srcLayout); + auto threadsPerWarp = triton::gpu::getThreadsPerWarp(srcLayout); + auto order = triton::gpu::getOrder(srcLayout); for (unsigned i = 0; i < order.size(); i++) { if (order[i] == axis) break; @@ -89,8 +88,8 @@ unsigned ReduceOpHelper::getThreadOffsetOnReductionAxis() { // TODO: Case with SliceLayout as srcLayout and numCTAs > 1 is to be implemented // in the future bool shouldUseDistSmem(Attribute srcLayout, Attribute dstLayout) { - unsigned numCTAs = getNumCTAs(srcLayout); - assert(numCTAs == getNumCTAs(dstLayout) && + unsigned numCTAs = triton::gpu::getNumCTAs(srcLayout); + assert(numCTAs == triton::gpu::getNumCTAs(dstLayout) && "Invalid layout conversion: the numbers of CTAs of src and dst " "layouts are different"); @@ -100,17 +99,19 @@ bool shouldUseDistSmem(Attribute srcLayout, Attribute dstLayout) { // Case where CTAsPerCGA of srcLayout in the sliced dim is not 1 is not // implemented yet - if (auto sliceLayout = mlir::dyn_cast(srcLayout)) { + if (auto sliceLayout = + mlir::dyn_cast(srcLayout)) { auto dim = sliceLayout.getDim(); - auto CTAsPerCGA = getCTAsPerCGA(sliceLayout.getParent()); + auto CTAsPerCGA = triton::gpu::getCTAsPerCGA(sliceLayout.getParent()); if (CTAsPerCGA[dim] != 1) llvm::report_fatal_error("Layout conversion to be implemented"); } // Case where CTAsPerCGA of dstLayout in the sliced dim is not 1 is supported - if (auto sliceLayout = mlir::dyn_cast(dstLayout)) { + if (auto sliceLayout = + mlir::dyn_cast(dstLayout)) { auto dim = sliceLayout.getDim(); - auto CTAsPerCGA = getCTAsPerCGA(sliceLayout.getParent()); + auto CTAsPerCGA = triton::gpu::getCTAsPerCGA(sliceLayout.getParent()); if (CTAsPerCGA[dim] != 1) return true; } @@ -119,8 +120,8 @@ bool shouldUseDistSmem(Attribute srcLayout, Attribute dstLayout) { // srcLayout and dstLayout // Case (2): Do not use dsmem when srcCTALayout == dstCTALayout - auto srcCTALayout = getCTALayout(srcLayout); - auto dstCTALayout = getCTALayout(dstLayout); + auto srcCTALayout = triton::gpu::getCTALayout(srcLayout); + auto dstCTALayout = triton::gpu::getCTALayout(dstLayout); if (srcCTALayout == dstCTALayout) return false; @@ -132,42 +133,45 @@ unsigned ReduceOpHelper::getInterWarpSize() { auto srcReduceDimSize = static_cast(srcShape[axis]); unsigned sizeIntraWarps = getIntraWarpSize(); return std::min(srcReduceDimSize / sizeIntraWarps, - getWarpsPerCTA(getSrcLayout())[axis]); + triton::gpu::getWarpsPerCTA(getSrcLayout())[axis]); } unsigned ReduceOpHelper::getIntraWarpSize() { auto srcReduceDimSize = static_cast(srcShape[axis]); - return std::min(srcReduceDimSize, getThreadsPerWarp(getSrcLayout())[axis]); + return std::min(srcReduceDimSize, + triton::gpu::getThreadsPerWarp(getSrcLayout())[axis]); } unsigned ReduceOpHelper::getInterWarpSizeWithUniqueData() { auto srcReduceDimSize = static_cast(srcShape[axis]); unsigned sizeIntraWarps = getIntraWarpSizeWithUniqueData(); - return std::min( - srcReduceDimSize / sizeIntraWarps, - getWarpsPerCTAWithUniqueData(getSrcLayout(), getSrcShape())[axis]); + return std::min(srcReduceDimSize / sizeIntraWarps, + triton::gpu::getWarpsPerCTAWithUniqueData( + getSrcLayout(), getSrcShape())[axis]); } unsigned ReduceOpHelper::getIntraWarpSizeWithUniqueData() { auto srcReduceDimSize = static_cast(srcShape[axis]); - unsigned elementPerThreads = - getUniqueContigPerThread(getSrcLayout(), getSrcShape())[axis]; - return std::min( - srcReduceDimSize / elementPerThreads, - getThreadsPerWarpWithUniqueData(getSrcLayout(), getSrcShape())[axis]); + unsigned elementPerThreads = triton::gpu::getUniqueContigPerThread( + getSrcLayout(), getSrcShape())[axis]; + return std::min(srcReduceDimSize / elementPerThreads, + triton::gpu::getThreadsPerWarpWithUniqueData( + getSrcLayout(), getSrcShape())[axis]); } unsigned ReduceOpHelper::getThreadsReductionAxis() { auto srcLayout = getSrcLayout(); auto srcShape = getSrcShape(); - return getThreadsPerWarpWithUniqueData(srcLayout, srcShape)[axis] * - getWarpsPerCTAWithUniqueData(srcLayout, srcShape)[axis]; + return triton::gpu::getThreadsPerWarpWithUniqueData(srcLayout, + srcShape)[axis] * + triton::gpu::getWarpsPerCTAWithUniqueData(srcLayout, srcShape)[axis]; } bool ReduceOpHelper::isWarpSynchronous() { auto srcLayout = getSrcLayout(); auto srcShape = getSrcShape(); - return getWarpsPerCTAWithUniqueData(srcLayout, srcShape)[axis] == 1; + return triton::gpu::getWarpsPerCTAWithUniqueData(srcLayout, srcShape)[axis] == + 1; } SmallVector ReduceOpHelper::getScratchConfig() { @@ -196,7 +200,7 @@ unsigned ReduceOpHelper::getScratchSizeInBytes() { bool ReduceOpHelper::isReduceWithinCTA() { auto axis = getAxis(); auto srcLayout = getSrcLayout(); - auto CTASplitNum = getCTASplitNum(srcLayout); + auto CTASplitNum = triton::gpu::getCTASplitNum(srcLayout); assert(axis < CTASplitNum.size()); return CTASplitNum[axis] == 1; } @@ -209,13 +213,13 @@ bool ReduceOpHelper::isSupportedLayout() { } auto srcLayout = getSrcLayout(); - if (isa(srcLayout)) { + if (isa(srcLayout)) { return true; } - if (auto mmaLayout = dyn_cast(srcLayout)) { + if (auto mmaLayout = dyn_cast(srcLayout)) { return mmaLayout.supportReduction(); } - if (auto sliceLayout = dyn_cast(srcLayout)) { + if (auto sliceLayout = dyn_cast(srcLayout)) { return true; } return false; @@ -226,7 +230,7 @@ unsigned ScanLoweringHelper::getAxisNumElementsPerThread() { } unsigned ScanLoweringHelper::getNonAxisNumElementsPerThread() { - SmallVector sizePerThreads = getContigPerThread(getEncoding()); + SmallVector sizePerThreads = triton::gpu::getContigPerThread(getEncoding()); sizePerThreads[getAxis()] = 1; return product(sizePerThreads); } @@ -293,14 +297,14 @@ unsigned ScanLoweringHelper::getNonAxisNumBlocks() { bool ScanLoweringHelper::isSupported() { // TODO: Support the following cases: // 1. Scan on non-blocking encodings - if (!isa(getEncoding())) + if (!isa(getEncoding())) return false; return true; } unsigned ScanLoweringHelper::getScratchSizeInElems() { auto mod = scanOp->getParentOfType(); - unsigned numWarps = TritonGPUDialect::getNumWarps(mod); + unsigned numWarps = triton::gpu::TritonGPUDialect::getNumWarps(mod); unsigned numNonAxisElementsPerWarp = getNonAxisNumThreadsPerWarp() * getNonAxisNumElementsPerThread(); unsigned numElements = numWarps * numNonAxisElementsPerWarp * @@ -357,8 +361,8 @@ getReshapeDecomposition(ArrayRef srcShape, return ret; } -BlockedEncodingAttr ScanLoweringHelper::getEncoding() { - return cast(srcEncoding); +triton::gpu::BlockedEncodingAttr ScanLoweringHelper::getEncoding() { + return cast(srcEncoding); } unsigned ScanLoweringHelper::getAxisElementStride() { @@ -367,7 +371,7 @@ unsigned ScanLoweringHelper::getAxisElementStride() { for (unsigned dim : order) { if (dim == getAxis()) return stride; - stride *= getContigPerThread(getEncoding())[dim]; + stride *= triton::gpu::getContigPerThread(getEncoding())[dim]; } llvm_unreachable("Axis not found in order"); } @@ -405,7 +409,8 @@ bool maybeSharedAllocationOp(Operation *op) { // query the memory effects of the op. auto *dialect = op->getDialect(); return dialect && - (dialect->getTypeID() == TypeID::get() || + (dialect->getTypeID() == + TypeID::get() || dialect->getTypeID() == TypeID::get() || dialect->getTypeID() == TypeID::get() || @@ -542,10 +547,10 @@ bool supportMMA(triton::DotOp op, int version) { if (triton::tools::getBoolEnv("DISABLE_MMA_V3")) return false; auto retType = op.getType(); - auto retShapePerCTA = getShapePerCTA(retType); + auto retShapePerCTA = triton::gpu::getShapePerCTA(retType); auto rank = retShapePerCTA.size(); auto mod = op->getParentOfType(); - int numWarps = TritonGPUDialect::getNumWarps(mod); + int numWarps = triton::gpu::TritonGPUDialect::getNumWarps(mod); if (!(numWarps % 4 == 0 && retShapePerCTA[rank - 2] % 64 == 0 && retShapePerCTA[rank - 1] % 8 == 0 && (aElemTy.isFloat8E5M2() || aElemTy.isFloat8E4M3FNUZ() || @@ -561,7 +566,8 @@ bool supportMMA(triton::DotOp op, int version) { } } if (aElemTy.isF32() && bElemTy.isF32()) { - return op.getInputPrecision() == InputPrecision::TF32 && version >= 2; + return op.getInputPrecision() == triton::InputPrecision::TF32 && + version >= 2; } return supportMMA(op.getA(), version) && supportMMA(op.getB(), version); } @@ -585,8 +591,8 @@ bool supportMMA(Value value, int version) { bool isMfmaToDotShortcut(RankedTensorType &srcTy, RankedTensorType &dstTy) { auto srcLayout = srcTy.getEncoding(); auto dstLayout = dstTy.getEncoding(); - auto mfmaLayout = dyn_cast(srcLayout); - auto dotOperandLayout = dyn_cast(dstLayout); + auto mfmaLayout = dyn_cast(srcLayout); + auto dotOperandLayout = dyn_cast(dstLayout); if (mfmaLayout == nullptr || dotOperandLayout == nullptr) return false; // TODO: Remove the restriction on the warpsPerCTA once chain dot testing is @@ -594,15 +600,15 @@ bool isMfmaToDotShortcut(RankedTensorType &srcTy, RankedTensorType &dstTy) { // layout when opIdx == 1. return mfmaLayout.getWarpsPerCTA()[1] == 1 && dotOperandLayout.getOpIdx() == 0 && mfmaLayout.getIsTransposed() && - dotOperandLayout.getKWidth() == getContigPerThread(mfmaLayout)[1] && + dotOperandLayout.getKWidth() == triton::gpu::getContigPerThread(mfmaLayout)[1] && dotOperandLayout.getParent() == mfmaLayout && (mfmaLayout.getMDim() == 32 || mfmaLayout.getMDim() == 16) && (srcTy.getElementType().isF16() || srcTy.getElementType().isBF16()); } static bool isMmaToMmaShortcut(Attribute srcEncoding, Attribute dstEncoding) { - auto src = dyn_cast(srcEncoding); - auto dst = dyn_cast(dstEncoding); + auto src = dyn_cast(srcEncoding); + auto dst = dyn_cast(dstEncoding); if (!src || !dst) return false; // when #mma = MmaEncoding @@ -620,8 +626,8 @@ bool matchMmaV3AndDotOperandLayout(RankedTensorType srcTy, RankedTensorType dstTy) { auto srcLayout = srcTy.getEncoding(); auto dstLayout = dstTy.getEncoding(); - auto mmaLayout = cast(srcLayout); - auto dotOperandLayout = cast(dstLayout); + auto mmaLayout = cast(srcLayout); + auto dotOperandLayout = cast(dstLayout); int elementTypeSize = srcTy.getElementType().getIntOrFloatBitWidth(); auto ans = mmaLayout.getVersionMajor() == 3 && dotOperandLayout.getOpIdx() == 0 && @@ -637,8 +643,9 @@ bool isMmaToDotShortcut(RankedTensorType srcTy, RankedTensorType dstTy) { // when #mma = MmaEncoding auto srcLayout = srcTy.getEncoding(); auto dstLayout = dstTy.getEncoding(); - auto mmaLayout = mlir::cast(srcLayout); - auto dotOperandLayout = mlir::cast(dstLayout); + auto mmaLayout = mlir::cast(srcLayout); + auto dotOperandLayout = + mlir::cast(dstLayout); return mmaLayout.getVersionMajor() == 2 && mmaLayout.getWarpsPerCTA()[1] == 1 && dotOperandLayout.getOpIdx() == 0 && @@ -865,13 +872,13 @@ std::unique_ptr createDataFlowSolver() { return solver; } -static MakeTensorPtrOp getMakeTensorPtrOpImpl(Operation *op, Value v) { +static triton::MakeTensorPtrOp getMakeTensorPtrOpImpl(Operation *op, Value v) { - if (auto makeTensorPtrOp = dyn_cast(op)) { + if (auto makeTensorPtrOp = dyn_cast(op)) { return makeTensorPtrOp; } - if (auto advanceOp = dyn_cast(op)) { + if (auto advanceOp = dyn_cast(op)) { return getMakeTensorPtrOp(advanceOp.getPtr()); } @@ -891,7 +898,7 @@ static MakeTensorPtrOp getMakeTensorPtrOpImpl(Operation *op, Value v) { llvm_unreachable("Unable to getMakeTensorPtr()"); } -MakeTensorPtrOp getMakeTensorPtrOp(Value v) { +triton::MakeTensorPtrOp getMakeTensorPtrOp(Value v) { using BranchOps = llvm::SetVector>; llvm::DenseMap blockToCFOps; auto moduleOp = diff --git a/lib/Dialect/TritonGPU/Transforms/Utility.cpp b/lib/Dialect/TritonGPU/Transforms/Utility.cpp index 57c535959ae2..8990bd736a2f 100644 --- a/lib/Dialect/TritonGPU/Transforms/Utility.cpp +++ b/lib/Dialect/TritonGPU/Transforms/Utility.cpp @@ -21,8 +21,6 @@ namespace mlir { -using namespace triton; - SmallVector mmaVersionToInstrShape(int version, const ArrayRef &shape, RankedTensorType type, @@ -103,18 +101,18 @@ Value getMemAccessPtr(Operation *op) { unsigned getElementBitWidth(RankedTensorType type) { auto typeForMem = - isa(type.getElementType()) - ? cast(type.getElementType()).getPointeeType() + isa(type.getElementType()) + ? cast(type.getElementType()).getPointeeType() : type.getElementType(); return typeForMem.getIntOrFloatBitWidth(); } unsigned getNumElementsPerThread(Operation *op, SmallVector order, - ModuleAxisInfoAnalysis &axisInfoAnalysis) { + triton::ModuleAxisInfoAnalysis &axisInfoAnalysis) { Value val = getMemAccessPtr(op); auto ty = cast(val.getType()); auto shapePerCTA = triton::gpu::getShapePerCTA(ty); - AxisInfo &valInfo = *axisInfoAnalysis.getAxisInfo(val); + triton::AxisInfo &valInfo = *axisInfoAnalysis.getAxisInfo(val); unsigned elemNumBits = getElementBitWidth(ty); unsigned elemNumBytes = std::max(elemNumBits / 8, 1u); unsigned maxMultipleBytes = valInfo.getDivisibility(order[0]); @@ -302,10 +300,11 @@ static std::optional inferDstEncoding(triton::ExpandDimsOp op, return sliceEncoding.getParent(); } -static std::optional inferDstEncoding(JoinOp op, Attribute srcEnc) { +static std::optional inferDstEncoding(triton::JoinOp op, + Attribute srcEnc) { Attribute dstEnc; if (srcEnc.getDialect() - .getRegisteredInterface() + .getRegisteredInterface() ->inferJoinOpEncoding(srcEnc, dstEnc, /*loc=*/std::nullopt) .succeeded()) { @@ -314,10 +313,11 @@ static std::optional inferDstEncoding(JoinOp op, Attribute srcEnc) { return std::nullopt; } -static std::optional inferDstEncoding(SplitOp op, Attribute srcEnc) { +static std::optional inferDstEncoding(triton::SplitOp op, + Attribute srcEnc) { Attribute dstEnc; if (srcEnc.getDialect() - .getRegisteredInterface() + .getRegisteredInterface() ->inferSplitOpEncoding(srcEnc, dstEnc, /*loc=*/std::nullopt) .succeeded()) { @@ -342,11 +342,12 @@ static std::optional inferSrcEncoding(triton::ExpandDimsOp op, encoding); } -static std::optional inferSrcEncoding(JoinOp op, Attribute dstEnc) { +static std::optional inferSrcEncoding(triton::JoinOp op, + Attribute dstEnc) { // Split is the inverse of join. Attribute srcEnc; if (dstEnc.getDialect() - .getRegisteredInterface() + .getRegisteredInterface() ->inferSplitOpEncoding(dstEnc, srcEnc, /*loc=*/std::nullopt) .succeeded()) { return srcEnc; @@ -354,11 +355,12 @@ static std::optional inferSrcEncoding(JoinOp op, Attribute dstEnc) { return std::nullopt; } -static std::optional inferSrcEncoding(SplitOp op, Attribute dstEnc) { +static std::optional inferSrcEncoding(triton::SplitOp op, + Attribute dstEnc) { // Join is the inverse of split. Attribute srcEnc; if (dstEnc.getDialect() - .getRegisteredInterface() + .getRegisteredInterface() ->inferJoinOpEncoding(dstEnc, srcEnc, /*loc=*/std::nullopt) .succeeded()) { return srcEnc; @@ -812,7 +814,7 @@ Value linearize(OpBuilder &b, Location loc, ArrayRef multiDim, } bool isPureUnaryInlineAsm(Operation *op) { - auto inlineAsmOp = dyn_cast(op); + auto inlineAsmOp = dyn_cast(op); if (!inlineAsmOp) return false; return op->getNumOperands() == 1 && op->getNumResults() == 1 && diff --git a/python/setup.py b/python/setup.py index a6574d459c28..a6b81a9b17a1 100644 --- a/python/setup.py +++ b/python/setup.py @@ -87,6 +87,49 @@ def copy_externals(): ] +def find_vswhere(): + program_files = os.environ.get("ProgramFiles(x86)", "C:\\Program Files (x86)") + vswhere_path = Path(program_files) / "Microsoft Visual Studio" / "Installer" / "vswhere.exe" + if vswhere_path.exists(): + return vswhere_path + return None + +def find_visual_studio(version_ranges): + vswhere = find_vswhere() + if not vswhere: + raise FileNotFoundError("vswhere.exe not found.") + + for version_range in version_ranges: + command = [ + str(vswhere), + "-version", version_range, + "-requires", "Microsoft.VisualStudio.Component.VC.Tools.x86.x64", + "-property", "installationPath", + "-prerelease" + ] + + try: + output = subprocess.check_output(command, text=True).strip() + if output: + return output + except subprocess.CalledProcessError: + continue + + return None + +def set_env_vars(vs_path, arch="x64"): + vcvarsall_path = Path(vs_path) / "VC" / "Auxiliary" / "Build" / "vcvarsall.bat" + if not vcvarsall_path.exists(): + raise FileNotFoundError(f"vcvarsall.bat not found in expected path: {vcvarsall_path}") + + command = f'call "{vcvarsall_path}" {arch} && set' + output = subprocess.check_output(command, shell=True, text=True) + + for line in output.splitlines(): + if '=' in line: + var, value = line.split('=', 1) + os.environ[var] = value + # Taken from https://github.com/pytorch/pytorch/blob/master/tools/setup_helpers/env.py def check_env_flag(name: str, default: str = "") -> bool: return os.getenv(name, default).upper() in ["ON", "1", "YES", "TRUE", "Y"] @@ -144,10 +187,7 @@ def get_json_package_info(): # llvm def get_llvm_package_info(): system = platform.system() - try: - arch = {"x86_64": "x64", "arm64": "arm64", "aarch64": "arm64"}[platform.machine()] - except KeyError: - arch = platform.machine() + arch = {"x86_64": "x64", "AMD64": "64" , "arm64": "arm64", "aarch64": "arm64"}[platform.machine()] if system == "Darwin": system_suffix = f"macos-{arch}" elif system == "Linux": @@ -251,17 +291,16 @@ def download_and_copy(name, src_path, variable, version, url_func): return base_dir = os.path.dirname(__file__) system = platform.system() - try: - arch = {"x86_64": "64", "arm64": "aarch64", "aarch64": "aarch64"}[platform.machine()] - except KeyError: - arch = platform.machine() - url = url_func(arch, version) + arch = {"x86_64": "64","AMD64": "64", "arm64": "aarch64", "aarch64": "aarch64"}[platform.machine()] + supported = {"Linux": "linux", "Windows": "win"} + is_supported = system in supported + if is_supported: + url = url_func(supported[system], arch, version) tmp_path = os.path.join(triton_cache_path, "nvidia", name) # path to cache the download dst_path = os.path.join(base_dir, os.pardir, "third_party", "nvidia", "backend", src_path) # final binary path src_path = os.path.join(tmp_path, src_path) - src_path += ".exe" if os.name == "nt" else "" download = not os.path.exists(src_path) - if os.path.exists(dst_path) and system == "Linux" and shutil.which(dst_path) is not None: + if os.path.exists(dst_path) and system == "Linux" and shutil.which(dst_path) is not None and is_supported: curr_version = subprocess.check_output([dst_path, "--version"]).decode("utf-8").strip() curr_version = re.search(r"V([.|\d]+)", curr_version).group(1) download = download or curr_version != version @@ -357,6 +396,12 @@ def get_proton_cmake_args(self): def build_extension(self, ext): lit_dir = shutil.which('lit') ninja_dir = shutil.which('ninja') + if platform.system() == "Windows": + vs_path = find_visual_studio(["[17.0,18.0)", "[16.0,17.0)"]) + env = set_env_vars(vs_path) + print(vs_path) + if not vs_path: + raise EnvironmentError("Visual Studio 2019 or 2022 not found.") # lit is used by the test suite thirdparty_cmake_args = get_thirdparty_packages([get_pybind11_package_info(), get_llvm_package_info()]) extdir = os.path.abspath(os.path.dirname(self.get_ext_fullpath(ext.path))) @@ -441,9 +486,10 @@ def build_extension(self, ext): with open(nvidia_version_path, "r") as nvidia_version_file: NVIDIA_TOOLCHAIN_VERSION = nvidia_version_file.read().strip() +extension = ".exe" if os.name == "nt" else "" download_and_copy( name="ptxas", - src_path="bin/ptxas", + src_path=f"bin/ptxas{extension}", variable="TRITON_PTXAS_PATH", version=NVIDIA_TOOLCHAIN_VERSION, url_func=lambda system, arch, version: @@ -451,7 +497,7 @@ def build_extension(self, ext): ) download_and_copy( name="cuobjdump", - src_path="bin/cuobjdump", + src_path=f"bin/cuobjdump{extension}", variable="TRITON_CUOBJDUMP_PATH", version=NVIDIA_TOOLCHAIN_VERSION, url_func=lambda system, arch, version: @@ -459,7 +505,7 @@ def build_extension(self, ext): ) download_and_copy( name="nvdisasm", - src_path="bin/nvdisasm", + src_path=f"bin/nvdisasm{extension}", variable="TRITON_NVDISASM_PATH", version=NVIDIA_TOOLCHAIN_VERSION, url_func=lambda system, arch, version: @@ -490,7 +536,10 @@ def build_extension(self, ext): f"https://anaconda.org/nvidia/cuda-cupti/{version}/download/{system}-{arch}/cuda-cupti-{version}-0.tar.bz2", ) -backends = [*BackendInstaller.copy(["nvidia", "amd"]), *BackendInstaller.copy_externals()] +backends = ["nvidia", "amd"] +if os.name == "nt": + backends = ["nvidia"] +backends = [*BackendInstaller.copy(backends), *BackendInstaller.copy_externals()] def add_link_to_backends(): diff --git a/python/src/interpreter.cc b/python/src/interpreter.cc index 6ab7c6c75c70..7c825ce06732 100644 --- a/python/src/interpreter.cc +++ b/python/src/interpreter.cc @@ -1,6 +1,7 @@ #include #include #include +#include #include #include #include @@ -14,16 +15,15 @@ enum class MemSemantic { ACQUIRE_RELEASE, ACQUIRE, RELEASE, RELAXED }; enum class RMWOp { ADD, FADD, AND, OR, XOR, XCHG, MAX, MIN, UMIN, UMAX }; std::map mem_semantic_map = { - {MemSemantic::ACQUIRE_RELEASE, __ATOMIC_ACQ_REL}, - {MemSemantic::ACQUIRE, __ATOMIC_ACQUIRE}, - {MemSemantic::RELEASE, __ATOMIC_RELEASE}, - {MemSemantic::RELAXED, __ATOMIC_RELAXED}, + {MemSemantic::ACQUIRE_RELEASE, static_cast(std::memory_order_acq_rel)}, + {MemSemantic::ACQUIRE, static_cast(std::memory_order_acquire)}, + {MemSemantic::RELEASE, static_cast(std::memory_order_release)}, + {MemSemantic::RELAXED, static_cast(std::memory_order_relaxed)}, }; - // Use compiler builtin atomics instead of std::atomic which requires // each variable to be declared as atomic. // Currently work for clang and gcc. -template T atomic_cmp(T *ptr, T val, int order) { +template T atomic_cmp(std::atomic *ptr, T val, std::memory_order order) { auto cmp = [](T old, T val) { if constexpr (is_min) { return old > val; @@ -31,38 +31,29 @@ template T atomic_cmp(T *ptr, T val, int order) { return old < val; } }; + // First load - T old_val = __atomic_load_n(ptr, order); + T old_val = ptr->load(order); while (cmp(old_val, val)) { - if (__atomic_compare_exchange(ptr, &old_val, &val, false, order, order)) { + if (ptr->compare_exchange_weak(old_val, val, order, order)) { break; } } return old_val; } -template T atomic_fadd(T *ptr, T val, int order) { - T old_val; - T new_val; - // First load - // Load ptr as if uint32_t or uint64_t and then memcpy to T - if constexpr (sizeof(T) == 4) { - uint32_t tmp = __atomic_load_n(reinterpret_cast(ptr), order); - std::memcpy(&old_val, &tmp, sizeof(T)); - } else if constexpr (sizeof(T) == 8) { - uint64_t tmp = __atomic_load_n(reinterpret_cast(ptr), order); - std::memcpy(&old_val, &tmp, sizeof(T)); - } else { - throw std::invalid_argument("Unsupported data type"); - } - while (true) { - new_val = old_val + val; - if (__atomic_compare_exchange(ptr, &old_val, &new_val, false, order, - order)) { - break; - } - } - return old_val; +template +T atomic_fadd(std::atomic *loc, T value, std::memory_order order) { + static_assert(std::is_floating_point::value, + "T must be a floating-point type"); + + T old_value = loc->load(order); + T new_value; + do { + new_value = old_value + value; + } while (!loc->compare_exchange_weak(old_value, new_value, order, order)); + + return old_value; } class AtomicOp { @@ -95,13 +86,14 @@ template class AtomicRMWOpBase : public AtomicOp { protected: void applyAt(void *loc, size_t i) override final { if (mask[i]) { - *(static_cast(ret) + i) = - applyAtMasked(static_cast(loc), - *(static_cast(val) + i), order); + std::atomic *atomic_ptr = static_cast *>(loc); + *(static_cast(ret) + i) = applyAtMasked(atomic_ptr, + *(static_cast(val) + i), std::memory_order(order)); } } - virtual DType applyAtMasked(DType *loc, const DType value, int order) = 0; + virtual DType applyAtMasked(std::atomic* loc, const DType value, + std::memory_order order) = 0; const void *val; void *ret; @@ -121,8 +113,8 @@ class AtomicRMWOp> using AtomicRMWOpBase::AtomicRMWOpBase; protected: - DType applyAtMasked(DType *loc, const DType value, int order) override { - return __atomic_fetch_add(loc, value, order); + DType applyAtMasked(std::atomic* loc, const DType value, std::memory_order order) override { + return std::atomic_fetch_add(loc, value); } }; @@ -133,7 +125,9 @@ class AtomicRMWOp> using AtomicRMWOpBase::AtomicRMWOpBase; protected: - DType applyAtMasked(DType *loc, const DType value, int order) override { + DType applyAtMasked(std::atomic* loc, const DType value, + std::memory_order order) override { + return atomic_fadd(loc, value, order); } }; @@ -145,8 +139,9 @@ class AtomicRMWOp> using AtomicRMWOpBase::AtomicRMWOpBase; protected: - DType applyAtMasked(DType *loc, const DType value, int order) override { - return __atomic_fetch_and(loc, value, order); + DType applyAtMasked(std::atomic* loc, const DType value, + std::memory_order order) override { + return std::atomic_fetch_and(loc, value); } }; @@ -157,8 +152,9 @@ class AtomicRMWOp> using AtomicRMWOpBase::AtomicRMWOpBase; protected: - DType applyAtMasked(DType *loc, const DType value, int order) override { - return __atomic_fetch_or(loc, value, order); + DType applyAtMasked(std::atomic* loc, const DType value, + std::memory_order order) override { + return std::atomic_fetch_or(loc, value); } }; @@ -169,8 +165,9 @@ class AtomicRMWOp> using AtomicRMWOpBase::AtomicRMWOpBase; protected: - DType applyAtMasked(DType *loc, const DType value, int order) override { - return __atomic_fetch_xor(loc, value, order); + DType applyAtMasked(std::atomic* loc, const DType value, + std::memory_order order) override { + return std::atomic_fetch_xor(loc, value); } }; @@ -182,7 +179,8 @@ class AtomicRMWOp::AtomicRMWOpBase; protected: - DType applyAtMasked(DType *loc, const DType value, int order) override { + DType applyAtMasked(std::atomic* loc, const DType value, + std::memory_order order) override { return atomic_cmp(loc, value, order); } }; @@ -195,8 +193,9 @@ class AtomicRMWOp::AtomicRMWOpBase; protected: - DType applyAtMasked(DType *loc, const DType value, int order) override { - return atomic_cmp(loc, value, order); + DType applyAtMasked(std::atomic* loc, const DType value, + std::memory_order order) override { + return atomic_cmp(loc, value, order); } }; @@ -207,8 +206,9 @@ class AtomicRMWOp> using AtomicRMWOpBase::AtomicRMWOpBase; protected: - DType applyAtMasked(DType *loc, const DType value, int order) override { - return __atomic_exchange_n(loc, value, order); + DType applyAtMasked(std::atomic* loc, const DType value, + std::memory_order order) override { + return loc->exchange(value, order); } }; @@ -224,25 +224,39 @@ class AtomicCASOp : public AtomicOp { // Atomic operations perform bitwise comparison, so it's safe to // use number of bytes (itemsize) to determine the type of pointers if (itemsize == 1) { + std::atomic *atomic_loc = + reinterpret_cast *>(loc); uint8_t desired_val = *(static_cast(desired) + i); - __atomic_compare_exchange_n(static_cast(loc), - static_cast(expected) + i, - desired_val, false, order, order); + uint8_t *expected_uint = static_cast(expected); + // Perform the compare and exchange operation + atomic_loc->compare_exchange_strong(*(expected_uint + i), desired_val, + std::memory_order(order), + std::memory_order(order)); + } else if (itemsize == 2) { + std::atomic *atomic_loc = reinterpret_cast *>(loc); uint16_t desired_val = *(static_cast(desired) + i); - __atomic_compare_exchange_n(static_cast(loc), - static_cast(expected) + i, - desired_val, false, order, order); + uint16_t *expected_uint = static_cast(expected); + atomic_loc->compare_exchange_strong(*(expected_uint + i), desired_val, + std::memory_order(order), + std::memory_order(order)); } else if (itemsize == 4) { + std::atomic *atomic_loc = reinterpret_cast *>(loc); uint32_t desired_val = *(static_cast(desired) + i); - __atomic_compare_exchange_n(static_cast(loc), - static_cast(expected) + i, - desired_val, false, order, order); + uint32_t *expected_uint = static_cast(expected); + atomic_loc->compare_exchange_strong(*(expected_uint + i), desired_val, + std::memory_order(order), + std::memory_order(order)); } else if (itemsize == 8) { uint64_t desired_val = *(static_cast(desired) + i); - __atomic_compare_exchange_n(static_cast(loc), - static_cast(expected) + i, - desired_val, false, order, order); + std::atomic *atomic_loc = + static_cast *>(loc); + uint64_t *expected_uint = static_cast(expected); + atomic_loc->compare_exchange_strong(*(expected_uint + i), desired_val, + std::memory_order(order), + std::memory_order(order)); + + } else { // The ‘__atomic’ builtins can be used with any integral scalar or pointer // type that is 1, 2, 4, or 8 bytes in length. 16-byte integral types are diff --git a/python/src/ir.cc b/python/src/ir.cc index 129daccd1bba..8a13e0a999db 100644 --- a/python/src/ir.cc +++ b/python/src/ir.cc @@ -215,9 +215,8 @@ void init_triton_ir(py::module &&m) { context.loadAllAvailableDialects(); }); - py::class_(m, "type", py::module_local()) - .def("is_integer", - [](Type &self, unsigned width) { return self.isInteger(width); }) + py::class_(m, "type", py::module_local()) + .def("is_integer", static_cast(&Type::isInteger)) .def("is_fp16", &Type::isF16) .def("__str__", [](Type &self) { std::string str; @@ -225,7 +224,7 @@ void init_triton_ir(py::module &&m) { self.print(os); return os.str(); }); - + py::class_(m, "function_type", py::module_local()) .def("param_types", [](FunctionType &self) { return std::vector(self.getInputs().begin(), @@ -1617,7 +1616,7 @@ void init_triton_ir(py::module &&m) { }); ::llvm::DebugFlag = true; - ::llvm::setCurrentDebugTypes(debugTypes.data(), debugTypes.size()); + llvm::setCurrentDebugTypes(debugTypes.data(), debugTypes.size()); } bool haveTiming = ::triton::tools::getBoolEnv("MLIR_ENABLE_TIMING"); diff --git a/python/src/main.cc b/python/src/main.cc index 5ad4be7d5584..600510ad770e 100644 --- a/python/src/main.cc +++ b/python/src/main.cc @@ -1,13 +1,14 @@ #include namespace py = pybind11; +#define EXPAND(x) x #define FOR_EACH_1(MACRO, X) MACRO(X) #define FOR_EACH_2(MACRO, X, ...) MACRO(X) FOR_EACH_1(MACRO, __VA_ARGS__) #define FOR_EACH_3(MACRO, X, ...) MACRO(X) FOR_EACH_2(MACRO, __VA_ARGS__) #define FOR_EACH_4(MACRO, X, ...) MACRO(X) FOR_EACH_3(MACRO, __VA_ARGS__) #define FOR_EACH_NARG(...) FOR_EACH_NARG_(__VA_ARGS__, FOR_EACH_RSEQ_N()) -#define FOR_EACH_NARG_(...) FOR_EACH_ARG_N(__VA_ARGS__) +#define FOR_EACH_NARG_(...) EXPAND(FOR_EACH_ARG_N(__VA_ARGS__)) #define FOR_EACH_ARG_N(_1, _2, _3, _4, N, ...) N #define FOR_EACH_RSEQ_N() 4, 3, 2, 1, 0 @@ -15,8 +16,7 @@ namespace py = pybind11; #define CONCATENATE1(x, y) x##y #define FOR_EACH(MACRO, ...) \ - CONCATENATE(FOR_EACH_, FOR_EACH_NARG_HELPER(__VA_ARGS__))(MACRO, __VA_ARGS__) -#define FOR_EACH_NARG_HELPER(...) FOR_EACH_NARG(__VA_ARGS__) + EXPAND(CONCATENATE(FOR_EACH_, FOR_EACH_NARG(__VA_ARGS__))(MACRO, __VA_ARGS__)) // New macro to remove parentheses #define REMOVE_PARENS(...) __VA_ARGS__ @@ -37,8 +37,8 @@ void init_triton_ir(pybind11::module &&m); void init_triton_llvm(pybind11::module &&m); void init_triton_interpreter(pybind11::module &&m); void init_triton_passes(pybind11::module &&m); -FOR_EACH_P(DECLARE_BACKEND, TRITON_BACKENDS_TUPLE) +FOR_EACH_P(DECLARE_BACKEND, TRITON_BACKENDS_TUPLE) PYBIND11_MODULE(libtriton, m) { m.doc() = "Python bindings to the C++ Triton API"; init_triton_env_vars(m); diff --git a/python/test/unit/language/print_helper.py b/python/test/unit/language/print_helper.py index e032792f3a55..d2461e5cbeb7 100644 --- a/python/test/unit/language/print_helper.py +++ b/python/test/unit/language/print_helper.py @@ -119,7 +119,7 @@ def test_print(func: str, data_type: str): func != "print_multiple_args" and func != "device_print_multiple_args" and \ func != "device_print_pointer": assert_close(y, x) - + torch.cuda.synchronize() if __name__ == "__main__": test_print(sys.argv[1], sys.argv[2]) diff --git a/python/test/unit/language/test_compile_errors.py b/python/test/unit/language/test_compile_errors.py index 0531f8ebc33e..2bb5cbcebf85 100644 --- a/python/test/unit/language/test_compile_errors.py +++ b/python/test/unit/language/test_compile_errors.py @@ -4,6 +4,7 @@ import triton.language as tl from triton.compiler.errors import CompilationError, CompileTimeAssertionFailure import traceback +import platform def test_err_undefined_variable(): @@ -131,7 +132,8 @@ def kernel(): try: inner = e.value.__cause__ outer = e.value - assert "/core.py" in '\n'.join(traceback.format_tb(inner.__traceback__)), "error should point inside core.py" + target = "\\core.py" if platform.system() == 'Windows' else "/core.py" + assert target in '\n'.join(traceback.format_tb(inner.__traceback__)), "error should point inside core.py" assert "at 2:4:" in str(outer), "error should point to expand_dims call" assert "" not in str(outer) diff --git a/python/test/unit/language/test_core.py b/python/test/unit/language/test_core.py index 6e9aca48aef9..9e518bf642e7 100644 --- a/python/test/unit/language/test_core.py +++ b/python/test/unit/language/test_core.py @@ -2500,7 +2500,7 @@ def test_scan_layouts(M, N, src_layout, axis, device): }} """ - with tempfile.NamedTemporaryFile(mode='w', suffix='.ttgir') as f: + with tempfile.NamedTemporaryFile(mode='w', suffix='.ttgir', delete=False) as f: f.write(ir) f.flush() kernel = triton.compile(f.name) @@ -2648,9 +2648,10 @@ def test_reduce_layouts(M, N, src_layout, axis, epilogue_kind, dtype_str, reduce }}) {{axis = {axis} : i32}} : (tensor<{M}x{N}x{ty}, #src>) -> tensor<{rdims_1d}x{ty}, #{GPU_DIALECT}.slice<{{dim = {axis}, parent = #src}}>> """ + epilogue - with tempfile.NamedTemporaryFile(mode='w', suffix='.ttgir') as f: + with tempfile.NamedTemporaryFile(mode='w', suffix='.ttgir',delete=False) as f: f.write(ir) f.flush() + f.close() kernel = triton.compile(f.name) rs = RandomState(17) @@ -2702,7 +2703,7 @@ def test_store_op(M, src_layout, device): }} """ - with tempfile.NamedTemporaryFile(mode='w', suffix='.ttgir') as f: + with tempfile.NamedTemporaryFile(mode='w', suffix='.ttgir',delete=False) as f: f.write(ir) f.flush() store_kernel = triton.compile(f.name) @@ -2752,7 +2753,7 @@ def test_convert1d(M, src_layout, dst_layout, src_dim, dst_dim, device): }} }} """ - with tempfile.NamedTemporaryFile(mode='w', suffix='.ttgir') as f: + with tempfile.NamedTemporaryFile(mode='w', suffix='.ttgir',delete=False) as f: f.write(ir) f.flush() kernel = triton.compile(f.name) @@ -2834,7 +2835,7 @@ def test_chain_reduce(M, N, src_layout, op, device, first_axis): }} }} """ - with tempfile.NamedTemporaryFile(mode='w', suffix='.ttgir') as f: + with tempfile.NamedTemporaryFile(mode='w', suffix='.ttgir',delete=False) as f: f.write(ir) f.flush() kernel = triton.compile(f.name) @@ -4891,7 +4892,7 @@ def test_convert2d(M, N, src_layout, interm_layout, dst_layout, dtype, device): x = to_triton(numpy_random((M, N), dtype_str=dtype), device=device) z = torch.empty_like(x, device=device) - with tempfile.NamedTemporaryFile(mode='w', suffix='.ttgir') as f: + with tempfile.NamedTemporaryFile(mode='w', suffix='.ttgir',delete=False) as f: f.write(ir) f.flush() kernel = triton.compile(f.name) @@ -4985,7 +4986,7 @@ def do_test(src_layout, dst_layout): x = to_triton(numpy_random((M, N), dtype_str=dtype), device=device) z = torch.empty_like(x) - with tempfile.NamedTemporaryFile(mode='w', suffix='.ttgir') as f: + with tempfile.NamedTemporaryFile(mode='w', suffix='.ttgir',delete=False) as f: f.write(ir) f.flush() kernel = triton.compile(f.name) @@ -5330,72 +5331,3 @@ def test_tl_range(device): ptx = pgm.asm['ptx'] # check that the loop got pipelined with the right number of stages. assert 'cp.async.wait_group 0x6' in ptx - - -@triton.jit(noinline=True) -def maxnreg_noinline1(X): - tl.store(X, 0) - - -@triton.jit(noinline=True) -def maxnreg_noinline2(X): - tl.store(X, 0) - - -def test_maxnreg(device): - assert not is_interpreter(), "this test won't work with the interpreter" - if is_hip(): - pytest.skip('maxnreg only works on CUDA') - - # triton kernel - @triton.jit - def kernel(X): - maxnreg_noinline1(X) - tl.store(X, 0) - maxnreg_noinline2(X) - - X = torch.empty(1, dtype=torch.int32, device=device) - k = kernel[(1, )](X, maxnreg=42) - - # Ensure that .maxnreg is set on the kernel function (marked with .entry) - # and not on either of the noinline functions (marked with .func). - try: - assert re.search(r'\.visible \.entry [^{;]*\.maxnreg 42', k.asm["ptx"]) - assert not re.search(r'\.visible \.func [^{;]*\.maxnreg', k.asm["ptx"]) - except AssertionError: - print("Failing ptx:\n", k.asm["ptx"]) - raise - - -@pytest.mark.interpreter -def test_temp_var_in_loop(device): - - @triton.jit - def temp_in_loop(Z, N: tl.constexpr, BLOCK: tl.constexpr): - acc = tl.full((BLOCK, ), 0, dtype=tl.int32) - for i in range(N): - if i == 0: - temp = tl.full((BLOCK, ), 2, dtype=tl.int32) - acc = temp - else: - acc += tl.full((BLOCK, ), 1, dtype=tl.int32) - # re-use the temp variable and make sure to check that it isn't creating incorrect IR. - temp = tl.full((BLOCK, ), 1, dtype=tl.int32) - acc += temp - z = Z + tl.arange(0, BLOCK) - tl.store(z, acc) - - N = 10 - BLOCK = 32 - out = torch.empty((BLOCK, ), dtype=torch.int32, device=device) - temp_in_loop[(1, )](out, N, BLOCK) - acc = torch.full((BLOCK, ), 0, dtype=torch.int32, device=device) - for i in range(N): - if i == 0: - temp = torch.full((BLOCK, ), 2, dtype=torch.int32, device=device) - acc = temp - else: - acc += torch.full((BLOCK, ), 1, dtype=torch.int32, device=device) - temp = torch.full((BLOCK, ), 1, dtype=torch.int32, device=device) - acc += temp - assert (acc == out).all() diff --git a/python/test/unit/language/test_reproducer.py b/python/test/unit/language/test_reproducer.py index a045e8f30e2c..ab19c3a6a1e3 100644 --- a/python/test/unit/language/test_reproducer.py +++ b/python/test/unit/language/test_reproducer.py @@ -15,7 +15,7 @@ def triton_(): @pytest.mark.skipif(not torch.cuda.is_available(), reason="requires cuda") def test_reproducer(): - tmpdir = ".tmp" + tmpdir = os.path.abspath(".tmp") reproducer = 'triton-reproducer.mlir' if os.path.exists(tmpdir): shutil.rmtree(tmpdir, ignore_errors=True) diff --git a/python/test/unit/language/test_subprocess.py b/python/test/unit/language/test_subprocess.py index 683a02a56e78..65632ff09399 100644 --- a/python/test/unit/language/test_subprocess.py +++ b/python/test/unit/language/test_subprocess.py @@ -48,8 +48,7 @@ def test_print(func_type: str, data_type: str): # Only check if there's no error assert err == b'' return - - outs = [line for line in outs.decode("UTF-8").split("\n") if line] + outs = [line for line in outs.decode("UTF-8").replace('\r', '').split("\n") if line] # The total number of elements in the 1-D tensor to print. N = 128 diff --git a/python/test/unit/operators/test_flash_attention.py b/python/test/unit/operators/test_flash_attention.py index f5cff538eb5f..ada4dbd66779 100644 --- a/python/test/unit/operators/test_flash_attention.py +++ b/python/test/unit/operators/test_flash_attention.py @@ -112,6 +112,16 @@ def bench_flash_attention(BATCH, H, N_CTX, D_HEAD, mode, casual, seq_par, provid ms = triton.testing.do_bench(fn, warmup=warmup, rep=rep) return ms - + # Hardcode the parameters +Z = 2 +H = 4 +N_CTX = 512 +D_HEAD = 128 +dtype = torch.float16 # or torch.bfloat16 if you want to test with bfloat16 +causal = True +seq_par = True +device = 'cuda' # or 'cpu' if you want to test on CPU +# Call the test function with hardcoded parameters +#test_op(Z, H, N_CTX, D_HEAD, dtype, causal, seq_par, device) # only works on post-Ampere GPUs right now # bench_flash_attention.run(save_path='.', print_data=True) diff --git a/python/test/unit/runtime/test_cache.py b/python/test/unit/runtime/test_cache.py index 44632fd1d677..278225213a68 100644 --- a/python/test/unit/runtime/test_cache.py +++ b/python/test/unit/runtime/test_cache.py @@ -11,7 +11,7 @@ import triton.language as tl from triton.runtime.jit import JITFunction -tmpdir = ".tmp" +tmpdir = os.path.abspath(".tmp") @triton.jit diff --git a/python/test/unit/runtime/test_subproc.py b/python/test/unit/runtime/test_subproc.py index eb3704cf2b4f..1884cb187bbe 100644 --- a/python/test/unit/runtime/test_subproc.py +++ b/python/test/unit/runtime/test_subproc.py @@ -8,7 +8,7 @@ import triton.language as tl from triton.compiler import ASTSource -tmpdir = ".tmp" +tmpdir = os.path.abspath(".tmp") target = triton.runtime.driver.active.get_current_target() diff --git a/python/triton/runtime/CLFinder.py b/python/triton/runtime/CLFinder.py new file mode 100644 index 000000000000..d2e77e0b4cce --- /dev/null +++ b/python/triton/runtime/CLFinder.py @@ -0,0 +1,54 @@ +import os +import subprocess +from pathlib import Path + +def find_vswhere(): + program_files = os.environ.get("ProgramFiles(x86)", "C:\\Program Files (x86)") + vswhere_path = Path(program_files) / "Microsoft Visual Studio" / "Installer" / "vswhere.exe" + if vswhere_path.exists(): + return vswhere_path + return None + +def find_visual_studio(version_ranges): + vswhere = find_vswhere() + if not vswhere: + raise FileNotFoundError("vswhere.exe not found.") + + for version_range in version_ranges: + command = [ + str(vswhere), + "-version", version_range, + "-requires", "Microsoft.VisualStudio.Component.VC.Tools.x86.x64", + "-property", "installationPath", + "-prerelease" + ] + + try: + output = subprocess.check_output(command, text=True).strip() + if output: + return output + except subprocess.CalledProcessError: + continue + + return None + +def set_env_vars(vs_path, arch="x64"): + vcvarsall_path = Path(vs_path) / "VC" / "Auxiliary" / "Build" / "vcvarsall.bat" + if not vcvarsall_path.exists(): + raise FileNotFoundError(f"vcvarsall.bat not found in expected path: {vcvarsall_path}") + + command = f'call "{vcvarsall_path}" {arch} && set' + output = subprocess.check_output(command, shell=True, text=True) + + for line in output.splitlines(): + if '=' in line: + var, value = line.split('=', 1) + os.environ[var] = value +def initialize_visual_studio_env(version_ranges, arch="x64"): + # Check if the environment variable that vcvarsall.bat sets is present + if os.environ.get('VSCMD_ARG_TGT_ARCH') != arch: + vs_path = find_visual_studio(version_ranges) + print(vs_path) + if not vs_path: + raise EnvironmentError("Visual Studio not found in specified version ranges.") + set_env_vars(vs_path, arch) diff --git a/python/triton/runtime/build.py b/python/triton/runtime/build.py index 9726836a2939..e757b8d3a117 100644 --- a/python/triton/runtime/build.py +++ b/python/triton/runtime/build.py @@ -6,7 +6,8 @@ import shutil import subprocess import setuptools - +import platform +from .CLFinder import initialize_visual_studio_env @contextlib.contextmanager def quiet(): @@ -22,10 +23,13 @@ def _cc_cmd(cc, src, out, include_dirs, library_dirs, libraries): if cc in ["cl", "clang-cl"]: cc_cmd = [cc, src, "/nologo", "/O2", "/LD"] cc_cmd += [f"/I{dir}" for dir in include_dirs] + cc_cmd += [f"/Fo{os.path.join(os.path.dirname(out), 'main.obj')}"] cc_cmd += ["/link"] + cc_cmd += [f"/OUT:{out}"] + cc_cmd += [f"/IMPLIB:{os.path.join(os.path.dirname(out), 'main.lib')}"] + cc_cmd += [f"/PDB:{os.path.join(os.path.dirname(out), 'main.pdb')}"] cc_cmd += [f"/LIBPATH:{dir}" for dir in library_dirs] cc_cmd += [f'{lib}.lib' for lib in libraries] - cc_cmd += [f"/OUT:{out}"] else: cc_cmd = [cc, src, "-O3", "-shared", "-fPIC"] cc_cmd += [f'-l{lib}' for lib in libraries] @@ -48,6 +52,9 @@ def _build(name, src, srcdir, library_dirs, include_dirs, libraries): clang = shutil.which("clang") gcc = shutil.which("gcc") cc = gcc if gcc is not None else clang + if platform.system() == "Windows": + cc = "cl" + initialize_visual_studio_env(["[17.0,18.0)", "[16.0,17.0)"]) if cc is None: raise RuntimeError("Failed to find C compiler. Please specify via CC environment variable.") # This function was renamed and made public in Python 3.10 diff --git a/third_party/f2reduce/f2reduce.cpp b/third_party/f2reduce/f2reduce.cpp index e3aa7dfe9f8c..2f099f473495 100644 --- a/third_party/f2reduce/f2reduce.cpp +++ b/third_party/f2reduce/f2reduce.cpp @@ -2,6 +2,14 @@ #include #include "f2reduce.h" +#ifdef _WIN32 +#include // For _mm_prefetch +#define __restrict__ __restrict +#define __builtin_prefetch(addr) _mm_prefetch(reinterpret_cast(addr), _MM_HINT_T0) +#define ATRIBUTE_NOINLINE __declspec(noinline) +#else +#define ATRIBUTE_NOINLINE __attribute__((noinline)) +#endif namespace { void swap_rows(uint64_t* __restrict__ x, uint64_t* __restrict__ y, uint64_t n) { @@ -12,7 +20,7 @@ void swap_rows(uint64_t* __restrict__ x, uint64_t* __restrict__ y, uint64_t n) { // the noinline attribute is necessary for gcc to properly vectorise this: template -__attribute__ ((noinline)) void memxor_lop7(uint64_t* __restrict__ dst, +ATRIBUTE_NOINLINE void memxor_lop7(uint64_t* __restrict__ dst, const uint64_t* __restrict__ src1, const uint64_t* __restrict__ src2, const uint64_t* __restrict__ src3, @@ -25,7 +33,7 @@ __attribute__ ((noinline)) void memxor_lop7(uint64_t* __restrict__ dst, } template -__attribute__ ((noinline)) void memxor_lop5(uint64_t* __restrict__ dst, +ATRIBUTE_NOINLINE void memxor_lop5(uint64_t *__restrict__ dst, const uint64_t* __restrict__ src1, const uint64_t* __restrict__ src2, const uint64_t* __restrict__ src3, @@ -36,7 +44,7 @@ __attribute__ ((noinline)) void memxor_lop5(uint64_t* __restrict__ dst, } template -__attribute__ ((noinline)) void memxor_lop3(uint64_t* __restrict__ dst, +ATRIBUTE_NOINLINE void memxor_lop3(uint64_t *__restrict__ dst, const uint64_t* __restrict__ src1, const uint64_t* __restrict__ src2) { for (uint64_t i = 0; i < N; i++) { diff --git a/third_party/nvidia/backend/compiler.py b/third_party/nvidia/backend/compiler.py index ef6297cae612..9427c16f7889 100644 --- a/third_party/nvidia/backend/compiler.py +++ b/third_party/nvidia/backend/compiler.py @@ -297,7 +297,7 @@ def make_cubin(src, metadata, opt, capability): cmd += [f'--gpu-name=sm_{capability}{suffix}', fsrc.name, '-o', fbin] try: - subprocess.run(cmd, check=True, stderr=flog) + subprocess.run(cmd, check=True, stdout=flog, stderr=flog) except subprocess.CalledProcessError as e: with open(flog.name) as log_file: log = log_file.read() diff --git a/third_party/nvidia/backend/driver.c b/third_party/nvidia/backend/driver.c index 9a61942d1ef8..97b06fe6da73 100644 --- a/third_party/nvidia/backend/driver.c +++ b/third_party/nvidia/backend/driver.c @@ -8,7 +8,9 @@ #include #define PY_SSIZE_T_CLEAN #include +#ifndef _WIN32 #include +#endif // Raises a Python exception and returns false if code is not CUDA_SUCCESS. static bool gpuAssert(CUresult code, const char *file, int line) { diff --git a/third_party/nvidia/backend/driver.py b/third_party/nvidia/backend/driver.py index 374e7582f095..0fe207b16e95 100644 --- a/third_party/nvidia/backend/driver.py +++ b/third_party/nvidia/backend/driver.py @@ -4,6 +4,7 @@ import sysconfig import subprocess import tempfile +import sys from pathlib import Path from triton.runtime.build import _build from triton.runtime.cache import get_cache_manager @@ -15,7 +16,16 @@ libdevice_dir = os.path.join(dirname, "lib") libraries = ['cuda'] +# Extract major and minor version of Python +major, minor = sys.version_info[:2] + +# Generate the library name by concatenating 'python' with the major and minor version +python_version_str = f"python{major}{minor}" + +# Append the generated library name to the libraries list + if os.name == "nt": + libraries.append(python_version_str) include_dir += [os.path.join(os.environ.get("CUDA_PATH"), "include")] @@ -142,7 +152,7 @@ def format_of(ty): "int8_t": "b", "int16_t": "h", "int32_t": "i", - "int64_t": "l", + "int64_t": "L", "uint8_t": "B", "uint16_t": "H", "uint32_t": "I", @@ -227,7 +237,7 @@ def format_of(ty): #endif static void _launch(int gridX, int gridY, int gridZ, int num_warps, int num_ctas, int clusterDimX, int clusterDimY, int clusterDimZ, int shared_memory, CUstream stream, CUfunction function{', ' + arg_decls if len(arg_decls) > 0 else ''}) {{ - void *params[] = {{ {', '.join(f"&arg{i}" for i in params)} }}; + void *params[] = {{{', '.join(f'&arg{i}' for i in params) if params else 'NULL'}}}; if (gridX*gridY*gridZ > 0) {{ if (num_ctas == 1) {{ CUDA_CHECK(cuLaunchKernel(function, gridX, gridY, gridZ, 32*num_warps, 1, 1, shared_memory, stream, params, 0)); diff --git a/unittest/Conversion/TritonGPUToLLVM/CMakeLists.txt b/unittest/Conversion/TritonGPUToLLVM/CMakeLists.txt index 3c5692a62639..98854afd9219 100644 --- a/unittest/Conversion/TritonGPUToLLVM/CMakeLists.txt +++ b/unittest/Conversion/TritonGPUToLLVM/CMakeLists.txt @@ -10,10 +10,15 @@ add_triton_ut( LIBS TritonGPUIR TritonNvidiaGPUIR TritonNVIDIAGPUToLLVM DEFS NVIDIA_TARGET=1 ) - +if(WIN32) +# Set C++20 standard for TestEmitIndicesNvidia target +target_compile_features(TestEmitIndicesNvidia PRIVATE cxx_std_20) +endif() +if(NOT WIN32) add_triton_ut( NAME TestEmitIndicesAMD SRCS EmitIndicesTest.cpp DumpLayout.cpp LIBS TritonGPUIR TritonAMDGPUToLLVM DEFS AMD_TARGET=1 ) +endif()