diff --git a/llm/config/llama/vera_argument.json b/llm/config/llama/vera_argument.json new file mode 100644 index 000000000000..157c47dde505 --- /dev/null +++ b/llm/config/llama/vera_argument.json @@ -0,0 +1,32 @@ +{ + "model_name_or_path": "facebook/llama-7b", + "dataset_name_or_path": "./data", + "output_dir": "./checkpoints/vera_ckpts", + "per_device_train_batch_size": 4, + "gradient_accumulation_steps": 4, + "per_device_eval_batch_size": 8, + "eval_accumulation_steps":16, + "num_train_epochs": 1, + "learning_rate": 3e-04, + "warmup_steps": 30, + "logging_steps": 1, + "evaluation_strategy": "epoch", + "save_strategy": "epoch", + "src_length": 1024, + "max_length": 2048, + "fp16": true, + "fp16_opt_level": "O2", + "do_train": true, + "do_eval": true, + "disable_tqdm": true, + "load_best_model_at_end": true, + "eval_with_do_generation": false, + "metric_for_best_model": "accuracy", + "recompute": true, + "save_total_limit": 10, + "tensor_parallel_degree": 1, + "pipeline_parallel_degree": 1, + "vera": true, + "zero_padding": false, + "use_flash_attention": false +} \ No newline at end of file diff --git a/llm/run_finetune.py b/llm/run_finetune.py index de31240d2ae3..47e323296f98 100644 --- a/llm/run_finetune.py +++ b/llm/run_finetune.py @@ -14,7 +14,6 @@ import json import os import sys -import inspect from functools import partial import paddle @@ -42,7 +41,14 @@ load_dataset, ) from paddlenlp.metrics import BLEU, Rouge1, Rouge2, RougeL -from paddlenlp.peft import LoRAConfig, LoRAModel, PrefixConfig, PrefixModelForCausalLM +from paddlenlp.peft import ( + LoRAConfig, + LoRAModel, + PrefixConfig, + PrefixModelForCausalLM, + VeRAConfig, + VeRAModel, +) from paddlenlp.trainer import PdArgumentParser, get_last_checkpoint from paddlenlp.trainer.trainer_callback import TrainerState from paddlenlp.transformers import ( @@ -51,9 +57,9 @@ AutoModelForCausalLMPipe, AutoTokenizer, Llama3Tokenizer, - LlamaTokenizer, LlamaForCausalLM, LlamaForCausalLMPipe, + LlamaTokenizer, ) from paddlenlp.transformers.configuration_utils import LlmMetaConfig from paddlenlp.utils.log import logger @@ -82,7 +88,6 @@ def main(): raise ValueError( "--do_train, --do_ptq, --do_gptq and --do_qat cannot work at the same time. Please choose only one at a time" ) - # Setup GPU & distributed training paddle.set_device(training_args.device) @@ -167,9 +172,7 @@ def main(): model = model_class.from_config(model_config, dtype=dtype) if model_args.flash_mask and (not data_args.zero_padding or not model.config.use_flash_attention): - logger.warning( - "`flash_mask` must use with zero padding and flash attention." - ) + logger.warning("`flash_mask` must use with zero padding and flash attention.") data_args.zero_padding = True model.config.use_flash_attention = True @@ -345,12 +348,16 @@ def neft_post_hook(module, input, output): "Zero Padding data stream is only implemented for LLaMA, Bloom, ChatGLM, QWen and Mistral so far." ) train_ds = ( - train_ds.map(partial(trans_func, is_test=False, zero_padding=data_args.zero_padding, flash_mask=model_args.flash_mask)) + train_ds.map( + partial(trans_func, is_test=False, zero_padding=data_args.zero_padding, flash_mask=model_args.flash_mask) + ) if train_ds is not None else None ) ptq_ds = ( - ptq_ds.map(partial(trans_func, is_test=False, zero_padding=data_args.zero_padding, flash_mask=model_args.flash_mask)) + ptq_ds.map( + partial(trans_func, is_test=False, zero_padding=data_args.zero_padding, flash_mask=model_args.flash_mask) + ) if ptq_ds is not None else None ) @@ -361,7 +368,14 @@ def neft_post_hook(module, input, output): ) eval_zero_padding = False dev_ds = ( - dev_ds.map(partial(trans_func, is_test=data_args.eval_with_do_generation, zero_padding=eval_zero_padding, flash_mask=model_args.flash_mask)) + dev_ds.map( + partial( + trans_func, + is_test=data_args.eval_with_do_generation, + zero_padding=eval_zero_padding, + flash_mask=model_args.flash_mask, + ) + ) if dev_ds is not None else None ) @@ -485,6 +499,20 @@ def compute_metrics_do_generation(eval_preds): "bleu4": bleu4.score(), } + if model_args.vera: + target_modules = get_lora_target_modules(model) + vera_config = VeRAConfig( + target_modules=target_modules, + r=model_args.vera_rank, + vera_alpha=model_args.vera_rank, + dtype=dtype, + base_model_name_or_path=model_args.model_name_or_path, + pissa_init=True, + ) + model = VeRAModel(model, vera_config) + model.mark_only_vera_as_trainable(notfreezeB=True) + model.print_trainable_parameters() + # Create trainer max_length = ( data_args.max_length diff --git a/llm/tools/merge_vera_params.py b/llm/tools/merge_vera_params.py new file mode 100644 index 000000000000..a1ddc49ac0a4 --- /dev/null +++ b/llm/tools/merge_vera_params.py @@ -0,0 +1,103 @@ +# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import os + +import paddle + +from paddlenlp.peft import VeRAConfig, VeRAModel +from paddlenlp.transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer +from paddlenlp.utils.env import CONFIG_NAME + + +def parse_arguments(): + parser = argparse.ArgumentParser() + parser.add_argument("--model_name_or_path", default=None, help="The directory of pretrained model.") + parser.add_argument("--vera_path", default="", help="The directory of VeRA parameters. Default to None") + parser.add_argument( + "--merge_vera_model_path", + default="", + help="The directory of merged parameters. Default to None", + ) + parser.add_argument("--device", type=str, default="gpu", help="Device") + parser.add_argument( + "--low_gpu_mem", type=bool, default=True, help="Whether to use low gpu memory. Default to False" + ) + return parser.parse_args() + + +def weight_process(name, vera_config, state_dict): + weight = state_dict.pop(name + ".weight").cuda() + vera_A = state_dict.pop(name + ".vera_A").cuda() + vera_B = state_dict.pop(name + ".vera_B").cuda() + vera_b = state_dict.pop(name + ".vera_b").cuda() + vera_d = state_dict.pop(name + ".vera_d").cuda() + diag_b = paddle.diag(vera_b) + diag_d = paddle.diag(vera_d) + + scaling = vera_config.vera_alpha / vera_config.r + state_dict[name + ".weight"] = (weight + vera_A @ diag_d @ vera_B @ diag_b * scaling).cpu() + + +def merge(): + args = parse_arguments() + paddle.set_device(args.device) + + vera_config = VeRAConfig.from_pretrained(args.vera_path) + if vera_config.base_model_name_or_path is None: + if args.model_name_or_path is not None: + raise ValueError("We can not find a valid model_name_or_path.") + else: + vera_config.base_model_name_or_path = args.model_name_or_path + + if os.path.isfile(os.path.join(args.vera_path, CONFIG_NAME)): + config = AutoConfig.from_pretrained(args.vera_path) + elif args.model_name_or_path is not None: + config = AutoConfig.from_pretrained(args.model_name_or_path) + else: + raise ValueError( + f"We can not find config.json in vera_path: {args.vera_path} or find a valid model_name_or_path." + ) + config.dtype = vera_config.dtype + if ( + vera_config.dtype == "bfloat16" or config.quantization_config.weight_quantize_algo in ["nf4", "fp4"] + ) and args.device == "cpu": + raise ValueError("We can not apply bfloat16 or nf4/fp4 vera merge on cpu.") + + # with device_guard() will cause SVD decomposition to fail + model = AutoModelForCausalLM.from_pretrained( + vera_config.base_model_name_or_path, + config=config, + low_cpu_mem_usage=True, + ) + model = VeRAModel.from_pretrained(model=model, vera_path=args.vera_path, vera_config=vera_config) + + model.eval() + model_state_dict = model.model.state_dict() + vera_name_list = [] + for key in model_state_dict.keys(): + if "vera_A" in key: + vera_name_list.append(key[:-7]) + + for name in vera_name_list: + weight_process(name, vera_config, model_state_dict) + + model.model.save_pretrained(args.merge_vera_model_path, state_dict=model_state_dict) + tokenizer = AutoTokenizer.from_pretrained(vera_config.base_model_name_or_path) + tokenizer.save_pretrained(args.merge_vera_model_path) + + +if __name__ == "__main__": + merge() diff --git a/llm/utils/argument.py b/llm/utils/argument.py index 63a14e4126ef..69c7770b40ed 100644 --- a/llm/utils/argument.py +++ b/llm/utils/argument.py @@ -196,6 +196,10 @@ class ModelArgument: lora_plus_scale: float = field(default=1.0, metadata={"help": "Lora B scale in LoRA+ technique"}) pissa: bool = field(default=False, metadata={"help": "Whether to use Pissa: https://arxiv.org/pdf/2404.02948.pdf"}) + # vera related parameters + vera: bool = field(default=False, metadata={"help": "Whether to use vera technique"}) + vera_rank: int = field(default=8, metadata={"help": "Vera attention dimension"}) + # prefix tuning related parameters prefix_tuning: bool = field(default=False, metadata={"help": "Whether to use Prefix technique"}) prefix_path: str = field(default=None, metadata={"help": "Initialize prefix state dict."}) @@ -209,9 +213,7 @@ class ModelArgument: aistudio_token: str = field(default=None, metadata={"help": "The token of aistudio"}) neftune: bool = field(default=False, metadata={"help": "Whether to apply NEFT"}) neftune_noise_alpha: float = field(default=5.0, metadata={"help": "NEFT noise alpha"}) - flash_mask: bool = field( - default=False, metadata={"help": "Whether to use flash_mask in flash attention."} - ) + flash_mask: bool = field(default=False, metadata={"help": "Whether to use flash_mask in flash attention."}) @dataclass diff --git a/paddlenlp/peft/__init__.py b/paddlenlp/peft/__init__.py index 4ffaec00a498..bf290397ec2e 100644 --- a/paddlenlp/peft/__init__.py +++ b/paddlenlp/peft/__init__.py @@ -14,3 +14,4 @@ from .lora import LoRAConfig, LoRAModel from .prefix import PrefixConfig, PrefixModelForCausalLM +from .vera import VeRAConfig, VeRAModel diff --git a/paddlenlp/peft/vera/__init__.py b/paddlenlp/peft/vera/__init__.py new file mode 100644 index 000000000000..2ba6e86f9002 --- /dev/null +++ b/paddlenlp/peft/vera/__init__.py @@ -0,0 +1,17 @@ +# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .vera_config import VeRAConfig +from .vera_layers import VeRALinear +from .vera_model import VeRAModel diff --git a/paddlenlp/peft/vera/vera_config.py b/paddlenlp/peft/vera/vera_config.py new file mode 100644 index 000000000000..76f0d3a73bb7 --- /dev/null +++ b/paddlenlp/peft/vera/vera_config.py @@ -0,0 +1,131 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +import os +from dataclasses import asdict, dataclass, field +from typing import List, Optional, Union + +from ...utils.env import VERA_CONFIG_NAME + + +@dataclass +class VeRAConfig: + """ + This is the configuration class to store the configuration of a [`VeRAModel`]. + Args: + r (`int`): vera attention dimension + target_modules (`Union[List[str],str]`): The names of the modules to apply vera to. + trainable_modules (`List[str]`): The names of the modules to train when applying vera. + vera_alpha (`float`): The alpha parameter for vera scaling. + vera_dropout (`float`): The dropout probability for vera layers. + """ + + r: int = field(default=8, metadata={"help": "vera attention dimension"}) + target_modules: Optional[Union[List[str], str]] = field( + default=None, + metadata={ + "help": "List of module names or regex expression of the module names to replace with vera." + "For example, ['q', 'v'] or '.*decoder.*(SelfAttention|EncDecAttention).*(q|v)$' " + }, + ) + trainable_modules: Optional[List[str]] = field( + default=None, + metadata={ + "help": "List of module names or regex expression of the module names to train when applying with vera." + "For example, ['q', 'v'] or '.*decoder.*(SelfAttention|EncDecAttention).*(q|v)$' " + }, + ) + vera_alpha: int = field(default=8, metadata={"help": "vera alpha"}) + vera_dropout: float = field(default=0.0, metadata={"help": "vera dropout"}) + trainable_bias: Optional[str] = field( + default=None, metadata={"help": "Define trainable bias parameters for the vera model."} + ) + tensor_parallel_degree: int = field(default=-1, metadata={"help": "1 for not use tensor parallel"}) + dtype: Optional[str] = field(default=None, metadata={"help": "The data type of tensor"}) + head_dim: Optional[int] = field( + default=None, + metadata={ + "help": "The model multi head dimension.Only for veraMergedLinear and ColumnParallelveraMergedLinear." + }, + ) + do_qat: bool = field(default=False, metadata={"help": "Whether the vera model would do quant-aware training"}) + base_model_name_or_path: Optional[str] = field( + default=None, metadata={"help": "The name of the base model to use."} + ) + pissa_init: bool = field(default=False, metadata={"help": "Whether the vera weight initialized by pissa"}) + + @property + def __dict__(self): + return asdict(self) + + def to_dict(self): + return self.__dict__ + + def save_pretrained(self, save_directory): + r""" + This method saves the configuration of your adapter model in a directory. + Args: + save_directory (`str`): + The directory where the configuration will be saved. + """ + if os.path.isfile(save_directory): + raise AssertionError(f"Provided path ({save_directory}) should be a directory, not a file") + + os.makedirs(save_directory, exist_ok=True) + + output_dict = self.__dict__ + output_path = os.path.join(save_directory, VERA_CONFIG_NAME) + + # save it + with open(output_path, "w") as writer: + writer.write(json.dumps(output_dict, indent=2, sort_keys=True)) + + @classmethod + def from_pretrained(cls, pretrained_model_name_or_path, **kwargs): + r""" + This method loads the configuration of your adapter model from a directory. + Args: + pretrained_model_name_or_path (`str`): + The directory or the hub-id where the configuration is saved. + **kwargs: + Additional keyword arguments passed along to the child class initialization. + """ + if os.path.isfile(os.path.join(pretrained_model_name_or_path, VERA_CONFIG_NAME)): + config_file = os.path.join(pretrained_model_name_or_path, VERA_CONFIG_NAME) + else: + raise ValueError(f"Can't find vera_config.json at '{pretrained_model_name_or_path}'") + + loaded_attributes = cls.from_json_file(config_file) + + config = cls(**kwargs) + + for key, value in loaded_attributes.items(): + if hasattr(config, key): + setattr(config, key, value) + + return config + + @classmethod + def from_json_file(cls, path_json_file): + r""" + Loads a configuration file from a json file. + Args: + path_json_file (`str`): + The path to the json file. + """ + with open(path_json_file, "r") as file: + json_object = json.load(file) + + return json_object diff --git a/paddlenlp/peft/vera/vera_layers.py b/paddlenlp/peft/vera/vera_layers.py new file mode 100644 index 000000000000..8bf478503ba3 --- /dev/null +++ b/paddlenlp/peft/vera/vera_layers.py @@ -0,0 +1,149 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math + +import paddle +import paddle.nn as nn +import paddle.nn.functional as F + + +class VeRALinear(nn.Linear): + # VeRA implemented in a dense layer + def __init__( + self, + base_linear_module: paddle.nn.layer.common.Linear, + in_features: int, + out_features: int, + r: int = 0, + vera_alpha: int = 1, + vera_dropout: float = 0.0, + pissa_init: bool = False, + **kwargs + ): + nn.Linear.__init__(self, in_features, out_features, **kwargs) + self.weight.set_value(base_linear_module.weight) + + if not isinstance(r, int) or r <= 0: + raise ValueError("Vora rank r should be a positive integer") + self.r = r + self.vera_alpha = vera_alpha + # Optional dropout + if vera_dropout > 0.0: + self.vera_dropout = nn.Dropout(p=vera_dropout) + else: + self.vera_dropout = lambda x: x + # Mark the weight as unmerged + self.merged = False + + if pissa_init: + assert self.vera_alpha == self.r, "pissa method requires vera_alpha=r, scaling=1" + self.scaling = 1.0 + self.vera_A = self.create_parameter( + shape=[in_features, r], + dtype=self._dtype, + is_bias=False, + ) + self.vera_B = self.create_parameter( + shape=[r, out_features], + dtype=self._dtype, + is_bias=False, + ) + self.pissa_init(r) + + else: + # Actual trainable parameters + self.vera_A = self.create_parameter( + shape=[in_features, r], + dtype=self._dtype, + is_bias=False, + default_initializer=nn.initializer.KaimingUniform( + negative_slope=math.sqrt(5), nonlinearity="leaky_relu" + ), + ) + self.vera_B = self.create_parameter( + shape=[r, out_features], + dtype=self._dtype, + is_bias=False, + default_initializer=nn.initializer.Constant(value=0.0), + ) + self.scaling = self.vera_alpha / self.r + + self.vera_b = self.create_parameter( + shape=[out_features], + dtype=self._dtype, + is_bias=False, + default_initializer=nn.initializer.Constant(value=1.0), + ) + + self.vera_d = self.create_parameter( + shape=[r], + dtype=self._dtype, + is_bias=False, + default_initializer=nn.initializer.Constant(value=1.0), + ) + + # Freezing the pre-trained weight matrix and bias vector + self.weight.stop_gradient = True + + def pissa_init(self, r): + weight = self.weight + dtype = weight.dtype + + if dtype != paddle.float32: + weight = weight.astype(paddle.float32) + + U, S, Vh = paddle.linalg.svd(weight.data, full_matrices=False) + + Ur = U[:, :r] + Sr = S[:r] + Vhr = Vh[:r] + + vera_A = Ur @ paddle.diag(paddle.sqrt(Sr)) + vera_B = paddle.diag(paddle.sqrt(Sr)) @ Vhr + + self.vera_A.set_value(vera_A.astype(dtype)) + self.vera_B.set_value(vera_B.astype(dtype)) + res = weight.data - vera_A @ vera_B + weight = res.astype(dtype) + self.weight.set_value(weight) + + def merge(self): + if not self.merged: + diag_b = paddle.diag(self.vera_b) + diag_d = paddle.diag(self.vera_d) + new_weight = self.weight + self.vera_A @ diag_d @ self.vera_B @ diag_b * self.scaling + self.weight.set_value(new_weight) + self.merged = True + + def unmerge(self): + if self.merged: + diag_b = paddle.diag(self.vera_b) + diag_d = paddle.diag(self.vera_d) + new_weight = self.weight - self.vera_A @ diag_d @ self.vera_B @ diag_b * self.scaling + self.weight.set_value(new_weight) + self.merged = False + + def forward(self, input: paddle.Tensor, *args, **kwargs): + result = F.linear(x=input, weight=self.weight, bias=self.bias, name=self.name) + if not self.merged: + # result += (self.vera_dropout(input) @ self.vera_A @ self.vera_B) * self.scaling + diag_b = paddle.diag(self.vera_b) + diag_d = paddle.diag(self.vera_d) + result += (self.vera_dropout(input) @ self.vera_A @ diag_d @ self.vera_B @ diag_b) * self.scaling + return result + + def extra_repr(self): + name = f", name={self.name}" if self.name else "" + return f"in_features={self.weight.shape[0]}, out_features={self.weight.shape[1]}, rank={self.r}{name}" diff --git a/paddlenlp/peft/vera/vera_model.py b/paddlenlp/peft/vera/vera_model.py new file mode 100644 index 000000000000..bfd00f07b94e --- /dev/null +++ b/paddlenlp/peft/vera/vera_model.py @@ -0,0 +1,284 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import copy +import os +import re +from collections import OrderedDict +from typing import Dict, Union + +import numpy as np +import paddle +import paddle.nn as nn +from paddle.distributed.fleet.meta_parallel import PipelineLayer + +from ...transformers.model_utils import PretrainedModel, _add_variant, dtype_guard +from ...utils.env import VERA_WEIGHTS_NAME +from ...utils.log import logger +from .vera_config import VeRAConfig +from .vera_layers import VeRALinear + + +class VeRAModel(nn.Layer): + restore_layer_map: Dict[nn.Layer, nn.Layer] = { + VeRALinear: nn.Linear, + } + + def __init__(self, model, vera_config: VeRAConfig) -> None: + super().__init__() + self.quantized = False + self.vera_config = vera_config + if self.vera_config.dtype is None: + self.vera_config.dtype = paddle.get_default_dtype() + with dtype_guard(self.vera_config.dtype): + self.model = self.get_vera_model(model, vera_config) + self.is_pipelinemodel = False + if issubclass(type(self.model), PipelineLayer): + raise NotImplementedError("vera don't support pipeline parallel now") + if vera_config.tensor_parallel_degree > 1: + raise NotImplementedError("vera don't support tensor parallel now") + self.forward = self.model.forward + + @classmethod + def from_pretrained(cls, model, vera_path, **kwargs): + vera_config = kwargs.pop("vera_config", None) + # init vera config & vera model + if not isinstance(vera_config, VeRAConfig): + vera_config = VeRAConfig.from_pretrained(vera_path) + # define a new variable to conserve original vera_config.tensor_parallel_degree value which will update while initializing vera model + vera_config_tensor_parallel_degree = vera_config.tensor_parallel_degree + vera_model = cls(model, vera_config) + + vera_weight_name = VERA_WEIGHTS_NAME + + # load and set vera weight parameter + vera_weight_path = os.path.join(vera_path, vera_weight_name) + logger.info(f"vera weight path is {vera_weight_path}") + if os.path.exists(vera_weight_path): + # load vera weight parameter + logger.info("vera_weight_path existed, loading vera weight parameter") + + vera_state_dict = paddle.load(vera_weight_path, return_numpy=True) + logger.info(f"Loading the VeRA weights from {vera_weight_path}") + + if ( + vera_config_tensor_parallel_degree > 1 + and vera_config_tensor_parallel_degree != model.config.tensor_parallel_degree + ): + raise NotImplementedError( + f"{vera_config_tensor_parallel_degree} is not equal to {model.config.tensor_parallel_degree}. Please merge VeRA weights first." + ) + + # set vera state dict + vera_model.set_state_dict(vera_state_dict) + else: + logger.error(f"VeRA weights not found under {vera_path}, creating VeRA weights from scratch") + + return vera_model + + def set_state_dict(self, state_dict): + import warnings + + warnings.filterwarnings( + action="ignore", message=".*Skip loading for.*", category=Warning, lineno=0, append=False + ) + self.model.set_state_dict(state_dict) + logger.info("Load vera weight successfully") + + def save_pretrained(self, save_directory: str, merge_tensor_parallel: bool = False, **kwargs): + + logger.info("save vera pretrained") + save_model_config = kwargs.get("save_model_config", True) + + if self.is_pipelinemodel: + self.model._single_to_pp_mapping = None + if self.quantized and merge_tensor_parallel and self.vera_config.tensor_parallel_degree > 1: + merge_tensor_parallel = False + logger.warning( + "Quantized strategy does not support merge_tensor_parallel. Set merge_tensor_parallel to False." + ) + if self.is_pipelinemodel and merge_tensor_parallel and self.vera_config.tensor_parallel_degree > 1: + merge_tensor_parallel = False + logger.warning( + "Pipeline parallism does not support merge_tensor_parallel. Set merge_tensor_parallel to False." + ) + + variant = kwargs.get("variant", None) + is_main_process = kwargs.get("is_main_process", paddle.distributed.get_rank() == 0) + + assert not os.path.isfile( + save_directory + ), f"Saving directory ({save_directory}) should be a directory, not a file" + os.makedirs(save_directory, exist_ok=True) + + vera_config_to_save = VeRAConfig(**self.vera_config.to_dict()) + + logger.info(f"vera config to save is {vera_config_to_save}") + + trainable_state_dict = self.get_trainable_state_dict() + + # save vera weight + vera_weight_name = _add_variant(VERA_WEIGHTS_NAME, variant) + weight_filename = os.path.join(save_directory, vera_weight_name) + paddle.save(trainable_state_dict, weight_filename) + + # save vera config + if is_main_process: + vera_config_to_save.save_pretrained(save_directory) + if save_model_config: + model_config_to_save = copy.deepcopy(self.model.config) + if merge_tensor_parallel: + model_config_to_save.tensor_parallel_degree = -1 + model_config_to_save.save_pretrained(save_directory) + + def _find_and_replace_module(self, model, module_name, vera_config, enable_vera): + parent_module = model + attribute_chain = module_name.split(".") + for name in attribute_chain[:-1]: + parent_module = getattr(parent_module, name) + module = getattr(parent_module, attribute_chain[-1]) + vera_module = None + if enable_vera is None: + if isinstance(module, nn.Linear): + vera_module = VeRALinear( + # pass the base linear module + base_linear_module=module, + in_features=module.weight.shape[0], + out_features=module.weight.shape[1], + r=vera_config.r, + vera_alpha=vera_config.vera_alpha, + vera_dropout=vera_config.vera_dropout, + bias_attr=False if module.bias is None else None, + pissa_init=vera_config.pissa_init, + ) + + if vera_module is None: + raise ValueError( + f"VeRA strategy only supports paddle.nn.Linear or paddle.distributed.fleet.meta_parallel.ColumnParallelLinear. {module}({module_name}) is not supported。" + ) + + if module.bias is not None: + vera_module.bias = module.bias + + setattr(parent_module, attribute_chain[-1], vera_module) + + def _find_and_restore_module(self, module_name): + parent_module = self.model + attribute_chain = module_name.split(".") + for name in attribute_chain[:-1]: + parent_module = getattr(parent_module, name) + module = getattr(parent_module, attribute_chain[-1]) + original_model_class = self.restore_layer_map[module.__class__] + original_module = original_model_class(in_features=module.weight.shape[0], out_features=module.weight.shape[1]) + original_module.weight = module.weight + if module.bias is not None: + original_module.bias = module.bias + setattr(parent_module, attribute_chain[-1], original_module) + + def get_trainable_state_dict(self): + trainable_state_dict = OrderedDict() + for name, weight in self.model.state_dict().items(): + # get vera parameter + if not weight.stop_gradient: + trainable_state_dict[name] = weight + return trainable_state_dict + + def print_trainable_parameters(self) -> None: + freeze_numel = 0 + trainable_numel = 0 + for _, weight in self.model.state_dict().items(): + if weight.stop_gradient: + freeze_numel += np.prod(weight.shape) + else: + trainable_numel += np.prod(weight.shape) + logger.debug( + f"Frozen parameters: {freeze_numel:.2e} || Trainable parameters:{trainable_numel:.2e} || Total parameters:{freeze_numel+trainable_numel:.2e}|| Trainable:{trainable_numel / (freeze_numel+trainable_numel):.2%}" + ) + + def mark_only_vera_as_trainable(self, notfreezeB=False) -> None: + for _, layer in self.model.named_sublayers(): + if isinstance(layer, VeRALinear): + for name, weight in layer.state_dict().items(): + if self.vera_config.trainable_bias in ["vera", "all"] and "bias" in name: + weight.stop_gradient = False + elif "vera" in name: + # notfreezeB=True, vera_b, vera_d, vera_B is trainable + # notfreezeB=False, vera_b, vera_d is trainable + if "vera_b" in name or "vera_d" in name: + weight.stop_gradient = False + elif "vera_B" in name and notfreezeB: + weight.stop_gradient = False + else: + weight.stop_gradient = True + else: + weight.stop_gradient = True + else: + for name, weight in layer.state_dict().items(): + if self.vera_config.trainable_bias == "all" and "bias" in name: + weight.stop_gradient = False + else: + weight.stop_gradient = True + if self.vera_config.trainable_modules is not None: + for name, weight in self.model.state_dict().items(): + if any( + re.fullmatch(trainable_module, name) for trainable_module in self.vera_config.trainable_modules + ): + weight.stop_gradient = False + + def get_vera_model(self, model: Union[PretrainedModel, nn.Layer], vera_config: VeRAConfig): + + if vera_config.target_modules is None: + return model + elif isinstance(vera_config.target_modules, str): + target_modules = [vera_config.target_modules] + enable_vera_list = [None] + else: + target_modules = vera_config.target_modules + enable_vera_list = [None for _ in range(len(target_modules))] + + for target_module, enable_vera in zip(target_modules, enable_vera_list): + for i in model.named_sublayers(): + module_name = i[0] + if re.fullmatch(target_module, module_name): + self._find_and_replace_module(model, module_name, vera_config, enable_vera) + return model + + def restore_original_model(self): + for layer_name, layer in self.model.named_sublayers(): + if isinstance(layer, VeRALinear): + self._find_and_restore_module(layer_name) + else: + raise NotImplementedError(f"{layer} restoration is not supported yet.") + return self.model + + def __getattr__(self, name: str): + """Forward missing attributes to the wrapped module.""" + try: + return super().__getattr__(name) # defer to nn.Layer's logic + except AttributeError: + return getattr(self.model, name) + + def train(self): + self.training = True + self.model.training = True + for layer in self.model.sublayers(): + layer.training = True + layer.train() + + def eval(self): + self.training = False + self.model.training = False + for layer in self.model.sublayers(): + layer.training = False + layer.eval() diff --git a/paddlenlp/trainer/integrations.py b/paddlenlp/trainer/integrations.py index b00441bf46fc..5c3eb93d8e44 100644 --- a/paddlenlp/trainer/integrations.py +++ b/paddlenlp/trainer/integrations.py @@ -23,7 +23,7 @@ import tempfile from pathlib import Path -from ..peft import LoRAModel, PrefixModelForCausalLM +from ..peft import LoRAModel, PrefixModelForCausalLM, VeRAModel from ..transformers import PretrainedModel from ..utils.log import logger from .trainer_callback import TrainerCallback @@ -116,7 +116,11 @@ def on_train_begin(self, args, state, control, **kwargs): self.vdl_writer.add_text("args", args.to_json_string()) if "model" in kwargs and logger.logger.level < 20: model = kwargs["model"] - if isinstance(model, LoRAModel) or isinstance(model, PrefixModelForCausalLM): + if ( + isinstance(model, LoRAModel) + or isinstance(model, PrefixModelForCausalLM) + or isinstance(model, VeRAModel) + ): model = kwargs["model"].model if isinstance(model, PretrainedModel) and model.constructed_from_pretrained_config(): model.config.architectures = [model.__class__.__name__] diff --git a/paddlenlp/trainer/trainer.py b/paddlenlp/trainer/trainer.py index 01e5fccbc02e..ab8c88e99eb8 100644 --- a/paddlenlp/trainer/trainer.py +++ b/paddlenlp/trainer/trainer.py @@ -81,7 +81,7 @@ DistDataLoader, default_data_collator, ) -from ..peft import LoRAModel, PrefixModelForCausalLM +from ..peft import LoRAModel, PrefixModelForCausalLM, VeRAModel try: from ..quantization.quantization_linear import QuantizationLinear @@ -107,6 +107,7 @@ SAFE_MASTER_WEIGHTS_INDEX_NAME, SAFE_PEFT_WEIGHTS_INDEX_NAME, SAFE_WEIGHTS_INDEX_NAME, + VERA_WEIGHTS_NAME, ) from ..utils.import_utils import is_datasets_available, is_paddle_cuda_available from ..utils.log import logger @@ -433,7 +434,11 @@ def __init__( if train_dataset is not None and not isinstance(train_dataset, collections.abc.Sized) and args.max_steps <= 0: raise ValueError("train_dataset does not implement __len__, max_steps has to be specified") - if isinstance(self.model, LoRAModel) or isinstance(self.model, PrefixModelForCausalLM): + if ( + isinstance(self.model, LoRAModel) + or isinstance(self.model, PrefixModelForCausalLM) + or isinstance(self.model, VeRAModel) + ): if self.args.unified_checkpoint and "skip_save_model_weight" in self.args.unified_checkpoint_config: self.args.unified_checkpoint_config.remove("skip_save_model_weight") logger.warning( @@ -572,6 +577,8 @@ def _load_from_peft_checkpoint(self, resume_from_checkpoint=None): weights_file = os.path.join(resume_from_checkpoint, PREFIX_WEIGHTS_NAME) if self.model.prefix_config.tensor_parallel_degree > 1: convert_tp = True + elif isinstance(self.model, VeRAModel): + weights_file = os.path.join(resume_from_checkpoint, VERA_WEIGHTS_NAME) if self.args.dataset_rank == 0: logger.info(f"Loading model from {resume_from_checkpoint} .") @@ -626,7 +633,11 @@ def _load_from_checkpoint(self, resume_from_checkpoint=None): self.runtime_timer.stop() return - if isinstance(self.model, LoRAModel) or isinstance(self.model, PrefixModelForCausalLM): + if ( + isinstance(self.model, LoRAModel) + or isinstance(self.model, PrefixModelForCausalLM) + or isinstance(self.model, VeRAModel) + ): self._load_from_peft_checkpoint(resume_from_checkpoint) self.runtime_timer.stop() return @@ -2514,7 +2525,11 @@ def _save(self, output_dir: Optional[str] = None, state_dict=None, merge_tensor_ merge_tensor_parallel = merge_tensor_parallel and self.args.use_hybrid_parallel # peft model - if isinstance(self.model, LoRAModel) or isinstance(self.model, PrefixModelForCausalLM): + if ( + isinstance(self.model, LoRAModel) + or isinstance(self.model, PrefixModelForCausalLM) + or isinstance(self.model, VeRAModel) + ): self.model.save_pretrained( output_dir, variant=self.args.weight_name_suffix, diff --git a/paddlenlp/utils/env.py b/paddlenlp/utils/env.py index f617ff760ad1..e51e87753e51 100644 --- a/paddlenlp/utils/env.py +++ b/paddlenlp/utils/env.py @@ -75,6 +75,9 @@ def _get_bool_env(env_key: str, default_value: str) -> bool: LORA_CONFIG_NAME = "lora_config.json" LORA_WEIGHTS_NAME = "lora_model_state.pdparams" +VERA_CONFIG_NAME = "vera_config.json" +VERA_WEIGHTS_NAME = "vera_model_state.pdparams" + PREFIX_CONFIG_NAME = "prefix_config.json" PREFIX_WEIGHTS_NAME = "prefix_model_state.pdparams" PADDLE_PEFT_WEIGHTS_INDEX_NAME = "peft_model.pdparams.index.json" diff --git a/tests/fixtures/llm/vera.yaml b/tests/fixtures/llm/vera.yaml new file mode 100644 index 000000000000..72752ddffa91 --- /dev/null +++ b/tests/fixtures/llm/vera.yaml @@ -0,0 +1,66 @@ +vera: + base: + dataset_name_or_path: "./data" + per_device_train_batch_size: 4 + gradient_accumulation_steps: 4 + per_device_eval_batch_size: 8 + eval_accumulation_steps: 16 + num_train_epochs: 3 + learning_rate: 3e-04 + warmup_steps: 30 + logging_steps: 1 + evaluation_strategy: "epoch" + save_strategy: "epoch" + src_length: 1024 + max_length: 2048 + fp16: true + fp16_opt_level: "O2" + do_train: true + do_eval: true + disable_tqdm: true + load_best_model_at_end: true + eval_with_do_generation: false + metric_for_best_model: "accuracy" + recompute: true + save_total_limit: 1 + tensor_parallel_degree: 1 + pipeline_parallel_degree: 1 + vera: true + + default: + llama: + model_name_or_path: __internal_testing__/tiny-random-llama + chatglm: + model_name_or_path: __internal_testing__/tiny-fused-chatglm + chatglm2: + model_name_or_path: __internal_testing__/tiny-fused-chatglm2 + bloom: + model_name_or_path: __internal_testing__/tiny-fused-bloom + qwen: + model_name_or_path: __internal_testing__/tiny-fused-qwen + qwen2: + model_name_or_path: __internal_testing__/tiny-random-qwen2 + qwen2moe: + model_name_or_path: __internal_testing__/tiny-random-qwen2moe + baichuan: + model_name_or_path: __internal_testing__/tiny-fused-baichuan + +inference-predict: + default: + mode: dynamic + max_length: 20 + batch_size: 2 + decode_strategy: greedy_search + dtype: float16 + +inference-to-static: + default: + dtype: float16 + +inference-infer: + default: + mode: static + dtype: float16 + batch_size: 2 + decode_strategy: greedy_search + max_length: 20 \ No newline at end of file diff --git a/tests/llm/test_vera.py b/tests/llm/test_vera.py new file mode 100644 index 000000000000..a3f81ede72e3 --- /dev/null +++ b/tests/llm/test_vera.py @@ -0,0 +1,82 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from __future__ import annotations + +import os +import sys +import unittest + +import paddle +from parameterized import parameterized_class + +from tests.testing_utils import argv_context_guard, load_test_config + +from .testing_utils import LLMTest + + +@parameterized_class( + ["model_dir"], + [ + ["llama"], + ["chatglm"], + ["chatglm2"], + ["bloom"], + ["qwen"], + ["baichuan"], + ], +) +class VeraTest(LLMTest, unittest.TestCase): + config_path: str = "./tests/fixtures/llm/vera.yaml" + model_dir: str = None + + def setUp(self) -> None: + LLMTest.setUp(self) + + self.model_codes_dir = os.path.join(self.root_path, self.model_dir) + sys.path.insert(0, self.model_codes_dir) + + def tearDown(self) -> None: + LLMTest.tearDown(self) + sys.path.remove(self.model_codes_dir) + + def test_vera(self): + self.disable_static() + paddle.set_default_dtype("float32") + + vera_config = load_test_config(self.config_path, "vera", self.model_dir) + vera_config["output_dir"] = self.output_dir + vera_config["dataset_name_or_path"] = self.data_dir + + with argv_context_guard(vera_config): + from run_finetune import main + + main() + + # merge weights + merge_vera_weights_config = { + "vera_path": vera_config["output_dir"], + "merge_vera_model_path": vera_config["output_dir"], + "device": "gpu", + "low_gpu_mem": True, + } + with argv_context_guard(merge_vera_weights_config): + from tools.merge_vera_params import merge + + merge() + + # TODO(wj-Mcat): disable chatglm2 test temporarily + if self.model_dir not in ["qwen", "baichuan", "chatglm2"]: + self.run_predictor({"inference_model": True}) + + self.run_predictor({"inference_model": False}) diff --git a/tests/peft/test_vera.py b/tests/peft/test_vera.py new file mode 100644 index 000000000000..8670576edb18 --- /dev/null +++ b/tests/peft/test_vera.py @@ -0,0 +1,203 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import copy +import os +import re +import unittest +from tempfile import NamedTemporaryFile, TemporaryDirectory + +import numpy as np +import paddle +from paddle import nn +from parameterized import parameterized + +from paddlenlp.peft.vera import VeRAConfig, VeRALinear, VeRAModel +from paddlenlp.transformers import AutoModel + + +class TestVeraLayer(unittest.TestCase): + def test_r_raise_exception(self): + with self.assertRaises(ValueError): + VeRALinear( + in_features=16, + out_features=16, + r=0, + vera_dropout=0.1, + vera_alpha=4, + base_linear_module=nn.Linear(in_features=16, out_features=16), + ) + + def test_forward(self): + vera_layer = VeRALinear( + in_features=16, + out_features=16, + r=4, + vera_dropout=0.1, + vera_alpha=4, + base_linear_module=nn.Linear(16, 16), + pissa_init=True, + ) + x = paddle.randn([2, 4, 16], "float32") + output = vera_layer(x) + self.assertFalse(vera_layer.vera_b.stop_gradient) + self.assertFalse(vera_layer.vera_d.stop_gradient) + self.assertTrue(vera_layer.weight.stop_gradient) + self.assertFalse(vera_layer.bias.stop_gradient) + self.assertEqual(output.shape, [2, 4, 16]) + + def test_train_eval(self): + x = paddle.randn([2, 4, 16], "float32") + vera_layer = VeRALinear( + in_features=16, out_features=16, r=4, base_linear_module=nn.Linear(in_features=16, out_features=16) + ) + vera_layer.train() + train_result = vera_layer(x) + train_weight = copy.deepcopy(vera_layer.weight) # deep copy since this is a pointer + vera_layer.eval() + eval_result = vera_layer(x) + eval_weight = vera_layer.weight + self.assertTrue(paddle.allclose(train_result, eval_result)) + self.assertTrue(paddle.allclose(train_weight, eval_weight)) + + def test_save_load(self): + with TemporaryDirectory() as tempdir: + vera_layer = VeRALinear( + in_features=16, out_features=16, r=4, base_linear_module=nn.Linear(in_features=16, out_features=16) + ) + weights_path = os.path.join(tempdir, "model.pdparams") + paddle.save(vera_layer.state_dict(), weights_path) + new_vera_layer = VeRALinear( + in_features=16, out_features=16, r=4, base_linear_module=nn.Linear(in_features=16, out_features=16) + ) + state_dict = paddle.load(weights_path) + new_vera_layer.set_dict(state_dict) + x = paddle.randn([2, 4, 16], "float32") + self.assertTrue(paddle.allclose(new_vera_layer(x), vera_layer(x))) + + def test_load_regular_linear(self): + with TemporaryDirectory() as tempdir: + regular_linear = paddle.nn.Linear(in_features=16, out_features=16) + weights_path = os.path.join(tempdir, "model.pdparams") + paddle.save(regular_linear.state_dict(), weights_path) + state_dict = paddle.load(weights_path) + # should be identical to regular linear + vera_layer_r8 = VeRALinear( + in_features=16, out_features=16, r=8, base_linear_module=nn.Linear(in_features=16, out_features=16) + ) + vera_layer_r4 = VeRALinear( + in_features=16, out_features=16, r=4, base_linear_module=nn.Linear(in_features=16, out_features=16) + ) + vera_layer_r8.set_dict(state_dict) + vera_layer_r4.set_dict(state_dict) + x = paddle.randn([2, 4, 16], "float32") + self.assertTrue(paddle.allclose(vera_layer_r8(x), regular_linear(x))) + self.assertTrue(paddle.allclose(vera_layer_r4(x), regular_linear(x))) + + +class TestVeraModel(unittest.TestCase): + @parameterized.expand([(None,), ("all",), ("vera",)]) + def test_vera_model_constructor(self, bias): + vera_config = VeRAConfig( + target_modules=[".*q_proj.*", ".*v_proj.*"], r=4, vera_alpha=4, head_dim=2, pissa_init=True + ) + # turn off plm dropout for to test train vs test + model = AutoModel.from_pretrained( + "__internal_testing__/tiny-random-bert", hidden_dropout_prob=0, attention_probs_dropout_prob=0 + ) + vera_model = VeRAModel(model, vera_config) + vera_model.mark_only_vera_as_trainable() + + for name, weight in vera_model.state_dict().items(): + if any([re.fullmatch(target_module, name) for target_module in vera_config.target_modules]): + if "vera_b" in name or "vera_d" in name: + self.assertFalse(weight.stop_gradient) + else: + self.assertTrue(weight.stop_gradient) + + input_ids = paddle.to_tensor(np.random.randint(100, 200, [1, 20])) + vera_model.train() + train_forward_results = vera_model(input_ids) + self.assertIsNotNone(train_forward_results) + vera_model.eval() + eval_forward_results = vera_model(input_ids) + self.assertIsNotNone(eval_forward_results) + self.assertTrue(paddle.allclose(train_forward_results[0], eval_forward_results[0])) + + def test_vera_model_save_load(self): + with TemporaryDirectory() as tempdir: + input_ids = paddle.to_tensor(np.random.randint(100, 200, [1, 20])) + vera_config = VeRAConfig( + target_modules=[".*q_proj.*", ".*v_proj.*"], + r=4, + vera_alpha=4, + ) + model = AutoModel.from_pretrained("__internal_testing__/tiny-random-bert") + vera_model = VeRAModel(model, vera_config) + vera_model.eval() + original_results = vera_model(input_ids) + vera_model.save_pretrained(tempdir) + + loaded_vera_model = VeRAModel.from_pretrained(model, tempdir) + loaded_vera_model.eval() + loaded_results = loaded_vera_model(input_ids) + self.assertTrue(paddle.allclose(original_results[0], loaded_results[0])) + + config_loaded_vera_model = VeRAModel.from_pretrained(model, tempdir, vera_config=vera_config) + config_loaded_vera_model.eval() + config_loaded_results = config_loaded_vera_model(input_ids) + self.assertTrue(paddle.allclose(original_results[0], config_loaded_results[0])) + + def test_restore_original_model(self): + vera_config = VeRAConfig( + target_modules=[".*q_proj.*", ".*v_proj.*"], + r=4, + vera_alpha=4, + ) + model = AutoModel.from_pretrained("__internal_testing__/tiny-random-bert") + vera_model = VeRAModel(model, vera_config) + with self.assertRaises(NotImplementedError): + vera_model.restore_original_model() + + def test_vera_module_raise_exception(self): + vera_config = VeRAConfig(target_modules=[".*norm1.*"], r=4, vera_alpha=4) + model = AutoModel.from_pretrained("__internal_testing__/tiny-random-bert") + with self.assertRaises(ValueError): + VeRAModel(model, vera_config) + + def test_pissa_raise_exception(self): + vera_config = VeRAConfig(target_modules=[".*q_proj.*"], r=4, vera_alpha=8, pissa_init=True) + model = AutoModel.from_pretrained("__internal_testing__/tiny-random-bert") + with self.assertRaises(AssertionError): + VeRAModel(model, vera_config) + + +class TestVeRAConfig(unittest.TestCase): + def test_save_load(self): + with TemporaryDirectory() as tempdir: + vera_config = VeRAConfig() + vera_config.save_pretrained(tempdir) + loaded_vera_config = VeRAConfig.from_pretrained(tempdir) + self.assertEqual(vera_config, loaded_vera_config) + + def test_save_load_err(self): + with NamedTemporaryFile("w+t") as f: + with self.assertRaises(ValueError): + VeRAConfig.from_pretrained(f.name) + + def test_save_pretrained_file_error(self): + with NamedTemporaryFile("w+t") as f: + vera_config = VeRAConfig() + with self.assertRaises(AssertionError): + vera_config.save_pretrained(f.name)