Skip to content

Commit

Permalink
Relabel multiscale connected components (#31)
Browse files Browse the repository at this point in the history
The `multiscale_unwrap()` function is updated to perform a
post-processing step that relabels the connected components resulting
from tiled unwrapping using the coarse-unwrapped connected components.

When a large interferogram is unwrapped using a tiled unwrapping
approach, each tile is independently assigned connected component
labels. This can cause some issues for interpreting the resulting
connected components:

* Labels may not be unique across tiles. Two different components in two
  different tiles may be assigned the same integer label.
* If a region of reliable unwrapped phase spans multiple tiles, it may
  be assigned different labels in each of the different tiles.

The relabeling step attempts to address these issues by assigning each
connected component a new label based on the low-resolution (i.e.
coarse-unwrapped) connected component that it most overlapped with. Two
or more high-res connected components that overlapped with the same
low-res connected component will be assigned the same final label.
High-res connected components that most overlapped with different
low-res connected components will be assigned distinct labels. Each
high-res connecteed component that didn't overlap with any low-res
component will be assigned a new unique label.

It's possible for the user to specify a minimum overlap fraction via the
`min_conncomp_overlap` parameter. If the intersection between a
high-res and low-res component (as a fraction of the area of the
high-res component) is below this threshold, then the two won't be
considered overlapping for purposes of relabeling.

The final set of connected components are assigned sequential positive
integer labels [1, 2, ..., N], where N is the total number of unique
components.

Implementing the relabeling step required some refactoring of the
implementation of the `multiscale_unwrap()` function.

The relabeling process needs to see the full set of connected component
labels from all tiles, which requires the Dask task graph to be computed
for each tile prior to relabeling. Later on, when we store the final
unwrapped phase and connected component labels in their respective
output datasets, the task graph would need to be re-computed in order to
retrieve the unwrapped phase from each tile. So each tile would get
unwrapped twice(!!) -- once during relabeling and once more during the
final `dask.array.store()` step.

I've avoided this by writing the intermediate connected component labels
arrays to temporary binary files prior to relabeling. I don't expect
this to have much impact on runtime, since much of the latency of
writing to disk should be hidden by parallel processing of different
tiles, but it does make the code much messier. I haven't been able to
think of a better approach so far, though.
  • Loading branch information
gmgunter authored Sep 20, 2023
1 parent b3612c0 commit 62ab5a2
Show file tree
Hide file tree
Showing 8 changed files with 1,121 additions and 251 deletions.
1 change: 1 addition & 0 deletions src/tophu/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from ._filter import *
from ._io import *
from ._label import *
from ._multilook import *
from ._multiscale import *
from ._unwrap import *
Expand Down
300 changes: 300 additions & 0 deletions src/tophu/_label.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,300 @@
from __future__ import annotations

import itertools
from collections.abc import Mapping
from typing import Any

import dask.array as da
import numpy as np
from numpy.typing import NDArray

from ._util import get_all_unique_values, mode, unique_nonzero_integers

__all__ = [
"relabel_hires_conncomps",
]


# A constant used to identify high-res connected components that don't overlap with any
# low-res component.
NO_OVERLAPPING_LABEL = -1


def find_max_overlapping_labels(
src_conncomp: NDArray[np.unsignedinteger],
dst_conncomp: NDArray[np.unsignedinteger],
*,
min_overlap: float = 0.5,
) -> dict[int, int]:
"""
Find overlapping connected components.
Given two sets of connected component labels, find the labels in the second set that
most overlap with each label in the first set. That is, for each unique label in
`src_conncomp`, compute the label from `dst_conncomp` that it has the largest
intersecting area with, if any such label exists.
The ratio of the intersecting area to the area of the original component must be at
least `min_overlap` for the two components to be considered overlapping. The special
constant `NO_OVERLAPPING_LABEL` is used to identify labels from `src_conncomp` that
did not sufficiently overlap with any connected component from `dst_conncomp`.
Zero-valued elements are not considered to be members of any connected component.
Parameters
----------
src_conncomp : numpy.ndarray
The initial set of connected component labels. An array of nonnegative integers.
dst_conncomp : numpy.ndarray
The second set of connected component labels. An array of nonnegative integers
with the same shape as `src_conncomp`.
min_overlap : float, optional
Minimum intersection between components in order to be considered overlapping,
as a fraction of the area of the component from `src_conncomp`. Must be in the
range (0, 1]. Defaults to 0.5.
Returns
-------
overlapping_labels : dict
A mapping from each unique label in `src_conncomp` to the label in
`dst_conncomp` that it most overlapped with, or to `NO_OVERLAPPING_LABEL` if no
connected component was found that satisfied the minimum overlap threshold.
"""
if dst_conncomp.shape != src_conncomp.shape:
raise ValueError(
"shape mismatch: input connected components arrays must have the same shape"
)

if min_overlap <= 0.0:
raise ValueError(f"min overlap must be > 0, got {min_overlap}")
if min_overlap > 1.0:
raise ValueError(f"min overlap must be <= 1, got {min_overlap}")

# Get the set of unique labels in the first array of connected components (CCs).
src_labels = unique_nonzero_integers(src_conncomp)

# Get a mask of nonzero values in the second array of CCs.
dst_nonzero = dst_conncomp != 0

# Given a label from `src_labels`, find the label of the CC from `dst_conncomp` that
# had the most overlapping area with the corresponding CC in `src_conncomp` (if any
# exists). If no label was found that satisfied the minimum overlap threshold,
# returns `NO_OVERLAPPING_LABEL`.
def get_max_overlapping_label(src_label: int) -> int:
# Get a mask of pixels within the current CC.
cc_mask = src_conncomp == src_label

# Get the total area of the CC (i.e. the number of nonzero values in the mask).
cc_area = np.count_nonzero(cc_mask)

# Get the most frequent label from `dst_conncomp` within the masked region.
dst_label, count = mode(dst_conncomp[cc_mask & dst_nonzero])

# Check whether there was sufficient overlap between the two labels.
if count >= min_overlap * cc_area:
return dst_label
else:
return NO_OVERLAPPING_LABEL

return {src_label: get_max_overlapping_label(src_label) for src_label in src_labels}


def relabel(
conncomp: NDArray[np.unsignedinteger],
label_mapping: Mapping[int, int],
) -> NDArray[np.unsignedinteger]:
"""
Replace each label in `conncomp` with a new label from `label_mapping`.
Given an array of provisional connected component labels `conncomp` and a mapping
from provisional labels to final labels `label_mapping`, create a new array of
connected component labels by replacing each provisional label with the
corresponding final label.
The set of unique nonzero labels in `conncomp` must be a subset of the keys of
`label_mapping`.
Zero-valued elements of `conncomp` are treated as masked out (i.e. not part of any
connected component). They are not relabeled.
Parameters
----------
conncomp : numpy.ndarray
The input array of provisional connected component labels. This array is not
modified by the function.
label_mapping : mapping
Defines a mapping from each unique nonzero label in `conncomp` to the
corresponding final label to assign to that component.
Returns
-------
relabeled : numpy.ndarray
A new array with the same shape and dtype as `conncomp` resulting from replacing
each input connected component label with the corresponding label from
`label_mapping`.
"""
# Create the new connected components (CC) array, initially filled with zeros.
new_conncomp = np.zeros_like(conncomp)

# Loop over unique CC labels in the original `conncomp` array.
for old_label in unique_nonzero_integers(conncomp):
# Get a mask of pixels within the current CC.
mask = conncomp == old_label

# Get the corresponding final label.
new_label = label_mapping[old_label]

# Assign the new label to masked pixels in the output array.
new_conncomp[mask] = new_label

return new_conncomp


def extract_scalar(arr: np.ndarray) -> Any:
"""Extract the scalar value from an array containing a single element."""
if arr.size != 1:
raise ValueError(f"array size must be equal to 1, got {arr.size}")
return np.squeeze(arr)[()]


def relabel_hires_conncomps(
conncomp_hires: da.Array,
conncomp_lores: da.Array,
*,
min_overlap: float = 0.5,
) -> da.Array:
"""
Deduplicate and merge connected component labels resulting from tiled unwrapping.
If a high-resolution interferogram is unwrapped by tiles, each tile may be assigned
a set of connected component (CC) labels independently from the surrounding tiles.
As a result, some CC labels may not be unique across tiles. Furthermore, if regions
of reliable unwrapped phase spanned multiple tiles, they may be assigned different
labels in different tiles.
This function attempts to resolve these issues as a post-processing step by using a
set of low-resolution CCs resulting from coarse unwrapping of the same
interferogram. Unlike the high-resolution CCs, each low-resolution CC is assumed to
be assigned a single unique label.
For each high-resolution CC in each tile, the low-resolution CC that it shares the
most overlapping area with is found. Then each CC is relabeled according to the
low-resolution CC that it most overlapped with. If two or more high-resolution
components share the same most-overlapping low-resolution CC, then they will be
assigned the same label. High-resolution CCs that most overlapped with different
low-resolution CCs will be assigned distinct labels. Each high-resolution CC that
did not overlap with any low-resolution CC will be assigned a unique label.
After relabeling, each unique connected component is assigned a new positive integer
label in [1, 2, ..., N], where N is the total number of unique connected components.
Zero-valued pixels are treated as masked out (i.e. not part of any connected
component). They are not relabeled.
Parameters
----------
conncomp_hires : dask.array.Array
The initial high-resolution connected components. A two-dimensional array of
nonnegative integer values. Each chunk of the array is assumed to have been
independently assigned its connected component labels, such that labels may not
be unique across chunks and some components that span multiple chunks may have
been assigned multiple labels.
conncomp_lores : dask.array.Array
An array of connected component labels resulting from coarse unwrapping. A
two-dimensional array of nonnegative integer values with the same shape and
chunk sizes as `conncomp_hires`. Unlike the high-resolution connected
components, each connected component in `conncomp_lores` is assumed to be
assigned a single unique label.
min_overlap : float, optional
Minimum intersection between components in order to be considered overlapping,
as a fraction of the area of the high-resolution component area. Must be in the
range (0, 1]. Defaults to 0.5.
Returns
-------
new_conncomp_hires : dask.array.Array
The array of updated high-resolution connected component labels.
"""
# The high-res and low-res connected component (CC) arrays should each be 2-D arrays
# with the same shape & chunk sizes.
if conncomp_hires.ndim != 2:
raise ValueError("the input connected components must be 2-D arrays")
if conncomp_hires.shape != conncomp_lores.shape:
raise ValueError(
"shape mismatch: the high-res and low-res connected components arrays must"
" have the same shape"
)
if conncomp_hires.chunks != conncomp_lores.chunks:
raise ValueError(
"the high-res and low-res connected components arrays must have the same"
" chunk sizes"
)

# For each high-res CC in each tile, find the label of the low-res CC that most
# overlapped with it, if any such component exists. The result is an array with
# shape equal to `conncomp_hires.numblocks` of dicts mapping from high-res labels to
# the corresponding low-res labels (one dict per tile).
label_mappings = da.map_blocks(
lambda cc_hires, cc_lores, min_overlap: np.atleast_2d(
find_max_overlapping_labels(cc_hires, cc_lores, min_overlap=min_overlap)
),
conncomp_hires,
conncomp_lores,
min_overlap=min_overlap,
meta=np.empty((), dtype=np.object_),
).compute()

# Get the set of all low-res CC labels from among all tiles that overlapped with any
# high-res CC. This is the set of all unique values (not keys) from among dicts in
# `label_mappings`.
mapped_labels = get_all_unique_values(label_mappings.flat)

# An inexhaustible generator that yields new unique positive-valued connected
# component labels not found in the set of existing low-res labels.
new_unique_labels = (
label for label in itertools.count(1) if label not in mapped_labels
)

# Update the label mappings to replace `NO_OVERLAPPING_LABEL` values with new unique
# labels.
for label_mapping in label_mappings.flat:
for key, val in label_mapping.items():
if val == NO_OVERLAPPING_LABEL:
label_mapping[key] = next(new_unique_labels)

# Once more, get the set of all mapped-to labels in `label_mappings` after we
# finished updating it to replace `NO_OVERLAPPING_LABEL` values with new labels.
updated_mapped_labels = get_all_unique_values(label_mappings.flat)

# We would like the final set of connected component labels to be the set of natural
# numbers 1 through N, where N is the total number of connected components.
# Currently, that's not necessarily the case in `updated_mapped_labels` -- due to
# merging of equivalent labels, there could be some "gaps" in the natural sequence
# of labels. So we define an additional mapping from `updated_mapped_labels` to the
# set of final labels, which will be the natural numbers 1 through N.
final_labels = dict(zip(updated_mapped_labels, itertools.count(1)))

# Create a new array of label mappings, with one dict per tile in the original
# high-res CC array. Each dict defines a mapping from the original high-res
# labels to the corresponding final labels for each CC in the tile.
final_label_mappings = label_mappings.copy()
for label_mapping in final_label_mappings.flat:
for key, val in label_mapping.items():
label_mapping[key] = final_labels[val]

# Break the `final_label_mappings` array up into chunks (one chunk per tile in the
# input `conncomp_hires` array).
final_label_mappings = da.from_array(final_label_mappings, chunks=(1, 1))
assert final_label_mappings.numblocks == conncomp_hires.numblocks

# Finally, create the output array of updated connected component labels by
# replacing each high-res connected component label with the new corresponding label
# from `final_label_mappings`.
return da.map_blocks(
lambda conncomp, label_mapping: relabel(
conncomp, extract_scalar(label_mapping)
),
conncomp_hires,
final_label_mappings,
meta=conncomp_hires._meta,
)
Loading

0 comments on commit 62ab5a2

Please sign in to comment.