diff --git a/plugins/flytekit-polars/flytekitplugins/polars/sd_transformers.py b/plugins/flytekit-polars/flytekitplugins/polars/sd_transformers.py index 0b5bf8e577..4290c88ae4 100644 --- a/plugins/flytekit-polars/flytekitplugins/polars/sd_transformers.py +++ b/plugins/flytekit-polars/flytekitplugins/polars/sd_transformers.py @@ -7,6 +7,7 @@ from flytekit.models import literals from flytekit.models.literals import StructuredDatasetMetadata from flytekit.models.types import StructuredDatasetType +from flytekit.types.structured.basic_dfs import get_storage_options from flytekit.types.structured.structured_dataset import ( PARQUET, StructuredDataset, @@ -62,12 +63,12 @@ def decode( flyte_value: literals.StructuredDataset, current_task_metadata: StructuredDatasetMetadata, ) -> pl.DataFrame: - local_dir = ctx.file_access.get_random_local_directory() - ctx.file_access.get_data(flyte_value.uri, local_dir, is_multipart=True) + uri = flyte_value.uri + kwargs = get_storage_options(ctx.file_access.data_config, uri) if current_task_metadata.structured_dataset_type and current_task_metadata.structured_dataset_type.columns: columns = [c.name for c in current_task_metadata.structured_dataset_type.columns] - return pl.read_parquet(local_dir, columns=columns, use_pyarrow=True) - return pl.read_parquet(local_dir, use_pyarrow=True) + return pl.read_parquet(uri, columns=columns, use_pyarrow=True, storage_options=kwargs) + return pl.read_parquet(uri, use_pyarrow=True, storage_options=kwargs) StructuredDatasetTransformerEngine.register(PolarsDataFrameToParquetEncodingHandler()) diff --git a/plugins/flytekit-polars/tests/test_polars_plugin_sd.py b/plugins/flytekit-polars/tests/test_polars_plugin_sd.py index 15a195e5d5..23fbf6d441 100644 --- a/plugins/flytekit-polars/tests/test_polars_plugin_sd.py +++ b/plugins/flytekit-polars/tests/test_polars_plugin_sd.py @@ -1,3 +1,5 @@ +import tempfile + import pandas as pd import polars as pl from flytekitplugins.polars.sd_transformers import PolarsDataFrameRenderer @@ -79,3 +81,13 @@ def create_sd() -> StructuredDataset: sd = create_sd() polars_df = sd.open(pl.DataFrame).all() assert pl.DataFrame(data).frame_equal(polars_df) + + tmp = tempfile.mktemp() + pl.DataFrame(data).write_parquet(tmp) + + @task + def t1(sd: StructuredDataset) -> pl.DataFrame: + return sd.open(pd.DataFrame).all() + + sd = StructuredDataset(uri=tmp) + t1(sd=sd).frame_equal(polars_df)