diff --git a/examples/bitmask_compression.ipynb b/examples/bitmask_compression.ipynb new file mode 100644 index 00000000..7658a67a --- /dev/null +++ b/examples/bitmask_compression.ipynb @@ -0,0 +1,252 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Bitmask Compression Example ##\n", + "\n", + "Bitmask compression allows for storing sparse tensors efficiently on the disk. \n", + "\n", + "Instead of storing each zero element represented as an actual number, we use bitmask to indicate which tensor entries correspond to zero elements. This approach is useful when the matrix is mostly zero values, as it saves space by not wastefully storing those zeros explicitly.\n", + "\n", + "The example below shows how to save and load sparse tensors using bitmask compression. It also demonstrates the benefits of the bitmask compression over \"dense\" representation, and finally, introduces the enhanced `safetensors` file format for storing sparse weights." + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [], + "source": [ + "import torch\n", + "import os\n", + "from safetensors import safe_open\n", + "from safetensors.torch import save_model\n", + "from compressed_tensors import save_compressed_model, load_compressed, BitmaskConfig\n", + "from transformers import AutoModelForCausalLM" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "LlamaForCausalLM(\n", + " (model): LlamaModel(\n", + " (embed_tokens): Embedding(32000, 768)\n", + " (layers): ModuleList(\n", + " (0-11): 12 x LlamaDecoderLayer(\n", + " (self_attn): LlamaSdpaAttention(\n", + " (q_proj): Linear(in_features=768, out_features=768, bias=False)\n", + " (k_proj): Linear(in_features=768, out_features=768, bias=False)\n", + " (v_proj): Linear(in_features=768, out_features=768, bias=False)\n", + " (o_proj): Linear(in_features=768, out_features=768, bias=False)\n", + " (rotary_emb): LlamaRotaryEmbedding()\n", + " )\n", + " (mlp): LlamaMLP(\n", + " (gate_proj): Linear(in_features=768, out_features=2048, bias=False)\n", + " (up_proj): Linear(in_features=768, out_features=2048, bias=False)\n", + " (down_proj): Linear(in_features=2048, out_features=768, bias=False)\n", + " (act_fn): SiLU()\n", + " )\n", + " (input_layernorm): LlamaRMSNorm()\n", + " (post_attention_layernorm): LlamaRMSNorm()\n", + " )\n", + " )\n", + " (norm): LlamaRMSNorm()\n", + " )\n", + " (lm_head): Linear(in_features=768, out_features=32000, bias=False)\n", + ")" + ] + }, + "execution_count": 8, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# load a tiny, pruned llama2 model\n", + "model_name = \"neuralmagic/llama2.c-stories110M-pruned50\"\n", + "model = AutoModelForCausalLM.from_pretrained(model_name)\n", + "model" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "The example layer model.layers.0.self_attn.q_proj.weight has sparsity 0.50%\n" + ] + } + ], + "source": [ + "# most of the weights of the model are pruned to 50% (except for few layers such as lm_head or embeddings)\n", + "state_dict = model.state_dict()\n", + "state_dict.keys()\n", + "example_layer = \"model.layers.0.self_attn.q_proj.weight\"\n", + "print(f\"The example layer {example_layer} has sparsity {torch.sum(state_dict[example_layer] == 0).item() / state_dict[example_layer].numel():.2f}%\")" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "The model is 31.67% sparse overall\n" + ] + } + ], + "source": [ + "# we can inspect to total sparisity of the state_dict\n", + "total_num_parameters = 0\n", + "total_num_zero_parameters = 0\n", + "for key in state_dict:\n", + " total_num_parameters += state_dict[key].numel()\n", + " total_num_zero_parameters += state_dict[key].eq(0).sum().item()\n", + "print(f\"The model is {total_num_zero_parameters/total_num_parameters*100:.2f}% sparse overall\")" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Compressing model: 100%|██████████| 111/111 [00:06<00:00, 17.73it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Size of the model's weights on disk using safetensors: 417.83 MB\n", + "Size of the model's weights on disk using compressed-tensors: 366.82 MB\n", + "The compression ratio is x1.14\n" + ] + } + ], + "source": [ + "# let's save the model on disk using safetensors and compressed-tensors and compare the size on disk\n", + "\n", + "## save the model using safetensors ##\n", + "save_model(model, \"model.safetensors\")\n", + "size_on_disk_mb = os.path.getsize('model.safetensors') / 1024 / 1024\n", + "\n", + "## save the model using compressed-tensors ##\n", + "save_compressed_model(model, \"compressed_model.safetensors\", compression_format=\"sparse-bitmask\")\n", + "compressed_size_on_disk_mb = os.path.getsize('compressed_model.safetensors') / 1024 / 1024\n", + "\n", + "print(f\"Size of the model's weights on disk using safetensors: {size_on_disk_mb:.2f} MB\")\n", + "print(f\"Size of the model's weights on disk using compressed-tensors: {compressed_size_on_disk_mb:.2f} MB\")\n", + "print(\"The compression ratio is x{:.2f}\".format(size_on_disk_mb / compressed_size_on_disk_mb))" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Storing weights with around 30% of zero entries requires significantly less disk space when using `compressed-tensors`. The compression ratio improves radically for more sparse models. \n", + "\n", + "We can load back the `state_dict` from the compressed and uncompressed representation on disk and confirm, that they represent same tensors in memory." + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Once loaded, the state_dicts from safetensors and compressed-tensors are equal: True\n" + ] + } + ], + "source": [ + "# load the safetensor and the compressed-tensor and show that they have the same representation\n", + "\n", + "## load the uncompressed safetensors to memory ##\n", + "state_dict_1 = {}\n", + "with safe_open('model.safetensors', framework=\"pt\") as f:\n", + " for key in f.keys():\n", + " state_dict_1[key] = f.get_tensor(key)\n", + "\n", + "## load the compressed-tensors to memory ##\n", + "config = BitmaskConfig() # we need to specify the method for decompression\n", + "state_dict_2 = load_compressed(\"compressed_model.safetensors\", config)\n", + "\n", + "tensors_equal = all(torch.equal(state_dict_1[key], state_dict_2[key]) for key in state_dict_1)\n", + "\n", + "print(f\"Once loaded, the state_dicts from safetensors and compressed-tensors are equal: {tensors_equal}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### SafeTensors File Format\n", + "\n", + "The reason why the introduced bitmask compression is much more efficient, is imbibing the information about the compression in the header of the `.safetensors` file.\n", + "For each parameter in the uncompressed `state_dict`, we store the following attributes needed for decompression in the compressed `state_dict`:\n", + "\n", + "* Compressed tensor\n", + "* Bitmask\n", + "* Uncompressed shape\n", + "* Row offsets\n", + "\n", + "```bash\n", + "# Dense\n", + "{\n", + " PARAM_NAME: uncompressed_tensor\n", + "}\n", + "\n", + "# Compressed\n", + "{\n", + " PARAM_NAME.compressed: compressed_tensor, # 1d tensor\n", + " PARAM_NAME.bitmask: value, # 2d bitmask tensor (nrows x (ncols / 8))\n", + " PARAM_NAME.shape: value, # Uncompressed shape tensor\n", + " PARAM_NAME.row_offsets: value # 1d offsets tensor\n", + "}\n", + "```" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": ".venv", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.12" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/examples/quantization.ipynb b/examples/quantization.ipynb new file mode 100644 index 00000000..5834dd37 --- /dev/null +++ b/examples/quantization.ipynb @@ -0,0 +1,346 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## QuantizationConfig Application ##\n", + "\n", + "`QuantizationConfig` allows for compressing the model on disk by reducing the precision of weights e.g. from float16 to int8.\n", + "\n", + "In order to save a compressed (quantized)\n", + "\n", + "1. Create a \"vanilla\" model. For that purpose we are using a `TinyLlaMa` model\n", + "2. Define the arguments of the `QuantizationConfig`\n", + "3. Use the function `apply_quantization_config` to modify the model to, for the relevant weight matrices, add parameters that simulate the quantize and dequantize operations (scale and zero point).\n", + "4. Calibrate the scale and zero point through few forward passes of the calibration data\n", + "5. Using the obtained scales and zero points to quantize the weight matrices to `int8` representation.\n", + "\n", + "The example below shows how to quantize the model. It also demonstrates the benefits of the quantization over \"dense\" representation." + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/nm/drive0/damian/compressed-tensors/.venv/lib/python3.10/site-packages/pydantic/_internal/_fields.py:186: UserWarning: Field name \"registry_requires_subclass\" shadows an attribute in parent \"RegistryMixin\"; \n", + " warnings.warn(\n", + "/nm/drive0/damian/compressed-tensors/.venv/lib/python3.10/site-packages/pydantic/_internal/_fields.py:186: UserWarning: Field name \"registry_requires_subclass\" shadows an attribute in parent \"CompressionConfig\"; \n", + " warnings.warn(\n", + "/nm/drive0/damian/compressed-tensors/.venv/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", + " from .autonotebook import tqdm as notebook_tqdm\n" + ] + } + ], + "source": [ + "from compressed_tensors.quantization import QuantizationConfig, apply_quantization_config, freeze_module_quantization\n", + "from transformers import AutoModelForCausalLM, AutoTokenizer\n", + "from datasets import load_dataset\n", + "from tqdm import tqdm" + ] + }, + { + "cell_type": "code", + "execution_count": 19, + "metadata": {}, + "outputs": [], + "source": [ + "model_name = \"neuralmagic/llama2.c-stories110M-pruned50\" # model to quantize and calibrate\n", + "dataset_name = \"garage-bAInd/Open-Platypus\" # dataset to calibrate on\n", + "num_calibration_samples = 1 # num calibration samples to calibrate on" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "LlamaForCausalLM(\n", + " (model): LlamaModel(\n", + " (embed_tokens): Embedding(32000, 768)\n", + " (layers): ModuleList(\n", + " (0-11): 12 x LlamaDecoderLayer(\n", + " (self_attn): LlamaSdpaAttention(\n", + " (q_proj): Linear(in_features=768, out_features=768, bias=False)\n", + " (k_proj): Linear(in_features=768, out_features=768, bias=False)\n", + " (v_proj): Linear(in_features=768, out_features=768, bias=False)\n", + " (o_proj): Linear(in_features=768, out_features=768, bias=False)\n", + " (rotary_emb): LlamaRotaryEmbedding()\n", + " )\n", + " (mlp): LlamaMLP(\n", + " (gate_proj): Linear(in_features=768, out_features=2048, bias=False)\n", + " (up_proj): Linear(in_features=768, out_features=2048, bias=False)\n", + " (down_proj): Linear(in_features=2048, out_features=768, bias=False)\n", + " (act_fn): SiLU()\n", + " )\n", + " (input_layernorm): LlamaRMSNorm()\n", + " (post_attention_layernorm): LlamaRMSNorm()\n", + " )\n", + " )\n", + " (norm): LlamaRMSNorm()\n", + " )\n", + " (lm_head): Linear(in_features=768, out_features=32000, bias=False)\n", + ")" + ] + }, + "execution_count": 11, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# Step 1: Load the model\n", + "model = AutoModelForCausalLM.from_pretrained(model_name)\n", + "model.eval()\n" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Module: layers.0.self_attn.q_proj has been quantized: False\n" + ] + } + ], + "source": [ + "for name, module in model.model.named_modules():\n", + " module_type = module.__class__.__name__\n", + " if module_type == \"Linear\":\n", + " is_quantized = hasattr(module, \"quantization_scheme\")\n", + " print(f\"Module: {name} has been quantized: {is_quantized}\")\n", + " break" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "metadata": {}, + "outputs": [], + "source": [ + "# Step 2: Define the quantization configuration\n", + "quantization_config = QuantizationConfig(\n", + " # \"fakequant\" means that the weights are still in their original format\n", + " # (e.g. float16), but quantization is emulated by adding scales/zeropoints \n", + " # to the model state_dict\n", + " format = \"fakequant\",\n", + " # \"initialize\" means that scale/zeropoints and observers have been attached \n", + " # to the layer but are set to dummy values (not yet calibrated)\n", + " quantization_status = \"calibration\",\n", + " config_groups ={\n", + " # \"group_1\" acts on all the nn.Linear layers of the model\n", + " # it quantizes the weights to 8 bits (symmetric, so zero_point = 0) \n", + " # it quantizes the input activations to 8 bits (asymmetric, so zero_point != 0)\n", + " \"group_1\": {\n", + " \"weights\": {\n", + " \"num_bits\": 8,\n", + " \"type\": \"int\",\n", + " \"symmetric\": True,\n", + " \"strategy\": \"tensor\"\n", + " },\n", + " \"input_activations\": {\n", + " \"num_bits\": 8,\n", + " \"type\": \"int\",\n", + " \"symmetric\": False,\n", + " \"strategy\": \"tensor\"\n", + " },\n", + " \"targets\": [\"Linear\"]\n", + " },\n", + " },\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 20, + "metadata": {}, + "outputs": [], + "source": [ + "# Step 3: Apply the quantization configuration\n", + "apply_quantization_config(model, quantization_config)" + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Module: layers.0.self_attn.q_proj has been quantized: True\n", + "Input activation scale: Parameter containing:\n", + "tensor([])\n", + "Input activation zero point: Parameter containing:\n", + "tensor([], dtype=torch.int64)\n", + "Weight scale: Parameter containing:\n", + "tensor([])\n", + "Weight zero point Parameter containing:\n", + "tensor([], dtype=torch.int64)\n" + ] + } + ], + "source": [ + "for name, module in model.model.named_modules():\n", + " module_type = module.__class__.__name__\n", + " if module_type == \"Linear\":\n", + " is_quantized = hasattr(module, \"quantization_scheme\")\n", + " print(f\"Module: {name} has been quantized: {is_quantized}\")\n", + " print(f\"Input activation scale: {module.input_scale}\")\n", + " print(f\"Input activation zero point: {module.input_zero_point}\")\n", + " print(f\"Weight scale: {module.weight_scale}\")\n", + " print(f\"Weight zero point {module.weight_zero_point}\")\n", + " break" + ] + }, + { + "cell_type": "code", + "execution_count": 21, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|██████████| 1/1 [01:23<00:00, 83.74s/it]\n" + ] + } + ], + "source": [ + "# Step 4: Calibrate the quantized layers\n", + "tokenizer = AutoTokenizer.from_pretrained(model_name)\n", + "tokenizer.pad_token = tokenizer.eos_token\n", + "dataset = load_dataset(dataset_name, split='train', streaming=True)\n", + "\n", + "# run calibration\n", + "for idx, sample in tqdm(enumerate(dataset), total = num_calibration_samples):\n", + " sample = tokenizer(sample['output'], return_tensors=\"pt\")\n", + " _ = model(**sample)\n", + "\n", + " if idx >= num_calibration_samples:\n", + " break" + ] + }, + { + "cell_type": "code", + "execution_count": 22, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Module: layers.0.self_attn.q_proj has been quantized: True\n", + "Input activation scale: 0.02618597261607647\n", + "Input activation zero point: 85\n", + "Weight scale: 0.008901744149625301\n", + "Weight zero point 0\n" + ] + } + ], + "source": [ + "for name, module in model.model.named_modules():\n", + " module_type = module.__class__.__name__\n", + " if module_type == \"Linear\":\n", + " is_quantized = hasattr(module, \"quantization_scheme\")\n", + " print(f\"Module: {name} has been quantized: {is_quantized}\")\n", + " print(f\"Input activation scale: {module.input_scale.item()}\")\n", + " print(f\"Input activation zero point: {module.input_zero_point.item()}\")\n", + " print(f\"Weight scale: {module.weight_scale.item()}\")\n", + " print(f\"Weight zero point {module.weight_zero_point.item()}\")\n", + " break" + ] + }, + { + "cell_type": "code", + "execution_count": 23, + "metadata": {}, + "outputs": [], + "source": [ + "model.apply(freeze_module_quantization)\n", + "quantization_config.format = \"compressed\"\n", + "apply_quantization_config(model, quantization_config)" + ] + }, + { + "cell_type": "code", + "execution_count": 24, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Module: layers.0.self_attn.q_proj has been quantized: True\n", + "Input activation scale: Parameter containing:\n", + "tensor([])\n", + "[Parameter containing:\n", + "tensor([[ 0.0000, 0.0000, -0.0712, ..., 0.0000, -0.0178, 0.0000],\n", + " [ 0.0000, 0.0979, 0.1246, ..., 0.1869, 0.0801, -0.1602],\n", + " [ 0.0000, -0.0712, 0.0000, ..., -0.2492, -0.1157, 0.2671],\n", + " ...,\n", + " [-0.0801, -0.0801, 0.0000, ..., 0.0000, 0.0000, 0.1068],\n", + " [-0.1157, -0.0801, -0.0712, ..., 0.0890, 0.0712, 0.0890],\n", + " [-0.1068, 0.0445, 0.0000, ..., 0.0712, 0.1246, 0.0534]],\n", + " requires_grad=True), Parameter containing:\n", + "tensor([]), Parameter containing:\n", + "tensor([], dtype=torch.int64), Parameter containing:\n", + "tensor([]), Parameter containing:\n", + "tensor([], dtype=torch.int64)]\n" + ] + } + ], + "source": [ + "for name, module in model.model.named_modules():\n", + " module_type = module.__class__.__name__\n", + " if module_type == \"Linear\":\n", + " is_quantized = hasattr(module, \"quantization_scheme\")\n", + " print(f\"Module: {name} has been quantized: {is_quantized}\")\n", + " print(f\"Input activation scale: {module.input_scale}\")\n", + " print(list(module.parameters()))\n", + " break" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": ".venv", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.12" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/examples/quantization_compression.ipynb b/examples/quantization_compression.ipynb new file mode 100644 index 00000000..d7cffc57 --- /dev/null +++ b/examples/quantization_compression.ipynb @@ -0,0 +1,344 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## QuantizationConfig Application ##\n", + "\n", + "`QuantizationConfig` allows for compressing the model on disk by reducing the precision of weights e.g. from float16 to int8.\n", + "\n", + "In order to save a compressed (quantized)\n", + "\n", + "1. Create a \"vanilla\" model. For that purpose we are using a `TinyLlaMa` model\n", + "2. Define the arguments of the `QuantizationConfig`\n", + "3. Use the function `apply_quantization_config` to modify the model to, for the relevant weight matrices, add parameters that simulate the quantize and dequantize operations (scale and zero point).\n", + "4. Calibrate the scale and zero point through few forward passes of the calibration data\n", + "5. Using the obtained scales and zero points to quantize the weight matrices to `int8` representation.\n", + "\n", + "The example below shows how to quantize the model. It also demonstrates the benefits of the quantization over \"dense\" representation." + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/nm/drive0/damian/compressed-tensors/.venv/lib/python3.10/site-packages/pydantic/_internal/_fields.py:186: UserWarning: Field name \"registry_requires_subclass\" shadows an attribute in parent \"RegistryMixin\"; \n", + " warnings.warn(\n", + "/nm/drive0/damian/compressed-tensors/.venv/lib/python3.10/site-packages/pydantic/_internal/_fields.py:186: UserWarning: Field name \"registry_requires_subclass\" shadows an attribute in parent \"CompressionConfig\"; \n", + " warnings.warn(\n", + "/nm/drive0/damian/compressed-tensors/.venv/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", + " from .autonotebook import tqdm as notebook_tqdm\n" + ] + } + ], + "source": [ + "from compressed_tensors.quantization import QuantizationConfig, apply_quantization_config, freeze_module_quantization, set_module_for_calibration\n", + "from transformers import AutoModelForCausalLM, AutoTokenizer\n", + "from datasets import load_dataset\n", + "from tqdm import tqdm\n", + "import torch" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "model_name = \"neuralmagic/llama2.c-stories110M-pruned50\" # model to quantize and calibrate\n", + "dataset_name = \"roneneldan/TinyStories\" # dataset to calibrate on\n", + "num_calibration_samples = 256 # num calibration samples to calibrate on\n", + "batch_size = 32 # batch size for calibration" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "LlamaForCausalLM(\n", + " (model): LlamaModel(\n", + " (embed_tokens): Embedding(32000, 768)\n", + " (layers): ModuleList(\n", + " (0-11): 12 x LlamaDecoderLayer(\n", + " (self_attn): LlamaSdpaAttention(\n", + " (q_proj): Linear(in_features=768, out_features=768, bias=False)\n", + " (k_proj): Linear(in_features=768, out_features=768, bias=False)\n", + " (v_proj): Linear(in_features=768, out_features=768, bias=False)\n", + " (o_proj): Linear(in_features=768, out_features=768, bias=False)\n", + " (rotary_emb): LlamaRotaryEmbedding()\n", + " )\n", + " (mlp): LlamaMLP(\n", + " (gate_proj): Linear(in_features=768, out_features=2048, bias=False)\n", + " (up_proj): Linear(in_features=768, out_features=2048, bias=False)\n", + " (down_proj): Linear(in_features=2048, out_features=768, bias=False)\n", + " (act_fn): SiLU()\n", + " )\n", + " (input_layernorm): LlamaRMSNorm()\n", + " (post_attention_layernorm): LlamaRMSNorm()\n", + " )\n", + " )\n", + " (norm): LlamaRMSNorm()\n", + " )\n", + " (lm_head): Linear(in_features=768, out_features=32000, bias=False)\n", + ")" + ] + }, + "execution_count": 3, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# Step 1: Load the model\n", + "model = AutoModelForCausalLM.from_pretrained(model_name)\n", + "model.to(\"cuda\")" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Module: layers.0.self_attn.q_proj has been quantized: False\n" + ] + } + ], + "source": [ + "def is_first_linear_layer_quantized(model: \"torch.nn.Module\"):\n", + " \"\"\"\n", + " Helper function that helps us determine if the first linear layer \n", + " in the model has been quantized (proxy for quantization state of the\n", + " whole model)\n", + " \"\"\"\n", + " for name, module in model.model.named_modules():\n", + " module_type = module.__class__.__name__\n", + " if module_type == \"Linear\":\n", + " is_quantized = hasattr(module, \"quantization_scheme\")\n", + " print(f\"Module: {name} has been quantized: {is_quantized}\")\n", + " if is_quantized:\n", + " print(f\"Input activation scale: {module.input_scale}\")\n", + " print(f\"Input activation zero point: {module.input_zero_point}\")\n", + " print(f\"Weight scale: {module.weight_scale}\")\n", + " print(f\"Weight zero point {module.weight_zero_point}\")\n", + " break\n", + "is_first_linear_layer_quantized(model)" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [], + "source": [ + "# Step 2: Define the quantization configuration\n", + "quantization_config = QuantizationConfig(\n", + " # \"fakequant\" means that the weights are still in their original format\n", + " # (e.g. float16), but quantization is emulated by adding scales/zeropoints \n", + " # to the model state_dict\n", + " format = \"fakequant\",\n", + " # \"initialize\" means that scale/zeropoints and observers have been attached \n", + " # to the layer but are set to dummy values (not yet calibrated)\n", + " quantization_status = \"calibration\",\n", + " config_groups ={\n", + " # \"group_1\" acts on all the nn.Linear layers of the model\n", + " # it quantizes the weights to 8 bits (symmetric, so zero_point = 0) \n", + " # it quantizes the input activations to 8 bits (asymmetric, so zero_point != 0)\n", + " \"group_1\": {\n", + " \"weights\": {\n", + " \"num_bits\": 8,\n", + " \"type\": \"int\",\n", + " \"symmetric\": True,\n", + " \"strategy\": \"tensor\"\n", + " },\n", + " \"input_activations\": {\n", + " \"num_bits\": 8,\n", + " \"type\": \"int\",\n", + " \"symmetric\": False,\n", + " \"strategy\": \"tensor\"\n", + " },\n", + " \"targets\": [\"Linear\"]\n", + " },\n", + " },\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Module: layers.0.self_attn.q_proj has been quantized: True\n", + "Input activation scale: Parameter containing:\n", + "tensor([], device='cuda:0')\n", + "Input activation zero point: Parameter containing:\n", + "tensor([], device='cuda:0', dtype=torch.int64)\n", + "Weight scale: Parameter containing:\n", + "tensor([], device='cuda:0')\n", + "Weight zero point Parameter containing:\n", + "tensor([], device='cuda:0', dtype=torch.int64)\n" + ] + } + ], + "source": [ + "# Step 3: Apply the quantization configuration and set the model for calibration\n", + "apply_quantization_config(model, quantization_config)\n", + "is_first_linear_layer_quantized(model)" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Repo card metadata block was not found. Setting CardData to empty.\n", + "100%|██████████| 8/8 [00:09<00:00, 1.23s/it]\n" + ] + } + ], + "source": [ + "# Step 4: Calibrate the quantized layers\n", + "tokenizer = AutoTokenizer.from_pretrained(model_name)\n", + "tokenizer.pad_token = tokenizer.eos_token\n", + "dataset = load_dataset(dataset_name, split='train', streaming=True)\n", + "dataset = dataset.map(lambda x: tokenizer(x['text'], truncation=True, padding=True), batched=True)\n", + "data_loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size)\n", + "\n", + "# # run calibration\n", + "for idx, sample in tqdm(enumerate(data_loader), total = num_calibration_samples // batch_size):\n", + " input_ids = torch.stack(sample[\"input_ids\"],axis=1).to(model.device)\n", + " attention_mask = torch.stack(sample[\"attention_mask\"],axis=1).to(model.device)\n", + "\n", + " _ = model(input_ids=input_ids, attention_mask=attention_mask)\n", + " \n", + " if idx >= num_calibration_samples // batch_size:\n", + " break" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Module: layers.0.self_attn.q_proj has been quantized: True\n", + "Input activation scale: Parameter containing:\n", + "tensor([0.0277], device='cuda:0')\n", + "Input activation zero point: Parameter containing:\n", + "tensor([69], device='cuda:0', dtype=torch.int8)\n", + "Weight scale: Parameter containing:\n", + "tensor([0.0089], device='cuda:0')\n", + "Weight zero point Parameter containing:\n", + "tensor(0, device='cuda:0', dtype=torch.int8)\n" + ] + } + ], + "source": [ + "model.apply(freeze_module_quantization)\n", + "is_first_linear_layer_quantized(model)" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, + "outputs": [], + "source": [ + "model.apply(freeze_module_quantization)\n", + "quantization_config.format = \"compressed\"\n", + "apply_quantization_config(model, quantization_config)" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Module: layers.0.self_attn.q_proj has been quantized: True\n", + "Input activation scale: Parameter containing:\n", + "tensor([], device='cuda:0')\n", + "[Parameter containing:\n", + "tensor([[ 0.0000, 0.0000, -0.0712, ..., 0.0000, -0.0178, 0.0000],\n", + " [ 0.0000, 0.0979, 0.1246, ..., 0.1869, 0.0801, -0.1602],\n", + " [ 0.0000, -0.0712, 0.0000, ..., -0.2492, -0.1157, 0.2671],\n", + " ...,\n", + " [-0.0801, -0.0801, 0.0000, ..., 0.0000, 0.0000, 0.1068],\n", + " [-0.1157, -0.0801, -0.0712, ..., 0.0890, 0.0712, 0.0890],\n", + " [-0.1068, 0.0445, 0.0000, ..., 0.0712, 0.1246, 0.0534]],\n", + " device='cuda:0', requires_grad=True), Parameter containing:\n", + "tensor([], device='cuda:0'), Parameter containing:\n", + "tensor([], device='cuda:0', dtype=torch.int64), Parameter containing:\n", + "tensor([], device='cuda:0'), Parameter containing:\n", + "tensor([], device='cuda:0', dtype=torch.int64)]\n" + ] + } + ], + "source": [ + "for name, module in model.model.named_modules():\n", + " module_type = module.__class__.__name__\n", + " if module_type == \"Linear\":\n", + " is_quantized = hasattr(module, \"quantization_scheme\")\n", + " print(f\"Module: {name} has been quantized: {is_quantized}\")\n", + " print(f\"Input activation scale: {module.input_scale}\")\n", + " print(list(module.parameters()))\n", + " break" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": ".venv", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.12" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/src/compressed_tensors/utils/helpers.py b/src/compressed_tensors/utils/helpers.py index 1c8dd29f..0153f3c5 100644 --- a/src/compressed_tensors/utils/helpers.py +++ b/src/compressed_tensors/utils/helpers.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import logging from pathlib import Path from typing import Dict, Literal, Optional, Union @@ -19,12 +20,20 @@ from compressed_tensors.compressors import ModelCompressor from compressed_tensors.config import CompressionConfig, CompressionFormat from safetensors import safe_open -from safetensors.torch import save_file +from safetensors.torch import save_file, save_model from torch import Tensor +from torch.nn import Module from transformers import AutoConfig -__all__ = ["infer_compressor_from_model_config", "load_compressed", "save_compressed"] +__all__ = [ + "infer_compressor_from_model_config", + "load_compressed", + "save_compressed", + "save_compressed_model", +] + +_LOGGER = logging.getLogger(__name__) def infer_compressor_from_model_config( @@ -109,7 +118,7 @@ def load_compressed( if compression_config is None: # no compression applied tensors = {} - with safe_open(compressed_tensors, framework="pt", device="cpu") as f: + with safe_open(compressed_tensors, framework="pt", device=device) as f: for key in f.keys(): tensors[key] = f.get_tensor(key) return tensors @@ -120,3 +129,35 @@ def load_compressed( compression_format, config=compression_config ) return dict(compressor.decompress(compressed_tensors)) + + +def save_compressed_model( + model: Module, + filename: str, + compression_format: Optional[CompressionFormat] = None, + force_contiguous: bool = True, +): + """ + Wrapper around safetensors `save_model` helper function, which allows for + saving compressed model to disk. Note: if compression_format is not None, + the model is assumed to have a state_dict with unique entries + + :param model: model to save on disk + :param filename: filename location to save the file + :param compression_format: compression format used for the model + :param force_contiguous: forcing the state_dict to be saved as contiguous tensors + """ + if compression_format is None: + # use the default save_model function from safetensors + save_model(model, filename, force_contiguous=force_contiguous) + return + + state_dict = model.state_dict() + if force_contiguous: + state_dict = {k: v.contiguous() for k, v in state_dict.items()} + try: + save_compressed(state_dict, filename, compression_format=compression_format) + except ValueError as e: + msg = str(e) + msg += " Or use save_compressed_model(..., force_contiguous=True), read the docs for potential caveats." # noqa E501 + raise ValueError(msg) diff --git a/tests/test_utils/test_helpers.py b/tests/test_utils/test_helpers.py index f643233c..9724d7fe 100644 --- a/tests/test_utils/test_helpers.py +++ b/tests/test_utils/test_helpers.py @@ -11,11 +11,13 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - import pytest import torch -from compressed_tensors import load_compressed, save_compressed +from compressed_tensors import load_compressed, save_compressed, save_compressed_model from compressed_tensors.config import BitmaskConfig +from safetensors import safe_open +from safetensors.torch import save_model +from transformers import AutoModelForCausalLM @pytest.fixture @@ -24,6 +26,13 @@ def tensors(): return tensors +@pytest.fixture +def llama_model(tmp_path): + model_name = "neuralmagic/llama2.c-stories110M-pruned50" + model = AutoModelForCausalLM.from_pretrained(model_name, cache_dir=tmp_path) + yield model + + def test_save_compressed_sparse_bitmask(tmp_path, tensors): save_compressed( tensors, @@ -102,3 +111,33 @@ def test_load_compressed_no_compression(tmp_path, tensors): loaded_tensors = load_compressed(tmp_path / "model.safetensors") for key in tensors: assert torch.allclose(tensors[key], loaded_tensors[key]) + + +def test_save_compressed_model(tmp_path, llama_model): + path_to_uncompressed = tmp_path / "model_uncompressed.safetensors" + path_to_compressed = tmp_path / "model_compressed.safetensors" + + # save uncompressed model + save_model(llama_model, path_to_uncompressed) + size_uncompressed_kb = path_to_uncompressed.stat().st_size / 1024 + + # save compressed model + save_compressed_model( + llama_model, path_to_compressed, compression_format="sparse-bitmask" + ) + size_compressed_kb = path_to_compressed.stat().st_size / 1024 + + # compare that the are the same after loading + state_dict_1 = {} + with safe_open(path_to_uncompressed, framework="pt") as f: + for key in f.keys(): + state_dict_1[key] = f.get_tensor(key) + state_dict_2 = load_compressed( + path_to_compressed, BitmaskConfig(format="sparse-bitmask") + ) + assert all( + torch.allclose(state_dict_1[key], state_dict_2[key]) for key in state_dict_1 + ) + # make sure that compressed model is smaller + # than uncompressed by roughly 1.14 (value established empirically) + assert pytest.approx(size_uncompressed_kb / size_compressed_kb, 0.01) == 1.14