Skip to content
This repository has been archived by the owner on May 17, 2024. It is now read-only.

Commit

Permalink
Merge pull request #311 from datafold/nov22_list_tables
Browse files Browse the repository at this point in the history
List tables from schema
  • Loading branch information
erezsh authored Nov 23, 2022
2 parents 1d7dc0c + 3d66325 commit 5825483
Show file tree
Hide file tree
Showing 12 changed files with 157 additions and 31 deletions.
8 changes: 6 additions & 2 deletions data_diff/sqeleton/abcs/mixins.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,9 +97,13 @@ class AbstractMixin_Schema(ABC):
TODO: Move AbstractDatabase.query_table_schema() and friends over here
"""

def table_information(self) -> Compilable:
"Query to return a table of schema information about existing tables"
raise NotImplementedError()

@abstractmethod
def list_tables(self, like: Compilable = None) -> Compilable:
"""Query to select the list of tables in the schema.
def list_tables(self, table_schema: str, like: Compilable = None) -> Compilable:
"""Query to select the list of tables in the schema. (query return type: table[str])
If 'like' is specified, the value is applied to the table name, using the 'like' operator.
"""
24 changes: 22 additions & 2 deletions data_diff/sqeleton/databases/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
import decimal

from ..utils import is_uuid, safezip
from ..queries import Expr, Compiler, table, Select, SKIP, Explain, Code
from ..queries import Expr, Compiler, table, Select, SKIP, Explain, Code, this
from ..abcs.database_types import (
AbstractDatabase,
AbstractDialect,
Expand All @@ -30,6 +30,8 @@
DbPath,
Boolean,
)
from ..abcs.mixins import Compilable
from ..abcs.mixins import AbstractMixin_Schema

logger = logging.getLogger("database")

