Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve type stability of cached walks #82

Merged
merged 3 commits into from
Nov 4, 2024

Conversation

chengchingwen
Copy link
Member

This PR adds a special cache type that allows the compiler to use the signature of the un-cached walk to generate corresponding type assertion to the untyped cache (IdDict{Any, Any}). This would improve the type stability of fmapand friends. It also looses the constraint of the cache type so functionality outside fmap remains the same.

@CarloLucibello
Copy link
Member

This adds some complexity to the code and some fragility as well, since it seems it could break with newer julia versions.
Can you post some benchmarks showing performance improvements?

@chengchingwen
Copy link
Member Author

Not a benchmark, but without this PR:

julia> @code_warntype gpu(Chain(Dense(3, 5), Dense(5, 2)))
MethodInstance for Flux.gpu(::Chain{Tuple{Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}, Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}}})
  from gpu(x) @ Flux ~/.julia/packages/Flux/Wz6D4/src/functor.jl:248
Arguments
  #self#::Core.Const(Flux.gpu)
  x::Chain{Tuple{Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}, Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}}}
Body::Chain{T} where T<:Tuple{Any, Any}
1%1 = Flux.FluxCUDAAdaptor()::Core.Const(Flux.FluxCUDAAdaptor(nothing))
│   %2 = Flux.gpu(%1, x)::Chain{T} where T<:Tuple{Any, Any}
└──      return %2

v.s. with:

julia> @code_warntype gpu(Chain(Dense(3, 5), Dense(5, 2)))
MethodInstance for Flux.gpu(::Chain{Tuple{Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}, Dense{typeof(identity
), Matrix{Float32}, Vector{Float32}}}})
  from gpu(x) @ Flux ~/.julia/packages/Flux/Wz6D4/src/functor.jl:248
Arguments
  #self#::Core.Const(Flux.gpu)
  x::Chain{Tuple{Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}, Dense{typeof(identity), Matrix{Float32}, Vecto
r{Float32}}}}
Body::Union{Chain{Tuple{Dense{typeof(identity), CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 1, CUDA.Mem.D
eviceBuffer}}, Dense{typeof(identity), CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 1, CUDA.Mem.DeviceBuff
er}}}}, Chain{Tuple{Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}, Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}}}}
1%1 = Flux.FluxCUDAAdaptor()::Core.Const(Flux.FluxCUDAAdaptor(nothing))
│   %2 = Flux.gpu(%1, x)::Union{Chain{Tuple{Dense{typeof(identity), CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, CuArray{Fl
oat32, 1, CUDA.Mem.DeviceBuffer}}, Dense{typeof(identity), CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 1,
 CUDA.Mem.DeviceBuffer}}}}, Chain{Tuple{Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}, Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}}}}
└──      return %2

@CarloLucibello
Copy link
Member

@darsnack @ToucheSir what do you think? I'm unfamiliar with expression manipulations.

@darsnack
Copy link
Member

I am also concerned about fragility. The implementation itself is sensible, but as written seems like it will need to get updated for internal changes often. The core idea is to use the return type of the walk to force the type when accessing the cache, right? That seems like a very straight-forward generated function to write with the call to return_type being the only brittle bit. Or is the rest of the current implementation necessary for performance reasons? Accessing the IdDict is the main reason Functors is type unstable, so fixing it is nice.

Pulling back, is there a use-case where we lack a function barrier between the call to gpu and the hot code path?

@chengchingwen
Copy link
Member Author

The core idea is to use the return type of the walk to force the type when accessing the cache, right? That seems like a very straight-forward generated function to write with the call to return_type being the only brittle bit. Or is the rest of the current implementation necessary for performance reasons?

Yes, essentially the whole generated function is just to generate return cache.cache[x]::(return_type(cache.walk, typeof(args))). It also seems to be doable without generated function, but with the generated function we can get the precise world-age (though I'm not familiar enough with the world-age mechanism to know if the precise world-age is required in this use-case).

Pulling back, is there a use-case where we lack a function barrier between the call to gpu and the hot code path?

if you need to handle data movement during the forward/backward pass.

@CarloLucibello
Copy link
Member

given the concerns expressed in LuxDL/Lux.jl#1017 I think we should do this.

@CarloLucibello CarloLucibello merged commit 2945731 into FluxML:master Nov 4, 2024
11 of 12 checks passed
@chengchingwen
Copy link
Member Author

@CarloLucibello Since Julia v1.10 is the new LTS, do you think we could drop v1.6 support so that we can remove that @static if VERSION >= v"1.10.0-DEV.609" branch which makes the code look fragile?

@CarloLucibello
Copy link
Member

yes, we should do that.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants