Skip to content

Commit

Permalink
Fix reinterpret performnace
Browse files Browse the repository at this point in the history
This fixes #25014 by making it more obvious what's going on to LLVM.
Instead of a memcpy loop, we use a new intrinsic that puts an actual
llvm.memcpy into the IR, which is enough for LLVM to fold everything
away. In the benchmark from #25014, we still see some regressions from
0.6, but that is because it needs to dereference through the pointers
in the reinterpret and reshape wrappers. In any real code, that
dereferencing should be loop-invariantly moved out of the inner loop.
  • Loading branch information
Keno committed Aug 16, 2018
1 parent 2715fb2 commit 0dbc329
Show file tree
Hide file tree
Showing 7 changed files with 57 additions and 22 deletions.
1 change: 1 addition & 0 deletions base/compiler/tfuncs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -359,6 +359,7 @@ add_tfunc(pointerref, 3, 3,
return Any
end, 4)
add_tfunc(pointerset, 4, 4, (@nospecialize(a), @nospecialize(v), @nospecialize(i), @nospecialize(align)) -> a, 5)
add_tfunc(aligned_memcpy, 4, 4, (@nospecialize(dst), @nospecialize(src), @nospecialize(n), @nospecialize(align)) -> Nothing, 5)

function typeof_tfunc(@nospecialize(t))
if isa(t, Const)
Expand Down
32 changes: 11 additions & 21 deletions base/reinterpretarray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -123,11 +123,9 @@ end
# once it knows the data layout
while nbytes_copied < sizeof(T)
s[] = a.parent[ind_start + i, tailinds...]
while nbytes_copied < sizeof(T) && sidx < sizeof(S)
unsafe_store!(tptr, unsafe_load(sptr, sidx + 1), nbytes_copied + 1)
sidx += 1
nbytes_copied += 1
end
nb = min(sizeof(S) - sidx, sizeof(T)-nbytes_copied)
Intrinsics.aligned_memcpy(tptr + nbytes_copied, sptr + sidx, nb, 1)
nbytes_copied += nb
sidx = 0
i += 1
end
Expand Down Expand Up @@ -173,34 +171,26 @@ end
# element from the original array and overwrite the relevant parts
if sidx != 0
s[] = a.parent[ind_start + i, tailinds...]
while nbytes_copied < sizeof(T) && sidx < sizeof(S)
unsafe_store!(sptr, unsafe_load(tptr, nbytes_copied + 1), sidx + 1)
sidx += 1
nbytes_copied += 1
end
nb = min(sizeof(S) - sidx, sizeof(T))
Intrinsics.aligned_memcpy(sptr + sidx, tptr, nb, 1)
nbytes_copied += nb
a.parent[ind_start + i, tailinds...] = s[]
i += 1
sidx = 0
end
# Deal with the main body of elements
while nbytes_copied < sizeof(T) && (sizeof(T) - nbytes_copied) > sizeof(S)
while nbytes_copied < sizeof(T) && sidx < sizeof(S)
unsafe_store!(sptr, unsafe_load(tptr, nbytes_copied + 1), sidx + 1)
sidx += 1
nbytes_copied += 1
end
nb = min(sizeof(S), sizeof(T) - nbytes_copied)
Intrinsics.aligned_memcpy(sptr, tptr + nbytes_copied, nb, 1)
nbytes_copied += nb
a.parent[ind_start + i, tailinds...] = s[]
i += 1
sidx = 0
end
# Deal with trailing partial elements
if nbytes_copied < sizeof(T)
s[] = a.parent[ind_start + i, tailinds...]
while nbytes_copied < sizeof(T) && sidx < sizeof(S)
unsafe_store!(sptr, unsafe_load(tptr, nbytes_copied + 1), sidx + 1)
sidx += 1
nbytes_copied += 1
end
nb = min(sizeof(S), sizeof(T) - nbytes_copied)
Intrinsics.aligned_memcpy(sptr, tptr + nbytes_copied, nb, 1)
a.parent[ind_start + i, tailinds...] = s[]
end
end
Expand Down
2 changes: 1 addition & 1 deletion base/reshapedarray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -222,7 +222,7 @@ end
I = ind2sub_rs(axes(A.parent), A.mi, i)
_unsafe_getindex_rs(parent(A), I)
end
_unsafe_getindex_rs(A, i::Integer) = (@inbounds ret = A[i]; ret)
@inline _unsafe_getindex_rs(A, i::Integer) = (@inbounds ret = A[i]; ret)
@inline _unsafe_getindex_rs(A, I) = (@inbounds ret = A[I...]; ret)