Expand Down Expand Up @@ -101,6 +103,22 @@ def apply_query(callback: Callable[[str], Any], sql_code: Union[str, ThreadLocal
return callback(sql_code)


class Mixin_Schema(AbstractMixin_Schema):
def table_information(self) -> Compilable:
return table("information_schema", "tables")

def list_tables(self, table_schema: str, like: Compilable = None) -> Compilable:
return (
self.table_information()
.where(
this.table_schema == table_schema,
this.table_name.like(like) if like is not None else SKIP,
this.table_type == "BASE TABLE",
)
.select(this.table_name)
)


class BaseDialect(AbstractDialect):
SUPPORTS_PRIMARY_KEY = False
TYPE_CLASSES: Dict[str, type] = {}
Expand Down Expand Up @@ -354,7 +372,9 @@ def _refine_coltypes(self, table_path: DbPath, col_dict: Dict[str, ColType], whe
return

fields = [Code(self.dialect.normalize_uuid(self.dialect.quote(c), String_UUID())) for c in text_columns]
samples_by_row = self.query(table(*table_path).select(*fields).where(Code(where) if where else SKIP).limit(sample_size), list)
samples_by_row = self.query(
table(*table_path).select(*fields).where(Code(where) if where else SKIP).limit(sample_size), list
)
if not samples_by_row:
raise ValueError(f"Table {table_path} is empty.")

Expand Down
19 changes: 17 additions & 2 deletions data_diff/sqeleton/databases/bigquery.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,9 @@
TemporalType,
Boolean,
)
from ..abcs.mixins import AbstractMixin_MD5, AbstractMixin_NormalizeValue
from ..abcs.mixins import AbstractMixin_MD5, AbstractMixin_NormalizeValue, AbstractMixin_Schema
from ..abcs import Compilable
from ..queries import this, table, SKIP
from .base import BaseDialect, Database, import_helper, parse_table_name, ConnectError, apply_query
from .base import TIMESTAMP_PRECISION_POS, ThreadLocalInterpreter

Expand Down Expand Up @@ -51,7 +53,20 @@ def normalize_boolean(self, value: str, coltype: Boolean) -> str:
return self.to_string(f"cast({value} as int)")


class Dialect(BaseDialect):
class Mixin_Schema(AbstractMixin_Schema):
def list_tables(self, table_schema: str, like: Compilable = None) -> Compilable:
return (
table(table_schema, "INFORMATION_SCHEMA", "TABLES")
.where(
this.table_schema == table_schema,
this.table_name.like(like) if like is not None else SKIP,
this.table_type == "BASE TABLE",
)
.select(this.table_name)
)


class Dialect(BaseDialect, Mixin_Schema):
name = "BigQuery"
ROUNDS_ON_PREC_LOSS = False # Technically BigQuery doesn't allow implicit rounding or truncation
TYPE_CLASSES = {
Expand Down
4 changes: 2 additions & 2 deletions data_diff/sqeleton/databases/duckdb.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
ThreadLocalInterpreter,
TIMESTAMP_PRECISION_POS,
)
from .base import MD5_HEXDIGITS, CHECKSUM_HEXDIGITS
from .base import MD5_HEXDIGITS, CHECKSUM_HEXDIGITS, Mixin_Schema


@import_helper("duckdb")
Expand Down Expand Up @@ -54,7 +54,7 @@ def normalize_boolean(self, value: str, coltype: Boolean) -> str:
return self.to_string(f"{value}::INTEGER")


class Dialect(BaseDialect):
class Dialect(BaseDialect, Mixin_Schema):
name = "DuckDB"
ROUNDS_ON_PREC_LOSS = False
SUPPORTS_PRIMARY_KEY = True
Expand Down
4 changes: 2 additions & 2 deletions data_diff/sqeleton/databases/mysql.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
ConnectError,
BaseDialect,
)
from .base import MD5_HEXDIGITS, CHECKSUM_HEXDIGITS, TIMESTAMP_PRECISION_POS
from .base import MD5_HEXDIGITS, CHECKSUM_HEXDIGITS, TIMESTAMP_PRECISION_POS, Mixin_Schema


@import_helper("mysql")
Expand Down Expand Up @@ -47,7 +47,7 @@ def normalize_uuid(self, value: str, coltype: ColType_UUID) -> str:
return f"TRIM(CAST({value} AS char))"


class Dialect(BaseDialect):
class Dialect(BaseDialect, Mixin_Schema):
name = "MySQL"
ROUNDS_ON_PREC_LOSS = True
SUPPORTS_PRIMARY_KEY = True
Expand Down
25 changes: 19 additions & 6 deletions data_diff/sqeleton/databases/oracle.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,10 @@
TimestampTZ,
FractionalType,
)
from ..abcs.mixins import AbstractMixin_MD5, AbstractMixin_NormalizeValue
from .base import BaseDialect, ThreadedDatabase, import_helper, ConnectError, QueryError
from ..abcs.mixins import AbstractMixin_MD5, AbstractMixin_NormalizeValue, AbstractMixin_Schema
from ..abcs import Compilable
from ..queries import this, table, SKIP
from .base import BaseDialect, ThreadedDatabase, import_helper, ConnectError, QueryError, Mixin_Schema
from .base import TIMESTAMP_PRECISION_POS

SESSION_TIME_ZONE = None # Changed by the tests
Expand Down Expand Up @@ -57,8 +59,19 @@ def normalize_number(self, value: str, coltype: FractionalType) -> str:
format_str += "0." + "9" * (coltype.precision - 1) + "0"
return f"to_char({value}, '{format_str}')"

class Mixin_Schema(AbstractMixin_Schema):
def list_tables(self, table_schema: str, like: Compilable = None) -> Compilable:
return (
table('ALL_TABLES')
.where(
this.OWNER == table_schema,
this.TABLE_NAME.like(like) if like is not None else SKIP,
)
.select(table_name = this.TABLE_NAME)
)


