Skip to content
This repository has been archived by the owner on May 17, 2024. It is now read-only.

Commit

Permalink
Add Oracle, Redshift
Browse files Browse the repository at this point in the history
  • Loading branch information
erezsh committed Nov 23, 2022
1 parent 49e8a4f commit 3d66325
Show file tree
Hide file tree
Showing 3 changed files with 25 additions and 11 deletions.
25 changes: 19 additions & 6 deletions data_diff/sqeleton/databases/oracle.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,10 @@
TimestampTZ,
FractionalType,
)
from ..abcs.mixins import AbstractMixin_MD5, AbstractMixin_NormalizeValue
from .base import BaseDialect, ThreadedDatabase, import_helper, ConnectError, QueryError
from ..abcs.mixins import AbstractMixin_MD5, AbstractMixin_NormalizeValue, AbstractMixin_Schema
from ..abcs import Compilable
from ..queries import this, table, SKIP
from .base import BaseDialect, ThreadedDatabase, import_helper, ConnectError, QueryError, Mixin_Schema
from .base import TIMESTAMP_PRECISION_POS

SESSION_TIME_ZONE = None # Changed by the tests
Expand Down Expand Up @@ -57,8 +59,19 @@ def normalize_number(self, value: str, coltype: FractionalType) -> str:
format_str += "0." + "9" * (coltype.precision - 1) + "0"
return f"to_char({value}, '{format_str}')"

class Mixin_Schema(AbstractMixin_Schema):
def list_tables(self, table_schema: str, like: Compilable = None) -> Compilable:
return (
table('ALL_TABLES')
.where(
this.OWNER == table_schema,
this.TABLE_NAME.like(like) if like is not None else SKIP,
)
.select(table_name = this.TABLE_NAME)
)


class Dialect(BaseDialect):
class Dialect(BaseDialect, Mixin_Schema):
name = "Oracle"
SUPPORTS_PRIMARY_KEY = True
TYPE_CLASSES: Dict[str, type] = {
Expand All @@ -73,7 +86,7 @@ class Dialect(BaseDialect):
ROUNDS_ON_PREC_LOSS = True

def quote(self, s: str):
return f"{s}"
return f'"{s}"'

def to_string(self, s: str):
return f"cast({s} as varchar(1024))"
Expand Down Expand Up @@ -143,7 +156,7 @@ class Oracle(ThreadedDatabase):
def __init__(self, *, host, database, thread_count, **kw):
self.kwargs = dict(dsn=f"{host}/{database}" if database else host, **kw)

self.default_schema = kw.get("user")
self.default_schema = kw.get("user").upper()

super().__init__(thread_count=thread_count)

Expand All @@ -168,5 +181,5 @@ def select_table_schema(self, path: DbPath) -> str:

return (
f"SELECT column_name, data_type, 6 as datetime_precision, data_precision as numeric_precision, data_scale as numeric_scale"
f" FROM ALL_TAB_COLUMNS WHERE table_name = '{table.upper()}' AND owner = '{schema.upper()}'"
f" FROM ALL_TAB_COLUMNS WHERE table_name = '{table}' AND owner = '{schema}'"
)
6 changes: 3 additions & 3 deletions tests/sqeleton/test_database.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@
TEST_DATABASES = {
dbs.MySQL,
dbs.PostgreSQL,
# dbs.Oracle,
# dbs.Redshift,
dbs.Oracle,
dbs.Redshift,
dbs.Snowflake,
dbs.DuckDB,
dbs.BigQuery,
Expand Down Expand Up @@ -61,7 +61,7 @@ def test_table_list(self):
assert not db.query(q)

db.query(tbl.create())
assert db.query(q, List[str] ) == [name]
self.assertEqual( db.query(q, List[str] ), [name])

db.query( tbl.drop() )
assert not db.query(q)
5 changes: 3 additions & 2 deletions tests/test_database_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -602,12 +602,13 @@ def _create_indexes(conn, table):

try:
if_not_exists = "IF NOT EXISTS" if not isinstance(conn, (db.MySQL, db.Oracle)) else ""
quote = conn.dialect.quote
conn.query(
f"CREATE INDEX {if_not_exists} xa_{table[1:-1]} ON {table} (id, col)",
f"CREATE INDEX {if_not_exists} xa_{table[1:-1]} ON {table} ({quote('id')}, {quote('col')})",
None,
)
conn.query(
f"CREATE INDEX {if_not_exists} xb_{table[1:-1]} ON {table} (id)",
f"CREATE INDEX {if_not_exists} xb_{table[1:-1]} ON {table} ({quote('id')})",
None,
)
except Exception as err:
Expand Down

0 comments on commit 3d66325

Please sign in to comment.