Skip to content
This repository has been archived by the owner on Oct 19, 2024. It is now read-only.

OOM error during inference. #19

Closed
AjeyPaiK opened this issue Nov 18, 2023 · 12 comments
Closed

OOM error during inference. #19

AjeyPaiK opened this issue Nov 18, 2023 · 12 comments
Assignees
Labels
bug Something isn't working

Comments

@AjeyPaiK
Copy link
Member

AjeyPaiK commented Nov 18, 2023

Describe the bug
I am using the WriteH5Callback at inference time. I tracked the RSS memory which gets utilised during inference. Below is what I found.

image

After writing each H5 file corresponding to one WSI from the inference dataset:

  1. The RSS memory held up by the callback increases consistently.
  2. In this example, I deliberately requested memory of 400GB and surely enough, the job throws an OOM error once this limit is crossed (check below).

image
image

To Reproduce

Run the following after configuring this version of ahcore

python /ahcore/tools/inference.py callbacks=default data_description=tissue_subtypes/segmentation_inference datamodule=dataset datamodule.num_workers=16 datamodule.batch_size=8 pre_transform=segmentation augmentations=segmentation lit_module=monai_segmentation/attention_unet

Expected behavior
The RSS memory shouldn't keep increasing with every new image while performing inference.

Environment
dlup version: 0.3.32
Python version: 3.10
Operating System: Linux

Additional Context

I am trying to run inference on a large batch of WSIs from a clinical dataset (n=1072) with my trained models. That's when I encountered this problem

@AjeyPaiK AjeyPaiK added the bug Something isn't working label Nov 18, 2023
@AjeyPaiK AjeyPaiK self-assigned this Nov 18, 2023
@AjeyPaiK
Copy link
Member Author

AjeyPaiK commented Nov 21, 2023

I investigated this a little bit more and I am fairly certain that we don't have a bug in our callback. But let me know what you think @jonasteuwen. I did the following things:

  1. Kept track of the reference counts to H5FileImageWriter and ensured they are explicitly dereferenced after each WSI output have been stored to disk by adding the following line within the process management:

    self._writers[self._current_filename] = {}

  2. Explicitly deleted variables or references to classes which may be leaking memory. For example, I delete self._data at the end of consume function in H5FileImageWriter. This was immediately followed by by a gc.collect() step.

  3. Line-by-line profiling of the callback to analyse which line creates memory resources. Check the snippet below:

Filename: /home/a.karkala/ahcore/ahcore/callbacks.py

