Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Name LLVM function arguments #50500

Merged
merged 5 commits into from
Jul 15, 2023
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
78 changes: 65 additions & 13 deletions src/codegen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1736,7 +1736,7 @@ jl_aliasinfo_t jl_aliasinfo_t::fromTBAA(jl_codectx_t &ctx, MDNode *tbaa) {
}

static Type *julia_type_to_llvm(jl_codectx_t &ctx, jl_value_t *jt, bool *isboxed = NULL);
static jl_returninfo_t get_specsig_function(jl_codectx_t &ctx, Module *M, Value *fval, StringRef name, jl_value_t *sig, jl_value_t *jlrettype, bool is_opaque_closure, bool gcstack_arg);
static jl_returninfo_t get_specsig_function(jl_codectx_t &ctx, Module *M, Value *fval, StringRef name, jl_value_t *sig, jl_value_t *jlrettype, bool is_opaque_closure, bool gcstack_arg, BitVector *used_arguments=nullptr, size_t *args_begin=nullptr);
static jl_cgval_t emit_expr(jl_codectx_t &ctx, jl_value_t *expr, ssize_t ssaval = -1);
static Value *global_binding_pointer(jl_codectx_t &ctx, jl_module_t *m, jl_sym_t *s,
jl_binding_t **pbnd, bool assign);
Expand Down Expand Up @@ -2363,16 +2363,16 @@ std::unique_ptr<Module> jl_create_llvm_module(StringRef name, LLVMContext &conte

static void jl_name_jlfunc_args(jl_codegen_params_t &params, Function *F) {
assert(F->arg_size() == 3);
F->getArg(0)->setName("function");
F->getArg(1)->setName("args");
F->getArg(2)->setName("nargs");
F->getArg(0)->setName("function::Core.Function");
F->getArg(1)->setName("args::Any[]");
F->getArg(2)->setName("nargs::UInt32");
}

static void jl_name_jlfuncparams_args(jl_codegen_params_t &params, Function *F) {
assert(F->arg_size() == 4);
F->getArg(0)->setName("function");
F->getArg(1)->setName("args");
F->getArg(2)->setName("nargs");
F->getArg(0)->setName("function::Core.Function");
F->getArg(1)->setName("args::Any[]");
F->getArg(2)->setName("nargs::UInt32");
F->getArg(3)->setName("sparams");
pchintalapudi marked this conversation as resolved.
Show resolved Hide resolved
}

Expand Down Expand Up @@ -4359,7 +4359,7 @@ static jl_cgval_t emit_call_specfun_boxed(jl_codectx_t &ctx, jl_value_t *jlretty
}
jl_aliasinfo_t ai = jl_aliasinfo_t::fromTBAA(ctx, ctx.tbaa().tbaa_const);
theFptr = ai.decorateInst(ctx.builder.CreateAlignedLoad(pfunc, GV, Align(sizeof(void*))));
setName(ctx.emission_context, theFptr, namep);
setName(ctx.emission_context, theFptr, specFunctionObject);
}
else {
theFptr = jl_Module->getOrInsertFunction(specFunctionObject, ctx.types().T_jlfunc).getCallee();
Expand Down Expand Up @@ -6935,10 +6935,11 @@ static Function *gen_invoke_wrapper(jl_method_instance_t *lam, jl_value_t *jlret
return w;
}

static jl_returninfo_t get_specsig_function(jl_codectx_t &ctx, Module *M, Value *fval, StringRef name, jl_value_t *sig, jl_value_t *jlrettype, bool is_opaque_closure, bool gcstack_arg)
static jl_returninfo_t get_specsig_function(jl_codectx_t &ctx, Module *M, Value *fval, StringRef name, jl_value_t *sig, jl_value_t *jlrettype, bool is_opaque_closure, bool gcstack_arg, BitVector *used_arguments, size_t *arg_offset)
{
jl_returninfo_t props = {};
SmallVector<Type*, 8> fsig;
SmallVector<std::string, 4> argnames;
Type *rt = NULL;
Type *srt = NULL;
if (jlrettype == (jl_value_t*)jl_bottom_type) {
Expand All @@ -6956,6 +6957,7 @@ static jl_returninfo_t get_specsig_function(jl_codectx_t &ctx, Module *M, Value
props.cc = jl_returninfo_t::Union;
Type *AT = ArrayType::get(getInt8Ty(ctx.builder.getContext()), props.union_bytes);
fsig.push_back(AT->getPointerTo());
argnames.push_back("union_bytes_return");
Type *pair[] = { ctx.types().T_prjlvalue, getInt8Ty(ctx.builder.getContext()) };
rt = StructType::get(ctx.builder.getContext(), makeArrayRef(pair));
}
Expand All @@ -6980,6 +6982,7 @@ static jl_returninfo_t get_specsig_function(jl_codectx_t &ctx, Module *M, Value
// sret is always passed from alloca
assert(M);
fsig.push_back(rt->getPointerTo(M->getDataLayout().getAllocaAddrSpace()));
argnames.push_back("sret_return");
srt = rt;
rt = getVoidTy(ctx.builder.getContext());
}
Expand Down Expand Up @@ -7018,6 +7021,7 @@ static jl_returninfo_t get_specsig_function(jl_codectx_t &ctx, Module *M, Value
param.addAttribute(Attribute::NoUndef);
attrs.push_back(AttributeSet::get(ctx.builder.getContext(), param));
fsig.push_back(get_returnroots_type(ctx, props.return_roots)->getPointerTo(0));
argnames.push_back("return_roots");
}

if (gcstack_arg){
Expand All @@ -7026,9 +7030,16 @@ static jl_returninfo_t get_specsig_function(jl_codectx_t &ctx, Module *M, Value
param.addAttribute(Attribute::NonNull);
attrs.push_back(AttributeSet::get(ctx.builder.getContext(), param));
fsig.push_back(PointerType::get(JuliaType::get_ppjlvalue_ty(ctx.builder.getContext()), 0));
argnames.push_back("pgcstack_arg");
}

for (size_t i = 0; i < jl_nparams(sig); i++) {
if (arg_offset)
*arg_offset = fsig.size();
size_t nparams = jl_nparams(sig);
if (used_arguments)
used_arguments->resize(nparams);

for (size_t i = 0; i < nparams; i++) {
jl_value_t *jt = jl_tparam(sig, i);
bool isboxed = false;
Type *ty = NULL;
Expand Down Expand Up @@ -7060,6 +7071,8 @@ static jl_returninfo_t get_specsig_function(jl_codectx_t &ctx, Module *M, Value
}
attrs.push_back(AttributeSet::get(ctx.builder.getContext(), param));
fsig.push_back(ty);
if (used_arguments)
used_arguments->set(i);
}

AttributeSet FnAttrs;
Expand Down Expand Up @@ -7089,8 +7102,14 @@ static jl_returninfo_t get_specsig_function(jl_codectx_t &ctx, Module *M, Value
else
fval = emit_bitcast(ctx, fval, ftype->getPointerTo());
}
if (gcstack_arg && isa<Function>(fval))
cast<Function>(fval)->setCallingConv(CallingConv::Swift);
if (auto F = dyn_cast<Function>(fval)) {
if (gcstack_arg)
F->setCallingConv(CallingConv::Swift);
assert(F->arg_size() >= argnames.size());
for (size_t i = 0; i < argnames.size(); i++) {
F->getArg(i)->setName(argnames[i]);
}
}
props.decl = FunctionCallee(ftype, fval);
props.attrs = attributes;
return props;
Expand Down Expand Up @@ -7316,11 +7335,44 @@ static jl_llvm_functions_t
Function *f = NULL;
bool has_sret = false;
if (specsig) { // assumes !va and !needsparams
BitVector used_args;
size_t args_begin;
returninfo = get_specsig_function(ctx, M, NULL, declarations.specFunctionObject, lam->specTypes,
jlrettype, ctx.is_opaque_closure, JL_FEAT_TEST(ctx,gcstack_arg));
jlrettype, ctx.is_opaque_closure, JL_FEAT_TEST(ctx,gcstack_arg), &used_args, &args_begin);
f = cast<Function>(returninfo.decl.getCallee());
has_sret = (returninfo.cc == jl_returninfo_t::SRet || returninfo.cc == jl_returninfo_t::Union);
jl_init_function(f, ctx.emission_context.TargetTriple);
auto arg_typename = [&](size_t i) JL_NOTSAFEPOINT { return jl_symbol_name(((jl_datatype_t*)jl_tparam(lam->specTypes, i))->name->name); };
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not safe to cast this to a datatype.

size_t nreal = 0;
for (size_t i = 0; i < std::min(nreq, static_cast<size_t>(used_args.size())); i++) {
jl_sym_t *argname = slot_symbol(ctx, i);
if (argname == jl_unused_sym)
continue;
if (used_args.test(i)) {
auto &arg = *f->getArg(args_begin++);
nreal++;
auto name = jl_symbol_name(argname);
if (!name[0]) {
arg.setName(StringRef("#") + Twine(nreal) + StringRef("::") + arg_typename(i));
} else {
arg.setName(name + StringRef("::") + arg_typename(i));
}
}
}
if (va && ctx.vaSlot != -1) {
size_t vidx = 0;
for (size_t i = nreq; i < used_args.size(); i++) {
if (used_args.test(i)) {
auto &arg = *f->getArg(args_begin++);
auto type = arg_typename(i);
const char *name = jl_symbol_name(slot_symbol(ctx, ctx.vaSlot));
if (!name[0])
name = "...";
vidx++;
arg.setName(name + StringRef("[") + Twine(vidx) + StringRef("]::") + type);
}
}
}

// common pattern: see if all return statements are an argument in that
// case the apply-generic call can re-use the original box for the return
Expand Down
2 changes: 1 addition & 1 deletion test/compiler/codegen.jl
Original file line number Diff line number Diff line change
Expand Up @@ -253,7 +253,7 @@ if opt_level > 0
load_dummy_ref_ir = get_llvm(load_dummy_ref, Tuple{Int})
@test !occursin("jl_gc_pool_alloc", load_dummy_ref_ir)
# Hopefully this is reliable enough. LLVM should be able to optimize this to a direct return.
@test occursin("ret $Iptr %0", load_dummy_ref_ir)
@test occursin("ret $Iptr %\"x::$(Int)\"", load_dummy_ref_ir)
end

# Issue 22770
Expand Down
26 changes: 18 additions & 8 deletions test/llvmpasses/fastmath.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# This file is a part of Julia. License is MIT: https://julialang.org/license

# RUN: julia --startup-file=no %s %t && llvm-link -S %t/* -o %t/module.ll
# RUN: julia --startup-file=no %s %t -O && llvm-link -S %t/* -o %t/module.ll
# RUN: cat %t/module.ll | FileCheck %s

## Notes:
Expand All @@ -14,21 +14,31 @@ include(joinpath("..", "testhelpers", "llvmpasses.jl"))

import Base.FastMath

# CHECK: call fast float @llvm.sqrt.f32(float %{{[0-9]+}})
# CHECK: call fast float @llvm.sqrt.f32(float %"x::Float32")
emit(FastMath.sqrt_fast, Float32)


# Float16 operations should be performed as Float32, unless @fastmath is specified
# TODO: this is not true for platforms that natively support Float16

foo(x::T,y::T) where T = x-y == zero(T)
# LOWER: fsub half %0, %1
# FINAL: %2 = fpext half %0 to float
# FINAL: %3 = fpext half %1 to float
# FINAL: fsub half %2, %3
# CHECK: define {{(swiftcc )?}}i8 @julia_foo_{{[0-9]+}}({{.*}}half %[[X:"x::Float16"]], half %[[Y:"y::Float16"]]) {{.*}}{
# CHECK-DAG: %[[XEXT:[0-9]+]] = fpext half %[[X]] to float
# CHECK-DAG: %[[YEXT:[0-9]+]] = fpext half %[[Y]] to float
# CHECK: %[[DIFF:[0-9]+]] = fsub float %[[XEXT]], %[[YEXT]]
# CHECK: %[[TRUNC:[0-9]+]] = fptrunc float %[[DIFF]] to half
# CHECK: %[[DIFFEXT:[0-9]+]] = fpext half %[[TRUNC]] to float
# CHECK: %[[CMP:[0-9]+]] = fcmp oeq float %[[DIFFEXT]], 0.000000e+00
# CHECK: %[[ZEXT:[0-9]+]] = zext i1 %[[CMP]] to i8
# CHECK: ret i8 %[[ZEXT]]
# CHECK: }
emit(foo, Float16, Float16)

@fastmath foo(x::T,y::T) where T = x-y == zero(T)
# LOWER: fsub fast half %0, %1
# FINAL: fsub fast half %0, %1
# CHECK: define {{(swiftcc )?}}i8 @julia_foo_{{[0-9]+}}({{.*}}half %[[X:"x::Float16"]], half %[[Y:"y::Float16"]]) {{.*}}{
# CHECK: %[[DIFF:[0-9]+]] = fsub fast half %[[X]], %[[Y]]
# CHECK: %[[CMP:[0-9]+]] = fcmp fast oeq half %[[DIFF]], 0xH0000
# CHECK: %[[ZEXT:[0-9]+]] = zext i1 %[[CMP]] to i8
# CHECK: ret i8 %[[ZEXT]]
# CHECK: }
emit(foo, Float16, Float16)