Skip to content

Commit

Permalink
Fix reinterpret(T, ::Array) performance
Browse files Browse the repository at this point in the history
When I originally wrote the new ReinterpretArray code, I made sure that LLVM was able
to optimize reinterpret(::Array) back to a single memory access with appropriate TBAA
and alignment info. Somewhere along the line LLVM lost that ability. While we should
try to recover that capability in LLVM, that showed that that is a relatively brittle
optimization for a very simple operation. So this patch takes a different approach:
We add two new intrinsics `tbaa_pointerref` and `tbaa_pointerset` that behave like
their non-TBAA variants, but additionally take a type to use as the TBAA tag. This
allows us to write a special case for `reinterpret(T, ::Array)` that directly emits
the correct pointer access. It's also a model for what a post-1.0 pure Julia
implementation of `Array` (e.g. on top of a buffer type) may look like.

Fixes #25014
  • Loading branch information
Keno committed May 23, 2018
1 parent dfa6e4b commit a3d5e35
Show file tree
Hide file tree
Showing 10 changed files with 186 additions and 33 deletions.
4 changes: 4 additions & 0 deletions base/compiler/optimize.jl
Original file line number Diff line number Diff line change
Expand Up @@ -812,6 +812,8 @@ end
function is_pure_intrinsic_infer(f::IntrinsicFunction)
return !(f === Intrinsics.pointerref || # this one is volatile
f === Intrinsics.pointerset || # this one is never effect-free
f === Intrinsics.tbaa_pointerref || # same as pointerref
f === Intrinsics.tbaa_pointerset || # same as pointerset
f === Intrinsics.llvmcall || # this one is never effect-free
f === Intrinsics.arraylen || # this one is volatile
f === Intrinsics.sqrt_llvm || # this one may differ at runtime (by a few ulps)
Expand All @@ -822,6 +824,8 @@ end
function is_pure_intrinsic_optim(f::IntrinsicFunction)
return !(f === Intrinsics.pointerref || # this one is volatile
f === Intrinsics.pointerset || # this one is never effect-free
f === Intrinsics.tbaa_pointerref || # same as pointerref
f === Intrinsics.tbaa_pointerset || # same as pointerset
f === Intrinsics.llvmcall || # this one is never effect-free
f === Intrinsics.arraylen || # this one is volatile
f === Intrinsics.checked_sdiv_int || # these may throw errors
Expand Down
34 changes: 19 additions & 15 deletions base/compiler/tfuncs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -317,22 +317,26 @@ add_tfunc(Core._expr, 1, INT_INF, (args...)->Expr, 100)
add_tfunc(applicable, 1, INT_INF, (@nospecialize(f), args...)->Bool, 100)
add_tfunc(Core.Intrinsics.arraylen, 1, 1, x->Int, 4)
add_tfunc(arraysize, 2, 2, (@nospecialize(a), @nospecialize(d))->Int, 4)
add_tfunc(pointerref, 3, 3,
function (@nospecialize(a), @nospecialize(i), @nospecialize(align))
a = widenconst(a)
if a <: Ptr
if isa(a,DataType) && isa(a.parameters[1],Type)
return a.parameters[1]
elseif isa(a,UnionAll) && !has_free_typevars(a)
unw = unwrap_unionall(a)
if isa(unw,DataType)
return rewrap_unionall(unw.parameters[1], a)
end
end
end
return Any
end, 4)
function pointerref_tfunc(@nospecialize(a), @nospecialize(i), @nospecialize(align))
a = widenconst(a)
if a <: Ptr
if isa(a,DataType) && isa(a.parameters[1],Type)
return a.parameters[1]
elseif isa(a,UnionAll) && !has_free_typevars(a)
unw = unwrap_unionall(a)
if isa(unw,DataType)
return rewrap_unionall(unw.parameters[1], a)
end
end
end
return Any
end
add_tfunc(pointerref, 3, 3, pointerref_tfunc, 4)
add_tfunc(tbaa_pointerref, 4, 4, function (@nospecialize(t), @nospecialize(a), @nospecialize(i), @nospecialize(align))
pointerref_tfunc(a, i, align)
end, 4)
add_tfunc(pointerset, 4, 4, (@nospecialize(a), @nospecialize(v), @nospecialize(i), @nospecialize(align)) -> a, 5)
add_tfunc(tbaa_pointerset, 5, 5, (@nospecialize(t), @nospecialize(a), @nospecialize(v), @nospecialize(i), @nospecialize(align)) -> a, 5)

function typeof_tfunc(@nospecialize(t))
if isa(t, Const)
Expand Down
33 changes: 33 additions & 0 deletions base/reinterpretarray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,39 @@ end
return a
end

# Special case for StridedArray
reinterpret_alignment(::Type{ReinterpretArray{T,N,S,A}} where {N, S, A<:Array}) where {T} = datatype_alignment(T)
reinterpret_alignment(::Type{<:ReinterpretArray}) = gcd(datatype_alignment(T), datatype_alignment(S))

@inline @propagate_inbounds function getindex(a::ReinterpretArray{T,N,S,A}, inds::Vararg{Int, N}) where {T,N,S,A<:Array}
if isbits(T) || (isstructtype(T) && !isa(T, Array) && isconcretetype(T))
check_readable(a)
pa = parent(a)
@GC.preserve pa begin
ptr = pointer(pa, LinearIndices(pa)[1, tail(inds)...])
return Intrinsics.tbaa_pointerref(A, Ptr{T}(ptr), inds[1], reinterpret_alignment(typeof(a)))
end
end
# Fall back to generic method
invoke(getindex, Tuple{ReinterpretArray{T,N,S}, Vararg{Int, N}}, a, inds...)
end

@inline @propagate_inbounds function setindex!(a::ReinterpretArray{T,N,S,A}, v, inds::Vararg{Int, N}) where {T,N,S,A<:Array}
if isbits(T) || (isstructtype(T) && !isa(T, Array) && isconcretetype(T))
check_writable(a)
v = convert(T, v)::T
pa = parent(a)
@GC.preserve pa begin
ptr = pointer(pa, LinearIndices(pa)[1, tail(inds)...])
Intrinsics.tbaa_pointerset(A, Ptr{T}(ptr), v, inds[1], reinterpret_alignment(typeof(a)))
return a
end
end
# Fall back to generic method
invoke(setindex!, Tuple{ReinterpretArray{T,N,S}, Any, Vararg{Int, N}}, a, v, inds...)
end


# Padding
struct Padding
offset::Int
Expand Down
2 changes: 1 addition & 1 deletion src/Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -193,7 +193,7 @@ $(BUILDDIR)/debuginfo.o $(BUILDDIR)/debuginfo.dbg.obj: \
$(addprefix $(SRCDIR)/,debuginfo.h processor.h)
$(BUILDDIR)/disasm.o $(BUILDDIR)/disasm.dbg.obj: $(SRCDIR)/debuginfo.h $(SRCDIR)/processor.h
$(BUILDDIR)/jitlayers.o $(BUILDDIR)/jitlayers.dbg.obj: $(SRCDIR)/jitlayers.h
$(BUILDDIR)/builtins.o $(BUILDDIR)/builtins.dbg.obj: $(SRCDIR)/table.c
$(BUILDDIR)/builtins.o $(BUILDDIR)/builtins.dbg.obj: $(SRCDIR)/table.c $(SRCDIR)/intrinsics.h
$(BUILDDIR)/staticdata.o $(BUILDDIR)/staticdata.dbg.obj: $(SRCDIR)/processor.h
$(BUILDDIR)/gc.o $(BUILDDIR)/gc.dbg.obj: $(SRCDIR)/gc.h
$(BUILDDIR)/gc-debug.o $(BUILDDIR)/gc-debug.dbg.obj: $(SRCDIR)/gc.h
Expand Down
20 changes: 13 additions & 7 deletions src/cgutils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1329,7 +1329,8 @@ static Value *data_pointer(jl_codectx_t &ctx, const jl_cgval_t &x)
}

static void emit_memcpy_llvm(jl_codectx_t &ctx, Value *dst, Value *src,
uint64_t sz, unsigned align, bool is_volatile, MDNode *tbaa)
uint64_t sz, unsigned align, bool is_volatile,
MDNode *dst_tbaa, MDNode *src_tbaa)
{
if (sz == 0)
return;
Expand Down Expand Up @@ -1361,21 +1362,26 @@ static void emit_memcpy_llvm(jl_codectx_t &ctx, Value *dst, Value *src,
src = emit_bitcast(ctx, src, dstty);
}
if (direct) {
auto val = tbaa_decorate(tbaa, ctx.builder.CreateAlignedLoad(src, align, is_volatile));
tbaa_decorate(tbaa, ctx.builder.CreateAlignedStore(val, dst, align, is_volatile));
auto val = tbaa_decorate(src_tbaa, ctx.builder.CreateAlignedLoad(src, align, is_volatile));
tbaa_decorate(dst_tbaa, ctx.builder.CreateAlignedStore(val, dst, align, is_volatile));
return;
}
}
// At the moment LLVM's memcpy only allows a single TBAA annotation
MDNode *tbaa = dst_tbaa == src_tbaa ? src_tbaa : NULL;
ctx.builder.CreateMemCpy(dst, src, sz, align, is_volatile, tbaa);
}

