diff --git a/mlir/utils/spirv/gen_spirv_dialect.py b/mlir/utils/spirv/gen_spirv_dialect.py index 70c3d9db168891..99ed3489b4cbda 100755 --- a/mlir/utils/spirv/gen_spirv_dialect.py +++ b/mlir/utils/spirv/gen_spirv_dialect.py @@ -127,44 +127,6 @@ def split_list_into_sublists(items): return chuncks -def uniquify_enum_cases(lst): - """Prunes duplicate enum cases from the list. - - Arguments: - - lst: List whose elements are to be uniqued. Assumes each element is a - (symbol, value) pair and elements already sorted according to value. - - Returns: - - A list with all duplicates removed. The elements are sorted according to - value and, for each value, uniqued according to symbol. - original list, - - A map from deduplicated cases to the uniqued case. - """ - cases = lst - uniqued_cases = [] - duplicated_cases = {} - - # First sort according to the value - cases.sort(key=lambda x: x[1]) - - # Then group them according to the value - for _, groups in itertools.groupby(cases, key=lambda x: x[1]): - # For each value, sort according to the enumerant symbol. - sorted_group = sorted(groups, key=lambda x: x[0]) - # Keep the "smallest" case, which is typically the symbol without extension - # suffix. But we have special cases that we want to fix. - case = sorted_group[0] - for i in range(1, len(sorted_group)): - duplicated_cases[sorted_group[i][0]] = case[0] - if case[0] == "HlslSemanticGOOGLE": - assert len(sorted_group) == 2, "unexpected new variant for HlslSemantic" - case = sorted_group[1] - duplicated_cases[sorted_group[0][0]] = case[0] - uniqued_cases.append(case) - - return uniqued_cases, duplicated_cases - - def toposort(dag, sort_fn): """Topologically sorts the given dag. @@ -197,14 +159,12 @@ def get_next_batch(dag): return sorted_nodes -def toposort_capabilities(all_cases, capability_mapping): +def toposort_capabilities(all_cases): """Returns topologically sorted capability (symbol, value) pairs. Arguments: - all_cases: all capability cases (containing symbol, value, and implied capabilities). - - capability_mapping: mapping from duplicated capability symbols to the - canonicalized symbol chosen for SPIRVBase.td. Returns: A list containing topologically sorted capability (symbol, value) pairs. @@ -215,13 +175,10 @@ def toposort_capabilities(all_cases, capability_mapping): # Get the current capability. cur = case["enumerant"] name_to_value[cur] = case["value"] - # Ignore duplicated symbols. - if cur in capability_mapping: - continue # Get capabilities implied by the current capability. prev = case.get("capabilities", []) - uniqued_prev = set([capability_mapping.get(c, c) for c in prev]) + uniqued_prev = set(prev) dag[cur] = uniqued_prev sorted_caps = toposort(dag, lambda x: name_to_value[x]) @@ -229,36 +186,12 @@ def toposort_capabilities(all_cases, capability_mapping): return [(c, name_to_value[c]) for c in sorted_caps] -def get_capability_mapping(operand_kinds): - """Returns the capability mapping from duplicated cases to canonicalized ones. - - Arguments: - - operand_kinds: all operand kinds' grammar spec - - Returns: - - A map mapping from duplicated capability symbols to the canonicalized - symbol chosen for SPIRVBase.td. - """ - # Find the operand kind for capability - cap_kind = {} - for kind in operand_kinds: - if kind["kind"] == "Capability": - cap_kind = kind - - kind_cases = [(case["enumerant"], case["value"]) for case in cap_kind["enumerants"]] - _, capability_mapping = uniquify_enum_cases(kind_cases) - - return capability_mapping - - -def get_availability_spec(enum_case, capability_mapping, for_op, for_cap): +def get_availability_spec(enum_case, for_op, for_cap): """Returns the availability specification string for the given enum case. Arguments: - enum_case: the enum case to generate availability spec for. It may contain 'version', 'lastVersion', 'extensions', or 'capabilities'. - - capability_mapping: mapping from duplicated capability symbols to the - canonicalized symbol chosen for SPIRVBase.td. - for_op: bool value indicating whether this is the availability spec for an op itself. - for_cap: bool value indicating whether this is the availability spec for @@ -313,10 +246,7 @@ def get_availability_spec(enum_case, capability_mapping, for_op, for_cap): if caps: canonicalized_caps = [] for c in caps: - if c in capability_mapping: - canonicalized_caps.append(capability_mapping[c]) - else: - canonicalized_caps.append(c) + canonicalized_caps.append(c) prefixed_caps = [ "SPIRV_C_{}".format(c) for c in sorted(set(canonicalized_caps)) ] @@ -357,7 +287,7 @@ def get_availability_spec(enum_case, capability_mapping, for_op, for_cap): return "{}{}{}".format(implies, "\n " if implies and avail else "", avail) -def gen_operand_kind_enum_attr(operand_kind, capability_mapping): +def gen_operand_kind_enum_attr(operand_kind): """Generates the TableGen EnumAttr definition for the given operand kind. Returns: @@ -388,13 +318,12 @@ def get_case_symbol(kind_name, case_name): # Special treatment for capability cases: we need to sort them topologically # because a capability can refer to another via the 'implies' field. kind_cases = toposort_capabilities( - operand_kind["enumerants"], capability_mapping + operand_kind["enumerants"] ) else: kind_cases = [ (case["enumerant"], case["value"]) for case in operand_kind["enumerants"] ] - kind_cases, _ = uniquify_enum_cases(kind_cases) max_len = max([len(symbol) for (symbol, _) in kind_cases]) # Generate the definition for each enum case @@ -412,7 +341,6 @@ def get_case_symbol(kind_name, case_name): value = int(case_pair[1]) avail = get_availability_spec( name_to_case_dict[name], - capability_mapping, False, kind_name == "Capability", ) @@ -648,11 +576,9 @@ def update_td_enum_attrs(path, operand_kinds, filter_list): ] filter_list.extend(existing_kinds) - capability_mapping = get_capability_mapping(operand_kinds) - # Generate definitions for all enums in filter list defs = [ - gen_operand_kind_enum_attr(kind, capability_mapping) + gen_operand_kind_enum_attr(kind) for kind in operand_kinds if kind["kind"] in filter_list ] @@ -762,7 +688,7 @@ def get_description(text, appendix): def get_op_definition( - instruction, opname, doc, existing_info, capability_mapping, settings + instruction, opname, doc, existing_info, settings ): """Generates the TableGen op definition for the given SPIR-V instruction. @@ -771,8 +697,6 @@ def get_op_definition( - doc: the instruction's SPIR-V HTML doc - existing_info: a dict containing potential manually specified sections for this instruction - - capability_mapping: mapping from duplicated capability symbols to the - canonicalized symbol chosen for SPIRVBase.td Returns: - A string containing the TableGen op definition @@ -840,7 +764,7 @@ def get_op_definition( operands = instruction.get("operands", []) # Op availability - avail = get_availability_spec(instruction, capability_mapping, True, False) + avail = get_availability_spec(instruction, True, False) if avail: avail = "\n\n {0}".format(avail) @@ -1021,7 +945,7 @@ def extract_td_op_info(op_def): def update_td_op_definitions( - path, instructions, docs, filter_list, inst_category, capability_mapping, settings + path, instructions, docs, filter_list, inst_category, settings ): """Updates SPIRVOps.td with newly generated op definition. @@ -1030,8 +954,6 @@ def update_td_op_definitions( - instructions: SPIR-V JSON grammar for all instructions - docs: SPIR-V HTML doc for all instructions - filter_list: a list containing new opnames to include - - capability_mapping: mapping from duplicated capability symbols to the - canonicalized symbol chosen for SPIRVBase.td. Returns: - A string containing all the TableGen op definitions @@ -1079,7 +1001,6 @@ def update_td_op_definitions( opname, docs[fixed_opname], op_info_dict.get(opname, {"inst_category": inst_category}), - capability_mapping, settings, ) ) @@ -1186,14 +1107,12 @@ def update_td_op_definitions( if args.new_inst is not None: assert args.op_td_path is not None docs = get_spirv_doc_from_html_spec(ext_html_url, args) - capability_mapping = get_capability_mapping(operand_kinds) update_td_op_definitions( args.op_td_path, instructions, docs, args.new_inst, args.inst_category, - capability_mapping, args, ) print("Done. Note that this script just generates a template; ", end="")