Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

win32+clang support #18

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 15 additions & 2 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,11 @@ endif()

# Compiler flags
include_directories(${CMAKE_CURRENT_SOURCE_DIR}/include)
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -D__STDC_FORMAT_MACROS -fPIC -std=gnu++17")
if(NOT WIN32)
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -D__STDC_FORMAT_MACROS -fPIC -std=gnu++17")
else()
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -D__STDC_FORMAT_MACROS -std=gnu++17 -Wno-deprecated")
endif()


# #########
Expand Down Expand Up @@ -146,6 +150,7 @@ if(TRITON_BUILD_PYTHON_MODULE)
# using pip install.
include_directories(${PYTHON_INCLUDE_DIRS})
include_directories(${PYBIND11_INCLUDE_DIR})
link_directories(${PYTHON_LIB_DIRS})
else()
# Otherwise, we might be building from top CMakeLists.txt directly.
# Try to find Python and pybind11 packages.
Expand Down Expand Up @@ -227,7 +232,8 @@ if(TRITON_BUILD_PYTHON_MODULE)
LLVMAArch64CodeGen
LLVMAArch64AsmParser
)
elseif(CMAKE_SYSTEM_PROCESSOR MATCHES "x86_64")
elseif(CMAKE_SYSTEM_PROCESSOR MATCHES "x86_64" OR
CMAKE_SYSTEM_PROCESSOR MATCHES "AMD64") # windows x64
list(APPEND TRITON_LIBRARIES
LLVMX86CodeGen
LLVMX86AsmParser
Expand Down Expand Up @@ -262,6 +268,8 @@ if(TRITON_BUILD_PYTHON_MODULE)
target_link_libraries(triton PUBLIC ${TRITON_LIBRARIES})
if(WIN32)
target_link_libraries(triton PRIVATE ${CMAKE_DL_LIBS})
set_target_properties(triton PROPERTIES SUFFIX ".pyd")
set_target_properties(triton PROPERTIES PREFIX "lib")
else()
target_link_libraries(triton PRIVATE z)
endif()
Expand Down Expand Up @@ -289,6 +297,11 @@ if(NOT TRITON_BUILD_PYTHON_MODULE)
endforeach()
endif()

if(WIN32)
option(CMAKE_USE_WIN32_THREADS_INIT "using WIN32 threads" ON)
option(gtest_disable_pthreads "Disable uses of pthreads in gtest." ON)
endif()

add_subdirectory(third_party/f2reduce)
add_subdirectory(bin)
add_subdirectory(test)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

#include <memory>
#include <optional>
#include <string>

namespace mlir {

Expand Down
38 changes: 38 additions & 0 deletions lib/Tools/LinearLayout.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,44 @@
#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
#define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n")

#if defined(_MSC_VER) && !defined(__clang__)
// from https://gist.github.com/pps83/3210a2f980fd02bb2ba2e5a1fc4a2ef0
#include <intrin.h>

static __forceinline int __builtin_ctz(unsigned x) {
#if defined(_M_ARM) || defined(_M_ARM64) || defined(_M_HYBRID_X86_ARM64) || \
defined(_M_ARM64EC)
return (int)_CountTrailingZeros(x);
#elif defined(__AVX2__) || defined(__BMI__)
return (int)_tzcnt_u32(x);
#else
unsigned long r;
_BitScanForward(&r, x);
return (int)r;
#endif
}

static __forceinline int __builtin_ctzll(unsigned long long x) {
#if defined(_M_ARM) || defined(_M_ARM64) || defined(_M_HYBRID_X86_ARM64) || \
defined(_M_ARM64EC)
return (int)_CountTrailingZeros64(x);
#elif defined(_WIN64)
#if defined(__AVX2__) || defined(__BMI__)
return (int)_tzcnt_u64(x);
#else
unsigned long r;
_BitScanForward64(&r, x);
return (int)r;
#endif
#else
int l = __builtin_ctz((unsigned)x);
int h = __builtin_ctz((unsigned)(x >> 32)) + 32;
return !!((unsigned)x) ? l : h;
#endif
}

#endif

namespace mlir::triton {

namespace {
Expand Down
71 changes: 41 additions & 30 deletions python/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,7 @@ def get_json_package_info():
def get_llvm_package_info():
system = platform.system()
try:
arch = {"x86_64": "x64", "arm64": "arm64", "aarch64": "arm64"}[platform.machine()]
arch = {"x86_64": "x64", "AMD64": "x64", "arm64": "arm64", "aarch64": "arm64"}[platform.machine()]
except KeyError:
arch = platform.machine()
if system == "Darwin":
Expand Down Expand Up @@ -196,6 +196,8 @@ def get_llvm_package_info():
f"LLVM pre-compiled image is not available for {system}-{arch}. Proceeding with user-configured LLVM from source build."
)
return Package("llvm", "LLVM-C.lib", "", "LLVM_INCLUDE_DIRS", "LLVM_LIBRARY_DIR", "LLVM_SYSPATH")
elif system == "Windows":
system_suffix = f"windows-{arch}"
else:
print(
f"LLVM pre-compiled image is not available for {system}-{arch}. Proceeding with user-configured LLVM from source build."
Expand Down Expand Up @@ -281,17 +283,20 @@ def download_and_copy(name, src_path, dst_path, variable, version, url_func):
base_dir = os.path.dirname(__file__)
system = platform.system()
try:
arch = {"x86_64": "64", "arm64": "aarch64", "aarch64": "aarch64"}[platform.machine()]
arch = {"x86_64": "64", "AMD64": "64", "arm64": "aarch64", "aarch64": "aarch64"}[platform.machine()]
except KeyError:
arch = platform.machine()
url = url_func(arch, version)
supported = {"Linux": "linux", "Windows": "win"}
is_supported = system in supported
if is_supported:
url = url_func(supported[system], arch, version)
tmp_path = os.path.join(triton_cache_path, "nvidia", name) # path to cache the download
dst_path = os.path.join(base_dir, os.pardir, "third_party", "nvidia", "backend", dst_path) # final binary path
platform_name = "sbsa-linux" if arch == "aarch64" else "x86_64-linux"
src_path = src_path(platform_name, version) if callable(src_path) else src_path
src_path = os.path.join(tmp_path, src_path)
download = not os.path.exists(src_path)
if os.path.exists(dst_path) and system == "Linux" and shutil.which(dst_path) is not None:
if os.path.exists(dst_path) and is_supported and shutil.which(dst_path) is not None:
curr_version = subprocess.check_output([dst_path, "--version"]).decode("utf-8").strip()
curr_version = re.search(r"V([.|\d]+)", curr_version).group(1)
download = download or curr_version != version
Expand Down Expand Up @@ -420,6 +425,10 @@ def build_extension(self, ext):
"-DTRITON_CODEGEN_BACKENDS=" + ';'.join([b.name for b in backends if not b.is_external]),
"-DTRITON_PLUGIN_DIRS=" + ';'.join([b.src_dir for b in backends if b.is_external])
]
if platform.system() == "Windows":
installed_base = sysconfig.get_config_var('installed_base')
py_lib_dirs = os.getenv("PYTHON_LIB_DIRS", os.path.join(installed_base, "libs"))
cmake_args.append("-DPYTHON_LIB_DIRS=" + py_lib_dirs)
if lit_dir is not None:
cmake_args.append("-DLLVM_EXTERNAL_LIT=" + lit_dir)
cmake_args.extend(thirdparty_cmake_args)
Expand All @@ -429,9 +438,8 @@ def build_extension(self, ext):
build_args = ["--config", cfg]

if platform.system() == "Windows":
cmake_args += ["-DCMAKE_BUILD_TYPE=" + cfg]
cmake_args += [f"-DCMAKE_RUNTIME_OUTPUT_DIRECTORY_{cfg.upper()}={extdir}"]
if sys.maxsize > 2**32:
cmake_args += ["-A", "x64"]
else:
cmake_args += ["-DCMAKE_BUILD_TYPE=" + cfg]
max_jobs = os.getenv("MAX_JOBS", str(2 * os.cpu_count()))
Expand Down Expand Up @@ -498,63 +506,66 @@ def get_platform_dependent_src_path(subdir):
if int(version_major) >= 12 and int(version_minor1) >= 5 else subdir)(*version.split('.')))


exe = ".exe" if os.name == "nt" else ""

download_and_copy(
name="ptxas", src_path="bin/ptxas", dst_path="bin/ptxas", variable="TRITON_PTXAS_PATH",
version=NVIDIA_TOOLCHAIN_VERSION["ptxas"], url_func=lambda arch, version:
name="ptxas", src_path=f"bin/ptxas{exe}", dst_path=f"bin/ptxas{exe}", variable="TRITON_PTXAS_PATH",
version=NVIDIA_TOOLCHAIN_VERSION["ptxas"], url_func=lambda system, arch, version:
((lambda version_major, version_minor1, version_minor2:
f"https://anaconda.org/nvidia/cuda-nvcc-tools/{version}/download/linux-{arch}/cuda-nvcc-tools-{version}-0.tar.bz2"
f"https://anaconda.org/nvidia/cuda-nvcc-tools/{version}/download/{system}-{arch}/cuda-nvcc-tools-{version}-0.tar.bz2"
if int(version_major) >= 12 and int(version_minor1) >= 5 else
f"https://anaconda.org/nvidia/cuda-nvcc/{version}/download/linux-{arch}/cuda-nvcc-{version}-0.tar.bz2")
f"https://anaconda.org/nvidia/cuda-nvcc/{version}/download/{system}-{arch}/cuda-nvcc-{version}-0.tar.bz2")
(*version.split('.'))))
download_and_copy(
name="cuobjdump",
src_path="bin/cuobjdump",
dst_path="bin/cuobjdump",
src_path=f"bin/cuobjdump{exe}",
dst_path=f"bin/cuobjdump{exe}",
variable="TRITON_CUOBJDUMP_PATH",
version=NVIDIA_TOOLCHAIN_VERSION["cuobjdump"],
url_func=lambda arch, version:
f"https://anaconda.org/nvidia/cuda-cuobjdump/{version}/download/linux-{arch}/cuda-cuobjdump-{version}-0.tar.bz2",
url_func=lambda system, arch, version:
f"https://anaconda.org/nvidia/cuda-cuobjdump/{version}/download/{system}-{arch}/cuda-cuobjdump-{version}-0.tar.bz2",
)
download_and_copy(
name="nvdisasm",
src_path="bin/nvdisasm",
dst_path="bin/nvdisasm",
src_path=f"bin/nvdisasm{exe}",
dst_path=f"bin/nvdisasm{exe}",
variable="TRITON_NVDISASM_PATH",
version=NVIDIA_TOOLCHAIN_VERSION["nvdisasm"],
url_func=lambda arch, version:
f"https://anaconda.org/nvidia/cuda-nvdisasm/{version}/download/linux-{arch}/cuda-nvdisasm-{version}-0.tar.bz2",
url_func=lambda system, arch, version:
f"https://anaconda.org/nvidia/cuda-nvdisasm/{version}/download/{system}-{arch}/cuda-nvdisasm-{version}-0.tar.bz2",
)
download_and_copy(
name="cudacrt", src_path=get_platform_dependent_src_path("include"), dst_path="include",
variable="TRITON_CUDACRT_PATH", version=NVIDIA_TOOLCHAIN_VERSION["cudacrt"], url_func=lambda arch, version:
variable="TRITON_CUDACRT_PATH", version=NVIDIA_TOOLCHAIN_VERSION["cudacrt"], url_func=lambda system, arch, version:
((lambda version_major, version_minor1, version_minor2:
f"https://anaconda.org/nvidia/cuda-crt-dev_linux-{arch}/{version}/download/noarch/cuda-crt-dev_linux-{arch}-{version}-0.tar.bz2"
f"https://anaconda.org/nvidia/cuda-crt-dev_{system}-{arch}/{version}/download/noarch/cuda-crt-dev_{system}-{arch}-{version}-0.tar.bz2"
if int(version_major) >= 12 and int(version_minor1) >= 5 else
f"https://anaconda.org/nvidia/cuda-nvcc/{version}/download/linux-{arch}/cuda-nvcc-{version}-0.tar.bz2")
f"https://anaconda.org/nvidia/cuda-nvcc/{version}/download/{system}-{arch}/cuda-nvcc-{version}-0.tar.bz2")
(*version.split('.'))))
download_and_copy(
name="cudart", src_path=get_platform_dependent_src_path("include"), dst_path="include",
variable="TRITON_CUDART_PATH", version=NVIDIA_TOOLCHAIN_VERSION["cudart"], url_func=lambda arch, version:
variable="TRITON_CUDART_PATH", version=NVIDIA_TOOLCHAIN_VERSION["cudart"], url_func=lambda system, arch, version:
((lambda version_major, version_minor1, version_minor2:
f"https://anaconda.org/nvidia/cuda-cudart-dev_linux-{arch}/{version}/download/noarch/cuda-cudart-dev_linux-{arch}-{version}-0.tar.bz2"
f"https://anaconda.org/nvidia/cuda-cudart-dev_{system}-{arch}/{version}/download/noarch/cuda-cudart-dev_{system}-{arch}-{version}-0.tar.bz2"
if int(version_major) >= 12 and int(version_minor1) >= 5 else
f"https://anaconda.org/nvidia/cuda-cudart-dev/{version}/download/linux-{arch}/cuda-cudart-dev-{version}-0.tar.bz2"
f"https://anaconda.org/nvidia/cuda-cudart-dev/{version}/download/{system}-{arch}/cuda-cudart-dev-{version}-0.tar.bz2"
)(*version.split('.'))))
download_and_copy(
name="cupti", src_path=get_platform_dependent_src_path("include"), dst_path="include",
variable="TRITON_CUPTI_INCLUDE_PATH", version=NVIDIA_TOOLCHAIN_VERSION["cupti"], url_func=lambda arch, version:
variable="TRITON_CUPTI_INCLUDE_PATH", version=NVIDIA_TOOLCHAIN_VERSION["cupti"],
url_func=lambda system, arch, version:
((lambda version_major, version_minor1, version_minor2:
f"https://anaconda.org/nvidia/cuda-cupti-dev/{version}/download/linux-{arch}/cuda-cupti-dev-{version}-0.tar.bz2"
f"https://anaconda.org/nvidia/cuda-cupti-dev/{version}/download/{system}-{arch}/cuda-cupti-dev-{version}-0.tar.bz2"
if int(version_major) >= 12 and int(version_minor1) >= 5 else
f"https://anaconda.org/nvidia/cuda-cupti/{version}/download/linux-{arch}/cuda-cupti-{version}-0.tar.bz2")
f"https://anaconda.org/nvidia/cuda-cupti/{version}/download/{system}-{arch}/cuda-cupti-{version}-0.tar.bz2")
(*version.split('.'))))
download_and_copy(
name="cupti", src_path=get_platform_dependent_src_path("lib"), dst_path="lib/cupti",
variable="TRITON_CUPTI_LIB_PATH", version=NVIDIA_TOOLCHAIN_VERSION["cupti"], url_func=lambda arch, version:
variable="TRITON_CUPTI_LIB_PATH", version=NVIDIA_TOOLCHAIN_VERSION["cupti"], url_func=lambda system, arch, version:
((lambda version_major, version_minor1, version_minor2:
f"https://anaconda.org/nvidia/cuda-cupti-dev/{version}/download/linux-{arch}/cuda-cupti-dev-{version}-0.tar.bz2"
f"https://anaconda.org/nvidia/cuda-cupti-dev/{version}/download/{system}-{arch}/cuda-cupti-dev-{version}-0.tar.bz2"
if int(version_major) >= 12 and int(version_minor1) >= 5 else
f"https://anaconda.org/nvidia/cuda-cupti/{version}/download/linux-{arch}/cuda-cupti-{version}-0.tar.bz2")
f"https://anaconda.org/nvidia/cuda-cupti/{version}/download/{system}-{arch}/cuda-cupti-{version}-0.tar.bz2")
(*version.split('.'))))

