Skip to content

Commit

Permalink
Feature bq is unique (#114)
Browse files Browse the repository at this point in the history
* adapt conftest bq to run tests locally

* added is_unique with test cases

* black, flake and type ok on bq
  • Loading branch information
vestalisvirginis authored Aug 27, 2023
1 parent dc6d2c4 commit 398cb62
Show file tree
Hide file tree
Showing 5 changed files with 112 additions and 45 deletions.
21 changes: 13 additions & 8 deletions conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import pytest
from cuallee import Check, CheckLevel
from pyspark.sql import SparkSession
from pathlib import Path
import logging
import duckdb

Expand Down Expand Up @@ -84,14 +85,18 @@ def db() -> duckdb.DuckDBPyConnection:

@pytest.fixture(scope="session")
def bq_client():
from google.oauth2 import service_account
import os
import json
with open("key.json", "w") as writer:
json.dump(json.loads(os.getenv("GOOGLE_KEY")), writer)

credentials = service_account.Credentials.from_service_account_file("key.json")


if Path('temp/key.json').exists()==True:
credentials = os.getenv('GOOGLE_APPLICATION_CREDENTIALS')
else:
from google.oauth2 import service_account
import os
import json
with open("key.json", "w") as writer:
json.dump(json.loads(os.getenv("GOOGLE_KEY")), writer)

credentials = service_account.Credentials.from_service_account_file("key.json")

try:
client = bigquery.Client(project="cuallee-bigquery-386709", credentials=credentials)
return client
Expand Down
59 changes: 39 additions & 20 deletions cuallee/bigquery_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,18 @@
import operator
import pandas as pd
from dataclasses import dataclass
from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union
from typing import Dict, List, Union
from google.cloud import bigquery
from cuallee import Check, ComputeEngine, Rule


class ComputeMethod(enum.Enum):
SQL = "SQL"


@dataclass
class ComputeInstruction:
predicate: str
predicate: Union[str, List[str], None]
expression: str
compute_method: ComputeMethod

Expand Down Expand Up @@ -41,7 +43,19 @@ def are_complete(self, rule: Rule):
predicate = [f"{c} IS NOT NULL" for c in rule.column]
self.compute_instruction = ComputeInstruction(
predicate,
"("+f"+".join([self._sum_predicate_to_integer(p) for p in predicate])+f")/{len(rule.column)}",
"("
+ "+".join([self._sum_predicate_to_integer(p) for p in predicate])
+ f")/{len(rule.column)}",
ComputeMethod.SQL,
)
return self.compute_instruction

def is_unique(self, rule: Rule):
"""Validation for unique values in column"""
predicate = None
self.compute_instruction = ComputeInstruction(
predicate,
f"COUNT(DISTINCT {rule.column})",
ComputeMethod.SQL,
)
return self.compute_instruction
Expand All @@ -50,16 +64,15 @@ def are_complete(self, rule: Rule):
def _get_expressions(compute_set: Dict[str, ComputeInstruction]) -> str:
"""Get the expression for all the rules in check in one string"""

return f", ".join(
return ", ".join(
[
compute_instruction.expression
+ f" AS KEY{key}"
compute_instruction.expression + f" AS KEY{key}"
for key, compute_instruction in compute_set.items()
]
)


def _build_query(expression_string: str, dataframe: str) -> str:
def _build_query(expression_string: str, dataframe: bigquery.table.Table) -> str:
"""Build query final query"""

return f"SELECT {expression_string} FROM {dataframe}"
Expand All @@ -71,10 +84,15 @@ def _compute_query_method(client, query: str) -> Dict:
return client.query(query).to_arrow().to_pandas().to_dict(orient="records")


def _compute_row(client, dataframe: str) -> Dict:
def _compute_row(client, dataframe: bigquery.table.Table) -> Dict:
"""Get the number of rows"""

return client.query(f'SELECT COUNT(*) AS count FROM {dataframe}').to_arrow().to_pandas().to_dict(orient="records")
return (
client.query(f"SELECT COUNT(*) AS count FROM {dataframe}")
.to_arrow()
.to_pandas()
.to_dict(orient="records")
)


def _calculate_violations(result, nrows) -> Union[int, float]:
Expand Down Expand Up @@ -106,7 +124,7 @@ def _calculate_pass_rate(result, nrows) -> float:
elif abs(result) == nrows:
return 0.5
else:
nrows / abs(result)
return nrows / abs(result)
else:
return result / nrows

Expand All @@ -115,9 +133,9 @@ def _evaluate_status(pass_rate, pass_threshold) -> str:
"""Return the status for each rule"""

if pass_rate >= pass_threshold:
return f"PASS"
return "PASS"
else:
return f"FAIL"
return "FAIL"


def validate_data_types(rules: List[Rule], dataframe: str) -> bool:
Expand All @@ -132,13 +150,14 @@ def compute(rules: Dict[str, Rule]) -> Dict:

def summary(check: Check, dataframe: bigquery.table.Table):
"""Compute all rules in this check from table loaded in BigQuery"""


# Check that user is connected to BigQuery
try:
try:
client = bigquery.Client()
except:
print('You are not connected to the BigQuery cloud. Please verify the steps followed during the Authenticate API requests step.')
print(
"You are not connected to the BigQuery cloud. Please verify the steps followed during the Authenticate API requests step."
)

# Compute the expression
computed_expressions = compute(check._rule)
Expand All @@ -148,7 +167,7 @@ def summary(check: Check, dataframe: bigquery.table.Table):
query_result = _compute_query_method(client, query)[0]

# Compute the total number of rows
rows = _compute_row(client, dataframe)[0]['count']
rows = _compute_row(client, dataframe)[0]["count"]

# Results
computation_basis = [
Expand All @@ -165,11 +184,11 @@ def summary(check: Check, dataframe: bigquery.table.Table):
"pass_rate": _calculate_pass_rate(query_result[f"KEY{hash_key}"], rows),
"pass_threshold": rule.coverage,
"status": _evaluate_status(
_calculate_pass_rate(query_result[f"KEY{hash_key}"], rows),
rule.coverage,
),
_calculate_pass_rate(query_result[f"KEY{hash_key}"], rows),
rule.coverage,
),
}
for index, (hash_key, rule) in enumerate(check._rule.items(), 1)
]