Line #    Mem usage    Increment  Occurrences   Line Contents
=============================================================
   367   4262.4 MiB   4262.4 MiB           1       @profile
   368                                             def _batch_end(
   369                                                     self,
   370                                                     trainer: pl.Trainer,
   371                                                     pl_module: pl.LightningModule,
   372                                                     outputs: Any,
   373                                                     batch: Any,
   374                                                     batch_idx: int,
   375                                                     stage: str,
   376                                                     dataloader_idx: int = 0,
   377                                             ) -> None:
   378   4262.4 MiB      0.0 MiB           1           filename = batch["path"][0]  # Filenames are constant across the batch.
   379   4262.4 MiB      0.0 MiB          11           if any([filename != path for path in batch["path"]]):
   380                                                     raise ValueError(
   381                                                         "All paths in a batch must be the same. "
   382                                                         "Either use batch_size=1 or ahcore.data.samplers.WsiBatchSampler."
   383                                                     )
   384                                         
   385   4262.4 MiB      0.0 MiB           1           if filename != self._current_filename:
   386                                                     output_filename = _get_h5_output_filename(
   387                                                         self.dump_dir,
   388                                                         filename,
   389                                                         model_name=str(pl_module.name),
   390                                                         step=pl_module.global_step,
   391                                                     )
   392                                                     output_filename.parent.mkdir(parents=True, exist_ok=True)
   393                                                     link_fn = (
   394                                                             self.dump_dir / "outputs" / f"{pl_module.name}" / f"step_{pl_module.global_step}" / "image_h5_link.txt"
   395                                                     )
   396                                                     with open(link_fn, "a" if link_fn.is_file() else "w") as file:
   397                                                         file.write(f"{filename},{output_filename}\n")
   398                                         
   399                                                     self._logger.debug("%s -> %s", filename, output_filename)
   400                                                     if self._current_filename is not None:
   401                                                         self.__process_management()
   402                                                         self._semaphore.release()
   403                                         
   404                                                     self._semaphore.acquire()
   405                                         
   406                                                     if stage == "validate":
   407                                                         total_dataset: ConcatDataset = trainer.datamodule.validate_dataset  # type: ignore
   408                                                     elif stage == "predict":
   409                                                         total_dataset: ConcatDataset = trainer.predict_dataloaders.dataset  # type: ignore
   410                                                     else:
   411                                                         raise NotImplementedError(f"Stage {stage} is not supported for {self.__class__.__name__}.")
   412                                         
   413                                                     current_dataset: TiledWsiDataset
   414                                                     current_dataset, _ = total_dataset.index_to_dataset(self._dataset_index)  # type: ignore
   415                                                     slide_image = current_dataset.slide_image
   416                                         
   417                                                     data_description: DataDescription = pl_module.data_description  # type: ignore
   418                                                     inference_grid: GridDescription = data_description.inference_grid
   419                                         
   420                                                     mpp = inference_grid.mpp
   421                                                     if mpp is None:
   422                                                         mpp = slide_image.mpp
   423                                         
   424                                                     scaling = slide_image.get_scaling(mpp)
   425                                                     slide_image.close()
   426                                                     # Below, we set the flag use_limit_bounds to account for the slide bounds used during dataset creation.
   427                                                     size = slide_image.get_scaled_size(scaling, use_limit_bounds=self._limit_bounds)
   428                                                     num_samples = len(current_dataset)
   429                                         
   430                                                     # Let's get the data_description, so we can figure out the tile size and things like that
   431                                                     tile_size = inference_grid.tile_size
   432                                                     tile_overlap = inference_grid.tile_overlap
   433                                         
   434                                                     # TODO: We are really putting strange things in the Queue if we may believe mypy
   435                                                     new_queue: Queue[Any] = Queue()  # pylint: disable=unsubscriptable-object
   436                                                     parent_conn, child_conn = Pipe()
   437                                                     new_writer = H5FileImageWriter(
   438                                                         output_filename,
   439                                                         size=size,
   440                                                         mpp=mpp,
   441                                                         tile_size=tile_size,
   442                                                         tile_overlap=tile_overlap,
   443                                                         num_samples=num_samples,
   444                                                         progress=None,
   445                                                     )
   446                                                     self._logger.info(f"Number of references to H5Writer: {len(gc.get_referrers(H5FileImageWriter))}")
   447                                                     new_process = Process(target=new_writer.consume, args=(self.generator(new_queue), child_conn))
   448                                                     new_process.start()
   449                                                     self._writers[filename] = {
   450                                                         "queue": new_queue,
   451                                                         "writer": new_writer,
   452                                                         "process": new_process,
   453                                                         "connection": parent_conn,
   454                                                     }
   455                                                     self._current_filename = filename
   456                                         
   457   4294.3 MiB     31.9 MiB           1           prediction = outputs["prediction"].detach().cpu().numpy()
   458   4294.3 MiB      0.0 MiB           1           coordinates_x, coordinates_y = batch["coordinates"]
   459   4294.3 MiB      0.0 MiB           1           coordinates = torch.stack([coordinates_x, coordinates_y]).T.detach().cpu().numpy()
   460   4294.3 MiB      0.0 MiB           1           self._writers[filename]["queue"].put((coordinates, prediction))
   461   4294.3 MiB      0.0 MiB           1           self._dataset_index += prediction.shape[0]


Filename: /home/a.karkala/ahcore/ahcore/callbacks.py

