Skip to content

Commit

Permalink
Simplify rule inference algo
Browse files Browse the repository at this point in the history
  • Loading branch information
danieltrt committed Jul 5, 2023
1 parent c4b552f commit c3bacd2
Showing 1 changed file with 32 additions and 59 deletions.
91 changes: 32 additions & 59 deletions experimental/rule_inference/static_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,11 +131,11 @@ def __attrs_post_init__(self):

def static_infer(self) -> RawRule:
if len(self.nodes_after) > 0 and len(self.nodes_before) > 0:
return self.create_replacement()
return self.create_rule(self.nodes_before, self.nodes_after)
elif len(self.nodes_after) > 0:
raise self.create_addition()
elif len(self.nodes_before) > 0:
return self.create_deletion()
return self.create_rule(self.nodes_before)

def find_nodes_to_change(self, node_before: Node, node_after: Node):
"""
Expand All @@ -162,48 +162,53 @@ def find_nodes_to_change(self, node_before: Node, node_after: Node):

return node_before, node_after

def create_replacement(self) -> RawRule:
"""
Create a rule based on the node before and after.
"""
# For replacements (---- +++++)
if len(self.nodes_before) == 1:
if len(self.nodes_after) == 1:
self.nodes_before[0], self.nodes_after[0] = self.find_nodes_to_change(
self.nodes_before[0], self.nodes_after[0]
def create_rule(
self, nodes_before: List[Node], nodes_after: List[Node] = None
) -> RawRule:
# If there is only one node
if len(nodes_before) == 1:
if len(nodes_after) == 1:
nodes_before[0], nodes_after[0] = self.find_nodes_to_change(
nodes_before[0], nodes_after[0]
)
qw = QueryWriter([self.nodes_before[0]])
qw.write()
lines_affected = " ".join(
[NodeUtils.convert_to_source(node) for node in self.nodes_after]
)
replacement_str = qw.replace_with_tags(lines_affected)
node = nodes_before[0]
qw = QueryWriter([node])
query = qw.write()

replacement_str = ""
if nodes_after: # If it's a replacement, not a deletion
lines_affected = " ".join(
[NodeUtils.convert_to_source(node) for node in nodes_after]
)
replacement_str = qw.replace_with_tags(lines_affected)

return RawRule(
name=self.name,
query=qw.query_str,
query=query,
replace_node=qw.outer_most_node[1:],
replace=replacement_str,
)

# If there are multiple nodes
else:
# find the smallest common ancestor of _nodes_before
ancestor = NodeUtils.find_lowest_common_ancestor(self.nodes_before)
ancestor = NodeUtils.find_lowest_common_ancestor(nodes_before)
replacement_str = NodeUtils.convert_to_source(
ancestor, exclude=self.nodes_before
ancestor, exclude=nodes_before
)

replacement_str = replacement_str.replace(
"{placeholder}", "", len(self.nodes_before) - 1
"{placeholder}", "", len(nodes_before) - 1
)

lines_affected = " ".join(
[NodeUtils.convert_to_source(node) for node in self.nodes_after]
)
# If it's a replacement, not a deletion
lines_affected = ""
if nodes_after:
lines_affected = " ".join(
[NodeUtils.convert_to_source(node) for node in nodes_after]
)

replacement_str = replacement_str.replace(
"{placeholder}", lines_affected, 1
)

qw = QueryWriter([ancestor])
qw.write()
replacement_str = qw.replace_with_tags(replacement_str)
Expand All @@ -215,37 +220,5 @@ def create_replacement(self) -> RawRule:
replace=replacement_str,
)

def create_deletion(self) -> RawRule:
if len(self.nodes_before) == 1:
node_before = self.nodes_before[0]
qw = QueryWriter([node_before])
query = qw.write()
return RawRule(
name=self.name,
query=query,
replace_node=qw.outer_most_node[1:],
replace="",
)
else:
# find the smallest common ancestor of _nodes_before
ancestor = NodeUtils.find_lowest_common_ancestor(self.nodes_before)
deletion_str = NodeUtils.convert_to_source(
ancestor, exclude=self.nodes_before
)

deletion_str = deletion_str.replace(
"{placeholder}", "", len(self.nodes_before)
)

qw = QueryWriter([ancestor])
qw.write()

return RawRule(
name=self.name,
query=qw.query_str,
replace_node=qw.outer_most_node[1:],
replace=deletion_str,
)

def create_addition(self) -> str:
raise NotImplementedError

0 comments on commit c3bacd2

Please sign in to comment.