Skip to content

Commit

Permalink
[load_from_hf_hub] Add dataset_length, set_index (#339)
Browse files Browse the repository at this point in the history
This PR adds 2 things to the `load_from_hf_hub` reusable component:

- a `dataset_length` argument, which is required in case the user
specifies `n_rows_to_load`. The reason why I added this is because I hit
an issue when `n_rows_to_load` was larger than the partition size. The
current code loads only the first partition, so even though I specified
`n_rows_to_load` to be 150k, I only got 69,000 rows. So to solve this I
calculate the size of a single partition, then return approximately the
requested `n_rows_to_load`.
- adds a monotonically increasing index as suggested by this
Stackoverflow post, to solve the issue of duplicate indices due to every
partition having indices that start at 0.
  • Loading branch information
NielsRogge authored Aug 8, 2023
1 parent 74aca21 commit dfdcee3
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 15 deletions.
20 changes: 17 additions & 3 deletions components/load_from_hf_hub/src/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,12 +44,26 @@ def load(self) -> dd.DataFrame:
)

# 3) Rename columns
logger.info("Renaming columns...")
dask_df = dask_df.rename(columns=self.column_name_mapping)

# 4) Optional: only return specific amount of rows
if self.n_rows_to_load:
dask_df = dask_df.head(self.n_rows_to_load)
dask_df = dd.from_pandas(dask_df, npartitions=1)
if self.n_rows_to_load is not None:
partitions_length = 0
for npartitions, partition in enumerate(dask_df.partitions):
if partitions_length >= self.n_rows_to_load:
logger.info(f"""Required number of partitions to load\n
{self.n_rows_to_load} is {npartitions}""")
break
partitions_length += len(partition)
dask_df = dask_df.head(self.n_rows_to_load, npartitions=npartitions)
dask_df = dd.from_pandas(dask_df, npartitions=npartitions)

# Set monotonically increasing index
logger.info("Setting the index...")
dask_df["id"] = 1
dask_df["id"] = dask_df.id.cumsum()
dask_df = dask_df.set_index("id", sort=True)

return dask_df

Expand Down
13 changes: 1 addition & 12 deletions examples/pipelines/datacomp/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,13 @@

from pipeline_configs import PipelineConfigs

from fondant.compiler import DockerCompiler
from fondant.pipeline import ComponentOp, Pipeline, Client

logger = logging.getLogger(__name__)

# Initialize pipeline and client
pipeline = Pipeline(
pipeline_name="Datacomp filtering pipeline",
pipeline_name="datacomp-filtering",
pipeline_description="A pipeline for filtering the Datacomp dataset",
# base_path=PipelineConfigs.BASE_PATH,
base_path="/Users/nielsrogge/Documents/fondant_artifacts_datacomp",
Expand Down Expand Up @@ -69,13 +68,3 @@
pipeline.add_op(filter_complexity_op, dependencies=filter_image_resolution_op)
pipeline.add_op(cluster_image_embeddings_op, dependencies=filter_complexity_op)
# TODO add more ops

# compile
if __name__ == "__main__":
compiler = DockerCompiler()
# mount the gcloud credentials to the container
extra_volumes = [
"$HOME/.config/gcloud/application_default_credentials.json:/root/.config/gcloud/application_default_credentials.json:ro"
]
compiler.compile(pipeline=pipeline, extra_volumes=extra_volumes)
logger.info("Run `docker compose up` to run the pipeline.")

0 comments on commit dfdcee3

Please sign in to comment.