Skip to content

Commit

Permalink
prefix more type-related stuff, move it to type.h
Browse files Browse the repository at this point in the history
  • Loading branch information
Hugobros3 committed Oct 9, 2024
1 parent b1f59cd commit 356af2d
Showing 47 changed files with 636 additions and 619 deletions.
46 changes: 30 additions & 16 deletions include/shady/ir/type.h
Original file line number Diff line number Diff line change
@@ -15,42 +15,56 @@ Type* nominal_type(Module*, Nodes annotations, String name);

const Type* shd_get_actual_mask_type(IrArena* arena);

String get_address_space_name(AddressSpace);
/// Returns false iff pointers in that address space can contain different data at the same address
/// (amongst threads in the same subgroup)
bool shd_is_addr_space_uniform(IrArena*, AddressSpace);

String shd_get_type_name(IrArena* arena, const Type* t);
bool shd_is_subtype(const Type* supertype, const Type* type);
void shd_check_subtype(const Type* supertype, const Type* type);

/// Is this a type that a value in the language can have ?
bool is_value_type(const Type*);
bool shd_is_value_type(const Type*);

/// Is this a valid data type (for usage in other types and as type arguments) ?
bool is_data_type(const Type*);
bool shd_is_data_type(const Type*);

bool shd_is_arithm_type(const Type*);
bool shd_is_shiftable_type(const Type*);
bool shd_has_boolean_ops(const Type*);
bool shd_is_comparable_type(const Type*);
bool shd_is_ordered_type(const Type*);
bool shd_is_physical_ptr_type(const Type* t);
bool shd_is_generic_ptr_type(const Type* t);

bool shd_is_reinterpret_cast_legal(const Type* src_type, const Type* dst_type);
bool shd_is_conversion_legal(const Type* src_type, const Type* dst_type);

/// Returns the (possibly qualified) pointee type from a (possibly qualified) ptr type
const Type* get_pointee_type(IrArena*, const Type*);
const Type* shd_get_pointee_type(IrArena*, const Type*);

String shd_get_address_space_name(AddressSpace);
/// Returns false iff pointers in that address space can contain different data at the same address
/// (amongst threads in the same subgroup)
bool shd_is_addr_space_uniform(IrArena*, AddressSpace);

String shd_get_type_name(IrArena* arena, const Type* t);

const Type* maybe_multiple_return(IrArena* arena, Nodes types);
Nodes unwrap_multiple_yield_types(IrArena* arena, const Type* type);

/// Collects the annotated types in the list of variables
/// NB: this is different from get_values_types, that function uses node.type, whereas this one uses node.payload.var.type
/// This means this function works in untyped modules where node.type is NULL.
Nodes get_param_types(IrArena* arena, Nodes variables);
Nodes shd_get_param_types(IrArena* arena, Nodes variables);

Nodes get_values_types(IrArena*, Nodes);
Nodes shd_get_values_types(IrArena*, Nodes);

// Qualified type helpers
/// Ensures an operand has divergence-annotated type and extracts it
const Type* get_unqualified_type(const Type*);
bool is_qualified_type_uniform(const Type*);
bool deconstruct_qualified_type(const Type**);
const Type* shd_get_unqualified_type(const Type*);
bool shd_is_qualified_type_uniform(const Type*);
bool shd_deconstruct_qualified_type(const Type**);

const Type* shd_as_qualified_type(const Type* type, bool uniform);

Nodes strip_qualifiers(IrArena*, Nodes);
Nodes add_qualifiers(IrArena*, Nodes, bool);
Nodes shd_strip_qualifiers(IrArena*, Nodes);
Nodes shd_add_qualifiers(IrArena*, Nodes, bool);

