diff --git a/Dockerfile.dev b/Dockerfile.dev index f6baf63896..b7c5104bbc 100644 --- a/Dockerfile.dev +++ b/Dockerfile.dev @@ -12,27 +12,21 @@ MAINTAINER Flyte Team LABEL org.opencontainers.image.source https://github.com/flyteorg/flytekit WORKDIR /root -ENV PYTHONPATH /root ARG VERSION -ARG DOCKER_IMAGE RUN apt-get update && apt-get install build-essential vim -y -COPY . /code/flytekit -WORKDIR /code/flytekit +COPY . /flytekit # Pod tasks should be exposed in the default image -RUN pip install -e . -RUN pip install -e plugins/flytekit-k8s-pod -RUN pip install -e plugins/flytekit-deck-standard +RUN pip install -e /flytekit +RUN pip install -e /flytekit/plugins/flytekit-k8s-pod +RUN pip install -e /flytekit/plugins/flytekit-deck-standard RUN pip install scikit-learn -ENV PYTHONPATH "/code/flytekit:/code/flytekit/plugins/flytekit-k8s-pod:/code/flytekit/plugins/flytekit-deck-standard:" +ENV PYTHONPATH "/flytekit:/flytekit/plugins/flytekit-k8s-pod:/flytekit/plugins/flytekit-deck-standard:" -WORKDIR /root RUN useradd -u 1000 flytekit RUN chown flytekit: /root USER flytekit - -ENV FLYTE_INTERNAL_IMAGE "$DOCKER_IMAGE" diff --git a/flytekit/clis/sdk_in_container/run.py b/flytekit/clis/sdk_in_container/run.py index 136831c0bc..c45ec3f150 100644 --- a/flytekit/clis/sdk_in_container/run.py +++ b/flytekit/clis/sdk_in_container/run.py @@ -81,7 +81,7 @@ def convert( raise ValueError( f"Currently only directories containing one file are supported, found [{len(files)}] files found in {p.resolve()}" ) - return Directory(dir_path=value, local_file=files[0].resolve()) + return Directory(dir_path=str(p), local_file=files[0].resolve()) raise click.BadParameter(f"parameter should be a valid directory path, {value}") diff --git a/flytekit/core/data_persistence.py b/flytekit/core/data_persistence.py index 8fb73ebd8c..ea36689874 100644 --- a/flytekit/core/data_persistence.py +++ b/flytekit/core/data_persistence.py @@ -107,16 +107,16 @@ def data_config(self) -> DataConfig: return self._data_config def get_filesystem( - self, protocol: typing.Optional[str] = None, anonymous: bool = False + self, protocol: typing.Optional[str] = None, anonymous: bool = False, **kwargs ) -> typing.Optional[fsspec.AbstractFileSystem]: if not protocol: return self._default_remote - kwargs = {} # type: typing.Dict[str, typing.Any] if protocol == "file": - kwargs = {"auto_mkdir": True} + kwargs["auto_mkdir"] = True elif protocol == "s3": - kwargs = s3_setup_args(self._data_config.s3, anonymous=anonymous) - return fsspec.filesystem(protocol, **kwargs) # type: ignore + s3kwargs = s3_setup_args(self._data_config.s3, anonymous=anonymous) + s3kwargs.update(kwargs) + return fsspec.filesystem(protocol, **s3kwargs) # type: ignore elif protocol == "gs": if anonymous: kwargs["token"] = _ANON @@ -128,9 +128,9 @@ def get_filesystem( return fsspec.filesystem(protocol, **kwargs) # type: ignore - def get_filesystem_for_path(self, path: str = "", anonymous: bool = False) -> fsspec.AbstractFileSystem: + def get_filesystem_for_path(self, path: str = "", anonymous: bool = False, **kwargs) -> fsspec.AbstractFileSystem: protocol = get_protocol(path) - return self.get_filesystem(protocol, anonymous=anonymous) + return self.get_filesystem(protocol, anonymous=anonymous, **kwargs) @staticmethod def is_remote(path: Union[str, os.PathLike]) -> bool: diff --git a/flytekit/core/type_engine.py b/flytekit/core/type_engine.py index f21e93a774..306c4116ad 100644 --- a/flytekit/core/type_engine.py +++ b/flytekit/core/type_engine.py @@ -129,7 +129,6 @@ def to_python_value(self, ctx: FlyteContext, lv: Literal, expected_python_type: f"Conversion to python value expected type {expected_python_type} from literal not implemented" ) - @abstractmethod def to_html(self, ctx: FlyteContext, python_val: T, expected_python_type: Type[T]) -> str: """ Converts any python val (dataframe, int, float) to a html string, and it will be wrapped in the HTML div diff --git a/flytekit/types/directory/types.py b/flytekit/types/directory/types.py index 7d576f9353..f4f23eb72f 100644 --- a/flytekit/types/directory/types.py +++ b/flytekit/types/directory/types.py @@ -2,20 +2,25 @@ import os import pathlib +import random import typing from dataclasses import dataclass, field from pathlib import Path +from typing import Any, Generator, Tuple +from uuid import UUID +import fsspec from dataclasses_json import config, dataclass_json +from fsspec.utils import get_protocol from marshmallow import fields -from flytekit.core.context_manager import FlyteContext +from flytekit.core.context_manager import FlyteContext, FlyteContextManager from flytekit.core.type_engine import TypeEngine, TypeTransformer from flytekit.models import types as _type_models from flytekit.models.core import types as _core_types from flytekit.models.literals import Blob, BlobMetadata, Literal, Scalar from flytekit.models.types import LiteralType -from flytekit.types.file import FileExt +from flytekit.types.file import FileExt, FlyteFile T = typing.TypeVar("T") PathType = typing.Union[str, os.PathLike] @@ -148,6 +153,18 @@ def __fspath__(self): def extension(cls) -> str: return "" + @classmethod + def new_remote(cls) -> FlyteDirectory: + """ + Create a new FlyteDirectory object using the currently configured default remote in the context (i.e. + the raw_output_prefix configured in the current FileAccessProvider object in the context). + This is used if you explicitly have a folder somewhere that you want to create files under. + If you want to write a whole folder, you can let your task return a FlyteDirectory object, + and let flytekit handle the uploading. + """ + d = FlyteContext.current_context().file_access.get_random_remote_directory() + return FlyteDirectory(path=d) + def __class_getitem__(cls, item: typing.Union[typing.Type, str]) -> typing.Type[FlyteDirectory]: if item is None: return cls @@ -176,6 +193,12 @@ def downloaded(self) -> bool: def remote_directory(self) -> typing.Optional[str]: return self._remote_directory + @property + def sep(self) -> str: + if os.name == "nt" and get_protocol(self.path or self.remote_source or self.remote_directory) == "file": + return "\\" + return "/" + @property def remote_source(self) -> str: """ @@ -184,9 +207,67 @@ def remote_source(self) -> str: """ return typing.cast(str, self._remote_source) + def new_file(self, name: typing.Optional[str] = None) -> FlyteFile: + """ + This will create a new file under the current folder. + If given a name, it will use the name given, otherwise it'll pick a random string. + Collisions are not checked. + """ + # TODO we may want to use - https://github.com/fsspec/universal_pathlib + if not name: + name = UUID(int=random.getrandbits(128)).hex + new_path = self.sep.join([str(self.path).rstrip(self.sep), name]) # trim trailing sep if any and join + return FlyteFile(path=new_path) + + def new_dir(self, name: typing.Optional[str] = None) -> FlyteDirectory: + """ + This will create a new folder under the current folder. + If given a name, it will use the name given, otherwise it'll pick a random string. + Collisions are not checked. + """ + if not name: + name = UUID(int=random.getrandbits(128)).hex + + new_path = self.sep.join([str(self.path).rstrip(self.sep), name]) # trim trailing sep if any and join + return FlyteDirectory(path=new_path) + def download(self) -> str: return self.__fspath__() + def crawl( + self, maxdepth: typing.Optional[int] = None, topdown: bool = True, **kwargs + ) -> Generator[Tuple[typing.Union[str, os.PathLike[Any]], typing.Dict[Any, Any]], None, None]: + """ + Crawl returns a generator of all files prefixed by any sub-folders under the given "FlyteDirectory". + if details=True is passed, then it will return a dictionary as specified by fsspec. + + Example: + + >>> list(fd.crawl()) + [("/base", "file1"), ("/base", "dir1/file1"), ("/base", "dir2/file1"), ("/base", "dir1/dir/file1")] + + >>> list(x.crawl(detail=True)) + [('/tmp/test', {'my-dir/ab.py': {'name': '/tmp/test/my-dir/ab.py', 'size': 0, 'type': 'file', + 'created': 1677720780.2318847, 'islink': False, 'mode': 33188, 'uid': 501, 'gid': 0, + 'mtime': 1677720780.2317934, 'ino': 1694329, 'nlink': 1}})] + """ + final_path = self.path + if self.remote_source: + final_path = self.remote_source + elif self.remote_directory: + final_path = self.remote_directory + ctx = FlyteContextManager.current_context() + fs = ctx.file_access.get_filesystem_for_path(final_path) + base_path_len = len(fsspec.core.strip_protocol(final_path)) + 1 # Add additional `/` at the end + for base, _, files in fs.walk(final_path, maxdepth, topdown, **kwargs): + current_base = base[base_path_len:] + if isinstance(files, dict): + for f, v in files.items(): + yield final_path, {os.path.join(current_base, f): v} + else: + for f in files: + yield final_path, os.path.join(current_base, f) + def __repr__(self): return self.path diff --git a/flytekit/types/file/file.py b/flytekit/types/file/file.py index 6537f85cae..23f4137344 100644 --- a/flytekit/types/file/file.py +++ b/flytekit/types/file/file.py @@ -3,12 +3,13 @@ import os import pathlib import typing +from contextlib import contextmanager from dataclasses import dataclass, field from dataclasses_json import config, dataclass_json from marshmallow import fields -from flytekit.core.context_manager import FlyteContext +from flytekit.core.context_manager import FlyteContext, FlyteContextManager from flytekit.core.type_engine import TypeEngine, TypeTransformer, TypeTransformerFailedError from flytekit.loggers import logger from flytekit.models.core.types import BlobType @@ -27,7 +28,9 @@ def noop(): @dataclass_json @dataclass class FlyteFile(os.PathLike, typing.Generic[T]): - path: typing.Union[str, os.PathLike] = field(default=None, metadata=config(mm_field=fields.String())) # type: ignore + path: typing.Union[str, os.PathLike] = field( + default=None, metadata=config(mm_field=fields.String()) + ) # type: ignore """ Since there is no native Python implementation of files and directories for the Flyte Blob type, (like how int exists for Flyte's Integer type) we need to create one so that users can express that their tasks take @@ -148,6 +151,15 @@ def t2() -> flytekit_typing.FlyteFile["csv"]: def extension(cls) -> str: return "" + @classmethod + def new_remote_file(cls, name: typing.Optional[str] = None) -> FlyteFile: + """ + Create a new FlyteFile object with a remote path. + """ + ctx = FlyteContextManager.current_context() + remote_path = ctx.file_access.get_random_remote_path(name) + return cls(path=remote_path) + def __class_getitem__(cls, item: typing.Union[str, typing.Type]) -> typing.Type[FlyteFile]: from . import FileExt @@ -226,6 +238,57 @@ def remote_source(self) -> str: def download(self) -> str: return self.__fspath__() + @contextmanager + def open( + self, + mode: str, + cache_type: typing.Optional[str] = None, + cache_options: typing.Optional[typing.Dict[str, typing.Any]] = None, + ): + """ + Returns a streaming File handle + + .. code-block:: python + + @task + def copy_file(ff: FlyteFile) -> FlyteFile: + new_file = FlyteFile.new_remote_file(ff.name) + with ff.open("rb", cache_type="readahead", cache={}) as r: + with new_file.open("wb") as w: + w.write(r.read()) + return new_file + + Alternatively + + .. code-block:: python + + @task + def copy_file(ff: FlyteFile) -> FlyteFile: + new_file = FlyteFile.new_remote_file(ff.name) + with fsspec.open(f"readahead::{ff.remote_path}", "rb", readahead={}) as r: + with new_file.open("wb") as w: + w.write(r.read()) + return new_file + + + :param mode: str Open mode like 'rb', 'rt', 'wb', ... + :param cache_type: optional str Specify if caching is to be used. Cache protocol can be ones supported by + fsspec https://filesystem-spec.readthedocs.io/en/latest/api.html#readbuffering, + especially useful for large file reads + :param cache_options: optional Dict[str, Any] Refer to fsspec caching options. This is strongly coupled to the + cache_protocol + """ + ctx = FlyteContextManager.current_context() + final_path = self.path + if self.remote_source: + final_path = self.remote_source + elif self.remote_path: + final_path = self.remote_path + fs = ctx.file_access.get_filesystem_for_path(final_path) + f = fs.open(final_path, mode, cache_type=cache_type, cache_options=cache_options) + yield f + f.close() + def __repr__(self): return self.path diff --git a/flytekit/types/structured/basic_dfs.py b/flytekit/types/structured/basic_dfs.py index ae3e8a00d9..c8f4ef3baa 100644 --- a/flytekit/types/structured/basic_dfs.py +++ b/flytekit/types/structured/basic_dfs.py @@ -62,22 +62,6 @@ def encode( structured_dataset_type.format = PARQUET return literals.StructuredDataset(uri=uri, metadata=StructuredDatasetMetadata(structured_dataset_type)) - def ddencode( - self, - ctx: FlyteContext, - structured_dataset: StructuredDataset, - structured_dataset_type: StructuredDatasetType, - ) -> literals.StructuredDataset: - - path = typing.cast(str, structured_dataset.uri) or ctx.file_access.get_random_remote_directory() - df = typing.cast(pd.DataFrame, structured_dataset.dataframe) - local_dir = ctx.file_access.get_random_local_directory() - local_path = os.path.join(local_dir, f"{0:05}") - df.to_parquet(local_path, coerce_timestamps="us", allow_truncated_timestamps=False) - ctx.file_access.upload_directory(local_dir, path) - structured_dataset_type.format = PARQUET - return literals.StructuredDataset(uri=path, metadata=StructuredDatasetMetadata(structured_dataset_type)) - class ParquetToPandasDecodingHandler(StructuredDatasetDecoder): def __init__(self): @@ -101,20 +85,6 @@ def decode( kwargs = get_storage_options(ctx.file_access.data_config, uri, anon=True) return pd.read_parquet(uri, columns=columns, storage_options=kwargs) - def dcccecode( - self, - ctx: FlyteContext, - flyte_value: literals.StructuredDataset, - current_task_metadata: StructuredDatasetMetadata, - ) -> pd.DataFrame: - path = flyte_value.uri - local_dir = ctx.file_access.get_random_local_directory() - ctx.file_access.get_data(path, local_dir, is_multipart=True) - if current_task_metadata.structured_dataset_type and current_task_metadata.structured_dataset_type.columns: - columns = [c.name for c in current_task_metadata.structured_dataset_type.columns] - return pd.read_parquet(local_dir, columns=columns) - return pd.read_parquet(local_dir) - class ArrowToParquetEncodingHandler(StructuredDatasetEncoder): def __init__(self): diff --git a/tests/flytekit/unit/core/test_data.py b/tests/flytekit/unit/core/test_data.py index 880036f636..1b33ad2923 100644 --- a/tests/flytekit/unit/core/test_data.py +++ b/tests/flytekit/unit/core/test_data.py @@ -1,13 +1,17 @@ import os +import random import shutil import tempfile +from uuid import UUID import fsspec import mock import pytest from flytekit.configuration import Config, S3Config +from flytekit.core.context_manager import FlyteContextManager from flytekit.core.data_persistence import FileAccessProvider, default_local_file_access_provider, s3_setup_args +from flytekit.types.directory.types import FlyteDirectory local = fsspec.filesystem("file") root = os.path.abspath(os.sep) @@ -99,6 +103,8 @@ def source_folder(): nested_dir = os.path.join(src_dir, "nested") local.mkdir(nested_dir) local.touch(os.path.join(src_dir, "original.txt")) + with open(os.path.join(src_dir, "original.txt"), "w") as fh: + fh.write("hello original") local.touch(os.path.join(nested_dir, "more.txt")) yield src_dir shutil.rmtree(parent_temp) @@ -213,3 +219,112 @@ def test_s3_setup_args_env_aws(mock_os, mock_get_config_file): kwargs = s3_setup_args(S3Config.auto()) # not explicitly in kwargs, since fsspec/boto3 will use these env vars by default assert kwargs == {} + + +def test_crawl_local_nt(source_folder): + """ + running this to see what it prints + """ + if os.name != "nt": # don't + return + source_folder = os.path.join(source_folder, "") # ensure there's a trailing / or \ + fd = FlyteDirectory(path=source_folder) + res = fd.crawl() + split = [(x, y) for x, y in res] + print(f"NT split {split}") + + # Test crawling a directory without trailing / or \ + source_folder = source_folder[:-1] + fd = FlyteDirectory(path=source_folder) + res = fd.crawl() + files = [os.path.join(x, y) for x, y in res] + print(f"NT files joined {files}") + + +def test_crawl_local_non_nt(source_folder): + """ + crawl on the source folder fixture should return for example + ('/var/folders/jx/54tww2ls58n8qtlp9k31nbd80000gp/T/tmpp14arygf/source/', 'original.txt') + ('/var/folders/jx/54tww2ls58n8qtlp9k31nbd80000gp/T/tmpp14arygf/source/', 'nested/more.txt') + """ + if os.name == "nt": # don't + return + source_folder = os.path.join(source_folder, "") # ensure there's a trailing / or \ + fd = FlyteDirectory(path=source_folder) + res = fd.crawl() + split = [(x, y) for x, y in res] + files = [os.path.join(x, y) for x, y in split] + assert set(split) == {(source_folder, "original.txt"), (source_folder, os.path.join("nested", "more.txt"))} + expected = {os.path.join(source_folder, "original.txt"), os.path.join(source_folder, "nested", "more.txt")} + assert set(files) == expected + + # Test crawling a directory without trailing / or \ + source_folder = source_folder[:-1] + fd = FlyteDirectory(path=source_folder) + res = fd.crawl() + files = [os.path.join(x, y) for x, y in res] + assert set(files) == expected + + # Test crawling a single file + fd = FlyteDirectory(path=os.path.join(source_folder, "original.txt")) + res = fd.crawl() + files = [os.path.join(x, y) for x, y in res] + assert len(files) == 0 + + +@pytest.mark.sandbox_test +def test_crawl_s3(source_folder): + """ + ('s3://my-s3-bucket/testdata/5b31492c032893b515650f8c76008cf7', 'original.txt') + ('s3://my-s3-bucket/testdata/5b31492c032893b515650f8c76008cf7', 'nested/more.txt') + """ + # Running mkdir on s3 filesystem doesn't do anything so leaving out for now + dc = Config.for_sandbox().data_config + provider = FileAccessProvider( + local_sandbox_dir="/tmp/unittest", raw_output_prefix="s3://my-s3-bucket/testdata/", data_config=dc + ) + s3_random_target = provider.get_random_remote_directory() + provider.put_data(source_folder, s3_random_target, is_multipart=True) + ctx = FlyteContextManager.current_context() + expected = {f"{s3_random_target}/original.txt", f"{s3_random_target}/nested/more.txt"} + + with FlyteContextManager.with_context(ctx.with_file_access(provider)): + fd = FlyteDirectory(path=s3_random_target) + res = fd.crawl() + res = [(x, y) for x, y in res] + files = [os.path.join(x, y) for x, y in res] + assert set(files) == expected + assert set(res) == {(s3_random_target, "original.txt"), (s3_random_target, os.path.join("nested", "more.txt"))} + + fd_file = FlyteDirectory(path=f"{s3_random_target}/original.txt") + res = fd_file.crawl() + files = [r for r in res] + assert len(files) == 1 + + +@pytest.mark.sandbox_test +def test_walk_local_copy_to_s3(source_folder): + dc = Config.for_sandbox().data_config + explicit_empty_folder = UUID(int=random.getrandbits(128)).hex + raw_output_path = f"s3://my-s3-bucket/testdata/{explicit_empty_folder}" + provider = FileAccessProvider(local_sandbox_dir="/tmp/unittest", raw_output_prefix=raw_output_path, data_config=dc) + + ctx = FlyteContextManager.current_context() + local_fd = FlyteDirectory(path=source_folder) + local_fd_crawl = local_fd.crawl() + local_fd_crawl = [x for x in local_fd_crawl] + with FlyteContextManager.with_context(ctx.with_file_access(provider)): + fd = FlyteDirectory.new_remote() + assert raw_output_path in fd.path + + # Write source folder files to new remote path + for root_path, suffix in local_fd_crawl: + new_file = fd.new_file(suffix) # noqa + with open(os.path.join(root_path, suffix), "rb") as r: # noqa + with new_file.open("w") as w: + print(f"Writing, t {type(w)} p {new_file.path} |{suffix}|") + w.write(str(r.read())) + + new_crawl = fd.crawl() + new_suffixes = [y for x, y in new_crawl] + assert len(new_suffixes) == 2 # should have written two files diff --git a/tests/flytekit/unit/core/test_flyte_file.py b/tests/flytekit/unit/core/test_flyte_file.py index e2123222e0..1c1593ad4c 100644 --- a/tests/flytekit/unit/core/test_flyte_file.py +++ b/tests/flytekit/unit/core/test_flyte_file.py @@ -7,9 +7,8 @@ import pytest import flytekit.configuration -from flytekit.configuration import Image, ImageConfig -from flytekit.core import context_manager -from flytekit.core.context_manager import ExecutionState +from flytekit.configuration import Config, Image, ImageConfig +from flytekit.core.context_manager import ExecutionState, FlyteContextManager from flytekit.core.data_persistence import FileAccessProvider, flyte_tmp_dir from flytekit.core.dynamic_workflow_task import dynamic from flytekit.core.launch_plan import LaunchPlan @@ -81,11 +80,10 @@ def t1() -> FlyteFile: def my_wf() -> FlyteFile: return t1() - random_dir = context_manager.FlyteContext.current_context().file_access.get_random_local_directory() - # print(f"Random: {random_dir}") + random_dir = FlyteContextManager.current_context().file_access.get_random_local_directory() fs = FileAccessProvider(local_sandbox_dir=random_dir, raw_output_prefix=os.path.join(random_dir, "mock_remote")) - ctx = context_manager.FlyteContext.current_context() - with context_manager.FlyteContextManager.with_context(ctx.with_file_access(fs)): + ctx = FlyteContextManager.current_context() + with FlyteContextManager.with_context(ctx.with_file_access(fs)): top_level_files = os.listdir(random_dir) assert len(top_level_files) == 1 # the flytekit_local folder @@ -108,10 +106,10 @@ def t1() -> FlyteFile: def my_wf() -> FlyteFile: return t1() - random_dir = context_manager.FlyteContext.current_context().file_access.get_random_local_directory() + random_dir = FlyteContextManager.current_context().file_access.get_random_local_directory() fs = FileAccessProvider(local_sandbox_dir=random_dir, raw_output_prefix=os.path.join(random_dir, "mock_remote")) - ctx = context_manager.FlyteContext.current_context() - with context_manager.FlyteContextManager.with_context(ctx.with_file_access(fs)): + ctx = FlyteContextManager.current_context() + with FlyteContextManager.with_context(ctx.with_file_access(fs)): top_level_files = os.listdir(random_dir) assert len(top_level_files) == 1 # the flytekit_local folder @@ -137,12 +135,12 @@ def my_wf() -> FlyteFile: return t1() # This creates a random directory that we know is empty. - random_dir = context_manager.FlyteContext.current_context().file_access.get_random_local_directory() + random_dir = FlyteContextManager.current_context().file_access.get_random_local_directory() # Creating a new FileAccessProvider will add two folderst to the random dir print(f"Random {random_dir}") fs = FileAccessProvider(local_sandbox_dir=random_dir, raw_output_prefix=os.path.join(random_dir, "mock_remote")) - ctx = context_manager.FlyteContext.current_context() - with context_manager.FlyteContextManager.with_context(ctx.with_file_access(fs)): + ctx = FlyteContextManager.current_context() + with FlyteContextManager.with_context(ctx.with_file_access(fs)): working_dir = os.listdir(random_dir) assert len(working_dir) == 1 # the local_flytekit folder @@ -189,11 +187,11 @@ def my_wf() -> FlyteFile: return t1() # This creates a random directory that we know is empty. - random_dir = context_manager.FlyteContext.current_context().file_access.get_random_local_directory() + random_dir = FlyteContextManager.current_context().file_access.get_random_local_directory() # Creating a new FileAccessProvider will add two folderst to the random dir fs = FileAccessProvider(local_sandbox_dir=random_dir, raw_output_prefix=os.path.join(random_dir, "mock_remote")) - ctx = context_manager.FlyteContext.current_context() - with context_manager.FlyteContextManager.with_context(ctx.with_file_access(fs)): + ctx = FlyteContextManager.current_context() + with FlyteContextManager.with_context(ctx.with_file_access(fs)): working_dir = os.listdir(random_dir) assert len(working_dir) == 1 # the local_flytekit dir @@ -243,8 +241,8 @@ def dyn(in1: FlyteFile): fd = FlyteFile("s3://anything") - with context_manager.FlyteContextManager.with_context( - context_manager.FlyteContextManager.current_context().with_serialization_settings( + with FlyteContextManager.with_context( + FlyteContextManager.current_context().with_serialization_settings( flytekit.configuration.SerializationSettings( project="test_proj", domain="test_domain", @@ -254,8 +252,8 @@ def dyn(in1: FlyteFile): ) ) ): - ctx = context_manager.FlyteContextManager.current_context() - with context_manager.FlyteContextManager.with_context( + ctx = FlyteContextManager.current_context() + with FlyteContextManager.with_context( ctx.with_execution_state(ctx.new_execution_state().with_params(mode=ExecutionState.Mode.TASK_EXECUTION)) ) as ctx: lit = TypeEngine.to_literal( @@ -433,3 +431,44 @@ def wf(path: str) -> os.PathLike: return t2(ff=n1) assert flyte_tmp_dir in wf(path="s3://somewhere").path + + +@pytest.mark.sandbox_test +def test_file_open_things(): + @task + def write_this_file_to_s3() -> FlyteFile: + ctx = FlyteContextManager.current_context() + dest = ctx.file_access.get_random_remote_path() + ctx.file_access.put(__file__, dest) + return FlyteFile(path=dest) + + @task + def copy_file(ff: FlyteFile) -> FlyteFile: + new_file = FlyteFile.new_remote_file(ff.remote_path) + with ff.open("r") as r: + with new_file.open("w") as w: + w.write(r.read()) + return new_file + + @task + def print_file(ff: FlyteFile): + with open(ff, "r") as fh: + print(len(fh.readlines())) + + dc = Config.for_sandbox().data_config + with tempfile.TemporaryDirectory() as new_sandbox: + provider = FileAccessProvider( + local_sandbox_dir=new_sandbox, raw_output_prefix="s3://my-s3-bucket/testdata/", data_config=dc + ) + ctx = FlyteContextManager.current_context() + local = ctx.file_access.get_filesystem("file") # get a local file system. + with FlyteContextManager.with_context(ctx.with_file_access(provider)): + f = write_this_file_to_s3() + copy_file(ff=f) + files = local.find(new_sandbox) + # copy_file was done via streaming so no files should have been written + assert len(files) == 0 + print_file(ff=f) + # print_file uses traditional download semantics so now a file should have been created + files = local.find(new_sandbox) + assert len(files) == 1 diff --git a/tests/flytekit/unit/core/tracker/d.py b/tests/flytekit/unit/core/tracker/d.py index 9385b0f08d..c84e36fe59 100644 --- a/tests/flytekit/unit/core/tracker/d.py +++ b/tests/flytekit/unit/core/tracker/d.py @@ -9,3 +9,7 @@ def tasks(): @task def foo(): pass + + +def inner_function(a: str) -> str: + return "hello" diff --git a/tests/flytekit/unit/core/tracker/test_tracking.py b/tests/flytekit/unit/core/tracker/test_tracking.py index 33ae18acd5..b33725436d 100644 --- a/tests/flytekit/unit/core/tracker/test_tracking.py +++ b/tests/flytekit/unit/core/tracker/test_tracking.py @@ -79,3 +79,10 @@ def test_extract_task_module(test_input, expected): except Exception: FeatureFlags.FLYTE_PYTHON_PACKAGE_ROOT = old raise + + +local_task = task(d.inner_function) + + +def test_local_task_wrap(): + assert local_task.instantiated_in == "tests.flytekit.unit.core.tracker.test_tracking" diff --git a/tests/flytekit/unit/extras/sqlite3/chinook.zip b/tests/flytekit/unit/extras/sqlite3/chinook.zip new file mode 100644 index 0000000000..6dd568fa61 Binary files /dev/null and b/tests/flytekit/unit/extras/sqlite3/chinook.zip differ diff --git a/tests/flytekit/unit/extras/sqlite3/test_task.py b/tests/flytekit/unit/extras/sqlite3/test_task.py index 40fc94a3d2..f8014f244b 100644 --- a/tests/flytekit/unit/extras/sqlite3/test_task.py +++ b/tests/flytekit/unit/extras/sqlite3/test_task.py @@ -1,3 +1,5 @@ +import os + import pandas import pytest @@ -10,8 +12,7 @@ from flytekit.types.schema import FlyteSchema ctx = context_manager.FlyteContextManager.current_context() -EXAMPLE_DB = ctx.file_access.get_random_local_path("chinook.zip") -ctx.file_access.get_data("https://www.sqlitetutorial.net/wp-content/uploads/2018/03/chinook.zip", EXAMPLE_DB) +EXAMPLE_DB = os.path.join(os.path.dirname(os.path.realpath(__file__)), "chinook.zip") # This task belongs to test_task_static but is intentionally here to help test tracking tk = SQLite3Task( diff --git a/tests/flytekit/unit/types/directory/__init__.py b/tests/flytekit/unit/types/directory/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/flytekit/unit/types/directory/test_types.py b/tests/flytekit/unit/types/directory/test_types.py new file mode 100644 index 0000000000..199b788733 --- /dev/null +++ b/tests/flytekit/unit/types/directory/test_types.py @@ -0,0 +1,31 @@ +import mock + +from flytekit import FlyteContext +from flytekit.types.directory import FlyteDirectory +from flytekit.types.file import FlyteFile + + +def test_new_file_dir(): + fd = FlyteDirectory(path="s3://my-bucket") + assert fd.sep == "/" + inner_dir = fd.new_dir("test") + assert inner_dir.path == "s3://my-bucket/test" + fd = FlyteDirectory(path="s3://my-bucket/") + inner_dir = fd.new_dir("test") + assert inner_dir.path == "s3://my-bucket/test" + f = inner_dir.new_file("test") + assert isinstance(f, FlyteFile) + assert f.path == "s3://my-bucket/test/test" + + +def test_new_remote_dir(): + fd = FlyteDirectory.new_remote() + assert FlyteContext.current_context().file_access.raw_output_prefix in fd.path + + +@mock.patch("flytekit.types.directory.types.os.name", "nt") +def test_sep_nt(): + fd = FlyteDirectory(path="file://mypath") + assert fd.sep == "\\" + fd = FlyteDirectory(path="s3://mypath") + assert fd.sep == "/" diff --git a/tests/flytekit/unit/types/file/__init__.py b/tests/flytekit/unit/types/file/__init__.py new file mode 100644 index 0000000000..e69de29bb2