Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add RedPajama-V2 dataset to TFDS. #5594

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion tensorflow_datasets/core/dataset_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -1315,7 +1315,7 @@ def _make_download_manager(
)

return download.DownloadManager(
download_dir=download_dir,
download_dir=download_dir / self.name,
extract_dir=extract_dir,
manual_dir=manual_dir,
url_infos=self.url_infos,
Expand Down
264 changes: 131 additions & 133 deletions tensorflow_datasets/core/download/download_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -316,9 +316,6 @@ def downloaded_size(self):
"""Returns the total size of downloaded files."""
return sum(url_info.size for url_info in self._recorded_url_infos.values())

def _get_dl_path(self, url: str, sha256: str) -> epath.Path:
return self._download_dir / resource_lib.get_dl_fname(url, sha256)

@property
def register_checksums(self):
"""Returns whether checksums are being computed and recorded to file."""
Expand All @@ -341,7 +338,7 @@ def _download(self, resource: Url) -> promise.Promise[epath.Path]:

This function:

1. Reuse cache (`_get_cached_path`) or download the file
1. Reuse cache (`downloader.get_cached_path`) or download the file
2. Register or validate checksums (`_register_or_validate_checksums`)
3. Rename download to final path (`_rename_and_get_final_dl_path`)

Expand All @@ -352,37 +349,39 @@ def _download(self, resource: Url) -> promise.Promise[epath.Path]:
path: The path to the downloaded resource.
"""
# Normalize the input
if isinstance(resource, str):
url = resource
else:
url = resource.url
if not isinstance(resource, resource_lib.Resource):
resource = resource_lib.Resource(url=resource)
url = resource.url
assert url is not None, 'URL is undefined from resource.'

expected_url_info = self._url_infos.get(url)
registered_url_info = self._url_infos.get(url)

# 3 possible destinations for the path:
# * In `manual_dir` (manually downloaded data)
# * In `downloads/url_path` (checksum unknown)
# * In `downloads/checksum_path` (checksum registered)
# * In `downloads/unregistered_path` (checksum unknown)
# * In `downloads/registered_path` (checksum registered)
manually_downloaded_path = _get_manually_downloaded_path(
manual_dir=self._manual_dir,
expected_url_info=expected_url_info,
url_info=registered_url_info,
)
url_path = self._get_dl_path(
url, sha256=hashlib.sha256(url.encode('utf-8')).hexdigest()
)
checksum_path = (
self._get_dl_path(url, sha256=expected_url_info.checksum)
if expected_url_info
else None
download_dir = self._download_dir / resource.relative_download_dir
download_dir.mkdir(parents=True, exist_ok=True)
unregistered_path = download_dir / resource_lib.get_dl_fname(
url=url, checksum=hashlib.sha256(url.encode('utf-8')).hexdigest()
)
if registered_url_info:
registered_path = download_dir / resource_lib.get_dl_fname(
url=url, checksum=registered_url_info.checksum
)
else:
registered_path = None

# Get the cached path and url_info (if they exists)
dl_result = downloader.get_cached_path(
manually_downloaded_path=manually_downloaded_path,
checksum_path=checksum_path,
url_path=url_path,
expected_url_info=expected_url_info,
registered_path=registered_path,
unregistered_path=unregistered_path,
registered_url_info=registered_url_info,
)
if dl_result.path and not self._force_download: # Download was cached
logging.info(
Expand All @@ -394,130 +393,166 @@ def _download(self, resource: Url) -> promise.Promise[epath.Path]:
else:
# Download in an empty tmp directory (to avoid name collisions)
# `download_tmp_dir` is cleaned-up in `_rename_and_get_final_dl_path`
dirname = f'{resource_lib.get_dl_dirname(url)}.tmp.{uuid.uuid4().hex}'
download_tmp_dir = self._download_dir / dirname
download_tmp_dir = (
unregistered_path.parent
/ f'{unregistered_path.name}.tmp.{uuid.uuid4().hex}'
)
download_tmp_dir.mkdir()
logging.info(f'Downloading {url} into {download_tmp_dir}...')
future = self._downloader.download(
url, download_tmp_dir, verify=self._verify_ssl
)

# Post-process the result
return future.then(
lambda dl_result: self._register_or_validate_checksums( # pylint: disable=g-long-lambda
url=url,
path=dl_result.path,
computed_url_info=dl_result.url_info,
expected_url_info=expected_url_info,
checksum_path=checksum_path,
url_path=url_path,
)
)
def callback(dl_result: downloader.DownloadResult) -> epath.Path:
return self._register_or_validate_checksums(
url=url,
dl_url_info=dl_result.url_info,
registered_url_info=registered_url_info,
dl_path=dl_result.path,
registered_path=registered_path,
unregistered_path=unregistered_path,
)

return future.then(callback)

def _register_or_validate_checksums(
self,
path: epath.Path,
url: str,
expected_url_info: checksums.UrlInfo | None,
computed_url_info: checksums.UrlInfo | None,
checksum_path: epath.Path | None,
url_path: epath.Path,
dl_url_info: checksums.UrlInfo | None,
registered_url_info: checksums.UrlInfo | None,
dl_path: epath.Path,
registered_path: epath.Path | None,
unregistered_path: epath.Path,
) -> epath.Path:
"""Validates/records checksums and renames final downloaded path."""
# `path` can be:
# * Manually downloaded
# * (cached) checksum_path
# * (cached) url_path
# * `tmp_dir/file` (downloaded path)

if computed_url_info:
if dl_url_info:
# Used both in `.downloaded_size` and `_record_url_infos()`
self._recorded_url_infos[url] = computed_url_info
self._recorded_url_infos[url] = dl_url_info

if self._register_checksums:
if not computed_url_info:
if not dl_url_info:
raise ValueError(
f'Cannot register checksums for {url}: no computed checksum. '
'--register_checksums with manually downloaded data not supported.'
)
# Note:
# * We save even if `expected_url_info == computed_url_info` as
# `expected_url_info` might have been loaded from another dataset.
# * We save even if `registered_url_info == dl_url_info` as
# `registered_url_info` might have been loaded from another dataset.
# * `register_checksums_path` was validated in `__init__` so this
# shouldn't fail.
self._record_url_infos()

# Checksum path should now match the new registered checksum (even if
# checksums were previously registered)
expected_url_info = computed_url_info
checksum_path = self._get_dl_path(url, computed_url_info.checksum)
registered_url_info = dl_url_info
registered_path = unregistered_path.parent / resource_lib.get_dl_fname(
url, dl_url_info.checksum
)
else:
# Eventually validate checksums
# Note:
# * If path is cached at `url_path` but cached
# `computed_url_info != expected_url_info`, a new download has
# been triggered (as _get_cached_path returns None)
# * If path is cached at `unregistered_path` but
# `dl_url_info != registered_url_info`, a new download has
# been triggered (as `downloader.get_cached_path` returns None)
# * If path was downloaded but checksums don't match expected, then
# the download isn't cached (re-running build will retrigger a new
# download). This is expected as it might mean the downloaded file
# was corrupted. Note: The tmp file isn't deleted to allow inspection.
_validate_checksums(
self._validate_checksums(
url=url,
path=path,
expected_url_info=expected_url_info,
computed_url_info=computed_url_info,
force_checksums_validation=self._force_checksums_validation,
dl_url_info=dl_url_info,
registered_url_info=registered_url_info,
dl_path=dl_path,
)

return self._rename_and_get_final_dl_path(
url=url,
path=path,
expected_url_info=expected_url_info,
computed_url_info=computed_url_info,
checksum_path=checksum_path,
url_path=url_path,
dl_url_info=dl_url_info,
registered_url_info=registered_url_info,
dl_path=dl_path,
registered_path=registered_path,
unregistered_path=unregistered_path,
)

def _validate_checksums(
self,
url: str,
dl_url_info: checksums.UrlInfo | None,
registered_url_info: checksums.UrlInfo | None,
dl_path: epath.Path,
) -> None:
"""Validate cached_url_info match url_info."""
# If force-checksums validations, both downloaded and registered url_info
# should exists
if self._force_checksums_validation:
# Checksum of the downloaded file unknown (for manually downloaded file)
if not dl_url_info:
dl_url_info = checksums.compute_url_info(dl_path)
# Checksums have not been registered
if not registered_url_info:
raise ValueError(
f'Missing checksums url: {url}, yet '
'`force_checksums_validation=True`. '
'Did you forget to register checksums?'
)

if (
registered_url_info
and dl_url_info
and registered_url_info != dl_url_info
):
msg = (
f'Artifact {url}, downloaded to {dl_path}, has wrong checksum:\n'
f'* Expected: {registered_url_info}\n'
f'* Got: {dl_url_info}\n'
'To debug, see: '
'https://www.tensorflow.org/datasets/overview#fixing_nonmatchingchecksumerror'
)
raise NonMatchingChecksumError(msg)

def _rename_and_get_final_dl_path(
self,
url: str,
path: epath.Path,
expected_url_info: checksums.UrlInfo | None,
computed_url_info: checksums.UrlInfo | None,
checksum_path: epath.Path | None,
url_path: epath.Path,
dl_url_info: checksums.UrlInfo | None,
registered_url_info: checksums.UrlInfo | None,
dl_path: epath.Path,
registered_path: epath.Path | None,
unregistered_path: epath.Path,
) -> epath.Path:
"""Eventually rename the downloaded file if checksums were recorded."""
# `path` can be:
# * Manually downloaded
# * (cached) checksum_path
# * (cached) url_path
# * `tmp_dir/file` (downloaded path)
if self._manual_dir and path.is_relative_to(self._manual_dir):
return path # Manually downloaded data
elif path == checksum_path: # Path already at final destination
assert computed_url_info == expected_url_info # Sanity check
return checksum_path # pytype: disable=bad-return-type
elif path == url_path:
if checksum_path:
# Manually downloaded data
if self._manual_dir and dl_path.is_relative_to(self._manual_dir):
return dl_path

# Cached at the final destination
elif dl_path == registered_path:
assert dl_url_info == registered_url_info # Sanity check
return dl_path

# Cached at the tmp destination
elif dl_path == unregistered_path:
if registered_path:
# Checksums were registered: Rename -> checksums_path
resource_lib.replace_info_file(path, checksum_path)
return path.replace(checksum_path)
resource_lib.replace_info_file(dl_path, registered_path)
return dl_path.replace(registered_path)
else:
# Checksums not registered: -> do nothing
return path
else: # Path was downloaded in tmp dir
dst_path = checksum_path or url_path
return dl_path

# Downloaded at the tmp destination
else:
path = registered_path or unregistered_path
resource_lib.write_info_file(
url=url,
path=dst_path,
path=path,
dataset_name=self._dataset_name,
original_fname=path.name,
url_info=computed_url_info,
original_fname=dl_path.name,
url_info=dl_url_info,
)
path.replace(dst_path)
path.parent.rmdir() # Cleanup tmp dir (will fail if dir not empty)
return dst_path
dl_path.replace(path)
dl_path.parent.rmdir() # Cleanup tmp dir (will fail if dir not empty)
return path

@utils.build_synchronize_decorator()
@utils.memoize()
Expand Down Expand Up @@ -711,59 +746,22 @@ def manual_dir(self) -> epath.Path:

def _get_manually_downloaded_path(
manual_dir: epath.Path | None,
expected_url_info: checksums.UrlInfo | None,
url_info: checksums.UrlInfo | None,
) -> epath.Path | None:
"""Checks if file is already downloaded in manual_dir."""
if not manual_dir: # Manual dir not passed
return None

if not expected_url_info or not expected_url_info.filename:
if not url_info or not url_info.filename:
return None # Filename unknown.

manual_path = manual_dir / expected_url_info.filename
manual_path = manual_dir / url_info.filename
if not manual_path.exists(): # File not manually downloaded
return None

return manual_path


def _validate_checksums(
url: str,
path: epath.Path,
computed_url_info: checksums.UrlInfo | None,
expected_url_info: checksums.UrlInfo | None,
force_checksums_validation: bool,
) -> None:
"""Validate computed_url_info match expected_url_info."""
# If force-checksums validations, both expected and computed url_info
# should exists
if force_checksums_validation:
# Checksum of the downloaded file unknown (for manually downloaded file)
if not computed_url_info:
computed_url_info = checksums.compute_url_info(path)
# Checksums have not been registered
if not expected_url_info:
raise ValueError(
f'Missing checksums url: {url}, yet '
'`force_checksums_validation=True`. '
'Did you forget to register checksums?'
)

if (
expected_url_info
and computed_url_info
and expected_url_info != computed_url_info
):
msg = (
f'Artifact {url}, downloaded to {path}, has wrong checksum:\n'
f'* Expected: {expected_url_info}\n'
f'* Got: {computed_url_info}\n'
'To debug, see: '
'https://www.tensorflow.org/datasets/overview#fixing_nonmatchingchecksumerror'
)
raise NonMatchingChecksumError(msg)


def _map_promise(map_fn, all_inputs):
"""Map the function into each element and resolve the promise."""
all_promises = tree.map_structure(map_fn, all_inputs) # Apply the function
Expand Down
Loading
Loading