From 36eb383250852a26c75574b412be322f9fbed829 Mon Sep 17 00:00:00 2001 From: James Knighton Date: Mon, 3 Oct 2022 19:31:06 -0700 Subject: [PATCH] Fix hang. --- streaming/base/dataset.py | 16 +++++++++++----- 1 file changed, 11 insertions(+), 5 deletions(-) diff --git a/streaming/base/dataset.py b/streaming/base/dataset.py index e9e5603ec..0fb4717a7 100644 --- a/streaming/base/dataset.py +++ b/streaming/base/dataset.py @@ -399,20 +399,26 @@ def _download_shard_part(self, # Is compression used? if zip_info: - # Download the compressed form if missing (or wait on download). + # Download the compressed form if missing (or wait on its download). zip_filename = os.path.join(self.local, self.split, zip_info.basename) if not os.path.isfile(zip_filename): - self._download_file(zip_info.basename, wait) + # Waiter or doer? + if wait: + # If waiter: wait for *raw* version to exist (as the zip may be ephemeral). + wait_for_download(raw_filename, self.timeout) + else: + # If doer: download the zip version. + self._download_file(zip_info.basename) # Validate and decompress (or wait on that). self._decompress_shard_part(zip_info, zip_filename, raw_filename, compression, wait) else: - # Download the raw version. + # Download the raw form (or wait on its download). self._download_file(raw_info.basename, wait) - # Doer or waiter? + # Waiter or doer? if not wait: - # If doer: load raw, validate. + # If doer: validate if requested. if self.hash: data = open(raw_filename, 'rb').read() assert get_hash(self.hash, data) == raw_info.hashes[self.hash]