From 13ef0a876cb1c9a83c497dd50a65ecb3eda3f8e2 Mon Sep 17 00:00:00 2001 From: Brendt Wohlberg Date: Mon, 5 Feb 2024 15:45:46 -0700 Subject: [PATCH] Various improvements to `scico.flax` and related example scripts (#498) * Rename flax data files * Minor docstring improvements * Minor docstring improvement * Remove astra-toolbox channel * Minor improvement * Add timer to trainer class * Minor clean up * Minor change to log format * Improve log formatting * Improve log formatting * Update submodule * Typo fix * Improve log format * Docs fixes * Docs consistency * Minor edit * Docs consistency * Fix broken cross-references * Docs consistency * Improve docs * Rename functions * Update function docs * Update URL * Clean up log format * Update submodule * Fix overly simple regex * Trivial edit * Clean up some scripts * Update submodule * Overlooked change from recent astra PR * Add note on GPU support test script * Overlooked change from recent astra PR * Add script for removing error output from notebooks * Update submodule * Fix tests * Fix tests * Update submodule --------- Co-authored-by: Brendt Wohlberg --- CHANGES.rst | 3 ++ data | 2 +- docs/source/install.rst | 5 +++ examples/jnb.py | 26 +++++++++++- examples/removejnberr.py | 18 ++++++++ examples/scripts/ct_astra_modl_train_foam2.py | 17 ++++---- examples/scripts/ct_astra_odp_train_foam2.py | 18 ++++---- examples/scripts/ct_astra_unet_train_foam2.py | 11 ++--- examples/scripts/deconv_datagen_bsds.py | 2 - examples/scripts/deconv_modl_train_foam1.py | 12 +++--- examples/scripts/deconv_odp_train_foam1.py | 14 +++---- examples/scripts/denoise_dncnn_train_bsds.py | 13 ++---- misc/conda/make_conda_env.sh | 3 +- scico/denoiser.py | 6 +-- scico/flax/__init__.py | 36 +++++++++++++--- scico/flax/_flax.py | 17 ++++---- scico/flax/_models.py | 24 +++++------ scico/flax/blocks.py | 14 +++---- scico/flax/examples/data_preprocessing.py | 2 +- scico/flax/examples/examples.py | 6 +-- scico/flax/inverse.py | 42 ++++++++----------- scico/flax/train/clu_utils.py | 21 ++++------ scico/flax/train/trainer.py | 34 ++++++++------- scico/flax/train/typed_dict.py | 4 +- scico/ray/__init__.py | 4 +- scico/scipy/__init__.py | 4 +- scico/test/flax/test_clu.py | 8 ++-- scico/test/flax/test_flax.py | 6 +-- 28 files changed, 209 insertions(+), 163 deletions(-) create mode 100755 examples/removejnberr.py diff --git a/CHANGES.rst b/CHANGES.rst index c675d6f9d..0fc4bd0fc 100644 --- a/CHANGES.rst +++ b/CHANGES.rst @@ -7,6 +7,9 @@ Version 0.0.6 (unreleased) ---------------------------- • Significant changes to ``linop.xray.astra`` API. +• Rename ``scico.flax.save_weights`` and ``scico.flax.load_weights`` to + ``scico.flax.save_variables`` and ``scico.flax.load_variables`` + respectively. diff --git a/data b/data index 097ae716c..330f4a514 160000 --- a/data +++ b/data @@ -1 +1 @@ -Subproject commit 097ae716ce965e4f2b1f0a966bd1ade0ff23149b +Subproject commit 330f4a5144be92a2acf19e27f60391a7f1fcd6f2 diff --git a/docs/source/install.rst b/docs/source/install.rst index 39e84387e..dcff541ea 100644 --- a/docs/source/install.rst +++ b/docs/source/install.rst @@ -130,6 +130,11 @@ a version with GPU support: numbers. +The script `misc/envinfo.py `_ +in the source distribution is provided as an aid to debugging GPU support +issues. + + Additional Dependencies ----------------------- diff --git a/examples/jnb.py b/examples/jnb.py index a0d98d02d..5b8ead6ee 100644 --- a/examples/jnb.py +++ b/examples/jnb.py @@ -1,5 +1,5 @@ # -*- coding: utf-8 -*- -# Copyright (C) 2022-2023 by SCICO Developers +# Copyright (C) 2022-2024 by SCICO Developers # All rights reserved. BSD 3-clause License. # This file is part of the SCICO package. Details of the copyright and # user license can be found in the 'LICENSE' file distributed with the @@ -44,7 +44,7 @@ def py_file_to_string(src): else: # Set flag indicating that an import statement has been seen once one has # been encountered - if re.match("^(import|from)", line): + if re.match("^import|^from .* import", line): import_seen = True lines.append(line) # Backtrack through list of lines to find last import statement @@ -221,3 +221,25 @@ def replace_markdown_cells(src, dst): # the dst cell if srccell[n]["cell_type"] == "markdown": dstcell[n]["source"] = srccell[n]["source"] + + +def remove_error_output(src): + """Remove output to stderr from all cells in `src`.""" + + if "cells" in src: + cells = src["cells"] + else: + cells = src["worksheets"][0]["cells"] + + modified = False + for c in cells: + if "outputs" in c: + dellist = [] + for n, out in enumerate(c["outputs"]): + if "name" in out and out["name"] == "stderr": + dellist.append(n) + modified = True + for n in dellist[::-1]: + del c["outputs"][n] + + return modified diff --git a/examples/removejnberr.py b/examples/removejnberr.py new file mode 100755 index 000000000..878b32de6 --- /dev/null +++ b/examples/removejnberr.py @@ -0,0 +1,18 @@ +#!/usr/bin/env python + +# Remove output to stderr in notebooks. NB: use with caution! +# Run as +# python removejnberr.py + +import glob +import os + +from jnb import read_notebook, remove_error_output +from py2jn.tools import write_notebook + +for src in glob.glob(os.path.join("notebooks", "*.ipynb")): + nb = read_notebook(src) + modflg = remove_error_output(nb) + if modflg: + print(f"Removing output to stderr from {src}") + write_notebook(nb, src) diff --git a/examples/scripts/ct_astra_modl_train_foam2.py b/examples/scripts/ct_astra_modl_train_foam2.py index c6e10bde5..dc25ebb38 100644 --- a/examples/scripts/ct_astra_modl_train_foam2.py +++ b/examples/scripts/ct_astra_modl_train_foam2.py @@ -54,7 +54,7 @@ from scico import metric, plot from scico.flax.examples import load_ct_data from scico.flax.train.traversals import clip_positive, construct_traversal -from scico.linop.xray.astra import XRayTransform +from scico.linop.xray.astra import XRayTransform2D """ Prepare parallel processing. Set an arbitrary processor count (only @@ -81,9 +81,9 @@ Build CT projection operator. """ angles = np.linspace(0, np.pi, n_projection) # evenly spaced projection angles -A = XRayTransform( +A = XRayTransform2D( input_shape=(N, N), - detector_spacing=1, + det_spacing=1, det_count=N, angles=angles, ) # CT projection operator @@ -138,7 +138,7 @@ """ -Construct functionality for making sure that the learned +Construct functionality for ensuring that the learned regularization parameter is always positive. """ lmbdatrav = construct_traversal("lmbda") # select lmbda parameters in model @@ -152,8 +152,8 @@ """ Print configuration of distributed run. """ -print(f"{'JAX process: '}{jax.process_index()}{' / '}{jax.process_count()}") -print(f"{'JAX local devices: '}{jax.local_devices()}") +print(f"\nJAX process: {jax.process_index()}{' / '}{jax.process_count()}") +print(f"JAX local devices: {jax.local_devices()}\n") """ @@ -212,9 +212,8 @@ cg_iter=model_conf["cg_iter_1"], ) # First stage: initialization training loop. - workdir = os.path.join(os.path.expanduser("~"), ".cache", "scico", "examples", "modl_ct_out") - - train_conf["workdir"] = workdir + workdir1 = os.path.join(os.path.expanduser("~"), ".cache", "scico", "examples", "modl_ct_out") + train_conf["workdir"] = workdir1 train_conf["post_lst"] = [lmbdapos] # Construct training object trainer = sflax.BasicFlaxTrainer( diff --git a/examples/scripts/ct_astra_odp_train_foam2.py b/examples/scripts/ct_astra_odp_train_foam2.py index ec8a21875..8c5d9ad61 100644 --- a/examples/scripts/ct_astra_odp_train_foam2.py +++ b/examples/scripts/ct_astra_odp_train_foam2.py @@ -58,7 +58,7 @@ from scico import metric, plot from scico.flax.examples import load_ct_data from scico.flax.train.traversals import clip_positive, construct_traversal -from scico.linop.xray.astra import XRayTransform +from scico.linop.xray.astra import XRayTransform2D """ Prepare parallel processing. Set an arbitrary processor count (only @@ -85,9 +85,9 @@ Build CT projection operator. """ angles = np.linspace(0, np.pi, n_projection) # evenly spaced projection angles -A = XRayTransform( +A = XRayTransform2D( input_shape=(N, N), - detector_spacing=1, + det_spacing=1, det_count=N, angles=angles, ) # CT projection operator @@ -138,7 +138,7 @@ """ -Construct functionality for making sure that the learned fidelity weight +Construct functionality for ensuring that the learned fidelity weight parameter is always positive. """ alphatrav = construct_traversal("alpha") # select alpha parameters in model @@ -152,8 +152,8 @@ """ Print configuration of distributed run. """ -print(f"{'JAX process: '}{jax.process_index()}{' / '}{jax.process_count()}") -print(f"{'JAX local devices: '}{jax.local_devices()}") +print(f"\nJAX process: {jax.process_index()}{' / '}{jax.process_count()}") +print(f"JAX local devices: {jax.local_devices()}\n") """ @@ -185,10 +185,7 @@ train_ds, test_ds, ) - -start_time = time() modvar, stats_object = trainer.train() -time_train = time() - start_time """ @@ -215,13 +212,14 @@ psnr_eval = metric.psnr(test_ds["label"][:maxn], output) print( f"{'ODPNet training':18s}{'epochs:':2s}{epochs:>5d}{'':21s}" - f"{'time[s]:':10s}{time_train:>7.2f}" + f"{'time[s]:':10s}{trainer.train_time:>7.2f}" ) print( f"{'ODPNet testing':18s}{'SNR:':5s}{snr_eval:>5.2f}{' dB'}{'':3s}" f"{'PSNR:':6s}{psnr_eval:>5.2f}{' dB'}{'':3s}{'time[s]:':10s}{time_eval:>7.2f}" ) + """ Plot comparison. """ diff --git a/examples/scripts/ct_astra_unet_train_foam2.py b/examples/scripts/ct_astra_unet_train_foam2.py index 33a1500b4..72e82e81d 100644 --- a/examples/scripts/ct_astra_unet_train_foam2.py +++ b/examples/scripts/ct_astra_unet_train_foam2.py @@ -104,21 +104,16 @@ """ workdir = os.path.join(os.path.expanduser("~"), ".cache", "scico", "examples", "unet_ct_out") train_conf["workdir"] = workdir -print(f"{'JAX process: '}{jax.process_index()}{' / '}{jax.process_count()}") -print(f"{'JAX local devices: '}{jax.local_devices()}") +print(f"\nJAX process: {jax.process_index()}{' / '}{jax.process_count()}") +print(f"JAX local devices: {jax.local_devices()}\n") - -# Construct training object trainer = sflax.BasicFlaxTrainer( train_conf, model, train_ds, test_ds, ) - -start_time = time() modvar, stats_object = trainer.train() -time_train = time() - start_time """ @@ -144,7 +139,7 @@ psnr_eval = metric.psnr(test_ds["label"][:maxn], output) print( f"{'UNet training':15s}{'epochs:':2s}{train_conf['num_epochs']:>5d}" - f"{'':21s}{'time[s]:':10s}{time_train:>7.2f}" + f"{'':21s}{'time[s]:':10s}{trainer.train_time:>7.2f}" ) print( f"{'UNet testing':15s}{'SNR:':5s}{snr_eval:>5.2f}{' dB'}{'':3s}" diff --git a/examples/scripts/deconv_datagen_bsds.py b/examples/scripts/deconv_datagen_bsds.py index 660a6a511..785e4540a 100644 --- a/examples/scripts/deconv_datagen_bsds.py +++ b/examples/scripts/deconv_datagen_bsds.py @@ -30,7 +30,6 @@ blur_sigma = 5 # Gaussian blur kernel parameter opBlur = PaddedCircularConvolve(output_size, channels, blur_shape, blur_sigma) - opBlur_vmap = vmap(opBlur) # for batch processing @@ -47,7 +46,6 @@ stride = 100 # stride to sample multiple patches from each image augment = True # augment data via rotations and flips - train_ds, test_ds = load_image_data( train_nimg, test_nimg, diff --git a/examples/scripts/deconv_modl_train_foam1.py b/examples/scripts/deconv_modl_train_foam1.py index 025779e8d..2916d7a1f 100644 --- a/examples/scripts/deconv_modl_train_foam1.py +++ b/examples/scripts/deconv_modl_train_foam1.py @@ -77,7 +77,6 @@ ishape = (output_size, output_size) opBlur = CircularConvolve(h=psf, input_shape=ishape) - opBlur_vmap = jax.vmap(opBlur) # for batch processing in data generation @@ -133,7 +132,7 @@ """ -Construct functionality for making sure that the learned regularization +Construct functionality for ensuring that the learned regularization parameter is always positive. """ lmbdatrav = construct_traversal("lmbda") # select lmbda parameters in model @@ -147,8 +146,8 @@ """ Print configuration of distributed run. """ -print(f"{'JAX process: '}{jax.process_index()}{' / '}{jax.process_count()}") -print(f"{'JAX local devices: '}{jax.local_devices()}") +print(f"\nJAX process: {jax.process_index()}{' / '}{jax.process_count()}") +print(f"JAX local devices: {jax.local_devices()}\n") """ @@ -204,9 +203,8 @@ cg_iter=model_conf["cg_iter"], ) # First stage: initialization training loop. - workdir = os.path.join(os.path.expanduser("~"), ".cache", "scico", "examples", "modl_dcnv_out") - - train_conf["workdir"] = workdir + workdir1 = os.path.join(os.path.expanduser("~"), ".cache", "scico", "examples", "modl_dcnv_out") + train_conf["workdir"] = workdir1 train_conf["post_lst"] = [lmbdapos] # Construct training object trainer = sflax.BasicFlaxTrainer( diff --git a/examples/scripts/deconv_odp_train_foam1.py b/examples/scripts/deconv_odp_train_foam1.py index 517f1bf3f..ffe852e7a 100644 --- a/examples/scripts/deconv_odp_train_foam1.py +++ b/examples/scripts/deconv_odp_train_foam1.py @@ -85,7 +85,6 @@ ishape = (output_size, output_size) opBlur = CircularConvolve(h=psf, input_shape=ishape) - opBlur_vmap = jax.vmap(opBlur) # for batch processing in data generation @@ -153,7 +152,7 @@ """ -Construct functionality for making sure that the learned fidelity weight +Construct functionality for ensuring that the learned fidelity weight parameter is always positive. """ alphatrav = construct_traversal("alpha") # select alpha parameters in model @@ -167,10 +166,10 @@ """ Run training loop. """ -workdir = os.path.join(os.path.expanduser("~"), ".cache", "scico", "examples", "odp_dcnv_out") -print(f"{'JAX process: '}{jax.process_index()}{' / '}{jax.process_count()}") -print(f"{'JAX local devices: '}{jax.local_devices()}") +print(f"\nJAX process: {jax.process_index()}{' / '}{jax.process_count()}") +print(f"JAX local devices: {jax.local_devices()}\n") +workdir = os.path.join(os.path.expanduser("~"), ".cache", "scico", "examples", "odp_dcnv_out") train_conf["workdir"] = workdir train_conf["post_lst"] = [alphapos] # Construct training object @@ -180,10 +179,7 @@ train_ds, test_ds, ) - -start_time = time() modvar, stats_object = trainer.train() -time_train = time() - start_time """ @@ -210,7 +206,7 @@ psnr_eval = metric.psnr(test_ds["label"][:maxn], output) print( f"{'ODPNet training':18s}{'epochs:':2s}{train_conf['num_epochs']:>5d}" - f"{'':21s}{'time[s]:':10s}{time_train:>7.2f}" + f"{'':21s}{'time[s]:':10s}{trainer.train_time:>7.2f}" ) print( f"{'ODPNet testing':18s}{'SNR:':5s}{snr_eval:>5.2f}{' dB'}{'':3s}" diff --git a/examples/scripts/denoise_dncnn_train_bsds.py b/examples/scripts/denoise_dncnn_train_bsds.py index 6e58ecb9b..a55df76d0 100644 --- a/examples/scripts/denoise_dncnn_train_bsds.py +++ b/examples/scripts/denoise_dncnn_train_bsds.py @@ -48,7 +48,6 @@ noise_range = False # Use fixed noise level stride = 23 # Stride to sample multiple patches from each image - train_ds, test_ds = load_image_data( train_nimg, test_nimg, @@ -105,8 +104,7 @@ """ workdir = os.path.join(os.path.expanduser("~"), ".cache", "scico", "examples", "dncnn_out") train_conf["workdir"] = workdir -print(f"{'JAX process: '}{jax.process_index()}{' / '}{jax.process_count()}") -print(f"{'JAX local devices: '}{jax.local_devices()}") +print(f"\nJAX local devices: {jax.local_devices()}\n") trainer = sflax.BasicFlaxTrainer( train_conf, @@ -114,10 +112,7 @@ train_ds, test_ds, ) - -start_time = time() modvar, stats_object = trainer.train() -time_train = time() - start_time """ @@ -138,7 +133,7 @@ psnr_eval = metric.psnr(test_ds["label"][:test_patches], output) print( f"{'DnCNNNet training':18s}{'epochs:':2s}{train_conf['num_epochs']:>5d}" - f"{'':21s}{'time[s]:':10s}{time_train:>7.2f}" + f"{'':21s}{'time[s]:':10s}{trainer.train_time:>7.2f}" ) print( f"{'DnCNNNet testing':18s}{'SNR:':5s}{snr_eval:>5.2f}{' dB'}{'':3s}" @@ -147,8 +142,8 @@ """ -Plot comparison. Note that patches have small sizes, thus, plots may -correspond to unidentifiable fragments. +Plot comparison. Note that plots may display unidentifiable image +fragments due to the small patch size. """ np.random.seed(123) indx = np.random.randint(0, high=test_patches) diff --git a/misc/conda/make_conda_env.sh b/misc/conda/make_conda_env.sh index e424df832..34b48c615 100755 --- a/misc/conda/make_conda_env.sh +++ b/misc/conda/make_conda_env.sh @@ -216,9 +216,8 @@ conda create $CONDA_FLAGS -n $ENVNM python=$PYVER eval "$(conda shell.bash hook)" # required to avoid errors re: `conda init` conda activate $ENVNM # Q: why not `source activate`? A: not always in the path -# Add conda-forge and astra-toolbox channels +# Add conda-forge channel conda config --env --append channels conda-forge -conda config --env --append channels astra-toolbox # Install mamba conda install mamba -n base -c conda-forge diff --git a/scico/denoiser.py b/scico/denoiser.py index 574aa2596..11a313ee3 100644 --- a/scico/denoiser.py +++ b/scico/denoiser.py @@ -1,5 +1,5 @@ # -*- coding: utf-8 -*- -# Copyright (C) 2020-2023 by SCICO Developers +# Copyright (C) 2020-2024 by SCICO Developers # All rights reserved. BSD 3-clause License. # This file is part of the SCICO package. Details of the copyright and # user license can be found in the 'LICENSE' file distributed with the @@ -34,7 +34,7 @@ import scico.numpy as snp from scico.data import _flax_data_path -from scico.flax import DnCNNNet, FlaxMap, load_weights +from scico.flax import DnCNNNet, FlaxMap, load_variables def bm3d(x: snp.Array, sigma: float, is_rgb: bool = False, profile: Union[BM3DProfile, str] = "np"): @@ -234,7 +234,7 @@ def __init__(self, variant: str = "6M"): self.is_blind = True model = DnCNNNet(depth=nlayer, channels=channels, num_filters=64, dtype=np.float32) - variables = load_weights(_flax_data_path("dncnn%s.npz" % variant)) + variables = load_variables(_flax_data_path("dncnn%s.mpk" % variant)) super().__init__(model, variables) def __call__(self, x: snp.Array, sigma: Optional[float] = None) -> snp.Array: diff --git a/scico/flax/__init__.py b/scico/flax/__init__.py index d4dce72d2..c141ae47c 100644 --- a/scico/flax/__init__.py +++ b/scico/flax/__init__.py @@ -1,16 +1,42 @@ # -*- coding: utf-8 -*- -# Copyright (C) 2021-2022 by SCICO Developers +# Copyright (C) 2021-2024 by SCICO Developers # All rights reserved. BSD 3-clause License. # This file is part of the SCICO package. Details of the copyright and # user license can be found in the 'LICENSE' file distributed with the # package. -"""Neural network models implemented in Flax and utility functions.""" +"""Neural network models implemented in `Flax `_ and utility functions. + +Many of the function and parameter names used in this sub-package are +based on the somewhat non-standard Flax terminology for neural network +components: + +`model` + The model is an abstract representation of the network structure that + does not include specific weight values. + +`parameters` + The parameters of a model are the weights of the network represented + by the model. + +`variables` + The variables encompass both the parameters (i.e. network weights) + and secondary values that are set from training data, such as + layer-dependent statistics used in batch normalization. + +`state` + The state encompasses both a set of model parameters as well as + optimizer parameters involved in training of that model. Storing the + state rather than just the variables enables a warm start for + additional training. + +| +""" import sys # isort: off -from ._flax import FlaxMap, load_weights, save_weights +from ._flax import FlaxMap, load_variables, save_variables from ._models import ConvBNNet, DnCNNNet, ResNet, UNet from .inverse import MoDLNet, ODPNet from .train.input_pipeline import create_input_iter @@ -21,8 +47,8 @@ __all__ = [ "FlaxMap", - "load_weights", - "save_weights", + "load_variables", + "save_variables", "ConvBNNet", "DnCNNNet", "ResNet", diff --git a/scico/flax/_flax.py b/scico/flax/_flax.py index 7af04ce93..7838ced51 100644 --- a/scico/flax/_flax.py +++ b/scico/flax/_flax.py @@ -20,15 +20,15 @@ PyTree = Any -def load_weights(filename: str) -> PyTree: - """Load trained model weights. +def load_variables(filename: str) -> PyTree: + """Load trained model variables. Args: - filename: Name of file containing parameters for trained model. + filename: Name of file containing trained model variables. Returns: - A tree-like structure containing the values of the parameters of - the model. + A tree-like structure containing the values of the model + variables. """ with open(filename, "rb") as data_file: bytes_input = data_file.read() @@ -40,12 +40,13 @@ def load_weights(filename: str) -> PyTree: return var_in -def save_weights(variables: PyTree, filename: str): +def save_variables(variables: PyTree, filename: str): """Save trained model weights. Args: - filename: Name of file to save parameters of trained model. - variables: Parameters of model to save. + filename: Name of file to to which model variables should be + saved. + variables: Model variables to save. """ bytes_output = serialization.msgpack_serialize(variables) diff --git a/scico/flax/_models.py b/scico/flax/_models.py index 74fae83f1..c59c8c7f6 100644 --- a/scico/flax/_models.py +++ b/scico/flax/_models.py @@ -48,10 +48,11 @@ class DnCNNNet(Module): depth: Number of layers in the neural network. channels: Number of channels of input tensor. num_filters: Number of filters in the convolutional layers. - kernel_size: Size of the convolution filters. Default: (3, 3). - strides: Convolution strides. Default: (1, 1). + kernel_size: Size of the convolution filters. + strides: Convolution strides. dtype: Output dtype. Default: :attr:`~numpy.float32`. - act: Class of activation function to apply. Default: `nn.relu`. + act: Class of activation function to apply. Default: + :func:`~flax.linen.activation.relu`. """ depth: int @@ -130,8 +131,8 @@ class ResNet(Module): num_filters: Number of filters in the layers of the block. Corresponds to the number of channels in the network processing. - kernel_size: Size of the convolution filters. Default: 3x3. - strides: Convolution strides. Default: 1x1. + kernel_size: Size of the convolution filters. + strides: Convolution strides. dtype: Output dtype. Default: :attr:`~numpy.float32`. """ @@ -204,8 +205,8 @@ class ConvBNNet(Module): num_filters: Number of filters in the layers of the block. Corresponds to the number of channels in the network processing. - kernel_size: Size of the convolution filters. Default: 3x3. - strides: Convolution strides. Default: 1x1. + kernel_size: Size of the convolution filters. + strides: Convolution strides. dtype: Output dtype. Default: :attr:`~numpy.float32`. """ @@ -272,12 +273,11 @@ class UNet(Module): num_filters: Number of filters in the convolutional layer of the block. Corresponds to the number of channels in the network processing. - kernel_size: Size of the convolution filters. Default: 3x3. - strides: Convolution strides. Default: 1x1. - block_depth: Number of processing layers per block. Default: 2. + kernel_size: Size of the convolution filters. + strides: Convolution strides. + block_depth: Number of processing layers per block. window_shape: Window for reduction for pooling and downsampling. - Default: 2x2. - upsampling: Factor for expanding. Default: 2. + upsampling: Factor for expanding. dtype: Output dtype. Default: :attr:`~numpy.float32`. """ diff --git a/scico/flax/blocks.py b/scico/flax/blocks.py index 26dfc0f6d..4b0664aac 100644 --- a/scico/flax/blocks.py +++ b/scico/flax/blocks.py @@ -1,5 +1,5 @@ # -*- coding: utf-8 -*- -# Copyright (C) 2021-2023 by SCICO Developers +# Copyright (C) 2021-2024 by SCICO Developers # All rights reserved. BSD 3-clause License. # This file is part of the SCICO package. Details of the copyright and # user license can be found in the 'LICENSE' file distributed with the @@ -39,9 +39,9 @@ class ConvBNBlock(Module): apply. act: Flax function defining the activation operation to apply. kernel_size: A shape tuple defining the size of the convolution - filters. Default: (3, 3). + filters. strides: A shape tuple defining the size of strides in - convolution. Default: (1, 1). + convolution. """ num_filters: int @@ -83,9 +83,9 @@ class ConvBlock(Module): conv: Flax module implementing the convolution layer to apply. act: Flax function defining the activation operation to apply. kernel_size: A shape tuple defining the size of the convolution - filters. Default: (3, 3). + filters. strides: A shape tuple defining the size of strides in - convolution. Default: (1, 1). + convolution. """ num_filters: int @@ -230,9 +230,9 @@ class ConvBNMultiBlock(Module): apply. act: Flax function defining the activation operation to apply. kernel_size: A shape tuple defining the size of the convolution - filters. Default: (3, 3). + filters. strides: A shape tuple defining the size of strides in - convolution. Default: (1, 1). + convolution. """ num_blocks: int diff --git a/scico/flax/examples/data_preprocessing.py b/scico/flax/examples/data_preprocessing.py index 09ba2ffc0..9f1ff89d0 100644 --- a/scico/flax/examples/data_preprocessing.py +++ b/scico/flax/examples/data_preprocessing.py @@ -418,7 +418,7 @@ def get_bsds_data(path: str, verbose: bool = False): # pragma: no cover verbose: Flag indicating whether to print status messages. """ # data source URL and filenames - data_base_url = "http://www.eecs.berkeley.edu/Research/Projects/CS/vision/grouping/BSR/" + data_base_url = "https://www2.eecs.berkeley.edu/Research/Projects/CS/vision/grouping/BSR/" data_tar_file = "BSR_bsds500.tgz" # ensure path directory exists if not os.path.isdir(path): diff --git a/scico/flax/examples/examples.py b/scico/flax/examples/examples.py index 02ee1c658..0bf428a33 100644 --- a/scico/flax/examples/examples.py +++ b/scico/flax/examples/examples.py @@ -452,9 +452,9 @@ def load_image_data( ) print( - "NOTE: If blur kernel or noise parameter are changed, the cache" - " must be manually deleted to ensure that the training data " - " is regenerated with these new parameters." + "NOTE: If blur kernel or noise parameter are changed, the cache " + "must be manually\n deleted to ensure that the training data" + " is regenerated with the new\n parameters." ) return train_ds, test_ds diff --git a/scico/flax/inverse.py b/scico/flax/inverse.py index 96e34d087..d6f4846b6 100644 --- a/scico/flax/inverse.py +++ b/scico/flax/inverse.py @@ -1,5 +1,5 @@ # -*- coding: utf-8 -*- -# Copyright (C) 2022-2023 by SCICO Developers +# Copyright (C) 2022-2024 by SCICO Developers # All rights reserved. BSD 3-clause License. # This file is part of the SCICO package. Details of the copyright and # user license can be found in the 'LICENSE' file distributed with the @@ -40,18 +40,17 @@ class MoDLNet(Module): Args: operator: Operator for computing forward and adjoint mappings. - depth: Depth of MoDL net. Default: 1. + depth: Depth of MoDL net. channels: Number of channels of input tensor. num_filters: Number of filters in the convolutional layer of the block. Corresponds to the number of channels in the output tensor. block_depth: Number of layers in the computational block. - kernel_size: Size of the convolution filters. Default: (3, 3). - strides: Convolution strides. Default: (1, 1). + kernel_size: Size of the convolution filters. + strides: Convolution strides. lmbda_ini: Initial value of the regularization weight `lambda`. - Default: 0.5. dtype: Output dtype. Default: :attr:`~numpy.float32`. - cg_iter: Number of iterations for cg solver. Default: 10. + cg_iter: Number of iterations for cg solver. """ operator: ModuleDef @@ -105,8 +104,7 @@ def lmbda_init_wrap(rng: PRNGKey, shape: Shape, dtype: DType = self.dtype) -> Ar for i in range(self.depth): z = resnet(x, train) - # Solve: - # (AH A + lmbda I) x = Ahb + lmbda * z + # Solve: (AH A + lmbda I) x = Ahb + lmbda * z b = Ahb + lmbda * z x = lax.map(cgsol, b) return x @@ -130,8 +128,8 @@ def cg_solver(A: Callable, b: Array, x0: Array = None, maxiter: int = 50) -> Arr A: Function implementing linear operator :math:`A`, should be positive definite. b: Input array :math:`\mb{b}`. - x0: Initial solution. Default: ``None``. - maxiter: Maximum iterations. Default: 50. + x0: Initial solution. + maxiter: Maximum iterations. Returns: x: Solution array. @@ -175,10 +173,9 @@ class ODPProxDnBlock(Module): num_filters: Number of filters in the convolutional layer of the block. Corresponds to the number of channels in the output tensor. - kernel_size: Size of the convolution filters. Default: (3, 3). - strides: Convolution strides. Default: (1, 1). + kernel_size: Size of the convolution filters. + strides: Convolution strides. alpha_ini: Initial value of the fidelity weight `alpha`. - Default: 0.2. dtype: Output dtype. Default: :attr:`~numpy.float32`. """ @@ -243,10 +240,9 @@ class ODPProxDcnvBlock(Module): num_filters: Number of filters in the convolutional layer of the block. Corresponds to the number of channels in the output tensor. - kernel_size: Size of the convolution filters. Default: (3, 3). - strides: Convolution strides. Default: (1, 1). + kernel_size: Size of the convolution filters. + strides: Convolution strides. alpha_ini: Initial value of the fidelity weight `alpha`. - Default: 0.99. dtype: Output dtype. Default: :attr:`~numpy.float32`. """ @@ -335,10 +331,9 @@ class ODPGrDescBlock(Module): num_filters: Number of filters in the convolutional layer of the block. Corresponds to the number of channels in the output tensor. - kernel_size: Size of the convolution filters. Default: (3, 3). - strides: Convolution strides. Default: (1, 1). + kernel_size: Size of the convolution filters. + strides: Convolution strides. alpha_ini: Initial value of the fidelity weight `alpha`. - Default: 0.2. dtype: Output dtype. Default: :attr:`~numpy.float32`. """ operator: ModuleDef @@ -403,19 +398,18 @@ class ODPNet(Module): Args: operator: Operator for computing forward and adjoint mappings. - depth: Depth of MoDL net. Default: 1. + depth: Depth of MoDL net. channels: Number of channels of input tensor. num_filters: Number of filters in the convolutional layer of the block. Corresponds to the number of channels in the output tensor. block_depth: Number of layers in the computational block. - kernel_size: Size of the convolution filters. Default: (3, 3). - strides: Convolution strides. Default: (1, 1). + kernel_size: Size of the convolution filters. + strides: Convolution strides. alpha_ini: Initial value of the fidelity weight `alpha`. - Default: 0.5. dtype: Output dtype. Default: :attr:`~numpy.float32`. odp_block: processing block to apply. Default - :class:`ODPProxDnBlock`. + :class:`.ODPProxDnBlock`. """ operator: ModuleDef diff --git a/scico/flax/train/clu_utils.py b/scico/flax/train/clu_utils.py index 59b0c575c..78cd8ec9b 100644 --- a/scico/flax/train/clu_utils.py +++ b/scico/flax/train/clu_utils.py @@ -1,5 +1,5 @@ # -*- coding: utf-8 -*- -# Copyright (C) 2022-2023 by SCICO Developers +# Copyright (C) 2022-2024 by SCICO Developers # All rights reserved. BSD 3-clause License. # This file is part of the SCICO package. Details of the copyright and # user license can be found in the 'LICENSE' file distributed with the @@ -8,13 +8,11 @@ """Utilities for displaying Flax models.""" # These utilities have been copied from the Common Loop Utils (CLU) -# https://github.com/google/CommonLoopUtils/tree/main/clu +# https://github.com/google/CommonLoopUtils/tree/main/clu # and have been modified to remove TensorFlow dependencies - -# CLU is licensed under the Apache License, Version 2.0 (the "License"); -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 +# CLU is licensed under the Apache License, Version 2.0, which may +# be obtained from +# http://www.apache.org/licenses/LICENSE-2.0 import warnings @@ -129,7 +127,7 @@ def make_row(name, value): def _default_table_value_formatter(value): - """Format ints with "," between thousands and floats to 3 digits.""" + """Format ints with "," between thousands, and floats to 3 digits.""" if isinstance(value, bool): return str(value) elif isinstance(value, int): @@ -159,7 +157,7 @@ def make_table( max_lines: Don't render a table longer than this. Returns: - A string representation of the table as in the example below. + A string representation of a table as in the example below. :: @@ -234,8 +232,7 @@ def get_parameter_overview( | FC_2/weights:0 | (1024, 32) | 32,768 | | FC_2/biases:0 | (32,) | 32 | +----------------+---------------+------------+ - - Total: 65,172,512 + Total weights: 65,172,512 """ if isinstance(params, (dict, flax.core.FrozenDict)): params = jax.tree_util.tree_map(np.asarray, params) @@ -245,4 +242,4 @@ def get_parameter_overview( # Pass in `column_names` to enable rendering empty tables. column_names = [field.name for field in dataclasses.fields(RowType)] table = make_table(rows, max_lines=max_lines, column_names=column_names) - return table + f"\nTotal: {total_weights:,}" + return table + f"\nTotal weights: {total_weights:,}" diff --git a/scico/flax/train/trainer.py b/scico/flax/train/trainer.py index 333a17636..ad569f20e 100644 --- a/scico/flax/train/trainer.py +++ b/scico/flax/train/trainer.py @@ -1,5 +1,5 @@ # -*- coding: utf-8 -*- -# Copyright (C) 2022-2023 by SCICO Developers +# Copyright (C) 2022-2024 by SCICO Developers # All rights reserved. BSD 3-clause License. # This file is part of the SCICO package. Details of the copyright and # user license can be found in the 'LICENSE' file distributed with the @@ -56,7 +56,7 @@ def sync_batch_stats(state: TrainState) -> TrainState: class BasicFlaxTrainer: - """Class for encapsulating Flax training configuration and execution.""" + """Class encapsulating Flax training configuration and execution.""" def __init__( self, @@ -82,7 +82,6 @@ def __init__( test_ds: Dictionary of testing data (includes images and labels). variables0: Optional initial state of model parameters. - Default: ``None``. """ # Configure seed if "seed" not in config: @@ -331,8 +330,8 @@ def construct_data_iterators( self.ishape = train_ds["image"].shape[1:3] self.log( - "Channels: %d, training signals: %d, testing" - " signals: %d, signal size: %d" + "channels: %d training signals: %d testing" + " signals: %d signal size: %d\n" % ( train_ds["label"].shape[-1], train_ds["label"].shape[0], @@ -393,7 +392,6 @@ def initialize_training_state( key: A PRNGKey used as the random key. model: Flax model to train. variables0: Optional initial state of model parameters. - Default: ``None``. """ # Create Flax training state state = self.create_train_state( @@ -404,9 +402,11 @@ def initialize_training_state( ok_no_ckpt = True # It is ok if no checkpoint is found state = checkpoint_restore(state, self.workdir, ok_no_ckpt) - self.log(get_parameter_overview(state.params)) + self.log("Network Structure:") + self.log(get_parameter_overview(state.params) + "\n") if hasattr(state, "batch_stats"): - self.log(get_parameter_overview(state.batch_stats)) + self.log("Batch Normalization:") + self.log(get_parameter_overview(state.batch_stats) + "\n") self.state = state @@ -414,12 +414,12 @@ def train(self) -> Tuple[Dict[str, Any], Optional[IterationStats]]: """Execute training loop. Returns: - Model variables extracted from TrainState and iteration - stats object obtained after executing the training loop. - Alternatively the TrainState can be returned directly instead - of the model variables. Note that the iteration stats object - is not ``None`` only if log is enabled when configuring the - training loop. + Model variables extracted from :class:`.TrainState` and + iteration stats object obtained after executing the training + loop. Alternatively the :class:`.TrainState` can be returned + directly instead of the model variables. Note that the + iteration stats object is not ``None`` only if log is enabled + when configuring the training loop. """ state = self.state step_offset = int(state.step) # > 0 if restarting from checkpoint @@ -428,7 +428,7 @@ def train(self) -> Tuple[Dict[str, Any], Optional[IterationStats]]: state = jax_utils.replicate(state) # Execute training loop and register stats t0 = time.time() - self.log("Initial compilation, this might take some minutes...") + self.log("Initial compilation, which might take some time ...") train_metrics: List[Any] = [] @@ -437,7 +437,7 @@ def train(self) -> Tuple[Dict[str, Any], Optional[IterationStats]]: # Training metrics computed in step train_metrics.append(metrics) if step == step_offset: - self.log("Initial compilation completed.") + self.log("Initial compilation completed.\n") if (step + 1) % self.log_every_steps == 0: # sync batch statistics across replicas state = sync_batch_stats(state) @@ -468,6 +468,8 @@ def train(self) -> Tuple[Dict[str, Any], Optional[IterationStats]]: "batch_stats": state.batch_stats, } + self.train_time = time.time() - t0 + return dvar, self.itstat_object # type: ignore def update_metrics(self, state: TrainState, step: int, train_metrics: List[MetricsDict], t0): diff --git a/scico/flax/train/typed_dict.py b/scico/flax/train/typed_dict.py index 78c69283a..c025db698 100644 --- a/scico/flax/train/typed_dict.py +++ b/scico/flax/train/typed_dict.py @@ -1,5 +1,5 @@ # -*- coding: utf-8 -*- -# Copyright (C) 2022-2023 by SCICO Developers +# Copyright (C) 2022-2024 by SCICO Developers # All rights reserved. BSD 3-clause License. # This file is part of the SCICO package. Details of the copyright and # user license can be found in the 'LICENSE' file distributed with the @@ -31,7 +31,7 @@ class DataSetDict(TypedDict): class ConfigDict(TypedDict): - """Dictionary structure for training parmeters. + """Dictionary structure for training parameters. Definition of the dictionary structure expected for specifying training parameters.""" diff --git a/scico/ray/__init__.py b/scico/ray/__init__.py index c7beb1d5e..9228f96b7 100644 --- a/scico/ray/__init__.py +++ b/scico/ray/__init__.py @@ -1,11 +1,11 @@ # -*- coding: utf-8 -*- -# Copyright (C) 2022-2023 by SCICO Developers +# Copyright (C) 2022-2024 by SCICO Developers # All rights reserved. BSD 3-clause License. # This file is part of the SCICO package. Details of the copyright and # user license can be found in the 'LICENSE' file distributed with the # package. -"""Simplified interfaces to :doc:`ray `.""" +"""Simplified interfaces to :doc:`Ray `.""" try: diff --git a/scico/scipy/__init__.py b/scico/scipy/__init__.py index 77bb4670e..2291ff85d 100644 --- a/scico/scipy/__init__.py +++ b/scico/scipy/__init__.py @@ -1,11 +1,11 @@ # -*- coding: utf-8 -*- -# Copyright (C) 2021 by SCICO Developers +# Copyright (C) 2021-2024 by SCICO Developers # All rights reserved. BSD 3-clause License. # This file is part of the SCICO package. Details of the copyright and # user license can be found in the 'LICENSE' file distributed with the # package. -"""Wrapped versions of :mod:`jax.scipy` functions. +"""Wrapped versions of `jax.scipy `_ functions. This modules currently serves simply as a namespace for :mod:`scico.scipy.special`. """ diff --git a/scico/test/flax/test_clu.py b/scico/test/flax/test_clu.py index 366926055..346ecc364 100644 --- a/scico/test/flax/test_clu.py +++ b/scico/test/flax/test_clu.py @@ -60,7 +60,7 @@ def test_count_parameters_empty(): | Name | Shape | Size | Mean | Std | +------+-------+------+------+-----+ +------+-------+------+------+-----+ -Total: 0""" +Total weights: 0""" FLAX_CONV2D_PARAMETER_OVERVIEW = """+-------------+--------------+------+ | Name | Shape | Size | @@ -68,7 +68,7 @@ def test_count_parameters_empty(): | conv/bias | (2,) | 2 | | conv/kernel | (3, 3, 3, 2) | 54 | +-------------+--------------+------+ -Total: 56""" +Total weights: 56""" FLAX_CONV2D_PARAMETER_OVERVIEW_WITH_STATS = """+-------------+--------------+------+------+-----+ | Name | Shape | Size | Mean | Std | @@ -76,7 +76,7 @@ def test_count_parameters_empty(): | conv/bias | (2,) | 2 | 1.0 | 0.0 | | conv/kernel | (3, 3, 3, 2) | 54 | 1.0 | 0.0 | +-------------+--------------+------+------+-----+ -Total: 56""" +Total weights: 56""" FLAX_CONV2D_MAPPING_PARAMETER_OVERVIEW_WITH_STATS = """+--------------------+--------------+------+------+-----+ | Name | Shape | Size | Mean | Std | @@ -84,7 +84,7 @@ def test_count_parameters_empty(): | params/conv/bias | (2,) | 2 | 1.0 | 0.0 | | params/conv/kernel | (3, 3, 3, 2) | 54 | 1.0 | 0.0 | +--------------------+--------------+------+------+-----+ -Total: 56""" +Total weights: 56""" # From https://github.com/google/CommonLoopUtils/blob/main/clu/parameter_overview_test.py def test_get_parameter_overview_empty(): diff --git a/scico/test/flax/test_flax.py b/scico/test/flax/test_flax.py index 0bca9a83b..c036527d9 100644 --- a/scico/test/flax/test_flax.py +++ b/scico/test/flax/test_flax.py @@ -310,7 +310,7 @@ def test_variable_load(variant): model = sflax.DnCNNNet(depth=nlayer, channels=chn, num_filters=64, dtype=np.float32) # Load weights for DnCNN. - variables = sflax.load_weights(_flax_data_path("dncnn%s.npz" % variant)) + variables = sflax.load_variables(_flax_data_path("dncnn%s.mpk" % variant)) try: fmap = sflax.FlaxMap(model, variables) @@ -328,7 +328,7 @@ def test_variable_load_mismatch(): nlayer = 6 model = sflax.ResNet(depth=nlayer, channels=chn, num_filters=64, dtype=np.float32) # Load weights for DnCNN. - variables = sflax.load_weights(_flax_data_path("dncnn6L.npz")) + variables = sflax.load_variables(_flax_data_path("dncnn6L.mpk")) # created with mismatched parameters fmap = sflax.FlaxMap(model, variables) @@ -350,7 +350,7 @@ def test_variable_save(): try: temp_dir = tempfile.TemporaryDirectory() - sflax.save_weights(unfreeze(variables), os.path.join(temp_dir.name, "vres6.npz")) + sflax.save_variables(unfreeze(variables), os.path.join(temp_dir.name, "vres6.mpk")) except Exception as e: print(e) assert 0