Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added a new Functional for TV Norm #456

Merged
merged 30 commits into from
Nov 5, 2023
Merged

Added a new Functional for TV Norm #456

merged 30 commits into from
Nov 5, 2023

Conversation

shnaqvi
Copy link
Contributor

@shnaqvi shnaqvi commented Oct 5, 2023

implementing its proximal operator using the fast subiteration-free algorithm proposed by Kamilov, 2016.

P.S. Tested it on a simple deblur problem with TV regularization using synthetic image generated with anisotropic gaussian blur and the reconstructed result is promising, with 10 iterations returning in 11 sec on 4kx3k image without GPU acceleration.

image

@bwohlberg bwohlberg self-assigned this Oct 6, 2023
@bwohlberg bwohlberg added the enhancement New feature or request label Oct 6, 2023
@bwohlberg bwohlberg linked an issue Oct 6, 2023 that may be closed by this pull request
@bwohlberg
Copy link
Collaborator

I'll check this carefully as soon as I can find the time, but for now, can you try to track down why some of the tests are failing?

@shnaqvi
Copy link
Contributor Author

shnaqvi commented Oct 6, 2023

can you try to track down why some of the tests are failing?

I've fixed all but 1 error, because of which it's failing the same test for MacOS and Ubuntu. My TV2DNorm has the same output shape as the input which is also the case for an already defined norm L1MinusL2Norm, but TV2DNorm keeps failing that test with ValueError that Cannot stack LinearOperators with nested output shapes. Do you have a suggestion on why could this be?

@bwohlberg
Copy link
Collaborator

I will try to find some time to figure out what's causing the error. Could you give me permission to push to your fork/branch?

@shnaqvi
Copy link
Contributor Author

shnaqvi commented Oct 9, 2023

Could you give me permission to push to your fork/branch?

Just did!

@bwohlberg
Copy link
Collaborator

Received, thanks.

@bwohlberg
Copy link
Collaborator

The test error was a result of the TV2DNorm being tested with BlockArray input, which it doesn't support. I've disabled that, but it still needs to be tested.

@codecov
Copy link

codecov bot commented Oct 9, 2023

Codecov Report

Merging #456 (3b7f75b) into main (27e2aec) will increase coverage by 0.02%.
The diff coverage is 95.35%.

@@            Coverage Diff             @@
##             main     #456      +/-   ##
==========================================
+ Coverage   94.56%   94.59%   +0.02%     
==========================================
  Files          88       89       +1     
  Lines        5498     5541      +43     
==========================================
+ Hits         5199     5241      +42     
- Misses        299      300       +1     
Flag Coverage Δ
unittests 94.59% <95.35%> (+0.02%) ⬆️

Flags with carried forward coverage won't be shown. Click here to find out more.

Files Coverage Δ
scico/functional/__init__.py 100.00% <100.00%> (ø)
scico/functional/_tvnorm.py 95.24% <95.24%> (ø)

... and 1 file with indirect coverage changes

@bwohlberg
Copy link
Collaborator

The cross-section of the solution in the plot in the initial PR comment looks suspicious: it's piece-wise constant at the bottom, but not at the top. It would be worth checking whether this due to a bug in the code.

@shnaqvi
Copy link
Contributor Author

shnaqvi commented Oct 10, 2023

It would be worth checking whether this due to a bug in the code.

I see what you mean @bwohlberg. But I'm not able to see what could be causing it. I've gone through Kamilov's algorithm again in detail and am not seeing any obvious issues in my straightforward 50-line implementation of his proposed proximal of the TV Norm. I've also sent another email to Kamilov today to see if he can review my reference implementation. I get better results after 30 iterations, but there's still some overshoot in the top of the cross-section profile.
image

@shnaqvi
Copy link
Contributor Author

shnaqvi commented Oct 11, 2023

@bwohlberg , I heard back from Ulugbek Kamilov who couldn't share a reference implementation but didn't point out any issues with this implementation either. He said and I can verify that the results with this TV regularizer are converging to piecewise constant in the limit of iterations getting larger. Please let me know if you have any other reservations.

