Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[tests] make cuda-only cases in TestModelAndLayerStatus device-agnostic #2026

Merged
merged 3 commits into from
Aug 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 30 additions & 28 deletions tests/test_tuners_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,9 +42,9 @@
check_target_module_exists,
inspect_matched_modules,
)
from peft.utils import INCLUDE_LINEAR_LAYERS_SHORTHAND
from peft.utils import INCLUDE_LINEAR_LAYERS_SHORTHAND, infer_device

from .testing_utils import require_bitsandbytes, require_torch_gpu
from .testing_utils import require_bitsandbytes, require_non_cpu, require_torch_gpu


# Implements tests for regex matching logic common for all BaseTuner subclasses, and
Expand Down Expand Up @@ -395,6 +395,8 @@ class TestModelAndLayerStatus:

"""

torch_device = infer_device()

@pytest.fixture
def small_model(self):
class SmallModel(nn.Module):
Expand Down Expand Up @@ -591,27 +593,27 @@ def test_devices_all_cpu_large(self, large_model):
]
assert result == expected

@pytest.mark.skipif(not torch.cuda.is_available(), reason="No CUDA device available.")
def test_devices_all_cuda_large(self, large_model):
large_model.to("cuda")
@require_non_cpu
def test_devices_all_gpu_large(self, large_model):
large_model.to(self.torch_device)
layer_status = large_model.get_layer_status()
result = [status.devices for status in layer_status]
expected = [
{"default": ["cuda"], "other": ["cuda"]},
{"default": ["cuda"]},
{"other": ["cuda"]},
{"default": ["cuda"]},
{"default": [self.torch_device], "other": [self.torch_device]},
{"default": [self.torch_device]},
{"other": [self.torch_device]},
{"default": [self.torch_device]},
]
assert result == expected

@pytest.mark.skipif(not torch.cuda.is_available(), reason="No CUDA device available.")
def test_devices_cpu_and_cuda_large(self, large_model):
# move the embedding layer to CUDA
large_model.model.lin0.lora_A["default"] = large_model.model.lin0.lora_A["default"].to("cuda")
@require_non_cpu
def test_devices_cpu_and_gpu_large(self, large_model):
# move the embedding layer to GPU
large_model.model.lin0.lora_A["default"] = large_model.model.lin0.lora_A["default"].to(self.torch_device)
layer_status = large_model.get_layer_status()
result = [status.devices for status in layer_status]
expected = [
{"default": ["cpu", "cuda"], "other": ["cpu"]},
{"default": ["cpu", self.torch_device], "other": ["cpu"]},
{"default": ["cpu"]},
{"other": ["cpu"]},
{"default": ["cpu"]},
Expand Down Expand Up @@ -806,18 +808,18 @@ def test_model_devices_all_cpu_large(self, large_model):
model_status = large_model.get_model_status()
assert model_status.devices == {"default": ["cpu"], "other": ["cpu"]}

@pytest.mark.skipif(not torch.cuda.is_available(), reason="No CUDA device available.")
def test_model_devices_all_cuda_large(self, large_model):
large_model.to("cuda")
@require_non_cpu
def test_model_devices_all_gpu_large(self, large_model):
large_model.to(self.torch_device)
model_status = large_model.get_model_status()
assert model_status.devices == {"default": ["cuda"], "other": ["cuda"]}
assert model_status.devices == {"default": [self.torch_device], "other": [self.torch_device]}

@pytest.mark.skipif(not torch.cuda.is_available(), reason="No CUDA device available.")
def test_model_devices_cpu_and_cuda_large(self, large_model):
# move the embedding layer to CUDA
large_model.model.lin0.lora_A["default"] = large_model.model.lin0.lora_A["default"].to("cuda")
@require_non_cpu
def test_model_devices_cpu_and_gpu_large(self, large_model):
# move the embedding layer to GPU
large_model.model.lin0.lora_A["default"] = large_model.model.lin0.lora_A["default"].to(self.torch_device)
model_status = large_model.get_model_status()
assert model_status.devices == {"default": ["cpu", "cuda"], "other": ["cpu"]}
assert model_status.devices == {"default": ["cpu", self.torch_device], "other": ["cpu"]}

def test_loha_model(self):
# ensure that this also works with non-LoRA, it's not necessary to test all tuners
Expand Down Expand Up @@ -858,7 +860,7 @@ def __init__(self):
assert layer_status0.available_adapters == ["default"]
assert layer_status0.devices == {"default": ["cpu"]}

@pytest.mark.skipif(not torch.cuda.is_available(), reason="No CUDA device available.")
@require_non_cpu
def test_vera_model(self):
# let's also test VeRA because it uses BufferDict
class SmallModel(nn.Module):
Expand All @@ -871,8 +873,8 @@ def __init__(self):
config = VeraConfig(target_modules=["lin0", "lin1"], init_weights=False)
model = get_peft_model(base_model, config)

# move the buffer dict to CUDA
model.lin0.vera_A["default"] = model.lin0.vera_A["default"].to("cuda")
# move the buffer dict to GPU
model.lin0.vera_A["default"] = model.lin0.vera_A["default"].to(self.torch_device)

model_status = model.get_model_status()
layer_status = model.get_layer_status()
Expand All @@ -888,7 +890,7 @@ def __init__(self):
assert model_status.merged_adapters == []
assert model_status.requires_grad == {"default": True}
assert model_status.available_adapters == ["default"]
assert model_status.devices == {"default": ["cpu", "cuda"]}
assert model_status.devices == {"default": ["cpu", self.torch_device]}

layer_status0 = layer_status[0]
assert len(layer_status) == 2
Expand All @@ -899,7 +901,7 @@ def __init__(self):
assert layer_status0.merged_adapters == []
assert layer_status0.requires_grad == {"default": True}
assert layer_status0.available_adapters == ["default"]
assert layer_status0.devices == {"default": ["cpu", "cuda"]}
assert layer_status0.devices == {"default": ["cpu", self.torch_device]}

###################
# non-PEFT models #
Expand Down
12 changes: 12 additions & 0 deletions tests/testing_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import numpy as np
import pytest
import torch
from accelerate.test_utils.testing import get_backend

from peft.import_utils import (
is_aqlm_available,
Expand All @@ -28,6 +29,17 @@
)


torch_device, device_count, memory_allocated_func = get_backend()


def require_non_cpu(test_case):
"""
Decorator marking a test that requires a hardware accelerator backend. These tests are skipped when there are no
hardware accelerator available.
"""
return unittest.skipUnless(torch_device != "cpu", "test requires a hardware accelerator")(test_case)


def require_torch_gpu(test_case):
"""
Decorator marking a test that requires a GPU. Will be skipped when no GPU is available.
Expand Down
Loading