Line #    Mem usage    Increment  Occurrences   Line Contents
=============================================================
   367   4294.4 MiB   4294.4 MiB           1       @profile
   368                                             def _batch_end(
   369                                                     self,
   370                                                     trainer: pl.Trainer,
   371                                                     pl_module: pl.LightningModule,
   372                                                     outputs: Any,
   373                                                     batch: Any,
   374                                                     batch_idx: int,
   375                                                     stage: str,
   376                                                     dataloader_idx: int = 0,
   377                                             ) -> None:
   378   4294.4 MiB      0.0 MiB           1           filename = batch["path"][0]  # Filenames are constant across the batch.
   379   4294.4 MiB      0.0 MiB          11           if any([filename != path for path in batch["path"]]):
   380                                                     raise ValueError(
   381                                                         "All paths in a batch must be the same. "
   382                                                         "Either use batch_size=1 or ahcore.data.samplers.WsiBatchSampler."
   383                                                     )
   384                                         
   385   4294.4 MiB      0.0 MiB           1           if filename != self._current_filename:
   386                                                     output_filename = _get_h5_output_filename(
   387                                                         self.dump_dir,
   388                                                         filename,
   389                                                         model_name=str(pl_module.name),
   390                                                         step=pl_module.global_step,
   391                                                     )
   392                                                     output_filename.parent.mkdir(parents=True, exist_ok=True)
   393                                                     link_fn = (
   394                                                             self.dump_dir / "outputs" / f"{pl_module.name}" / f"step_{pl_module.global_step}" / "image_h5_link.txt"
   395                                                     )
   396                                                     with open(link_fn, "a" if link_fn.is_file() else "w") as file:
   397                                                         file.write(f"{filename},{output_filename}\n")
   398                                         
   399                                                     self._logger.debug("%s -> %s", filename, output_filename)
   400                                                     if self._current_filename is not None:
   401                                                         self.__process_management()
   402                                                         self._semaphore.release()
   403                                         
   404                                                     self._semaphore.acquire()
   405                                         
   406                                                     if stage == "validate":
   407                                                         total_dataset: ConcatDataset = trainer.datamodule.validate_dataset  # type: ignore
   408                                                     elif stage == "predict":
   409                                                         total_dataset: ConcatDataset = trainer.predict_dataloaders.dataset  # type: ignore
   410                                                     else:
   411                                                         raise NotImplementedError(f"Stage {stage} is not supported for {self.__class__.__name__}.")
   412                                         
   413                                                     current_dataset: TiledWsiDataset
   414                                                     current_dataset, _ = total_dataset.index_to_dataset(self._dataset_index)  # type: ignore
   415                                                     slide_image = current_dataset.slide_image
   416                                         
   417                                                     data_description: DataDescription = pl_module.data_description  # type: ignore
   418                                                     inference_grid: GridDescription = data_description.inference_grid
   419                                         
   420                                                     mpp = inference_grid.mpp
   421                                                     if mpp is None:
   422                                                         mpp = slide_image.mpp
   423                                         
   424                                                     scaling = slide_image.get_scaling(mpp)
   425                                                     slide_image.close()
   426                                                     # Below, we set the flag use_limit_bounds to account for the slide bounds used during dataset creation.
   427                                                     size = slide_image.get_scaled_size(scaling, use_limit_bounds=self._limit_bounds)
   428                                                     num_samples = len(current_dataset)
   429                                         
   430                                                     # Let's get the data_description, so we can figure out the tile size and things like that
   431                                                     tile_size = inference_grid.tile_size
   432                                                     tile_overlap = inference_grid.tile_overlap
   433                                         
   434                                                     # TODO: We are really putting strange things in the Queue if we may believe mypy
   435                                                     new_queue: Queue[Any] = Queue()  # pylint: disable=unsubscriptable-object
   436                                                     parent_conn, child_conn = Pipe()
   437                                                     new_writer = H5FileImageWriter(
   438                                                         output_filename,
   439                                                         size=size,
   440                                                         mpp=mpp,
   441                                                         tile_size=tile_size,
   442                                                         tile_overlap=tile_overlap,
   443                                                         num_samples=num_samples,
   444                                                         progress=None,
   445                                                     )
   446                                                     self._logger.info(f"Number of references to H5Writer: {len(gc.get_referrers(H5FileImageWriter))}")
   447                                                     new_process = Process(target=new_writer.consume, args=(self.generator(new_queue), child_conn))
   448                                                     new_process.start()
   449                                                     self._writers[filename] = {
   450                                                         "queue": new_queue,
   451                                                         "writer": new_writer,
   452                                                         "process": new_process,
   453                                                         "connection": parent_conn,
   454                                                     }
   455                                                     self._current_filename = filename
   456                                         
   457   4326.3 MiB     31.9 MiB           1           prediction = outputs["prediction"].detach().cpu().numpy()
   458   4326.3 MiB      0.0 MiB           1           coordinates_x, coordinates_y = batch["coordinates"]
   459   4326.3 MiB      0.0 MiB           1           coordinates = torch.stack([coordinates_x, coordinates_y]).T.detach().cpu().numpy()
   460   4326.3 MiB      0.0 MiB           1           self._writers[filename]["queue"].put((coordinates, prediction))
   461   4326.3 MiB      0.0 MiB           1           self._dataset_index += prediction.shape[0]
  1. I changed the threshold for garbage collection. By default, the threshold after which a garbage collection is triggered for three generations is (700,10,10). I changed this to (1,1,1) only during the duration of the callback but it didn't help.

The issue we have looks the same as what is reported here. They identified a problem with hdf5 c library against which h5py is compiled. Unfortunately, the issue they pointed out is still not fixed in upstream hdf5. Check here. One of the contributor has taken this up and marked this issue for the next release. check here.

For now, maybe we should compile h5py from source by choosing the non-leaky hdf5 library.

@moerlemans
Copy link
Contributor