One point I did notice is that in Kamilov's implementation, he is applying the shrinkage only on the difference operators. I made that change on my end but I'm not seeing any obvious change in the general behavior we are observing. I'll regardless make that change for posterity and push the update. However, with that change enacted, I don't get piecewise constant even in the bottom of the profile in 10 iterations. So I will need to combine TV prior with a Nonnegative prior, for which I have opened a separate issue, as I couldn't figure out how to do it. It'd be great if you can help me with that! Thanks

…perator of the haar transform as in Kamilov, 2016
@bwohlberg
Copy link
Collaborator

Strange that the fix made it worse. For now, let's just focus on getting a correct implementation of the original method. Can you share your test script? (Preferably as an attachment so it doesn't obscure the discussion thread here.)

@shnaqvi
Copy link
Contributor Author

shnaqvi commented Oct 12, 2023

Can you share your test script? (Preferably as an attachment so it doesn't obscure the discussion thread here.)

@bwohlberg , please see the test code in the script attached. I agree that if the shrinkage is applied to the entire output of the haar transform, like I had done earlier, then we get better results, i.e. convergence to piecewise constant in the limit of the iterations (see image A below). But if we apply the shrinkage operator just to the difference operator's result in the haar transform, then we aren't getting piecewise constant profile even after 100 iterations (see image B below).

So it might be better to just consciously keep the shrinkage after taking the haar transform in the algorithm. I asked Kamilov in the email what does he see as the implication of having difference operator applied on the entire transform, but I haven't received a reply from him on that.

A:
image

B:
image

scico_deconv.ipynb.zip

@bwohlberg
Copy link
Collaborator

See the attached test script -- it looks as if the prox doesn't actually do anything.

tvnormtest.zip

Args:
tau: Parameter :math:`\tau` in the norm definition.
"""
self.tau = tau
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this play a useful role? The base Functional already handles multiplication by a scalar.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry I didn't understand. Do you mean the role of the tau parameter? It affects the threshold in the prox function.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But it just multiplies the prox scaling parameter, no? So what does it add over that parameter?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@bwohlberg , sure it does scales the prox parameter in the prox() function, but it is also used in the __call__ function to evaluate the Norm to scale the output, right?

@@ -477,3 +479,110 @@ def prox(
svdU, svdS, svdV = snp.linalg.svd(v, full_matrices=False)
svdS = snp.maximum(0, svdS - lam)
return svdU @ snp.diag(svdS) @ svdV


class TV2DNorm(Functional):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would be better for this to be able to handle arbitrary numbers of dimensions.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@bwohlberg , we can extend it to 3D by adding another set of conditional statements if axis == 2. However, it doesn't look straightforward to extend it to arbitrary number of dimensions because we need to make in-place updates and in Jax we do this using arr.at().set() framework. But while numpy allows for indexing using .take() function, Jax doesn't seem to allow it. Do you know of a way to update immutable arrays in Jax by indexing using equivalent of take() to replace at().set()?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's come back to this after having resolved the other issues.

@shnaqvi
Copy link
Contributor Author

shnaqvi commented Oct 31, 2023

I'm not convinced that the order of the 2D transform and shrinkage is correct, which is a consequence of the 1D forward transform and shrinkage being fused into a single function. The correct approach would be to compute the full 2D transform, shrink the highpass coefficients, and then apply the pseudo-inverse.

@bwohlberg , that is what I was initially doing which gave me better results. Yes we can totally do that. Instead of shrinking just the coefficients from the difference operator in the Haar transform, apply the shrinkage operator to the entire forward transform output. And I did check how the TV Norm does in our simple test case of a sinusoid and results are visually the same.

@bwohlberg
Copy link
Collaborator

@bwohlberg , that is what I was initially doing which gave me better results. Yes we can totally do that. Instead of shrinking just the coefficients from the difference operator in the Haar transform, apply the shrinkage operator to the entire forward transform output. And I did check how the TV Norm does in our simple test case of a sinusoid and results are visually the same.

That's not quite what I meant. It seems pretty clear from the paper that the shrinkage should only be applied to the highpass coefficients, but this should be done after the decomposition has been computed in both spatial directions.

@shnaqvi
Copy link
Contributor Author

shnaqvi commented Nov 1, 2023

The TV2DNorm implementation is not the same as in the PR branch?

Sorry @bwohlberg, I went back investigating. Indeed, this is different (attached again here, _norm.py.zip). And it is basically the original version of the implementation that I had submitted in the PR, which has the shrinkage operator applied to the entirety of the forward transform.

I verified that with the shrinkage operator applied inside the forward transform function, I am able to reproduce your result, which is that I am not seeing any effect of the proximal parameter. But I am not able to spot any coding errors that could be causing this behavior.

So basically with the shrinkage operator applied on the entirety of the forward transform, we are getting expected clipping behavior and none otherwise.

@shnaqvi
Copy link
Contributor Author

shnaqvi commented Nov 1, 2023

It seems pretty clear from the paper that the shrinkage should only be applied to the highpass coefficients, but this should be done after the decomposition has been computed in both spatial directions.

I am seeing it differently. If you please refer to equation 20b in Kamilov, 2016, the shrinkage operator is applied inside the summation sign after the kth decomposition is computed. And later we sum over both the spatial dimensions and the shifts. Do you agree?

@bwohlberg
Copy link
Collaborator

I am seeing it differently. If you please refer to equation 20b in Kamilov, 2016, the shrinkage operator is applied inside the summation sign after the kth decomposition is computed. And later we sum over both the spatial dimensions and the shifts. Do you agree?

I agree. My concern was that the decompositions were being mixed along different axes with interleaved shrinkage, but on looking again I see that I was not reading the code correctly. So I don't see any obvious errors in the implementation.

@shnaqvi
Copy link
Contributor Author

shnaqvi commented Nov 2, 2023

(I would really expect to see the smooth waveform being clipped to something closer to a square wave.)

@bwohlberg , since TV norm is minimizing the summation of the local difference in the image, I think we'll see the features in the image become more smooth (in the sense of high frequencies suppressed, while preserving the low frequencies, so as to not introduce newer sharp edges).

Here, I did some more testing with an image having both low and high frequency features (sinusoidal and square boundaries). I also compared the two implementations of the proximal operator of the TV norm, 1) where we apply shrinkage to just the difference operator or high pass features (labeled "Shrink-Highpass") and 2) where we apply the shrinkage to the entirety of the forward Haar transform (labeled "Shrink-Forward").

Since in my original implementation ("Shrink-Forward"), we were shrinking the entire transform, we begin to see clipping of the signal and the dynamic range is also lost, whereas "Shrink-Highpass" retains the dynamic range while minimizing the integral of the image gradient magnitude (see below). Both of them retain the edges in the image as expected.
image

And as the proximal parameter increases, "Shrink-Highpass"'s proximal function returns a more smooth signal (to a certain limit), while "Shrink-Forward" returns a more clipped signal (to an absolute zero in the limit). (see below).
image

So I don't see any obvious errors in the implementation.

So if you're okay, can we merge the PR to the master? :-)

Here are my scripts: scico_tvnormtest.zip

@bwohlberg
Copy link
Collaborator

[...]
And as the proximal parameter increases, "Shrink-Highpass"'s proximal function returns a more smooth signal (to a certain limit), while "Shrink-Forward" returns a more clipped signal (to an absolute zero in the limit). (see below).
[...]

Thanks for adding the additional results. It seems clear that the shrinkage should only be applied to the highpass component.

So I don't see any obvious errors in the implementation.

So if you're okay, can we merge the PR to the master? :-)

We still need tests for the new code. I am working on that, as well as an alternative implementation that supports input arrays of arbitrary dimensionality.

@shnaqvi
Copy link
Contributor Author

shnaqvi commented Nov 3, 2023

Thanks for adding the additional results. It seems clear that the shrinkage should only be applied to the highpass component.

@bwohlberg, for completeness, I further denoised over 100 iterations of APGM with identity forward operator and this TV Prior (Shrink-Highpass) and the results show a convincingly denoised image, amid heavy gaussian noise.
image

image

@bwohlberg
Copy link
Collaborator

Please take a look at the new pull request in your fork of scico.

@bwohlberg
Copy link
Collaborator

@shnaqvi: Merging shortly. Thanks for your help in getting this useful feature added.

@bwohlberg bwohlberg merged commit 08a5896 into lanl:main Nov 5, 2023
16 checks passed
@shnaqvi
Copy link
Contributor Author

shnaqvi commented Nov 5, 2023

Pleasure is mine, and thanks for proactively optimizing and integrating it! Next up, Proximal Averaged APGM for composite priors minimization! 😃

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Enabling use of TV prior with a forward operator with Accelerated PGM
2 participants