static void emit_memcpy_llvm(jl_codectx_t &ctx, Value *dst, Value *src,
Value *sz, unsigned align, bool is_volatile, MDNode *tbaa)
Value *sz, unsigned align, bool is_volatile,
MDNode *dst_tbaa, MDNode *src_tbaa)
{
if (auto const_sz = dyn_cast<ConstantInt>(sz)) {
emit_memcpy_llvm(ctx, dst, src, const_sz->getZExtValue(), align, is_volatile, tbaa);
emit_memcpy_llvm(ctx, dst, src, const_sz->getZExtValue(), align, is_volatile, dst_tbaa, src_tbaa);
return;
}
// At the moment LLVM's memcpy only allows a single TBAA annotation
MDNode *tbaa = dst_tbaa == src_tbaa ? src_tbaa : NULL;
ctx.builder.CreateMemCpy(dst, src, sz, align, is_volatile, tbaa);
}

Expand All @@ -1391,10 +1397,10 @@ static Value *get_value_ptr(jl_codectx_t &ctx, const jl_cgval_t &v)

template<typename T1, typename T2, typename T3>
static void emit_memcpy(jl_codectx_t &ctx, T1 &&dst, T2 &&src, T3 &&sz, unsigned align,
bool is_volatile=false, MDNode *tbaa=nullptr)
bool is_volatile=false, MDNode *dst_tbaa=nullptr, MDNode *src_tbaa=nullptr)
{
emit_memcpy_llvm(ctx, get_value_ptr(ctx, dst), get_value_ptr(ctx, src), sz, align,
is_volatile, tbaa);
is_volatile, dst_tbaa, src_tbaa);
}

static bool emit_getfield_unknownidx(jl_codectx_t &ctx,
Expand Down
89 changes: 81 additions & 8 deletions src/intrinsics.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,13 @@ static void jl_init_intrinsic_functions_codegen(Module *m)
args4.push_back(T_prjlvalue); \
args4.push_back(T_prjlvalue); \
args4.push_back(T_prjlvalue); \
args4.push_back(T_prjlvalue);
args4.push_back(T_prjlvalue); \
std::vector<Type *> args5(0); \
args5.push_back(T_prjlvalue); \
args5.push_back(T_prjlvalue); \
args5.push_back(T_prjlvalue); \
args5.push_back(T_prjlvalue); \
args5.push_back(T_prjlvalue);

#define ADD_I(name, nargs) do { \
Function *func = Function::Create(FunctionType::get(T_prjlvalue, args##nargs, false), \
Expand Down Expand Up @@ -598,7 +604,12 @@ static jl_cgval_t emit_runtime_pointerref(jl_codectx_t &ctx, jl_cgval_t *argv)
return emit_runtime_call(ctx, pointerref, argv, 3);
}

static jl_cgval_t emit_pointerref(jl_codectx_t &ctx, jl_cgval_t *argv)
static jl_cgval_t emit_runtime_tbaa_pointerref(jl_codectx_t &ctx, jl_cgval_t *argv)
{
return emit_runtime_call(ctx, tbaa_pointerref, argv, 4);
}

static jl_cgval_t emit_pointerref_internal(jl_codectx_t &ctx, jl_cgval_t *argv, MDNode *src_tbaa)
{
const jl_cgval_t &e = argv[0];
const jl_cgval_t &i = argv[1];
Expand Down Expand Up @@ -643,25 +654,56 @@ static jl_cgval_t emit_pointerref(jl_codectx_t &ctx, jl_cgval_t *argv)
LLT_ALIGN(size, jl_datatype_align(ety))));
Value *thePtr = emit_unbox(ctx, T_pint8, e, e.typ);
thePtr = ctx.builder.CreateGEP(T_int8, emit_bitcast(ctx, thePtr, T_pint8), im1);
emit_memcpy(ctx, strct, thePtr, size, 1);
emit_memcpy(ctx, strct, thePtr, size, align_nb, /* volatile= */ 0, /* dst_tbaa */ NULL, src_tbaa);
return mark_julia_type(ctx, strct, true, ety);
}
else {
bool isboxed;
Type *ptrty = julia_type_to_llvm(ety, &isboxed);
assert(!isboxed);
Value *thePtr = emit_unbox(ctx, ptrty->getPointerTo(), e, e.typ);
return typed_load(ctx, thePtr, im1, ety, tbaa_data, true, align_nb);
return typed_load(ctx, thePtr, im1, ety, src_tbaa, true, align_nb);
}
}

static jl_cgval_t emit_pointerref(jl_codectx_t &ctx, jl_cgval_t *argv)
{
return emit_pointerref_internal(ctx, argv, tbaa_data);
}

static jl_cgval_t emit_tbaa_pointerref(jl_codectx_t &ctx, jl_cgval_t *argv)
{
const jl_cgval_t &t = argv[0];

if ((t.constant == NULL || !jl_is_type(t.constant)))
return emit_runtime_tbaa_pointerref(ctx, argv);

jl_value_t *tbaa_type = t.constant;
// For now, we only allow array types with concrete element types
if (!jl_is_array_type(tbaa_type))
return emit_runtime_tbaa_pointerref(ctx, argv);

jl_value_t *array_eltype = jl_tparam0(tbaa_type);
if ((!jl_isbits(array_eltype) && !jl_is_structtype(array_eltype)) ||
jl_is_array_type(array_eltype) || !jl_is_concrete_type(array_eltype))
return emit_runtime_tbaa_pointerref(ctx, argv);

// TODO: Once we have stronger tbaa, pick the correct type based on tbaa_type here.
return emit_pointerref_internal(ctx, argv+1, tbaa_arraybuf);
}

static jl_cgval_t emit_runtime_pointerset(jl_codectx_t &ctx, jl_cgval_t *argv)
{
return emit_runtime_call(ctx, pointerset, argv, 4);
}

static jl_cgval_t emit_runtime_tbaa_pointerset(jl_codectx_t &ctx, jl_cgval_t *argv)
{
return emit_runtime_call(ctx, pointerset, argv, 4);
}

