Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add create_nonzero_conncomp_counts to create counts of valid unwrapped outputs #485

Merged
merged 4 commits into from
Nov 7, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
120 changes: 116 additions & 4 deletions src/dolphin/timeseries.py
Original file line number Diff line number Diff line change
Expand Up @@ -524,7 +524,9 @@ def invert_stack(


def get_incidence_matrix(
ifg_pairs: Sequence[tuple[T, T]], sar_idxs: Sequence[T] | None = None
ifg_pairs: Sequence[tuple[T, T]],
sar_idxs: Sequence[T] | None = None,
delete_first_date_column: bool = True,
) -> np.ndarray:
"""Build the indicator matrix from a list of ifg pairs (index 1, index 2).

Expand All @@ -538,6 +540,10 @@ def get_incidence_matrix(
were formed from.
Otherwise, created from the unique entries in `ifg_pairs`.
Only provide if there are some dates which are not present in `ifg_pairs`.
delete_first_date_column : bool
If True, removes the first column of the matrix to make it full column rank.
Size will be `n_sar_dates - 1` columns.
Otherwise, the matrix will have `n_sar_dates`, but rank `n_sar_dates - 1`.

Returns
-------
Expand All @@ -553,13 +559,13 @@ def get_incidence_matrix(
sar_idxs = sorted(set(flatten(ifg_pairs)))

M = len(ifg_pairs)
N = len(sar_idxs) - 1
col_iter = sar_idxs[1:] if delete_first_date_column else sar_idxs
N = len(col_iter)
A = np.zeros((M, N))

# Create a dictionary mapping sar dates to matrix columns
# We take the first SAR acquisition to be time 0, leave out of matrix
date_to_col = {date: i for i, date in enumerate(sar_idxs[1:])}
# Populate the matrix
date_to_col = {date: i for i, date in enumerate(col_iter)}
for i, (early, later) in enumerate(ifg_pairs):
if early in date_to_col:
A[i, date_to_col[early]] = -1
Expand Down Expand Up @@ -1316,3 +1322,109 @@ def invert_stack_l1(A: ArrayLike, dphi: ArrayLike) -> Array:
# residuals = jnp.sum(residual_vecs, axis=0)

return phase, residuals


def create_nonzero_conncomp_counts(
conncomp_file_list: Sequence[PathOrStr],
output_dir: PathOrStr,
ifg_date_pairs: Sequence[Sequence[DateOrDatetime]] | None = None,
block_shape: tuple[int, int] = (256, 256),
num_threads: int = 4,
) -> list[Path]:
"""Count the number of valid interferograms per date.

Parameters
----------
conncomp_file_list : Sequence[PathOrStr]
List of connected component files
output_dir : PathOrStr
The directory to save the output files
ifg_date_pairs : Sequence[Sequence[DateOrDatetime]], optional
List of date pairs corresponding to the interferograms.
If not provided, will be parsed from filenames.
block_shape : tuple[int, int], optional
The shape of the blocks to process in parallel.
num_threads : int
The number of parallel blocks to process at once.

Returns
-------
out_paths : list[Path]
List of output files, one per unique date

"""
output_dir = Path(output_dir)
output_dir.mkdir(exist_ok=True, parents=True)

if ifg_date_pairs is None:
ifg_date_pairs = [get_dates(str(f))[:2] for f in conncomp_file_list]
try:
# Ensure it's a list of pairs
ifg_tuples = [(ref, sec) for (ref, sec) in ifg_date_pairs] # noqa: C416
except ValueError as e:
raise ValueError(
"Each item in `ifg_date_pairs` must be a sequence of length 2"
) from e

# Get unique dates and create the counting matrix
sar_dates: list[DateOrDatetime] = sorted(set(utils.flatten(ifg_date_pairs)))

date_counting_matrix = np.abs(
get_incidence_matrix(ifg_tuples, sar_dates, delete_first_date_column=False)
)

# Create output paths for each date
suffix = "_valid_count.tif"
out_paths = [output_dir / f"{d.strftime('%Y%m%d')}{suffix}" for d in sar_dates]

if all(p.exists() for p in out_paths):
logger.info("All output files exist, skipping counting")
return out_paths

logger.info("Counting valid interferograms per date")

# Create VRT stack for reading
vrt_name = Path(output_dir) / "conncomp_network.vrt"
conncomp_reader = io.VRTStack(
file_list=conncomp_file_list,
outfile=vrt_name,
skip_size_check=True,
read_masked=True,
)

def count_by_date(
readers: Sequence[io.StackReader], rows: slice, cols: slice
) -> tuple[np.ndarray, slice, slice]:
"""Process each block by counting valid interferograms per date."""
stack = readers[0][:, rows, cols]
valid_mask = stack.filled(0) != 0 # Shape: (n_ifgs, block_rows, block_cols)

# Use the counting matrix to map from interferograms to dates
# For each pixel, multiply the valid_mask to get counts per date
# Reshape valid_mask to (n_ifgs, -1) to handle all pixels at once
valid_flat = valid_mask.reshape(valid_mask.shape[0], -1)
# Matrix multiply to get counts per date
# (date_counting_matrix.T) is shape (n_sar_dates, n_ifgs), and each row
# has a number of 1s equal to the nonzero conncomps for that date.
date_count_cols = date_counting_matrix.T @ valid_flat
date_counts = date_count_cols.reshape(-1, *valid_mask.shape[1:])

return date_counts, rows, cols

# Setup writer for all output files
writer = io.BackgroundStackWriter(
out_paths, like_filename=conncomp_file_list[0], dtype=np.uint16, units="count"
)

# Process the blocks
io.process_blocks(
readers=[conncomp_reader],
writer=writer,
func=count_by_date,
block_shape=block_shape,
num_threads=num_threads,
)
writer.notify_finished()

logger.info("Completed counting valid interferograms per date")
return out_paths