Skip to content

Commit

Permalink
Add rnumber/rarray (#2075)
Browse files Browse the repository at this point in the history
* Add rnumber/rarray

* fix

* Update compiler.jl
  • Loading branch information
wsmoses authored Nov 8, 2024
1 parent ae171bf commit 31df08b
Show file tree
Hide file tree
Showing 4 changed files with 109 additions and 7 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ EnzymeStaticArraysExt = "StaticArrays"
BFloat16s = "0.2, 0.3, 0.4, 0.5"
CEnum = "0.4, 0.5"
ChainRulesCore = "1"
EnzymeCore = "0.8.4, 0.8.5"
EnzymeCore = "0.8.6"
Enzyme_jll = "0.0.163"
GPUCompiler = "0.21, 0.22, 0.23, 0.24, 0.25, 0.26, 0.27, 1"
LLVM = "6.1, 7, 8, 9"
Expand Down
2 changes: 1 addition & 1 deletion lib/EnzymeCore/Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "EnzymeCore"
uuid = "f151be2c-9106-41f4-ab19-57ee4f262869"
authors = ["William Moses <wmoses@mit.edu>", "Valentin Churavy <vchuravy@mit.edu>"]
version = "0.8.5"
version = "0.8.6"

[compat]
Adapt = "3, 4"
Expand Down
56 changes: 56 additions & 0 deletions lib/EnzymeCore/src/EnzymeCore.jl
Original file line number Diff line number Diff line change
Expand Up @@ -592,4 +592,60 @@ Return a new mode with its [`ABI`](@ref) set to the chosen type.
"""
function set_abi end


"""
Primitive Type usable within Reactant. See Reactant.jl for more information.
"""
@static if isdefined(Core, :BFloat16)
const ReactantPrimitive = Union{
Bool,
Int8,
UInt8,
Int16,
UInt16,
Int32,
UInt32,
Int64,
UInt64,
Float16,
Core.BFloat16,
Float32,
Float64,
Complex{Float32},
Complex{Float64},
}
else
const ReactantPrimitive = Union{
Bool,
Int8,
UInt8,
Int16,
UInt16,
Int32,
UInt32,
Int64,
UInt64,
Float16,
Float32,
Float64,
Complex{Float32},
Complex{Float64},
}
end

"""
Abstract Reactant Array type. See Reactant.jl for more information
"""
abstract type RArray{T<:ReactantPrimitive,N} <: AbstractArray{T,N} end
@inline Base.eltype(::RArray{T}) where T = T
@inline Base.eltype(::Type{<:RArray{T}}) where T = T

"""
Abstract Reactant Number type. See Reactant.jl for more information
"""
abstract type RNumber{T<:ReactantPrimitive} <: Number end
@inline Base.eltype(::RNumber{T}) where T = T
@inline Base.eltype(::Type{<:RNumber{T}}) where T = T


end # module EnzymeCore
56 changes: 51 additions & 5 deletions src/compiler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,8 @@ const known_ops = Dict{DataType,Tuple{Symbol,Int,Union{Nothing,Tuple{Symbol,Data
end

const nofreefns = Set{String}((
"ClientGetDevice",
"BufferOnCPU",
"pcre2_match_8",
"julia.gcroot_flush",
"pcre2_jit_stack_assign_8",
Expand Down Expand Up @@ -311,6 +313,8 @@ const nofreefns = Set{String}((
))

const inactivefns = Set{String}((
"ClientGetDevice",
"BufferOnCPU",
"pcre2_match_data_create_from_pattern_8",
"ijl_typeassert",
"jl_typeassert",
Expand Down Expand Up @@ -517,6 +521,8 @@ end
end
end

@inline numbereltype(::Type{<:EnzymeCore.RNumber{T}}) where {T} = T
@inline ptreltype(::Type{<:EnzymeCore.RArray{T}}) where {T} = T
@inline ptreltype(::Type{Ptr{T}}) where {T} = T
@inline ptreltype(::Type{Core.LLVMPtr{T,N}}) where {T,N} = T
@inline ptreltype(::Type{Core.LLVMPtr{T} where N}) where {T} = T
Expand Down Expand Up @@ -644,10 +650,21 @@ end
return ActiveState
end

if T <: EnzymeCore.RNumber
return active_reg_inner(
numbereltype(T),
seen,
world,
Val(justActive),
Val(UnionSret),
Val(AbstractIsMixed),
)
end

if T <: Ptr ||
T <: Core.LLVMPtr ||
T <: Base.RefValue ||
T <: Array ||
T <: Array || T <: EnzymeCore.RArray
is_arrayorvararg_ty(T)
if justActive
return AnyState
Expand Down Expand Up @@ -762,7 +779,7 @@ end
end

# if abstract it must be by reference
if Base.isabstracttype(T)
if Base.isabstracttype(T) || T == Tuple
if AbstractIsMixed
return MixedState
else
Expand All @@ -779,11 +796,11 @@ end
end

@assert !Base.isabstracttype(T)
if !(Base.isconcretetype(T) || (T <: Tuple && T != Tuple) || T isa UnionAll)
if !(Base.isconcretetype(T) || T <: Tuple || T isa UnionAll)
throw(AssertionError("Type $T is not concrete type or concrete tuple"))
end

nT = if T <: Tuple && T != Tuple && !(T isa UnionAll)
nT = if T <: Tuple && !(T isa UnionAll)
Tuple{(
ntuple(length(T.parameters)) do i
Base.@_inline_meta
Expand Down Expand Up @@ -1108,7 +1125,7 @@ struct Return2
end

function force_recompute!(mod::LLVM.Module)
for f in functions(mod), bb in blocks(f), inst in instructions(bb)
for f in functions(mod), bb in blocks(f), inst in collect(instructions(bb))
if isa(inst, LLVM.LoadInst)
has_loaded = false
for u in LLVM.uses(inst)
Expand Down Expand Up @@ -1137,8 +1154,24 @@ function force_recompute!(mod::LLVM.Module)
metadata(inst)["enzyme_nocache"] = MDNode(LLVM.Metadata[])
end
end
if isa(inst, LLVM.CallInst)
cf = LLVM.called_operand(inst)
if isa(cf, LLVM.Function)
if LLVM.name(cf) == "llvm.julia.gc_preserve_begin"
has_use = false
for u2 in LLVM.uses(inst)
has_use = true
break
end
if !has_use
eraseInst(bb, inst)
end
end
end
end
end
end

function permit_inlining!(f::LLVM.Function)
for bb in blocks(f), inst in instructions(bb)
# remove illegal invariant.load and jtbaa_const invariants
Expand Down Expand Up @@ -3294,13 +3327,26 @@ function annotate!(mod, mode)
end
end

for fname in (
"UnsafeBufferPointer",
)
if haskey(funcs, fname)
for fn in funcs[fname]
if LLVM.version().major <= 15
push!(function_attributes(fn), LLVM.StringAttribute("enzyme_math", "__dynamic_cast"))
end
end
end
end

for fname in (
"jl_f_getfield",
"ijl_f_getfield",
"jl_get_nth_field_checked",
"ijl_get_nth_field_checked",
"jl_f__svec_ref",
"ijl_f__svec_ref",
"UnsafeBufferPointer"
)
if haskey(funcs, fname)
for fn in funcs[fname]
Expand Down

0 comments on commit 31df08b

Please sign in to comment.