Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

WIP: Make mutating immutables easier #21912

Closed
wants to merge 7 commits into from
Closed
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion base/boot.jl
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,8 @@ export
Expr, GotoNode, LabelNode, LineNumberNode, QuoteNode,
GlobalRef, NewvarNode, SSAValue, Slot, SlotNumber, TypedSlot,
# object model functions
fieldtype, getfield, setfield!, nfields, throw, tuple, ===, isdefined, eval,
fieldtype, getfield, setfield!, nfields, throw, tuple, ===,
isdefined, eval,
# sizeof # not exported, to avoid conflicting with Base.sizeof
# type reflection
issubtype, typeof, isa, typeassert,
Expand Down
5 changes: 5 additions & 0 deletions base/exports.jl
Original file line number Diff line number Diff line change
Expand Up @@ -978,6 +978,11 @@ export
catch_stacktrace,

# types
gepfield,
gepindex,
setfield,
setindex,
@setfield,
convert,
fieldoffset,
fieldname,
Expand Down
2 changes: 2 additions & 0 deletions base/reflection.jl
Original file line number Diff line number Diff line change
Expand Up @@ -379,6 +379,8 @@ function fieldindex(T::DataType, name::Symbol, err::Bool=true)
return Int(ccall(:jl_field_index, Cint, (Any, Any, Cint), T, name, err)+1)
end

fieldisptr(T::DataType, idx::Integer) = 1 == ccall(:jl_get_field_isptr, Cint, (Any, Cint), T, idx)

type_alignment(x::DataType) = (@_pure_meta; ccall(:jl_get_alignment, Csize_t, (Any,), x))

# return all instances, for types that can be enumerated
Expand Down
105 changes: 105 additions & 0 deletions base/refpointer.jl
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ Ref(x, i::Integer) = (i != 1 && error("Object only has one element"); Ref(x))
Ref{T}() where {T} = RefValue{T}() # Ref{T}()
Ref{T}(x) where {T} = RefValue{T}(x) # Ref{T}(x)
convert(::Type{Ref{T}}, x) where {T} = RefValue{T}(x)
copy(x::RefValue{T}) where {T} = RefValue{T}(x.x)

function unsafe_convert(P::Type{Ptr{T}}, b::RefValue{T}) where T
if isbits(T)
Expand Down Expand Up @@ -113,6 +114,110 @@ cconvert(::Type{Ref{P}}, a::Array{<:Ptr}) where {P<:Ptr} = a
cconvert(::Type{Ptr{P}}, a::Array) where {P<:Union{Ptr,Cwstring,Cstring}} = Ref{P}(a)
cconvert(::Type{Ref{P}}, a::Array) where {P<:Union{Ptr,Cwstring,Cstring}} = Ref{P}(a)


## RefField
struct RefField{T} <: Ref{T}
base
offset::UInt
# Basic constructors
global gepfield
function gepfield(x::ANY, idx::Integer)
typeof(x).mutable || error("Tried to take reference to immutable type $(typeof(x))")
new{fieldtype(typeof(x), idx)}(x, fieldoffset(typeof(x), idx))
end
function gepfield(x::RefField{T}, idx::Integer) where {T}
!fieldisptr(T, idx) || error("Can only take interior references that are inline (e.g. immutable). Tried to access field \"$(fieldname(T, idx))\" of type $T")
new{fieldtype(T, idx)}(x.base, x.offset + fieldoffset(T, idx))
end
end

function gepfield(x::ANY, sym::Symbol)
gepfield(x, Base.fieldindex(typeof(x), sym))
end
function gepfield(x::RefField{T}, sym::Symbol) where T
gepfield(x, Base.fieldindex(T, sym))
end
gepindex(x::Ref) = gepfield(x, 1)

# Tuple is defined before us in bootstrap, so it can't refer to RefField
gepindex(x::RefField{<:Tuple}, idx) = gepfield(x, idx)

function setindex!(x::RefField{T}, v::T) where T
unsafe_store!(Ptr{T}(pointer_from_objref(x.base)+x.offset), v)
v
end

function getindex(x::RefField{T}, v::T) where T
unsafe_load(Ptr{T}(pointer_from_objref(x.base)+x.offset), v)
end

function setfield(x, sym, v)
if typeof(x).mutable
y = copy(x)
setfield!(y, sym, v)
y
else
y = Ref{typeof(x)}(x)
(gepfield(y@[], sym))[] = v
y[]
end
end

function setindex(x, v, idxs...)
if typeof(x).mutable
y = copy(x)
setindex!(y, v, idxs...)
y
else
y = Ref{typeof(x)}(x)
(gepindex(y@[], idxs...))[] = v
y[]
end
end