// e[i] = x
static jl_cgval_t emit_pointerset(jl_codectx_t &ctx, jl_cgval_t *argv)
static jl_cgval_t emit_pointerset_internal(jl_codectx_t &ctx, jl_cgval_t *argv, MDNode *dest_tbaa)
{
const jl_cgval_t &e = argv[0];
const jl_cgval_t &x = argv[1];
Expand Down Expand Up @@ -696,7 +738,7 @@ static jl_cgval_t emit_pointerset(jl_codectx_t &ctx, jl_cgval_t *argv)
Instruction *store = ctx.builder.CreateAlignedStore(
emit_pointer_from_objref(ctx, boxed(ctx, x)),
ctx.builder.CreateGEP(T_size, thePtr, im1), align_nb);
tbaa_decorate(tbaa_data, store);
tbaa_decorate(dest_tbaa, store);
}
else if (!jl_isbits(ety)) {
if (!jl_is_structtype(ety) || jl_is_array_type(ety) || !jl_is_concrete_type(ety)) {
Expand All @@ -707,18 +749,45 @@ static jl_cgval_t emit_pointerset(jl_codectx_t &ctx, jl_cgval_t *argv)
uint64_t size = jl_datatype_size(ety);
im1 = ctx.builder.CreateMul(im1, ConstantInt::get(T_size,
LLT_ALIGN(size, jl_datatype_align(ety))));
emit_memcpy(ctx, ctx.builder.CreateGEP(T_int8, thePtr, im1), x, size, align_nb);
emit_memcpy(ctx, ctx.builder.CreateGEP(T_int8, thePtr, im1), x, size, align_nb,
/* volatile= */0, dest_tbaa);
}
else {
bool isboxed;
Type *ptrty = julia_type_to_llvm(ety, &isboxed);
assert(!isboxed);
thePtr = emit_unbox(ctx, ptrty->getPointerTo(), e, e.typ);
typed_store(ctx, thePtr, im1, x, ety, tbaa_data, NULL, align_nb);
typed_store(ctx, thePtr, im1, x, ety, dest_tbaa, NULL, align_nb);
}
return mark_julia_type(ctx, thePtr, false, aty);
}

static jl_cgval_t emit_pointerset(jl_codectx_t &ctx, jl_cgval_t *argv)
{
return emit_pointerset_internal(ctx, argv, tbaa_data);
}

static jl_cgval_t emit_tbaa_pointerset(jl_codectx_t &ctx, jl_cgval_t *argv)
{
const jl_cgval_t &t = argv[0];

if ((t.constant == NULL || !jl_is_type(t.constant)))
return emit_runtime_tbaa_pointerset(ctx, argv);

jl_value_t *tbaa_type = t.constant;
// For now, we only allow array types with concrete element types
if (!jl_is_array_type(tbaa_type))
return emit_runtime_tbaa_pointerset(ctx, argv);

jl_value_t *array_eltype = jl_tparam0(tbaa_type);
if ((!jl_isbits(array_eltype) && !jl_is_structtype(array_eltype)) ||
jl_is_array_type(array_eltype) || !jl_is_concrete_type(array_eltype))
return emit_runtime_tbaa_pointerset(ctx, argv);

// TODO: Once we have stronger tbaa, pick the correct type based on tbaa_type here.
return emit_pointerset_internal(ctx, argv+1, tbaa_arraybuf);
}

