-
Notifications
You must be signed in to change notification settings - Fork 27
/
memory_management.py
67 lines (48 loc) · 1.7 KB
/
memory_management.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
import torch
from contextlib import contextmanager
high_vram = False
gpu = torch.device('cuda')
cpu = torch.device('cpu')
torch.zeros((1, 1)).to(gpu, torch.float32)
torch.cuda.empty_cache()
models_in_gpu = []
@contextmanager
def movable_bnb_model(m):
if hasattr(m, 'quantization_method'):
m.quantization_method_backup = m.quantization_method
del m.quantization_method
try:
yield None
finally:
if hasattr(m, 'quantization_method_backup'):
m.quantization_method = m.quantization_method_backup
del m.quantization_method_backup
return
def load_models_to_gpu(models):
global models_in_gpu
if not isinstance(models, (tuple, list)):
models = [models]
models_to_remain = [m for m in set(models) if m in models_in_gpu]
models_to_load = [m for m in set(models) if m not in models_in_gpu]
models_to_unload = [m for m in set(models_in_gpu) if m not in models_to_remain]
if not high_vram:
for m in models_to_unload:
with movable_bnb_model(m):
m.to(cpu)
print('Unload to CPU:', m.__class__.__name__)
models_in_gpu = models_to_remain
for m in models_to_load:
with movable_bnb_model(m):
m.to(gpu)
print('Load to GPU:', m.__class__.__name__)
models_in_gpu = list(set(models_in_gpu + models))
torch.cuda.empty_cache()
return
def unload_all_models(extra_models=None):
global models_in_gpu
if extra_models is None:
extra_models = []
if not isinstance(extra_models, (tuple, list)):
extra_models = [extra_models]
models_in_gpu = list(set(models_in_gpu + extra_models))
return load_models_to_gpu([])