Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Spawn many small reprojection tasks #41

Draft
wants to merge 12 commits into
base: main
Choose a base branch
from
Draft
4 changes: 2 additions & 2 deletions example_runtime_config.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@ checkpoint_mode = 'task_exit'
[apps.create_manifest]
# The path to the staging directory
# e.g. "/gscratch/dirac/kbmod/workflow/staging"
staging_directory = "/home/drew/code/kbmod-wf/dev_staging"
output_directory = "/home/drew/code/kbmod-wf/dev_staging/single_chip_workflow"
staging_directory = "/Users/drew/code/kbmod-wf/dev_staging"
output_directory = "/Users/drew/code/kbmod-wf/dev_staging/processing"
file_pattern = "*.collection"


Expand Down
168 changes: 168 additions & 0 deletions src/kbmod_wf/parallel_repro_single_chip_wf.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,168 @@
import argparse
import os
from pathlib import Path

import toml
import parsl
from parsl import python_app, File
import parsl.executors

from kbmod_wf.utilities import (
apply_runtime_updates,
get_resource_config,
get_executors,
get_configured_logger,
)

from kbmod_wf.workflow_tasks import create_manifest, ic_to_wu_return_shards, kbmod_search


@python_app(
cache=True,
executors=get_executors(["local_dev_testing", "reproject_single_shard"]),
ignore_for_cache=["logging_file"],
)
def reproject_shard(inputs=(), outputs=(), wcses=None, runtime_config={}, logging_file=None):
from kbmod_wf.utilities.logger_utilities import get_configured_logger, ErrorLogger

logger = get_configured_logger("task.sharded_reproject", logging_file.filepath)

from kbmod_wf.task_impls.reproject_single_chip_single_night_wu_shard import reproject_shard

logger.info("Starting reproject_ic")
with ErrorLogger(logger):
reproject_shard(
original_wu_shard_filepath=inputs[0].filepath,
original_wcs=wcses,
reprojected_wu_shard_filepath=outputs[0].filepath,
runtime_config=runtime_config,
logger=logger,
)
logger.info("Completed reproject_ic")
return outputs[0]


def workflow_runner(env=None, runtime_config={}):
"""This function will load and configure Parsl, and run the workflow.

Parameters
----------
env : str, optional
Environment string used to define which resource configuration to use,
by default None
runtime_config : dict, optional
Dictionary of assorted runtime configuration parameters, by default {}
"""
resource_config = get_resource_config(env=env)
resource_config = apply_runtime_updates(resource_config, runtime_config)

app_configs = runtime_config.get("apps", {})

dfk = parsl.load(resource_config)
if dfk:
logging_file = File(os.path.join(dfk.run_dir, "kbmod.log"))
logger = get_configured_logger("workflow.workflow_runner", logging_file.filepath)

if runtime_config is not None:
logger.info(f"Using runtime configuration definition:\n{toml.dumps(runtime_config)}")

logger.info("Starting workflow")

# gather all the *.collection files that are staged for processing
create_manifest_config = app_configs.get("create_manifest", {})
manifest_file = File(
os.path.join(create_manifest_config.get("output_directory", os.getcwd()), "manifest.txt")
)

create_manifest_future = create_manifest(
inputs=[],
outputs=[manifest_file],
runtime_config=app_configs.get("create_manifest", {}),
logging_file=logging_file,
)

with open(create_manifest_future.result(), "r") as manifest:
# process each .collection file in the manifest
original_work_unit_futures = []
for line in manifest:
# Create path object for the line in the manifest
input_file = Path(line.strip())

# Create a directory to contain each work unit's shards
sharded_directory = Path(input_file.parent, input_file.stem)
sharded_directory.mkdir(exist_ok=True)

# Construct the work unit filepath
output_workunit_filepath = Path(sharded_directory, input_file.stem + ".wu")

# Create the work unit future
original_work_unit_futures.append(
ic_to_wu_return_shards(
inputs=[File(str(input_file))],
outputs=[File(str(output_workunit_filepath))],
runtime_config=app_configs.get("ic_to_wu", {}),
logging_file=logging_file,
)
)

