Skip to content

Commit

Permalink
Merge branch 'main' into sa/weight_packed
Browse files Browse the repository at this point in the history
  • Loading branch information
Sara Adkins committed May 23, 2024
2 parents c3df04b + 4a0e3c5 commit 1694b0b
Show file tree
Hide file tree
Showing 15 changed files with 256 additions and 98 deletions.
18 changes: 14 additions & 4 deletions src/compressed_tensors/compressors/int_quantized.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,11 @@ def compress(
args=quant_args,
dtype=torch.int8,
)

elif name.endswith("zero_point"):
if torch.all(value == 0):
# all zero_points are 0, no need to include in
# compressed state_dict
continue
compressed_dict[name] = value.to("cpu")

return compressed_dict
Expand Down Expand Up @@ -106,10 +110,16 @@ def decompress(
with safe_open(safe_path, framework="pt", device=device) as f:
weight_data[param_name] = f.get_tensor(full_name)

if len(weight_data) == len(self.COMPRESSION_PARAM_NAMES):
if "weight_scale" in weight_data:
zero_point = weight_data.get("weight_zero_point", None)
scale = weight_data["weight_scale"]
if zero_point is None:
# zero_point assumed to be 0 if not included in state_dict
zero_point = torch.zeros_like(scale)

decompressed = dequantize(
x_q=weight_data["weight"],
scale=weight_data["weight_scale"],
zero_point=weight_data["weight_zero_point"],
scale=scale,
zero_point=zero_point,
)
yield merge_names(weight_name, "weight"), decompressed
18 changes: 15 additions & 3 deletions src/compressed_tensors/compressors/pack_quantized.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,12 @@ def compress(
compressed_dict[merge_names(prefix, "weight_packed")] = value
continue

elif name.endswith("zero_point"):
if torch.all(value == 0):
# all zero_points are 0, no need to include in
# compressed state_dict
continue

compressed_dict[name] = value.to("cpu")

return compressed_dict
Expand Down Expand Up @@ -116,14 +122,20 @@ def decompress(
with safe_open(safe_path, framework="pt", device=device) as f:
weight_data[param_name] = f.get_tensor(full_name)

if len(weight_data) == len(self.COMPRESSION_PARAM_NAMES):
if "weight_scale" in weight_data:
zero_point = weight_data.get("weight_zero_point", None)
scale = weight_data["weight_scale"]
if zero_point is None:
# zero_point assumed to be 0 if not included in state_dict
zero_point = torch.zeros_like(scale)

weight = weight_data["weight_packed"]
original_shape = torch.Size(weight_data["weight_shape"])
unpacked = unpack_4bit_ints(weight, original_shape)
decompressed = dequantize(
x_q=unpacked,
scale=weight_data["weight_scale"],
zero_point=weight_data["weight_zero_point"],
scale=scale,
zero_point=zero_point,
)
yield merge_names(weight_name, "weight"), decompressed

Expand Down
2 changes: 1 addition & 1 deletion src/compressed_tensors/compressors/sparse_bitmask.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ def compress(self, model_state: Dict[str, Tensor]) -> Dict[str, Tensor]:
f"found an existing entry for {key}. The existing entry will "
"be replaced."
)
compressed_dict |= bitmask_dict
compressed_dict.update(bitmask_dict)

return compressed_dict

Expand Down
25 changes: 1 addition & 24 deletions src/compressed_tensors/quantization/lifecycle/forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,6 @@ def _process_quantization(
q_min = torch.tensor(-bit_range / 2, device=x.device)
group_size = args.group_size

# group
if args.strategy == QuantizationStrategy.GROUP:

if do_dequantize: # if dequantizing the output should be a fp type
Expand Down Expand Up @@ -195,29 +194,7 @@ def _process_quantization(
)
output[:, idx : (idx + group_size)] = _dequantize(input, sc, zp)

# channel-wise
elif args.strategy == QuantizationStrategy.CHANNEL: # group_size == -1
if do_quantize:
output = _quantize(x, scale, zero_point, q_min, q_max, dtype=dtype)
if do_dequantize:
output = _dequantize(output if do_quantize else x, scale, zero_point)

# per-token
elif args.strategy == QuantizationStrategy.TOKEN:
# before: scale shape = [num_tokens]
# after: scale shape = [num_tokens, 1]
# x.shape = 1, num_tokens, 1]
# scale gets broadcasted as expected withput having [1, num_tokens, 1] shape

scale = scale.unsqueeze(1)
zero_point = zero_point.unsqueeze(1)

if do_quantize:
output = _quantize(x, scale, zero_point, q_min, q_max, dtype=dtype)
if do_dequantize:
output = _dequantize(output if do_quantize else x, scale, zero_point)

else:
else: # covers channel, token and tensor strategies
if do_quantize:
output = _quantize(x, scale, zero_point, q_min, q_max, dtype=dtype)
if do_dequantize:
Expand Down
50 changes: 24 additions & 26 deletions src/compressed_tensors/quantization/observers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Optional, Tuple
from typing import Any, Optional, Tuple

import torch
from compressed_tensors.quantization.quant_args import (
Expand Down Expand Up @@ -50,9 +50,16 @@ def forward(self, observed: Tensor) -> Tuple[FloatTensor, IntTensor]:
"""
return self.get_qparams(observed=observed)

def calculate_qparams(self, observed: Tensor) -> Tuple[FloatTensor, IntTensor]:
def calculate_qparams(
self,
observed: Tensor,
reduce_dims: Optional[Tuple[int]] = None,
) -> Tuple[FloatTensor, IntTensor]:
"""
:param observed: observed tensor to calculate quantization parameters for
:param reduce_dims: optional tuple of dimensions to reduce along,
returned scale and zero point will be shaped (1,) along the
reduced dimensions
:return: tuple of scale and zero point derived from the observed tensor
"""
raise NotImplementedError(f"{self.__class__} must implement calculate_qparams")
Expand All @@ -70,6 +77,7 @@ def get_qparams(
Convenience function to wrap overwritten calculate_qparams
adds support to make observed tensor optional and support for tracking latest
calculated scale and zero point
:param observed: optional observed tensor to calculate quantization parameters
from
:return: tuple of scale and zero point based on last observed value
Expand All @@ -85,46 +93,36 @@ def get_qparams(
elif self.quantization_args.strategy == QuantizationStrategy.GROUP:
columns = observed.shape[1]
scales, zero_points = [], []
for i in range(0, columns, self.quantization_args.group_size):
group_idxs = range(0, columns, self.quantization_args.group_size)
for group_id, group_idx in enumerate(group_idxs):
scale, zero_point = self.get_qparams_along_dim(
observed[:, i : (i + group_size)],
observed[:, group_idx : (group_idx + group_size)],
0,
tensor_id=group_id,
)
scales.append(scale)
zero_points.append(zero_point)
self._scale = torch.stack(scales, dim=1, out=self._scale)
self._zero_point = torch.stack(zero_points, dim=1, out=self._zero_point)

self._scale = torch.cat(scales, dim=1, out=self._scale)
self._zero_point = torch.cat(zero_points, dim=1, out=self._zero_point)

elif self.quantization_args.strategy == QuantizationStrategy.CHANNEL:
# assume observed is transposed, because its the output, hence use dim 0
self._scale, self._zero_point = self.get_qparams_along_dim(observed, 0)

elif self.quantization_args.strategy == QuantizationStrategy.TOKEN:

# use dim 1, assume the obsersed.shape = [batch, token, hidden]
# should be batch, token

self._scale, self._zero_point = self.get_qparams_along_dim(
observed, dim=1
)

return self._scale, self._zero_point

def get_qparams_along_dim(self, observed, dim: int):
# TODO: add documentation that specifies the shape must
# be padded with 1-dims so the scales are along the right channel
# TODO: generalize the logic for reduce_dims
scales, zero_points = [], []

# TODO: make a more generic way to get the channel
num_dims = observed.shape[dim]

for dim_idx in range(num_dims):
scale, zero_point = self.calculate_qparams(
observed.select(dim=dim, index=dim_idx)
)

scales.append(scale)
zero_points.append(zero_point)
# breakpoint()
return torch.stack(scales), torch.stack(zero_points)
def get_qparams_along_dim(
self, observed, dim: int, tensor_id: Optional[Any] = None
):
reduce_dims = tuple(idx for idx in range(observed.ndim) if idx != dim)
return self.calculate_qparams(
observed, reduce_dims=reduce_dims, tensor_id=tensor_id
)
26 changes: 17 additions & 9 deletions src/compressed_tensors/quantization/observers/memoryless.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Tuple
from typing import Any, Optional, Tuple

import torch
from compressed_tensors.quantization.observers.base import Observer
Expand All @@ -30,19 +30,27 @@ class MemorylessObserver(Observer):
zero point based on the latest observed value without tracking state
"""

def calculate_qparams(self, observed: Tensor) -> Tuple[FloatTensor, IntTensor]:
def calculate_qparams(
self,
observed: Tensor,
tensor_id: Optional[Any] = None,
reduce_dims: Optional[Tuple[int]] = None,
) -> Tuple[FloatTensor, IntTensor]:
"""
Returns the min and max values of observed
Returns the min and max values of observed tensor
:param observed: observed tensor to calculate quantization parameters for
:param tensor_id: optional id for tensor; not used for memoryless
:param reduce_dims: optional tuple of dimensions to reduce along,
returned scale and zero point will be shaped (1,) along the
reduced dimensions
:return: tuple of scale and zero point derived from the observed tensor
"""
# TODO: Add support for full range of quantization Args, only supports 8bit
# per tensor
min_val, max_val = torch.aminmax(observed)

# ensure zero is in the range
min_val = torch.min(min_val, torch.zeros_like(min_val))
max_val = torch.max(max_val, torch.zeros_like(max_val))
if not reduce_dims:
min_val, max_val = torch.aminmax(observed)
else:
min_val = torch.amin(observed, dim=reduce_dims, keepdims=True)
max_val = torch.amax(observed, dim=reduce_dims, keepdims=True)

return calculate_qparams(min_val, max_val, self.quantization_args)
42 changes: 29 additions & 13 deletions src/compressed_tensors/quantization/observers/min_max.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Optional, Tuple
from typing import Any, Optional, Tuple

import torch
from compressed_tensors.quantization.observers.base import Observer
Expand All @@ -36,14 +36,15 @@ def __init__(
):
super().__init__(quantization_args=quantization_args)

self.min_val = None
self.max_val = None
self.min_val = {}
self.max_val = {}
self.averaging_constant = averaging_constant

def calculate_qparams(
self,
observed: Tensor,
reduce_dims: Optional[Tuple[int]] = None,
tensor_id: Optional[Any] = None,
) -> Tuple[FloatTensor, IntTensor]:
"""
Updates the observed min and max using a moving average smoothed by the
Expand All @@ -53,28 +54,43 @@ def calculate_qparams(
:param reduce_dims: optional tuple of dimensions to reduce along,
returned scale and zero point will be shaped (1,) along the
reduced dimensions
:param tensor_id: Optional id if different ranges of observed tensors are
passed, useful for sharding tensors by group_size
:return: tuple of scale and zero point derived from the observed tensor
"""
tensor_id = tensor_id or "default"

if not reduce_dims:
min_val, max_val = torch.aminmax(observed)
else:
min_val = torch.amin(observed, dim=reduce_dims, keepdims=True)
max_val = torch.amax(observed, dim=reduce_dims, keepdims=True)

if self.min_val is None and self.max_val is None:
self.min_val = min_val
self.max_val = max_val
running_min_val = self.min_val.get(tensor_id, None)
running_max_val = self.max_val.get(tensor_id, None)

if running_min_val is None or running_max_val is None:
updated_min_val = min_val
updated_max_val = max_val
else:
self.min_val = self.min_val + self.averaging_constant * (
min_val - self.min_val
updated_min_val = running_min_val + self.averaging_constant * (
min_val - running_min_val
)
self.max_val = self.max_val + self.averaging_constant * (
max_val - self.max_val
updated_max_val = running_max_val + self.averaging_constant * (
max_val - running_max_val
)

return calculate_qparams(self.min_val, self.max_val, self.quantization_args)
self.min_val[tensor_id] = updated_min_val
self.max_val[tensor_id] = updated_max_val

return calculate_qparams(
updated_min_val, updated_max_val, self.quantization_args
)

def get_qparams_along_dim(self, observed, dim: int):
def get_qparams_along_dim(
self, observed, dim: int, tensor_id: Optional[Any] = None
):
reduce_dims = tuple(idx for idx in range(observed.ndim) if idx != dim)
return self.calculate_qparams(observed, reduce_dims=reduce_dims)
return self.calculate_qparams(
observed, reduce_dims=reduce_dims, tensor_id=tensor_id
)
25 changes: 21 additions & 4 deletions src/compressed_tensors/quantization/quant_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,14 @@
# limitations under the License.

from enum import Enum
from typing import Dict, List, Optional
from typing import Dict, List, Optional, Union

from compressed_tensors.base import QUANTIZATION_CONFIG_NAME
from compressed_tensors.config import CompressionFormat
from compressed_tensors.quantization.quant_scheme import QuantizationScheme
from compressed_tensors.quantization.quant_scheme import (
QuantizationScheme,
preset_name_to_scheme,
)
from compressed_tensors.quantization.utils import (
calculate_compression_ratio,
is_module_quantized,
Expand Down Expand Up @@ -105,7 +108,8 @@ class QuantizationConfig(BaseModel):
mapped to a QuantizationScheme in config_groups.
:param config_groups: dict of QuantizationSchemes specifying the quantization
settings for each quantized layer
settings for each quantized layer. A group could also be a reference to
a predefined scheme name, mapped to a list of its target layers/classes
:param quant_method: a constant used to differentiate sparseML quantization from
other quantization configs
:param format: specifies how the quantized model is stored on disk
Expand All @@ -117,13 +121,26 @@ class QuantizationConfig(BaseModel):
are not quantized even if they match up with a target in config_groups
"""

config_groups: Dict[str, QuantizationScheme]
config_groups: Dict[str, Union[QuantizationScheme, List[str]]]
quant_method: str = "sparseml"
format: str = "fakequant"
quantization_status: QuantizationStatus = QuantizationStatus.INITIALIZED
global_compression_ratio: Optional[float] = None
ignore: Optional[List[str]] = Field(default_factory=list)

def model_post_init(self, __context):
"""
updates any quantization schemes defined as presets to be fully loaded
schemes
"""
for group_name, targets_or_scheme in self.config_groups.items():
if isinstance(targets_or_scheme, QuantizationScheme):
continue # scheme already defined
self.config_groups[group_name] = preset_name_to_scheme(
name=group_name,
targets=targets_or_scheme,
)

@staticmethod
def from_model_config(model_name_or_path) -> "QuantizationConfig":
"""
Expand Down
Loading

0 comments on commit 1694b0b

Please sign in to comment.