-
Notifications
You must be signed in to change notification settings - Fork 7
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
3 changed files
with
51 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,42 @@ | ||
import cudf | ||
import dask_cudf | ||
|
||
import crossfit as cf | ||
from crossfit import op | ||
|
||
|
||
def create_sample_ddf(): | ||
df = cudf.DataFrame( | ||
{ | ||
"text": [ | ||
"query: how much protein should a female eat", | ||
"query: summit define", | ||
"passage: As a general guideline, the CDC's average requirement of protein for women ages 19 to 70 is 46 grams per day. But, as you can see from this chart, you'll need to increase that if you're expecting or training for a marathon. Check out the chart below to see how much protein you should be eating each day.", | ||
"passage: Definition of summit for English Language Learners. : 1 the highest point of a mountain : the top of a mountain. : 2 the highest level. : 3 a meeting or series of meetings between the leaders of two or more governments.", | ||
] | ||
} | ||
) | ||
|
||
npartitions = 2 # assume 2 GPUs and data is small enough to fit on 2 GPUs | ||
df = dask_cudf.from_cudf(df, npartitions=npartitions) | ||
|
||
return df | ||
|
||
|
||
if __name__ == "__main__": | ||
df = create_sample_ddf() | ||
|
||
model = cf.SentenceTransformerModel("intfloat/e5-large-v2") | ||
|
||
with cf.Distributed(rmm_pool_size="16GB", n_workers=2): | ||
tokenizer = op.Tokenizer(model, cols=["text"]) | ||
tokens = tokenizer(df) | ||
|
||
num_tokens = tokens.input_ids.map_partitions( | ||
# work around `list_series.list.index(0)` not working by casting list values to int. | ||
lambda s: s.list.astype(int).list.index(0).replace(-1, s.list.len().iloc[0]), | ||
meta=("input_ids", "int"), | ||
).to_frame() | ||
num_tokens.to_parquet("temp_num_tokens.parquet") | ||
|
||
print(dask_cudf.read_parquet("temp_num_tokens.parquet").compute()) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters