Skip to content

Commit

Permalink
FIX: Multitask prompt tuning with other tuning init (#1144)
Browse files Browse the repository at this point in the history
Resolves #1082.

Also, adding tests for prompt_tuning_init != RANDOM.

---------

Co-authored-by: Mayank Mishra <32954280+mayank31398@users.noreply.github.com>
  • Loading branch information
BenjaminBossan and mayank31398 authored Feb 19, 2024
1 parent 8a0dce2 commit 043d5c0
Show file tree
Hide file tree
Showing 3 changed files with 68 additions and 5 deletions.
2 changes: 1 addition & 1 deletion src/peft/peft_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -1153,7 +1153,7 @@ def generate(self, *args, **kwargs):
self.base_model.prepare_inputs_for_generation = self.base_model_prepare_inputs_for_generation
return outputs

def prepare_inputs_for_generation(self, *args, task_ids: torch.Tensor = None, **kwargs):
def prepare_inputs_for_generation(self, *args, task_ids: Optional[torch.Tensor] = None, **kwargs):
peft_config = self.active_peft_config
model_kwargs = self.base_model_prepare_inputs_for_generation(*args, **kwargs)

Expand Down
3 changes: 2 additions & 1 deletion src/peft/tuners/multitask_prompt_tuning/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,9 +66,10 @@ def __init__(self, config: MultitaskPromptTuningConfig, word_embeddings):
"init method"
)

# TODO: There should be an option for safetensors
state_dict: dict = torch.load(
config.prompt_tuning_init_state_dict_path,
map_location=word_embeddings.device,
map_location=word_embeddings.weight.device,
)

if config.prompt_tuning_init in [
Expand Down
68 changes: 65 additions & 3 deletions tests/test_multitask_prompt_tuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,15 @@
import tempfile
from unittest import TestCase

import pytest
import torch
from parameterized import parameterized
from torch.testing import assert_close

from peft.mapping import get_peft_model
from peft.peft_model import PeftModel
from peft.tuners.multitask_prompt_tuning import MultitaskPromptTuningConfig
from peft.utils.other import prepare_model_for_int8_training
from peft.tuners.multitask_prompt_tuning import MultitaskPromptTuningConfig, MultitaskPromptTuningInit
from peft.utils.other import WEIGHTS_NAME, prepare_model_for_int8_training
from peft.utils.save_and_load import get_peft_model_state_dict
from tests.testing_common import PeftCommonTester

Expand Down Expand Up @@ -73,7 +75,9 @@ def _create_multitask_prompt_tuning_config(cls) -> MultitaskPromptTuningConfig:
task_type="CAUSAL_LM",
num_virtual_tokens=50,
num_tasks=3,
prompt_tuning_init_text="classify the following into either positive or negative, or entailment, neutral or contradiction:",
prompt_tuning_init_text=(
"classify the following into either positive or negative, or entailment, neutral or contradiction:"
),
)

def test_prepare_for_training(self) -> None:
Expand Down Expand Up @@ -240,3 +244,61 @@ def test_bf16_inference(self) -> None:
mpt = get_peft_model(original, self._create_multitask_prompt_tuning_config())
mpt = mpt.to(self.torch_device)
_ = mpt.generate(input_ids=input_ids, task_ids=task_ids)

def test_generate_text_with_random_init(self) -> None:
model = LlamaForCausalLM(self._create_test_llama_config())

config = self._create_multitask_prompt_tuning_config()
config.prompt_tuning_init = MultitaskPromptTuningInit.RANDOM

model = get_peft_model(model, config)
model = model.to(self.torch_device)

input_ids = torch.LongTensor([[1, 1, 1], [2, 1, 2]]).to(self.torch_device)
attention_mask = torch.LongTensor([[1, 1, 1], [1, 0, 1]]).to(self.torch_device)
task_ids = torch.LongTensor([0]).to(self.torch_device)

# check if `generate` works
_ = model.generate(input_ids=input_ids, attention_mask=attention_mask, task_ids=task_ids)

with pytest.raises(ValueError):
# check if `generate` raises an error if task_ids are not passed
_ = model.generate(input_ids, attention_mask=attention_mask)

@parameterized.expand(
[
MultitaskPromptTuningInit.AVERAGE_SOURCE_TASKS,
MultitaskPromptTuningInit.EXACT_SOURCE_TASK,
MultitaskPromptTuningInit.ONLY_SOURCE_SHARED,
],
)
def test_generate_text_with_other_init(self, prompt_tuning_init) -> None:
with tempfile.TemporaryDirectory() as tmp_dirname:
model = LlamaForCausalLM(self._create_test_llama_config())
model = get_peft_model(model, self._create_multitask_prompt_tuning_config())
model.save_pretrained(tmp_dirname, safe_serialization=False) # bc torch.load is used

config = MultitaskPromptTuningConfig(
task_type="CAUSAL_LM",
num_virtual_tokens=50,
num_tasks=1,
prompt_tuning_init_text=(
"classify the following into either positive or negative, or entailment, neutral or contradiction:"
),
prompt_tuning_init=prompt_tuning_init,
prompt_tuning_init_state_dict_path=os.path.join(tmp_dirname, WEIGHTS_NAME),
)
model = LlamaForCausalLM(self._create_test_llama_config())
model = get_peft_model(model, config)
model = model.to(self.torch_device)

input_ids = torch.LongTensor([[1, 1, 1], [2, 1, 2]]).to(self.torch_device)
attention_mask = torch.LongTensor([[1, 1, 1], [1, 0, 1]]).to(self.torch_device)
task_ids = torch.LongTensor([0]).to(self.torch_device)

# check if `generate` works
_ = model.generate(input_ids=input_ids, attention_mask=attention_mask, task_ids=task_ids)

with pytest.raises(ValueError):
# check if `generate` raises an error if task_ids are not passed
_ = model.generate(input_ids, attention_mask=attention_mask)

0 comments on commit 043d5c0

Please sign in to comment.