return pd.DataFrame(computation_basis).set_index('id')
return pd.DataFrame(computation_basis).set_index("id")
22 changes: 12 additions & 10 deletions test/unit/bigquery/test_are_complete.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,43 +7,45 @@


def test_positive():
df = bigquery.dataset.Table('bigquery-public-data.chicago_taxi_trips.taxi_trips')
df = bigquery.dataset.Table("bigquery-public-data.chicago_taxi_trips.taxi_trips")
check = Check(CheckLevel.WARNING, "pytest")
check.are_complete(("taxi_id", "unique_key"))
rs = check.validate(df)
assert rs.status.str.match('PASS')[1]
assert rs.status.str.match("PASS")[1]
assert rs.violations[1] == 0
assert rs.pass_rate[1] == 1.0


def test_negative():
df = bigquery.dataset.Table('bigquery-public-data.chicago_taxi_trips.taxi_trips')
df = bigquery.dataset.Table("bigquery-public-data.chicago_taxi_trips.taxi_trips")
check = Check(CheckLevel.WARNING, "pytest")
check.are_complete(("trip_start_timestamp", "trip_end_timestamp"))
rs = check.validate(df)
assert rs.status.str.match('FAIL')[1]
assert rs.status.str.match("FAIL")[1]
assert rs.violations[1] == 9217
assert rs.pass_threshold[1] == 1.0
assert rs.pass_rate[1] == 0.9999558876219533


@pytest.mark.parametrize(
"rule_column", [tuple(["taxi_id", "unique_key"]), list(["taxi_id", "unique_key"])], ids=("tuple", "list")
"rule_column",
[tuple(["taxi_id", "unique_key"]), list(["taxi_id", "unique_key"])],
ids=("tuple", "list"),
)
def test_parameters(spark, rule_column):
df = bigquery.dataset.Table('bigquery-public-data.chicago_taxi_trips.taxi_trips')
df = bigquery.dataset.Table("bigquery-public-data.chicago_taxi_trips.taxi_trips")
check = Check(CheckLevel.WARNING, "pytest")
check.are_complete(rule_column)
rs = check.validate(df)
assert rs.status.str.match('PASS')[1]
assert rs.status.str.match("PASS")[1]


