Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Determinant could be faster if use LU #456

Closed
cortner opened this issue Jun 26, 2021 · 16 comments
Closed

Determinant could be faster if use LU #456

cortner opened this issue Jun 26, 2021 · 16 comments
Labels
performance Related to improving computational performance

Comments

@cortner
Copy link

cortner commented Jun 26, 2021

I noticed the rrule for det does not use the LU factorisation. Is this intentional? Or is it implicit?

EDIT: implicitly both rrule and frule do use the factorisation but it could be reused for effiicency. For rrule a question of numerical stability remains but not clear to me yet if resolvable?

@oxinabox oxinabox changed the title Determinant Determinant could be faster if use LU Jun 26, 2021
@oxinabox
Copy link
Member

Yeah, I think that makes sense det would use LU under the hook.
but if we computed the LU ourselves we could reuse that facrorization to do /x_lu instead of *inv(x) in the pullback, right?

function rrule(::typeof(det), x::Union{Number, AbstractMatrix})
Ω = det(x)
function det_pullback(ΔΩ)
∂x = x isa Number ? ΔΩ : Ω * ΔΩ * inv(x)'
return (NoTangent(), ∂x)
end
return Ω, det_pullback
end

It would be one of the cases where we can benifit from changing the primal computation

@oxinabox oxinabox added the performance Related to improving computational performance label Jun 26, 2021
@cortner
Copy link
Author

cortner commented Jun 26, 2021

Yes that was my point.

One difficulty (??) - what if the LinearAlgebra implementation changes? Unlikely here but maybe not in general. I guess ChainRules needs to keep track of such changes? But is it also a bit concerning that this strategy requires code duplication?

@oxinabox
Copy link
Member

what if the LinearAlgebra implementation changes? Unlikely here but maybe not in general.

The math remains the same even if implementation changes.
we don't promise the same primal answer as running the original primal function call.
we only suggest that we will meet the same accuracy thresholds where documented.
(just like julia versions)
If the primal function is 5ULP, then we will aim also to be 5ULP but might be wrong in opposite direction.

But is it also a bit concerning that this strategy requires code duplication?

It's how it is.
Sometimes you want to modify the primal for this reason.
You need some extra internal information from the middle like this,
(or you want to calculate it in a totally different way)

In theory not having a rule can be better, if the optimizer can inline everything and see the common subexpressions (like computing lu for det and for inv) and eliminate them.
In practice that is very rare (and doing it between pullback and primal computation is impossible without opaque closures)

@cortner
Copy link
Author

cortner commented Jun 26, 2021

Thanks for the thoughts

@sethaxen
Copy link
Member

IIRC, we didn't use lu here because this rule is for abstract types, meaning it will in an almost type-piratical way make any specialized implementations of det for some other matrix type invisible to the AD, where instead we are doing something generic. That specialized matrix type might have some other very efficient way to compute det that doesn't use the LU decomposition, and the decomposition may be sloooow for that type. So to try to salvage this bad situation, we potentially do twice the work, i.e. implicitly decomposing twice.

@cortner
Copy link
Author

cortner commented Jun 27, 2021

I think it is a little worse than that. By calling inv you are assembling the inverse matrix, which should almost never be done. Even when computed via LU, is this even numerically stable? Ok to have a fallback, but I feel at least for any matrix that has an LU decomposition defined one should do it via LU rather than inv.

@oxinabox
Copy link
Member

we can at least add the LU path for StridedMatrix{<:BLASFLoat} which IIRC should all have LU.

@sethaxen
Copy link
Member

I think it is a little worse than that. By calling inv you are assembling the inverse matrix, which should almost never be done. Even when computed via LU, is this even numerically stable? Ok to have a fallback, but I feel at least for any matrix that has an LU decomposition defined one should do it via LU rather than inv.

If you work out the rrule for computing the determinant from the LU decomposition and then compose it with the rrule for the LU decomposition, you end up with (conj(d) * Δd) * inv(F)', where F is the LU decomposition, d is the determinant, and Δd is the cotangent of the determinant, and you'll notice this is identical to our rrule for det, so there's no way around computing the inverse here.

From the LU decomposition, the matrix inverse can be computed quickly, cheaply, and in-place using two applications of backwards substitution, so the only remaining question is one of stability. The matrix inverse does not exist exactly when the determinant is exactly zero. We don't currently do any special-casing for the zero-determinant case, but perhaps we should. By the subgradient convenient, the cotangent should then be the zero matrix.