macro setfield(base, idx)
if idx.head != :(=)
error("Expected assignment as second argument")
end
idx, rhs = idx.args
x, v = ntuple(i->gensym(), 2)
setupexprs = Expr[:($x = $base), :($v = $rhs)]
getfieldexprs, setfieldexprs = Expr[], Expr[]
res = nothing
while true
y, nv = ntuple(i->gensym(), 2)
if isa(idx, Symbol)
idx = Expr(:quote, idx)
(res !== nothing) && push!(getfieldexprs, :($res = getfield($x, $idx)))
push!(setfieldexprs, :($nv = setfield($x, $idx, $v)))
break
elseif idx.head == :vect
t = gensym()
(res !== nothing) && push!(getfieldexprs, :($res = getindex($x, $t...)))
push!(getfieldexprs, :($t = tuple($(idx.args...))))
push!(setfieldexprs, :($nv = setindex($x, $v, $t...)))
break
elseif idx.head == :(.)
(res !== nothing) && push!(getfieldexprs, :($res = getfield($y, $(idx.args[2]))))
push!(setfieldexprs, :($nv = setfield($y, $(idx.args[2]), $v)))
res, v, idx = y, nv, idx.args[1]
elseif idx.head == :ref
t = gensym()
(res !== nothing) && push!(getfieldexprs, :($res = getindex($y, $t...)))
push!(getfieldexprs, :($t = tuple($(idx.args[2:end]...))))
push!(setfieldexprs, :($nv = setindex($y, $v, $t...)))
res, v, idx = y, nv, idx.args[1]
else
error("Unknown field ref syntax")
end
end
push!(getfieldexprs, :($y = $x))
esc(Expr(:block,
setupexprs...,
reverse(getfieldexprs)...,
setfieldexprs...))
end

###

getindex(b::RefValue) = b.x
Expand Down
10 changes: 3 additions & 7 deletions base/tuple.jl
Original file line number Diff line number Diff line change
Expand Up @@ -22,14 +22,10 @@ getindex(t::Tuple, i::Int) = getfield(t, i)
getindex(t::Tuple, i::Real) = getfield(t, convert(Int, i))
getindex(t::Tuple, r::AbstractArray{<:Any,1}) = ([t[ri] for ri in r]...)
getindex(t::Tuple, b::AbstractArray{Bool,1}) = length(b) == length(t) ? getindex(t,find(b)) : throw(BoundsError(t, b))
gepindex(x::Tuple, i) = gepfield(x, i)

# returns new tuple; N.B.: becomes no-op if i is out-of-bounds
setindex(x::Tuple, v, i::Integer) = _setindex((), x, v, i::Integer)
function _setindex(y::Tuple, r::Tuple, v, i::Integer)
@_inline_meta
_setindex((y..., ifelse(length(y) + 1 == i, v, first(r))), tail(r), v, i)
end
_setindex(y::Tuple, r::Tuple{}, v, i::Integer) = y
# returns new tuple
setindex(x::Tuple, v, i::Integer) = setfield(x, i, v)

## iterating ##

Expand Down
1 change: 1 addition & 0 deletions src/builtin_proto.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ DECLARE_BUILTIN(_apply_latest);
DECLARE_BUILTIN(isdefined); DECLARE_BUILTIN(nfields);
DECLARE_BUILTIN(tuple); DECLARE_BUILTIN(svec);
DECLARE_BUILTIN(getfield); DECLARE_BUILTIN(setfield);
DECLARE_BUILTIN(setfield_bang);
DECLARE_BUILTIN(fieldtype); DECLARE_BUILTIN(arrayref);
DECLARE_BUILTIN(arrayset); DECLARE_BUILTIN(arraysize);
DECLARE_BUILTIN(apply_type); DECLARE_BUILTIN(applicable);
Expand Down
33 changes: 31 additions & 2 deletions src/builtins.c
Original file line number Diff line number Diff line change
Expand Up @@ -620,7 +620,7 @@ JL_CALLABLE(jl_f_getfield)
return fval;
}

JL_CALLABLE(jl_f_setfield)
JL_CALLABLE(jl_f_setfield_bang)
{
JL_NARGS(setfield!, 3, 3);
jl_value_t *v = args[0];
Expand Down Expand Up @@ -650,6 +650,34 @@ JL_CALLABLE(jl_f_setfield)
return args[2];
}

JL_CALLABLE(jl_f_setfield)
{
JL_NARGS(setfield, 3, 3);
jl_value_t *v = args[0];
jl_value_t *vt = (jl_value_t*)jl_typeof(v);
if (vt == (jl_value_t*)jl_module_type || !jl_is_datatype(vt))
jl_type_error("setfield", (jl_value_t*)jl_datatype_type, v);
jl_datatype_t *st = (jl_datatype_t*)vt;
size_t idx;
if (jl_is_long(args[1])) {
idx = jl_unbox_long(args[1])-1;
if (idx >= jl_datatype_nfields(st))
jl_bounds_error(args[0], args[1]);
}
else {
JL_TYPECHK(setfield!, symbol, args[1]);
idx = jl_field_index(st, (jl_sym_t*)args[1], 1);
}
jl_value_t *ft = jl_field_type(st,idx);
if (!jl_isa(args[2], ft)) {
jl_type_error("setfield", ft, args[2]);
}
jl_value_t *newv = jl_new_struct_uninit(vt);
memcpy(newv, v, st->size);
jl_set_nth_field(newv, idx, args[2]);
return newv;
}

static jl_value_t *get_fieldtype(jl_value_t *t, jl_value_t *f)
{
if (jl_is_unionall(t)) {
Expand Down Expand Up @@ -1077,7 +1105,8 @@ void jl_init_primitives(void)

// field access
add_builtin_func("getfield", jl_f_getfield);
add_builtin_func("setfield!", jl_f_setfield);
add_builtin_func("setfield", jl_f_setfield);
add_builtin_func("setfield!", jl_f_setfield_bang);
add_builtin_func("fieldtype", jl_f_fieldtype);
add_builtin_func("nfields", jl_f_nfields);
add_builtin_func("isdefined", jl_f_isdefined);
Expand Down
Loading