Skip to content

Commit

Permalink
docs: artefact_management_doc (#2504)
Browse files Browse the repository at this point in the history
Add Deep Learning Artefact Management Docs to Lance Docs.
  • Loading branch information
tanaymeh authored Jun 25, 2024
1 parent f51c5f0 commit 06a6a09
Show file tree
Hide file tree
Showing 2 changed files with 178 additions and 1 deletion.
176 changes: 176 additions & 0 deletions docs/examples/artefact_management.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,176 @@
Deep Learning Artefact Management using Lance
---------------------------------------------
Along with datasets, Lance file format can also be used for saving and versioning deep learning model weights.
In fact deep learning artefact management can be made more streamlined (compared to vanilla weight saving methods) using Lance file format for PyTorch model weights.

In this example we will be demonstrating how you save, version and load a PyTorch model's weights using Lance. More specifically we will be loading a pre-trained ResNet model, saving it in Lance file format, loading it back to PyTorch and verifying if the weights are still indeed the same.
We will also be demonstrating how you can version your model weights in a single lance dataset thanks to our Zero-copy, automatic versioning.

**Key Idea:** When you save a model's weights (read: state dictionary) in PyTorch, weights are stored as key-value pairs in an :meth:`OrderedDict` with the keys representing the weight's name and the value representing the corresponding weight tensor.
To emulate this as closely as possible, we will be saving the weights in three columns. The first column will have the name of the weight, the second will have the weight itself but flattened in a list and the third will have the original shape of the weights so they can be reconstructed for loading into a model.

Imports and Setup
~~~~~~~~~~~~~~~~~
We will start by importing and loading all the necessary modules.

.. code-block:: python
import os
import shutil
import lance
import pyarrow as pa
import torch
from collections import OrderedDict
We will also define a :meth:`GLOBAL_SCHEMA` that will dictate how the weights table will look like.

.. code-block:: python
GLOBAL_SCHEMA = pa.schema(
[
pa.field("name", pa.string()),
pa.field("value", pa.list_(pa.float64(), -1)),
pa.field("shape", pa.list_(pa.int64(), -1)), # Is a list with variable shape because weights can have any number of dims
]
)
As we covered earlier, the weights table will have three columns - one for storing the weight name, one for storing the flattened weight value and one for storing the original weight shape for loading them back.

Saving and Versioning Models
~~~~~~~~~~~~~~~~~~~~~~~~~~~~
First we will focus on the model saving part. Let's start by writing a utility function that will take a model's state dict, goes over each weight, flatten it and then return the weight name, flattened weight and weight's original shape in a pyarrow :meth:`RecordBatch`.

.. code-block:: python
def _save_model_writer(state_dict):
"""Yields a RecordBatch for each parameter in the model state dict"""
for param_name, param in state_dict.items():
param_shape = list(param.size())
param_value = param.flatten().tolist()
yield pa.RecordBatch.from_arrays(
[
pa.array(
[param_name],
pa.string(),
),
pa.array(
[param_value],
pa.list_(pa.float64(), -1),
),
pa.array(
[param_shape],
pa.list_(pa.int64(), -1),
),
],
["name", "value", "shape"],
)
Now about versioning: Let's say you trained your model on some new data but don't want to overwrite your old checkpoint, you can now just save these newly trained model weights as a version in Lance weights dataset.
This will allow you to load specific version of weights from one lance weight dataset instead of making separate folders for each model checkpoint to make.

Let's write a function that handles the work for saving the model, whether with versions or without them.

.. code-block:: python
def save_model(state_dict: OrderedDict, file_name: str, version=False):
"""Saves a PyTorch model in lance file format
Args:
state_dict (OrderedDict): Model state dict
file_name (str): Lance model name
version (bool): Whether to save as a new version or overwrite the existing versions,
if the lance file already exists
"""
# Create a reader
reader = pa.RecordBatchReader.from_batches(
GLOBAL_SCHEMA, _save_model_writer(state_dict)
)
if os.path.exists(file_name):
if version:
# If we want versioning, we use the overwrite mode to create a new version
lance.write_dataset(
reader, file_name, schema=GLOBAL_SCHEMA, mode="overwrite"
)
else:
# If we don't want versioning, we delete the existing file and write a new one
shutil.rmtree(file_name)
lance.write_dataset(reader, file_name, schema=GLOBAL_SCHEMA)
else:
# If the file doesn't exist, we write a new one
lance.write_dataset(reader, file_name, schema=GLOBAL_SCHEMA)
The above function will take in the model state dict, the lance saved file name and the weights version. The function will start by making a :meth:`RecordBatchReader` using the global schema and the utility function we wrote above.
If the weights lance dataset already exists in the directory, we will just save it as a new version (if versioning is enabled) or delete the old file and save the weights as new. Otherwise the weights saving will be done normally.

Loading Models
~~~~~~~~~~~~~~
Loading weights from a Lance weight dataset into a model is just the reverse of saving them. The key part is to reshape the flattened weights back to their original shape, which is easier thanks to the shape that you saved corresponding to the weights.
We will divide this into three functions for better readability.

The first function will be the :meth:`_load_weight` function which will take a "weight" retrieved from the Lance weight dataset and return the weight as a torch tensor in it's original shape. The "weight" that we retrieve from the Lance weight dataset will be a dict with value corresponding to each column in form of a key.

.. code-block:: python
def _load_weight(weight: dict) -> torch.Tensor:
"""Converts a weight dict to a torch tensor"""
return torch.tensor(weight["value"], dtype=torch.float64).reshape(weight["shape"])
Optionally, you could also add an option to specify the datatype of the weights.

The next function will be on loading all the weights from the lance weight dataset into a state dictionary, which is what PyTorch will expect when we load the weights into our model.

.. code-block:: python
def _load_state_dict(file_name: str, version: int = 1, map_location=None) -> OrderedDict:
"""Reads the model weights from lance file and returns a model state dict
If the model weights are too large, this function will fail with a memory error.
Args:
file_name (str): Lance model name
version (int): Version of the model to load
map_location (str): Device to load the model on
Returns:
OrderedDict: Model state dict
"""
ds = lance.dataset(file_name, version=version)
weights = ds.take([x for x in range(ds.count_rows())]).to_pylist()
state_dict = OrderedDict()
for weight in weights:
state_dict[weight["name"]] = _load_weight(weight).to(map_location)
return state_dict
The :meth:`load_state_dict` function will expect a lance weight dataset file name, a version and a device where the weights will be loaded into.
We essentially load all the weights from the lance weight dataset into our memory and iteratively convert them into weights using the utility function we wrote earlier and then put them on the device.

One thing to note here is that this function will fail if the saved weights are larger than memory. For the sake of simplicity, we assume the weights to be loaded can fit in the memory and we don't have to deal with any sharding.

Finally, we will write a higher level function is the only one we will call to load the weights.

.. code-block:: python
def load_model(
model: torch.nn.Module, file_name: str, version: int = 1, map_location=None
):
"""Loads the model weights from lance file and sets them to the model
Args:
model (torch.nn.Module): PyTorch model
file_name (str): Lance model name
version (int): Version of the model to load
map_location (str): Device to load the model on
"""
state_dict = _load_state_dict(file_name, version=version, map_location=map_location)
model.load_state_dict(state_dict)
The :meth:`load_model` function will require the model, the lance weight dataset name, the version of weights to load in and the map location. This will just call the :meth:`_load_state_dict` utility to get the state dict and then load that state dict into the model.

Conclusion
~~~~~~~~~~
In conclusion, you only need to call the two function: :meth:`save_model` and :meth:`load_model` to save and load the models respectively and as long as the weights can be fit in the memory and are in PyTorch, it should be fine.

Although experimental, this approach defines a new way of doing deep learning artefact management.
3 changes: 2 additions & 1 deletion docs/examples/examples.rst
Original file line number Diff line number Diff line change
Expand Up @@ -8,4 +8,5 @@ Examples
Training LLMs using a Lance text dataset <./llm_training.rst>
Reading and writing a Lance dataset in Rust <./write_read_dataset.rst>
Creating Multi-Modal datasets using Lance <./flickr8k_dataset_creation.rst>
Training Multi-Modal models using a Lance dataset <./clip_training.rst>
Training Multi-Modal models using a Lance dataset <./clip_training.rst>
Deep Learning Artefact Management using Lance <./artefact_management.rst>

0 comments on commit 06a6a09

Please sign in to comment.