From dd3f0441eeee0bf6623d9d4270302ae45fc91bee Mon Sep 17 00:00:00 2001 From: Maximilian Maag Date: Fri, 18 Oct 2024 15:26:11 +0200 Subject: [PATCH] fix(installer): pytorch and ROCm versions are incompatible Each version of torch is only available for specific versions of CUDA and ROCm. The Invoke installer and dockerfile try to install torch 2.4.1 with ROCm 5.6 support, which does not exist. As a result, the installation falls back to the default CUDA version so AMD GPUs aren't detected. This commits fixes that by bumping the ROCm version to 6.1, as suggested by the PyTorch documentation. [1] The specified CUDA version of 12.4 is still correct according to [1] so it does need to be changed. Closes #7006 Closes #7146 [1]: https://pytorch.org/get-started/previous-versions/#v241 --- docker/Dockerfile | 2 +- installer/lib/installer.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/docker/Dockerfile b/docker/Dockerfile index 9dac0dcefe0..2e9c22e5b2e 100644 --- a/docker/Dockerfile +++ b/docker/Dockerfile @@ -38,7 +38,7 @@ RUN --mount=type=cache,target=/root/.cache/pip \ if [ "$TARGETPLATFORM" = "linux/arm64" ] || [ "$GPU_DRIVER" = "cpu" ]; then \ extra_index_url_arg="--extra-index-url https://download.pytorch.org/whl/cpu"; \ elif [ "$GPU_DRIVER" = "rocm" ]; then \ - extra_index_url_arg="--extra-index-url https://download.pytorch.org/whl/rocm5.6"; \ + extra_index_url_arg="--extra-index-url https://download.pytorch.org/whl/rocm6.1"; \ else \ extra_index_url_arg="--extra-index-url https://download.pytorch.org/whl/cu124"; \ fi &&\ diff --git a/installer/lib/installer.py b/installer/lib/installer.py index bff97eca934..df90ba8c616 100644 --- a/installer/lib/installer.py +++ b/installer/lib/installer.py @@ -410,7 +410,7 @@ def get_torch_source() -> Tuple[str | None, str | None]: optional_modules: str | None = None if OS == "Linux": if device == GpuType.ROCM: - url = "https://download.pytorch.org/whl/rocm5.6" + url = "https://download.pytorch.org/whl/rocm6.1" elif device == GpuType.CPU: url = "https://download.pytorch.org/whl/cpu" elif device == GpuType.CUDA: