diff --git a/fastchat/model/compression.py b/fastchat/model/compression.py index 06e503f30..7329cfe0c 100644 --- a/fastchat/model/compression.py +++ b/fastchat/model/compression.py @@ -168,6 +168,11 @@ def load_compress_model(model_path, device, torch_dtype, use_fast, revision="mai base_pattern = os.path.join(model_path, "pytorch_model*.bin") files = glob.glob(base_pattern) + use_safetensors = False + if len(files) == 0: + base_pattern = os.path.join(model_path, "*.safetensors") + files = glob.glob(base_pattern) + use_safetensors = True if len(files) == 0: raise ValueError( f"Cannot find any model weight files. " @@ -175,8 +180,15 @@ def load_compress_model(model_path, device, torch_dtype, use_fast, revision="mai ) compressed_state_dict = {} + if use_safetensors: + from safetensors.torch import load_file for filename in tqdm(files): - tmp_state_dict = torch.load(filename, map_location=lambda storage, loc: storage) + if use_safetensors: + tmp_state_dict = load_file(filename) + else: + tmp_state_dict = torch.load( + filename, map_location=lambda storage, loc: storage + ) for name in tmp_state_dict: if name in linear_weights: tensor = tmp_state_dict[name].to(device, dtype=torch_dtype)