Skip to content

Commit

Permalink
refactor: remove Requires and move everything to extensions
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Jul 16, 2024
1 parent 9a1d6c0 commit bb36104
Show file tree
Hide file tree
Showing 11 changed files with 27 additions and 19 deletions.
6 changes: 4 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,15 @@ KernelAbstractions = "63c18a36-062a-441e-b654-da1e3ab1ce7c"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Requires = "ae029012-a4dd-5104-9daa-d747884805df"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"

[weakdeps]
AMDGPU = "21141c5a-9bdb-4563-92ae-f87d6854732e"
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869"
FFTW = "7a1cc6ca-52ef-59f5-83cd-3a7055c09341"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
NNPACK_jll = "a6bfbf70-4841-5cb9-aa18-3a8ad3c413ee"
cuDNN = "02a925ec-e4fe-4b08-9a7e-0d78e3d38ccd"

[extensions]
Expand All @@ -27,6 +28,8 @@ NNlibCUDACUDNNExt = ["CUDA", "cuDNN"]
NNlibCUDAExt = "CUDA"
NNlibEnzymeCoreExt = "EnzymeCore"
NNlibFFTWExt = "FFTW"
NNlibForwardDiffExt = "ForwardDiff"
NNlibNNPAC_jllExt = "NNPACK_jll"

[compat]
AMDGPU = "0.9.4"
Expand All @@ -41,7 +44,6 @@ KernelAbstractions = "0.9.2"
LinearAlgebra = "1.10"
Pkg = "1.10"
Random = "1.10"
Requires = "1.0"
Statistics = "1"
cuDNN = "1"
julia = "1.10"
9 changes: 9 additions & 0 deletions ext/NNlibForwardDiffExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
module NNlibForwardDiffExt

using ForwardDiff: ForwardDiff
using NNlib: NNlib

NNlib.within_gradient(x::ForwardDiff.Dual) = true
NNlib.within_gradient(x::AbstractArray{<:ForwardDiff.Dual}) = true

end
File renamed without changes.
14 changes: 14 additions & 0 deletions ext/NNlibNNPACK_jllExt/NNlibNNPACK_jllExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
module NNlibNNPACK_jllExt

using NNPACK_jll, Pkg

if isdefined(NNPACK_jll, :libnnpack)
include("NNPACK.jl")
else
@warn "NNPACK not available for your platform: " *
"$( Pkg.BinaryPlatforms.platform_name(Pkg.BinaryPlatforms.platform_key_abi()))" *
"($( Pkg.BinaryPlatforms.triplet(Pkg.BinaryPlatforms.platform_key_abi())))
You will be able to use only the default Julia NNlib backend"
end

end
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
17 changes: 0 additions & 17 deletions src/NNlib.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@ using LinearAlgebra.BLAS: @blasfunc, BlasInt
using LinearAlgebra: AdjOrTransAbsMat, Adjoint, BlasFloat, Transpose
using Pkg
using Random
using Requires
using Statistics
using Statistics: mean

Expand All @@ -26,17 +25,6 @@ export ConvDims, DenseConvDims, PoolDims, DepthwiseConvDims

is_nnpack_available() = false

@init @require NNPACK_jll="a6bfbf70-4841-5cb9-aa18-3a8ad3c413ee" begin
if isdefined(NNPACK_jll, :libnnpack)
include("nnpack/NNPACK.jl")
else
@warn "NNPACK not available for your platform: " *
"$( Pkg.BinaryPlatforms.platform_name(Pkg.BinaryPlatforms.platform_key_abi()))" *
"($( Pkg.BinaryPlatforms.triplet(Pkg.BinaryPlatforms.platform_key_abi())))
You will be able to use only the default Julia NNlib backend"
end
end

include("activations.jl")
for f in ACTIVATIONS
@eval export $(f)
Expand Down Expand Up @@ -95,11 +83,6 @@ export upsample_nearest, ∇upsample_nearest,
include("gather.jl")
include("scatter.jl")
include("utils.jl")
@init @require ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" begin
using .ForwardDiff
within_gradient(x::ForwardDiff.Dual) = true
within_gradient(x::AbstractArray{<:ForwardDiff.Dual}) = true
end

include("sampling.jl")
include("functions.jl")
Expand Down

0 comments on commit bb36104

Please sign in to comment.