static Value *emit_checked_srem_int(jl_codectx_t &ctx, Value *x, Value *den)
{
Type *t = den->getType();
Expand Down Expand Up @@ -939,6 +1008,10 @@ static jl_cgval_t emit_intrinsic(jl_codectx_t &ctx, intrinsic f, jl_value_t **ar
return emit_pointerref(ctx, argv);
case pointerset:
return emit_pointerset(ctx, argv);
case tbaa_pointerref:
return emit_tbaa_pointerref(ctx, argv);
case tbaa_pointerset:
return emit_tbaa_pointerset(ctx, argv);
case bitcast:
return generic_bitcast(ctx, argv);
case trunc_int:
Expand Down
2 changes: 2 additions & 0 deletions src/intrinsics.h
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,8 @@
/* pointer access */ \
ADD_I(pointerref, 3) \
ADD_I(pointerset, 4) \
ADD_I(tbaa_pointerref, 4) \
ADD_I(tbaa_pointerset, 5) \
/* c interface */ \
ADD_I(cglobal, 2) \
ALIAS(llvmcall, llvmcall) \
Expand Down
2 changes: 2 additions & 0 deletions src/julia_internal.h
Original file line number Diff line number Diff line change
Expand Up @@ -700,6 +700,8 @@ unsigned jl_intrinsic_nargs(int f);
JL_DLLEXPORT jl_value_t *jl_bitcast(jl_value_t *ty, jl_value_t *v);
JL_DLLEXPORT jl_value_t *jl_pointerref(jl_value_t *p, jl_value_t *i, jl_value_t *align);
JL_DLLEXPORT jl_value_t *jl_pointerset(jl_value_t *p, jl_value_t *x, jl_value_t *align, jl_value_t *i);
JL_DLLEXPORT jl_value_t *jl_tbaa_pointerref(jl_value_t *t, jl_value_t *p, jl_value_t *i, jl_value_t *align);
JL_DLLEXPORT jl_value_t *jl_tbaa_pointerset(jl_value_t *t, jl_value_t *p, jl_value_t *x, jl_value_t *align, jl_value_t *i);
JL_DLLEXPORT jl_value_t *jl_cglobal(jl_value_t *v, jl_value_t *ty);
JL_DLLEXPORT jl_value_t *jl_cglobal_auto(jl_value_t *v);

Expand Down
27 changes: 27 additions & 0 deletions src/runtime_intrinsics.c
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,33 @@ JL_DLLEXPORT jl_value_t *jl_pointerset(jl_value_t *p, jl_value_t *x, jl_value_t
return p;
}

static void check_tbaa_type(jl_value_t *t)
{
// For now, we only allow array types with concrete element types
if (!jl_is_array_type(t))
jl_error("tbaa_pointer(set/ref): Type argument must be an array type");

jl_value_t *array_eltype = jl_tparam0(t);
if ((!jl_isbits(array_eltype) && !jl_is_structtype(array_eltype)) ||
jl_is_array_type(array_eltype) || !jl_is_concrete_type(array_eltype))
jl_error("tbaa_pointer(set/ref): TBAA array element type must be isbits or a structtype"
", not an array and concrete");
}

JL_DLLEXPORT jl_value_t *jl_tbaa_pointerref(jl_value_t *t, jl_value_t *p, jl_value_t *i, jl_value_t *align)
{
JL_TYPECHK(tbaa_pointerref, type, t);
check_tbaa_type(t);
return jl_pointerref(p, i, align);
}

JL_DLLEXPORT jl_value_t *jl_tbaa_pointerset(jl_value_t *t, jl_value_t *p, jl_value_t *x, jl_value_t *i, jl_value_t *align)
{
JL_TYPECHK(tbaa_pointerset, type, t);
check_tbaa_type(t);
return jl_pointerset(x, p, i, align);
}

JL_DLLEXPORT jl_value_t *jl_cglobal(jl_value_t *v, jl_value_t *ty)
{
JL_TYPECHK(cglobal, type, ty);
Expand Down
6 changes: 4 additions & 2 deletions test/reinterpretarray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,8 @@ B = Complex{Int64}[5+6im, 7+8im, 9+10im]
@test reinterpret(NTuple{3, Int64}, B) == [(5,6,7),(8,9,10)]

# setindex
let Ac = copy(A), Bc = copy(B)
for (Ac, Bc) in zip((copy(A), GenericArray(copy(A))),
(copy(B), GenericArray(copy(B))))
reinterpret(Complex{Int64}, Ac)[2] = -1 - 2im
@test Ac == [1, 2, -1, -2]
reinterpret(NTuple{3, Int64}, Bc)[2] = (4,5,6)
Expand All @@ -26,7 +27,8 @@ let Ac = copy(A), Bc = copy(B)
end

# same-size reinterpret where one of the types is non-primitive
let a = NTuple{4,UInt8}[(0x01,0x02,0x03,0x04)]
for a = (NTuple{4,UInt8}[(0x01,0x02,0x03,0x04)],
GenericArray(NTuple{4,UInt8}[(0x01,0x02,0x03,0x04)]))
@test reinterpret(Float32, a)[1] == reinterpret(Float32, 0x04030201)
reinterpret(Float32, a)[1] = 2.0
@test reinterpret(Float32, a)[1] == 2.0
Expand Down

0 comments on commit a3d5e35

Please sign in to comment.