-
Notifications
You must be signed in to change notification settings - Fork 236
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
docs: artefact_management_doc (#2504)
Add Deep Learning Artefact Management Docs to Lance Docs.
- Loading branch information
Showing
2 changed files
with
178 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,176 @@ | ||
Deep Learning Artefact Management using Lance | ||
--------------------------------------------- | ||
Along with datasets, Lance file format can also be used for saving and versioning deep learning model weights. | ||
In fact deep learning artefact management can be made more streamlined (compared to vanilla weight saving methods) using Lance file format for PyTorch model weights. | ||
|
||
In this example we will be demonstrating how you save, version and load a PyTorch model's weights using Lance. More specifically we will be loading a pre-trained ResNet model, saving it in Lance file format, loading it back to PyTorch and verifying if the weights are still indeed the same. | ||
We will also be demonstrating how you can version your model weights in a single lance dataset thanks to our Zero-copy, automatic versioning. | ||
|
||
**Key Idea:** When you save a model's weights (read: state dictionary) in PyTorch, weights are stored as key-value pairs in an :meth:`OrderedDict` with the keys representing the weight's name and the value representing the corresponding weight tensor. | ||
To emulate this as closely as possible, we will be saving the weights in three columns. The first column will have the name of the weight, the second will have the weight itself but flattened in a list and the third will have the original shape of the weights so they can be reconstructed for loading into a model. | ||
|
||
Imports and Setup | ||
~~~~~~~~~~~~~~~~~ | ||
We will start by importing and loading all the necessary modules. | ||
|
||
.. code-block:: python | ||
import os | ||
import shutil | ||
import lance | ||
import pyarrow as pa | ||
import torch | ||
from collections import OrderedDict | ||
We will also define a :meth:`GLOBAL_SCHEMA` that will dictate how the weights table will look like. | ||
|
||
.. code-block:: python | ||
GLOBAL_SCHEMA = pa.schema( | ||
[ | ||
pa.field("name", pa.string()), | ||
pa.field("value", pa.list_(pa.float64(), -1)), | ||
pa.field("shape", pa.list_(pa.int64(), -1)), # Is a list with variable shape because weights can have any number of dims | ||
] | ||
) | ||
As we covered earlier, the weights table will have three columns - one for storing the weight name, one for storing the flattened weight value and one for storing the original weight shape for loading them back. | ||
|
||
Saving and Versioning Models | ||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~ | ||
First we will focus on the model saving part. Let's start by writing a utility function that will take a model's state dict, goes over each weight, flatten it and then return the weight name, flattened weight and weight's original shape in a pyarrow :meth:`RecordBatch`. | ||
|
||
.. code-block:: python | ||
def _save_model_writer(state_dict): | ||
"""Yields a RecordBatch for each parameter in the model state dict""" | ||
for param_name, param in state_dict.items(): | ||
param_shape = list(param.size()) | ||
param_value = param.flatten().tolist() | ||
yield pa.RecordBatch.from_arrays( | ||
[ | ||
pa.array( | ||
[param_name], | ||
pa.string(), | ||
), | ||
pa.array( | ||
[param_value], | ||
pa.list_(pa.float64(), -1), | ||
), | ||
pa.array( | ||
[param_shape], | ||
pa.list_(pa.int64(), -1), | ||
), | ||
], | ||
["name", "value", "shape"], | ||
) | ||
Now about versioning: Let's say you trained your model on some new data but don't want to overwrite your old checkpoint, you can now just save these newly trained model weights as a version in Lance weights dataset. | ||
This will allow you to load specific version of weights from one lance weight dataset instead of making separate folders for each model checkpoint to make. | ||
|
||
Let's write a function that handles the work for saving the model, whether with versions or without them. | ||
|
||
.. code-block:: python | ||
def save_model(state_dict: OrderedDict, file_name: str, version=False): | ||
"""Saves a PyTorch model in lance file format | ||
Args: | ||
state_dict (OrderedDict): Model state dict | ||
file_name (str): Lance model name | ||
version (bool): Whether to save as a new version or overwrite the existing versions, | ||
if the lance file already exists | ||
""" | ||
# Create a reader | ||
reader = pa.RecordBatchReader.from_batches( | ||
GLOBAL_SCHEMA, _save_model_writer(state_dict) | ||
) | ||
if os.path.exists(file_name): | ||
if version: | ||
# If we want versioning, we use the overwrite mode to create a new version | ||
lance.write_dataset( | ||
reader, file_name, schema=GLOBAL_SCHEMA, mode="overwrite" | ||
) | ||
else: | ||
# If we don't want versioning, we delete the existing file and write a new one | ||
shutil.rmtree(file_name) | ||
lance.write_dataset(reader, file_name, schema=GLOBAL_SCHEMA) | ||
else: | ||
# If the file doesn't exist, we write a new one | ||
lance.write_dataset(reader, file_name, schema=GLOBAL_SCHEMA) | ||
The above function will take in the model state dict, the lance saved file name and the weights version. The function will start by making a :meth:`RecordBatchReader` using the global schema and the utility function we wrote above. | ||
If the weights lance dataset already exists in the directory, we will just save it as a new version (if versioning is enabled) or delete the old file and save the weights as new. Otherwise the weights saving will be done normally. | ||
|
||
Loading Models | ||
~~~~~~~~~~~~~~ | ||
Loading weights from a Lance weight dataset into a model is just the reverse of saving them. The key part is to reshape the flattened weights back to their original shape, which is easier thanks to the shape that you saved corresponding to the weights. | ||
We will divide this into three functions for better readability. | ||
|
||
The first function will be the :meth:`_load_weight` function which will take a "weight" retrieved from the Lance weight dataset and return the weight as a torch tensor in it's original shape. The "weight" that we retrieve from the Lance weight dataset will be a dict with value corresponding to each column in form of a key. | ||
|
||
.. code-block:: python | ||
def _load_weight(weight: dict) -> torch.Tensor: | ||
"""Converts a weight dict to a torch tensor""" | ||
return torch.tensor(weight["value"], dtype=torch.float64).reshape(weight["shape"]) | ||
Optionally, you could also add an option to specify the datatype of the weights. | ||
|
||
The next function will be on loading all the weights from the lance weight dataset into a state dictionary, which is what PyTorch will expect when we load the weights into our model. | ||
|
||
.. code-block:: python | ||
def _load_state_dict(file_name: str, version: int = 1, map_location=None) -> OrderedDict: | ||
"""Reads the model weights from lance file and returns a model state dict | ||
If the model weights are too large, this function will fail with a memory error. | ||
Args: | ||
file_name (str): Lance model name | ||
version (int): Version of the model to load | ||
map_location (str): Device to load the model on | ||
Returns: | ||
OrderedDict: Model state dict | ||
""" | ||
ds = lance.dataset(file_name, version=version) | ||
weights = ds.take([x for x in range(ds.count_rows())]).to_pylist() | ||
state_dict = OrderedDict() | ||
for weight in weights: | ||
state_dict[weight["name"]] = _load_weight(weight).to(map_location) | ||
return state_dict | ||
The :meth:`load_state_dict` function will expect a lance weight dataset file name, a version and a device where the weights will be loaded into. | ||
We essentially load all the weights from the lance weight dataset into our memory and iteratively convert them into weights using the utility function we wrote earlier and then put them on the device. | ||
|
||
One thing to note here is that this function will fail if the saved weights are larger than memory. For the sake of simplicity, we assume the weights to be loaded can fit in the memory and we don't have to deal with any sharding. | ||
|
||
Finally, we will write a higher level function is the only one we will call to load the weights. | ||
|
||
.. code-block:: python | ||
def load_model( | ||
model: torch.nn.Module, file_name: str, version: int = 1, map_location=None | ||
): | ||
"""Loads the model weights from lance file and sets them to the model | ||
Args: | ||
model (torch.nn.Module): PyTorch model | ||
file_name (str): Lance model name | ||
version (int): Version of the model to load | ||
map_location (str): Device to load the model on | ||
""" | ||
state_dict = _load_state_dict(file_name, version=version, map_location=map_location) | ||
model.load_state_dict(state_dict) | ||
The :meth:`load_model` function will require the model, the lance weight dataset name, the version of weights to load in and the map location. This will just call the :meth:`_load_state_dict` utility to get the state dict and then load that state dict into the model. | ||
|
||
Conclusion | ||
~~~~~~~~~~ | ||
In conclusion, you only need to call the two function: :meth:`save_model` and :meth:`load_model` to save and load the models respectively and as long as the weights can be fit in the memory and are in PyTorch, it should be fine. | ||
|
||
Although experimental, this approach defines a new way of doing deep learning artefact management. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters