Skip to content

Commit

Permalink
Add visual_min_max procedure
Browse files Browse the repository at this point in the history
This replaces the mrc_visual_min_max cli with a function suitable for
the requirements of the image workflow pipelines.
  • Loading branch information
blowekamp committed Apr 14, 2023
1 parent c3f9e9a commit 3aade82
Show file tree
Hide file tree
Showing 7 changed files with 93 additions and 337 deletions.
4 changes: 4 additions & 0 deletions docs/api.rst
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
API
===


.. automodule:: pytools
:members:

.. automodule:: pytools.meta
:members:

Expand Down
9 changes: 2 additions & 7 deletions docs/commandline.rst
Original file line number Diff line number Diff line change
Expand Up @@ -5,22 +5,17 @@ The Pytools packages contain a command line executables for various tasks. They

.. code-block :: bash
mrc_visual_min_max --help
mrc2nifti --help
Or the preferred way using the `python` executable to execute the module entry point:

.. code-block :: bash
python -m mrc_visual_min_max --help
python -m mrc2nifti --help
With either method of invoking the command line interface, the following sections describes the sub-commands available
and the command line options available.

.. click:: pytools.ng.mrc2nifti:main
:prog: mrc2nifti

.. click:: pytools.ng.mrc2ngpc:main
:prog: mrc2npgc

.. click:: pytools.ng.build_histogram:main
:prog: mrc_visual_min_max
4 changes: 3 additions & 1 deletion pytools/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
# limitations under the License.
#

from .workflow_functions import visual_min_max

_installed_package = "tomojs_pytools"

try:
Expand All @@ -35,4 +37,4 @@
# package is not installed
pass

__all__ = ["__version__"]
__all__ = ["__version__", "visual_min_max"]
1 change: 0 additions & 1 deletion pytools/ng/mrc2nifti.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,6 @@ def sub_volume_execute(inplace=True):
def wrapper(func):
@wraps(func)
def slice_by_slice(image: sitk.Image, *args, **kwargs):

dim = image.GetDimension()
iter_dim = 2

Expand Down
258 changes: 5 additions & 253 deletions pytools/utils/histogram.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,21 +11,14 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
import math
import click

import numpy as np
import logging
import SimpleITK as sitk
import zarr

from pytools.utils import MutuallyExclusiveOption
from pytools import __version__
from math import floor, ceil
from pathlib import Path
import dask.array
from typing import Union, Tuple, Any
from typing import Tuple, Any
from abc import ABC, abstractmethod

