From 05f08428c9619a92987dc9d8b42c20e616166656 Mon Sep 17 00:00:00 2001 From: Won-Kyu Park Date: Sat, 2 Dec 2023 03:44:27 +0900 Subject: [PATCH 1/8] based on Windows support PR #2465 by @andreigh * based on https://github.com/openai/triton/pull/2465 * manually applied, rebased, fix lint errors * use set_target_properties(), cleanup for windows * remove '/A' platform option to use windows ninja * remove unknown option '/m' * use sysconfig.get_config_var() to get the path of python*.lib * clang fix for windows * remove '-fPIC' for windows clang * fix download_and_copy() to support windows * add "exe" extension for windows * use "pyd" extension for windows to make importlib work Original-author-by: Andrei Gheorghe Signed-off-by: Won-Kyu Park --- .gitignore | 1 + CMakeLists.txt | 50 ++++++++++---- bin/CMakeLists.txt | 1 + lib/Conversion/TritonGPUToLLVM/CMakeLists.txt | 1 + python/setup.py | 68 +++++++++++-------- python/triton/common/backend.py | 9 ++- python/triton/common/build.py | 33 +++++++-- python/triton/compiler/make_launcher.py | 2 +- python/triton/runtime/driver.py | 2 +- 9 files changed, 113 insertions(+), 54 deletions(-) diff --git a/.gitignore b/.gitignore index 0180cd911245..05f922a11698 100644 --- a/.gitignore +++ b/.gitignore @@ -6,6 +6,7 @@ python/build/ python/triton.egg-info/ python/triton/_C/libtriton.pyd python/triton/_C/libtriton.so +python/triton/_C/triton.dll # Python caches __pycache__/ diff --git a/CMakeLists.txt b/CMakeLists.txt index 9622a0f0074d..d2c74bdbfefd 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -28,8 +28,17 @@ set(TRITON_CODEGEN_BACKENDS "" CACHE STRING "Enable different codegen backends") # used conditionally in this file and by lit tests # Customized release build type with assertions: TritonRelBuildWithAsserts -set(CMAKE_C_FLAGS_TRITONRELBUILDWITHASSERTS "-O2 -g") -set(CMAKE_CXX_FLAGS_TRITONRELBUILDWITHASSERTS "-O2 -g") +if(NOT MSVC) + set(CMAKE_C_FLAGS_TRITONRELBUILDWITHASSERTS "-O2 -g") + set(CMAKE_CXX_FLAGS_TRITONRELBUILDWITHASSERTS "-O2 -g") +else() + set(CMAKE_C_FLAGS_TRITONRELBUILDWITHASSERTS "/Zi /Ob0 /Od /RTC1") + set(CMAKE_CXX_FLAGS_TRITONRELBUILDWITHASSERTS "/Zi /Ob0 /Od /RTC1") + set(CMAKE_EXE_LINKER_FLAGS_TRITONRELBUILDWITHASSERTS "/debug:fastlink /INCREMENTAL") + set(CMAKE_MODULE_LINKER_FLAGS_TRITONRELBUILDWITHASSERTS "/debug:fastlink /INCREMENTAL") + set(CMAKE_SHARED_LINKER_FLAGS_TRITONRELBUILDWITHASSERTS "/debug:fastlink /INCREMENTAL") + set(CMAKE_STATIC_LINKER_FLAGS_TRITONRELBUILDWITHASSERTS "/debug:fastlink /INCREMENTAL") +endif() # Default build type if(NOT CMAKE_BUILD_TYPE) @@ -47,7 +56,15 @@ include_directories(${CMAKE_CURRENT_SOURCE_DIR}/include) # Third-party include_directories(${PYBIND11_INCLUDE_DIR}) -set(CMAKE_CXX_FLAGS "${CMAKE_C_FLAGS} -D__STDC_FORMAT_MACROS -fPIC -std=gnu++17 -fvisibility=hidden -fvisibility-inlines-hidden") +if(NOT MSVC) + if(NOT WIN32) + set(CMAKE_CXX_FLAGS "${CMAKE_C_FLAGS} -D__STDC_FORMAT_MACROS -fPIC -std=gnu++17 -fvisibility=hidden -fvisibility-inlines-hidden") + else() + set(CMAKE_CXX_FLAGS "${CMAKE_C_FLAGS} -D__STDC_FORMAT_MACROS -std=gnu++17 -Wno-deprecated -fvisibility=hidden -fvisibility-inlines-hidden") + endif() +else() + set(CMAKE_CXX_FLAGS "${CMAKE_C_FLAGS} -D__STDC_FORMAT_MACROS /wd4244 /wd4624 /wd4715 /wd4530") +endif() if(APPLE) set(CMAKE_OSX_DEPLOYMENT_TARGET 11.6) @@ -59,7 +76,7 @@ endif() if(NOT MLIR_DIR) if(NOT LLVM_LIBRARY_DIR) if(WIN32) - find_package(LLVM 13 REQUIRED COMPONENTS nvptx amdgpu) + find_package(LLVM 17 REQUIRED COMPONENTS nvptx amdgpu) include_directories(${LLVM_INCLUDE_DIRS}) separate_arguments(LLVM_DEFINITIONS_LIST NATIVE_COMMAND ${LLVM_DEFINITIONS}) @@ -158,6 +175,8 @@ if(TRITON_BUILD_PYTHON_MODULE) if(PYTHON_INCLUDE_DIRS) include_directories(${PYTHON_INCLUDE_DIRS}) + message(STATUS "PYTHON_LIB_DIRS ${PYTHON_LIB_DIRS}") + link_directories(${PYTHON_LIB_DIRS}) else() find_package(Python3 REQUIRED COMPONENTS Development Interpreter) include_directories(${Python3_INCLUDE_DIRS}) @@ -167,16 +186,6 @@ if(TRITON_BUILD_PYTHON_MODULE) endif() endif() -# # Triton -# file(GLOB_RECURSE LIBTRITON_SRC lib/*.cc) -# if (WIN32 AND TRITON_BUILD_PYTHON_MODULE) -# Python3_add_library(triton SHARED ${LIBTRITON_SRC} ${PYTHON_SRC}) -# set_target_properties(triton PROPERTIES SUFFIX ".pyd") -# set_target_properties(triton PROPERTIES PREFIX "lib") -# else() -# add_library(triton SHARED ${LIBTRITON_SRC} ${PYTHON_SRC}) -# endif() - # MLIR find_package(MLIR REQUIRED CONFIG PATHS ${MLIR_DIR}) @@ -188,7 +197,11 @@ include(AddLLVM) include(AddMLIR) # Disable warnings that show up in external code (gtest;pybind11) -set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Werror -Wno-covered-switch-default") +if(NOT MSVC) + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Werror -Wno-covered-switch-default") +else() + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} /WX-") +endif() include_directories(${MLIR_INCLUDE_DIRS}) include_directories(${LLVM_INCLUDE_DIRS}) @@ -241,6 +254,8 @@ if(TRITON_BUILD_PYTHON_MODULE) target_link_libraries(triton PRIVATE ${LLVM_LIBRARIES} ${CMAKE_DL_LIBS} ${TRITON_LIBRARIES} ) + set_target_properties(triton PROPERTIES SUFFIX ".pyd") + set_target_properties(triton PROPERTIES PREFIX "lib") elseif(APPLE) target_link_libraries(triton ${LLVM_LIBRARIES} z ${TRITON_LIBRARIES} @@ -277,6 +292,11 @@ if (${CODEGEN_BACKENDS_LEN} GREATER 0) endforeach() endif() +if(WIN32) + option(CMAKE_USE_WIN32_THREADS_INIT "using WIN32 threads" ON) + option(gtest_disable_pthreads "Disable uses of pthreads in gtest." ON) +endif() + add_subdirectory(bin) add_subdirectory(test) add_subdirectory(unittest) diff --git a/bin/CMakeLists.txt b/bin/CMakeLists.txt index 4fb3396d8021..fbbd2f0ef190 100644 --- a/bin/CMakeLists.txt +++ b/bin/CMakeLists.txt @@ -44,6 +44,7 @@ target_link_libraries(triton-reduce PRIVATE mlir_check_all_link_libraries(triton-reduce) add_llvm_executable(triton-llvm-opt + PARTIAL_SOURCES_INTENDED triton-llvm-opt.cpp DEPENDS diff --git a/lib/Conversion/TritonGPUToLLVM/CMakeLists.txt b/lib/Conversion/TritonGPUToLLVM/CMakeLists.txt index 590bc6b99da0..c4d336366e80 100644 --- a/lib/Conversion/TritonGPUToLLVM/CMakeLists.txt +++ b/lib/Conversion/TritonGPUToLLVM/CMakeLists.txt @@ -56,6 +56,7 @@ add_mlir_conversion_library(TritonGPUToLLVM ) add_mlir_library(ASMBuilder + PARTIAL_SOURCES_INTENDED GCNAsmFormat.cpp PTXAsmFormat.cpp diff --git a/python/setup.py b/python/setup.py index 32eb9eb8b24a..180ce68f7676 100644 --- a/python/setup.py +++ b/python/setup.py @@ -115,7 +115,7 @@ def get_thirdparty_packages(triton_cache_path): if p.syspath_var_name in os.environ: package_dir = os.environ[p.syspath_var_name] version_file_path = os.path.join(package_dir, "version.txt") - if p.syspath_var_name not in os.environ and\ + if p.syspath_var_name not in os.environ and p.url and\ (not os.path.exists(version_file_path) or Path(version_file_path).read_text() != p.url): try: shutil.rmtree(package_root_dir) @@ -128,6 +128,9 @@ def get_thirdparty_packages(triton_cache_path): # write version url to package_dir with open(os.path.join(package_dir, "version.txt"), "w") as f: f.write(p.url) + elif p.syspath_var_name not in os.environ and not p.url: + raise RuntimeError( + f'{p.syspath_var_name} not set ! Please install {p.package} manually and set {p.syspath_var_name}.') if p.include_flag: thirdparty_cmake_args.append(f"-D{p.include_flag}={package_dir}/include") if p.lib_flag: @@ -143,15 +146,18 @@ def download_and_copy(src_path, variable, version, url_func): return base_dir = os.path.dirname(__file__) arch = platform.machine() - if arch == "x86_64": + if arch in ["x86_64", "AMD64"]: arch = "64" - url = url_func(arch, version) + supported = {"Linux": "linux", "Windows": "win"} + is_supported = platform.system() in supported + if is_supported: + url = url_func(supported[platform.system()], arch, version) dst_prefix = os.path.join(base_dir, "triton") dst_suffix = os.path.join("third_party", "cuda", src_path) dst_path = os.path.join(dst_prefix, dst_suffix) - is_linux = platform.system() == "Linux" + dst_path += ".exe" if os.name == "nt" else "" download = False - if is_linux: + if is_supported: download = True if os.path.exists(dst_path): curr_version = subprocess.check_output([dst_path, "--version"]).decode("utf-8").strip() @@ -163,6 +169,7 @@ def download_and_copy(src_path, variable, version, url_func): with tempfile.TemporaryDirectory() as temp_dir: file.extractall(path=temp_dir) src_path = os.path.join(temp_dir, src_path) + src_path += ".exe" if os.name == "nt" else "" os.makedirs(os.path.split(dst_path)[0], exist_ok=True) shutil.copy(src_path, dst_path) @@ -262,6 +269,10 @@ def build_extension(self, ext): "-DCMAKE_VERBOSE_MAKEFILE:BOOL=ON", "-DPYTHON_INCLUDE_DIRS=" + python_include_dir, ] + if platform.system() == "Windows": + installed_base = sysconfig.get_config_var('installed_base') + py_lib_dirs = os.getenv("PYTHON_LIB_DIRS", os.path.join(installed_base, "libs")) + cmake_args.append("-DPYTHON_LIB_DIRS=" + py_lib_dirs) if lit_dir is not None: cmake_args.append("-DLLVM_EXTERNAL_LIT=" + lit_dir) cmake_args.extend(thirdparty_cmake_args) @@ -276,10 +287,8 @@ def build_extension(self, ext): cmake_args += ["-DTRITON_CODEGEN_BACKENDS=" + all_codegen_backends] if platform.system() == "Windows": + cmake_args += ["-DCMAKE_BUILD_TYPE=" + cfg] cmake_args += [f"-DCMAKE_RUNTIME_OUTPUT_DIRECTORY_{cfg.upper()}={extdir}"] - if sys.maxsize > 2**32: - cmake_args += ["-A", "x64"] - build_args += ["--", "/m"] else: cmake_args += ["-DCMAKE_BUILD_TYPE=" + cfg] max_jobs = os.getenv("MAX_JOBS", str(2 * os.cpu_count())) @@ -321,27 +330,28 @@ def build_extension(self, ext): subprocess.check_call(["cmake", "--build", ".", "--target", "mlir-doc"], cwd=cmake_dir) -download_and_copy( - src_path="bin/ptxas", - variable="TRITON_PTXAS_PATH", - version="12.3.52", - url_func=lambda arch, version: - f"https://anaconda.org/nvidia/cuda-nvcc/12.3.52/download/linux-{arch}/cuda-nvcc-{version}-0.tar.bz2", -) -download_and_copy( - src_path="bin/cuobjdump", - variable="TRITON_CUOBJDUMP_PATH", - version="12.3.52", - url_func=lambda arch, version: - f"https://anaconda.org/nvidia/cuda-cuobjdump/12.3.52/download/linux-{arch}/cuda-cuobjdump-{version}-0.tar.bz2", -) -download_and_copy( - src_path="bin/nvdisasm", - variable="TRITON_NVDISASM_PATH", - version="12.3.52", - url_func=lambda arch, version: - f"https://anaconda.org/nvidia/cuda-nvdisasm/12.3.52/download/linux-{arch}/cuda-nvdisasm-{version}-0.tar.bz2", -) +if platform.system() in ["Linux", "Windows"]: + download_and_copy( + src_path="bin/ptxas", + variable="TRITON_PTXAS_PATH", + version="12.3.52", + url_func=lambda system, arch, version: + f"https://anaconda.org/nvidia/cuda-nvcc/{version}/download/{system}-{arch}/cuda-nvcc-{version}-0.tar.bz2", + ) + download_and_copy( + src_path="bin/cuobjdump", + variable="TRITON_CUOBJDUMP_PATH", + version="12.3.52", + url_func=lambda system, arch, version: + f"https://anaconda.org/nvidia/cuda-cuobjdump/{version}/download/{system}-{arch}/cuda-cuobjdump-{version}-0.tar.bz2", + ) + download_and_copy( + src_path="bin/nvdisasm", + variable="TRITON_NVDISASM_PATH", + version="12.3.52", + url_func=lambda system, arch, version: + f"https://anaconda.org/nvidia/cuda-nvdisasm/{version}/download/{system}-{arch}/cuda-nvdisasm-{version}-0.tar.bz2", + ) setup( name=os.environ.get("TRITON_WHEEL_NAME", "triton"), diff --git a/python/triton/common/backend.py b/python/triton/common/backend.py index 899b9510bc6b..ada18f51816e 100644 --- a/python/triton/common/backend.py +++ b/python/triton/common/backend.py @@ -106,6 +106,7 @@ def get_backend(device_type: str): def _path_to_binary(binary: str): + binary += ".exe" if os.name == "nt" else "" base_dir = os.path.join(os.path.dirname(__file__), os.pardir) paths = [ os.environ.get(f"TRITON_{binary.upper()}_PATH", ""), @@ -113,7 +114,10 @@ def _path_to_binary(binary: str): ] for p in paths: - bin = p.split(" ")[0] + if os.name != "nt": + bin = p.split(" ")[0] + else: + bin = p if os.path.exists(bin) and os.path.isfile(bin): result = subprocess.check_output([bin, "--version"], stderr=subprocess.STDOUT) if result is not None: @@ -152,7 +156,8 @@ def compute_core_version_key(): contents += [hashlib.sha1(f.read()).hexdigest()] # backend libtriton_hash = hashlib.sha1() - with open(os.path.join(TRITON_PATH, "_C/libtriton.so"), "rb") as f: + ext = "so" if os.name != "nt" else "pyd" + with open(os.path.join(TRITON_PATH, "_C", "libtriton." + ext), "rb") as f: while True: chunk = f.read(1024**2) if not chunk: diff --git a/python/triton/common/build.py b/python/triton/common/build.py index bd8395d4af2d..c39564899c69 100644 --- a/python/triton/common/build.py +++ b/python/triton/common/build.py @@ -22,6 +22,9 @@ def libcuda_dirs(): if env_libcuda_path: return [env_libcuda_path] + if os.name == "nt": + return [os.environ.get("CUDA_PATH") + "\\lib\\x64"] + libs = subprocess.check_output(["/sbin/ldconfig", "-p"]).decode() # each line looks like the following: # libcuda.so.1 (libc6,x86-64) => /lib/x86_64-linux-gnu/libcuda.so.1 @@ -60,6 +63,24 @@ def cuda_include_dir(): return os.path.join(cuda_path, "include") +def _cc_cmd(cc, src, out, include_dirs, library_dirs): + if cc == "cl": + cc_cmd = [cc, src, "/nologo", "/O2", "/LD"] + cc_cmd += [f"/I{dir}" for dir in include_dirs] + cc_cmd += ["/link"] + cc_cmd += [f"/LIBPATH:{dir}" for dir in library_dirs] + cc_cmd += ["cuda.lib", f"/OUT:{out}"] + else: + cc_cmd = [cc, src, "-O3", "-shared", "-fPIC"] + cc_cmd += [f"-I{dir}" for dir in include_dirs] + cc_cmd += [f"-L{dir}" for dir in library_dirs] + cc_cmd += ["-lcuda", "-o", out] + + if os.name == "nt": cc_cmd.pop(cc_cmd.index("-fPIC")) + + return cc_cmd + + def _build(name, src, srcdir): if is_hip(): hip_lib_dir = os.path.join(rocm_path_dir(), "lib") @@ -88,6 +109,10 @@ def _build(name, src, srcdir): if scheme == 'posix_local': scheme = 'posix_prefix' py_include_dir = sysconfig.get_paths(scheme=scheme)["include"] + py_lib_dirs = [] + if os.name == "nt": + installed_base = sysconfig.get_config_var('installed_base') + py_lib_dirs = [os.getenv("PYTHON_LIB_DIRS", os.path.join(installed_base, "libs"))] if is_hip(): ret = subprocess.check_call([ @@ -95,18 +120,14 @@ def _build(name, src, srcdir): f"-L{hip_lib_dir}", "-lamdhip64", "-o", so ]) else: - cc_cmd = [ - cc, src, "-O3", f"-I{cu_include_dir}", f"-I{py_include_dir}", f"-I{srcdir}", "-shared", "-fPIC", "-lcuda", - "-o", so - ] - cc_cmd += [f"-L{dir}" for dir in cuda_lib_dirs] + cc_cmd = _cc_cmd(cc, src, so, [cu_include_dir, py_include_dir, srcdir], [*cuda_lib_dirs, *py_lib_dirs]) ret = subprocess.check_call(cc_cmd) if ret == 0: return so # fallback on setuptools extra_compile_args = [] - library_dirs = cuda_lib_dirs + library_dirs = [*cuda_lib_dirs, *py_lib_dirs] include_dirs = [srcdir, cu_include_dir] libraries = ['cuda'] # extra arguments diff --git a/python/triton/compiler/make_launcher.py b/python/triton/compiler/make_launcher.py index 52a8f74a11eb..84752c942ce6 100644 --- a/python/triton/compiler/make_launcher.py +++ b/python/triton/compiler/make_launcher.py @@ -25,7 +25,7 @@ def make_stub(name, signature, constants, ids, **kwargs): # name of files that are cached so_cache_key = make_so_cache_key(get_cuda_version_key(), signature, constants, ids, **kwargs) so_cache_manager = get_cache_manager(so_cache_key) - so_name = f"{name}.so" + so_name = f'{name}.{"so" if os.name != "nt" else "pyd"}' # retrieve stub from cache if it exists cache_path = so_cache_manager.get_file(so_name) if cache_path is None: diff --git a/python/triton/runtime/driver.py b/python/triton/runtime/driver.py index 1abcb3bafcf7..f7664e57b8e6 100644 --- a/python/triton/runtime/driver.py +++ b/python/triton/runtime/driver.py @@ -59,7 +59,7 @@ def __init__(self): src = Path(os.path.join(dirname, "backends", "cuda.c")).read_text() key = hashlib.md5(src.encode("utf-8")).hexdigest() cache = get_cache_manager(key) - fname = "cuda_utils.so" + fname = "cuda_utils." + ("so" if os.name != "nt" else "pyd") cache_path = cache.get_file(fname) if cache_path is None: with tempfile.TemporaryDirectory() as tmpdir: From 7219ea698d602023e84b8b99f2603142d733fba9 Mon Sep 17 00:00:00 2001 From: Won-Kyu Park Date: Sat, 2 Dec 2023 03:45:53 +0900 Subject: [PATCH 2/8] dlopen fix for win32 * based on Windows support PR #2456 by @andreigh * DISPATCH_ARGS fix by @andreigh * WIN32 fix using LoadLibrary --- python/triton/compiler/make_launcher.py | 25 ++++++++++++++++++ python/triton/runtime/backends/cuda.c | 34 +++++++++++++++++++++++++ 2 files changed, 59 insertions(+) diff --git a/python/triton/compiler/make_launcher.py b/python/triton/compiler/make_launcher.py index 84752c942ce6..e716e70253fe 100644 --- a/python/triton/compiler/make_launcher.py +++ b/python/triton/compiler/make_launcher.py @@ -109,7 +109,12 @@ def format_of(ty): #include \"cuda.h\" #include #include +#ifndef _WIN32 #include +#else +#define WIN32_LEAN_AND_MEAN +#include +#endif static inline void gpuAssert(CUresult code, const char *file, int line) {{ @@ -132,6 +137,7 @@ def format_of(ty): typedef CUresult (*cuLaunchKernelEx_t)(const CUlaunchConfig* config, CUfunction f, void** kernelParams, void** extra); +#ifndef _WIN32 static cuLaunchKernelEx_t getLaunchKernelExHandle() {{ // Open the shared library void* handle = dlopen("libcuda.so", RTLD_LAZY); @@ -150,6 +156,25 @@ def format_of(ty): }} return cuLaunchKernelExHandle; }} +#else +static cuLaunchKernelEx_t getLaunchKernelExHandle() {{ + // Open the shared library + HMODULE handle = LoadLibraryA("nvcuda.dll"); + if (!handle) {{ + PyErr_SetString(PyExc_RuntimeError, "Failed to open nvcuda.dll"); + return NULL; + }} + cuLaunchKernelEx_t cuLaunchKernelExHandle = + (cuLaunchKernelEx_t)GetProcAddress((HMODULE)handle, "cuLaunchKernelEx"); + // Check for errors + long error = GetLastError(); + if (error) {{ + PyErr_SetString(PyExc_RuntimeError, "Failed to retrieve cuLaunchKernelEx from nvcuda.dll"); + return NULL; + }} + return cuLaunchKernelExHandle; +}} +#endif static void _launch(int gridX, int gridY, int gridZ, int num_warps, int num_ctas, int clusterDimX, int clusterDimY, int clusterDimZ, int shared_memory, CUstream stream, CUfunction function{', ' + arg_decls if len(arg_decls) > 0 else ''}) {{ void *params[] = {{ {', '.join(f"&arg{i}" for i in params)} }}; diff --git a/python/triton/runtime/backends/cuda.c b/python/triton/runtime/backends/cuda.c index 928c8fc06e52..8cbb4fca6606 100644 --- a/python/triton/runtime/backends/cuda.c +++ b/python/triton/runtime/backends/cuda.c @@ -1,5 +1,10 @@ #include "cuda.h" +#ifndef _WIN32 #include +#else +#define WIN32_LEAN_AND_MEAN +#include +#endif #include #define PY_SSIZE_T_CLEAN #include @@ -94,10 +99,17 @@ static bool gpuAssert(CUresult code, const char *file, int line) { #define DISPATCH_ARGS_N(_14, _13, _12, _11, _10, _9, _8, _7, _6, _5, _4, _3, \ _2, _1, N, ...) \ ADD_ENUM_ITEM_##N +#if !defined(_MSC_VER) || defined(__clang__) #define DISPATCH_ARGS(...) \ DISPATCH_ARGS_N(__VA_ARGS__, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, \ 0) \ (__VA_ARGS__) +#else +#define EXPAND_ARGS(args) args +#define DISPATCH_ARGS(...) \ + DISPATCH_ARGS_N EXPAND_ARGS((__VA_ARGS__, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, \ + 4, 3, 2, 1, 0))(__VA_ARGS__) +#endif #define ADD_ENUM_TO_MODULE(module, enum_name, ...) \ do { \ @@ -380,6 +392,7 @@ typedef CUresult (*cuTensorMapEncodeTiled_t)( typedef CUresult (*cuOccupancyMaxActiveClusters_t)( int *numClusters, CUfunction func, const CUlaunchConfig *config); +#ifndef _WIN32 #define defineGetFunctionHandle(name, symbolName) \ static symbolName##_t name() { \ /* Open the shared library */ \ @@ -401,6 +414,27 @@ typedef CUresult (*cuOccupancyMaxActiveClusters_t)( } \ return funcHandle; \ } +#else +#define defineGetFunctionHandle(name, symbolName) \ + static symbolName##_t name() { \ + /* Open the shared library */ \ + HMODULE handle = LoadLibraryA("nvcuda.dll"); \ + if (!handle) { \ + PyErr_SetString(PyExc_RuntimeError, "Failed to open nvcuda.dll"); \ + return NULL; \ + } \ + symbolName##_t funcHandle = \ + (symbolName##_t)GetProcAddress((HMODULE)handle, #symbolName); \ + /* Check for errors */ \ + long err = GetLastError(); \ + if (err) { \ + PyErr_SetString(PyExc_RuntimeError, \ + "Failed to retrieve " #symbolName " from nvcuda.dll"); \ + return NULL; \ + } \ + return funcHandle; \ + } +#endif defineGetFunctionHandle(getCuTensorMapEncodeTiledHandle, cuTensorMapEncodeTiled); From 9a2a84af1695a55b21638ffa251d16554ad209b3 Mon Sep 17 00:00:00 2001 From: Won-Kyu Park Date: Fri, 1 Dec 2023 14:53:43 +0900 Subject: [PATCH 3/8] fix compile error clang error "(aka 'long long') must match previous return type 'long' when lambda expression has unspecified explicit return typ" --- .../TritonGPU/Transforms/Pipeliner/MatmulLoopPipeline.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/lib/Dialect/TritonGPU/Transforms/Pipeliner/MatmulLoopPipeline.cpp b/lib/Dialect/TritonGPU/Transforms/Pipeliner/MatmulLoopPipeline.cpp index 69ce2bd54cf9..92bc49c7f724 100644 --- a/lib/Dialect/TritonGPU/Transforms/Pipeliner/MatmulLoopPipeline.cpp +++ b/lib/Dialect/TritonGPU/Transforms/Pipeliner/MatmulLoopPipeline.cpp @@ -717,8 +717,8 @@ void mlir::triton::asyncLaunchDots(scf::ForOp forOp) { lastOp = op; op = op->getBlock()->getParentOp(); } - return std::distance(lastOp->getBlock()->getParent()->begin(), - lastOp->getBlock()->getIterator()); + return (long)std::distance(lastOp->getBlock()->getParent()->begin(), + lastOp->getBlock()->getIterator()); }; /// XXX(Keren): Clean up the following duplicate code with checkDotOp /// dots to be pipelined From 78fa02a66992b2d23183081df6e27f86e4c7fccb Mon Sep 17 00:00:00 2001 From: Won-Kyu Park Date: Mon, 4 Dec 2023 00:35:00 +0900 Subject: [PATCH 4/8] unit/runtime/*.py fix for windows --- python/test/unit/runtime/test_cache.py | 2 +- python/test/unit/runtime/test_subproc.py | 8 ++++++-- 2 files changed, 7 insertions(+), 3 deletions(-) diff --git a/python/test/unit/runtime/test_cache.py b/python/test/unit/runtime/test_cache.py index cd589fa920f5..339dc25e617a 100644 --- a/python/test/unit/runtime/test_cache.py +++ b/python/test/unit/runtime/test_cache.py @@ -70,7 +70,7 @@ def test_nested1_change(): def write_and_load_module(code, num_extra_lines): - with tempfile.NamedTemporaryFile(mode='w+', suffix='.py') as f: + with tempfile.NamedTemporaryFile(mode='w+', suffix='.py', delete=False) as f: f.write(('# extra line\n' * num_extra_lines) + code) f.flush() spec = importlib.util.spec_from_file_location("module.name", f.name) diff --git a/python/test/unit/runtime/test_subproc.py b/python/test/unit/runtime/test_subproc.py index 63401f28e42b..d0ecd771384f 100644 --- a/python/test/unit/runtime/test_subproc.py +++ b/python/test/unit/runtime/test_subproc.py @@ -34,11 +34,15 @@ def kernel_sub(a, b, o, N: tl.constexpr): def test_compile_in_subproc() -> None: + import os major, minor = torch.cuda.get_device_capability(0) cc = major * 10 + minor config = triton.compiler.AttrsDescriptor(tuple(range(4)), (), (), ()) - multiprocessing.set_start_method('fork') + if os.name == "nt": + multiprocessing.set_start_method('spawn') + else: + multiprocessing.set_start_method('fork') proc = multiprocessing.Process(target=compile_fn, args=(config, cc)) proc.start() proc.join() @@ -64,7 +68,7 @@ def test_compile_in_forked_subproc() -> None: capability = major * 10 + minor config = triton.compiler.AttrsDescriptor(tuple(range(1)), (), (), ()) - assert multiprocessing.get_start_method() == 'fork' + assert multiprocessing.get_start_method() in ['fork', 'spawn'] proc = multiprocessing.Process(target=compile_fn_dot, args=(config, capability)) proc.start() proc.join() From d9bc2cc4a1fbd6300fbd898d49870f2a4920f52e Mon Sep 17 00:00:00 2001 From: Won-Kyu Park Date: Mon, 4 Dec 2023 12:47:09 +0900 Subject: [PATCH 5/8] fix MANIFEST.in * fix warning "warning: manifest_maker: MANIFEST.in, line 4: path 'triton/runtime/backends/' cannot end with '/'" --- python/MANIFEST.in | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/MANIFEST.in b/python/MANIFEST.in index c8c818919820..f46c7ce67dee 100644 --- a/python/MANIFEST.in +++ b/python/MANIFEST.in @@ -1,5 +1,5 @@ graft src graft triton/third_party graft triton/tools -graft triton/runtime/backends/ +graft triton/runtime/backends graft triton/language/extra From 474ea6403d6b021a0469cb204cf7708f21a83356 Mon Sep 17 00:00:00 2001 From: Won-Kyu Park Date: Mon, 18 Dec 2023 18:48:49 +0900 Subject: [PATCH 6/8] call subprocess.run() without shell option --- python/triton/compiler/backends/cuda.py | 25 ++++++++++++++----------- 1 file changed, 14 insertions(+), 11 deletions(-) diff --git a/python/triton/compiler/backends/cuda.py b/python/triton/compiler/backends/cuda.py index e6576469e74a..794ef5357507 100644 --- a/python/triton/compiler/backends/cuda.py +++ b/python/triton/compiler/backends/cuda.py @@ -195,13 +195,16 @@ def make_cubin(src, metadata, opt, capability): fsrc.flush() fbin = fsrc.name + '.o' - line_info = '' if os.environ.get('TRITON_DISABLE_LINE_INFO') else ' -lineinfo' - fmad = '' if opt.enable_fp_fusion else ' --fmad=false' - suffix = 'a ' if capability == 90 else ' ' - cmd = f'{ptxas}{line_info}{fmad} -v --gpu-name=sm_{capability}{suffix}{fsrc.name} -o {fbin} 2> {flog.name}' + line_info = '' if os.environ.get('TRITON_DISABLE_LINE_INFO') else '-lineinfo' + fmad = '' if opt.enable_fp_fusion else '--fmad=false' + cmd = [ptxas] + cmd += [line_info] if line_info != '' else [] + cmd += [fmad] if fmad != '' else [] + suffix = 'a' if capability == 90 else '' + cmd += ['-v', f'--gpu-name=sm_{capability}{suffix}', fsrc.name, '-o', fbin] try: - subprocess.run(cmd, shell=True, check=True) + subprocess.run(cmd, check=True, stderr=flog) except subprocess.CalledProcessError as e: with open(flog.name) as log_file: log = log_file.read() @@ -212,16 +215,16 @@ def make_cubin(src, metadata, opt, capability): f'Please run `ptxas {fsrc.name}` to confirm that this is a bug in `ptxas`\n{log}') else: raise RuntimeError(f'`ptxas` failed with error code {e.returncode}: \n{log}') - finally: - if os.path.exists(fsrc.name): - os.remove(fsrc.name) - if os.path.exists(flog.name): - os.remove(flog.name) - with open(fbin, 'rb') as f: cubin = f.read() if os.path.exists(fbin): os.remove(fbin) + + if os.path.exists(fsrc.name): + os.remove(fsrc.name) + if os.path.exists(flog.name): + os.remove(flog.name) + return cubin def add_stages(self, stages, options): From 9739b8df049c7989cb7e9b9bfdf11fa083a9d133 Mon Sep 17 00:00:00 2001 From: Won-Kyu Park Date: Thu, 30 Nov 2023 08:10:20 +0900 Subject: [PATCH 7/8] support windows-latest * use conda for ubuntu-latest * enable windows-latest build * disable artifact check for non self-hosted * build wheels and upload dist artifacts * update build matrix with python-version, etc. * fix deprecated ::set-output, etc. --- .github/workflows/integration-tests.yml | 137 ++++++++++++++++++++---- environment.yml | 18 ++++ 2 files changed, 137 insertions(+), 18 deletions(-) create mode 100644 environment.yml diff --git a/.github/workflows/integration-tests.yml b/.github/workflows/integration-tests.yml index b13c74801ee3..7ecd0b0b1c1e 100644 --- a/.github/workflows/integration-tests.yml +++ b/.github/workflows/integration-tests.yml @@ -23,26 +23,25 @@ jobs: matrix-required: ${{ steps.set-matrix.outputs.matrix-required }} matrix-optional: ${{ steps.set-matrix.outputs.matrix-optional }} steps: - - name: Prepare runner matrix + - name: Prepare matrix id: set-matrix run: | if [ x"${{ github.repository }}" == x"openai/triton" ]; then - echo '::set-output name=matrix-required::[["self-hosted", "A100"], ["self-hosted", "H100"]]' - echo '::set-output name=matrix-optional::[["self-hosted", "gfx908"], ["self-hosted", "arc770"]]' + echo 'matrix-required={"runner": [["self-hosted", "A100"], ["self-hosted", "H100"]], "python-version": ["3.11"], "cuda-version": ["12.1"], "cc": ["clang"]}' >> "$GITHUB_OUTPUT" + echo 'matrix-optional={"runner": [["self-hosted", "gfx908"], ["self-hosted", "arc770"]], "python-version": ["3.11"], "cuda-version": ["12.1"], "cc": ["clang"]}' >> "$GITHUB_OUTPUT" else - echo '::set-output name=matrix-required::["ubuntu-latest"]' - echo '::set-output name=matrix-optional::["ubuntu-latest"]' + echo 'matrix-required={"runner":["ubuntu-latest", "windows-latest"], "python-version": ["3.10", "3.11"], "cuda-version": ["12.1"], "cc": ["clang"]}' >> "$GITHUB_OUTPUT" + echo 'matrix-optional={"runner":["ubuntu-latest", "windows-latest"], "python-version": ["3.10", "3.11"], "cuda-version": ["12.1"], "cc": ["clang"]}' >> "$GITHUB_OUTPUT" fi Integration-Tests: needs: Runner-Preparation runs-on: ${{ matrix.runner }} - timeout-minutes: 20 + timeout-minutes: 60 strategy: - matrix: - runner: ${{fromJson(needs.Runner-Preparation.outputs.matrix-required)}} + matrix: ${{fromJson(needs.Runner-Preparation.outputs.matrix-required)}} steps: - name: Checkout @@ -56,11 +55,90 @@ jobs: echo "ENABLE_TMA=0" >> "${GITHUB_ENV}" echo "TRITON_DISABLE_LINE_INFO=1" >> "${GITHUB_ENV}" + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v3 + with: + python-version: ${{ matrix.python-version }} + + - name: Set up MSVC + if: matrix.runner == 'windows-latest' + uses: ilammy/msvc-dev-cmd@v1.12.1 + with: + arch: amd64 + + - name: Setup Mambaforge (Windows) + if: matrix.runner == 'windows-latest' + uses: conda-incubator/setup-miniconda@v3 + with: + miniforge-variant: Mambaforge + miniforge-version: latest + activate-environment: triton-env + use-mamba: true + + - uses: conda-incubator/setup-miniconda@v3 + if: matrix.runner == 'windows-latest' + with: + activate-environment: triton-env + environment-file: environment.yml + auto-activate-base: true + python-version: ${{ matrix.python-version }} + + - name: set Environment Variables (Windows) + if: matrix.runner == 'windows-latest' + shell: bash -el {0} + run: | + LLVM_SHORTHASH="$(cat cmake/llvm-hash.txt | cut -c1-8)" + # prepare LLVM prebuilt path. will be downloaded and extracted by setup.py step + echo "~/.triton/llvm/llvm-$LLVM_SHORTHASH-windows-x64/bin" >> "$GITHUB_PATH" + #echo "LLVM_SYSPATH=~/.triton/llvm/llvm-$LLVM_SHORTHASH-windows-x64" >> "$GITHUB_ENV" + # compile with a selected matrix.cc + if [ "${{matrix.cc}}" = "cl" ]; then + echo "CC=cl" >> "${GITHUB_ENV}" + echo "CXX=cl" >> "${GITHUB_ENV}" + elif [ "${{matrix.cc}}" = "clang" ]; then + echo "CC=clang" >> "${GITHUB_ENV}" + echo "CXX=clang++" >> "${GITHUB_ENV}" + fi + + - name: CUDA toolkit ${{ matrix.cuda-version }} + shell: bash -el {0} + if: matrix.runner[0] != 'self-hosted' + run: | + if [ "${{ matrix.runner }}" = "ubuntu-latest" ]; then + # prepare space for ubuntu + sudo rm -rf /usr/share/dotnet + sudo rm -rf /opt/ghc + fi + + addon="" + cuda_version=${{ matrix.cuda-version }} + [ "$cuda_version" = "12.1" ] && cuda_version="12.1.1" && addon="cuda-cudart-static cuda-nvrtc" + [ "$cuda_version" = "11.8" ] && cuda_version="11.8.0" + + conda install python=${{ matrix.python-version }} cuda-libraries-dev cuda-nvcc cuda-nvtx cuda-cupti cuda-cudart cuda-cudart-dev cuda-runtime cuda-libraries $addon -c "nvidia/label/cuda-$cuda_version" + + - name: Update environment + if: matrix.runner[0] != 'self-hosted' + shell: bash + run: | + echo "BACKEND=CUDA" >> "${GITHUB_ENV}" + echo "ENABLE_TMA=0" >> "${GITHUB_ENV}" + echo "TRITON_DISABLE_LINE_INFO=1" >> "${GITHUB_ENV}" + + - name: Set reusable strings + # Turn repeated input strings (such as the build output directory) into step outputs. These step outputs can be used throughout the workflow file. + id: strings + shell: bash + run: | + echo "build-output-dir=${{ github.workspace }}/build" >> "$GITHUB_OUTPUT" + - name: Clear cache + shell: bash run: | rm -rf ~/.triton - name: Update PATH + if: matrix.runner[0] == 'self-hosted' run: | echo "PATH=${HOME}/.local/bin:${PATH}" >> "${GITHUB_ENV}" @@ -70,17 +148,29 @@ jobs: python3 -m pre_commit run --all-files --verbose - name: Install Triton - if: ${{ env.BACKEND == 'CUDA'}} + if: matrix.runner != 'windows-latest' run: | cd python python3 -m pip install --upgrade pip - python3 -m pip install cmake==3.24 ninja pytest-xdist + python3 -m pip install cmake==3.24 ninja pytest-xdist wheel sudo apt-get update -y sudo apt-get install -y ccache clang lld TRITON_BUILD_WITH_CLANG_LLD=true TRITON_BUILD_WITH_CCACHE=true python3 -m pip install --no-build-isolation -vvv '.[tests]' + if [ "${{ matrix.runner }}" = 'ubuntu-latest' ]; then + python3 setup.py bdist_wheel + fi + + - name: Install Triton (Windows) + if: matrix.runner == 'windows-latest' + run: | + cd python + python -m pip install --upgrade pip + python -m pip install cmake==3.24 ninja pytest-xdist wheel + python -m pip install --no-build-isolation -vvv . + python setup.py bdist_wheel - name: Run lit tests - if: ${{ env.BACKEND == 'CUDA'}} + if: matrix.runner[0] == 'self-hosted' && env.BACKEND == 'CUDA' run: | python3 -m pip install lit cd python @@ -96,7 +186,7 @@ jobs: echo "ENABLE_TMA=1" >> "${GITHUB_ENV}" - name: Run python tests on CUDA with ENABLE_TMA=1 - if: ${{ env.BACKEND == 'CUDA' && env.ENABLE_TMA == '1'}} + if: ${{(matrix.runner[0] == 'self-hosted') && env.BACKEND == 'CUDA' && env.ENABLE_TMA == '1'}} run: | cd python/test/unit python3 -m pytest -n 8 --ignore=runtime --ignore=operators --ignore=language/test_line_info.py --ignore=language/test_subprocess.py @@ -109,7 +199,7 @@ jobs: python3 -m pytest hopper/test_flashattention.py - name: Run python tests on CUDA with ENABLE_TMA=0 - if: ${{ env.BACKEND == 'CUDA' && env.ENABLE_TMA == '0'}} + if: ${{(matrix.runner[0] == 'self-hosted') && env.BACKEND == 'CUDA' && env.ENABLE_TMA == '0'}} run: | cd python/test/unit python3 -m pytest -n 8 --ignore=runtime --ignore=hopper --ignore=operators --ignore=language/test_line_info.py @@ -119,10 +209,12 @@ jobs: TRITON_DISABLE_LINE_INFO=0 python3 -m pytest language/test_line_info.py - name: Clear cache + shell: bash run: | rm -rf ~/.triton - name: Run interpreter tests + if: matrix.runner[0] == 'self-hosted' env: # TRITON_INTERPRET: "1" CUA_VISIBLE_DEVICES: "" @@ -131,17 +223,25 @@ jobs: python3 -m pytest -vs operators/test_flash_attention.py - name: Run partial tests on CUDA with ENABLE_TMA=1 - if: ${{ env.BACKEND == 'CUDA' && env.ENABLE_TMA == '1'}} + if: ${{(matrix.runner[0] == 'self-hosted') && env.BACKEND == 'CUDA' && env.ENABLE_TMA == '1'}} run: | cd python/test/unit python3 -m pytest -n 8 operators - name: Run partial tests on CUDA with ENABLE_TMA=0 - if: ${{ env.BACKEND == 'CUDA' && env.ENABLE_TMA == '0'}} + if: ${{(matrix.runner[0] == 'self-hosted') && env.BACKEND == 'CUDA' && env.ENABLE_TMA == '0'}} run: | cd python/test/unit python3 -m pytest -n 8 operators + - name: Upload Build artifacts + if: matrix.runner[0] != 'self-hosted' + uses: actions/upload-artifact@v3 + with: + name: triton-dist ${{ matrix.runner }} + path: | + ${{ github.workspace }}/python/dist/ + - name: Create artifacts archive if: ${{(matrix.runner[0] == 'self-hosted') && (matrix.runner[1] == 'V100' || matrix.runner[1] == 'A100' || matrix.runner[1] == 'H100')}} run: | @@ -150,20 +250,20 @@ jobs: - name: Upload artifacts archive if: ${{(matrix.runner[0] == 'self-hosted') && (matrix.runner[1] == 'V100' || matrix.runner[1] == 'A100' || matrix.runner[1] == 'H100')}} - uses: actions/upload-artifact@v2 + uses: actions/upload-artifact@v3 with: name: artifacts ${{ matrix.runner[1] }} path: ~/.triton/artifacts.tar.gz - name: Run CXX unittests - if: ${{ env.BACKEND == 'CUDA'}} + if: ${{(matrix.runner[0] == 'self-hosted') && env.BACKEND == 'CUDA'}} run: | cd python cd "build/$(ls build | grep -i cmake)" ctest - name: Regression tests - if: ${{ contains(matrix.runner, 'A100') }} + if: ${{ (matrix.runner[0] == 'self-hosted') && contains(matrix.runner, 'A100') }} run: | python3 -m pip install pytest-rerunfailures cd python/test/regression @@ -173,6 +273,7 @@ jobs: sudo nvidia-smi -i 0 -rgc Compare-artifacts: + if: ${{(github.repository == 'openai/triton')}} needs: Integration-Tests timeout-minutes: 5 diff --git a/environment.yml b/environment.yml new file mode 100644 index 000000000000..031020b365bb --- /dev/null +++ b/environment.yml @@ -0,0 +1,18 @@ +name: triton +channels: + - conda-forge + - pytorch +dependencies: + - python + - pytest + - pytorch + - torchaudio + - torchvision + - typer + - ca-certificates + - certifi + - openssl + - zlib + - zstd + - llvm>=17.0 + - mlir>=17.0 From 216f92858c12a4a03f502b3ea6d3e78f5320c579 Mon Sep 17 00:00:00 2001 From: Won-Kyu Park Date: Mon, 4 Dec 2023 23:05:23 +0900 Subject: [PATCH 8/8] use windows llvm build --- python/setup.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/python/setup.py b/python/setup.py index 180ce68f7676..3ba6748e849d 100644 --- a/python/setup.py +++ b/python/setup.py @@ -86,6 +86,10 @@ def get_llvm_package_info(): vglibc = tuple(map(int, platform.libc_ver()[1].split('.'))) vglibc = vglibc[0] * 100 + vglibc[1] system_suffix = 'ubuntu-x64' if vglibc > 217 else 'centos-x64' + elif system == "Windows": + if arch == "AMD64": + arch = "x64" + system_suffix = f"windows-{arch}" else: return Package("llvm", "LLVM-C.lib", "", "LLVM_INCLUDE_DIRS", "LLVM_LIBRARY_DIR", "LLVM_SYSPATH") # use_assert_enabled_llvm = check_env_flag("TRITON_USE_ASSERT_ENABLED_LLVM", "False")