# reproject each WorkUnit shard individually
# For chip-by-chip, this isn't really necessary, so hardcoding to 0.
reproject_futures = []
for f in original_work_unit_futures:
shard_futures = []
shard_files, wcses = f.result()
for i in shard_files:
shard_file = Path(i)
shard_future = reproject_shard(
inputs=[File(str(shard_file))],
outputs=[File(str(shard_file.parent / (shard_file.stem + ".repro")))],
wcses=wcses,
runtime_config=app_configs.get("reproject_wu", {}),
logging_file=logging_file,
)
shard_futures.append(shard_future)
reproject_futures.append(shard_futures)

# run kbmod search on each reprojected WorkUnit
search_futures = []
for f in reproject_futures:
search_futures.append(
kbmod_search(
inputs=[i.result() for i in f],
outputs=[],
runtime_config=app_configs.get("kbmod_search", {}),
logging_file=logging_file,
)
)

[f.result() for f in search_futures]

logger.info("Workflow complete")

parsl.clear()


if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--env",
type=str,
choices=["dev", "klone"],
help="The environment to run the workflow in.",
)

parser.add_argument(
"--runtime-config",
type=str,
help="The complete runtime configuration filepath to use for the workflow.",
)

args = parser.parse_args()

# if a runtime_config file was provided and exists, load the toml as a dict.
runtime_config = {}
if args.runtime_config is not None and os.path.exists(args.runtime_config):
with open(args.runtime_config, "r") as toml_runtime_config:
runtime_config = toml.load(toml_runtime_config)

workflow_runner(env=args.env, runtime_config=runtime_config)
11 changes: 1 addition & 10 deletions src/kbmod_wf/resource_configs/dev_configuration.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,5 @@ def dev_resource_config():
return Config(
# put the log files in in the top level folder, "run_logs".
run_dir=os.path.join(project_dir, "run_logs", datetime.date.today().isoformat()),
app_cache=True,
checkpoint_mode="task_exit",
checkpoint_files=get_all_checkpoints(
os.path.join(project_dir, "run_logs", datetime.date.today().isoformat())
),
executors=[
ThreadPoolExecutor(
label="local_dev_testing",
)
],
executors=[ThreadPoolExecutor(label="local_dev_testing", max_threads=3)],
)
20 changes: 20 additions & 0 deletions src/kbmod_wf/resource_configs/klone_configuration.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
"compute_bigmem": "01:00:00",
"large_mem": "04:00:00",
"sharded_reproject": "04:00:00",
"reproject_single_shard": "00:30:00",
"gpu_max": "08:00:00",
}

