Skip to content
This repository has been archived by the owner on Oct 19, 2024. It is now read-only.

Commit

Permalink
I think the callback works now
Browse files Browse the repository at this point in the history
  • Loading branch information
jonasteuwen committed Apr 18, 2024
1 parent 9b27d1a commit c5c8f49
Show file tree
Hide file tree
Showing 3 changed files with 52 additions and 31 deletions.
18 changes: 14 additions & 4 deletions ahcore/callbacks/converters/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,21 +63,28 @@ def worker(self) -> None:
self._results_queue.put(None) # Signal that this worker is done
break
filename, cache_filename = task
logger.info("Processing task: %s %s", filename, cache_filename)

result = self.process_task(filename, cache_filename)
logger.info("Task completed: %s (from worker)", result)
self._results_queue.put(result) # Store the result

@property
def dump_dir(self) -> Path:
return self._dump_dir

def collect_results(self):
"""Yield results from the results queue as they arrive."""
logger.info("Collecting...")
while True:
logger.info("Collecting results")
finished_workers = 0
while finished_workers < self._max_concurrent_tasks:
result = self._results_queue.get()
logger.info("Result: %s", result)
if result is None:
break
finished_workers += 1
logger.info(f"Worker completed, total finished: {finished_workers}")
if finished_workers == self._max_concurrent_tasks:
logger.info("All workers have completed.")
continue
yield result

def start(self, filename: str) -> None:
Expand All @@ -95,7 +102,10 @@ def start(self, filename: str) -> None:
self.schedule_task(filename=Path(filename), cache_filename=cache_filename)

def shutdown_workers(self):
logger.info("Shutting down workers...")
for _ in range(self._max_concurrent_tasks):
self._task_queue.put(None) # Send shutdown signal
for worker in self._workers:
worker.join() # Wait for all workers to finish
logger.info("Workers shut down.")

42 changes: 19 additions & 23 deletions ahcore/callbacks/converters/wsi_metric_callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from collections import namedtuple
from multiprocessing import Process
from multiprocessing.pool import Pool
from typing import NamedTuple
from pathlib import Path
from typing import Any, Generator, Optional, Type, cast