image
This version does work for me

@AjeyPaiK
Copy link
Member Author

AjeyPaiK commented Nov 21, 2023

Thanks @moerlemans ! So, your working environment uses hdf5 1.10.6. Let me try with these versions.

@AjeyPaiK
Copy link
Member Author

AjeyPaiK commented Nov 22, 2023

I did some more checks and I am capturing them here. It's now starting to look like it wasn't a problem with h5py or the underlying library either (well, at most it may have added to the trouble). But I installed older versions of h5py with older versions of hdf5 library and the problem still persisted.

Specifically, I tried:
h5py 3.7.0 compiled against hdf5 1.10.6 (same setup as @moerlemans above).
h5py 3.6.0 compiled against hdf5 1.10.6 (just for the heck of it and really rule this out).

Then, I started looking elsewhere. I investigated how much memory the child processes end up taking while doing the h5 writing. Following are some screenshots I made using the top -u <username> command.

output from top while h5 writing for first wsi output from top while h5 writing for second wsi

When the prediction begins for the first image, process with id 3328375 is the parent. 3336839 is the child process. After that's done, the child process is properly exited. During this time, the resident set size and the virtual memory has significantly increased in the parent process. The next child process with id 3387563 inherits the parent (multiprocessing uses forking by default. So, the memory isn't copied but it's shared and copied only when the child is trying to modify it.).

So, this rules out any problem with the multiprocessing we have in place currently. Or in fact, the problem doesn't seem to be in the h5 writing or the callback at all. So, I disabled all the callbacks and simply ran a prediction loop. Much to my own surprise, it crashed very quickly. It turns out there is an open issue on pytorch lightning 2.0 github which is similar (but not exactly the same).

@AjeyPaiK AjeyPaiK changed the title OOM error while using the WriteH5Callback OOM error during inference. Nov 22, 2023
@AjeyPaiK
Copy link
Member Author

AjeyPaiK commented Nov 22, 2023

I downgraded to pytorch_lightning 1.9.1 while retaining the latest pytorch(2.1.1). The issue doesn't seem to go away.

@AjeyPaiK
Copy link
Member Author

I downgraded pytorch to 1.12.1 and the issue didn't go away.

@AjeyPaiK
Copy link
Member Author

image

Here is a memory profile for ahcore during inference from start to the point when the program crashes.

@AjeyPaiK
Copy link
Member Author

AjeyPaiK commented Nov 23, 2023

From the looks of it, it seems like the predictions are not being deallocated after the prediction step. So, to test this hypothesis I return None after predict_step() in the lit_module. The memory profile now looks much better.

Please note that, in this run, my batch size was 256 tiles and all the callbacks were switched off.

image

@AjeyPaiK
Copy link
Member Author

AjeyPaiK commented Nov 23, 2023

@EricMarcus-ai and I looked at the issue today. We disabled virtually everything that could cause memory leaks. Concretely, we did the following:

Commented all the lines within the predict_step() and ran the script. Even though the model isn't being used at this time, the memory usage was through the roof and it eventually crashed.

Just as a sanity check, we also downgraded the pytorch lightning version and repeated the run. It still broke.

    def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> Any:
        if self._augmentations and "predict" in self._augmentations:
            batch = self._augmentations["predict"](batch)

        #_relevant_dict = {k: v for k, v in batch.items() if k in self.RELEVANT_KEYS}
        #batch = {**batch, **self._get_inference_prediction(batch["image"])}
        #pred_debug = self._get_inference_prediction(batch["image"])
        #_prediction = batch["prediction"]
        #output = {"prediction": _prediction, **_relevant_dict}

        # This is a sanity check. We expect the filenames to be constant across the batch.
        filename = batch["path"][0]
        if any([filename != f for f in batch["path"]]):
            raise ValueError("Filenames are not constant across the batch.")
        return None

@AjeyPaiK
Copy link
Member Author

Today, I investigated the role of the return_predictions flag in the trainer.predict(). It looks like when this is set to True, pytorch lightning aggregates results from all batches throughout the epoch hence holding them in main memory. Below, we see this clearly. (note: callbacks are completely switched off).

image
image

@jonasteuwen
Copy link
Contributor

@AjeyPaiK can this be closed?

@AjeyPaiK
Copy link
Member Author

AjeyPaiK commented Dec 8, 2023

Yes. I made the necessary changes in this PR.

@AjeyPaiK AjeyPaiK closed this as completed Dec 8, 2023
@AjeyPaiK AjeyPaiK reopened this Dec 8, 2023
AjeyPaiK added a commit that referenced this issue Dec 8, 2023
Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

3 participants