backends = [*BackendInstaller.copy(["nvidia", "amd"]), *BackendInstaller.copy_externals()]
Expand Down
6 changes: 3 additions & 3 deletions python/src/main.cc
Original file line number Diff line number Diff line change
Expand Up @@ -4,22 +4,22 @@

namespace py = pybind11;

#define EXPAND(x) x
#define FOR_EACH_1(MACRO, X) MACRO(X)
#define FOR_EACH_2(MACRO, X, ...) MACRO(X) FOR_EACH_1(MACRO, __VA_ARGS__)
#define FOR_EACH_3(MACRO, X, ...) MACRO(X) FOR_EACH_2(MACRO, __VA_ARGS__)
#define FOR_EACH_4(MACRO, X, ...) MACRO(X) FOR_EACH_3(MACRO, __VA_ARGS__)

#define FOR_EACH_NARG(...) FOR_EACH_NARG_(__VA_ARGS__, FOR_EACH_RSEQ_N())
#define FOR_EACH_NARG_(...) FOR_EACH_ARG_N(__VA_ARGS__)
#define FOR_EACH_NARG_(...) EXPAND(FOR_EACH_ARG_N(__VA_ARGS__))
#define FOR_EACH_ARG_N(_1, _2, _3, _4, N, ...) N
#define FOR_EACH_RSEQ_N() 4, 3, 2, 1, 0

#define CONCATENATE(x, y) CONCATENATE1(x, y)
#define CONCATENATE1(x, y) x##y

#define FOR_EACH(MACRO, ...) \
CONCATENATE(FOR_EACH_, FOR_EACH_NARG_HELPER(__VA_ARGS__))(MACRO, __VA_ARGS__)
#define FOR_EACH_NARG_HELPER(...) FOR_EACH_NARG(__VA_ARGS__)
EXPAND(CONCATENATE(FOR_EACH_, FOR_EACH_NARG(__VA_ARGS__))(MACRO, __VA_ARGS__))

// New macro to remove parentheses
#define REMOVE_PARENS(...) __VA_ARGS__
Expand Down
10 changes: 7 additions & 3 deletions python/triton/backends/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,20 +204,24 @@ def __init__(self, target: GPUTarget) -> None:

@staticmethod
def _path_to_binary(binary: str):
exe = ".exe" if os.name == "nt" else ""
base_dir = os.path.join(os.path.dirname(__file__), os.pardir)
paths = [
os.environ.get(f"TRITON_{binary.upper()}_PATH", ""),
os.path.join(base_dir, "third_party", "cuda", "bin", binary),
os.path.join(base_dir, "third_party", "cuda", "bin", f"{binary}{exe}"),
]
for p in paths:
bin = p.split(" ")[0]
if os.name != "nt":
bin = p.split(" ")[0]
else:
bin = p
if os.path.exists(bin) and os.path.isfile(bin):
result = subprocess.check_output([bin, "--version"], stderr=subprocess.STDOUT)
if result is not None:
version = re.search(r".*release (\d+\.\d+).*", result.decode("utf-8"), flags=re.MULTILINE)
if version is not None:
return p, version.group(1)
raise RuntimeError(f"Cannot find {binary}")
raise RuntimeError(f"Cannot find {binary}{exe}")