Expand Down Expand Up @@ -74,7 +75,6 @@ def setup(self, callback: WriterCallback, trainer: pl.Trainer, pl_module: pl.Lig
logger.info("Data description: %s", self._data_description)
logger.info("Data dir: %s", self._data_description.data_dir)

# self._data_dir = self._data_description.data_dir

# For mypy
assert self._data_description
Expand All @@ -89,7 +89,6 @@ def setup(self, callback: WriterCallback, trainer: pl.Trainer, pl_module: pl.Lig

# Here we can query the database for the validation images
self._data_manager: DataManager = trainer.datamodule.data_manager # type: ignore
# self._validate_metadata_gen = self._create_validate_image_metadata_gen()

self._callback = callback
self._trainer = trainer
Expand All @@ -98,27 +97,23 @@ def setup(self, callback: WriterCallback, trainer: pl.Trainer, pl_module: pl.Lig

self._dump_dir = self._callback.dump_dir
self._data_dir = self._pl_module.data_description.data_dir
self._filenames_seen = []
# MAYBE CAN CALL
# Start worker processes

for _ in range(self._max_concurrent_tasks):
process = Process(target=self.worker)
process.start()
self._workers.append(process)

def compute_metrics(self, filename, pl_module: pl.LightningModule):
def process_task(self, filename: Path, cache_filename: Path):
# So we have the filename of the image, but now we need to get it's metadata

task_data = prepare_task_data(
filename,
self._dump_dir,
self._data_dir,
pl_module,
self._pl_module,
self._data_description,
self._data_manager,
)
logger.info("Computing metrics for %s", filename)
logger.info("Task data: %s", task_data)

curr_metrics = compute_metrics_for_case(
task_data=task_data,
Expand All @@ -129,17 +124,17 @@ def compute_metrics(self, filename, pl_module: pl.LightningModule):
save_per_image=self._save_per_image,
)

self._results_queue.put(curr_metrics)

# TODO: Ajey, you can put the results in the Manager.Queue from base
self._filenames_seen.append(filename)
logger.info("Metrics: %s", curr_metrics)
logger.info("Metrics putting in queue: %s (and returning from process_task)", curr_metrics)
return curr_metrics

def process_task(self, filename: Path, cache_filename: Path) -> None:
self.compute_metrics(filename, self._pl_module)

class WsiMetricTaskData(NamedTuple):
filename: Path
cache_filename: Path
metadata: ImageMetadata
mask: Optional[Any] = None
annotations: Optional[Any] = None

TaskData = namedtuple("TaskData", ["filename", "cache_filename", "metadata", "mask", "annotations"])


def prepare_task_data(
Expand All @@ -149,7 +144,7 @@ def prepare_task_data(
pl_module: pl.LightningModule,
data_description: DataDescription,
data_manager: DataManager,
) -> TaskData:
) -> WsiMetricTaskData:
cache_filename = get_output_filename(
dump_dir=dump_dir,
input_path=data_dir / filename,
Expand All @@ -161,11 +156,11 @@ def prepare_task_data(
metadata = fetch_image_metadata(image)
mask, annotations = get_mask_and_annotations_from_record(data_description.annotations_dir, image)

return TaskData(filename, cache_filename, metadata, mask, annotations)
return WsiMetricTaskData(filename, cache_filename, metadata, mask, annotations)


def compute_metrics_for_case(
task_data: TaskData,
task_data: WsiMetricTaskData,
image_reader: Type[FileImageReader],
class_names: dict[int, str],
data_description: DataDescription,
Expand All @@ -175,8 +170,6 @@ def compute_metrics_for_case(
# Extract the data from the namedtuple
filename, cache_filename, metadata, mask, annotations = task_data

logger.info("Computing metrics for %s", filename)

with image_reader(cache_filename, stitching_mode=StitchingMode.CROP) as cache_reader:
dataset_of_validation_image = _ValidationDataset(
data_description=data_description,
Expand All @@ -200,6 +193,7 @@ def compute_metrics_for_case(
wsi_metrics_dictionary = {
"image_fn": str(data_description.data_dir / metadata.filename),
"uuid": filename.stem,
"metrics": {},
}

if filename.with_suffix(".tiff").is_file():
Expand All @@ -208,9 +202,11 @@ def compute_metrics_for_case(
wsi_metrics_dictionary["cache_fn"] = str(filename)
for metric in wsi_metrics._metrics:
metric.get_wsi_score(str(filename))
wsi_metrics_dictionary[metric.name] = {
wsi_metrics_dictionary["metrics"][metric.name] = {
class_names[class_idx]: metric.wsis[str(filename)][class_idx][metric.name].item()
for class_idx in range(data_description.num_classes)
}

logger.info("Returning these metrics: %s", wsi_metrics_dictionary)

return wsi_metrics_dictionary
23 changes: 19 additions & 4 deletions ahcore/callbacks/writer_callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -362,19 +362,34 @@ def _epoch_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> None
self._cleanup_shutdown_event.set()
self._tile_counter = {}

def on_validation_epoch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
logger.info("Ending epoch...")
# TODO: There must be a mechanism that this can be done during the batch processing? Or should there be a signal?
for callback in self._callbacks:
if not callback.has_returns:
continue
from pprint import pformat

callback.shutdown_workers()

output_metrics = {}

results = callback.collect_results()
print(callback._filenames_seen)
for result in results:

for idx, result in enumerate(results):
metrics = result["metrics"]
if idx == 0:
for key, value in metrics.items():
if key not in output_metrics:
output_metrics[key] = 0.0

for key, value in metrics.items():
output_metrics[key] += value

if result is None:
break
logger.info("Results: %s", pformat(result))

reduced_metrics = {k: v / idx for k, v in output_metrics.items()}
pl_module.log_dict(reduced_metrics, on_step=False, on_epoch=True, prog_bar=True, logger=True)

def on_validation_epoch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
self._epoch_end(trainer, pl_module)
Expand Down

0 comments on commit c5c8f49

Please sign in to comment.