Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix bug w.r.t. using gradient_checkpoint without tuning embed_tokens and Fix typo of template in langchain_qa.py #175

Merged
merged 3 commits into from
Aug 28, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion scripts/langchain/langchain_qa.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@
"现在还有一些文字,(如果有需要)你可以根据它们完善现有的回答。"
"\n\n"
"{context_str}\n"
"\\nn"
"\n\n"
"请根据新的文段,进一步完善你的回答。"
" [/INST]"
)
Expand Down
11 changes: 11 additions & 0 deletions scripts/training/run_clm_pt_with_peft.py
Original file line number Diff line number Diff line change
Expand Up @@ -575,6 +575,17 @@ def group_texts(examples):
lora_dropout=lora_dropout,
modules_to_save=modules_to_save)
model = get_peft_model(model, peft_config)

if training_args.gradient_checkpointing and \
(not model.modules_to_save or 'embed_tokens' not in model.modules_to_save):
# enable requires_grad to avoid exception during backward pass when using gradient_checkpoint without tuning embed.
if hasattr(model.base_model, "enable_input_require_grads"):
model.base_model.enable_input_require_grads()
elif hasattr(model.base_model, "get_input_embeddings"):
def make_inputs_require_grad(_module, _input, _output):
_output.requires_grad_(True)
model.base_model.get_input_embeddings().register_forward_hook(make_inputs_require_grad)

model.print_trainable_parameters()
old_state_dict = model.state_dict
model.state_dict = (
Expand Down
10 changes: 10 additions & 0 deletions scripts/training/run_clm_sft_with_peft.py
Original file line number Diff line number Diff line change
Expand Up @@ -382,6 +382,16 @@ def main():
modules_to_save=modules_to_save)
model = get_peft_model(model, peft_config)

if training_args.gradient_checkpointing and \
(not model.modules_to_save or 'embed_tokens' not in model.modules_to_save):
# enable requires_grad to avoid exception during backward pass when using gradient_checkpoint without tuning embed.
if hasattr(model.base_model, "enable_input_require_grads"):
model.base_model.enable_input_require_grads()
elif hasattr(model.base_model, "get_input_embeddings"):
def make_inputs_require_grad(_module, _input, _output):
_output.requires_grad_(True)
model.base_model.get_input_embeddings().register_forward_hook(make_inputs_require_grad)

#model.base_model.tie_weights()
model.print_trainable_parameters()
logger.info(f"model.modules_to_save: {model.modules_to_save}")
Expand Down