Skip to content

Commit

Permalink
[mlir][spirv] Remove code for de-duplicating symbols in SPIR-V grammar (
Browse files Browse the repository at this point in the history
llvm#111778)

SPIR-V grammar was updated in upstream to have an "aliases" field
instead of duplicating symbols with same values. See
KhronosGroup/SPIRV-Headers#447 for details.
  • Loading branch information
cmarcelo authored Oct 30, 2024
1 parent 67c4857 commit 6e75eec
Showing 1 changed file with 10 additions and 91 deletions.
101 changes: 10 additions & 91 deletions mlir/utils/spirv/gen_spirv_dialect.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,44 +127,6 @@ def split_list_into_sublists(items):
return chuncks


def uniquify_enum_cases(lst):
"""Prunes duplicate enum cases from the list.
Arguments:
- lst: List whose elements are to be uniqued. Assumes each element is a
(symbol, value) pair and elements already sorted according to value.
Returns:
- A list with all duplicates removed. The elements are sorted according to
value and, for each value, uniqued according to symbol.
original list,
- A map from deduplicated cases to the uniqued case.
"""
cases = lst
uniqued_cases = []
duplicated_cases = {}

# First sort according to the value
cases.sort(key=lambda x: x[1])

# Then group them according to the value
for _, groups in itertools.groupby(cases, key=lambda x: x[1]):
# For each value, sort according to the enumerant symbol.
sorted_group = sorted(groups, key=lambda x: x[0])
# Keep the "smallest" case, which is typically the symbol without extension
# suffix. But we have special cases that we want to fix.
case = sorted_group[0]
for i in range(1, len(sorted_group)):
duplicated_cases[sorted_group[i][0]] = case[0]
if case[0] == "HlslSemanticGOOGLE":
assert len(sorted_group) == 2, "unexpected new variant for HlslSemantic"
case = sorted_group[1]
duplicated_cases[sorted_group[0][0]] = case[0]
uniqued_cases.append(case)

return uniqued_cases, duplicated_cases


def toposort(dag, sort_fn):
"""Topologically sorts the given dag.
Expand Down Expand Up @@ -197,14 +159,12 @@ def get_next_batch(dag):
return sorted_nodes


def toposort_capabilities(all_cases, capability_mapping):
def toposort_capabilities(all_cases):
"""Returns topologically sorted capability (symbol, value) pairs.
Arguments:
- all_cases: all capability cases (containing symbol, value, and implied
capabilities).
- capability_mapping: mapping from duplicated capability symbols to the
canonicalized symbol chosen for SPIRVBase.td.
Returns:
A list containing topologically sorted capability (symbol, value) pairs.
Expand All @@ -215,50 +175,23 @@ def toposort_capabilities(all_cases, capability_mapping):
# Get the current capability.
cur = case["enumerant"]
name_to_value[cur] = case["value"]
# Ignore duplicated symbols.
if cur in capability_mapping:
continue

# Get capabilities implied by the current capability.
prev = case.get("capabilities", [])
uniqued_prev = set([capability_mapping.get(c, c) for c in prev])
uniqued_prev = set(prev)
dag[cur] = uniqued_prev

sorted_caps = toposort(dag, lambda x: name_to_value[x])
# Attach the capability's value as the second component of the pair.
return [(c, name_to_value[c]) for c in sorted_caps]


def get_capability_mapping(operand_kinds):
"""Returns the capability mapping from duplicated cases to canonicalized ones.
Arguments:
- operand_kinds: all operand kinds' grammar spec
Returns:
- A map mapping from duplicated capability symbols to the canonicalized
symbol chosen for SPIRVBase.td.
"""
# Find the operand kind for capability
cap_kind = {}
for kind in operand_kinds:
if kind["kind"] == "Capability":
cap_kind = kind

kind_cases = [(case["enumerant"], case["value"]) for case in cap_kind["enumerants"]]
_, capability_mapping = uniquify_enum_cases(kind_cases)

return capability_mapping


def get_availability_spec(enum_case, capability_mapping, for_op, for_cap):
def get_availability_spec(enum_case, for_op, for_cap):
"""Returns the availability specification string for the given enum case.
Arguments:
- enum_case: the enum case to generate availability spec for. It may contain
'version', 'lastVersion', 'extensions', or 'capabilities'.
- capability_mapping: mapping from duplicated capability symbols to the
canonicalized symbol chosen for SPIRVBase.td.
- for_op: bool value indicating whether this is the availability spec for an
op itself.
- for_cap: bool value indicating whether this is the availability spec for
Expand Down Expand Up @@ -313,10 +246,7 @@ def get_availability_spec(enum_case, capability_mapping, for_op, for_cap):
if caps:
canonicalized_caps = []
for c in caps:
if c in capability_mapping:
canonicalized_caps.append(capability_mapping[c])
else:
canonicalized_caps.append(c)
canonicalized_caps.append(c)
prefixed_caps = [
"SPIRV_C_{}".format(c) for c in sorted(set(canonicalized_caps))
]
Expand Down Expand Up @@ -357,7 +287,7 @@ def get_availability_spec(enum_case, capability_mapping, for_op, for_cap):
return "{}{}{}".format(implies, "\n " if implies and avail else "", avail)


def gen_operand_kind_enum_attr(operand_kind, capability_mapping):
def gen_operand_kind_enum_attr(operand_kind):
"""Generates the TableGen EnumAttr definition for the given operand kind.
Returns:
Expand Down Expand Up @@ -388,13 +318,12 @@ def get_case_symbol(kind_name, case_name):
# Special treatment for capability cases: we need to sort them topologically
# because a capability can refer to another via the 'implies' field.
kind_cases = toposort_capabilities(
operand_kind["enumerants"], capability_mapping
operand_kind["enumerants"]
)
else:
kind_cases = [
(case["enumerant"], case["value"]) for case in operand_kind["enumerants"]
]
kind_cases, _ = uniquify_enum_cases(kind_cases)
max_len = max([len(symbol) for (symbol, _) in kind_cases])

# Generate the definition for each enum case
Expand All @@ -412,7 +341,6 @@ def get_case_symbol(kind_name, case_name):
value = int(case_pair[1])
avail = get_availability_spec(
name_to_case_dict[name],
capability_mapping,
False,
kind_name == "Capability",
)
Expand Down Expand Up @@ -648,11 +576,9 @@ def update_td_enum_attrs(path, operand_kinds, filter_list):
]
filter_list.extend(existing_kinds)

