Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use dataclasses instead of NamedTuple for displacement, stitched outputs #449

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/dolphin/unwrap/_post_process.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ def get_2pi_ambiguities(


def interpolate_masked_gaps(
unw: NDArray[np.float_], ifg: NDArray[np.complex64]
unw: NDArray[np.float64], ifg: NDArray[np.complex64]
) -> None:
"""Perform phase unwrapping using nearest neighbor interpolation of ambiguities.

Expand Down
2 changes: 1 addition & 1 deletion src/dolphin/unwrap/_unwrap_3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,7 +210,7 @@ def filled_masked_unw_regions(

def _reform_wrapped_phase(
unw_filename: PathOrStr, ifg_filenames: Sequence[PathOrStr]
) -> tuple[NDArray[np.float_], NDArray[np.complex64]]:
) -> tuple[NDArray[np.float64], NDArray[np.complex64]]:
"""Load unwrapped phase, and re-calculate the corresponding wrapped phase.

Finds the matching ifg to `unw_filename`, or uses 2 to compute the correct
Expand Down
55 changes: 24 additions & 31 deletions src/dolphin/workflows/displacement.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@
import multiprocessing as mp
from collections import defaultdict
from concurrent.futures import ProcessPoolExecutor
from dataclasses import dataclass
from pathlib import Path
from typing import NamedTuple

from opera_utils import group_by_burst, group_by_date # , get_dates
from tqdm.auto import tqdm
Expand All @@ -25,8 +25,9 @@
logger = logging.getLogger(__name__)


class OutputPaths(NamedTuple):
"""Named tuple of `DisplacementWorkflow` outputs."""
@dataclass
class OutputPaths:
"""Output files of the `DisplacementWorkflow`."""

comp_slc_dict: dict[str, list[Path]]
stitched_ifg_paths: list[Path]
Expand Down Expand Up @@ -188,15 +189,7 @@ def run(
# Is there one best size? dependent on `half_window` or resolution?
# For now, just pick a reasonable size
corr_window_size = (11, 11)
(
stitched_ifg_paths,
stitched_cor_paths,
stitched_temp_coh_file,
stitched_ps_file,
stitched_amp_dispersion_file,
stitched_shp_count_file,
stitched_similarity_file,
) = stitching_bursts.run(
stitched_paths = stitching_bursts.run(
ifg_file_list=ifg_file_list,
temp_coh_file_list=temp_coh_file_list,
ps_file_list=ps_file_list,
Expand All @@ -217,13 +210,13 @@ def run(
_print_summary(cfg)
return OutputPaths(
comp_slc_dict=comp_slc_dict,
stitched_ifg_paths=stitched_ifg_paths,
stitched_cor_paths=stitched_cor_paths,
stitched_temp_coh_file=stitched_temp_coh_file,
stitched_ps_file=stitched_ps_file,
stitched_amp_dispersion_file=stitched_amp_dispersion_file,
stitched_shp_count_file=stitched_shp_count_file,
stitched_similarity_file=stitched_similarity_file,
stitched_ifg_paths=stitched_paths.ifg_paths,
stitched_cor_paths=stitched_paths.interferometric_corr_paths,
stitched_temp_coh_file=stitched_paths.temp_coh_file,
stitched_ps_file=stitched_paths.ps_file,
stitched_amp_dispersion_file=stitched_paths.amp_dispersion_file,
stitched_shp_count_file=stitched_paths.shp_count_file,
stitched_similarity_file=stitched_paths.similarity_file,
unwrapped_paths=None,
conncomp_paths=None,
timeseries_paths=None,
Expand All @@ -235,9 +228,9 @@ def run(
row_looks, col_looks = cfg.phase_linking.half_window.to_looks()
nlooks = row_looks * col_looks
unwrapped_paths, conncomp_paths = unwrapping.run(
ifg_file_list=stitched_ifg_paths,
cor_file_list=stitched_cor_paths,
temporal_coherence_file=stitched_temp_coh_file,
ifg_file_list=stitched_paths.ifg_paths,
cor_file_list=stitched_paths.interferometric_corr_paths,
temporal_coherence_file=stitched_paths.temp_coh_file,
nlooks=nlooks,
unwrap_options=cfg.unwrap_options,
mask_file=cfg.mask_file,
Expand All @@ -258,8 +251,8 @@ def run(
timeseries_paths, reference_point = timeseries.run(
unwrapped_paths=unwrapped_paths,
conncomp_paths=conncomp_paths,
corr_paths=stitched_cor_paths,
condition_file=stitched_temp_coh_file,
corr_paths=stitched_paths.interferometric_corr_paths,
condition_file=stitched_paths.temp_coh_file,
condition=CallFunc.MAX,
output_dir=ts_opts._directory,
method=timeseries.InversionMethod(ts_opts.method),
Expand Down Expand Up @@ -361,13 +354,13 @@ def run(
_print_summary(cfg)
return OutputPaths(
comp_slc_dict=comp_slc_dict,
stitched_ifg_paths=stitched_ifg_paths,
stitched_cor_paths=stitched_cor_paths,
stitched_temp_coh_file=stitched_temp_coh_file,
stitched_ps_file=stitched_ps_file,
stitched_amp_dispersion_file=stitched_amp_dispersion_file,
stitched_shp_count_file=stitched_shp_count_file,
stitched_similarity_file=stitched_similarity_file,
stitched_ifg_paths=stitched_paths.ifg_paths,
stitched_cor_paths=stitched_paths.interferometric_corr_paths,
stitched_temp_coh_file=stitched_paths.temp_coh_file,
stitched_ps_file=stitched_paths.ps_file,
stitched_amp_dispersion_file=stitched_paths.amp_dispersion_file,
stitched_shp_count_file=stitched_paths.shp_count_file,
stitched_similarity_file=stitched_paths.similarity_file,
unwrapped_paths=unwrapped_paths,
# TODO: Let's keep the unwrapped_paths since all the outputs are
# corresponding to those and if we have a network unwrapping, the
Expand Down
18 changes: 10 additions & 8 deletions src/dolphin/workflows/stitching_bursts.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,9 @@
from __future__ import annotations

import logging
from dataclasses import dataclass
from pathlib import Path
from typing import NamedTuple, Sequence
from typing import Sequence

from dolphin import stitching
from dolphin._log import log_runtime
Expand All @@ -18,22 +19,23 @@
logger = logging.getLogger(__name__)


class StitchedOutputs(NamedTuple):
@dataclass
class StitchedOutputs:
"""Output rasters from stitching step."""

stitched_ifg_paths: list[Path]
ifg_paths: list[Path]
"""List of Paths to the stitched interferograms."""
interferometric_corr_paths: list[Path]
"""List of Paths to interferometric correlation files created."""
stitched_temp_coh_file: Path
temp_coh_file: Path
"""Path to temporal correlation file created."""
stitched_ps_file: Path
ps_file: Path
"""Path to ps mask file created."""
stitched_amp_disp_file: Path
amp_dispersion_file: Path
"""Path to amplitude dispersion file created."""
stitched_shp_count_file: Path
shp_count_file: Path
"""Path to SHP count file created."""
stitched_similarity_file: Path
similarity_file: Path
"""Path to cosine similarity file created."""


Expand Down