Skip to content

Commit

Permalink
Removed Catalog.open_object() and refactor method to return file ob…
Browse files Browse the repository at this point in the history
…ject from row (#467)

* removed catalog.open_object and refactor method to return file objects from row

* removed not used method
  • Loading branch information
ilongin authored Sep 23, 2024
1 parent 4a7b17f commit 0835bf1
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 83 deletions.
68 changes: 19 additions & 49 deletions src/datachain/catalog/catalog.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@
)
from datachain.dataset import DatasetVersion
from datachain.job import Job
from datachain.lib.file import File

logger = logging.getLogger("datachain")

Expand Down Expand Up @@ -1399,65 +1400,34 @@ def edit_dataset(
dataset = self.get_dataset(name)
return self.update_dataset(dataset, **update_data)

def get_file_signals(
self, dataset_name: str, dataset_version: int, row: RowDict
) -> Optional[RowDict]:
def get_file_from_row(
self, dataset_name: str, dataset_version: int, row: RowDict, signal_name: str
) -> "File":
"""
Function that returns file signals from dataset row.
Note that signal names are without prefix, so if there was 'laion__file__source'
in original row, result will have just 'source'
Example output:
{
"source": "s3://ldb-public",
"path": "animals/dogs/dog.jpg",
...
}
Function that returns specific file signal from dataset row by name.
"""
from datachain.lib.file import File
from datachain.lib.signal_schema import DEFAULT_DELIMITER, SignalSchema

version = self.get_dataset(dataset_name).get_version(dataset_version)

file_signals_values = RowDict()

schema = SignalSchema.deserialize(version.feature_schema)
for file_signals in schema.get_signals(File):
prefix = file_signals.replace(".", DEFAULT_DELIMITER) + DEFAULT_DELIMITER
file_signals_values[file_signals] = {
c_name.removeprefix(prefix): c_value
for c_name, c_value in row.items()
if c_name.startswith(prefix)
and DEFAULT_DELIMITER not in c_name.removeprefix(prefix)
}

if not file_signals_values:
return None

# there can be multiple file signals in a schema, but taking the first
# one for now. In future we might add ability to choose from which one
# to open object
return next(iter(file_signals_values.values()))

def open_object(
self,
dataset_name: str,
dataset_version: int,
row: RowDict,
use_cache: bool = True,
**config: Any,
):
from datachain.lib.file import File
if signal_name not in schema.get_signals(File):
raise RuntimeError(
f"File signal with path {signal_name} not found in ",
f"dataset {dataset_name}@v{dataset_version} signals schema",
)

file_signals = self.get_file_signals(dataset_name, dataset_version, row)
if not file_signals:
raise RuntimeError("Cannot open object without file signals")
prefix = signal_name.replace(".", DEFAULT_DELIMITER) + DEFAULT_DELIMITER
file_signals = {
c_name.removeprefix(prefix): c_value
for c_name, c_value in row.items()
if c_name.startswith(prefix)
and DEFAULT_DELIMITER not in c_name.removeprefix(prefix)
and c_name.removeprefix(prefix) in File.model_fields
}

config = config or self.client_config
client = self.get_client(file_signals["source"], **config)
return client.open_object(
File._from_row(file_signals),
use_cache=use_cache,
)
return File(**file_signals)

def ls(
self,
Expand Down
54 changes: 20 additions & 34 deletions tests/func/test_catalog.py
Original file line number Diff line number Diff line change
Expand Up @@ -847,7 +847,7 @@ def test_garbage_collect(cloud_test_catalog, from_cli, capsys):
assert catalog.get_temp_table_names() == []


def test_get_file_signals(cloud_test_catalog, dogs_dataset):
def test_get_file_from_row(cloud_test_catalog, dogs_dataset):
catalog = cloud_test_catalog.catalog
catalog.metastore.update_dataset_version(
dogs_dataset,
Expand All @@ -863,18 +863,22 @@ def test_get_file_signals(cloud_test_catalog, dogs_dataset):
"name": "Jon",
"age": 25,
"f1__source": "s3://first_bucket",
"f1__name": "image1.jpg",
"f1__path": "image1.jpg",
"f2__source": "s3://second_bucket",
"f2__name": "image2.jpg",
"f2__path": "image2.jpg",
}

assert catalog.get_file_signals(dogs_dataset.name, 1, row) == {
"source": "s3://first_bucket",
"name": "image1.jpg",
}
assert catalog.get_file_from_row(dogs_dataset.name, 1, row, "f1") == File(
source="s3://first_bucket",
path="image1.jpg",
)
assert catalog.get_file_from_row(dogs_dataset.name, 1, row, "f2") == File(
source="s3://second_bucket",
path="image2.jpg",
)


def test_get_file_signals_with_custom_types(cloud_test_catalog, dogs_dataset):
def test_get_file_from_row_with_custom_types(cloud_test_catalog, dogs_dataset):
catalog = cloud_test_catalog.catalog
catalog.metastore.update_dataset_version(
dogs_dataset,
Expand All @@ -885,44 +889,26 @@ def test_get_file_signals_with_custom_types(cloud_test_catalog, dogs_dataset):
"f1": "File@v1",
"f2": "File@v1",
"_custom_types": {
"File@v1": {"source": "str", "name": "str"},
"File@v1": {"source": "str", "path": "str"},
},
},
)
row = {
"name": "Jon",
"age": 25,
"f1__source": "s3://first_bucket",
"f1__name": "image1.jpg",
"f1__path": "image1.jpg",
"f2__source": "s3://second_bucket",
"f2__name": "image2.jpg",
}

assert catalog.get_file_signals(dogs_dataset.name, 1, row) == {
"source": "s3://first_bucket",
"name": "image1.jpg",
"f2__path": "image2.jpg",
}


def test_get_file_signals_no_signals(cloud_test_catalog, dogs_dataset):
catalog = cloud_test_catalog.catalog
catalog.metastore.update_dataset_version(
dogs_dataset,
1,
feature_schema={
"name": "str",
"age": "str",
},
assert catalog.get_file_from_row(dogs_dataset.name, 1, row, "f1") == File(
source="s3://first_bucket",
path="image1.jpg",
)
row = {
"name": "Jon",
"age": 25,
}

assert catalog.get_file_signals(dogs_dataset.name, 1, row) is None


def test_open_object_no_file_signals(cloud_test_catalog, dogs_dataset):
def test_get_file_from_row_no_signals(cloud_test_catalog, dogs_dataset):
catalog = cloud_test_catalog.catalog
catalog.metastore.update_dataset_version(
dogs_dataset,
Expand All @@ -938,4 +924,4 @@ def test_open_object_no_file_signals(cloud_test_catalog, dogs_dataset):
}

with pytest.raises(RuntimeError):
assert catalog.open_object(dogs_dataset.name, 1, row)
assert catalog.get_file_from_row(dogs_dataset.name, 1, row, "missing")

0 comments on commit 0835bf1

Please sign in to comment.