Skip to content

Commit

Permalink
fix(merge_index): scheme was not well formatted (#576)
Browse files Browse the repository at this point in the history
* fix(merge_index): scheme was not well formatted

* fix: dbfs merge index

* fix lint
  • Loading branch information
fwertel authored Jan 26, 2024
1 parent aacea8b commit dffef31
Show file tree
Hide file tree
Showing 2 changed files with 65 additions and 25 deletions.
66 changes: 41 additions & 25 deletions streaming/base/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,7 +253,7 @@ def merge_index(*args: Any, **kwargs: Any):
raise ValueError(f'Invalid arguments to merge_index: {args}, {kwargs}')


def _merge_index_from_list(index_file_urls: List[Union[str, Tuple[str, str]]],
def _merge_index_from_list(index_file_urls: Sequence[Union[str, Tuple[str, str]]],
out: Union[str, Tuple[str, str]],
keep_local: bool = True,
download_timeout: int = 60) -> None:
Expand Down Expand Up @@ -359,6 +359,44 @@ def _merge_index_from_list(index_file_urls: List[Union[str, Tuple[str, str]]],
shutil.rmtree(cu.local, ignore_errors=True)


def _not_merged_index(index_file_path: str, out: str):
"""Check if index_file_path is the merged index at folder out.
Args:
index_file_path (str): the path to index.json file
out (str): remote or local url of a folder
Return:
(bool): no if index.json sits in out instead of in the subfolders of out
"""
prefix = str(urllib.parse.urlparse(out).path)
return os.path.dirname(index_file_path).strip('/') != prefix.strip('/')


def _format_remote_index_files(remote: str, files: List[str]) -> List[str]:
"""Formats the remote index files by appending the remote URL scheme and netloc to each file.
Args:
remote (str): The remote URL.
files (list[str]): The list of files.
Returns:
list[str]: The formatted remote index files.
"""
remote_index_files = []
obj = urllib.parse.urlparse(remote)
for file in files:
if file.endswith(get_index_basename()) and _not_merged_index(file, remote):
join_char = '://'
if obj.scheme == 'dbfs':
path = Path(remote)
prefix = os.path.join(path.parts[0], path.parts[1])
if prefix == 'dbfs:/Volumes':
join_char = ':/'

remote_index_files.append(obj.scheme + join_char + os.path.join(obj.netloc, file))
return remote_index_files


def _merge_index_from_root(out: Union[str, Tuple[str, str]],
keep_local: bool = True,
download_timeout: int = 60) -> None:
Expand All @@ -378,18 +416,6 @@ def _merge_index_from_root(out: Union[str, Tuple[str, str]],
"""
from streaming.base.storage.upload import CloudUploader

def not_merged_index(index_file_path: str, out: str):
"""Check if index_file_path is the merged index at folder out.
Args:
index_file_path (str): the path to index.json file
out (str): remote or local url of a folder
Return:
(bool): no if index.json sits in out instead of in the subfolders of out
"""
prefix = str(urllib.parse.urlparse(out).path)
return os.path.dirname(index_file_path).strip('/') != prefix.strip('/')

if not out:
logger.warning('No MDS dataset folder specified, no index merged')
return
Expand All @@ -399,21 +425,11 @@ def not_merged_index(index_file_path: str, out: str):
local_index_files = []
cl = CloudUploader.get(cu.local, exist_ok=True, keep_local=True)
for file in cl.list_objects():
if file.endswith('.json') and not_merged_index(file, cu.local):
if file.endswith('.json') and _not_merged_index(file, cu.local):
local_index_files.append(file)

if cu.remote:
obj = urllib.parse.urlparse(cu.remote)
remote_index_files = []
for file in cu.list_objects():
if file.endswith(get_index_basename()) and not_merged_index(file, cu.remote):
join_char = '//'
if obj.scheme == 'dbfs':
path = Path(cu.remote)
prefix = os.path.join(path.parts[0], path.parts[1])
if prefix == 'dbfs:/Volumes':
join_char = '/'
remote_index_files.append(obj.scheme + join_char + os.path.join(obj.netloc, file))
remote_index_files = _format_remote_index_files(cu.list_objects(), cu.remote)
if len(local_index_files) == len(remote_index_files):
_merge_index_from_list(list(zip(local_index_files, remote_index_files)),
out,
Expand Down
24 changes: 24 additions & 0 deletions tests/test_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,30 @@ def get_expected(mds_root: str):
assert n_shard_files == expected_n_shard_files, f'expected {expected_n_shard_files} shard files but got {n_shard_files}'


@pytest.mark.parametrize('scheme', ['gs', 's3', 'oci', 'dbfs'])
def test_format_remote_index_files(scheme: str):
"""Validate the format of remote index files."""
from streaming.base.util import _format_remote_index_files

if scheme == 'dbfs':
remote = os.path.join('dbfs:/', 'Volumes')
else:
full_scheme = scheme + '://'
remote = os.path.join(full_scheme, MY_BUCKET[full_scheme], MY_PREFIX)

index_files = [
os.path.join('folder_1', 'index.json'),
os.path.join('folder_2', 'index.json'),
os.path.join('folder_3', 'index.json'),
]
remote_index_files = _format_remote_index_files(remote, index_files)

assert len(remote_index_files) == len(index_files)
for file in remote_index_files:
obj = urllib.parse.urlparse(file)
assert obj.scheme == scheme


@pytest.mark.parametrize('index_file_urls_pattern', [1, 2, 3])
@pytest.mark.parametrize('keep_local', [True, False])
@pytest.mark.parametrize('scheme', ['gs://', 's3://', 'oci://'])
Expand Down

0 comments on commit dffef31

Please sign in to comment.