Skip to content

Commit

Permalink
lift_indirect_targets: compute the CFG where the liftee is the entry …
Browse files Browse the repository at this point in the history
…and use that to compute free variables
  • Loading branch information
Hugobros3 committed Apr 27, 2024
1 parent b5552d1 commit eb37f77
Showing 1 changed file with 19 additions and 17 deletions.
36 changes: 19 additions & 17 deletions src/shady/passes/lift_indirect_targets.c
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@ typedef struct Context_ {
Rewriter rewriter;
CFG* cfg;
const UsesMap* uses;
struct Dict* live_vars;

struct Dict* lifted;
bool disable_lowering;
Expand Down Expand Up @@ -83,27 +82,32 @@ static void add_to_recover_context(struct List* recover_context, struct Dict* se
}
}

static LiftedCont* lambda_lift(Context* ctx, const Node* cont, String given_name) {
assert(is_basic_block(cont) || is_case(cont));
LiftedCont** found = find_value_dict(const Node*, LiftedCont*, ctx->lifted, cont);
static LiftedCont* lambda_lift(Context* ctx, const Node* liftee, String given_name) {
assert(is_basic_block(liftee) || is_case(liftee));
LiftedCont** found = find_value_dict(const Node*, LiftedCont*, ctx->lifted, liftee);
if (found)
return *found;

IrArena* a = ctx->rewriter.dst_arena;
Nodes oparams = get_abstraction_params(cont);
const Node* obody = get_abstraction_body(cont);
Nodes oparams = get_abstraction_params(liftee);
const Node* obody = get_abstraction_body(liftee);

String name = is_basic_block(cont) ? format_string_arena(a->arena, "%s_%s", get_abstraction_name(cont->payload.basic_block.fn), get_abstraction_name(cont)) : unique_name(a, given_name);
String name = is_basic_block(liftee) ? format_string_arena(a->arena, "%s_%s", get_abstraction_name(liftee->payload.basic_block.fn), get_abstraction_name(liftee)) : unique_name(a, given_name);

// Compute the live stuff we'll need
CFNode* cf_node = cfg_lookup(ctx->cfg, cont);
CFNodeVariables* node_vars = *find_value_dict(CFNode*, CFNodeVariables*, ctx->live_vars, cf_node);
CFG* cfg_rooted_in_liftee = build_cfg(ctx->cfg->entry->node, liftee, NULL, false);
CFNode* cf_node = cfg_lookup(cfg_rooted_in_liftee, liftee);
struct Dict* live_vars = compute_cfg_variables_map(cfg_rooted_in_liftee);
CFNodeVariables* node_vars = *find_value_dict(CFNode*, CFNodeVariables*, live_vars, cf_node);
struct List* recover_context = new_list(const Node*);

add_to_recover_context(recover_context, node_vars->bound_set, cont);
add_to_recover_context(recover_context, node_vars->free_set, liftee);
size_t recover_context_size = entries_count_list(recover_context);

debugv_print("lambda_lift: free (to-be-spilled) variables at '%s' (count=%d): ", name, recover_context_size);
destroy_cfg_variables_map(live_vars);
destroy_cfg(cfg_rooted_in_liftee);

debugv_print("lambda_lift: free (to-be-spilled) variables at '%s' (count=%d): ", get_abstraction_name_safe(liftee), recover_context_size);
for (size_t i = 0; i < recover_context_size; i++) {
const Node* item = read_list(const Node*, recover_context)[i];
debugv_print("%s %%%d", get_value_name(item) ? get_value_name(item) : "", item->id);
Expand All @@ -116,9 +120,9 @@ static LiftedCont* lambda_lift(Context* ctx, const Node* cont, String given_name
Nodes new_params = recreate_variables(&ctx->rewriter, oparams);

LiftedCont* lifted_cont = calloc(sizeof(LiftedCont), 1);
lifted_cont->old_cont = cont;
lifted_cont->old_cont = liftee;
lifted_cont->save_values = recover_context;
insert_dict(const Node*, LiftedCont*, ctx->lifted, cont, lifted_cont);
insert_dict(const Node*, LiftedCont*, ctx->lifted, liftee, lifted_cont);

Context lifting_ctx = *ctx;
lifting_ctx.rewriter = create_children_rewriter(&ctx->rewriter);
Expand Down Expand Up @@ -175,15 +179,13 @@ static const Node* process_node(Context* ctx, const Node* node) {
Context fn_ctx = *ctx;
fn_ctx.cfg = build_fn_cfg(node);
fn_ctx.uses = create_uses_map(node, (NcDeclaration | NcType));
fn_ctx.live_vars = compute_cfg_variables_map(fn_ctx.cfg);
fn_ctx.disable_lowering = lookup_annotation(node, "Internal");
ctx = &fn_ctx;

Node* new = recreate_decl_header_identity(&ctx->rewriter, node);
recreate_decl_body_identity(&ctx->rewriter, node, new);

destroy_uses_map(ctx->uses);
destroy_cfg_variables_map(ctx->live_vars);
destroy_cfg(ctx->cfg);
return new;
}
Expand All @@ -205,7 +207,7 @@ static const Node* process_node(Context* ctx, const Node* node) {

const Node* otail = get_let_tail(node);
BodyBuilder* bb = begin_body(a);
LiftedCont* lifted_tail = lambda_lift(ctx, otail, unique_name(a, format_string_arena(a->arena, "post_control_%s", get_abstraction_name(ctx->cfg->entry->node))));
LiftedCont* lifted_tail = lambda_lift(ctx, otail, unique_name(a, format_string_arena(a->arena, "lifted %s", get_abstraction_name_safe(otail))));
const Node* sp = add_spill_instrs(ctx, bb, lifted_tail->save_values);
const Node* tail_ptr = fn_addr_helper(a, lifted_tail->lifted_fn);

Expand Down Expand Up @@ -253,7 +255,7 @@ Module* lift_indirect_targets(const CompilerConfig* config, Module* src) {
}
destroy_dict(ctx.lifted);
destroy_rewriter(&ctx.rewriter);
log_module(DEBUGVV, config, dst);
// log_module(DEBUGVV, config, dst);
verify_module(config, dst);
src = dst;
if (oa)
Expand Down

0 comments on commit eb37f77

Please sign in to comment.