Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Document huggingface_hub.get_safetensors_metadata #417

Merged
merged 5 commits into from
Jan 4, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 31 additions & 30 deletions docs/source/metadata_parsing.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -92,40 +92,41 @@ export type SafetensorsShardedHeaders = Record<FileName, SafetensorsFileHeader>;

### Python

In this example python script, we are parsing metadata of [gpt2](https://huggingface.co/gpt2/blob/main/model.safetensors).
[`huggingface_hub`](https://huggingface.co/docs/huggingface_hub) provides a Python API to parse safetensors metadata.
Use [`get_safetensors_metadata`](https://huggingface.co/docs/huggingface_hub/package_reference/hf_api#huggingface_hub.HfApi.get_safetensors_metadata) to get all safetensors metadata of a model.
Depending on if the model is sharded or not, one or multiple safetensors files will be parsed.

```python
import requests # pip install requests
import struct

def parse_single_file(url):
# Fetch the first 8 bytes of the file
headers = {'Range': 'bytes=0-7'}
response = requests.get(url, headers=headers)
# Interpret the bytes as a little-endian unsigned 64-bit integer
length_of_header = struct.unpack('<Q', response.content)[0]
# Fetch length_of_header bytes starting from the 9th byte
headers = {'Range': f'bytes=8-{7 + length_of_header}'}
response = requests.get(url, headers=headers)
# Interpret the response as a JSON object
header = response.json()
return header

url = "https://huggingface.co/gpt2/resolve/main/model.safetensors"
header = parse_single_file(url)
Comment on lines -101 to -115
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i think this was still valuable to show how it actually works, no? potentially mentioning the huggingface_hub support at the end

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Linking @mishig25's PR: #424 (thanks about that 🙏)


print(header)
# {
# "__metadata__": { "format": "pt" },
# "h.10.ln_1.weight": {
# "dtype": "F32",
# "shape": [768],
# "data_offsets": [223154176, 223157248]
# },
# ...
# }
>>> from huggingface_hub import get_safetensors_metadata

# Parse repo with single weights file
>>> metadata = get_safetensors_metadata("bigscience/bloomz-560m")
Wauplin marked this conversation as resolved.
Show resolved Hide resolved
>>> metadata
SafetensorsRepoMetadata(
metadata=None,
sharded=False,
weight_map={'h.0.input_layernorm.bias': 'model.safetensors', ...},
files_metadata={'model.safetensors': SafetensorsFileMetadata(...)}
)
>>> metadata.files_metadata["model.safetensors"].metadata
{'format': 'pt'}

# Parse repo with sharded model (i.e. multiple weights files)
>>> metadata = get_safetensors_metadata("bigscience/bloom")
Parse safetensors files: 100%|██████████████████████████████████████████| 72/72 [00:12<00:00, 5.78it/s]
>>> metadata
SafetensorsRepoMetadata(metadata={'total_size': 352494542848}, sharded=True, weight_map={...}, files_metadata={...})
>>> len(metadata.files_metadata)
72 # All safetensors files have been fetched

# Parse repo that is not a safetensors repo
>>> get_safetensors_metadata("runwayml/stable-diffusion-v1-5")
NotASafetensorsRepoError: 'runwayml/stable-diffusion-v1-5' is not a safetensors repo. Couldn't find 'model.safetensors.index.json' or 'model.safetensors' files.
```

To parse the metadata of a single safetensors file, use [`parse_safetensors_file_metadata`](https://huggingface.co/docs/huggingface_hub/package_reference/hf_api#huggingface_hub.HfApi.parse_safetensors_file_metadata).


## Example output

For instance, here are the number of params per dtype for a few models on the HuggingFace Hub. Also see [this issue](https://github.com/huggingface/safetensors/issues/44) for more examples of usage.
Expand Down
Loading