From de6b551059d11be3e4234591010405449bf86b9a Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Fri, 12 Jul 2024 13:22:38 -0700 Subject: [PATCH] refactor: use the faster `get_device_type` --- Project.toml | 2 +- .../api/Accelerator_Support/LuxDeviceUtils.md | 1 + ext/LuxForwardDiffExt/LuxForwardDiffExt.jl | 3 ++- src/distributed/public_api.jl | 6 ++--- src/helpers/nested_ad.jl | 2 +- src/layers/extension.jl | 22 ++++++++++--------- src/transform/flux.jl | 4 ++-- 7 files changed, 22 insertions(+), 18 deletions(-) diff --git a/Project.toml b/Project.toml index 174aa9a4ad..1be35ba747 100644 --- a/Project.toml +++ b/Project.toml @@ -92,7 +92,7 @@ LinearAlgebra = "1.10" Logging = "1.10" LossFunctions = "0.11.1" LuxCore = "0.1.16" -LuxDeviceUtils = "0.1.22" +LuxDeviceUtils = "0.1.25" LuxLib = "0.3.23" LuxTestUtils = "0.1.15" MLUtils = "0.4.3" diff --git a/docs/src/api/Accelerator_Support/LuxDeviceUtils.md b/docs/src/api/Accelerator_Support/LuxDeviceUtils.md index 09b9e3e7ec..75c73426a2 100644 --- a/docs/src/api/Accelerator_Support/LuxDeviceUtils.md +++ b/docs/src/api/Accelerator_Support/LuxDeviceUtils.md @@ -29,6 +29,7 @@ reset_gpu_device! supported_gpu_backends default_device_rng get_device +get_device_type LuxDeviceUtils.loaded LuxDeviceUtils.functional ``` diff --git a/ext/LuxForwardDiffExt/LuxForwardDiffExt.jl b/ext/LuxForwardDiffExt/LuxForwardDiffExt.jl index 0479b42e08..db861646ff 100644 --- a/ext/LuxForwardDiffExt/LuxForwardDiffExt.jl +++ b/ext/LuxForwardDiffExt/LuxForwardDiffExt.jl @@ -3,7 +3,8 @@ module LuxForwardDiffExt using ArgCheck: @argcheck using ADTypes: AutoForwardDiff using ChainRulesCore: ChainRulesCore -using Lux: Lux, get_device +using Lux: Lux +using LuxDeviceUtils: get_device_type using FastClosures: @closure using ForwardDiff: ForwardDiff using Functors: fmap diff --git a/src/distributed/public_api.jl b/src/distributed/public_api.jl index 5e31888d1b..34df5b46c8 100644 --- a/src/distributed/public_api.jl +++ b/src/distributed/public_api.jl @@ -100,7 +100,7 @@ end function bcast!(backend::AbstractLuxDistributedBackend, sendbuf, recvbuf; root::Int=0) send_dev = get_device(sendbuf) recv_dev = get_device(recvbuf) - if send_dev === recv_dev + if send_dev == recv_dev return __bcast!(backend, sendbuf, recvbuf, send_dev; root) else sendbuf_ = sendbuf |> recv_dev @@ -134,7 +134,7 @@ function allreduce!( backend::AbstractLuxDistributedBackend, sendbuf, recvbuf, op::F) where {F} send_dev = get_device(sendbuf) recv_dev = get_device(recvbuf) - if send_dev === recv_dev + if send_dev == recv_dev return __allreduce!(backend, sendbuf, recvbuf, op, send_dev) else sendbuf_ = sendbuf |> recv_dev @@ -167,7 +167,7 @@ function reduce!(backend::AbstractLuxDistributedBackend, sendbuf, recvbuf, op::F; root::Int=0) where {F} send_dev = get_device(sendbuf) recv_dev = get_device(recvbuf) - if send_dev === recv_dev + if send_dev == recv_dev return __reduce!(backend, sendbuf, recvbuf, op, send_dev; root) else sendbuf_ = sendbuf |> recv_dev diff --git a/src/helpers/nested_ad.jl b/src/helpers/nested_ad.jl index d0bb26e397..7720d2ef64 100644 --- a/src/helpers/nested_ad.jl +++ b/src/helpers/nested_ad.jl @@ -160,7 +160,7 @@ function CRC.rrule(cfg::RuleConfig{>:HasReverseMode}, ::typeof(__internal_ad_jac # FIXME: threading on CUDA cause unexpected errors on the first run to CUDNN # when doing a algorithm lookup - ∂x, ∂y = if get_device(x) isa LuxCPUDevice + ∂x, ∂y = if get_device_type(x) <: LuxCPUDevice tasks = map(i -> Threads.@spawn(map_fn(i)), 1:__numrows(Δ)) mapreduce(fetch, recursive_add!!, tasks) else diff --git a/src/layers/extension.jl b/src/layers/extension.jl index 4d5c53d838..c0a4ebd74b 100644 --- a/src/layers/extension.jl +++ b/src/layers/extension.jl @@ -122,12 +122,12 @@ end y = match_eltype(de, ps, st, x) return ( __apply_dynamic_expression( - de, de.expression, de.operator_enum, y, ps.params, get_device(x)), + de, de.expression, de.operator_enum, y, ps.params, get_device_type(x)), st) end function __apply_dynamic_expression( - de::DynamicExpressionsLayer, expr, operator_enum, x, ps, ::LuxCPUDevice) + de::DynamicExpressionsLayer, expr, operator_enum, x, ps, ::Type{LuxCPUDevice}) __update_expression_constants!(expr, ps) return expr(x, operator_enum; de.turbo, de.bumper) end @@ -135,7 +135,7 @@ end function __apply_dynamic_expression_rrule end function CRC.rrule(::typeof(__apply_dynamic_expression), de::DynamicExpressionsLayer, - expr, operator_enum, x, ps, ::LuxCPUDevice) + expr, operator_enum, x, ps, ::Type{LuxCPUDevice}) if !_is_extension_loaded(Val(:DynamicExpressions)) error("`DynamicExpressions.jl` is not loaded. Please load it before using \ computing gradient for `DynamicExpressionLayer`.") @@ -143,9 +143,9 @@ function CRC.rrule(::typeof(__apply_dynamic_expression), de::DynamicExpressionsL return __apply_dynamic_expression_rrule(de, expr, operator_enum, x, ps) end -function __apply_dynamic_expression(de, expr, operator_enum, x, ps, dev) +function __apply_dynamic_expression(de, expr, operator_enum, x, ps, ::Type{DEV}) where {DEV} throw(ArgumentError("`DynamicExpressions.jl` only supports CPU operations. Current \ - device detected as $(dev). CUDA.jl will be supported after \ + device detected as $(DEV). CUDA.jl will be supported after \ https://github.com/SymbolicML/DynamicExpressions.jl/pull/65 is \ merged upstream.")) end @@ -243,19 +243,21 @@ end @inline function (sc::SimpleChainsLayer{false})(x, ps, st) y = match_eltype(sc, ps, st, x) - return __apply_simple_chain(sc.layer, y, ps.params, get_device(x)), st + return __apply_simple_chain(sc.layer, y, ps.params, get_device_type(x)), st end @inline function (sc::SimpleChainsLayer{true})(x, ps, st) y = match_eltype(sc, ps, st, x) - return convert(Array, __apply_simple_chain(sc.layer, y, ps.params, get_device(x))), st + return ( + convert(Array, __apply_simple_chain(sc.layer, y, ps.params, get_device_type(x))), + st) end -@inline __apply_simple_chain(layer, x, ps, ::LuxCPUDevice) = layer(x, ps) +@inline __apply_simple_chain(layer, x, ps, ::Type{LuxCPUDevice}) = layer(x, ps) -function __apply_simple_chain(layer, x, ps, dev) +function __apply_simple_chain(layer, x, ps, ::Type{DEV}) where {DEV} throw(ArgumentError("`SimpleChains.jl` only supports CPU operations. Current device \ - detected as $(dev).")) + detected as $(DEV).")) end # Workaround for SimpleChains not being able to handle some input types diff --git a/src/transform/flux.jl b/src/transform/flux.jl index 737f7a63e6..33ae4c0110 100644 --- a/src/transform/flux.jl +++ b/src/transform/flux.jl @@ -61,10 +61,10 @@ Base.@deprecate transform(l; preserve_ps_st::Bool=false, force_preserve::Bool=fa FromFluxAdaptor(preserve_ps_st, force_preserve), l) @inline function _maybe_flip_conv_weight(x::AbstractArray) - return _maybe_flip_conv_weight(x, get_device(x)) + return _maybe_flip_conv_weight(x, get_device_type(x)) end @inline _maybe_flip_conv_weight(x::AbstractArray, _) = copy(x) -@inline function _maybe_flip_conv_weight(x::AbstractArray, ::LuxAMDGPUDevice) +@inline function _maybe_flip_conv_weight(x::AbstractArray, ::Type{<:LuxAMDGPUDevice}) # This is a very rare operation, hence we dont mind allowing scalar operations return @allowscalar reverse(x; dims=ntuple(identity, ndims(x) - 2)) end