Skip to content

Commit

Permalink
tear down by unloading from workers (#39)
Browse files Browse the repository at this point in the history
  • Loading branch information
edknv authored Dec 18, 2023
1 parent 1f7d20a commit 05accee
Show file tree
Hide file tree
Showing 5 changed files with 72 additions and 4 deletions.
8 changes: 8 additions & 0 deletions crossfit/backend/torch/hf/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,14 @@ def load_on_worker(self, worker, device="cuda"):
worker.torch_model = self.load_model(device)
worker.cfg = self.load_cfg()

def unload_from_worker(self, worker):
if hasattr(worker, "torch_model"):
delattr(worker, "torch_model")
if hasattr(worker, "cfg"):
delattr(worker, "cfg")
gc.collect()
torch.cuda.empty_cache()

def load_model(self, device="cuda"):
return AutoModel.from_pretrained(self.path_or_name).to(device)

Expand Down
3 changes: 3 additions & 0 deletions crossfit/backend/torch/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,9 @@ def load_tokenizer(self):
def load_on_worker(self, worker):
raise NotImplementedError()

def unload_from_worker(self, worker):
raise NotImplementedError()

def call_on_worker(self, worker, *args, **kwargs):
return worker.torch_model(*args, **kwargs)

Expand Down
3 changes: 3 additions & 0 deletions crossfit/backend/torch/op/embed.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,9 @@ def __init__(
def setup(self):
self.model.load_on_worker(self)

def teardown(self):
self.model.unload_from_worker(self)

@torch.no_grad()
def call(self, data, partition_info=None):
index = data.index
Expand Down
36 changes: 32 additions & 4 deletions crossfit/op/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,11 @@
import uuid

import dask.dataframe as dd
from dask.distributed import get_worker
from dask.distributed import get_worker, wait
from tqdm.auto import tqdm

from crossfit.backend.dask.cluster import global_dask_client


class Op:
def __init__(self, pre=None, cols=False, keep_cols=None):
Expand All @@ -30,25 +32,49 @@ def __init__(self, pre=None, cols=False, keep_cols=None):
def setup(self):
pass

def teardown(self):
pass

def meta(self):
return None

def setup_worker(self):
def get_worker(self):
try:
worker = get_worker()
except ValueError:
worker = self

self.worker_name = getattr(worker, "name", 0)
return worker

def _get_init_name(self):
init_name = f"setup_done_{self.id}"
return init_name

def setup_worker(self):
worker = self.get_worker()

self.worker_name = getattr(worker, "name", 0)
init_name = self._get_init_name()

if not hasattr(worker, init_name):
self.setup()
setattr(worker, init_name, True)

def teardown_worker(self):
worker = self.get_worker()

init_name = self._get_init_name()

if hasattr(worker, init_name):
delattr(worker, init_name)
self.teardown()

def call_dask(self, data: dd.DataFrame):
output = data.map_partitions(self, meta=self._build_dask_meta(data))

if global_dask_client():
wait(output)

return output

def create_progress_bar(self, total, partition_info=None, **kwargs):
Expand All @@ -74,7 +100,9 @@ def add_keep_cols(self, data, output):

def __call__(self, data, *args, partition_info=None, **kwargs):
if isinstance(data, dd.DataFrame):
return self.call_dask(data, *args, **kwargs)
output = self.call_dask(data, *args, **kwargs)
self.teardown_worker()
return output

self.setup_worker()

Expand Down
26 changes: 26 additions & 0 deletions tests/backend/pytorch_backend/test_torch_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,3 +11,29 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from unittest.mock import Mock

import pytest

import crossfit as cf


class TestHFModel:
@pytest.fixture
def model(self):
return cf.HFModel("sentence-transformers/all-MiniLM-L6-v2")

@pytest.fixture
def mock_worker(self):
return Mock()

def test_unload_from_worker(self, model, mock_worker):
model.load_on_worker(mock_worker)

assert hasattr(mock_worker, "torch_model")
assert hasattr(mock_worker, "cfg")

model.unload_from_worker(mock_worker)

assert not hasattr(mock_worker, "torch_model")
assert not hasattr(mock_worker, "cfg")

0 comments on commit 05accee

Please sign in to comment.