Skip to content

Commit

Permalink
Add remove_dual function
Browse files Browse the repository at this point in the history
  • Loading branch information
KnutAM committed Sep 7, 2024
1 parent d6f2771 commit ecf45d2
Show file tree
Hide file tree
Showing 4 changed files with 37 additions and 4 deletions.
1 change: 1 addition & 0 deletions docs/src/DomainBuffers/StateVariables.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,4 +17,5 @@ by its cell number. (The output from the mentioned functions are `Dict{Int}`)
```@docs
FerriteAssembly.create_cell_state
update_states!
FerriteAssembly.remove_dual
```
9 changes: 5 additions & 4 deletions docs/src/literate_tutorials/viscoelasticity.jl
Original file line number Diff line number Diff line change
Expand Up @@ -74,10 +74,11 @@ function FerriteAssembly.element_residual!(re, state, ae, m::ZenerMaterial, cv::
δ∇N = shape_symmetric_gradient(cv, q_point, i)
re[i] += (δ∇N σ) *
end
## Note that to save the state by mutation, we need to extract the value from the dual
## number. Consequently, we do this before assigning to the state vector. Note that
## if the state was a scalar, we should use `ForwardDiff.value` instead.
state[q_point] = Tensors._extract_value(ϵv)
## We only want to save the value-part of the states, and FerriteAssembly comes with
## the utility `FerriteAssembly.remove_dual` to do so for scalars and Tensors.
## Note that using `state[q_point]` instead of ϵv for any calculations
## affecting re, will result in wrong derivatives.
state[q_point] = FerriteAssembly.remove_dual(ϵv)
end
end;

Expand Down
18 changes: 18 additions & 0 deletions src/Utils/utils.jl
Original file line number Diff line number Diff line change
@@ -1,3 +1,21 @@
# Public functions for convenience
"""
remove_dual(x::T) where {T <: Number}
remove_dual(x::AbstractTensor{<:Any, <:Any, T}) where {T}
Removes the dual part if `T <: ForwardDiff.Dual`, extract the value part.
Typically used when assigning state variables during differentiation calls.
"""
function remove_dual end

# Scalars
remove_dual(x::ForwardDiff.Dual) = ForwardDiff.value(x)
remove_dual(x::Number) = x

# Tensors
remove_dual(x::AbstractTensor{<:Any, <:Any, <:ForwardDiff.Dual}) = Tensors._extract_value(x)
remove_dual(x::AbstractTensor) = x

# Internal functions used for convenience

"""
Expand Down
13 changes: 13 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
using Ferrite, FerriteAssembly
using ForwardDiff
using SparseArrays
using Test
import FerriteAssembly as FA
Expand Down Expand Up @@ -46,6 +47,18 @@ include("errors.jl")
@test FerriteAssembly.get_old_state(buffer, 1) === nothing

end

@testset "utility functions" begin
x = rand()
xd = ForwardDiff.Dual(x, rand(3)...)
@test x === FerriteAssembly.remove_dual(x)
@test x === FerriteAssembly.remove_dual(xd)

t = rand(Tensor{2,3})
td = Tensor{2,3}((i, j) -> ForwardDiff.Dual(t[i, j], rand(), rand()))
@test t === FerriteAssembly.remove_dual(t)
@test t === FerriteAssembly.remove_dual(td)
end
end

# Print show warning at the end if running tests single-threaded.
Expand Down

0 comments on commit ecf45d2

Please sign in to comment.