Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[mlir][spirv] Remove code for de-duplicating symbols in SPIR-V grammar #111778

Merged
merged 1 commit into from
Oct 30, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
101 changes: 10 additions & 91 deletions mlir/utils/spirv/gen_spirv_dialect.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,44 +127,6 @@ def split_list_into_sublists(items):
return chuncks


def uniquify_enum_cases(lst):
"""Prunes duplicate enum cases from the list.

Arguments:
- lst: List whose elements are to be uniqued. Assumes each element is a
(symbol, value) pair and elements already sorted according to value.

Returns:
- A list with all duplicates removed. The elements are sorted according to
value and, for each value, uniqued according to symbol.
original list,
- A map from deduplicated cases to the uniqued case.
"""
cases = lst
uniqued_cases = []
duplicated_cases = {}

# First sort according to the value
cases.sort(key=lambda x: x[1])

# Then group them according to the value
for _, groups in itertools.groupby(cases, key=lambda x: x[1]):
# For each value, sort according to the enumerant symbol.
sorted_group = sorted(groups, key=lambda x: x[0])
# Keep the "smallest" case, which is typically the symbol without extension
# suffix. But we have special cases that we want to fix.
case = sorted_group[0]
for i in range(1, len(sorted_group)):
duplicated_cases[sorted_group[i][0]] = case[0]
if case[0] == "HlslSemanticGOOGLE":
assert len(sorted_group) == 2, "unexpected new variant for HlslSemantic"
case = sorted_group[1]
duplicated_cases[sorted_group[0][0]] = case[0]
uniqued_cases.append(case)

return uniqued_cases, duplicated_cases


def toposort(dag, sort_fn):
"""Topologically sorts the given dag.

Expand Down Expand Up @@ -197,14 +159,12 @@ def get_next_batch(dag):
return sorted_nodes


def toposort_capabilities(all_cases, capability_mapping):
def toposort_capabilities(all_cases):
"""Returns topologically sorted capability (symbol, value) pairs.

Arguments:
- all_cases: all capability cases (containing symbol, value, and implied
capabilities).
- capability_mapping: mapping from duplicated capability symbols to the
canonicalized symbol chosen for SPIRVBase.td.

Returns:
A list containing topologically sorted capability (symbol, value) pairs.
Expand All @@ -215,50 +175,23 @@ def toposort_capabilities(all_cases, capability_mapping):
# Get the current capability.
cur = case["enumerant"]
name_to_value[cur] = case["value"]
# Ignore duplicated symbols.
if cur in capability_mapping:
continue

# Get capabilities implied by the current capability.
prev = case.get("capabilities", [])
uniqued_prev = set([capability_mapping.get(c, c) for c in prev])
uniqued_prev = set(prev)
dag[cur] = uniqued_prev

sorted_caps = toposort(dag, lambda x: name_to_value[x])
# Attach the capability's value as the second component of the pair.
return [(c, name_to_value[c]) for c in sorted_caps]


def get_capability_mapping(operand_kinds):
"""Returns the capability mapping from duplicated cases to canonicalized ones.

Arguments:
- operand_kinds: all operand kinds' grammar spec

Returns:
- A map mapping from duplicated capability symbols to the canonicalized
symbol chosen for SPIRVBase.td.
"""
# Find the operand kind for capability
cap_kind = {}
for kind in operand_kinds:
if kind["kind"] == "Capability":
cap_kind = kind

kind_cases = [(case["enumerant"], case["value"]) for case in cap_kind["enumerants"]]
_, capability_mapping = uniquify_enum_cases(kind_cases)

return capability_mapping


def get_availability_spec(enum_case, capability_mapping, for_op, for_cap):
def get_availability_spec(enum_case, for_op, for_cap):
"""Returns the availability specification string for the given enum case.

Arguments:
- enum_case: the enum case to generate availability spec for. It may contain
'version', 'lastVersion', 'extensions', or 'capabilities'.
- capability_mapping: mapping from duplicated capability symbols to the
canonicalized symbol chosen for SPIRVBase.td.
- for_op: bool value indicating whether this is the availability spec for an
op itself.
- for_cap: bool value indicating whether this is the availability spec for
Expand Down Expand Up @@ -313,10 +246,7 @@ def get_availability_spec(enum_case, capability_mapping, for_op, for_cap):
if caps:
canonicalized_caps = []
for c in caps:
if c in capability_mapping:
canonicalized_caps.append(capability_mapping[c])
else:
canonicalized_caps.append(c)
canonicalized_caps.append(c)
prefixed_caps = [
"SPIRV_C_{}".format(c) for c in sorted(set(canonicalized_caps))
]
Expand Down Expand Up @@ -357,7 +287,7 @@ def get_availability_spec(enum_case, capability_mapping, for_op, for_cap):
return "{}{}{}".format(implies, "\n " if implies and avail else "", avail)


def gen_operand_kind_enum_attr(operand_kind, capability_mapping):
def gen_operand_kind_enum_attr(operand_kind):
"""Generates the TableGen EnumAttr definition for the given operand kind.