class Dialect(BaseDialect):
class Dialect(BaseDialect, Mixin_Schema):
name = "Oracle"
SUPPORTS_PRIMARY_KEY = True
TYPE_CLASSES: Dict[str, type] = {
Expand All @@ -73,7 +86,7 @@ class Dialect(BaseDialect):
ROUNDS_ON_PREC_LOSS = True

def quote(self, s: str):
return f"{s}"
return f'"{s}"'

def to_string(self, s: str):
return f"cast({s} as varchar(1024))"
Expand Down Expand Up @@ -143,7 +156,7 @@ class Oracle(ThreadedDatabase):
def __init__(self, *, host, database, thread_count, **kw):
self.kwargs = dict(dsn=f"{host}/{database}" if database else host, **kw)

self.default_schema = kw.get("user")
self.default_schema = kw.get("user").upper()

super().__init__(thread_count=thread_count)

Expand All @@ -168,5 +181,5 @@ def select_table_schema(self, path: DbPath) -> str:

return (
f"SELECT column_name, data_type, 6 as datetime_precision, data_precision as numeric_precision, data_scale as numeric_scale"
f" FROM ALL_TAB_COLUMNS WHERE table_name = '{table.upper()}' AND owner = '{schema.upper()}'"
f" FROM ALL_TAB_COLUMNS WHERE table_name = '{table}' AND owner = '{schema}'"
)
9 changes: 2 additions & 7 deletions data_diff/sqeleton/databases/postgresql.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,7 @@
Boolean,
)
from ..abcs.mixins import AbstractMixin_MD5, AbstractMixin_NormalizeValue
from .base import (
BaseDialect,
ThreadedDatabase,
import_helper,
ConnectError,
)
from .base import BaseDialect, ThreadedDatabase, import_helper, ConnectError, Mixin_Schema
from .base import MD5_HEXDIGITS, CHECKSUM_HEXDIGITS, _CHECKSUM_BITSIZE, TIMESTAMP_PRECISION_POS

SESSION_TIME_ZONE = None # Changed by the tests
Expand Down Expand Up @@ -53,7 +48,7 @@ def normalize_boolean(self, value: str, coltype: Boolean) -> str:
return self.to_string(f"{value}::int")