@abstractclassmethod
def supports_target(target: GPUTarget):
Expand Down
3 changes: 2 additions & 1 deletion python/triton/compiler/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,8 @@ def triton_key():

# backend
libtriton_hash = hashlib.sha256()
with open(os.path.join(TRITON_PATH, "_C/libtriton.so"), "rb") as f:
ext = "so" if os.name != "nt" else "pyd"
with open(os.path.join(TRITON_PATH, "_C", "libtriton." + ext), "rb") as f:
while True:
chunk = f.read(1024**2)
if not chunk:
Expand Down
1 change: 1 addition & 0 deletions python/triton/runtime/build.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ def _build(name, src, srcdir, library_dirs, include_dirs, libraries):
cc_cmd += [f'-l{lib}' for lib in libraries]
cc_cmd += [f"-L{dir}" for dir in library_dirs]
cc_cmd += [f"-I{dir}" for dir in include_dirs if dir is not None]
if os.name == "nt": cc_cmd.pop(cc_cmd.index("-fPIC"))
ret = subprocess.check_call(cc_cmd)
if ret == 0:
return so
Expand Down
31 changes: 31 additions & 0 deletions third_party/amd/backend/driver.c
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,12 @@
// clang-format on
#define PY_SSIZE_T_CLEAN
#include <Python.h>
#ifndef _WIN32
#include <dlfcn.h>
#else
#define WIN32_LEAN_AND_MEAN
#include <windows.h>
#endif
#include <stdio.h>
#include <stdlib.h>

