Skip to content

Commit

Permalink
Add some large model troubleshooting steps (#31862)
Browse files Browse the repository at this point in the history
  • Loading branch information
damccorm authored Aug 2, 2024
1 parent 0b4b8ea commit d96fa7d
Showing 1 changed file with 49 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ RunInference has several mechanisms for reducing memory utilization. For example

Many Beam runners, however, run multiple Beam processes per machine at once. This can cause problems since the memory footprint of loading large models like LLMs multiple times can be too large to fit into a single machine.
For memory-intensive models, RunInference provides a mechanism for more intelligently sharing memory across multiple processes to reduce the overall memory footprint. To enable this mode, users just have
to set the parameter `large_model` to True in their model configuration (see below for an example), and Beam will take care of the memory management.
to set the parameter `large_model` to True in their model configuration (see below for an example), and Beam will take care of the memory management. When using a custom model handler, you can override the `share_model_across_processes` function or the `model_copies` function for a similar effect.

### Running an Example Pipeline with T5

Expand Down Expand Up @@ -122,3 +122,51 @@ A `ModelHandler` requires parameters like:
* `device` – The device on which you wish to run the model. If device = GPU then a GPU device will be used if it is available. Otherwise, it will be CPU.
* `inference_fn` - The inference function to use during RunInference.
* `large_model` - (see `Memory Management` above). Whether to use memory minimization techniques to lower the memory footprint of your model.

### Troubleshooting Large Models

#### Pickling errors

When sharing a model across processes with `large_model=True` or using a custom model handler, Beam sends the input and output data across a process boundary.
To do this, it uses a serialization method known as [pickling](https://docs.python.org/3/library/pickle.html).
For example, if you call `output=model.my_inference_fn(input_1, input_2)`, `input_1`, `input_2`, and `output` will all need to be pickled.
The model itself does not need to be pickled since it is not passed across process boundaries.

While most objects can be pickled without issue, if one of these objects is unpickleable you may run into errors like `error: can't pickle fasttext_pybind.fasttext objects`.
To work around this, there are a few options:

First of all, if possible you can choose not to share your model across processes. This will incur additional memory pressure, but it may be tolerable in some cases.

Second, using a custom model handler you can wrap your model to take in and return serializable types. For example, if your model handler looks like:

```
class MyModelHandler():
def load_model(self):
return model_loading_logic()
def run_inference(self, batch: Sequence[str], model, inference_args):
unpickleable_object = Unpickleable(batch)
unpickleable_returned = model.predict(unpickleable_object)
my_output = int(unpickleable_returned[0])
return my_output
```

you could instead wrap the unpickleable pieces in a model wrapper. Since the model wrapper will sit in the inference process, this will work as long as it only takes in/returns pickleable objects.

```
class MyWrapper():
def __init__(self, model):
self._model = model
def predict(self, batch: Sequence[str]):
unpickleable_object = Unpickleable(batch)
unpickleable_returned = model.predict(unpickleable_object)
return int(prediction[0])
class MyModelHandler():
def load_model(self):
return MyWrapper(model_loading_logic())
def run_inference(self, batch: Sequence[str], model: MyWrapper, inference_args):
return model.predict(unpickleable_object)
```

0 comments on commit d96fa7d

Please sign in to comment.