Skip to content

Commit

Permalink
fix loraga
Browse files Browse the repository at this point in the history
  • Loading branch information
greycooker committed Dec 16, 2024
1 parent 4d7603f commit 86c5b33
Show file tree
Hide file tree
Showing 3 changed files with 28 additions and 16 deletions.
11 changes: 6 additions & 5 deletions paddlenlp/peft/lora/lora_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,7 @@ def __init__(self, model, lora_config: LoRAConfig) -> None:
self.forward = self.model.forward
if lora_config.loraga:
self.loraga_init_dict = {}
self.reinit_base_model = False

Check warning on line 168 in paddlenlp/peft/lora/lora_model.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/peft/lora/lora_model.py#L167-L168

Added lines #L167 - L168 were not covered by tests

logger.info("Mark only lora and trainable_module as trainable.")
self.mark_only_lora_as_trainable()
Expand Down Expand Up @@ -349,11 +350,11 @@ def process_split_and_assign(name, concat_tensor, axis, init_dict, state_dict):
)

base_name = name.replace("lora_A", "weight")

# Reinit base model
offset = init_loraA.cuda() @ init_loraB.cuda()
ori_weight = model_state_dict[base_name]
model_state_dict[base_name].set_value(ori_weight - self.lora_config.scaling * offset)
if not self.reinit_base_model:

Check warning on line 353 in paddlenlp/peft/lora/lora_model.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/peft/lora/lora_model.py#L352-L353

Added lines #L352 - L353 were not covered by tests
# Reinit base model
offset = init_loraA.cuda() @ init_loraB.cuda()
ori_weight = model_state_dict[base_name]
model_state_dict[base_name].set_value(ori_weight - self.lora_config.scaling * offset)

Check warning on line 357 in paddlenlp/peft/lora/lora_model.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/peft/lora/lora_model.py#L355-L357

Added lines #L355 - L357 were not covered by tests
del model_state_dict
gc.collect()
self.model.set_state_dict(state_dict)
Expand Down
29 changes: 18 additions & 11 deletions paddlenlp/peft/lora/loraga_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,10 +74,9 @@ def estimate_gradient(self, model: PretrainedModel):
):
for batch in dataloader:
iters += 1

Check warning on line 76 in paddlenlp/peft/lora/loraga_utils.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/peft/lora/loraga_utils.py#L75-L76

Added lines #L75 - L76 were not covered by tests
batch = {k: paddle.to_tensor(v) for k, v in batch.items()}

# Pipeline parallel not supported currently
loss, logits = model(**batch)
with paddle.amp.auto_cast(enable=True, custom_black_list=self.args.amp_custom_black_list):
loss, logits = model(**batch)
loss.backward()

Check warning on line 80 in paddlenlp/peft/lora/loraga_utils.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/peft/lora/loraga_utils.py#L78-L80

Added lines #L78 - L80 were not covered by tests