capability_mapping = get_capability_mapping(operand_kinds)

# Generate definitions for all enums in filter list
defs = [
gen_operand_kind_enum_attr(kind, capability_mapping)
gen_operand_kind_enum_attr(kind)
for kind in operand_kinds
if kind["kind"] in filter_list
]
Expand Down Expand Up @@ -762,7 +688,7 @@ def get_description(text, appendix):


def get_op_definition(
instruction, opname, doc, existing_info, capability_mapping, settings
instruction, opname, doc, existing_info, settings
):
"""Generates the TableGen op definition for the given SPIR-V instruction.
Expand All @@ -771,8 +697,6 @@ def get_op_definition(
- doc: the instruction's SPIR-V HTML doc
- existing_info: a dict containing potential manually specified sections for
this instruction
- capability_mapping: mapping from duplicated capability symbols to the
canonicalized symbol chosen for SPIRVBase.td
Returns:
- A string containing the TableGen op definition
Expand Down Expand Up @@ -840,7 +764,7 @@ def get_op_definition(
operands = instruction.get("operands", [])

# Op availability
avail = get_availability_spec(instruction, capability_mapping, True, False)
avail = get_availability_spec(instruction, True, False)
if avail:
avail = "\n\n {0}".format(avail)

Expand Down Expand Up @@ -1021,7 +945,7 @@ def extract_td_op_info(op_def):


def update_td_op_definitions(
path, instructions, docs, filter_list, inst_category, capability_mapping, settings
path, instructions, docs, filter_list, inst_category, settings
):
"""Updates SPIRVOps.td with newly generated op definition.
Expand All @@ -1030,8 +954,6 @@ def update_td_op_definitions(
- instructions: SPIR-V JSON grammar for all instructions
- docs: SPIR-V HTML doc for all instructions
- filter_list: a list containing new opnames to include
- capability_mapping: mapping from duplicated capability symbols to the
canonicalized symbol chosen for SPIRVBase.td.
Returns:
- A string containing all the TableGen op definitions
Expand Down Expand Up @@ -1079,7 +1001,6 @@ def update_td_op_definitions(
opname,
docs[fixed_opname],
op_info_dict.get(opname, {"inst_category": inst_category}),
capability_mapping,
settings,
)
)
Expand Down Expand Up @@ -1186,14 +1107,12 @@ def update_td_op_definitions(
if args.new_inst is not None:
assert args.op_td_path is not None
docs = get_spirv_doc_from_html_spec(ext_html_url, args)
capability_mapping = get_capability_mapping(operand_kinds)
update_td_op_definitions(
args.op_td_path,
instructions,
docs,
args.new_inst,
args.inst_category,
capability_mapping,
args,
)
print("Done. Note that this script just generates a template; ", end="")
Expand Down

0 comments on commit 6e75eec

Please sign in to comment.