class PostgresqlDialect(BaseDialect):
class PostgresqlDialect(BaseDialect, Mixin_Schema):
name = "PostgreSQL"
ROUNDS_ON_PREC_LOSS = True
SUPPORTS_PRIMARY_KEY = True
Expand Down
4 changes: 2 additions & 2 deletions data_diff/sqeleton/databases/presto.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
Boolean,
)
from ..abcs.mixins import AbstractMixin_MD5, AbstractMixin_NormalizeValue
from .base import BaseDialect, Database, import_helper, ThreadLocalInterpreter
from .base import BaseDialect, Database, import_helper, ThreadLocalInterpreter, Mixin_Schema
from .base import (
MD5_HEXDIGITS,
CHECKSUM_HEXDIGITS,
Expand Down Expand Up @@ -69,7 +69,7 @@ def normalize_boolean(self, value: str, coltype: Boolean) -> str:
return self.to_string(f"cast ({value} as int)")


class Dialect(BaseDialect):
class Dialect(BaseDialect, Mixin_Schema):
name = "Presto"
ROUNDS_ON_PREC_LOSS = True
TYPE_CLASSES = {
Expand Down
25 changes: 23 additions & 2 deletions data_diff/sqeleton/databases/snowflake.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,9 @@
DbPath,
Boolean,
)
from ..abcs.mixins import AbstractMixin_MD5, AbstractMixin_NormalizeValue
from ..abcs.mixins import AbstractMixin_MD5, AbstractMixin_NormalizeValue, AbstractMixin_Schema
from ..abcs import Compilable
from data_diff.sqeleton.queries import table, this, SKIP
from .base import BaseDialect, ConnectError, Database, import_helper, CHECKSUM_MASK, ThreadLocalInterpreter


Expand Down Expand Up @@ -46,7 +48,23 @@ def normalize_boolean(self, value: str, coltype: Boolean) -> str:
return self.to_string(f"{value}::int")


class Dialect(BaseDialect):
class Mixin_Schema(AbstractMixin_Schema):
def table_information(self) -> Compilable:
return table("INFORMATION_SCHEMA", "TABLES")

def list_tables(self, table_schema: str, like: Compilable = None) -> Compilable:
return (
self.table_information()
.where(
this.TABLE_SCHEMA == table_schema,
this.TABLE_NAME.like(like) if like is not None else SKIP,
this.TABLE_TYPE == "BASE TABLE",
)
.select(table_name=this.TABLE_NAME)
)


class Dialect(BaseDialect, Mixin_Schema):
name = "Snowflake"
ROUNDS_ON_PREC_LOSS = False
TYPE_CLASSES = {
Expand All @@ -72,6 +90,9 @@ def quote(self, s: str):
def to_string(self, s: str):
return f"cast({s} as string)"

def table_information(self) -> Compilable:
return table("INFORMATION_SCHEMA", "TABLES")


class Snowflake(Database):
dialect = Dialect()
Expand Down
21 changes: 19 additions & 2 deletions data_diff/sqeleton/databases/vertica.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,9 @@
Boolean,
ColType_UUID,
)
from ..abcs.mixins import AbstractMixin_MD5, AbstractMixin_NormalizeValue
from ..abcs.mixins import AbstractMixin_MD5, AbstractMixin_NormalizeValue, AbstractMixin_Schema
from ..abcs import Compilable
from ..queries import table, this, SKIP


@import_helper("vertica")
Expand Down Expand Up @@ -60,7 +62,22 @@ def normalize_boolean(self, value: str, coltype: Boolean) -> str:
return self.to_string(f"cast ({value} as int)")


class Dialect(BaseDialect):
class Mixin_Schema(AbstractMixin_Schema):
def table_information(self) -> Compilable:
return table("v_catalog", "tables")

def list_tables(self, table_schema: str, like: Compilable = None) -> Compilable:
return (
self.table_information()
.where(
this.table_schema == table_schema,
this.table_name.like(like) if like is not None else SKIP,
)
.select(this.table_name)
)


class Dialect(BaseDialect, Mixin_Schema):
name = "Vertica"
ROUNDS_ON_PREC_LOSS = True

Expand Down
40 changes: 40 additions & 0 deletions tests/sqeleton/test_database.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,33 @@
from typing import Callable, List
import unittest

from ..common import str_to_checksum, TEST_MYSQL_CONN_STRING
from ..common import str_to_checksum, test_each_database_in_list, TestPerDatabase, get_conn, random_table_suffix
# from data_diff.sqeleton import databases as db
# from data_diff.sqeleton import connect

from data_diff.sqeleton.queries import table

from data_diff import databases as dbs
from data_diff.databases import connect


TEST_DATABASES = {
dbs.MySQL,
dbs.PostgreSQL,
dbs.Oracle,
dbs.Redshift,
dbs.Snowflake,
dbs.DuckDB,
dbs.BigQuery,
dbs.Presto,
dbs.Trino,
dbs.Vertica,
}

test_each_database: Callable = test_each_database_in_list(TEST_DATABASES)


class TestDatabase(unittest.TestCase):
def setUp(self):
self.mysql = connect(TEST_MYSQL_CONN_STRING)
Expand All @@ -25,3 +49,19 @@ def test_bad_uris(self):
self.assertRaises(ValueError, connect, "postgresql:///bla/foo")
self.assertRaises(ValueError, connect, "snowflake://user:pass@bya42734/xdiffdev/TEST1")
self.assertRaises(ValueError, connect, "snowflake://user:pass@bya42734/xdiffdev/TEST1?warehouse=ha&schema=dup")


@test_each_database
class TestSchema(TestPerDatabase):
def test_table_list(self):
name = self.table_src_name
db = self.connection
tbl = table(db.parse_table_name(name), schema={'id': int})
q = db.dialect.list_tables(db.default_schema, name)
assert not db.query(q)

db.query(tbl.create())
self.assertEqual( db.query(q, List[str] ), [name])

db.query( tbl.drop() )
assert not db.query(q)
5 changes: 3 additions & 2 deletions tests/test_database_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -602,12 +602,13 @@ def _create_indexes(conn, table):

try:
if_not_exists = "IF NOT EXISTS" if not isinstance(conn, (db.MySQL, db.Oracle)) else ""
quote = conn.dialect.quote
conn.query(
f"CREATE INDEX {if_not_exists} xa_{table[1:-1]} ON {table} (id, col)",
f"CREATE INDEX {if_not_exists} xa_{table[1:-1]} ON {table} ({quote('id')}, {quote('col')})",
None,
)
conn.query(
f"CREATE INDEX {if_not_exists} xb_{table[1:-1]} ON {table} (id)",
f"CREATE INDEX {if_not_exists} xb_{table[1:-1]} ON {table} ({quote('id')})",
None,
)
except Exception as err:
Expand Down

0 comments on commit 5825483

Please sign in to comment.