Skip to content

Commit

Permalink
Various improvements to scico.flax and related example scripts (#498)
Browse files Browse the repository at this point in the history
* Rename flax data files

* Minor docstring improvements

* Minor docstring improvement

* Remove astra-toolbox channel

* Minor improvement

* Add timer to trainer class

* Minor clean up

* Minor change to log format

* Improve log formatting

* Improve log formatting

* Update submodule

* Typo fix

* Improve log format

* Docs fixes

* Docs consistency

* Minor edit

* Docs consistency

* Fix broken cross-references

* Docs consistency

* Improve docs

* Rename functions

* Update function docs

* Update URL

* Clean up log format

* Update submodule

* Fix overly simple regex

* Trivial edit

* Clean up some scripts

* Update submodule

* Overlooked change from recent astra PR

* Add note on GPU support test script

* Overlooked change from recent astra PR

* Add script for removing error output from notebooks

* Update submodule

* Fix tests

* Fix tests

* Update submodule

---------

Co-authored-by: Brendt Wohlberg <brendt@lanl.gov>
  • Loading branch information
bwohlberg and Brendt Wohlberg authored Feb 5, 2024
1 parent c483304 commit 13ef0a8
Show file tree
Hide file tree
Showing 28 changed files with 209 additions and 163 deletions.
3 changes: 3 additions & 0 deletions CHANGES.rst
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,9 @@ Version 0.0.6 (unreleased)
----------------------------

• Significant changes to ``linop.xray.astra`` API.
• Rename ``scico.flax.save_weights`` and ``scico.flax.load_weights`` to
``scico.flax.save_variables`` and ``scico.flax.load_variables``
respectively.



Expand Down
5 changes: 5 additions & 0 deletions docs/source/install.rst
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,11 @@ a version with GPU support:
numbers.


The script `misc/envinfo.py <https://github.com/lanl/scico/blob/main/misc/envinfo.py>`_
in the source distribution is provided as an aid to debugging GPU support
issues.



Additional Dependencies
-----------------------
Expand Down
26 changes: 24 additions & 2 deletions examples/jnb.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# -*- coding: utf-8 -*-
# Copyright (C) 2022-2023 by SCICO Developers
# Copyright (C) 2022-2024 by SCICO Developers
# All rights reserved. BSD 3-clause License.
# This file is part of the SCICO package. Details of the copyright and
# user license can be found in the 'LICENSE' file distributed with the
Expand Down Expand Up @@ -44,7 +44,7 @@ def py_file_to_string(src):
else:
# Set flag indicating that an import statement has been seen once one has
# been encountered
if re.match("^(import|from)", line):
if re.match("^import|^from .* import", line):
import_seen = True
lines.append(line)
# Backtrack through list of lines to find last import statement
Expand Down Expand Up @@ -221,3 +221,25 @@ def replace_markdown_cells(src, dst):
# the dst cell
if srccell[n]["cell_type"] == "markdown":
dstcell[n]["source"] = srccell[n]["source"]


def remove_error_output(src):
"""Remove output to stderr from all cells in `src`."""

if "cells" in src:
cells = src["cells"]
else:
cells = src["worksheets"][0]["cells"]

modified = False
for c in cells:
if "outputs" in c:
dellist = []
for n, out in enumerate(c["outputs"]):
if "name" in out and out["name"] == "stderr":
dellist.append(n)
modified = True
for n in dellist[::-1]:
del c["outputs"][n]

return modified
18 changes: 18 additions & 0 deletions examples/removejnberr.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
#!/usr/bin/env python

# Remove output to stderr in notebooks. NB: use with caution!
# Run as
# python removejnberr.py

import glob
import os

from jnb import read_notebook, remove_error_output
from py2jn.tools import write_notebook

for src in glob.glob(os.path.join("notebooks", "*.ipynb")):
nb = read_notebook(src)
modflg = remove_error_output(nb)
if modflg:
print(f"Removing output to stderr from {src}")
write_notebook(nb, src)
17 changes: 8 additions & 9 deletions examples/scripts/ct_astra_modl_train_foam2.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@
from scico import metric, plot
from scico.flax.examples import load_ct_data
from scico.flax.train.traversals import clip_positive, construct_traversal
from scico.linop.xray.astra import XRayTransform
from scico.linop.xray.astra import XRayTransform2D

"""
Prepare parallel processing. Set an arbitrary processor count (only
Expand All @@ -81,9 +81,9 @@
Build CT projection operator.
"""
angles = np.linspace(0, np.pi, n_projection) # evenly spaced projection angles
A = XRayTransform(
A = XRayTransform2D(
input_shape=(N, N),
detector_spacing=1,
det_spacing=1,
det_count=N,
angles=angles,
) # CT projection operator
Expand Down Expand Up @@ -138,7 +138,7 @@


"""
Construct functionality for making sure that the learned
Construct functionality for ensuring that the learned
regularization parameter is always positive.
"""
lmbdatrav = construct_traversal("lmbda") # select lmbda parameters in model
Expand All @@ -152,8 +152,8 @@
"""
Print configuration of distributed run.
"""
print(f"{'JAX process: '}{jax.process_index()}{' / '}{jax.process_count()}")
print(f"{'JAX local devices: '}{jax.local_devices()}")
print(f"\nJAX process: {jax.process_index()}{' / '}{jax.process_count()}")
print(f"JAX local devices: {jax.local_devices()}\n")


"""
Expand Down Expand Up @@ -212,9 +212,8 @@
cg_iter=model_conf["cg_iter_1"],
)
# First stage: initialization training loop.
workdir = os.path.join(os.path.expanduser("~"), ".cache", "scico", "examples", "modl_ct_out")

train_conf["workdir"] = workdir
workdir1 = os.path.join(os.path.expanduser("~"), ".cache", "scico", "examples", "modl_ct_out")
train_conf["workdir"] = workdir1
train_conf["post_lst"] = [lmbdapos]
# Construct training object
trainer = sflax.BasicFlaxTrainer(
Expand Down
18 changes: 8 additions & 10 deletions examples/scripts/ct_astra_odp_train_foam2.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@
from scico import metric, plot
from scico.flax.examples import load_ct_data
from scico.flax.train.traversals import clip_positive, construct_traversal
from scico.linop.xray.astra import XRayTransform
from scico.linop.xray.astra import XRayTransform2D

"""
Prepare parallel processing. Set an arbitrary processor count (only
Expand All @@ -85,9 +85,9 @@
Build CT projection operator.
"""
angles = np.linspace(0, np.pi, n_projection) # evenly spaced projection angles
A = XRayTransform(
A = XRayTransform2D(
input_shape=(N, N),
detector_spacing=1,
det_spacing=1,
det_count=N,
angles=angles,
) # CT projection operator
Expand Down Expand Up @@ -138,7 +138,7 @@


"""
Construct functionality for making sure that the learned fidelity weight
Construct functionality for ensuring that the learned fidelity weight
parameter is always positive.
"""
alphatrav = construct_traversal("alpha") # select alpha parameters in model
Expand All @@ -152,8 +152,8 @@
"""
Print configuration of distributed run.
"""
print(f"{'JAX process: '}{jax.process_index()}{' / '}{jax.process_count()}")
print(f"{'JAX local devices: '}{jax.local_devices()}")
print(f"\nJAX process: {jax.process_index()}{' / '}{jax.process_count()}")
print(f"JAX local devices: {jax.local_devices()}\n")


"""
Expand Down Expand Up @@ -185,10 +185,7 @@
train_ds,
test_ds,
)

start_time = time()
modvar, stats_object = trainer.train()
time_train = time() - start_time


"""
Expand All @@ -215,13 +212,14 @@
psnr_eval = metric.psnr(test_ds["label"][:maxn], output)
print(
f"{'ODPNet training':18s}{'epochs:':2s}{epochs:>5d}{'':21s}"
f"{'time[s]:':10s}{time_train:>7.2f}"
f"{'time[s]:':10s}{trainer.train_time:>7.2f}"
)
print(
f"{'ODPNet testing':18s}{'SNR:':5s}{snr_eval:>5.2f}{' dB'}{'':3s}"
f"{'PSNR:':6s}{psnr_eval:>5.2f}{' dB'}{'':3s}{'time[s]:':10s}{time_eval:>7.2f}"
)


"""
Plot comparison.
"""
Expand Down
11 changes: 3 additions & 8 deletions examples/scripts/ct_astra_unet_train_foam2.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,21 +104,16 @@
"""
workdir = os.path.join(os.path.expanduser("~"), ".cache", "scico", "examples", "unet_ct_out")
train_conf["workdir"] = workdir
print(f"{'JAX process: '}{jax.process_index()}{' / '}{jax.process_count()}")
print(f"{'JAX local devices: '}{jax.local_devices()}")
print(f"\nJAX process: {jax.process_index()}{' / '}{jax.process_count()}")
print(f"JAX local devices: {jax.local_devices()}\n")


# Construct training object
trainer = sflax.BasicFlaxTrainer(
train_conf,
model,
train_ds,
test_ds,
)

start_time = time()
modvar, stats_object = trainer.train()
time_train = time() - start_time


"""
Expand All @@ -144,7 +139,7 @@
psnr_eval = metric.psnr(test_ds["label"][:maxn], output)
print(
f"{'UNet training':15s}{'epochs:':2s}{train_conf['num_epochs']:>5d}"
f"{'':21s}{'time[s]:':10s}{time_train:>7.2f}"
f"{'':21s}{'time[s]:':10s}{trainer.train_time:>7.2f}"
)
print(
f"{'UNet testing':15s}{'SNR:':5s}{snr_eval:>5.2f}{' dB'}{'':3s}"
Expand Down
2 changes: 0 additions & 2 deletions examples/scripts/deconv_datagen_bsds.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@
blur_sigma = 5 # Gaussian blur kernel parameter

opBlur = PaddedCircularConvolve(output_size, channels, blur_shape, blur_sigma)

opBlur_vmap = vmap(opBlur) # for batch processing


Expand All @@ -47,7 +46,6 @@
stride = 100 # stride to sample multiple patches from each image
augment = True # augment data via rotations and flips


train_ds, test_ds = load_image_data(
train_nimg,
test_nimg,
Expand Down
12 changes: 5 additions & 7 deletions examples/scripts/deconv_modl_train_foam1.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,6 @@

ishape = (output_size, output_size)
opBlur = CircularConvolve(h=psf, input_shape=ishape)

opBlur_vmap = jax.vmap(opBlur) # for batch processing in data generation


Expand Down Expand Up @@ -133,7 +132,7 @@


"""
Construct functionality for making sure that the learned regularization
Construct functionality for ensuring that the learned regularization
parameter is always positive.
"""
lmbdatrav = construct_traversal("lmbda") # select lmbda parameters in model
Expand All @@ -147,8 +146,8 @@
"""
Print configuration of distributed run.
"""
print(f"{'JAX process: '}{jax.process_index()}{' / '}{jax.process_count()}")
print(f"{'JAX local devices: '}{jax.local_devices()}")
print(f"\nJAX process: {jax.process_index()}{' / '}{jax.process_count()}")
print(f"JAX local devices: {jax.local_devices()}\n")


"""
Expand Down Expand Up @@ -204,9 +203,8 @@
cg_iter=model_conf["cg_iter"],
)
# First stage: initialization training loop.
workdir = os.path.join(os.path.expanduser("~"), ".cache", "scico", "examples", "modl_dcnv_out")

train_conf["workdir"] = workdir
workdir1 = os.path.join(os.path.expanduser("~"), ".cache", "scico", "examples", "modl_dcnv_out")
train_conf["workdir"] = workdir1
train_conf["post_lst"] = [lmbdapos]
# Construct training object
trainer = sflax.BasicFlaxTrainer(
Expand Down
14 changes: 5 additions & 9 deletions examples/scripts/deconv_odp_train_foam1.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,6 @@

ishape = (output_size, output_size)
opBlur = CircularConvolve(h=psf, input_shape=ishape)

opBlur_vmap = jax.vmap(opBlur) # for batch processing in data generation


Expand Down Expand Up @@ -153,7 +152,7 @@


"""
Construct functionality for making sure that the learned fidelity weight
Construct functionality for ensuring that the learned fidelity weight
parameter is always positive.
"""
alphatrav = construct_traversal("alpha") # select alpha parameters in model
Expand All @@ -167,10 +166,10 @@
"""
Run training loop.
"""
workdir = os.path.join(os.path.expanduser("~"), ".cache", "scico", "examples", "odp_dcnv_out")
print(f"{'JAX process: '}{jax.process_index()}{' / '}{jax.process_count()}")
print(f"{'JAX local devices: '}{jax.local_devices()}")
print(f"\nJAX process: {jax.process_index()}{' / '}{jax.process_count()}")
print(f"JAX local devices: {jax.local_devices()}\n")

workdir = os.path.join(os.path.expanduser("~"), ".cache", "scico", "examples", "odp_dcnv_out")
train_conf["workdir"] = workdir
train_conf["post_lst"] = [alphapos]
# Construct training object
Expand All @@ -180,10 +179,7 @@
train_ds,
test_ds,
)

start_time = time()
modvar, stats_object = trainer.train()
time_train = time() - start_time


"""
Expand All @@ -210,7 +206,7 @@
psnr_eval = metric.psnr(test_ds["label"][:maxn], output)
print(
f"{'ODPNet training':18s}{'epochs:':2s}{train_conf['num_epochs']:>5d}"
f"{'':21s}{'time[s]:':10s}{time_train:>7.2f}"
f"{'':21s}{'time[s]:':10s}{trainer.train_time:>7.2f}"
)
print(
f"{'ODPNet testing':18s}{'SNR:':5s}{snr_eval:>5.2f}{' dB'}{'':3s}"
Expand Down
13 changes: 4 additions & 9 deletions examples/scripts/denoise_dncnn_train_bsds.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,6 @@
noise_range = False # Use fixed noise level
stride = 23 # Stride to sample multiple patches from each image


train_ds, test_ds = load_image_data(
train_nimg,
test_nimg,
Expand Down Expand Up @@ -105,19 +104,15 @@
"""
workdir = os.path.join(os.path.expanduser("~"), ".cache", "scico", "examples", "dncnn_out")
train_conf["workdir"] = workdir
print(f"{'JAX process: '}{jax.process_index()}{' / '}{jax.process_count()}")
print(f"{'JAX local devices: '}{jax.local_devices()}")
print(f"\nJAX local devices: {jax.local_devices()}\n")

trainer = sflax.BasicFlaxTrainer(
train_conf,
model,
train_ds,
test_ds,
)

start_time = time()
modvar, stats_object = trainer.train()
time_train = time() - start_time


"""
Expand All @@ -138,7 +133,7 @@
psnr_eval = metric.psnr(test_ds["label"][:test_patches], output)
print(
f"{'DnCNNNet training':18s}{'epochs:':2s}{train_conf['num_epochs']:>5d}"
f"{'':21s}{'time[s]:':10s}{time_train:>7.2f}"
f"{'':21s}{'time[s]:':10s}{trainer.train_time:>7.2f}"
)
print(
f"{'DnCNNNet testing':18s}{'SNR:':5s}{snr_eval:>5.2f}{' dB'}{'':3s}"
Expand All @@ -147,8 +142,8 @@


"""
Plot comparison. Note that patches have small sizes, thus, plots may
correspond to unidentifiable fragments.
Plot comparison. Note that plots may display unidentifiable image
fragments due to the small patch size.
"""
np.random.seed(123)
indx = np.random.randint(0, high=test_patches)
Expand Down
Loading

0 comments on commit 13ef0a8

Please sign in to comment.