def test_coverage():
df = bigquery.dataset.Table('bigquery-public-data.chicago_taxi_trips.taxi_trips')
df = bigquery.dataset.Table("bigquery-public-data.chicago_taxi_trips.taxi_trips")
check = Check(CheckLevel.WARNING, "pytest")
check.are_complete(("trip_start_timestamp", "trip_end_timestamp"), 0.7)
rs = check.validate(df)
assert rs.status.str.match('PASS')[1]
assert rs.status.str.match("PASS")[1]
assert rs.violations[1] == 9217
assert rs.pass_threshold[1] == 0.7
assert rs.pass_rate[1] == 0.9999558876219533 #207167439/207176656
assert rs.pass_rate[1] == 0.9999558876219533 # 207167439/207176656
14 changes: 7 additions & 7 deletions test/unit/bigquery/test_is_complete.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,21 +6,21 @@


def test_positive():
df = bigquery.dataset.Table('bigquery-public-data.chicago_taxi_trips.taxi_trips')
df = bigquery.dataset.Table("bigquery-public-data.chicago_taxi_trips.taxi_trips")
check = Check(CheckLevel.WARNING, "pytest")
check.is_complete("taxi_id")
rs = check.validate(df)
assert rs.status.str.match('PASS')[1]
assert rs.status.str.match("PASS")[1]
assert rs.violations[1] == 0
assert rs.pass_rate[1] == 1.0


def test_negative():
df = bigquery.dataset.Table('bigquery-public-data.chicago_taxi_trips.taxi_trips')
df = bigquery.dataset.Table("bigquery-public-data.chicago_taxi_trips.taxi_trips")
check = Check(CheckLevel.WARNING, "pytest")
check.is_complete("trip_end_timestamp")
rs = check.validate(df)
assert rs.status.str.match('FAIL')[1]
assert rs.status.str.match("FAIL")[1]
assert rs.violations[1] == 18434
assert rs.pass_threshold[1] == 1.0
assert rs.pass_rate[1] == 0.9999117752439066
Expand All @@ -31,11 +31,11 @@ def test_negative():


def test_coverage():
df = bigquery.dataset.Table('bigquery-public-data.chicago_taxi_trips.taxi_trips')
df = bigquery.dataset.Table("bigquery-public-data.chicago_taxi_trips.taxi_trips")
check = Check(CheckLevel.WARNING, "pytest")
check.is_complete("trip_end_timestamp", 0.7)
rs = check.validate(df)
assert rs.status.str.match('PASS')[1]
assert rs.status.str.match("PASS")[1]
assert rs.violations[1] == 18434
assert rs.pass_threshold[1] == 0.7
assert rs.pass_rate[1] == 0.9999117752439066 #207158222/207176656
assert rs.pass_rate[1] == 0.9999117752439066 # 207158222/207176656
41 changes: 41 additions & 0 deletions test/unit/bigquery/test_is_unique.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
import pandas as pd

from google.cloud import bigquery

from cuallee import Check, CheckLevel


def test_positive():
df = bigquery.dataset.Table("bigquery-public-data.chicago_taxi_trips.taxi_trips")
check = Check(CheckLevel.WARNING, "pytest")
check.is_unique("unique_key")
rs = check.validate(df)
assert rs.status.str.match("PASS")[1]
assert rs.violations[1] == 0
assert rs.pass_rate[1] == 1.0


def test_negative():
df = bigquery.dataset.Table("bigquery-public-data.chicago_taxi_trips.taxi_trips")
check = Check(CheckLevel.WARNING, "pytest")
check.is_unique("taxi_id")
rs = check.validate(df)
assert rs.status.str.match("FAIL")[1]
assert rs.violations[1] == 208933883
assert rs.pass_threshold[1] == 1.0
assert rs.pass_rate[1] == 9738 / 208943621


# def test_parameters():
# return "😅 No parameters to be tested!"


def test_coverage():
df = bigquery.dataset.Table("bigquery-public-data.chicago_taxi_trips.taxi_trips")
check = Check(CheckLevel.WARNING, "pytest")
check.is_unique("taxi_id", 0.000007)
rs = check.validate(df)
assert rs.status.str.match("PASS")[1]
assert rs.violations[1] == 208933883
assert rs.pass_threshold[1] == 0.000007
assert rs.pass_rate[1] == 9738 / 208943621

0 comments on commit 398cb62

Please sign in to comment.