// Pack (vector) type helpers
const Type* get_packed_type_element(const Type*);
8 changes: 4 additions & 4 deletions src/backend/c/emit_c.c
Original file line number Diff line number Diff line change
@@ -189,8 +189,8 @@ void c_emit_global_variable_definition(Emitter* emitter, AddressSpace as, String
break;
}
default: {
prefix = shd_format_string_arena(emitter->arena->arena, "/* %s */", get_address_space_name(as));
shd_warn_print("warning: address space %s not supported in CUDA for global variables\n", get_address_space_name(as));
prefix = shd_format_string_arena(emitter->arena->arena, "/* %s */", shd_get_address_space_name(as));
shd_warn_print("warning: address space %s not supported in CUDA for global variables\n", shd_get_address_space_name(as));
break;
}
}
@@ -209,8 +209,8 @@ void c_emit_global_variable_definition(Emitter* emitter, AddressSpace as, String
break;
}
default: {
prefix = shd_format_string_arena(emitter->arena->arena, "/* %s */", get_address_space_name(as));
shd_warn_print("warning: address space %s not supported in GLSL for global variables\n", get_address_space_name(as));
prefix = shd_format_string_arena(emitter->arena->arena, "/* %s */", shd_get_address_space_name(as));
shd_warn_print("warning: address space %s not supported in GLSL for global variables\n", shd_get_address_space_name(as));
break;
}
}
2 changes: 1 addition & 1 deletion src/backend/c/emit_c_control_flow.c
Original file line number Diff line number Diff line change
@@ -148,7 +148,7 @@ static void emit_loop(Emitter* emitter, FnEmitter* fn, Printer* p, Loop loop) {
arr[i] = unique_name(emitter->arena, "phi");
}
Strings param_names = shd_strings(emitter->arena, variables.count, arr);
Strings eparams = emit_variable_declarations(emitter, fn, p, NULL, &param_names, get_param_types(emitter->arena, params), true, &loop.initial_args);
Strings eparams = emit_variable_declarations(emitter, fn, p, NULL, &param_names, shd_get_param_types(emitter->arena, params), true, &loop.initial_args);
for (size_t i = 0; i < params.count; i++)
register_emitted(&sub_emiter, fn, params.nodes[i], term_from_cvalue(eparams.strings[i]));

