diff --git a/docs/source/metadata_parsing.mdx b/docs/source/metadata_parsing.mdx index 4dd14330..ed557963 100644 --- a/docs/source/metadata_parsing.mdx +++ b/docs/source/metadata_parsing.mdx @@ -92,40 +92,41 @@ export type SafetensorsShardedHeaders = Record; ### Python -In this example python script, we are parsing metadata of [gpt2](https://huggingface.co/gpt2/blob/main/model.safetensors). +[`huggingface_hub`](https://huggingface.co/docs/huggingface_hub) provides a Python API to parse safetensors metadata. +Use [`get_safetensors_metadata`](https://huggingface.co/docs/huggingface_hub/package_reference/hf_api#huggingface_hub.HfApi.get_safetensors_metadata) to get all safetensors metadata of a model. +Depending on if the model is sharded or not, one or multiple safetensors files will be parsed. ```python -import requests # pip install requests -import struct - -def parse_single_file(url): - # Fetch the first 8 bytes of the file - headers = {'Range': 'bytes=0-7'} - response = requests.get(url, headers=headers) - # Interpret the bytes as a little-endian unsigned 64-bit integer - length_of_header = struct.unpack('>> from huggingface_hub import get_safetensors_metadata + +# Parse repo with single weights file +>>> metadata = get_safetensors_metadata("bigscience/bloomz-560m") +>>> metadata +SafetensorsRepoMetadata( + metadata=None, + sharded=False, + weight_map={'h.0.input_layernorm.bias': 'model.safetensors', ...}, + files_metadata={'model.safetensors': SafetensorsFileMetadata(...)} +) +>>> metadata.files_metadata["model.safetensors"].metadata +{'format': 'pt'} + +# Parse repo with sharded model (i.e. multiple weights files) +>>> metadata = get_safetensors_metadata("bigscience/bloom") +Parse safetensors files: 100%|██████████████████████████████████████████| 72/72 [00:12<00:00, 5.78it/s] +>>> metadata +SafetensorsRepoMetadata(metadata={'total_size': 352494542848}, sharded=True, weight_map={...}, files_metadata={...}) +>>> len(metadata.files_metadata) +72 # All safetensors files have been fetched + +# Parse repo that is not a safetensors repo +>>> get_safetensors_metadata("runwayml/stable-diffusion-v1-5") +NotASafetensorsRepoError: 'runwayml/stable-diffusion-v1-5' is not a safetensors repo. Couldn't find 'model.safetensors.index.json' or 'model.safetensors' files. ``` +To parse the metadata of a single safetensors file, use [`parse_safetensors_file_metadata`](https://huggingface.co/docs/huggingface_hub/package_reference/hf_api#huggingface_hub.HfApi.parse_safetensors_file_metadata). + + ## Example output For instance, here are the number of params per dtype for a few models on the HuggingFace Hub. Also see [this issue](https://github.com/huggingface/safetensors/issues/44) for more examples of usage.