From 6c86cfcd7410535c29e0d9ef222d7383c0da8ae1 Mon Sep 17 00:00:00 2001 From: Karan Jariwala Date: Wed, 28 Dec 2022 22:22:25 -0800 Subject: [PATCH] Reference the shared memory object in a worker process when using spawn multiprocessing method --- streaming/base/dataset.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/streaming/base/dataset.py b/streaming/base/dataset.py index ce816b0e4..cac4c7a75 100644 --- a/streaming/base/dataset.py +++ b/streaming/base/dataset.py @@ -320,6 +320,9 @@ def _get_progress(self, world: World) -> Tuple[int, int]: Returns: Tuple[int, int]: What epoch this is, and sample offset in that epoch. """ + # Reference the same shared memory object in a worker process + self._next_epoch_arr = np.ndarray(1, buffer=self._next_epoch_shm.buf, dtype=np.int64) + # Either resume from checkpoint, or start from scratch. presumed_epoch = self.next_epoch epoch, sample_in_epoch = self._resume(world, presumed_epoch)