Returns:
Expand Down Expand Up @@ -388,13 +318,12 @@ def get_case_symbol(kind_name, case_name):
# Special treatment for capability cases: we need to sort them topologically
# because a capability can refer to another via the 'implies' field.
kind_cases = toposort_capabilities(
operand_kind["enumerants"], capability_mapping
operand_kind["enumerants"]
)
else:
kind_cases = [
(case["enumerant"], case["value"]) for case in operand_kind["enumerants"]
]
kind_cases, _ = uniquify_enum_cases(kind_cases)
max_len = max([len(symbol) for (symbol, _) in kind_cases])

# Generate the definition for each enum case
Expand All @@ -412,7 +341,6 @@ def get_case_symbol(kind_name, case_name):
value = int(case_pair[1])
avail = get_availability_spec(
name_to_case_dict[name],
capability_mapping,
False,
kind_name == "Capability",
)
Expand Down Expand Up @@ -648,11 +576,9 @@ def update_td_enum_attrs(path, operand_kinds, filter_list):
]
filter_list.extend(existing_kinds)

capability_mapping = get_capability_mapping(operand_kinds)

# Generate definitions for all enums in filter list
defs = [
gen_operand_kind_enum_attr(kind, capability_mapping)
gen_operand_kind_enum_attr(kind)
for kind in operand_kinds
if kind["kind"] in filter_list
]
Expand Down Expand Up @@ -762,7 +688,7 @@ def get_description(text, appendix):


def get_op_definition(
instruction, opname, doc, existing_info, capability_mapping, settings
instruction, opname, doc, existing_info, settings
):
"""Generates the TableGen op definition for the given SPIR-V instruction.

Expand All @@ -771,8 +697,6 @@ def get_op_definition(
- doc: the instruction's SPIR-V HTML doc
- existing_info: a dict containing potential manually specified sections for
this instruction
- capability_mapping: mapping from duplicated capability symbols to the
canonicalized symbol chosen for SPIRVBase.td

Returns:
- A string containing the TableGen op definition
Expand Down Expand Up @@ -840,7 +764,7 @@ def get_op_definition(
operands = instruction.get("operands", [])

# Op availability
avail = get_availability_spec(instruction, capability_mapping, True, False)
avail = get_availability_spec(instruction, True, False)
if avail:
avail = "\n\n {0}".format(avail)

Expand Down Expand Up @@ -1019,7 +943,7 @@ def extract_td_op_info(op_def):


def update_td_op_definitions(
path, instructions, docs, filter_list, inst_category, capability_mapping, settings
path, instructions, docs, filter_list, inst_category, settings
):
"""Updates SPIRVOps.td with newly generated op definition.

Expand All @@ -1028,8 +952,6 @@ def update_td_op_definitions(
- instructions: SPIR-V JSON grammar for all instructions
- docs: SPIR-V HTML doc for all instructions
- filter_list: a list containing new opnames to include
- capability_mapping: mapping from duplicated capability symbols to the
canonicalized symbol chosen for SPIRVBase.td.

Returns:
- A string containing all the TableGen op definitions
Expand Down Expand Up @@ -1077,7 +999,6 @@ def update_td_op_definitions(
opname,
docs[fixed_opname],
op_info_dict.get(opname, {"inst_category": inst_category}),
capability_mapping,
settings,
)
)
Expand Down Expand Up @@ -1184,14 +1105,12 @@ def update_td_op_definitions(
if args.new_inst is not None:
assert args.op_td_path is not None
docs = get_spirv_doc_from_html_spec(ext_html_url, args)
capability_mapping = get_capability_mapping(operand_kinds)
update_td_op_definitions(
args.op_td_path,
instructions,
docs,
args.new_inst,
args.inst_category,
capability_mapping,
args,
)
print("Done. Note that this script just generates a template; ", end="")
Expand Down
Loading