diff --git a/Project.toml b/Project.toml index 67c1b24..44e2fba 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "SLEEFPirates" uuid = "476501e8-09a2-5ece-8869-fb82de89a1fa" authors = ["chriselrod "] -version = "0.6.33" +version = "0.6.34" [deps] IfElse = "615f187c-cbe4-4ef1-ba3b-2fcf58d6d173" diff --git a/src/SLEEFPirates.jl b/src/SLEEFPirates.jl index ce03223..e6f8134 100644 --- a/src/SLEEFPirates.jl +++ b/src/SLEEFPirates.jl @@ -231,6 +231,14 @@ ldexp(x::Float16, q::Int) = Float16(ldexpk(Float32(x), q)) max_tanh(::Type{Float64}) = 19.06154746539849599509609553228539867418786340504817671278462587964799037885145 max_tanh(::Type{Float32}) = 9.010913339828708369989037671244720498805572920317272822795576296065428827978905f0 +@inline function tanh_fast(x::AbstractSIMD{W,Float32}) where {W} + # stolen from https://github.com/FluxML/NNlib.jl/pull/345 + # https://github.com/FluxML/NNlib.jl/blob/5dd04df4e95f9f9b70d6232fac546f3e98899fc2/src/activations.jl#L766-L773 + x2 = abs2(x) + n = evalpoly(x2, (1.0f0, 0.1346604f0, 0.0035974074f0, 2.2332108f-5, 1.587199f-8)) + d = evalpoly(x2, (1.0f0, 0.4679937f0, 0.026262015f0, 0.0003453992f0, 8.7767893f-7)) + ifelse(x2 < 66f0, @fastmath(x * (n / d)), sign(x)) +end @inline function tanh_fast(x) exp2xm1 = expm1_fast(Base.FastMath.add_fast(x, x)) # Division is faster than approximate inversion in