-
Notifications
You must be signed in to change notification settings - Fork 2
/
utils.py
executable file
·187 lines (163 loc) · 8.35 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
import functools
import torch
import functools
import warnings
import logging
import os
import shutil
import json
import torch
from torch.distributed.fsdp.wrap import (
transformer_auto_wrap_policy,
)
import torch.distributed as dist
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP, ShardedStateDictConfig, MixedPrecision
from torch.distributed.fsdp.api import StateDictType
from accelerate import init_empty_weights, load_checkpoint_and_dispatch
import transformers
import torch.distributed._shard.checkpoint as dist_cp
warnings.filterwarnings("ignore", message="TypedStorage is deprecated.*")
def get_gpu_memory_usage(rank):
return {
'total': round(torch.cuda.get_device_properties(rank).total_memory / (1024**3), 2),
'max': round(torch.cuda.max_memory_allocated(rank) / (1024**3), 2),
'reserved': round(torch.cuda.memory_reserved(rank) / (1024**3), 2),
'allocated': round(torch.cuda.memory_allocated(rank) / (1024**3), 2)
}
def get_fsdp_wrapped_empty_model(model_config, wrapped_cls, hack=False):
with init_empty_weights():
if hack:
model = transformers.AutoModelForCausalLM.from_config(model_config).bfloat16()
else:
model = transformers.AutoModelForCausalLM.from_config(model_config)
make_nonpersistent_buffer_persistent(model)
model.reset_parameters = lambda: None
wrapped_cls.reset_parameters = lambda x: None
torch.nn.Embedding.reset_parameters = lambda x: None
my_auto_wrap_policy = functools.partial(
transformer_auto_wrap_policy,
transformer_layer_cls=set([wrapped_cls, torch.nn.Embedding]),
)
bf16 = MixedPrecision(param_dtype=torch.bfloat16, reduce_dtype=torch.bfloat16, buffer_dtype=torch.bfloat16)
model = FSDP(model, auto_wrap_policy=my_auto_wrap_policy, device_id=torch.cuda.current_device(), mixed_precision=bf16)
return model
def load_fsdp_ckpt_with_accelerate(fsdp_path, model_config, hf_dummy_path, wrapped_class):
with init_empty_weights():
model_empty = transformers.AutoModelForCausalLM.from_config(model_config)
model_empty = model_empty.bfloat16()
model = load_checkpoint_and_dispatch(model_empty, hf_dummy_path, device_map="auto", no_split_module_classes=[wrapped_class, torch.nn.Embedding])
current_vocab_size_based_on_weight = model.get_output_embeddings().weight.shape[0]
if model.config.vocab_size != current_vocab_size_based_on_weight:
model.resize_token_embeddings(model.config.vocab_size)
print("device map used by accelerate:\n", json.dumps(model.hf_device_map, indent=4))
model = load_state_dict_fsdp(model, fsdp_path, offload_to_cpu=False, no_dist=True)
return model
def load_state_dict_fsdp(model, load_path, offload_to_cpu=True, no_dist=False):
if no_dist:
checkpoint = model.state_dict()
dist_cp.load_state_dict(
state_dict=checkpoint,
storage_reader=dist_cp.FileSystemReader(load_path),
no_dist=no_dist
)
model.load_state_dict(checkpoint)
else:
with FSDP.state_dict_type(model, StateDictType.SHARDED_STATE_DICT, ShardedStateDictConfig(offload_to_cpu=offload_to_cpu)):
checkpoint = model.state_dict()
dist_cp.load_state_dict(
state_dict=checkpoint,
storage_reader=dist_cp.FileSystemReader(load_path),
no_dist=no_dist
)
model.load_state_dict(checkpoint)
return model
def save_state_dict_fsdp(model, save_path, offload_to_cpu=True):
with FSDP.state_dict_type(model, StateDictType.SHARDED_STATE_DICT, ShardedStateDictConfig(offload_to_cpu=offload_to_cpu)):
checkpoint = model.state_dict()
dist_cp.save_state_dict(
state_dict=checkpoint,
storage_writer=dist_cp.FileSystemWriter(save_path),
)
return model
def save_model_to_fsdp_format(model, save_path):
with LogLevelContext(logging.ERROR):
warnings.filterwarnings("ignore", message="TypedStorage is deprecated.*")
dist_cp.save_state_dict(
state_dict=model.state_dict(),
storage_writer=dist_cp.FileSystemWriter(save_path),
no_dist=True
)
def save_opt_or_scheduler_fsdp(model, opt, save_path, rank, offload_to_cpu=True):
with FSDP.state_dict_type(model, StateDictType.SHARDED_STATE_DICT, ShardedStateDictConfig(offload_to_cpu=offload_to_cpu)):
torch.save(opt.state_dict(), os.path.join(save_path, f"shard{rank}.pt"))
def load_opt_or_scheduler_fsdp(model, opt, save_path, rank, offload_to_cpu=True):
with FSDP.state_dict_type(model, StateDictType.SHARDED_STATE_DICT, ShardedStateDictConfig(offload_to_cpu=offload_to_cpu)):
state_dict = torch.load(os.path.join(save_path, f"shard{rank}.pt"))
opt.load_state_dict(state_dict)
def save_model_opt_scheduler_states_fsdp(model, opt, scheduler, step_count, checkpoint_path, rank, dont_save_opt=False, keep_old_checkpoints=False):
path = os.path.join(checkpoint_path, str(step_count), "model")
save_state_dict_fsdp(model, path)
if not dont_save_opt:
path = os.path.join(checkpoint_path, str(step_count), "opt")
os.makedirs(path, exist_ok=True)
save_opt_or_scheduler_fsdp(model, opt, path, rank)
path = os.path.join(checkpoint_path, str(step_count), "scheduler")
os.makedirs(path, exist_ok=True)
save_opt_or_scheduler_fsdp(model, scheduler, path, rank)
if rank == 0:
if not keep_old_checkpoints:
remove_all_folders_except_the_latest_step_count(checkpoint_path, step_count)
def load_model_opt_scheduler_states_fsdp(model, opt, scheduler, checkpoint_path):
last_checkpoint_path, start_step_count = get_last_checkpoint_path_and_last_step_count(checkpoint_path)
path = os.path.join(last_checkpoint_path, "model")
load_state_dict_fsdp(model, path)
rank = torch.cuda.current_device()
path = os.path.join(last_checkpoint_path, "opt")
load_opt_or_scheduler_fsdp(model, opt, path, rank)
path = os.path.join(last_checkpoint_path, "scheduler")
load_opt_or_scheduler_fsdp(model, scheduler, path, rank)
return model, opt, scheduler, start_step_count
def remove_all_folders_except_the_latest_step_count(checkpoint_path, cur_step_count):
if os.path.exists(checkpoint_path):
for folder in os.listdir(checkpoint_path):
if int(folder) != cur_step_count:
shutil.rmtree(os.path.join(checkpoint_path, folder))
def get_last_checkpoint_path_and_last_step_count(checkpoint_path):
if os.path.exists(checkpoint_path):
last_step_count = max(int(x) for x in os.listdir(checkpoint_path))
last_checkpoint_path = os.path.join(checkpoint_path, str(last_step_count))
return last_checkpoint_path, last_step_count+1
return None, None
def setup(rank, world_size):
os.environ['MASTER_ADDR'] = 'localhost'
os.environ['MASTER_PORT'] = '12355'
dist.init_process_group("nccl", rank=rank, world_size=world_size)
def cleanup():
dist.destroy_process_group()
def get_all_existing_loggers():
return logging.Logger.manager.loggerDict.values()
def add_padding_token(tokenizer):
print("attempt to add padding token if no padding token exists")
print("Special tokens before adding padding token: ", tokenizer.special_tokens_map)
if not tokenizer.pad_token:
tokenizer.add_special_tokens({"pad_token": "[PAD]"})
print("Special tokens after adding padding token: ", tokenizer.special_tokens_map)
return tokenizer
def make_nonpersistent_buffer_persistent(model):
for name, module in model.named_modules():
if hasattr(module, "_non_persistent_buffers_set") and len(module._non_persistent_buffers_set) > 0:
print(f"moving non-persistent buffers to persistent buffers for module {name}")
module._persistent_buffers_set = module._non_persistent_buffers_set
module._non_persistent_buffers_set = set()
class LogLevelContext:
def __init__(self, level):
self.level = level
self.original_levels = {}
def __enter__(self):
for logger in get_all_existing_loggers():
if isinstance(logger, logging.Logger):
self.original_levels[logger] = logger.level
logger.setLevel(self.level)
def __exit__(self, exc_type, exc_value, traceback):
for logger, original_level in self.original_levels.items():
logger.setLevel(original_level)