import math

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -106,93 +99,7 @@ def compute_histogram(self, histogram_bin_edges=None, density=False) -> Tuple[np
pass


class sitkHistogramHelper(HistogramBase):
"""
Read image slice by slice, and build a histogram. The image file must be readable by SimpleITK.
The SimpleITK is expected to support streaming the file format.
The np.histogram function is run on each image slice with the provided histogram_bin_edges, and
accumulated for the results.
:param filename: The path to the image file to read. MRC file type is recommend.
:param histogram_bin_edges: A monotonically increasing array of min edges. The resulting
histogram or weights will have n-1 elements. If None, then it will be automatically computed for integers, and
an np.bincount may be used as an optimization.
:param extract_axis: The image dimension which is sliced during image reading.
:param density: If true the sum of the results is 1.0, otherwise it is the count of values in each bin.
:param extract_step: The number of slices to read at one time.
"""

def __init__(self, filename, extract_axis=2, extract_step=1):
self.reader = sitk.ImageFileReader()
self.reader.SetFileName(str(filename))
self.reader.ReadImageInformation()

logger.info(f'Reading "{self.reader.GetFileName()}" image information...')

logger.info(f"\tPixel Type: {sitk.GetPixelIDValueAsString(self.reader.GetPixelIDValue())}")
logger.info(f"\tPixel Type: {sitk.GetPixelIDValueAsString(self.reader.GetPixelIDValue())}")
logger.info(f"\tSize: {self.reader.GetSize()}")
logger.info(f"\tSpacing: {self.reader.GetSpacing()}")
logger.info(f"\tOrigin: {self.reader.GetOrigin()}")

self.extract_axis = extract_axis
self.extract_step = extract_step

def compute_min_max(self):
img = self.reader.Execute()

min_max_filter = sitk.MinimumMaximumImageFilter()
min_max_filter.Execute(img)
return min_max_filter.GetMinimum(), min_max_filter.GetMaximum()

@property
def dtype(self):
return sitk.extra._get_numpy_dtype(self.reader)

def compute_histogram(self, histogram_bin_edges=None, density=False) -> Tuple[np.array, np.array]:
use_bincount = False
if histogram_bin_edges is None:
if np.issubdtype(self.dtype, np.integer) and np.iinfo(self.dtype).bits <= 16:
histogram_bin_edges = self.compute_histogram_bin_edges(
number_of_bins=2 ** np.iinfo(self.dtype).bits + 1
)
if self.dtype() in (np.uint8, np.uint16):
use_bincount = True
else:
histogram_bin_edges = self.compute_histogram_bin_edges()

h = np.zeros(len(histogram_bin_edges) - 1, dtype=np.int64)

extract_index = [0] * self.reader.GetDimension()

size = self.reader.GetSize()
extract_size = list(size)
extract_size[self.extract_axis] = 0
self.reader.SetExtractSize(extract_size)

for i in range(0, self.reader.GetSize()[self.extract_axis], self.extract_step):
extract_index[self.extract_axis] = i
self.reader.SetExtractIndex(extract_index)
logger.debug(f"extract_index: {extract_index}")

extract_size[self.extract_axis] = min(i + self.extract_step, size[self.extract_axis]) - i
self.reader.SetExtractSize(extract_size)
img = self.reader.Execute()

# accumulate histogram counts
if use_bincount:
h += np.bincount(sitk.GetArrayViewFromImage(img).ravel(), minlength=len(h))
else:
h += np.histogram(sitk.GetArrayViewFromImage(img).ravel(), bins=histogram_bin_edges, density=False)[0]

if density:
h /= np.sum(h)

return h, histogram_bin_edges


class daskHisogramHelper(HistogramBase):
class DaskHistogramHelper(HistogramBase):
def __init__(self, arr: dask.array):
self._arr = arr
if not self._arr.dtype.isnative:
Expand Down Expand Up @@ -226,41 +133,13 @@ def compute_histogram(self, histogram_bin_edges=None, density=False) -> Tuple[np
return h.compute(), bins


class zarrHisogramHelper(daskHisogramHelper):
class ZARRHistogramHelper(DaskHistogramHelper):
def __init__(self, filename):
za = zarr.open_array(filename, mode="r")
super().__init__(dask.array.from_zarr(za))
logging.debug(za.info)


def stream_build_histogram(
filename: Union[Path, str], histogram_bin_edges=None, extract_axis=2, density=False, extract_step=1
):
"""
Read image slice by slice, and build a histogram. The image file must be readable by SimpleITK.
The SimpleITK is expected to support streaming the file format.
The np.histogram function is run on each image slice with the provided histogram_bin_edges, and
accumulated for the results.
:param filename: The path to the image file to read. MRC file type is recommend.
:param histogram_bin_edges: A monotonically increasing array of min edges. The resulting
histogram or weights will have n-1 elements. If None, then it will be automatically computed for integers, and
an np.bincount may be used as an optimization.
:param extract_axis: The image dimension which is sliced during image reading.
:param density: If true the sum of the results is 1.0, otherwise it is the count of values in each bin.
:param extract_step: The number of slices to read at one time.
"""

input_image = Path(filename)

if input_image.is_dir() and (input_image / ".zarray").exists():
histo = zarrHisogramHelper(input_image)

else:
histo = sitkHistogramHelper(filename, extract_axis=extract_axis, extract_step=extract_step)

return histo.compute_histogram(histogram_bin_edges=histogram_bin_edges, density=density)


def histogram_robust_stats(hist, bin_edges):
"""
Computes the "median" and "mad" (Median Absolute Deviation).
Expand Down Expand Up @@ -298,130 +177,3 @@ def histogram_stats(hist, bin_edges):
results["sigma"] = math.sqrt(results["var"])

return results


@click.command()
@click.argument("input_image", type=click.Path(exists=True, dir_okay=True, path_type=Path))
@click.option(
"--mad",
"mad_scale",
type=float,
cls=MutuallyExclusiveOption,
mutually_exclusive=["sigma", "percentile-crop"],
help="Use INPUT_IMAGE's robust median absolute deviation (MAD) scale by option's value about the median for "
"minimum and maximum range. ",
)
@click.option(
"--sigma",
"sigma_scale",
type=float,
cls=MutuallyExclusiveOption,
mutually_exclusive=["mad", "percentile-crop"],
help="Use INPUT_IMAGE's standard deviation (sigma) scale by option's value about the mean for minimum and "
"maximum range. ",
)
@click.option(
"--percentile",
type=click.FloatRange(0.0, 100),
cls=MutuallyExclusiveOption,
mutually_exclusive=["sigma", "mad"],
help="Use INPUT_IMAGE's middle percentile (option's value) of data for minimum and maximum range.",
)
@click.option(
"--clamp/--no-clamp",
default=False,
help="Clamps minimum and maximum range to existing intensity values (floor and limit).",
)
@click.option(
"--output-json",
type=click.Path(exists=False, dir_okay=False, resolve_path=True),
help='The output filename produced in JSON format with "neuroglancerPrecomputedMin", '
'"neuroglancerPrecomputedMax", "neuroglancerPrecomputedFloor" and "neuroglancerPrecomputedLimit" data '
"elements of a double numeric value.",
)
@click.option(
"--log-level", default="INFO", type=click.Choice(["DEBUG", "INFO", "WARNING", "ERROR"], case_sensitive=False)
)
@click.version_option(__version__)
def main(input_image: Path, mad_scale, sigma_scale, percentile, clamp, output_json, log_level):
"""
Reads the INPUT_IMAGE to compute an estimated minimum and maximum range to be used for visualization of the
data set. The image is required to have an integer pixel type.
The optional OUTPUT_JSON filename will be created with the following data elements with integer values as strings:
"neuroglancerPrecomputedMin"
"neuroglancerPrecomputedMax"
"neuroglancerPrecomputedFloor"
"neuroglancerPrecomputedLimit"
"""

logging.basicConfig(format="%(levelname)s: %(message)s", level=logging.getLevelName(log_level))

if input_image.is_dir() and (input_image / ".zarray").exists():
histo = zarrHisogramHelper(input_image)
elif input_image.suffix in (".nii", ".mha", ".mrc", ".rec"):
from SimpleITK.utilities.dask import from_sitk

logger.debug("Loading chunk with SimpleITK and dask...")
sitk_da = from_sitk(input_image, chunks=(1, -1, -1))
histo = daskHisogramHelper(sitk_da)
else:
logger.debug("Loading whole image with SimpleITK...")
img = sitk.ReadImage(input_image)
histo = daskHisogramHelper(dask.array.from_array(sitk.GetArrayViewFromImage(img), chunks=(1, -1, -1)))

logger.info(f'Building histogram for "{input_image}"...')
h, bins = histo.compute_histogram(histogram_bin_edges=None, density=False)

mids = 0.5 * (bins[1:] + bins[:-1])

logger.info("Computing statistics...")
if mad_scale:
stats = histogram_robust_stats(h, bins)
logger.debug(f"stats: {stats}")
min_max = (stats["median"] - stats["mad"] * mad_scale, stats["median"] + stats["mad"] * mad_scale)
elif sigma_scale:
stats = histogram_stats(h, bins)
logger.debug(f"stats: {stats}")
min_max = (stats["mean"] - stats["sigma"] * sigma_scale, stats["mean"] + stats["sigma"] * sigma_scale)
elif percentile:
lower_quantile = (0.5 * (100 - percentile)) / 100.0
upper_quantile = percentile / 100.0 + lower_quantile
logger.debug(f"quantiles: {lower_quantile} {upper_quantile}")

# cs = np.cumsum(h)
# min_max = (np.searchsorted(cs, cs[-1] * (percentile_crop * .005)),
# np.searchsorted(cs, cs[-1] * (1.0 - percentile_crop * .005)))
# min_max = (mids[min_max[0]], mids[min_max[1]])
min_max = weighted_quantile(
mids, quantiles=[lower_quantile, upper_quantile], sample_weight=h, values_sorted=True
)
else:
raise RuntimeError("Missing expected argument")

floor_limit = weighted_quantile(mids, quantiles=[0.0, 1.0], sample_weight=h, values_sorted=True)

if clamp:
min_max = (max(min_max[0], floor_limit[0]), min(min_max[1], floor_limit[1]))

output = {
"neuroglancerPrecomputedMin": str(floor(min_max[0])),
"neuroglancerPrecomputedMax": str(ceil(min_max[1])),
"neuroglancerPrecomputedFloor": str(floor(floor_limit[0])),
"neuroglancerPrecomputedLimit": str(ceil(floor_limit[1])),
}

logger.debug(f"output: {output}")
if output_json:
import json

with open(output_json, "w") as fp:
json.dump(output, fp)
else:
print(output)

return output


if __name__ == "__main__":
main()
Loading

0 comments on commit 3aade82

Please sign in to comment.