Expand Down Expand Up @@ -58,11 +63,16 @@ static struct HIPSymbolTable hipSymbolTable;

bool initSymbolTable() {
// Use the HIP runtime library loaded into the existing process if it exits.
#ifndef _WIN32
void *lib = dlopen("libamdhip64.so", RTLD_NOLOAD);
#else
HMODULE lib = LoadLibraryA("amdhip64.dll");
#endif
if (lib) {
// printf("[triton] chosen loaded libamdhip64.so in the process\n");
}

#ifndef _WIN32
// Otherwise, go through the list of search paths to dlopen the first HIP
// driver library.
if (!lib) {
Expand All @@ -79,8 +89,15 @@ bool initSymbolTable() {
PyErr_SetString(PyExc_RuntimeError, "cannot open libamdhip64.so");
return false;
}
#else
if (!lib) {
PyErr_SetString(PyExc_RuntimeError, "cannot open amdhip64.dll");
return false;
}
#endif

// Resolve all symbols we are interested in.
#ifndef _WIN32
dlerror(); // Clear existing errors
const char *error = NULL;
#define QUERY_EACH_FN(hipSymbolName, ...) \
Expand All @@ -92,6 +109,20 @@ bool initSymbolTable() {
dlclose(lib); \
return false; \
}
#else
long error = 0;
#define QUERY_EACH_FN(hipSymbolName, ...) \
*(void **)&hipSymbolTable.hipSymbolName = \
GetProcAddress((HMODULE)lib, #hipSymbolName); \
error = GetLastError(); \
if (error) { \
PyErr_SetString(PyExc_RuntimeError, \
"cannot query " #hipSymbolName " from amdhip64.dll"); \
FreeLibrary(lib); \
return false; \
}

#endif

HIP_SYMBOL_LIST(QUERY_EACH_FN, QUERY_EACH_FN)

Expand Down
Loading