-
Notifications
You must be signed in to change notification settings - Fork 29
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
13 changed files
with
597 additions
and
268 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Empty file.
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,135 @@ | ||
import logging | ||
from dataclasses import dataclass | ||
from typing import Dict, List, Optional, Union | ||
|
||
from sqlglot import diff, parse_one, transpile | ||
from sqlglot.diff import Insert, Keep, Move, Remove, Update | ||
from sqlglot.optimizer import optimize | ||
|
||
from continuous_eval.metrics.base import Metric | ||
|
||
logger = logging.getLogger("metrics") | ||
|
||
|
||
@dataclass(frozen=True) | ||
class ASTDiffWeightConfig: | ||
""" | ||
Configuration for assigning weights to different types of changes in the AST diff. | ||
Higher weights indicate more significant changes, which are expected to have a greater impact on query semantics. | ||
""" | ||
|
||
keep: float = 0.0 | ||
# Updates are significant as they imply a modification in function or value. | ||
update: float = 1.5 | ||
# Inserts affect the structure and content but are simpler than updates. | ||
insert: float = 1.0 | ||
# Removes affect the structure and content but are simpler than updates. | ||
remove: float = 1.0 | ||
# Moves are generally less impactful as they simply change the order. | ||
move: float = 0.5 | ||
# Default weight for other types of changes | ||
default: float = 1.0 | ||
|
||
|
||
class _SQLMetric: | ||
def __init__(self, optimize: bool = False, schema: Optional[Dict] = None): | ||
self._optimize = optimize | ||
self._schema = schema | ||
|
||
def _prepare_query(self, sql: str): | ||
""" | ||
Parse, transpile, and optionally optimize a SQL query. | ||
""" | ||
formatted_sql = transpile(sql, pretty=True, comments=False)[0] | ||
if self._optimize: | ||
try: | ||
optimized_sql = optimize(parse_one(formatted_sql), schema=self._schema).sql(pretty=True) | ||
return optimized_sql | ||
except Exception as e: | ||
logger.warning(f"Failed to optimize SQL query given schema: {e}. Using unoptimized query.") | ||
return formatted_sql | ||
return formatted_sql | ||
|
||
|
||
class SQLSyntaxMatch(Metric, _SQLMetric): | ||
""" | ||
This metric evaluates the syntactic similarity between the generated SQL query and a set of ground truth queries. | ||
It uses the sqlglot library to format and compare the SQL queries. | ||
""" | ||
|
||
def __init__(self, optimize: bool = False, schema: Optional[Dict] = None): | ||
super(SQLSyntaxMatch, self).__init__() | ||
_SQLMetric.__init__(self, optimize=optimize, schema=schema) | ||
|
||
def __call__(self, answer: str, ground_truth_answers: Union[List[str], str]): | ||
|
||
transformed_answer = self._prepare_query(answer) | ||
transformed_ground_truths = [self._prepare_query(gt) for gt in ground_truth_answers] | ||
|
||
max_match_score = 0.0 | ||
|
||
for transformed_gt in transformed_ground_truths: | ||
match_score = float(transformed_answer == transformed_gt) | ||
if match_score > max_match_score: | ||
max_match_score = match_score | ||
|
||
return {"SQL_Syntax_Match": max_match_score} | ||
|
||
|
||
class SQLASTSimilarity(Metric, _SQLMetric): | ||
""" | ||
Compare SQL queries using AST similarity, considering different types of changes differently and improving normalization. | ||
""" | ||
|
||
def __init__( | ||
self, | ||
optimize: bool = False, | ||
schema: Optional[Dict] = None, | ||
diff_weights: ASTDiffWeightConfig = ASTDiffWeightConfig(), | ||
): | ||
super(SQLASTSimilarity, self).__init__() | ||
_SQLMetric.__init__(self, optimize=optimize, schema=schema) | ||
self._diff_weights = diff_weights | ||
|
||
def __call__(self, answer: str, ground_truth_answers: Union[List[str], str], **kwargs): | ||
|
||
transformed_answer = self._prepare_query(answer) | ||
transformed_ground_truths = [self._prepare_query(gt) for gt in ground_truth_answers] | ||
|
||
try: | ||
answer_tree = parse_one(transformed_answer) | ||
ground_truth_trees = [parse_one(gt) for gt in transformed_ground_truths] | ||
except Exception: | ||
return {"SQL_AST_Similarity": -1.0} | ||
|
||
similarity_scores = [ | ||
self._calculate_similarity(answer_tree, ground_truth_tree) for ground_truth_tree in ground_truth_trees | ||
] | ||
|
||
return { | ||
"SQL_AST_Similarity": max(similarity_scores) if similarity_scores else -1.0, | ||
} | ||
|
||
def _calculate_similarity(self, tree1, tree2): | ||
diff_result = diff(tree1, tree2) | ||
total_changes = sum(self._apply_weights(change) for change in diff_result) | ||
max_nodes = max(len(list(tree1.walk())), len(list(tree2.walk()))) | ||
similarity_score = 1 - (total_changes / max_nodes) if max_nodes > 0 else 1 | ||
return similarity_score | ||
|
||
def _apply_weights(self, change): | ||
""" | ||
Assign weights to different types of changes based on their expected impact on query semantics. | ||
""" | ||
if isinstance(change, Keep): | ||
return self._diff_weights.keep | ||
elif isinstance(change, Update): | ||
return self._diff_weights.update | ||
elif isinstance(change, Insert): | ||
return self._diff_weights.insert | ||
elif isinstance(change, Remove): | ||
return self._diff_weights.remove | ||
elif isinstance(change, Move): | ||
return self._diff_weights.move | ||
else: | ||
return self._diff_weights.default |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
70 changes: 70 additions & 0 deletions
70
docs/src/content/docs/metrics/Code/Deterministic/sql_ast_similarity.md
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,70 @@ | ||
--- | ||
title: SQL AST Similarity | ||
sidebar: | ||
order: 1 | ||
--- | ||
|
||
### Definitions | ||
|
||
**SQL AST Similarity** compares the structure of two SQL queries by analyzing their Abstract Syntax Trees (ASTs). This metric assesses similarity by matching the nodes within these trees, taking into account the statement types and their arrangement. Different types of tree differences (such as insert, remove, update, move, etc.) are weighted differently to calculate the final similarity score. | ||
|
||
<br> | ||
|
||
$$ | ||
\text{SQL AST Similarity} = 1 - \frac{\text{Total Weight Changes}}{\text{Maximum Possible Nodes}} | ||
$$ | ||
|
||
<br> | ||
|
||
:::note | ||
The metric depends on syntactically correct SQL queries to produce the Abstract Syntax Trees (ASTs). If the scripts contain syntax errors and cannot be parsed, the metric will yield a score of -1.0. | ||
::: | ||
|
||
<br> | ||
|
||
### Example Usage | ||
|
||
Required data items: `answer`, `ground_truth_answers` | ||
|
||
```python | ||
from continuous_eval.metrics.code import SQLASTSimilarity | ||
|
||
datum = { | ||
"answer": "SELECT name, age FROM customers", | ||
"ground_truth_answers": ["SELECT age, name FROM customers"], | ||
}, | ||
|
||
metric = SQLASTSimilarity() | ||
print(metric(**datum)) | ||
``` | ||
|
||
You can optionally initialize the metric to use optimized SQL queries using the [sqlglot optimizer](https://github.com/tobymao/sqlglot?tab=readme-ov-file#sql-optimizer) and optionally pass in the schema. For example: | ||
```python | ||
schema={"x": {"A": "INT", "B": "INT", "C": "INT", "D": "INT", "Z": "STRING"}} | ||
sql_syntax_match_optimized = SQLASTSimilarity(optimized=True, schema=schema) | ||
``` | ||
|
||
You can also customize weights to different types of nodes in the AST diff. | ||
Higher weights indicate more significant changes, which are expected to have a greater impact on query semantics. | ||
|
||
```python | ||
from continuous_eval.metrics.code.sql.deterministic import ASTDiffWeightConfig | ||
|
||
weights = ASTDiffWeightConfig( | ||
keep_weight=0.0, | ||
update_weight=2, | ||
insert_weight=1.0, | ||
remove_weight=1.5, | ||
move_weight=0, | ||
default_weight=0, | ||
) | ||
ASTSimilarity = SQLASTSimilarity(diff_weights=weights) | ||
``` | ||
|
||
### Example Output | ||
|
||
```JSON | ||
{ | ||
"SQL_AST_Similarity": 0.9375 | ||
} | ||
``` |
43 changes: 43 additions & 0 deletions
43
docs/src/content/docs/metrics/Code/Deterministic/sql_syntax_match.md
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,43 @@ | ||
--- | ||
title: SQL Syntax Match | ||
sidebar: | ||
order: 2 | ||
--- | ||
|
||
## Definitions | ||
|
||
**SQL Syntax Match** evaluates the syntactic equivalence between generated SQL queries and a set of ground truth queries. The strict comparison can tolerate formatting changes. | ||
|
||
## Example Usage | ||
|
||
Required data items: `answer`, `ground_truth_answers` | ||
|
||
```python | ||
from continuous_eval.metrics.code import SQLSyntaxMatch | ||
|
||
sql_syntax_match = SQLSyntaxMatch() | ||
|
||
datum = { | ||
"answer": "SELECT * FROM users;"", | ||
"ground_truth_answers": [ | ||
"SELECT * from users;" | ||
], | ||
}, | ||
|
||
metric = SQLSyntaxMatch() | ||
print(metric(**datum)) | ||
``` | ||
|
||
You can optionally initialize the metric to use optimized SQL queries using the [sqlglot optimizer](https://github.com/tobymao/sqlglot?tab=readme-ov-file#sql-optimizer) and optionally pass in the schema. For example: | ||
```python | ||
schema={"x": {"A": "INT", "B": "INT", "C": "INT", "D": "INT", "Z": "STRING"}} | ||
sql_syntax_match_optimized = SQLSyntaxSimilarity(optimized=True, schema=schema) | ||
``` | ||
|
||
## Example Output | ||
|
||
```JSON | ||
{ | ||
"SQL_Syntax_Match": 1.0 | ||
} | ||
``` |
Oops, something went wrong.