Skip to content

Commit

Permalink
update custom pytorch example to use deberta (#44)
Browse files Browse the repository at this point in the history
* update custom pytorch example

* comment out test

* lint

* lint
  • Loading branch information
edknv authored Jan 17, 2024
1 parent 5bdcdf1 commit 382a6d8
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 15 deletions.
37 changes: 22 additions & 15 deletions examples/custom_pytorch_model.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import argparse
import os
from dataclasses import dataclass

import dask_cudf
import torch
Expand All @@ -14,27 +15,28 @@
NUM_ROWS = 1_000


class CFG:
model = "sentence-transformers/all-MiniLM-L6-v2"
@dataclass
class Config:
model = "microsoft/deberta-v3-base"
fc_dropout = 0.2
max_len = 512
out_dim = 3


class CustomModel(nn.Module):
def __init__(self, cfg, config_path=None, pretrained=False):
def __init__(self, config, config_path=None, pretrained=False):
super().__init__()
self.cfg = cfg
self.config = config
if config_path is None:
self.config = AutoConfig.from_pretrained(cfg.model, output_hidden_states=True)
self.config = AutoConfig.from_pretrained(config.model, output_hidden_states=True)
else:
self.config = torch.load(config_path)
if pretrained:
self.model = AutoModel.from_pretrained(cfg.model, config=self.config)
self.model = AutoModel.from_pretrained(config.model, config=self.config)
else:
self.model = AutoModel(self.config)
self.fc_dropout = nn.Dropout(cfg.fc_dropout)
self.fc = nn.Linear(self.config.hidden_size, self.cfg.out_dim)
self.fc_dropout = nn.Dropout(config.fc_dropout)
self.fc = nn.Linear(self.config.hidden_size, config.out_dim)
self._init_weights(self.fc)

def _init_weights(self, module):
Expand Down Expand Up @@ -63,8 +65,8 @@ def forward(self, batch):


# The user must provide a load_model function
def load_model(cfg, device, model_path):
model = CustomModel(cfg, config_path=None, pretrained=True)
def load_model(config, device, model_path):
model = CustomModel(config, config_path=None, pretrained=True)
model = model.to(device)

if os.path.exists(model_path):
Expand All @@ -77,10 +79,14 @@ def load_model(cfg, device, model_path):


class MyModel(HFModel):
def load_model(self, device="cuda"):
return load_model(CFG, device=device, model_path=self.path_or_name)
def __init__(self, config):
self.config = config
super().__init__(self.config.model)

def load_cfg(self):
def load_model(self, model_path=None, device="cuda"):
return load_model(self.config, device=device, model_path=model_path or self.path_or_name)

def load_config(self):
return AutoConfig.from_pretrained(self.path_or_name)


Expand Down Expand Up @@ -114,12 +120,13 @@ def main():
labels = ["foo", "bar", "baz"]

with cf.Distributed(rmm_pool_size=args.pool_size, n_workers=args.num_workers):
model = MyModel(CFG.model)
model = MyModel(Config)
pipe = op.Sequential(
op.Tokenizer(model, cols=[args.input_column]),
op.Tokenizer(model, cols=[args.input_column], tokenizer_type="sentencepiece"),
op.Predictor(model, sorted_data_loader=True, batch_size=args.batch_size),
op.Labeler(labels, cols=["preds"]),
repartition=args.partitions,
keep_cols=[args.input_column],
)
outputs = pipe(ddf)
outputs.to_parquet(args.output_parquet_path)
Expand Down
4 changes: 4 additions & 0 deletions tests/examples/test_scripts.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,10 @@
import sys # noqa: E402
import tempfile # noqa: E402

# from uuid import uuid4 # noqa: E402

# from crossfit.dataset.load import load_dataset # noqa: E402

examples_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "..", "..", "examples")


Expand Down

0 comments on commit 382a6d8

Please sign in to comment.