36 changes: 18 additions & 18 deletions src/backend/c/emit_c_value.c
Original file line number Diff line number Diff line change
@@ -246,7 +246,7 @@ CTerm c_bind_intermediary_result(Emitter* emitter, Printer* p, const Type* t, CT

static const Type* get_first_op_scalar_type(Nodes ops) {
const Type* t = shd_first(ops)->type;
deconstruct_qualified_type(&t);
shd_deconstruct_qualified_type(&t);
deconstruct_maybe_packed_type(&t);
return t;
}
@@ -495,7 +495,7 @@ static CTerm emit_primop(Emitter* emitter, FnEmitter* fn, Printer* p, const Node
const Node* offset = prim_op->operands.nodes[1];
CValue c_offset = to_cvalue(emitter, c_emit_value(emitter, fn, offset));
if (emitter->config.dialect == CDialect_GLSL) {
if (get_unqualified_type(offset->type)->payload.int_type.width == IntTy64)
if (shd_get_unqualified_type(offset->type)->payload.int_type.width == IntTy64)
c_offset = shd_format_string_arena(arena->arena, "int(%s)", c_offset);
}
term = term_from_cvalue(shd_format_string_arena(arena->arena, "(%s %s %s)", src, prim_op->op == lshift_op ? "<<" : ">>", c_offset));
@@ -527,7 +527,7 @@ static CTerm emit_primop(Emitter* emitter, FnEmitter* fn, Printer* p, const Node
}
case convert_op: {
CTerm src = c_emit_value(emitter, fn, shd_first(prim_op->operands));
const Type* src_type = get_unqualified_type(shd_first(prim_op->operands)->type);
const Type* src_type = shd_get_unqualified_type(shd_first(prim_op->operands)->type);
const Type* dst_type = shd_first(prim_op->type_arguments);
if (emitter->config.dialect == CDialect_GLSL) {
if (is_glsl_scalar_type(src_type) && is_glsl_scalar_type(dst_type)) {
@@ -543,7 +543,7 @@ static CTerm emit_primop(Emitter* emitter, FnEmitter* fn, Printer* p, const Node
}
case reinterpret_op: {
CTerm src_value = c_emit_value(emitter, fn, shd_first(prim_op->operands));
const Type* src_type = get_unqualified_type(shd_first(prim_op->operands)->type);
const Type* src_type = shd_get_unqualified_type(shd_first(prim_op->operands)->type);
const Type* dst_type = shd_first(prim_op->type_arguments);
switch (emitter->config.dialect) {
case CDialect_CUDA:
@@ -625,7 +625,7 @@ static CTerm emit_primop(Emitter* emitter, FnEmitter* fn, Printer* p, const Node
term = term_from_cvalue(dst);
}

const Type* t = get_unqualified_type(shd_first(prim_op->operands)->type);
const Type* t = shd_get_unqualified_type(shd_first(prim_op->operands)->type);
for (size_t i = (insert ? 2 : 1); i < prim_op->operands.count; i++) {
const Node* index = prim_op->operands.nodes[i];
const IntLiteral* static_index = shd_resolve_to_int_literal(index);
@@ -678,8 +678,8 @@ static CTerm emit_primop(Emitter* emitter, FnEmitter* fn, Printer* p, const Node
String rhs_e = to_cvalue(emitter, c_emit_value(emitter, fn, prim_op->operands.nodes[1]));
const Type* lhs_t = lhs->type;
const Type* rhs_t = rhs->type;
bool lhs_u = deconstruct_qualified_type(&lhs_t);
bool rhs_u = deconstruct_qualified_type(&rhs_t);
bool lhs_u = shd_deconstruct_qualified_type(&lhs_t);
bool rhs_u = shd_deconstruct_qualified_type(&rhs_t);
size_t left_size = lhs_t->payload.pack_type.width;
// size_t total_size = lhs_t->payload.pack_type.width + rhs_t->payload.pack_type.width;
String suffixes = "xyzw";
@@ -793,13 +793,13 @@ static CTerm emit_ptr_composite_element(Emitter* emitter, FnEmitter* fn, Printer
CTerm acc = c_emit_value(emitter, fn, lea.ptr);

const Type* src_qtype = lea.ptr->type;
bool uniform = is_qualified_type_uniform(src_qtype);
const Type* curr_ptr_type = get_unqualified_type(src_qtype);
bool uniform = shd_is_qualified_type_uniform(src_qtype);
const Type* curr_ptr_type = shd_get_unqualified_type(src_qtype);
assert(curr_ptr_type->tag == PtrType_TAG);

const Type* pointee_type = get_pointee_type(arena, curr_ptr_type);
const Type* pointee_type = shd_get_pointee_type(arena, curr_ptr_type);
const Node* selector = lea.index;
uniform &= is_qualified_type_uniform(selector->type);
uniform &= shd_is_qualified_type_uniform(selector->type);
switch (is_type(pointee_type)) {
case ArrType_TAG: {
CTerm index = c_emit_value(emitter, fn, selector);
@@ -859,8 +859,8 @@ static CTerm emit_ptr_array_element_offset(Emitter* emitter, FnEmitter* fn, Prin
CTerm acc = c_emit_value(emitter, fn, lea.ptr);

const Type* src_qtype = lea.ptr->type;
bool uniform = is_qualified_type_uniform(src_qtype);
const Type* curr_ptr_type = get_unqualified_type(src_qtype);
bool uniform = shd_is_qualified_type_uniform(src_qtype);
const Type* curr_ptr_type = shd_get_unqualified_type(src_qtype);
assert(curr_ptr_type->tag == PtrType_TAG);

const IntLiteral* offset_static_value = shd_resolve_to_int_literal(lea.offset);
@@ -869,9 +869,9 @@ static CTerm emit_ptr_array_element_offset(Emitter* emitter, FnEmitter* fn, Prin
// we sadly need to drop to the value level (aka explicit pointer arithmetic) to do this
// this means such code is never going to be legal in GLSL
// also the cast is to account for our arrays-in-structs hack
const Type* pointee_type = get_pointee_type(arena, curr_ptr_type);
const Type* pointee_type = shd_get_pointee_type(arena, curr_ptr_type);
acc = term_from_cvalue(shd_format_string_arena(arena->arena, "((%s) &(%s)[%s])", c_emit_type(emitter, curr_ptr_type, NULL), to_cvalue(emitter, acc), to_cvalue(emitter, offset)));
uniform &= is_qualified_type_uniform(lea.offset->type);
uniform &= shd_is_qualified_type_uniform(lea.offset->type);
}

if (emitter->config.dialect == CDialect_ISPC)
@@ -893,7 +893,7 @@ static CTerm emit_alloca(Emitter* emitter, Printer* p, const Type* instr) {
CTerm variable = (CTerm) { .value = NULL, .var = variable_name };
c_emit_variable_declaration(emitter, p, get_allocated_type(instr), variable_name, true, NULL);
if (emitter->config.dialect == CDialect_ISPC) {
variable = ispc_varying_ptr_helper(emitter, p, get_unqualified_type(instr->type), variable);
variable = ispc_varying_ptr_helper(emitter, p, shd_get_unqualified_type(instr->type), variable);
}
return variable;
}
@@ -927,8 +927,8 @@ static CTerm emit_instruction(Emitter* emitter, FnEmitter* fn, Printer* p, const
Store payload = instruction->payload.store;
c_emit_mem(emitter, fn, payload.mem);
const Type* addr_type = payload.ptr->type;
bool addr_uniform = deconstruct_qualified_type(&addr_type);
bool value_uniform = is_qualified_type_uniform(payload.value->type);
bool addr_uniform = shd_deconstruct_qualified_type(&addr_type);
bool value_uniform = shd_is_qualified_type_uniform(payload.value->type);
assert(addr_type->tag == PtrType_TAG);
CAddr dereferenced = deref_term(emitter, c_emit_value(emitter, fn, payload.ptr));
CValue cvalue = to_cvalue(emitter, c_emit_value(emitter, fn, payload.value));
2 changes: 1 addition & 1 deletion src/backend/spirv/emit_spv.c
Original file line number Diff line number Diff line change
@@ -94,7 +94,7 @@ static void emit_function(Emitter* emitter, const Node* node) {
const Type* param_type = param->payload.param.type;
SpvId param_id = spvb_parameter(fn_builder.base, spv_emit_type(emitter, param_type));
spv_register_emitted(emitter, false, param, param_id);
deconstruct_qualified_type(&param_type);
shd_deconstruct_qualified_type(&param_type);
if (param_type->tag == PtrType_TAG && param_type->payload.ptr_type.address_space == AsGlobal) {
spvb_decorate(emitter->file_builder, param_id, SpvDecorationAliased, 0, NULL);
}
6 changes: 3 additions & 3 deletions src/backend/spirv/emit_spv_control_flow.c
Original file line number Diff line number Diff line change
@@ -52,13 +52,13 @@ static void emit_match(Emitter* emitter, FnBuilder* fn_builder, BBBuilder bb_bui
spv_emit_mem(emitter, fn_builder, match.mem);
SpvId join_bb_id = spv_find_emitted(emitter, fn_builder, match.tail);

assert(get_unqualified_type(match.inspect->type)->tag == Int_TAG);
assert(shd_get_unqualified_type(match.inspect->type)->tag == Int_TAG);
SpvId inspectee = spv_emit_value(emitter, fn_builder, match.inspect);

SpvId default_id = spv_find_emitted(emitter, fn_builder, match.default_case);

const Type* inspectee_t = match.inspect->type;
deconstruct_qualified_type(&inspectee_t);
shd_deconstruct_qualified_type(&inspectee_t);
assert(inspectee_t->tag == Int_TAG);
size_t literal_width = inspectee_t->payload.int_type.width == IntTy64 ? 2 : 1;
size_t literal_case_entry_size = literal_width + 1;
@@ -93,7 +93,7 @@ static void emit_loop(Emitter* emitter, FnBuilder* fn_builder, BBBuilder bb_buil
Nodes body_params = get_abstraction_params(loop_instr.body);
LARRAY(SpvbPhi*, loop_continue_phis, body_params.count);
for (size_t i = 0; i < body_params.count; i++) {
SpvId loop_param_type = spv_emit_type(emitter, get_unqualified_type(body_params.nodes[i]->type));
SpvId loop_param_type = spv_emit_type(emitter, shd_get_unqualified_type(body_params.nodes[i]->type));

SpvId continue_phi_id = spvb_fresh_id(emitter->file_builder);
SpvbPhi* continue_phi = spvb_add_phi(continue_builder, loop_param_type, continue_phi_id);
4 changes: 2 additions & 2 deletions src/backend/spirv/emit_spv_type.c
Original file line number Diff line number Diff line change
@@ -32,7 +32,7 @@ SpvStorageClass spv_emit_addr_space(Emitter* emitter, AddressSpace address_space
case AsUniformConstant: return SpvStorageClassUniformConstant;

default: {
shd_error_print("Cannot emit address space %s.\n", get_address_space_name(address_space));
shd_error_print("Cannot emit address space %s.\n", shd_get_address_space_name(address_space));
shd_error_die();
SHADY_UNREACHABLE;
}
@@ -215,7 +215,7 @@ SpvId spv_emit_type(Emitter* emitter, const Type* type) {
case Type_JoinPointType_TAG: shd_error("These must be lowered beforehand")
}

if (is_data_type(type)) {
if (shd_is_data_type(type)) {
if (type->tag == PtrType_TAG && type->payload.ptr_type.address_space == AsGlobal) {
//TypeMemLayout elem_mem_layout = get_mem_layout(emitter->arena, type->payload.ptr_type.pointed_type);
//spvb_decorate(emitter->file_builder, new, SpvDecorationArrayStride, 1, (uint32_t[]) {elem_mem_layout.size_in_bytes});
Loading

0 comments on commit 356af2d

Please sign in to comment.