Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Draft] Add AutoRound support #5486

Draft
wants to merge 2 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 42 additions & 0 deletions examples/train_qlora/llama3_lora_sft_round.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
### model
model_name_or_path: /models/opt-125m
export_quantization_bit: 4
export_quantization_dataset: "data/c4_demo.json" ##not used
export_quantization_method: "auto_round"

### method
stage: sft
do_train: true
finetuning_type: lora
lora_target: all

### dataset
dataset: identity,alpaca_en_demo
template: llama3
cutoff_len: 1024
max_samples: 1000
overwrite_cache: true
preprocessing_num_workers: 16

### output
output_dir: saves/llama3-8b/lora/sft
logging_steps: 10
save_steps: 500
plot_loss: true
overwrite_output_dir: true

### train
per_device_train_batch_size: 1
gradient_accumulation_steps: 8
learning_rate: 1.0e-4
num_train_epochs: 3.0
lr_scheduler_type: cosine
warmup_ratio: 0.1
bf16: true
ddp_timeout: 180000000

### eval
val_size: 0.1
per_device_eval_batch_size: 1
eval_strategy: steps
eval_steps: 500
14 changes: 12 additions & 2 deletions src/llamafactory/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
from .extras.misc import get_device_count
from .train.tuner import export_model, run_exp
from .webui.interface import run_web_demo, run_web_ui

import shlex

