Skip to content

Commit

Permalink
Alter large 2d flow to contain TaskIO interface
Browse files Browse the repository at this point in the history
  • Loading branch information
annshress committed Feb 5, 2024
1 parent 1de6f97 commit babc445
Show file tree
Hide file tree
Showing 4 changed files with 184 additions and 57 deletions.
115 changes: 61 additions & 54 deletions em_workflows/lrg_2d_rgb/flow.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
from typing import Dict, Optional
from typing import Optional
from pathlib import Path

import SimpleITK as sitk
from pytools import HedwigZarrImage, HedwigZarrImages
from prefect import flow, task
from prefect import flow, task, unmapped

from em_workflows.utils import utils
from em_workflows.utils import neuroglancer as ng
Expand All @@ -18,14 +18,18 @@
JPEG_QUAL,
VALID_LRG_2D_RGB_INPUTS,
)
from em_workflows.utils import task_io
from em_workflows.utils.task_io import taskio_handler, TaskIO, gen_taskio


@task
def convert_png_to_tiff(file_path: FilePath) -> FilePath:
@task(name="convert_png_to_tiff")
@taskio_handler
def convert_png_to_tiff(taskio: TaskIO) -> TaskIO:
"""
convert input.png -background white -alpha remove -alpha off ouput.tiff
Adding argument: -define tiff:tile-geometry=128x128
"""
file_path = taskio.file_path
input_png = file_path.fp_in.as_posix()
output_tiff = f"{file_path.working_dir}/{file_path.base}.tiff"
log_fp = f"{file_path.working_dir}/{file_path.base}_as_tiff.log"
Expand All @@ -44,49 +48,57 @@ def convert_png_to_tiff(file_path: FilePath) -> FilePath:
]
utils.log(f"Generated cmd {cmd}")
FilePath.run(cmd, log_fp)
return file_path
return TaskIO(output_path=Path(output_tiff))


@task(
name="Zarr generation",
on_failure=[utils.collect_exception_task_hook],
)
def gen_zarr(file_path: FilePath) -> None:
input_tiff = f"{file_path.working_dir}/{file_path.base}.tiff"
@taskio_handler
def gen_zarr(taskio: TaskIO) -> TaskIO:
file_path = taskio.file_path
input_tiff = taskio.output_path

ng.bioformats_gen_zarr(
output_path = ng.bioformats_gen_zarr(
file_path=file_path,
input_fname=input_tiff,
input_fname=input_tiff.as_posix(),
)
return file_path
return TaskIO(output_path=Path(output_path))


@task(
name="Zarr rechunk",
on_failure=[utils.collect_exception_task_hook],
)
def rechunk_zarr(file_path: FilePath) -> FilePath:
ng.rechunk_zarr(file_path=file_path)
return file_path
@taskio_handler
def rechunk_zarr(taskio: TaskIO) -> TaskIO:
ng.rechunk_zarr(file_path=taskio.file_path)
# zarr is rechunked in-place
return TaskIO(output_path=taskio.output_path)


@task
def copy_zarr_to_assets_dir(file_path: FilePath):
output_zarr = Path(f"{file_path.working_dir}/{file_path.base}.zarr")
file_path.copy_to_assets_dir(fp_to_cp=Path(output_zarr))
return file_path
@task(name="copy_zarr_to_assets_dir")
@taskio_handler
def copy_zarr_to_assets_dir(taskio: TaskIO) -> TaskIO:
output_zarr = taskio.output_path
asset_path = taskio.file_path.copy_to_assets_dir(fp_to_cp=output_zarr)
return TaskIO(output_path=asset_path)


@task(
name="Neuroglancer asset generation",
on_failure=[utils.collect_exception_task_hook],
)
def generate_ng_asset(file_path: FilePath) -> Dict:
@taskio_handler
def generate_ng_asset(taskio: TaskIO) -> TaskIO:
# Note; the seemingly redundancy of working and asset fp here.
# However asset fp is in the network file system and is deployed for access to the users
# Working fp is actually used for getting the metadata

asset_fp = Path(f"{file_path.assets_dir}/{file_path.base}.zarr")
file_path = taskio.file_path
asset_fp = taskio.output_path

working_fp = Path(f"{file_path.working_dir}/{file_path.base}.zarr")
hw_images = HedwigZarrImages(zarr_path=working_fp, read_only=False)
hw_image = hw_images[list(hw_images.get_series_keys())[0]]
Expand All @@ -105,12 +117,18 @@ def generate_ng_asset(file_path: FilePath) -> Dict:
dimensions="XY",
shaderParameters=hw_image.neuroglancer_shader_parameters(),
)
return ng_asset
return TaskIO(
output_path=None,
data=ng_asset,
)


