-
Notifications
You must be signed in to change notification settings - Fork 29
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Feature/code eval metrics #29
Changes from 2 commits
54de5d3
665be35
1eb8ba1
0a44fdd
24518d0
0ab1174
515d087
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,325 @@ | ||
import ast | ||
import warnings | ||
from _ast import * | ||
|
||
from munkres import Munkres | ||
from thefuzz import fuzz | ||
|
||
from continuous_eval.metrics.base import Metric | ||
|
||
|
||
class CodeStringMatch(Metric): | ||
def calculate(self, answer, ground_truths, **kwargs): | ||
max_exact_match = 0 | ||
max_similarity_score = 0 | ||
for gt in ground_truths: | ||
exact_match = float(answer == gt) | ||
similarity_score = fuzz.ratio(answer, gt) / 100 | ||
if exact_match > max_exact_match: | ||
max_exact_match = exact_match | ||
if similarity_score > max_similarity_score: | ||
max_similarity_score = similarity_score | ||
return {"Exact_Match_Score": max_exact_match, "Fuzzy_Match_Score": max_similarity_score} | ||
|
||
|
||
class PythonASTSimilarity(Metric): | ||
''' | ||
The following functions are adapted from python-ast-comparison by Pedro Salazar Paredes | ||
Copyright (c) 2023 Pedro Salazar Paredes | ||
Licensed under the MIT License | ||
Source: https://github.com/PedroSalazarParedes/python-ast-comparison | ||
Modifications: Adjusted to be used in the context of generated code evaluation | ||
''' | ||
|
||
def _compare_ASTs(ast_a: AST, ast_b: AST, reorder_depth: int) -> int: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The docstring for the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @ellipsis-dev can you fix this docstrings for me? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
"""Compare two ASTs corresponding to python programs. | ||
|
||
Args: | ||
ast_a: The first program AST to compare. | ||
ast_b: The first program AST to compare. | ||
reorder_depth: The maximum children reorder depth for better | ||
performance. | ||
|
||
Returns: | ||
True if the ASTs are equivalent, otherwise False. | ||
""" | ||
children_a = list(ast.iter_child_nodes(ast_a)) | ||
children_b = list(ast.iter_child_nodes(ast_b)) | ||
if (type(ast_a) == type(ast_b)) and len(list(children_a)) == 0 and len(list(children_b)) == 0: | ||
return 1 | ||
|
||
if (type(ast_a) != type(ast_b)) or (len(children_a) != len(children_b)): | ||
return 0 | ||
|
||
if reorder_depth == 0: | ||
match_index = sum( | ||
map( | ||
lambda pairs: PythonASTSimilarity._compare_ASTs(pairs[0], pairs[1], reorder_depth), | ||
zip(children_a, children_b), | ||
) | ||
) | ||
return match_index + 1 | ||
|
||
elif reorder_depth > 0: | ||
match_index = PythonASTSimilarity._reorder_children_compare(ast_a, ast_b, reorder_depth - 1) | ||
return match_index + 1 | ||
|
||
return 0 | ||
|
||
def _reorder_children_compare(ast_a: AST, ast_b: AST, reorder_depth: int) -> int: | ||
"""Reorders child nodes and compares them. | ||
|
||
Args: | ||
ast_a: The first AST for child comparison. | ||
ast_b: The second AST for child comparison. | ||
reorder_depth: The maximum children reorder depth for better | ||
performance. | ||
|
||
Returns: | ||
True if there is a way to match 1-1 every child node of ast_a | ||
with every child node of ast_b, otherwise False. | ||
""" | ||
comparison_matrix = [] | ||
cost_matrix = [] | ||
best_match_value = 0 | ||
children_a = list(ast.iter_child_nodes(ast_a)) | ||
children_b = list(ast.iter_child_nodes(ast_b)) | ||
|
||
if len(children_a) <= 1 or len(children_b) <= 1: | ||
for child_a in children_a: | ||
for child_b in children_b: | ||
best_match_value += PythonASTSimilarity._compare_ASTs(child_a, child_b, reorder_depth) | ||
else: | ||
for child_a in children_a: | ||
row = [] | ||
cost_row = [] | ||
for child_b in children_b: | ||
similarity = PythonASTSimilarity._compare_ASTs(child_a, child_b, reorder_depth) | ||
row.append(similarity) | ||
cost_row.append(10000000 - similarity) | ||
|
||
comparison_matrix.append(row) | ||
cost_matrix.append(cost_row) | ||
|
||
m = Munkres() | ||
indices = m.compute(cost_matrix) | ||
|
||
for row, col in indices: | ||
best_match_value += comparison_matrix[row][col] | ||
|
||
return best_match_value | ||
|
||
def _compare_subtrees(sig_subtrees_p1: list, sig_subtrees_p2: list, reorder_depth: int) -> tuple: | ||
"""Compare two significant subtree lists reordering up to a certain depth. | ||
|
||
Args: | ||
sig_subtrees_p1: The first significant AST list for comparison. | ||
sig_subtrees_p2: The second significant AST list for comparison. | ||
reorder_depth: The maximum children reorder depth for better | ||
performance. | ||
|
||
Returns: | ||
A tuple with the ratio of matching to non-matching nodes of the | ||
significant subtrees, and a list with the best matching of subtrees. | ||
""" | ||
comparison_matrix = [] | ||
cost_matrix = [] | ||
best_match = [] | ||
best_match_value = 0 | ||
best_match_weight = 0 | ||
children_a = sig_subtrees_p1.copy() | ||
children_b = sig_subtrees_p2.copy() | ||
|
||
if len(children_a) <= 1 or len(children_b) <= 1: | ||
for child_a in children_a: | ||
best_match += [child_a] | ||
for child_b in children_b: | ||
best_match_value += PythonASTSimilarity._compare_ASTs(child_a, child_b, reorder_depth) | ||
best_match += [child_b] | ||
else: | ||
for child_a in children_a: | ||
row = [] | ||
cost_row = [] | ||
for child_b in children_b: | ||
similarity = PythonASTSimilarity._compare_ASTs(child_a, child_b, reorder_depth) | ||
row.append(similarity) | ||
cost_row.append(10000000 - similarity) | ||
|
||
comparison_matrix.append(row) | ||
cost_matrix.append(cost_row) | ||
|
||
m = Munkres() | ||
indices = m.compute(cost_matrix) | ||
|
||
for row, col in indices: | ||
best_match_weight += PythonASTSimilarity._apply_weights_to_subtrees_mult( | ||
comparison_matrix[row][col], sig_subtrees_p1[row], sig_subtrees_p2[col] | ||
) | ||
best_match += [sig_subtrees_p1[row], sig_subtrees_p2[col]] | ||
|
||
all_subtrees_weight = sum( | ||
map( | ||
lambda tree: PythonASTSimilarity._apply_weights_to_subtrees( | ||
PythonASTSimilarity._get_num_nodes(tree), tree | ||
), | ||
sig_subtrees_p1, | ||
) | ||
) + sum( | ||
map( | ||
lambda tree: PythonASTSimilarity._apply_weights_to_subtrees( | ||
PythonASTSimilarity._get_num_nodes(tree), tree | ||
), | ||
sig_subtrees_p2, | ||
) | ||
) | ||
|
||
similarity = 2 * best_match_weight / all_subtrees_weight | ||
|
||
return round(similarity, 4), best_match | ||
|
||
def _is_significant(root: AST) -> bool: | ||
"""Determine if an AST is significant. | ||
|
||
Args: | ||
root: The AST whose significance we want. | ||
|
||
Returns: | ||
True for if it is significant, False otherwise. | ||
""" | ||
return ( | ||
isinstance(root, Import) | ||
or isinstance(root, FunctionDef) | ||
or isinstance(root, If) | ||
or isinstance(root, ClassDef) | ||
or isinstance(root, While) | ||
or isinstance(root, For) | ||
or isinstance(root, comprehension) | ||
or isinstance(root, Return) | ||
) | ||
|
||
def _get_significant_subtrees(root: AST) -> list: | ||
"""Find the significant subtrees of an AST. | ||
|
||
Args: | ||
root: The root of the main AST. | ||
|
||
Returns: | ||
A list with all the significant subtrees of root. | ||
""" | ||
significant_subtrees = [] | ||
for node in ast.walk(root): | ||
if PythonASTSimilarity._is_significant(node): | ||
significant_subtrees.append(node) | ||
return significant_subtrees | ||
|
||
def _get_num_nodes(root: AST) -> int: | ||
"""Find the number of nodes for a given tree. | ||
|
||
Args: | ||
root: The root of the tree whose size we want. | ||
|
||
Returns: | ||
The number of nodes in the tree. | ||
""" | ||
return len(list(ast.walk(root))) | ||
|
||
def _apply_weights_to_subtrees(weight: float, subtree: AST) -> float: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The weights applied in the |
||
"""Apply weights to subtrees according to the time por their roots. | ||
|
||
Args: | ||
weight: The number of nodes in the subtree. | ||
subtree: The subtree. | ||
|
||
Returns: | ||
The weighed weight of the tree. | ||
""" | ||
new_weight = weight | ||
if isinstance(subtree, Import): | ||
new_weight *= 0.3 | ||
elif isinstance(subtree, Module): | ||
new_weight *= 1 | ||
elif isinstance(subtree, FunctionDef): | ||
new_weight *= 1.2 | ||
elif isinstance(subtree, If): | ||
new_weight *= 0.5 | ||
elif isinstance(subtree, ClassDef): | ||
new_weight *= 1 | ||
elif isinstance(subtree, While): | ||
new_weight *= 1 | ||
elif isinstance(subtree, For): | ||
new_weight *= 1 | ||
elif isinstance(subtree, comprehension): | ||
new_weight *= 1 | ||
elif isinstance(subtree, Return): | ||
new_weight *= 1 | ||
return new_weight | ||
|
||
def _apply_weights_to_subtrees_mult(weight: float, ast_1: AST, ast_2: AST) -> float: | ||
"""Find the average weight of both trees in order to weigh the comparison. | ||
|
||
Args: | ||
weight: The weight of the comparison. | ||
ast_1: The first compared tree. | ||
ast_2: The second compared tree. | ||
|
||
Returns: | ||
The average of the subtrees' weights. | ||
""" | ||
if weight == 0: | ||
return 0 | ||
else: | ||
return ( | ||
PythonASTSimilarity._apply_weights_to_subtrees(weight, ast_1) | ||
+ PythonASTSimilarity._apply_weights_to_subtrees(weight, ast_2) | ||
) / 2 | ||
|
||
def _compare_many(programs: list) -> list: | ||
"""Compare all of the programs in the list. | ||
|
||
Args: | ||
programs: A list of strings with python programs. | ||
|
||
Returns: | ||
A matrix with the similarity rating of between all the programs. | ||
""" | ||
tree_list = list(map(lambda prog: PythonASTSimilarity._get_significant_subtrees(ast.parse(prog)), programs)) | ||
|
||
matrix = [] | ||
for program_1_tree_num in range(0, len(tree_list)): | ||
for program_2_tree_num in range(program_1_tree_num, len(tree_list)): | ||
if program_1_tree_num == program_2_tree_num: | ||
continue | ||
|
||
subtrees1 = tree_list[program_1_tree_num] | ||
subtrees2 = tree_list[program_2_tree_num] | ||
|
||
result = PythonASTSimilarity._compare_subtrees(subtrees1, subtrees2, 1000)[0] | ||
|
||
matrix.append((program_1_tree_num, program_2_tree_num, result)) | ||
matrix.append((program_2_tree_num, program_1_tree_num, result)) | ||
|
||
return matrix | ||
|
||
def calculate(self, answer, ground_truths, **kwargs): | ||
|
||
try: | ||
answer_tree = ast.parse(answer, mode="exec") | ||
ground_truth_trees = [ast.parse(gt, mode="exec") for gt in ground_truths] | ||
except SyntaxError as e: | ||
warning_msg = f"Error: {e}: AST cannot be parsed from answer: {answer} or ground_truths:{ground_truths}. Returning -1.0." | ||
warnings.warn(warning_msg, Warning) | ||
return {"Python_AST_Similarity": -1.0} | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Returning a Python_AST_Similarity score of -1.0 when a SyntaxError is caught could be misleading. Consider returning None or raising an exception to indicate that an error occurred. |
||
|
||
answer_subtree = PythonASTSimilarity._get_significant_subtrees(answer_tree) | ||
ground_truth_subtrees = [ | ||
PythonASTSimilarity._get_significant_subtrees(ground_truth_tree) for ground_truth_tree in ground_truth_trees | ||
] | ||
|
||
similarity_scores = [ | ||
PythonASTSimilarity._compare_subtrees(answer_subtree, ground_truth_subtree, 1000)[0] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The |
||
for ground_truth_subtree in ground_truth_subtrees | ||
] | ||
|
||
return { | ||
"Python_AST_Similarity": max(similarity_scores), | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The copyright year in the docstring seems to be incorrect. It is mentioned as 2023, which is a future year. Please correct it.