Skip to content

Commit

Permalink
Fix reinterpret performance
Browse files Browse the repository at this point in the history
This fixes #25014 by making it more obvious what's going on to LLVM.
Instead of a memcpy loop, we use a ccall to :memcpy and turn this into
llvm.memcpy at the IR level, which is enough for LLVM to fold everything
away. In the benchmark from #25014, we still see some regressions from
0.6, but that is because it needs to dereference through the pointers
in the reinterpret and reshape wrappers. In any real code, that
dereferencing should be loop-invariantly moved out of the inner loop.

(cherry picked from commit 777810b)
  • Loading branch information
Keno authored and KristofferC committed Feb 11, 2019
1 parent 0f9d23d commit 5a706f3
Show file tree
Hide file tree
Showing 3 changed files with 28 additions and 22 deletions.
34 changes: 13 additions & 21 deletions base/reinterpretarray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,8 @@ end
_getindex_ra(a, inds[1], tail(inds))
end

@inline _memcpy!(dst, src, n) = ccall(:memcpy, Cvoid, (Ptr{UInt8}, Ptr{UInt8}, Csize_t), dst, src, n)

@inline @propagate_inbounds function _getindex_ra(a::ReinterpretArray{T,N,S}, i1::Int, tailinds::TT) where {T,N,S,TT}
# Make sure to match the scalar reinterpret if that is applicable
if sizeof(T) == sizeof(S) && (fieldcount(T) + fieldcount(S)) == 0
Expand All @@ -123,11 +125,9 @@ end
# once it knows the data layout
while nbytes_copied < sizeof(T)
s[] = a.parent[ind_start + i, tailinds...]
while nbytes_copied < sizeof(T) && sidx < sizeof(S)
unsafe_store!(tptr, unsafe_load(sptr, sidx + 1), nbytes_copied + 1)
sidx += 1
nbytes_copied += 1
end
nb = min(sizeof(S) - sidx, sizeof(T)-nbytes_copied)
_memcpy!(tptr + nbytes_copied, sptr + sidx, nb)
nbytes_copied += nb
sidx = 0
i += 1
end
Expand Down Expand Up @@ -173,34 +173,26 @@ end
# element from the original array and overwrite the relevant parts
if sidx != 0
s[] = a.parent[ind_start + i, tailinds...]
while nbytes_copied < sizeof(T) && sidx < sizeof(S)
unsafe_store!(sptr, unsafe_load(tptr, nbytes_copied + 1), sidx + 1)
sidx += 1
nbytes_copied += 1
end
nb = min(sizeof(S) - sidx, sizeof(T))
_memcpy!(sptr + sidx, tptr, nb)
nbytes_copied += nb
a.parent[ind_start + i, tailinds...] = s[]
i += 1
sidx = 0
end
# Deal with the main body of elements
while nbytes_copied < sizeof(T) && (sizeof(T) - nbytes_copied) > sizeof(S)
while nbytes_copied < sizeof(T) && sidx < sizeof(S)
unsafe_store!(sptr, unsafe_load(tptr, nbytes_copied + 1), sidx + 1)
sidx += 1
nbytes_copied += 1
end
nb = min(sizeof(S), sizeof(T) - nbytes_copied)
_memcpy!(sptr, tptr + nbytes_copied, nb)
nbytes_copied += nb
a.parent[ind_start + i, tailinds...] = s[]
i += 1
sidx = 0
end
# Deal with trailing partial elements
if nbytes_copied < sizeof(T)
s[] = a.parent[ind_start + i, tailinds...]
while nbytes_copied < sizeof(T) && sidx < sizeof(S)
unsafe_store!(sptr, unsafe_load(tptr, nbytes_copied + 1), sidx + 1)
sidx += 1
nbytes_copied += 1
end
nb = min(sizeof(S), sizeof(T) - nbytes_copied)
_memcpy!(sptr, tptr + nbytes_copied, nb)
a.parent[ind_start + i, tailinds...] = s[]
end
end
Expand Down
2 changes: 1 addition & 1 deletion base/reshapedarray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -222,7 +222,7 @@ end
I = ind2sub_rs(axes(A.parent), A.mi, i)
_unsafe_getindex_rs(parent(A), I)
end
_unsafe_getindex_rs(A, i::Integer) = (@inbounds ret = A[i]; ret)
@inline _unsafe_getindex_rs(A, i::Integer) = (@inbounds ret = A[i]; ret)
@inline _unsafe_getindex_rs(A, I) = (@inbounds ret = A[I...]; ret)

@inline function setindex!(A::ReshapedArrayLF, val, index::Int)
Expand Down
14 changes: 14 additions & 0 deletions src/ccall.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1820,6 +1820,20 @@ static jl_cgval_t emit_ccall(jl_codectx_t &ctx, jl_value_t **args, size_t nargs)
JL_GC_POP();
return mark_or_box_ccall_result(ctx, strp, retboxed, rt, unionall, static_rt);
}
else if (is_libjulia_func(memcpy)) {
const jl_cgval_t &dst = argv[0];
const jl_cgval_t &src = argv[1];
const jl_cgval_t &n = argv[2];
ctx.builder.CreateMemCpy(
ctx.builder.CreateIntToPtr(
emit_unbox(ctx, T_size, dst, (jl_value_t*)jl_voidpointer_type), T_pint8),
ctx.builder.CreateIntToPtr(
emit_unbox(ctx, T_size, src, (jl_value_t*)jl_voidpointer_type), T_pint8),
emit_unbox(ctx, T_size, n, (jl_value_t*)jl_ulong_type), 1,
false);
JL_GC_POP();
return ghostValue(jl_void_type);
}

jl_cgval_t retval = sig.emit_a_ccall(
ctx,
Expand Down

0 comments on commit 5a706f3

Please sign in to comment.