Skip to content

Commit

Permalink
Add deterministic SQL metrics (#63)
Browse files Browse the repository at this point in the history
  • Loading branch information
yisz authored May 22, 2024
1 parent 344f7e9 commit 8124a22
Show file tree
Hide file tree
Showing 13 changed files with 597 additions and 268 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ print(metric(**datum))
<tr>
<td rowspan="2">Code Generation</td>
<td>Deterministic</td>
<td>CodeStringMatch, PythonASTSimilarity</td>
<td>CodeStringMatch, PythonASTSimilarity, SQLSyntaxMatch, SQLASTSimilarity</td>
</tr>
<tr>
<td>LLM-based</td>
Expand Down
Empty file.
1 change: 0 additions & 1 deletion continuous_eval/metrics/code/python/__init__.py

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
from typing import List, Union

from munkres import Munkres
from sqlglot import diff, parse_one
from sqlglot.diff import Keep
from thefuzz import fuzz

from continuous_eval.metrics.base import Metric
Expand Down
135 changes: 135 additions & 0 deletions continuous_eval/metrics/code/sql/deterministic.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,135 @@
import logging
from dataclasses import dataclass
from typing import Dict, List, Optional, Union

from sqlglot import diff, parse_one, transpile
from sqlglot.diff import Insert, Keep, Move, Remove, Update
from sqlglot.optimizer import optimize

from continuous_eval.metrics.base import Metric

logger = logging.getLogger("metrics")


@dataclass(frozen=True)
class ASTDiffWeightConfig:
"""
Configuration for assigning weights to different types of changes in the AST diff.
Higher weights indicate more significant changes, which are expected to have a greater impact on query semantics.
"""

keep: float = 0.0
# Updates are significant as they imply a modification in function or value.
update: float = 1.5
# Inserts affect the structure and content but are simpler than updates.
insert: float = 1.0
# Removes affect the structure and content but are simpler than updates.
remove: float = 1.0
# Moves are generally less impactful as they simply change the order.
move: float = 0.5
# Default weight for other types of changes
default: float = 1.0


class _SQLMetric:
def __init__(self, optimize: bool = False, schema: Optional[Dict] = None):
self._optimize = optimize
self._schema = schema

def _prepare_query(self, sql: str):
"""
Parse, transpile, and optionally optimize a SQL query.
"""
formatted_sql = transpile(sql, pretty=True, comments=False)[0]
if self._optimize:
try:
optimized_sql = optimize(parse_one(formatted_sql), schema=self._schema).sql(pretty=True)
return optimized_sql
except Exception as e:
logger.warning(f"Failed to optimize SQL query given schema: {e}. Using unoptimized query.")
return formatted_sql
return formatted_sql


class SQLSyntaxMatch(Metric, _SQLMetric):
"""
This metric evaluates the syntactic similarity between the generated SQL query and a set of ground truth queries.
It uses the sqlglot library to format and compare the SQL queries.
"""

def __init__(self, optimize: bool = False, schema: Optional[Dict] = None):
super(SQLSyntaxMatch, self).__init__()
_SQLMetric.__init__(self, optimize=optimize, schema=schema)

def __call__(self, answer: str, ground_truth_answers: Union[List[str], str]):

transformed_answer = self._prepare_query(answer)
transformed_ground_truths = [self._prepare_query(gt) for gt in ground_truth_answers]

max_match_score = 0.0

for transformed_gt in transformed_ground_truths:
match_score = float(transformed_answer == transformed_gt)
if match_score > max_match_score:
max_match_score = match_score

return {"SQL_Syntax_Match": max_match_score}


class SQLASTSimilarity(Metric, _SQLMetric):
"""
Compare SQL queries using AST similarity, considering different types of changes differently and improving normalization.
"""

def __init__(
self,
optimize: bool = False,
schema: Optional[Dict] = None,
diff_weights: ASTDiffWeightConfig = ASTDiffWeightConfig(),
):
super(SQLASTSimilarity, self).__init__()
_SQLMetric.__init__(self, optimize=optimize, schema=schema)
self._diff_weights = diff_weights

def __call__(self, answer: str, ground_truth_answers: Union[List[str], str], **kwargs):

transformed_answer = self._prepare_query(answer)
transformed_ground_truths = [self._prepare_query(gt) for gt in ground_truth_answers]

try:
answer_tree = parse_one(transformed_answer)
ground_truth_trees = [parse_one(gt) for gt in transformed_ground_truths]
except Exception:
return {"SQL_AST_Similarity": -1.0}

similarity_scores = [
self._calculate_similarity(answer_tree, ground_truth_tree) for ground_truth_tree in ground_truth_trees
]

return {
"SQL_AST_Similarity": max(similarity_scores) if similarity_scores else -1.0,
}

def _calculate_similarity(self, tree1, tree2):
diff_result = diff(tree1, tree2)
total_changes = sum(self._apply_weights(change) for change in diff_result)
max_nodes = max(len(list(tree1.walk())), len(list(tree2.walk())))
similarity_score = 1 - (total_changes / max_nodes) if max_nodes > 0 else 1
return similarity_score

def _apply_weights(self, change):
"""
Assign weights to different types of changes based on their expected impact on query semantics.
"""
if isinstance(change, Keep):
return self._diff_weights.keep
elif isinstance(change, Update):
return self._diff_weights.update
elif isinstance(change, Insert):
return self._diff_weights.insert
elif isinstance(change, Remove):
return self._diff_weights.remove
elif isinstance(change, Move):
return self._diff_weights.move
else:
return self._diff_weights.default
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
---
title: StringMatch
title: Code String Match
sidebar:
order: 1
---
Expand All @@ -18,7 +18,7 @@ It outputs both the binary exact match score and the fuzzy match score in the ra
Required data items: `answer`, `ground_truth_answers`

```python
from continuous_eval.metrics.code.python import CodeStringMatch
from continuous_eval.metrics.code import CodeStringMatch

datum = {
"answer": "def function(x, y):\n return x + y",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ The metric depends on syntactically correct Python scripts to produce the Abstra
Required data items: `answer`, `ground_truth_answers`

```python
from continuous_eval.metrics.code.python import PythonASTSimilarity
from continuous_eval.metrics.code import PythonASTSimilarity

datum = {
"answer": "def function(x, y):\n return x + y",
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
---
title: SQL AST Similarity
sidebar:
order: 1
---

### Definitions

**SQL AST Similarity** compares the structure of two SQL queries by analyzing their Abstract Syntax Trees (ASTs). This metric assesses similarity by matching the nodes within these trees, taking into account the statement types and their arrangement. Different types of tree differences (such as insert, remove, update, move, etc.) are weighted differently to calculate the final similarity score.

<br>

$$
\text{SQL AST Similarity} = 1 - \frac{\text{Total Weight Changes}}{\text{Maximum Possible Nodes}}
$$

<br>

:::note
The metric depends on syntactically correct SQL queries to produce the Abstract Syntax Trees (ASTs). If the scripts contain syntax errors and cannot be parsed, the metric will yield a score of -1.0.
:::

<br>

### Example Usage

Required data items: `answer`, `ground_truth_answers`

```python
from continuous_eval.metrics.code import SQLASTSimilarity

datum = {
"answer": "SELECT name, age FROM customers",
"ground_truth_answers": ["SELECT age, name FROM customers"],
},

metric = SQLASTSimilarity()
print(metric(**datum))
```

You can optionally initialize the metric to use optimized SQL queries using the [sqlglot optimizer](https://github.com/tobymao/sqlglot?tab=readme-ov-file#sql-optimizer) and optionally pass in the schema. For example:
```python
schema={"x": {"A": "INT", "B": "INT", "C": "INT", "D": "INT", "Z": "STRING"}}
sql_syntax_match_optimized = SQLASTSimilarity(optimized=True, schema=schema)
```

You can also customize weights to different types of nodes in the AST diff.
Higher weights indicate more significant changes, which are expected to have a greater impact on query semantics.

```python
from continuous_eval.metrics.code.sql.deterministic import ASTDiffWeightConfig

weights = ASTDiffWeightConfig(
keep_weight=0.0,
update_weight=2,
insert_weight=1.0,
remove_weight=1.5,
move_weight=0,
default_weight=0,
)
ASTSimilarity = SQLASTSimilarity(diff_weights=weights)
```

### Example Output

```JSON
{
"SQL_AST_Similarity": 0.9375
}
```
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
---
title: SQL Syntax Match
sidebar:
order: 2
---

## Definitions

**SQL Syntax Match** evaluates the syntactic equivalence between generated SQL queries and a set of ground truth queries. The strict comparison can tolerate formatting changes.

## Example Usage

Required data items: `answer`, `ground_truth_answers`

```python
from continuous_eval.metrics.code import SQLSyntaxMatch

sql_syntax_match = SQLSyntaxMatch()

datum = {
"answer": "SELECT * FROM users;"",
"ground_truth_answers": [
"SELECT * from users;"
],
},

metric = SQLSyntaxMatch()
print(metric(**datum))
```

You can optionally initialize the metric to use optimized SQL queries using the [sqlglot optimizer](https://github.com/tobymao/sqlglot?tab=readme-ov-file#sql-optimizer) and optionally pass in the schema. For example:
```python
schema={"x": {"A": "INT", "B": "INT", "C": "INT", "D": "INT", "Z": "STRING"}}
sql_syntax_match_optimized = SQLSyntaxSimilarity(optimized=True, schema=schema)
```

## Example Output

```JSON
{
"SQL_Syntax_Match": 1.0
}
```
Loading

0 comments on commit 8124a22

Please sign in to comment.