diff --git a/docs/src/DomainBuffers/StateVariables.md b/docs/src/DomainBuffers/StateVariables.md index 1d6aa5a..dcbcf53 100644 --- a/docs/src/DomainBuffers/StateVariables.md +++ b/docs/src/DomainBuffers/StateVariables.md @@ -17,4 +17,5 @@ by its cell number. (The output from the mentioned functions are `Dict{Int}`) ```@docs FerriteAssembly.create_cell_state update_states! +FerriteAssembly.remove_dual ``` \ No newline at end of file diff --git a/docs/src/literate_tutorials/viscoelasticity.jl b/docs/src/literate_tutorials/viscoelasticity.jl index 846d3b9..375e4ca 100644 --- a/docs/src/literate_tutorials/viscoelasticity.jl +++ b/docs/src/literate_tutorials/viscoelasticity.jl @@ -74,10 +74,11 @@ function FerriteAssembly.element_residual!(re, state, ae, m::ZenerMaterial, cv:: δ∇N = shape_symmetric_gradient(cv, q_point, i) re[i] += (δ∇N ⊡ σ) * dΩ end - ## Note that to save the state by mutation, we need to extract the value from the dual - ## number. Consequently, we do this before assigning to the state vector. Note that - ## if the state was a scalar, we should use `ForwardDiff.value` instead. - state[q_point] = Tensors._extract_value(ϵv) + ## We only want to save the value-part of the states, and FerriteAssembly comes with + ## the utility `FerriteAssembly.remove_dual` to do so for scalars and Tensors. + ## Note that using `state[q_point]` instead of ϵv for any calculations + ## affecting re, will result in wrong derivatives. + state[q_point] = FerriteAssembly.remove_dual(ϵv) end end; diff --git a/src/Utils/utils.jl b/src/Utils/utils.jl index 519a719..e394359 100644 --- a/src/Utils/utils.jl +++ b/src/Utils/utils.jl @@ -1,3 +1,21 @@ +# Public functions for convenience +""" + remove_dual(x::T) where {T <: Number} + remove_dual(x::AbstractTensor{<:Any, <:Any, T}) where {T} + +Removes the dual part if `T <: ForwardDiff.Dual`, extract the value part. +Typically used when assigning state variables during differentiation calls. +""" +function remove_dual end + +# Scalars +remove_dual(x::ForwardDiff.Dual) = ForwardDiff.value(x) +remove_dual(x::Number) = x + +# Tensors +remove_dual(x::AbstractTensor{<:Any, <:Any, <:ForwardDiff.Dual}) = Tensors._extract_value(x) +remove_dual(x::AbstractTensor) = x + # Internal functions used for convenience """ diff --git a/test/runtests.jl b/test/runtests.jl index 02f7998..7bfa97b 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1,4 +1,5 @@ using Ferrite, FerriteAssembly +using ForwardDiff using SparseArrays using Test import FerriteAssembly as FA @@ -46,6 +47,18 @@ include("errors.jl") @test FerriteAssembly.get_old_state(buffer, 1) === nothing end + + @testset "utility functions" begin + x = rand() + xd = ForwardDiff.Dual(x, rand(3)...) + @test x === FerriteAssembly.remove_dual(x) + @test x === FerriteAssembly.remove_dual(xd) + + t = rand(Tensor{2,3}) + td = Tensor{2,3}((i, j) -> ForwardDiff.Dual(t[i, j], rand(), rand())) + @test t === FerriteAssembly.remove_dual(t) + @test t === FerriteAssembly.remove_dual(td) + end end # Print show warning at the end if running tests single-threaded.