Skip to content

Commit

Permalink
update peft.
Browse files Browse the repository at this point in the history
  • Loading branch information
shibing624 committed May 9, 2023
1 parent d8c81ae commit 968a3f8
Show file tree
Hide file tree
Showing 3 changed files with 21 additions and 15 deletions.
12 changes: 7 additions & 5 deletions textgen/bloom/bloom_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,8 +206,10 @@ def train_model(

# setup peft
if self.args.use_peft:
peft_type = self.args.peft_type.upper()
logger.info(f"Using PEFT type: {peft_type}")
# add peft config
if self.args.peft_type == 'LORA':
if peft_type == 'LORA':
peft_config = LoraConfig(
task_type=TaskType.CAUSAL_LM,
inference_mode=False,
Expand All @@ -217,7 +219,7 @@ def train_model(
target_modules=self.args.lora_target_modules,
bias=self.args.lora_bias,
)
elif self.args.peft_type == 'ADALORA':
elif peft_type == 'ADALORA':
from peft import AdaLoraConfig

peft_config = AdaLoraConfig(
Expand All @@ -234,22 +236,22 @@ def train_model(
task_type=TaskType.CAUSAL_LM,
inference_mode=False,
)
elif self.args.peft_type == 'PROMPT_TUNING':
elif peft_type == 'PROMPT_TUNING':
from peft import PromptTuningConfig

peft_config = PromptTuningConfig(
task_type=TaskType.CAUSAL_LM,
num_virtual_tokens=self.args.num_virtual_tokens,
)
elif self.args.peft_type == 'P_TUNING':
elif peft_type == 'P_TUNING':
from peft import PromptEncoderConfig

peft_config = PromptEncoderConfig(
task_type=TaskType.CAUSAL_LM,
num_virtual_tokens=self.args.num_virtual_tokens,
encoder_hidden_size=self.args.prompt_encoder_hidden_size
)
elif self.args.peft_type == 'PREFIX_TUNING':
elif peft_type == 'PREFIX_TUNING':
from peft import PrefixTuningConfig

peft_config = PrefixTuningConfig(
Expand Down
12 changes: 7 additions & 5 deletions textgen/chatglm/chatglm_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,8 +215,10 @@ def train_model(

# setup peft
if self.args.use_peft:
peft_type = self.args.peft_type.upper()
logger.info(f"Using PEFT type: {peft_type}")
# add peft config
if self.args.peft_type == 'LORA':
if peft_type == 'LORA':
peft_config = LoraConfig(
task_type=TaskType.CAUSAL_LM,
inference_mode=False,
Expand All @@ -226,7 +228,7 @@ def train_model(
target_modules=self.args.lora_target_modules,
bias=self.args.lora_bias,
)
elif self.args.peft_type == 'ADALORA':
elif peft_type == 'ADALORA':
from peft import AdaLoraConfig

peft_config = AdaLoraConfig(
Expand All @@ -243,22 +245,22 @@ def train_model(
task_type=TaskType.CAUSAL_LM,
inference_mode=False,
)
elif self.args.peft_type == 'PROMPT_TUNING':
elif peft_type == 'PROMPT_TUNING':
from peft import PromptTuningConfig

peft_config = PromptTuningConfig(
task_type=TaskType.CAUSAL_LM,
num_virtual_tokens=self.args.num_virtual_tokens,
)
elif self.args.peft_type == 'P_TUNING':
elif peft_type == 'P_TUNING':
from peft import PromptEncoderConfig

peft_config = PromptEncoderConfig(
task_type=TaskType.CAUSAL_LM,
num_virtual_tokens=self.args.num_virtual_tokens,
encoder_hidden_size=self.args.prompt_encoder_hidden_size
)
elif self.args.peft_type == 'PREFIX_TUNING':
elif peft_type == 'PREFIX_TUNING':
from peft import PrefixTuningConfig

peft_config = PrefixTuningConfig(
Expand Down
12 changes: 7 additions & 5 deletions textgen/llama/llama_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,8 +212,10 @@ def train_model(

# setup peft
if self.args.use_peft:
peft_type = self.args.peft_type.upper()
logger.info(f"Using PEFT type: {peft_type}")
# add peft config
if self.args.peft_type == 'LORA':
if peft_type == 'LORA':
peft_config = LoraConfig(
task_type=TaskType.CAUSAL_LM,
inference_mode=False,
Expand All @@ -223,7 +225,7 @@ def train_model(
target_modules=self.args.lora_target_modules,
bias=self.args.lora_bias,
)
elif self.args.peft_type == 'ADALORA':
elif peft_type == 'ADALORA':
from peft import AdaLoraConfig

peft_config = AdaLoraConfig(
Expand All @@ -240,22 +242,22 @@ def train_model(
task_type=TaskType.CAUSAL_LM,
inference_mode=False,
)
elif self.args.peft_type == 'PROMPT_TUNING':
elif peft_type == 'PROMPT_TUNING':
from peft import PromptTuningConfig

peft_config = PromptTuningConfig(
task_type=TaskType.CAUSAL_LM,
num_virtual_tokens=self.args.num_virtual_tokens,
)
elif self.args.peft_type == 'P_TUNING':
elif peft_type == 'P_TUNING':
from peft import PromptEncoderConfig

peft_config = PromptEncoderConfig(
task_type=TaskType.CAUSAL_LM,
num_virtual_tokens=self.args.num_virtual_tokens,
encoder_hidden_size=self.args.prompt_encoder_hidden_size
)
elif self.args.peft_type == 'PREFIX_TUNING':
elif peft_type == 'PREFIX_TUNING':
from peft import PrefixTuningConfig

peft_config = PrefixTuningConfig(
Expand Down

0 comments on commit 968a3f8

Please sign in to comment.