USAGE = (
"-" * 70
Expand Down Expand Up @@ -86,11 +86,18 @@ def main():
elif command == Command.EXPORT:
export_model()
elif command == Command.TRAIN:
from .train.tuner import quantize_model
args = quantize_model()
force_torchrun = os.environ.get("FORCE_TORCHRUN", "0").lower() in ["true", "1"]
if force_torchrun or get_device_count() > 1:
master_addr = os.environ.get("MASTER_ADDR", "127.0.0.1")
master_port = os.environ.get("MASTER_PORT", str(random.randint(20001, 29999)))
logger.info("Initializing distributed tasks at: {}:{}".format(master_addr, master_port))
args_list = []
for key, value in args.items():
args_list.append(f'--{key}')
args_list.append(str(value))

process = subprocess.run(
(
"torchrun --nnodes {nnodes} --node_rank {node_rank} --nproc_per_node {nproc_per_node} "
Expand All @@ -102,7 +109,7 @@ def main():
master_addr=master_addr,
master_port=master_port,
file_name=launcher.__file__,
args=" ".join(sys.argv[1:]),
args=' '.join(shlex.quote(arg) for arg in args_list)
),
shell=True,
)
Expand All @@ -119,3 +126,6 @@ def main():
print(USAGE)
else:
raise NotImplementedError("Unknown command: {}.".format(command))

if __name__ == "__main__": ##TODO remove this
main()
4 changes: 4 additions & 0 deletions src/llamafactory/hparams/model_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,10 @@ class ExportArguments:
default=None,
metadata={"help": "The number of bits to quantize the exported model."},
)
export_quantization_method: Optional[str] = field(
default=None,
metadata={"help": "The method to quantize the exported model."},
)
export_quantization_dataset: Optional[str] = field(
default=None,
metadata={"help": "Path to the dataset or dataset name to use in quantizing the exported model."},
Expand Down
2 changes: 1 addition & 1 deletion src/llamafactory/hparams/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@

def _parse_args(parser: "HfArgumentParser", args: Optional[Dict[str, Any]] = None) -> Tuple[Any]:
if args is not None:
return parser.parse_dict(args)
return parser.parse_dict(args, allow_extra_keys=True)

if len(sys.argv) == 2 and sys.argv[1].endswith(".yaml"):
return parser.parse_yaml_file(os.path.abspath(sys.argv[1]))
Expand Down
2 changes: 1 addition & 1 deletion src/llamafactory/model/model_utils/quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ def configure_quantization(
quant_bits = quantization_config.get("bits", "?")
logger.info("Loading {}-bit {}-quantized model.".format(quant_bits, quant_method.upper()))

elif model_args.export_quantization_bit is not None: # auto-gptq
elif model_args.export_quantization_bit is not None and not model_args.export_quantization_method=="auto_round": # auto-gptq
if model_args.export_quantization_bit not in [8, 4, 3, 2]:
raise ValueError("AutoGPTQ only accepts 2/3/4/8-bit quantization.")

Expand Down
82 changes: 78 additions & 4 deletions src/llamafactory/train/tuner.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,18 +11,20 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import gc
import os
import shutil
from typing import TYPE_CHECKING, Any, Dict, List, Optional

import torch
from transformers import PreTrainedModel
from transformers import PreTrainedModel, AutoModelForVision2Seq, AutoModelForCausalLM
from transformers.utils.versions import require_version

from ..data import get_template_and_fix_tokenizer
from ..extras.constants import V_HEAD_SAFE_WEIGHTS_NAME, V_HEAD_WEIGHTS_NAME
from ..extras.logging import get_logger
from ..hparams import get_infer_args, get_train_args
from ..hparams.parser import _parse_train_args
from ..model import load_model, load_tokenizer
from .callbacks import LogCallback
from .dpo import run_dpo
Expand All @@ -31,12 +33,12 @@
from .pt import run_pt
from .rm import run_rm
from .sft import run_sft

from ..model.loader import _get_init_kwargs, load_config
from ..model.patcher import patch_config

if TYPE_CHECKING:
from transformers import TrainerCallback


logger = get_logger(__name__)


Expand All @@ -60,6 +62,78 @@ def run_exp(args: Optional[Dict[str, Any]] = None, callbacks: List["TrainerCallb
raise ValueError("Unknown task: {}.".format(finetuning_args.stage))


def quantize_model(args: Optional[Dict[str, Any]] = None):
if args is not None:
res = args
import sys
if len(sys.argv) == 2 and sys.argv[1].endswith(".yaml"):
yaml_file = os.path.abspath(sys.argv[1])
import yaml
from pathlib import Path
res = yaml.safe_load(Path(yaml_file).read_text())

if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
json_file = os.path.abspath(sys.argv[1])
import json
with open(Path(json_file), encoding="utf-8") as open_json_file:
res = json.loads(open_json_file.read())
model_args, data_args, training_args, finetuning_args, generating_args = _parse_train_args(args)
if not (model_args.export_quantization_bit is not None and model_args.export_quantization_method == "auto_round"):
return res

if model_args.export_quantization_bit not in [4]:
raise ValueError("AutoRound only accepts 4 bits quantization.")

require_version("auto_round>=0.3.0", "To fix: pip install auto_round>=0.3.0")
require_version("auto_gptq>=0.7.1", "To fix: pip install auto_gptq>=0.7.1")

if model_args.adapter_name_or_path:
raise ValueError("Please merge adapters before quantizing the model.")

if model_args.mixture_of_depths == "load":
raise NotImplementedError("AutoRound only supports `AutoModelForCausalLM` models ")
if model_args.train_from_scratch:
raise NotImplementedError("AutoRound only supports trained models")

if model_args.export_dir is None:
export_dir = 'saves/autoround_quantized_model'
logger.warning(" `export_dir` has not been specified, set it to `saves/autoround_quantized_model`.")
else:
export_dir = model_args.export_dir

tokenizer_module = load_tokenizer(model_args)
tokenizer = tokenizer_module["tokenizer"]
# get_template_and_fix_tokenizer(tokenizer, data_args)
from transformers.integrations import is_deepspeed_zero3_enabled
from transformers.modeling_utils import is_fsdp_enabled
if is_deepspeed_zero3_enabled() or is_fsdp_enabled():
raise ValueError("DeepSpeed ZeRO-3 or FSDP is incompatible with PTQ-quantized models.")

init_kwargs = _get_init_kwargs(model_args)
config = load_config(model_args)
if type(config) in AutoModelForVision2Seq._model_mapping.keys():
raise NotImplementedError("AutoRound only supports `AutoModelForCausalLM` models ")
patch_config(config, tokenizer, model_args, init_kwargs, is_trainable=False)
init_kwargs["config"] = config
init_kwargs["pretrained_model_name_or_path"] = model_args.model_name_or_path
init_kwargs["device_map"] = "cpu"
init_kwargs["torch_dtype"] = "auto"
init_kwargs['config'].use_cache = False
model = AutoModelForCausalLM.from_pretrained(**init_kwargs)
bits, group_size, sym = model_args.export_quantization_bit, 128, False
from auto_round import AutoRound
autoround = AutoRound(model, tokenizer, bits=model_args.export_quantization_bit, group_size=group_size, sym=sym,
nsamples=2, iters=2) ##TODO pass more configs and change it back
autoround.quantize()
autoround.save_quantized(export_dir, format='auto_gptq', inplace=True)
torch.cuda.empty_cache()
gc.collect()
res.pop("export_quantization_bit")
res["model_name_or_path"] = export_dir

return res


def export_model(args: Optional[Dict[str, Any]] = None) -> None:
model_args, data_args, finetuning_args, _ = get_infer_args(args)

Expand Down