diff --git a/modules/quant_loader.py b/modules/quant_loader.py index a2b484b022..7a5f8461f5 100644 --- a/modules/quant_loader.py +++ b/modules/quant_loader.py @@ -9,8 +9,17 @@ sys.path.insert(0, str(Path("repositories/GPTQ-for-LLaMa"))) -# 4-bit LLaMA -def load_quantized(model_name, model_type): +def load_quantized(model_name): + if not shared.args.gptq_model_type: + # Try to determine model type from model name + model_type = model_name.split('-')[0].lower() + if model_type not in ('llama', 'opt'): + print("Can't determine model type from model name. Please specify it manually using --gptq-model-type " + "argument") + exit() + else: + model_type = shared.args.gptq_model_type.lower() + if model_type == 'llama': from llama import load_quant elif model_type == 'opt': @@ -20,7 +29,16 @@ def load_quantized(model_name, model_type): exit() path_to_model = Path(f'models/{model_name}') - pt_model = f'{model_name}-{shared.args.gptq_bits}bit.pt' + if path_to_model.name.lower().startswith('llama-7b'): + pt_model = f'llama-7b-{shared.args.gptq_bits}bit.pt' + elif path_to_model.name.lower().startswith('llama-13b'): + pt_model = f'llama-13b-{shared.args.gptq_bits}bit.pt' + elif path_to_model.name.lower().startswith('llama-30b'): + pt_model = f'llama-30b-{shared.args.gptq_bits}bit.pt' + elif path_to_model.name.lower().startswith('llama-65b'): + pt_model = f'llama-65b-{shared.args.gptq_bits}bit.pt' + else: + pt_model = f'{model_name}-{shared.args.gptq_bits}bit.pt' # Try to find the .pt both in models/ and in the subfolder pt_path = None