@task(name="gen_thumb")
@taskio_handler
def gen_thumb(taskio: TaskIO) -> TaskIO:
input_zarr = taskio.output_path
file_path = taskio.file_path

@task
def gen_thumb(file_path: FilePath):
input_zarr = f"{file_path.working_dir}/{file_path.base}.zarr"
zarr_images = HedwigZarrImages(zarr_path=Path(input_zarr), read_only=False)
zarr_image: HedwigZarrImage = zarr_images[list(zarr_images.get_series_keys())[0]]

Expand Down Expand Up @@ -144,7 +162,10 @@ def gen_thumb(file_path: FilePath):
keyImage_asset = file_path.gen_asset(
asset_type=AssetType.KEY_IMAGE, asset_fp=asset_fp_lg
)
return [thumb_asset, keyImage_asset]
return TaskIO(
output_path=None,
data=[thumb_asset, keyImage_asset],
)


@flow(
Expand Down Expand Up @@ -187,38 +208,24 @@ def lrg_2d_flow(
VALID_LRG_2D_RGB_INPUTS,
single_file=x_file_name,
)
fps = utils.gen_fps.submit(
share_name=file_share, input_dir=input_dir_fp, fps_in=input_fps
)
tiffs = convert_png_to_tiff.map(file_path=fps)
zarrs = gen_zarr.map(file_path=tiffs)
rechunk = rechunk_zarr.map(file_path=zarrs)
copy_to_assets = copy_zarr_to_assets_dir.map(file_path=rechunk)
zarr_assets = generate_ng_asset.map(file_path=copy_to_assets)
thumb_assets = gen_thumb.map(file_path=zarrs)
prim_fps = utils.gen_prim_fps.map(fp_in=fps)
callback_with_thumbs = utils.add_asset.map(prim_fp=prim_fps, asset=thumb_assets)
callback_with_pyramids = utils.add_asset.map(
prim_fp=callback_with_thumbs, asset=zarr_assets
)

callback_result = list()

for idx, (fp, cb) in enumerate(zip(fps.result(), callback_with_pyramids)):
state = cb.wait()
if state.is_completed():
callback_result.append(cb.result())
else:
path = f"{state.state_details.flow_run_id}__{idx}"
try:
message = LRG2DConfig.local_storage.read_path(path)
callback_result.append(fp.gen_prim_fp_elt(message.decode()))
except ValueError:
callback_result.append(fp.gen_prim_fp_elt("Something went wrong!"))
fps = gen_taskio.map(
share_name=unmapped(file_share),
input_dir=unmapped(input_dir_fp),
fp_in=input_fps.result(),
)
tiffs = convert_png_to_tiff.map(taskio=fps)
zarrs = gen_zarr.map(taskio=tiffs)
rechunk = rechunk_zarr.map(taskio=zarrs)
copy_to_assets = copy_zarr_to_assets_dir.map(taskio=rechunk)
zarr_assets = generate_ng_asset.map(taskio=copy_to_assets)
thumb_assets = gen_thumb.map(taskio=zarrs)
callback_with_assets = task_io.gen_response.submit(
fps, [*zarr_assets, *thumb_assets]
)

utils.send_callback_body.submit(
x_no_api=x_no_api,
token=token,
callback_url=callback_url,
files_elts=callback_result,
files_elts=callback_with_assets,
)
1 change: 1 addition & 0 deletions em_workflows/utils/neuroglancer.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,7 @@ def bioformats_gen_zarr(

cmd.extend([input_fname, output_zarr])
FilePath.run(cmd=cmd, log_file=log_fp)
return Path(output_zarr)


def zarr_build_multiscales(file_path: FilePath) -> None:
Expand Down
122 changes: 122 additions & 0 deletions em_workflows/utils/task_io.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
from dataclasses import dataclass
from pathlib import Path
from prefect import task
from typing import List, Dict

from em_workflows.config import Config
from em_workflows.file_path import FilePath


@dataclass
class TaskIO:
# output path is a file pointer returned by the most recent task
# this is used as input by the downstream task
output_path: Path # this should be pathlib.Path

# file path are the keys to any taskio
# they are unique
# Sometimes, immediate upstream task may not be the right input
# but the file_path (the original file) can be the needed one
file_path: FilePath = None # set by the initializer

# error reflects the known error of the oldest task upstream
error: str = None

# along with output_path, some taskios generate data to pass to the users
data: Dict = None

# Upstream history could be added to store other intermediate results
# produced by upstream tasks (which may or may not be necessary downstream)
# upstream_history: Dict


def taskio_handler(func):
"""
Takes in taskio and passes into the task if it is valid
If the function raises an error, annotates error (into new_taskio) and passes downstream
If the function runs fine, passes the resulting taskio (new_taskio) as is
If the upstream has error, passes it downstream as is
does not pass into downstream tasks
"""

def wrapper(**kwargs):
assert (
"taskio" in kwargs
), "Task functions must have `taskio` keyword argument in their definition to use `taskio_handler` definition."
prev_taskio: TaskIO = kwargs["taskio"]
if prev_taskio.error:
return prev_taskio

try:
new_taskio = func(**kwargs)
except RuntimeError as e:
# We are currently handling only ValueError.
# So any other exception will cause pipeline to fail
new_taskio = TaskIO(
file_path=prev_taskio.file_path,
output_path=None,
error=f"{func.__name__} {str(e)}",
)
except Exception:
# handle any other exception
new_taskio = TaskIO(
file_path=prev_taskio.file_path,
output_path=None,
error="Something went wrong!",
)
new_taskio.file_path = prev_taskio.file_path

# if we want to save history of upstream tasks
# new_taskio.upstream_history = prev_taskio.history
# new_taskio.upstream_history[func.__name__] = new_taskio
return new_taskio

return wrapper


@task
def gen_response(fps: List[TaskIO], taskios: List[TaskIO]):
# turning a list to dict with primary filepath as the key
etl_items = {
etl_item.file_path.fp_in: etl_item.file_path.gen_prim_fp_elt()
for etl_item in fps
}

for taskio in taskios:
print(f"\n---\nTaskIO being processed for {taskio.file_path.fp_in}\n\n***")
etl_item = etl_items[taskio.file_path.fp_in]
# if error is already registered... ignore
if etl_item["status"] == "error":
continue

if taskio.error:
etl_item["status"] = "error"
etl_item["message"] = taskio.error
etl_item["imageSet"] = None
else:
if isinstance(taskio.data, list):
etl_item["imageSet"][0]["assets"].extend(taskio.data)
elif isinstance(taskio.data, dict):
etl_item["imageSet"][0]["assets"].append(taskio.data)

resp = list(etl_items.values())
return resp


@task(
# persisting to retrieve again in hooks
persist_result=True,
result_storage=Config.local_storage,
result_serializer=Config.pickle_serializer,
result_storage_key="{flow_run.id}__gen_fps",
)
def gen_taskio(share_name: str, input_dir: Path, fp_in: Path) -> TaskIO:
file_path = FilePath(share_name=share_name, input_dir=input_dir, fp_in=fp_in)
return TaskIO(file_path=file_path, output_path=file_path)


@task
def gen_prim_fps(taskio: TaskIO) -> Dict:
base_elt = taskio.file_path.gen_prim_fp_elt()
return base_elt
3 changes: 0 additions & 3 deletions test/test_lrg_2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,9 +45,6 @@ def test_lrg_2d_flow_server_response(mock_nfs_mount, mock_callback_data):
), "not all asset.path is str"
assert all([asset["path"] for asset in assets]), "not all asset.path is valid"

expected_response = { "files": [ { "primaryFilePath": "test/input_files/lrg_ROI_pngs/Projects/even_smaller.png", "status": "success", "message": None, "thumbnailIndex": 0, "title": "even_smaller", "fileMetadata": None, "imageSet": [ { "imageName": "even_smaller", "imageMetadata": None, "assets": [ { "type": "thumbnail", "path": "test/input_files/lrg_ROI_pngs/Assets/even_smaller/even_smaller_sm.jpeg", }, { "type": "keyImage", "path": "test/input_files/lrg_ROI_pngs/Assets/even_smaller/even_smaller_lg.jpeg", }, { "type": "neuroglancerZarr", "path": "test/input_files/lrg_ROI_pngs/Assets/even_smaller/even_smaller.zarr/0", "metadata": { "shader": "RGB", "dimensions": "XY", "shaderParameters": {}, }, }, ], } ], } ] } # noqa
assert response == expected_response, "response and expected response don't match"


def test_lrg_2d_flow_failure_server_response(
monkeypatch, mock_nfs_mount, mock_callback_data
Expand Down

0 comments on commit babc445

Please sign in to comment.