Skip to content

Commit

Permalink
refactor: use the faster get_device_type
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Jul 13, 2024
1 parent 6a9ef65 commit a431979
Show file tree
Hide file tree
Showing 7 changed files with 12 additions and 10 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ LinearAlgebra = "1.10"
Logging = "1.10"
LossFunctions = "0.11.1"
LuxCore = "0.1.16"
LuxDeviceUtils = "0.1.22"
LuxDeviceUtils = "0.1.26"
LuxLib = "0.3.23"
LuxTestUtils = "0.1.15"
MLUtils = "0.4.3"
Expand Down
1 change: 1 addition & 0 deletions docs/src/api/Accelerator_Support/LuxDeviceUtils.md
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ reset_gpu_device!
supported_gpu_backends
default_device_rng
get_device
get_device_type
LuxDeviceUtils.loaded
LuxDeviceUtils.functional
```
Expand Down
3 changes: 2 additions & 1 deletion ext/LuxForwardDiffExt/LuxForwardDiffExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@ module LuxForwardDiffExt
using ArgCheck: @argcheck
using ADTypes: AutoForwardDiff
using ChainRulesCore: ChainRulesCore
using Lux: Lux, get_device
using Lux: Lux
using LuxDeviceUtils: get_device
using FastClosures: @closure
using ForwardDiff: ForwardDiff
using Functors: fmap
Expand Down
6 changes: 3 additions & 3 deletions src/distributed/public_api.jl
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ end
function bcast!(backend::AbstractLuxDistributedBackend, sendbuf, recvbuf; root::Int=0)
send_dev = get_device(sendbuf)
recv_dev = get_device(recvbuf)
if send_dev === recv_dev
if send_dev == recv_dev
return __bcast!(backend, sendbuf, recvbuf, send_dev; root)
else
sendbuf_ = sendbuf |> recv_dev
Expand Down Expand Up @@ -134,7 +134,7 @@ function allreduce!(
backend::AbstractLuxDistributedBackend, sendbuf, recvbuf, op::F) where {F}
send_dev = get_device(sendbuf)
recv_dev = get_device(recvbuf)
if send_dev === recv_dev
if send_dev == recv_dev
return __allreduce!(backend, sendbuf, recvbuf, op, send_dev)
else
sendbuf_ = sendbuf |> recv_dev
Expand Down Expand Up @@ -167,7 +167,7 @@ function reduce!(backend::AbstractLuxDistributedBackend,
sendbuf, recvbuf, op::F; root::Int=0) where {F}
send_dev = get_device(sendbuf)
recv_dev = get_device(recvbuf)
if send_dev === recv_dev
if send_dev == recv_dev
return __reduce!(backend, sendbuf, recvbuf, op, send_dev; root)
else
sendbuf_ = sendbuf |> recv_dev
Expand Down
2 changes: 1 addition & 1 deletion src/helpers/nested_ad.jl
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,7 @@ function CRC.rrule(cfg::RuleConfig{>:HasReverseMode}, ::typeof(__internal_ad_jac

# FIXME: threading on CUDA cause unexpected errors on the first run to CUDNN
# when doing a algorithm lookup
∂x, ∂y = if get_device(x) isa LuxCPUDevice
∂x, ∂y = if get_device_type(x) <: LuxCPUDevice
tasks = map(i -> Threads.@spawn(map_fn(i)), 1:__numrows(Δ))
mapreduce(fetch, recursive_add!!, tasks)
else
Expand Down
4 changes: 2 additions & 2 deletions src/layers/extension.jl
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ end
y = match_eltype(de, ps, st, x)
return (
__apply_dynamic_expression(
de, de.expression, de.operator_enum, y, ps.params, get_device(x)),
de, de.expression, de.operator_enum, y, ps.params, get_device_type(x)),
st)
end

Expand Down Expand Up @@ -248,7 +248,7 @@ end

@inline function (sc::SimpleChainsLayer{true})(x, ps, st)
y = match_eltype(sc, ps, st, x)
return convert(Array, __apply_simple_chain(sc.layer, y, ps.params, get_device(x))), st
return (convert(Array, __apply_simple_chain(sc.layer, y, ps.params, get_device(x))), st)
end

@inline __apply_simple_chain(layer, x, ps, ::LuxCPUDevice) = layer(x, ps)
Expand Down
4 changes: 2 additions & 2 deletions src/transform/flux.jl
Original file line number Diff line number Diff line change
Expand Up @@ -61,10 +61,10 @@ Base.@deprecate transform(l; preserve_ps_st::Bool=false, force_preserve::Bool=fa
FromFluxAdaptor(preserve_ps_st, force_preserve), l)

@inline function _maybe_flip_conv_weight(x::AbstractArray)
return _maybe_flip_conv_weight(x, get_device(x))
return _maybe_flip_conv_weight(x, get_device_type(x))
end
@inline _maybe_flip_conv_weight(x::AbstractArray, _) = copy(x)
@inline function _maybe_flip_conv_weight(x::AbstractArray, ::LuxAMDGPUDevice)
@inline function _maybe_flip_conv_weight(x::AbstractArray, ::Type{<:LuxAMDGPUDevice})
# This is a very rare operation, hence we dont mind allowing scalar operations
return @allowscalar reverse(x; dims=ntuple(identity, ndims(x) - 2))
end

0 comments on commit a431979

Please sign in to comment.