Expand Down Expand Up @@ -80,6 +81,25 @@ def klone_resource_config():
worker_init="",
),
),
HighThroughputExecutor(
label="reproject_single_shard",
max_workers=1,
provider=SlurmProvider(
partition="ckpt-g2",
account="astro",
min_blocks=0,
max_blocks=256,
init_blocks=0,
parallelism=1,
nodes_per_block=1,
cores_per_node=1,
mem_per_node=1, # only working on 1 image, so <1 GB should be required
exclusive=False,
walltime=walltimes["reproject_single_shard"],
# Command to run before starting worker - i.e. conda activate <special_env>
worker_init="",
),
),
HighThroughputExecutor(
label="gpu",
max_workers=1,
Expand Down
8 changes: 7 additions & 1 deletion src/kbmod_wf/task_impls/ic_to_wu.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,4 +83,10 @@ def create_work_unit(self):
elapsed = round(time.time() - last_time, 1)
self.logger.debug(f"Required {elapsed}[s] to write WorkUnit to disk: {self.wu_filepath}")

return self.wu_filepath
# All of the WCS information is maintained in the header of the ImageCollection
# as a Astropy table Column. Here we unwrap the column to create a list of strings
# Each string can then be converted into a WCS object when needed (for reprojection)
# using: `wcs_objects = [WCS(json.loads(i)) for i in wcses]`
wcses = [i for i in ic.data['wcs']]

return self.wu_filepath, wcses
141 changes: 141 additions & 0 deletions src/kbmod_wf/task_impls/reproject_single_chip_single_night_wu_shard.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,141 @@
import json
import os
import time
from logging import Logger

import numpy as np
import astropy.io.fits as fitsio
from astropy.wcs import WCS

# import kbmod
# from kbmod.work_unit import WorkUnit
# import kbmod.reprojection as reprojection

from reproject import reproject_adaptive
from reproject.mosaicking import find_optimal_celestial_wcs


def reproject_shard(
original_wu_shard_filepath: str = None,
original_wcs=None,
reprojected_wu_shard_filepath: str = None,
runtime_config: dict = {},
logger: Logger = None,
):
"""This task will reproject a WorkUnit to a common WCS.

Parameters
----------
original_wu_shard_filepath : str, optional
The fully resolved filepath to the input WorkUnit file, by default None
reprojected_wu_shard_filepath : str, optional
The fully resolved filepath to the resulting WorkUnit file after
reprojection, by default None
runtime_config : dict, optional
Additional configuration parameters to be used at runtime, by default {}
logger : Logger, optional
Primary logger for the workflow, by default None

Returns
-------
str
The fully resolved filepath of the resulting WorkUnit file after reflex
and reprojection.
"""

wcs_list = [WCS(json.loads(wcs), relax=True) for wcs in original_wcs]

opt_wcs, shape = find_optimal_celestial_wcs(wcs_list)
opt_wcs.array_shape = shape

shard = fitsio.open(original_wu_shard_filepath)
shard_wcs = WCS(shard[0].header)
shard[1].header.update(shard_wcs.to_header())
shard[2].header.update(shard_wcs.to_header())

sci = reproject_adaptive(
shard,
opt_wcs,
hdu_in=0,
shape_out=opt_wcs.array_shape,
bad_value_mode="ignore",
roundtrip_coords=False,
)

var = reproject_adaptive(
shard,
opt_wcs,
hdu_in=1,
shape_out=opt_wcs.array_shape,
bad_value_mode="ignore",
roundtrip_coords=False,
)

mask = reproject_adaptive(
shard,
opt_wcs,
hdu_in=2,
shape_out=opt_wcs.array_shape,
bad_value_mode="ignore",
roundtrip_coords=False,
)

shard[0].data = sci.astype(np.float32)
shard[1].data = var.astype(np.float32)
shard[2].data = mask.astype(np.float32)

shard.write(original_wu_shard_filepath)

with open(reprojected_wu_shard_filepath, "w") as f:
f.write(f"Reprojected: {original_wu_shard_filepath}")

return original_wu_shard_filepath


# class WUShardReprojector:
# def __init__(
# self,
# original_wu_filepath: str = None,
# reprojected_wu_filepath: str = None,
# runtime_config: dict = {},
# logger: Logger = None,
# ):
# self.original_wu_filepath = original_wu_filepath
# self.reprojected_wu_filepath = reprojected_wu_filepath
# self.runtime_config = runtime_config
# self.logger = logger

# # Default to 8 workers if not in the config. Value must be 0<num workers<65.
# self.n_workers = max(1, min(self.runtime_config.get("n_workers", 8), 64))

# def reproject_workunit_shard(self):
# last_time = time.time()
# self.logger.info(f"Lazy reading existing WorkUnit from disk: {self.original_wu_filepath}")
# directory_containing_shards, wu_filename = os.path.split(self.original_wu_filepath)
# wu = WorkUnit.from_sharded_fits(wu_filename, directory_containing_shards, lazy=True)
# elapsed = round(time.time() - last_time, 1)
# self.logger.info(f"Required {elapsed}[s] to lazy read original WorkUnit {self.original_wu_filepath}.")

# directory_containing_reprojected_shards, reprojected_wu_filename = os.path.split(
# self.reprojected_wu_filepath
# )

# # Reproject to a common WCS using the WCS for our patch
# self.logger.info(f"Reprojecting WorkUnit with {self.n_workers} workers...")
# last_time = time.time()

# opt_wcs, shape = find_optimal_celestial_wcs(list(wu._per_image_wcs))
# opt_wcs.array_shape = shape
# reprojection.reproject_work_unit(
# wu,
# opt_wcs,
# max_parallel_processes=self.n_workers,
# write_output=True,
# directory=directory_containing_reprojected_shards,
# filename=reprojected_wu_filename,
# )

# elapsed = round(time.time() - last_time, 1)
# self.logger.info(f"Required {elapsed}[s] to create the sharded reprojected WorkUnit.")

# return self.reprojected_wu_filepath
Loading
Loading