we can at least add the LU path for StridedMatrix{<:BLASFLoat} which IIRC should all have LU.

This would be safe to do. Since I think det is almost always computed from a factorization, the best solution ultimately is probably to ensure we have rules for all common factorizations, remove this det rrule entirely, and then only add a few det rules with factorization arguments if necessary. For det from lu, this will probably be a little less efficient than calling lu in the rrule, because instead of just calling inv(F), it goes through a few more steps. I think the main factorization we're missing rules for here is Bunch-Kaufman for Symmetric{<:BlasFloat} matrices.

@cortner
Copy link
Author

cortner commented Jun 27, 2021

When I take grad(det(A)), sure, there is no other way, except possibly have a lazy representation of the inverse (which I'd personally prefer but appreciate this may lead to other problems).

However, if A = A(p) and I want grad(det(A(p)) then wouldn't the process be

  • compute A + store the pullback ddet -> dA(p)'[ddet]
  • compute det(A) + ddet
  • now apply the pullback

when applying the pullback would I not want to use the LU factorisation instead of the "collected" matrix?

@sethaxen
Copy link
Member

However, if A = A(p) and I want grad(det(A(p)) then wouldn't the process be

  • compute A + store the pullback ddet -> dA(p)'[ddet]
  • compute det(A) + ddet

I wasn't able to follow. What do these terms mean?

@cortner
Copy link
Author

cortner commented Jun 27, 2021

f(p) = det(A(p))

Forward pass:

  1. Compute A and the pullback closure ddet -> adjoint(dA(p)) * ddet
  2. Compute det(A) and ddet = ddet(A)

Backward pass:

  1. apply the closure to get grad_f(p) = adjoint(dA(p)) * ddet

(perfectly possible I'm missing something or getting something wrong here ... I only just started to think about AD at the implementation level.)

@sethaxen
Copy link
Member

I'm sorry, it's still not entirely clear to me what your notation means. e.g. you seem to be using the prefix of d to denote a cotangent, but cotangents cannot be multiplied by each other (due to linearity of the pullback operator), so I'm not certain what adjoint(dA(p)) * ddet means.

@cortner
Copy link
Author

cortner commented Jun 27, 2021

I'm appying an adjoint operator - this application is denotes by *, I'm not multiplying two tangents. I apologize if my notation is not precise as you'd like it but as an expert you can probably make an educated guess what I mean.

@cortner
Copy link
Author

cortner commented Jun 30, 2021

So I'm trying to write down a concrete use-case, and at least the ones I had in mind when posting this seem to be more relevant for frule than rrule. I will momentarily edit the first post to add this point.

Re the frule:

  D[ det(A) ] = det(A) tr[ A \ DA ]

and for the A \ DA operation one can again reuse the factorisation. (I'm specifically interested in cases where DA has special structure, e.g., low rank.)

But I appreciate the problem of a generic vs specialised implementation.

I'll write something else on rrule later.

@cortner
Copy link
Author

cortner commented Jun 30, 2021

Returning to rrule - the following is not quite my use-case, but reasonably close and doesn't require further explanation. I'm actually confused how the rrule applies here?

We have parameters p = (pj) and

   L = sum_i   f( det(A(xi, p)) - yi )

with D = d / dpj, Ai = A(xi,p), fi' = f'(det(Ai) - yi), DAi = DA(xi, p), then

  DL = sum_i fi' * det(Ai) * tr[ Ai \ DAi ]

I'm now struggling to re-order the operations to see the backpropagation. Is it simply this?

  DL = tr[  sum_i  { fi' * det Ai * inv(Ai) } *  DAi  ]

? I.e. the backpropagation would be:

  • compute [ fi' ]
  • compute the matrices Gi = det Ai / fi' * inv(Ai)
  • sum_i Gi \ DAi
  • tr

And would you agree that this indicates I should use the frule in such a scenario?

If what I've written is correct, then there is still the issue left that "collecting" inv(Ai) is not necessary and for numerical stability could be replace with a lazy inverse operation?

@cortner
Copy link
Author

cortner commented Jul 6, 2021

I now think this issue is irrelevant, will close reopen a new issue.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
performance Related to improving computational performance
Projects
None yet
Development

No branches or pull requests

3 participants