diff --git a/.github/workflows/integration-tests.yml b/.github/workflows/integration-tests.yml index 52f6ae9a56c0..b1d6e614ea4a 100644 --- a/.github/workflows/integration-tests.yml +++ b/.github/workflows/integration-tests.yml @@ -23,26 +23,25 @@ jobs: matrix-required: ${{ steps.set-matrix.outputs.matrix-required }} matrix-optional: ${{ steps.set-matrix.outputs.matrix-optional }} steps: - - name: Prepare runner matrix + - name: Prepare matrix id: set-matrix run: | if [ x"${{ github.repository }}" == x"openai/triton" ]; then - echo '::set-output name=matrix-required::[["self-hosted", "A100"], ["self-hosted", "H100"]]' - echo '::set-output name=matrix-optional::[["self-hosted", "gfx908"], ["self-hosted", "arc770"]]' + echo 'matrix-required={"runner": [["self-hosted", "A100"], ["self-hosted", "H100"]], "python-version": ["3.11"], "cuda-version": ["12.1"], "cc": ["clang"]}' >> "$GITHUB_OUTPUT" + echo 'matrix-optional={"runner": [["self-hosted", "gfx908"], ["self-hosted", "arc770"]], "python-version": ["3.11"], "cuda-version": ["12.1"], "cc": ["clang"]}' >> "$GITHUB_OUTPUT" else - echo '::set-output name=matrix-required::["ubuntu-latest"]' - echo '::set-output name=matrix-optional::["ubuntu-latest"]' + echo 'matrix-required={"runner":["ubuntu-latest", "windows-latest"], "python-version": ["3.8", "3.9", "3.10", "3.11"], "cuda-version": ["11.8.89", "12.1.1"], "cc": ["clang"]}' >> "$GITHUB_OUTPUT" + echo 'matrix-optional={"runner":["ubuntu-latest", "windows-latest"], "python-version": ["3.8", "3.9", "3.10", "3.11"], "cuda-version": ["11.8.89", "12.1.1"], "cc": ["clang"]}' >> "$GITHUB_OUTPUT" fi Integration-Tests: needs: Runner-Preparation runs-on: ${{ matrix.runner }} - timeout-minutes: 20 + timeout-minutes: 60 strategy: - matrix: - runner: ${{fromJson(needs.Runner-Preparation.outputs.matrix-required)}} + matrix: ${{fromJson(needs.Runner-Preparation.outputs.matrix-required)}} steps: - name: Checkout @@ -55,15 +54,99 @@ jobs: echo "BACKEND=CUDA" >> "${GITHUB_ENV}" echo "TRITON_DISABLE_LINE_INFO=1" >> "${GITHUB_ENV}" + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v3 + with: + python-version: ${{ matrix.python-version }} + + - name: Set up MSVC + if: matrix.runner == 'windows-latest' + uses: ilammy/msvc-dev-cmd@v1.12.1 + with: + arch: amd64 + + - name: Setup Micromamba + uses: mamba-org/setup-micromamba@v1 + if: matrix.runner[0] != 'self-hosted' + with: + environment-name: triton-env + init-shell: bash + create-args: >- + typer + ca-certificates + certifi + openssl + zlib + zstd + llvm>=17.0 + condarc: | + channels: + - nvidia/label/cuda-${{ matrix.cuda-version }} + - conda-forge + - pytorch + channel_priority: strict + + - name: set Environment Variables (Windows) + if: matrix.runner == 'windows-latest' + shell: bash -el {0} + run: | + ver=4017f04e + curl -L -O https://github.com/wkpark/triton/releases/download/llvm-$ver-windows/llvm-$ver-windows-x64.tar.gz + curl -L -O https://github.com/wkpark/triton/releases/download/llvm-$ver-windows/llvm-fix.patch + tar xvf llvm-$ver-windows-x64.tar.gz + mv llvm-$ver-windows-x64 LLVM + patch -p0 < llvm-fix.patch + echo "LLVM_SYSPATH=${{ github.workspace }}\\LLVM" >> "$GITHUB_ENV" + rm -f llvm-$ver-windows-x64.tar.gz + + ### LLVM_SHORTHASH="$(cat cmake/llvm-hash.txt | cut -c1-8)" + # prepare LLVM prebuilt path. will be downloaded and extracted by setup.py step + ### echo "~/.triton/llvm/llvm-$LLVM_SHORTHASH-windows-x64/bin" >> "$GITHUB_PATH" + #echo "LLVM_SYSPATH=~/.triton/llvm/llvm-$LLVM_SHORTHASH-windows-x64" >> "$GITHUB_ENV" + # compile with a selected matrix.cc + if [ "${{matrix.cc}}" = "cl" ]; then + echo "CC=cl" >> "${GITHUB_ENV}" + echo "CXX=cl" >> "${GITHUB_ENV}" + elif [ "${{matrix.cc}}" = "clang" ]; then + echo "CC=clang" >> "${GITHUB_ENV}" + echo "CXX=clang++" >> "${GITHUB_ENV}" + fi + + - name: CUDA Setup ${{ matrix.cuda-version }} + if: matrix.runner[0] != 'self-hosted' + shell: bash -el {0} + run: | + CUDA_HOME="${{ env.MAMBA_ROOT_PREFIX }}/envs/bnb-env" + echo CUDA_HOME=$CUDA_HOME >> "$GITHUB_ENV" + echo CUDA_PATH=$CUDA_HOME >> "$GITHUB_ENV" + + - name: Update environment + if: matrix.runner[0] != 'self-hosted' + shell: bash + run: | + echo "BACKEND=CUDA" >> "${GITHUB_ENV}" + echo "ENABLE_TMA=0" >> "${GITHUB_ENV}" + echo "TRITON_DISABLE_LINE_INFO=1" >> "${GITHUB_ENV}" + + - name: Set reusable strings + # Turn repeated input strings (such as the build output directory) into step outputs. These step outputs can be used throughout the workflow file. + id: strings + shell: bash + run: | + echo "build-output-dir=${{ github.workspace }}/build" >> "$GITHUB_OUTPUT" + - name: Clear cache + shell: bash run: | rm -rf ~/.triton - name: Update PATH + if: matrix.runner[0] == 'self-hosted' run: | echo "PATH=${HOME}/.local/bin:${PATH}" >> "${GITHUB_ENV}" - name: Check pre-commit + shell: bash run: | python3 -m pip install --upgrade pre-commit # TODO: ignore the first yapf failure until https://github.com/google/yapf/issues/1164 is fixed @@ -73,17 +156,29 @@ jobs: python3 -m pre_commit run --all-files --verbose - name: Install Triton - if: ${{ env.BACKEND == 'CUDA'}} + if: matrix.runner != 'windows-latest' run: | cd python python3 -m pip install --upgrade pip - python3 -m pip install cmake==3.24 ninja pytest-xdist + python3 -m pip install cmake==3.24 ninja pytest-xdist wheel sudo apt-get update -y sudo apt-get install -y ccache clang lld TRITON_BUILD_WITH_CLANG_LLD=true TRITON_BUILD_WITH_CCACHE=true python3 -m pip install --no-build-isolation -vvv '.[tests]' + if [ "${{ matrix.runner }}" = 'ubuntu-latest' ]; then + python3 setup.py bdist_wheel + fi + + - name: Install Triton (Windows) + if: matrix.runner == 'windows-latest' + run: | + cd python + python -m pip install --upgrade pip + python -m pip install cmake==3.24 ninja pytest-xdist wheel + python -m pip install --no-build-isolation -vvv . + python setup.py bdist_wheel - name: Run lit tests - if: ${{ env.BACKEND == 'CUDA'}} + if: matrix.runner[0] == 'self-hosted' && env.BACKEND == 'CUDA' run: | python3 -m pip install lit cd python @@ -94,7 +189,7 @@ jobs: lit -v "${LIT_TEST_DIR}" - name: Run python tests on CUDA - if: ${{ env.BACKEND == 'CUDA' }} + if: ${{ (matrix.runner[0] == 'self-hosted') && env.BACKEND == 'CUDA' }} run: | cd python/test/unit python3 -m pytest -vvv -n 8 --ignore=runtime --ignore=operators --ignore=language/test_line_info.py --ignore=language/test_subprocess.py @@ -107,10 +202,12 @@ jobs: python3 -m pytest -vvv hopper/test_flashattention.py - name: Clear cache + shell: bash run: | rm -rf ~/.triton - name: Run interpreter tests + if: matrix.runner[0] == 'self-hosted' env: # TRITON_INTERPRET: "1" CUA_VISIBLE_DEVICES: "" @@ -119,11 +216,19 @@ jobs: python3 -m pytest -vvv -s operators/test_flash_attention.py - name: Run partial tests on CUDA - if: ${{ env.BACKEND == 'CUDA' }} + if: ${{ (matrix.runner[0] == 'self-hosted') && env.BACKEND == 'CUDA' }} run: | cd python/test/unit python3 -m pytest -vvv -n 8 operators + - name: Upload Build artifacts + if: matrix.runner[0] != 'self-hosted' + uses: actions/upload-artifact@v3 + with: + name: triton-dist ${{ matrix.runner }} python-${{ matrix.python-version }} cuda-${{ matrix.cuda-version }} + path: | + ${{ github.workspace }}/python/dist/ + - name: Create artifacts archive if: ${{(matrix.runner[0] == 'self-hosted') && (matrix.runner[1] == 'V100' || matrix.runner[1] == 'A100' || matrix.runner[1] == 'H100')}} run: | @@ -132,13 +237,13 @@ jobs: - name: Upload artifacts archive if: ${{(matrix.runner[0] == 'self-hosted') && (matrix.runner[1] == 'V100' || matrix.runner[1] == 'A100' || matrix.runner[1] == 'H100')}} - uses: actions/upload-artifact@v2 + uses: actions/upload-artifact@v3 with: name: artifacts ${{ matrix.runner[1] }} path: ~/.triton/artifacts.tar.gz - name: Run CXX unittests - if: ${{ env.BACKEND == 'CUDA'}} + if: ${{(matrix.runner[0] == 'self-hosted') && env.BACKEND == 'CUDA'}} run: | cd python cd "build/$(ls build | grep -i cmake)" @@ -146,6 +251,7 @@ jobs: Compare-artifacts: + if: ${{(github.repository == 'openai/triton')}} needs: Integration-Tests timeout-minutes: 5 diff --git a/.gitignore b/.gitignore index 7c5081621fc3..c89206f736e3 100644 --- a/.gitignore +++ b/.gitignore @@ -6,6 +6,7 @@ python/build/ python/triton.egg-info/ python/triton/_C/libtriton.pyd python/triton/_C/libtriton.so +python/triton/_C/triton.dll # Backends copied from submodules python/triton/backends/ diff --git a/CMakeLists.txt b/CMakeLists.txt index 309855e0c4f3..529e29a2bf3e 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -30,8 +30,17 @@ set(TRITON_CODEGEN_BACKENDS "" CACHE STRING "Enable different codegen backends") # used conditionally in this file and by lit tests # Customized release build type with assertions: TritonRelBuildWithAsserts -set(CMAKE_C_FLAGS_TRITONRELBUILDWITHASSERTS "-O2 -g") -set(CMAKE_CXX_FLAGS_TRITONRELBUILDWITHASSERTS "-O2 -g") +if(NOT MSVC) + set(CMAKE_C_FLAGS_TRITONRELBUILDWITHASSERTS "-O2 -g") + set(CMAKE_CXX_FLAGS_TRITONRELBUILDWITHASSERTS "-O2 -g") +else() + set(CMAKE_C_FLAGS_TRITONRELBUILDWITHASSERTS "/Zi /Ob0 /Od /RTC1") + set(CMAKE_CXX_FLAGS_TRITONRELBUILDWITHASSERTS "/Zi /Ob0 /Od /RTC1") + set(CMAKE_EXE_LINKER_FLAGS_TRITONRELBUILDWITHASSERTS "/debug:fastlink /INCREMENTAL") + set(CMAKE_MODULE_LINKER_FLAGS_TRITONRELBUILDWITHASSERTS "/debug:fastlink /INCREMENTAL") + set(CMAKE_SHARED_LINKER_FLAGS_TRITONRELBUILDWITHASSERTS "/debug:fastlink /INCREMENTAL") + set(CMAKE_STATIC_LINKER_FLAGS_TRITONRELBUILDWITHASSERTS "/debug:fastlink /INCREMENTAL") +endif() # Default build type if(NOT CMAKE_BUILD_TYPE) @@ -45,7 +54,15 @@ endif() # Compiler flags include_directories(${CMAKE_CURRENT_SOURCE_DIR}/include) -set(CMAKE_CXX_FLAGS "${CMAKE_C_FLAGS} -D__STDC_FORMAT_MACROS -fPIC -std=gnu++17") +if(NOT MSVC) + if(NOT WIN32) + set(CMAKE_CXX_FLAGS "${CMAKE_C_FLAGS} -D__STDC_FORMAT_MACROS -fPIC -std=gnu++17") + else() + set(CMAKE_CXX_FLAGS "${CMAKE_C_FLAGS} -D__STDC_FORMAT_MACROS -std=gnu++17 -Wno-deprecated") + endif() +else() + set(CMAKE_CXX_FLAGS "${CMAKE_C_FLAGS} -D__STDC_FORMAT_MACROS /wd4244 /wd4624 /wd4715 /wd4530") +endif() # Third-party include_directories(${PYBIND11_INCLUDE_DIR}) @@ -103,7 +120,11 @@ endfunction() # Disable warnings that show up in external code (gtest;pybind11) -set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Werror -Wno-covered-switch-default -fvisibility=hidden") +if(NOT MSVC) + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Werror -Wno-covered-switch-default -fvisibility=hidden") +else() + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} /WX-") +endif() include_directories(".") include_directories(${MLIR_INCLUDE_DIRS}) @@ -137,6 +158,8 @@ if(TRITON_BUILD_PYTHON_MODULE) if(PYTHON_INCLUDE_DIRS) include_directories(${PYTHON_INCLUDE_DIRS}) + message(STATUS "PYTHON_LIB_DIRS ${PYTHON_LIB_DIRS}") + link_directories(${PYTHON_LIB_DIRS}) else() find_package(Python3 REQUIRED COMPONENTS Development Interpreter) include_directories(${Python3_INCLUDE_DIRS}) @@ -203,6 +226,8 @@ if(TRITON_BUILD_PYTHON_MODULE) target_link_libraries(triton PUBLIC ${TRITON_LIBRARIES}) if(WIN32) target_link_libraries(triton PRIVATE ${CMAKE_DL_LIBS}) + set_target_properties(triton PROPERTIES SUFFIX ".pyd") + set_target_properties(triton PROPERTIES PREFIX "lib") else() target_link_libraries(triton PRIVATE z) endif() @@ -220,6 +245,11 @@ if(TRITON_BUILD_PYTHON_MODULE AND NOT WIN32) target_link_libraries(triton PRIVATE ${PYTHON_LDFLAGS}) endif() +if(WIN32) + option(CMAKE_USE_WIN32_THREADS_INIT "using WIN32 threads" ON) + option(gtest_disable_pthreads "Disable uses of pthreads in gtest." ON) +endif() + add_subdirectory(bin) add_subdirectory(test) add_subdirectory(unittest) diff --git a/bin/CMakeLists.txt b/bin/CMakeLists.txt index 9acab3da1f9b..2f5880e0c5e7 100644 --- a/bin/CMakeLists.txt +++ b/bin/CMakeLists.txt @@ -71,6 +71,7 @@ mlir_check_all_link_libraries(triton-lsp) add_llvm_executable(triton-llvm-opt + PARTIAL_SOURCES_INTENDED triton-llvm-opt.cpp DEPENDS diff --git a/lib/Dialect/TritonGPU/IR/Dialect.cpp b/lib/Dialect/TritonGPU/IR/Dialect.cpp index 96e91dbfb0c5..34886d3902b9 100644 --- a/lib/Dialect/TritonGPU/IR/Dialect.cpp +++ b/lib/Dialect/TritonGPU/IR/Dialect.cpp @@ -1433,6 +1433,7 @@ MfmaEncodingAttr::getShapePerCTATileForDotOperands(ArrayRef shape, return {32, parentShapePerCTA[1]}; } else { assert(0 && "DotOperandEncodingAttr opIdx must be 0 or 1"); + return {}; } } diff --git a/lib/Dialect/TritonGPU/Transforms/Pipeliner/MatmulLoopPipeline.cpp b/lib/Dialect/TritonGPU/Transforms/Pipeliner/MatmulLoopPipeline.cpp index b37d148a8238..27046de8b70e 100644 --- a/lib/Dialect/TritonGPU/Transforms/Pipeliner/MatmulLoopPipeline.cpp +++ b/lib/Dialect/TritonGPU/Transforms/Pipeliner/MatmulLoopPipeline.cpp @@ -634,8 +634,8 @@ void mlir::triton::asyncLaunchDots(scf::ForOp forOp) { lastOp = op; op = op->getBlock()->getParentOp(); } - return std::distance(lastOp->getBlock()->getParent()->begin(), - lastOp->getBlock()->getIterator()); + return (long)std::distance(lastOp->getBlock()->getParent()->begin(), + lastOp->getBlock()->getIterator()); }; /// XXX(Keren): Clean up the following duplicate code with checkDotOp /// dots to be pipelined diff --git a/python/setup.py b/python/setup.py index 9a8cb82455a4..fc1d8cedb419 100644 --- a/python/setup.py +++ b/python/setup.py @@ -169,7 +169,7 @@ def get_thirdparty_packages(): if p.syspath_var_name in os.environ: package_dir = os.environ[p.syspath_var_name] version_file_path = os.path.join(package_dir, "version.txt") - if p.syspath_var_name not in os.environ and\ + if p.syspath_var_name not in os.environ and p.url and\ (not os.path.exists(version_file_path) or Path(version_file_path).read_text() != p.url): try: shutil.rmtree(package_root_dir) @@ -182,6 +182,9 @@ def get_thirdparty_packages(): # write version url to package_dir with open(os.path.join(package_dir, "version.txt"), "w") as f: f.write(p.url) + elif p.syspath_var_name not in os.environ and not p.url: + raise RuntimeError( + f'{p.syspath_var_name} not set ! Please install {p.package} manually and set {p.syspath_var_name}.') if p.include_flag: thirdparty_cmake_args.append(f"-D{p.include_flag}={package_dir}/include") if p.lib_flag: @@ -195,13 +198,17 @@ def download_and_copy(src_path, variable, version, url_func): return base_dir = os.path.dirname(__file__) system = platform.system() - arch = {"x86_64": "64", "arm64": "aarch64"}[platform.machine()] - url = url_func(arch, version) + arch = {"x86_64": "64", "AMD64": "64", "arm64": "aarch64"}[platform.machine()] + supported = {"Linux": "linux", "Windows": "win"} + is_supported = system in supported + if is_supported: + url = url_func(supported[system], arch, version) tmp_path = os.path.join(triton_cache_path, "nvidia") # path to cache the download dst_path = os.path.join(base_dir, os.pardir, "third_party", "nvidia", "backend", src_path) # final binary path src_path = os.path.join(tmp_path, src_path) + src_path += ".exe" if os.name == "nt" else "" download = not os.path.exists(src_path) - if os.path.exists(dst_path) and system == "Linux": + if os.path.exists(dst_path) and is_supported: curr_version = subprocess.check_output([dst_path, "--version"]).decode("utf-8").strip() curr_version = re.search(r"V([.|\d]+)", curr_version).group(1) download = download or curr_version != version @@ -300,6 +307,10 @@ def build_extension(self, ext): "-DCMAKE_VERBOSE_MAKEFILE:BOOL=ON", "-DPYTHON_INCLUDE_DIRS=" + python_include_dir, "-DTRITON_CODEGEN_BACKENDS=" + ';'.join([b.name for b in backends]) ] + if platform.system() == "Windows": + installed_base = sysconfig.get_config_var('installed_base') + py_lib_dirs = os.getenv("PYTHON_LIB_DIRS", os.path.join(installed_base, "libs")) + cmake_args.append("-DPYTHON_LIB_DIRS=" + py_lib_dirs) if lit_dir is not None: cmake_args.append("-DLLVM_EXTERNAL_LIT=" + lit_dir) cmake_args.extend(thirdparty_cmake_args) @@ -316,9 +327,8 @@ def build_extension(self, ext): # cmake_args += ["-DTRITON_CODEGEN_BACKENDS=" + all_codegen_backends] if platform.system() == "Windows": + cmake_args += ["-DCMAKE_BUILD_TYPE=" + cfg] cmake_args += [f"-DCMAKE_RUNTIME_OUTPUT_DIRECTORY_{cfg.upper()}={extdir}"] - if sys.maxsize > 2**32: - cmake_args += ["-A", "x64"] else: cmake_args += ["-DCMAKE_BUILD_TYPE=" + cfg] max_jobs = os.getenv("MAX_JOBS", str(2 * os.cpu_count())) @@ -360,28 +370,32 @@ def build_extension(self, ext): subprocess.check_call(["cmake", "--build", ".", "--target", "mlir-doc"], cwd=cmake_dir) -download_and_copy( - src_path="bin/ptxas", - variable="TRITON_PTXAS_PATH", - version="12.3.52", - url_func=lambda arch, version: - f"https://anaconda.org/nvidia/cuda-nvcc/12.3.52/download/linux-{arch}/cuda-nvcc-{version}-0.tar.bz2", -) -download_and_copy( - src_path="bin/cuobjdump", - variable="TRITON_CUOBJDUMP_PATH", - version="12.3.52", - url_func=lambda arch, version: - f"https://anaconda.org/nvidia/cuda-cuobjdump/12.3.52/download/linux-{arch}/cuda-cuobjdump-{version}-0.tar.bz2", -) -download_and_copy( - src_path="bin/nvdisasm", - variable="TRITON_NVDISASM_PATH", - version="12.3.52", - url_func=lambda arch, version: - f"https://anaconda.org/nvidia/cuda-nvdisasm/12.3.52/download/linux-{arch}/cuda-nvdisasm-{version}-0.tar.bz2", -) -backends = _copy_backends(["nvidia", "amd"]) +if platform.system() in ["Linux", "Windows"]: + download_and_copy( + src_path="bin/ptxas", + variable="TRITON_PTXAS_PATH", + version="12.3.52", + url_func=lambda system, arch, version: + f"https://anaconda.org/nvidia/cuda-nvcc/{version}/download/{system}-{arch}/cuda-nvcc-{version}-0.tar.bz2", + ) + download_and_copy( + src_path="bin/cuobjdump", + variable="TRITON_CUOBJDUMP_PATH", + version="12.3.52", + url_func=lambda system, arch, version: + f"https://anaconda.org/nvidia/cuda-cuobjdump/{version}/download/{system}-{arch}/cuda-cuobjdump-{version}-0.tar.bz2", + ) + download_and_copy( + src_path="bin/nvdisasm", + variable="TRITON_NVDISASM_PATH", + version="12.3.52", + url_func=lambda system, arch, version: + f"https://anaconda.org/nvidia/cuda-nvdisasm/{version}/download/{system}-{arch}/cuda-nvdisasm-{version}-0.tar.bz2", + ) +backends = ["nvidia", "amd"] +if os.name == "nt": + backends = ["nvidia"] +backends = _copy_backends(backends) def add_link_to_backends(): diff --git a/python/test/unit/language/test_annotations.py b/python/test/unit/language/test_annotations.py index 26bb40664904..0c1f065a10aa 100644 --- a/python/test/unit/language/test_annotations.py +++ b/python/test/unit/language/test_annotations.py @@ -1,12 +1,41 @@ from __future__ import annotations - import torch - import triton import triton.language as tl +import pytest + + +def annotated_function(return_type=None, **arg_types): + """A decorator to add annotations to a function.""" + + def decorator(func): + func.__annotations__ = {**arg_types, 'return': return_type} + return func + + return decorator + + +# Test integer annotations +@pytest.mark.parametrize(("signed", "width"), [ + (signed, width) for signed in [False, True]\ + for width in [8, 16, 32, 64] +] + [(False, 1)] + ) +def test_int_annotation(signed, width, device): + + @triton.jit + @annotated_function(X=torch.tensor, v=f"tl.{'' if signed else 'u'}int{width}") + def _kernel(X, v): + tl.store(X, v) + + h = _kernel[(1, )](torch.empty(1, device=device), 3) + pfx = 'si' if signed else 'ui' + assert f'%arg1: i{width}' in h.asm["ttir"] + assert f'arith.{pfx}tofp' in h.asm["ttir"] -def test_annotations(device): +# Test that unknown annotations do not emit an error +def test_unknown_annotation(device): @triton.jit def _kernel(X: torch.Tensor, N: int, BLOCK_SIZE: tl.constexpr): diff --git a/python/test/unit/runtime/test_cache.py b/python/test/unit/runtime/test_cache.py index cd589fa920f5..339dc25e617a 100644 --- a/python/test/unit/runtime/test_cache.py +++ b/python/test/unit/runtime/test_cache.py @@ -70,7 +70,7 @@ def test_nested1_change(): def write_and_load_module(code, num_extra_lines): - with tempfile.NamedTemporaryFile(mode='w+', suffix='.py') as f: + with tempfile.NamedTemporaryFile(mode='w+', suffix='.py', delete=False) as f: f.write(('# extra line\n' * num_extra_lines) + code) f.flush() spec = importlib.util.spec_from_file_location("module.name", f.name) diff --git a/python/test/unit/runtime/test_subproc.py b/python/test/unit/runtime/test_subproc.py index b9b9c7d3cab8..71ffa6d253e7 100644 --- a/python/test/unit/runtime/test_subproc.py +++ b/python/test/unit/runtime/test_subproc.py @@ -34,11 +34,15 @@ def kernel_sub(a, b, o, N: tl.constexpr): def test_compile_in_subproc() -> None: + import os major, minor = torch.cuda.get_device_capability(0) cc = major * 10 + minor config = triton.compiler.AttrsDescriptor(tuple(range(4)), ()) - multiprocessing.set_start_method('fork') + if os.name == "nt": + multiprocessing.set_start_method('spawn') + else: + multiprocessing.set_start_method('fork') proc = multiprocessing.Process(target=compile_fn, args=(config, cc)) proc.start() proc.join() @@ -64,7 +68,7 @@ def test_compile_in_forked_subproc() -> None: capability = major * 10 + minor config = triton.compiler.AttrsDescriptor(tuple(range(1)), ()) - assert multiprocessing.get_start_method() == 'fork' + assert multiprocessing.get_start_method() in ['fork', 'spawn'] proc = multiprocessing.Process(target=compile_fn_dot, args=(config, capability)) proc.start() proc.join() diff --git a/python/triton/backends/compiler.py b/python/triton/backends/compiler.py index 0655b3fa5ade..12a9ba0064d1 100644 --- a/python/triton/backends/compiler.py +++ b/python/triton/backends/compiler.py @@ -12,13 +12,17 @@ def __init__(self, target: tuple) -> None: @staticmethod def _path_to_binary(binary: str): + binary += ".exe" if os.name == "nt" else "" base_dir = os.path.join(os.path.dirname(__file__), os.pardir) paths = [ os.environ.get(f"TRITON_{binary.upper()}_PATH", ""), os.path.join(base_dir, "third_party", "cuda", "bin", binary), ] for p in paths: - bin = p.split(" ")[0] + if os.name != "nt": + bin = p.split(" ")[0] + else: + bin = p if os.path.exists(bin) and os.path.isfile(bin): result = subprocess.check_output([bin, "--version"], stderr=subprocess.STDOUT) if result is not None: diff --git a/python/triton/compiler/code_generator.py b/python/triton/compiler/code_generator.py index 10911ec1911b..39df99d5b198 100644 --- a/python/triton/compiler/code_generator.py +++ b/python/triton/compiler/code_generator.py @@ -1177,6 +1177,7 @@ def str_to_ty(name): "i16": language.int16, "i32": language.int32, "i64": language.int64, + "u1": language.int1, "u8": language.uint8, "u16": language.uint16, "u32": language.uint32, diff --git a/python/triton/compiler/compiler.py b/python/triton/compiler/compiler.py index 52080bf29540..39ad248a13fc 100644 --- a/python/triton/compiler/compiler.py +++ b/python/triton/compiler/compiler.py @@ -159,7 +159,8 @@ def triton_key(): contents += [hashlib.sha1(f.read()).hexdigest()] # backend libtriton_hash = hashlib.sha1() - with open(os.path.join(TRITON_PATH, "_C/libtriton.so"), "rb") as f: + ext = "so" if os.name != "nt" else "pyd" + with open(os.path.join(TRITON_PATH, "_C", "libtriton." + ext), "rb") as f: while True: chunk = f.read(1024**2) if not chunk: diff --git a/python/triton/runtime/build.py b/python/triton/runtime/build.py index d7baeb2868b0..9726836a2939 100644 --- a/python/triton/runtime/build.py +++ b/python/triton/runtime/build.py @@ -18,6 +18,26 @@ def quiet(): sys.stdout, sys.stderr = old_stdout, old_stderr +def _cc_cmd(cc, src, out, include_dirs, library_dirs, libraries): + if cc in ["cl", "clang-cl"]: + cc_cmd = [cc, src, "/nologo", "/O2", "/LD"] + cc_cmd += [f"/I{dir}" for dir in include_dirs] + cc_cmd += ["/link"] + cc_cmd += [f"/LIBPATH:{dir}" for dir in library_dirs] + cc_cmd += [f'{lib}.lib' for lib in libraries] + cc_cmd += [f"/OUT:{out}"] + else: + cc_cmd = [cc, src, "-O3", "-shared", "-fPIC"] + cc_cmd += [f'-l{lib}' for lib in libraries] + cc_cmd += [f"-L{dir}" for dir in library_dirs] + cc_cmd += [f"-I{dir}" for dir in include_dirs] + cc_cmd += ["-o", out] + + if os.name == "nt": cc_cmd.pop(cc_cmd.index("-fPIC")) + + return cc_cmd + + def _build(name, src, srcdir, library_dirs, include_dirs, libraries): suffix = sysconfig.get_config_var('EXT_SUFFIX') so = os.path.join(srcdir, '{name}{suffix}'.format(name=name, suffix=suffix)) @@ -41,10 +61,7 @@ def _build(name, src, srcdir, library_dirs, include_dirs, libraries): scheme = 'posix_prefix' py_include_dir = sysconfig.get_paths(scheme=scheme)["include"] include_dirs = include_dirs + [srcdir, py_include_dir] - cc_cmd = [cc, src, "-O3", "-shared", "-fPIC", "-o", so] - cc_cmd += [f'-l{lib}' for lib in libraries] - cc_cmd += [f"-L{dir}" for dir in library_dirs] - cc_cmd += [f"-I{dir}" for dir in include_dirs] + cc_cmd = _cc_cmd(cc, src, so, include_dirs, library_dirs, libraries) ret = subprocess.check_call(cc_cmd) if ret == 0: return so @@ -58,7 +75,7 @@ def _build(name, src, srcdir, library_dirs, include_dirs, libraries): language='c', sources=[src], include_dirs=include_dirs, - extra_compile_args=extra_compile_args + ['-O3'], + extra_compile_args=extra_compile_args + ['-O3' if "-O3" in cc_cmd else "/O2"], extra_link_args=extra_link_args, library_dirs=library_dirs, libraries=libraries, diff --git a/python/triton/runtime/jit.py b/python/triton/runtime/jit.py index 6ec60a358cce..286d5c509d0a 100644 --- a/python/triton/runtime/jit.py +++ b/python/triton/runtime/jit.py @@ -124,12 +124,13 @@ def signature_key(self): annotation = self.param.annotation if "Tensor" in annotation: return self.value.dtype - elif annotation == "bool": - return "i1" - elif annotation == "float": - return "fp32" - else: - return JITFunction._key_of(self.value) + for ty1, ty2 in [("uint", 'u'), ("int", 'i')]: + width = annotation[annotation.find(ty1) + len(ty1):] + if width and ty1 in annotation: + return f"{ty2}{width}" + if annotation == "bool": + return "u1" + return JITFunction._key_of(self.value) def specialization_key(self): assert not self.param.do_not_specialize @@ -375,7 +376,7 @@ def run(self, *args, grid, warmup, **kwargs): # Build kernel signature -- doesn't include constexpr arguments. signature = { - arg.param.num: self._type_of(self._key_of(arg.value)) + arg.param.num: self._type_of(arg.signature_key()) for arg in args if not arg.param.is_constexpr } diff --git a/third_party/nvidia/backend/compiler.py b/third_party/nvidia/backend/compiler.py index bfed650edaf7..4497da9db52f 100644 --- a/third_party/nvidia/backend/compiler.py +++ b/third_party/nvidia/backend/compiler.py @@ -14,13 +14,17 @@ from pathlib import Path def _path_to_binary(binary: str): + binary += ".exe" if os.name == "nt" else "" paths = [ os.environ.get(f"TRITON_{binary.upper()}_PATH", ""), os.path.join(os.path.dirname(__file__), "bin", binary), ] for p in paths: - bin = p.split(" ")[0] + if os.name != "nt": + bin = p.split(" ")[0] + else: + bin = p if os.path.exists(bin) and os.path.isfile(bin): result = subprocess.check_output([bin, "--version"], stderr=subprocess.STDOUT) if result is not None: @@ -229,16 +233,19 @@ def make_cubin(src, metadata, opt, capability): fsrc.flush() fbin = fsrc.name + '.o' - line_info = '' if os.environ.get('TRITON_DISABLE_LINE_INFO') else ' -lineinfo' - fmad = '' if opt.enable_fp_fusion else ' --fmad=false' - suffix = 'a ' if capability == 90 else ' ' + line_info = '' if os.environ.get('TRITON_DISABLE_LINE_INFO') else '-lineinfo' + fmad = '' if opt.enable_fp_fusion else '--fmad=false' + suffix = 'a' if capability == 90 else '' + cmd = [ptxas] + cmd += [line_info] if line_info != '' else [] + cmd += [fmad] if fmad != '' else [] + cmd += ['-v'] if os.environ.get("DISABLE_PTXAS_OPT", "0") == "1": - cmd = f'{ptxas}{line_info}{fmad} -v --opt-level 0 --gpu-name=sm_{capability}{suffix}{fsrc.name} -o {fbin} 2> {flog.name}' - else: - cmd = f'{ptxas}{line_info}{fmad} -v --gpu-name=sm_{capability}{suffix}{fsrc.name} -o {fbin} 2> {flog.name}' + cmd += ["-opt-level", "0"] + cmd += [f'--gpu-name=sm_{capability}{suffix}', fsrc.name, '-o', fbin] try: - subprocess.run(cmd, shell=True, check=True) + subprocess.run(cmd, check=True, stderr=flog) except subprocess.CalledProcessError as e: with open(flog.name) as log_file: log = log_file.read() @@ -249,16 +256,17 @@ def make_cubin(src, metadata, opt, capability): f'Please run `ptxas {fsrc.name}` to confirm that this is a bug in `ptxas`\n{log}') else: raise RuntimeError(f'`ptxas` failed with error code {e.returncode}: \n{log}') - finally: - if os.path.exists(fsrc.name): - os.remove(fsrc.name) - if os.path.exists(flog.name): - os.remove(flog.name) with open(fbin, 'rb') as f: cubin = f.read() if os.path.exists(fbin): os.remove(fbin) + + if os.path.exists(fsrc.name): + os.remove(fsrc.name) + if os.path.exists(flog.name): + os.remove(flog.name) + return cubin def add_stages(self, stages, options): diff --git a/third_party/nvidia/backend/driver.c b/third_party/nvidia/backend/driver.c index 037cfca1d0e4..64eaf0771862 100644 --- a/third_party/nvidia/backend/driver.c +++ b/third_party/nvidia/backend/driver.c @@ -1,5 +1,10 @@ #include "cuda.h" +#ifndef _WIN32 #include +#else +#define WIN32_LEAN_AND_MEAN +#include +#endif #include #define PY_SSIZE_T_CLEAN #include @@ -135,6 +140,7 @@ static PyObject *loadBinary(PyObject *self, PyObject *args) { typedef CUresult (*cuOccupancyMaxActiveClusters_t)( int *numClusters, CUfunction func, const CUlaunchConfig *config); +#ifndef _WIN32 #define defineGetFunctionHandle(name, symbolName) \ static symbolName##_t name() { \ /* Open the shared library */ \ @@ -156,6 +162,27 @@ typedef CUresult (*cuOccupancyMaxActiveClusters_t)( } \ return funcHandle; \ } +#else +#define defineGetFunctionHandle(name, symbolName) \ + static symbolName##_t name() { \ + /* Open the shared library */ \ + HMODULE handle = LoadLibraryA("nvcuda.dll"); \ + if (!handle) { \ + PyErr_SetString(PyExc_RuntimeError, "Failed to open nvcuda.dll"); \ + return NULL; \ + } \ + symbolName##_t funcHandle = \ + (symbolName##_t)GetProcAddress((HMODULE)handle, #symbolName); \ + /* Check for errors */ \ + long err = GetLastError(); \ + if (err) { \ + PyErr_SetString(PyExc_RuntimeError, \ + "Failed to retrieve " #symbolName " from nvcuda.dll"); \ + return NULL; \ + } \ + return funcHandle; \ + } +#endif defineGetFunctionHandle(getCuOccupancyMaxActiveClustersHandle, cuOccupancyMaxActiveClusters); diff --git a/third_party/nvidia/backend/driver.py b/third_party/nvidia/backend/driver.py index 6e69d0d0e4db..e694469dd04d 100644 --- a/third_party/nvidia/backend/driver.py +++ b/third_party/nvidia/backend/driver.py @@ -1,6 +1,7 @@ import functools import os import hashlib +import sysconfig import subprocess import tempfile from pathlib import Path @@ -13,12 +14,20 @@ libdevice_dir = os.path.join(dirname, "lib") libraries = ['cuda'] +if os.name == "nt": + include_dir += [os.path.join(os.environ.get("CUDA_PATH"), "include")] + @functools.lru_cache() def libcuda_dirs(): env_libcuda_path = os.getenv("TRITON_LIBCUDA_PATH") if env_libcuda_path: return [env_libcuda_path] + if os.name == "nt": + installed_base = sysconfig.get_config_var('installed_base') + dirs = [os.path.join(os.environ.get("CUDA_PATH"), "lib", "x64")] + dirs += [os.getenv("PYTHON_LIB_DIRS", os.path.join(installed_base, "libs"))] + return dirs libs = subprocess.check_output(["/sbin/ldconfig", "-p"]).decode() # each line looks like the following: @@ -47,7 +56,8 @@ def library_dirs(): def compile_module_from_src(src, name): key = hashlib.md5(src.encode("utf-8")).hexdigest() cache = get_cache_manager(key) - cache_path = cache.get_file(f"{name}.so") + so_name = f'{name}.{"so" if os.name != "nt" else "pyd"}' + cache_path = cache.get_file(so_name) if cache_path is None: with tempfile.TemporaryDirectory() as tmpdir: src_path = os.path.join(tmpdir, "main.c") @@ -55,7 +65,7 @@ def compile_module_from_src(src, name): f.write(src) so = _build(name, src_path, tmpdir, library_dirs(), include_dir, libraries) with open(so, "rb") as f: - cache_path = cache.put(f.read(), f"{name}.so", binary=True) + cache_path = cache.put(f.read(), so_name, binary=True) import importlib.util spec = importlib.util.spec_from_file_location(name, cache_path) mod = importlib.util.module_from_spec(spec) @@ -96,6 +106,9 @@ def ty_to_cpp(ty): "i16": "int16_t", "i32": "int32_t", "i64": "int64_t", + "u1": "uint32_t", + "u8": "uint8_t", + "u16": "uint16_t", "u32": "uint32_t", "u64": "uint64_t", "fp16": "float", @@ -115,18 +128,7 @@ def make_launcher(constants, signature, ids): def _extracted_type(ty): if ty[0] == '*': return "PyObject*" - return { - 'i1': 'int32_t', - 'i32': 'int32_t', - 'i64': 'int64_t', - 'u32': 'uint32_t', - 'u64': 'uint64_t', - 'fp16': 'float', - 'bf16': 'float', - 'fp32': 'float', - 'f32': 'float', - 'fp64': 'double', - }[ty] + return ty_to_cpp(ty) def format_of(ty): return { @@ -134,10 +136,14 @@ def format_of(ty): "float": "f", "double": "d", "long": "l", - "uint32_t": "I", + "int8_t": "b", + "int16_t": "h", "int32_t": "i", + "int64_t": "l", + "uint8_t": "B", + "uint16_t": "H", + "uint32_t": "I", "uint64_t": "K", - "int64_t": "L", }[ty] format = "iiiiiiiiiKKOOO" + ''.join([format_of(_extracted_type(ty)) for ty in signature.values()]) @@ -151,7 +157,12 @@ def format_of(ty): #include \"cuda.h\" #include #include +#ifndef _WIN32 #include +#else +#define WIN32_LEAN_AND_MEAN +#include +#endif static inline void gpuAssert(CUresult code, const char *file, int line) {{ @@ -174,6 +185,7 @@ def format_of(ty): typedef CUresult (*cuLaunchKernelEx_t)(const CUlaunchConfig* config, CUfunction f, void** kernelParams, void** extra); +#ifndef _WIN32 static cuLaunchKernelEx_t getLaunchKernelExHandle() {{ // Open the shared library void* handle = dlopen("libcuda.so", RTLD_LAZY); @@ -192,6 +204,25 @@ def format_of(ty): }} return cuLaunchKernelExHandle; }} +#else +static cuLaunchKernelEx_t getLaunchKernelExHandle() {{ + // Open the shared library + HMODULE handle = LoadLibraryA("nvcuda.dll"); + if (!handle) {{ + PyErr_SetString(PyExc_RuntimeError, "Failed to open nvcuda.dll"); + return NULL; + }} + cuLaunchKernelEx_t cuLaunchKernelExHandle = + (cuLaunchKernelEx_t)GetProcAddress((HMODULE)handle, "cuLaunchKernelEx"); + // Check for errors + long error = GetLastError(); + if (error) {{ + PyErr_SetString(PyExc_RuntimeError, "Failed to retrieve cuLaunchKernelEx from nvcuda.dll"); + return NULL; + }} + return cuLaunchKernelExHandle; +}} +#endif static void _launch(int gridX, int gridY, int gridZ, int num_warps, int num_ctas, int clusterDimX, int clusterDimY, int clusterDimZ, int shared_memory, CUstream stream, CUfunction function{', ' + arg_decls if len(arg_decls) > 0 else ''}) {{ void *params[] = {{ {', '.join(f"&arg{i}" for i in params)} }};