@inline function setindex!(A::ReshapedArrayLF, val, index::Int)
Expand Down
30 changes: 30 additions & 0 deletions src/intrinsics.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -718,6 +718,33 @@ static jl_cgval_t emit_pointerset(jl_codectx_t &ctx, jl_cgval_t *argv)
return mark_julia_type(ctx, thePtr, false, aty);
}

static void emit_aligned_memcpy(jl_codectx_t &ctx, jl_cgval_t *argv)
{
const jl_cgval_t &dst = argv[0];
const jl_cgval_t &src = argv[1];
const jl_cgval_t &n = argv[2];
const jl_cgval_t &align = argv[3];

if (n.typ != (jl_value_t*)jl_long_type ||
align.typ != (jl_value_t*)jl_long_type ||
!jl_is_cpointer_type(dst.typ) ||
!jl_is_cpointer_type(src.typ)) {
emit_runtime_call(ctx, aligned_memcpy, argv, 4);
return;
}

size_t alignment = 1;
if (align.constant)
alignment = jl_unbox_long(align.constant);

ctx.builder.CreateMemCpy(
ctx.builder.CreateIntToPtr(dst.V, T_pint8),
ctx.builder.CreateIntToPtr(src.V, T_pint8),
n.V, alignment,
false);

}

static Value *emit_checked_srem_int(jl_codectx_t &ctx, Value *x, Value *den)
{
Type *t = den->getType();
Expand Down Expand Up @@ -937,6 +964,9 @@ static jl_cgval_t emit_intrinsic(jl_codectx_t &ctx, intrinsic f, jl_value_t **ar
return emit_pointerref(ctx, argv);
case pointerset:
return emit_pointerset(ctx, argv);
case aligned_memcpy:
emit_aligned_memcpy(ctx, argv);
return mark_julia_type(ctx, NULL, false, jl_void_type);
case bitcast:
return generic_bitcast(ctx, argv);
case trunc_int:
Expand Down
1 change: 1 addition & 0 deletions src/intrinsics.h
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,7 @@
/* pointer access */ \
ADD_I(pointerref, 3) \
ADD_I(pointerset, 4) \
ADD_I(aligned_memcpy, 4) \
/* c interface */ \
ADD_I(cglobal, 2) \
ALIAS(llvmcall, llvmcall) \
Expand Down
1 change: 1 addition & 0 deletions src/julia_internal.h
Original file line number Diff line number Diff line change
Expand Up @@ -705,6 +705,7 @@ unsigned jl_intrinsic_nargs(int f);
JL_DLLEXPORT jl_value_t *jl_bitcast(jl_value_t *ty, jl_value_t *v);
JL_DLLEXPORT jl_value_t *jl_pointerref(jl_value_t *p, jl_value_t *i, jl_value_t *align);
JL_DLLEXPORT jl_value_t *jl_pointerset(jl_value_t *p, jl_value_t *x, jl_value_t *align, jl_value_t *i);
JL_DLLEXPORT jl_value_t *jl_aligned_memcpy(jl_value_t *dst, jl_value_t *src, jl_value_t *n, jl_value_t *align);
JL_DLLEXPORT jl_value_t *jl_cglobal(jl_value_t *v, jl_value_t *ty);
JL_DLLEXPORT jl_value_t *jl_cglobal_auto(jl_value_t *v);

Expand Down
12 changes: 12 additions & 0 deletions src/runtime_intrinsics.c
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,18 @@ JL_DLLEXPORT jl_value_t *jl_pointerset(jl_value_t *p, jl_value_t *x, jl_value_t
return p;
}

// run time version of memcpy intrinsic
JL_DLLEXPORT jl_value_t *jl_aligned_memcpy(jl_value_t *dst, jl_value_t *src, jl_value_t *n, jl_value_t *align)
{
JL_TYPECHK(pointerref, pointer, dst);
JL_TYPECHK(pointerref, pointer, src)
JL_TYPECHK(pointerref, long, n);
JL_TYPECHK(pointerref, long, align);
memcpy((uint8_t*)jl_unbox_long(dst), (uint8_t*)jl_unbox_long(src), jl_unbox_long(n));
return jl_nothing;
}


JL_DLLEXPORT jl_value_t *jl_cglobal(jl_value_t *v, jl_value_t *ty)
{
JL_TYPECHK(cglobal, type, ty);
Expand Down

0 comments on commit 0dbc329

Please sign in to comment.