if iters == self.loraga_init_iters:
Expand Down Expand Up @@ -160,13 +159,21 @@ def get_module_gradient(
rank_suffix = "_" + str(local_rank)
local_grad_name = ".".join(grad_name.split(".")[1:]) + ".weight" + rank_suffix
gradient = gradient_dict.pop(local_grad_name).cuda()

Check warning on line 161 in paddlenlp/peft/lora/loraga_utils.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/peft/lora/loraga_utils.py#L159-L161

Added lines #L159 - L161 were not covered by tests

is_fleet_init = True
try:
hcg = fleet.get_hybrid_communicate_group()
model_parallel_group = hcg.get_model_parallel_group()
sharding_parallel_group = hcg.get_sharding_parallel_group()
data_parallel_group = hcg.get_data_parallel_group()
except:
is_fleet_init = False

Check warning on line 170 in paddlenlp/peft/lora/loraga_utils.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/peft/lora/loraga_utils.py#L163-L170

Added lines #L163 - L170 were not covered by tests

if tp_degree > 1:

Check warning on line 172 in paddlenlp/peft/lora/loraga_utils.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/peft/lora/loraga_utils.py#L172

Added line #L172 was not covered by tests
# remove prefix and suffix in name
model_split_key = local_grad_name.split(base_model_prefix)[-1].rsplit(rank_suffix, 1)[0]
if model_split_key in base_model_split_mappings:
merge_func = base_model_split_mappings[model_split_key]
hcg = fleet.get_hybrid_communicate_group()
model_parallel_group = hcg.get_model_parallel_group()
output_tensors = []
dist.all_gather(output_tensors, gradient, group=model_parallel_group)

Check warning on line 178 in paddlenlp/peft/lora/loraga_utils.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/peft/lora/loraga_utils.py#L174-L178

Added lines #L174 - L178 were not covered by tests

Expand All @@ -175,18 +182,17 @@ def get_module_gradient(

# sharding
if sharding_degree > 1:
hcg = fleet.get_hybrid_communicate_group()
sharding_parallel_group = hcg.get_sharding_parallel_group()
if sharding_parallel_group.nranks > 1:

dist.all_reduce(gradient, op=dist.ReduceOp.SUM, group=sharding_parallel_group)
gradient /= sharding_parallel_group.nranks

Check warning on line 187 in paddlenlp/peft/lora/loraga_utils.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/peft/lora/loraga_utils.py#L184-L187

Added lines #L184 - L187 were not covered by tests

# dp
if dp_degree > 1:
hcg = fleet.get_hybrid_communicate_group()
data_parallel_group = hcg.get_data_parallel_group()
if data_parallel_group.nranks > 1:
dist.all_reduce(gradient, op=dist.ReduceOp.SUM, group=data_parallel_group)
if is_fleet_init:
dist.all_reduce(gradient, op=dist.ReduceOp.SUM, group=data_parallel_group)

Check warning on line 193 in paddlenlp/peft/lora/loraga_utils.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/peft/lora/loraga_utils.py#L190-L193

Added lines #L190 - L193 were not covered by tests
else:
dist.all_reduce(gradient, op=dist.ReduceOp.SUM)
gradient /= data_parallel_group.nranks
return gradient

Check warning on line 197 in paddlenlp/peft/lora/loraga_utils.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/peft/lora/loraga_utils.py#L195-L197

Added lines #L195 - L197 were not covered by tests

Expand Down Expand Up @@ -250,6 +256,7 @@ def loraga_svd_reinit(
lora_split_mapping,
**kwargs,
)
model.reinit_base_model = True
model.loraga_init_dict = loraga_init_dict

Check warning on line 260 in paddlenlp/peft/lora/loraga_utils.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/peft/lora/loraga_utils.py#L259-L260

Added lines #L259 - L260 were not covered by tests


Expand Down
4 changes: 4 additions & 0 deletions paddlenlp/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -623,6 +623,8 @@ def _load_from_checkpoint(self, resume_from_checkpoint=None):
self.model,
resume_from_checkpoint,
)
if isinstance(self.model, LoRAModel) and self.model.lora_config.loraga:
self.model.reinit_base_model = True

Check warning on line 627 in paddlenlp/trainer/trainer.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/trainer.py#L626-L627

Added lines #L626 - L627 were not covered by tests
logger.info(f"Loading model from {resume_from_checkpoint} using unified checkpoint.")
self.runtime_timer.stop()
return
Expand All @@ -635,6 +637,8 @@ def _load_from_checkpoint(self, resume_from_checkpoint=None):
or isinstance(self.model, ReFTModel)
):
self._load_from_peft_checkpoint(resume_from_checkpoint)
if isinstance(self.model, LoRAModel) and self.model.lora_config.loraga:
self.model.reinit_base_model = True

Check warning on line 641 in paddlenlp/trainer/trainer.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/trainer.py#L641

Added line #L641 was not covered by tests
self.runtime_timer.stop()
return

Expand Down

0 comments on commit 86c5b33

Please sign in to comment.