Skip to content

Commit

Permalink
Make --load-8bit flag work with weights in safetensors format (#2698)
Browse files Browse the repository at this point in the history
  • Loading branch information
xuguodong1999 authored Nov 22, 2023
1 parent 85c797e commit 99d19ac
Showing 1 changed file with 13 additions and 1 deletion.
14 changes: 13 additions & 1 deletion fastchat/model/compression.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,15 +168,27 @@ def load_compress_model(model_path, device, torch_dtype, use_fast, revision="mai
base_pattern = os.path.join(model_path, "pytorch_model*.bin")

files = glob.glob(base_pattern)
use_safetensors = False
if len(files) == 0:
base_pattern = os.path.join(model_path, "*.safetensors")
files = glob.glob(base_pattern)
use_safetensors = True
if len(files) == 0:
raise ValueError(
f"Cannot find any model weight files. "
f"Please check your (cached) weight path: {model_path}"
)

compressed_state_dict = {}
if use_safetensors:
from safetensors.torch import load_file
for filename in tqdm(files):
tmp_state_dict = torch.load(filename, map_location=lambda storage, loc: storage)
if use_safetensors:
tmp_state_dict = load_file(filename)
else:
tmp_state_dict = torch.load(
filename, map_location=lambda storage, loc: storage
)
for name in tmp_state_dict:
if name in linear_weights:
tensor = tmp_state_dict[name].to(device, dtype=torch_dtype)
Expand Down

0 comments on commit 99d19ac

Please sign in to comment.