Skip to content

Commit

Permalink
Migrate references to base adapter (#689)
Browse files Browse the repository at this point in the history
* Migrate references to get nearly all unit tests working.

* normalize exceptions paths; remove always importing the mp context

* remove refs to `runtime_config` and delete test_context.py

* add changie

* fix from_config

* fix `from_config` call

* update dev-requirements.txt to point to dbt-core main

* remove mp_context dep and ignore relation_config issues

---------

Co-authored-by: Mila Page <versusfacit@users.noreply.github.com>
Co-authored-by: Colin <colin.rogers@dbtlabs.com>
  • Loading branch information
3 people authored Jan 10, 2024
1 parent 15dafa3 commit f95c534
Show file tree
Hide file tree
Showing 15 changed files with 113 additions and 340 deletions.
6 changes: 6 additions & 0 deletions .changes/unreleased/Under the Hood-20240102-152425.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
kind: Under the Hood
body: Update base adapter references as part of decoupling migration
time: 2024-01-02T15:24:25.890421-08:00
custom:
Author: colin-rogers-dbt VersusFacit
Issue: "698"
36 changes: 17 additions & 19 deletions dbt/adapters/redshift/connections.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,16 +7,17 @@
import agate
import sqlparse
import redshift_connector
from dbt.adapters.exceptions import FailedToConnectError
from dbt.common.clients import agate_helper
from redshift_connector.utils.oids import get_datatype_name

from dbt.adapters.sql import SQLConnectionManager
from dbt.contracts.connection import AdapterResponse, Connection, Credentials
from dbt.contracts.util import Replaceable
from dbt.dataclass_schema import dbtClassMixin, StrEnum, ValidationError
from dbt.events import AdapterLogger
from dbt.exceptions import DbtRuntimeError, CompilationError
import dbt.flags
from dbt.helper_types import Port
from dbt.adapters.contracts.connection import AdapterResponse, Connection, Credentials
from dbt.adapters.events.logging import AdapterLogger
from dbt.common.contracts.util import Replaceable
from dbt.common.dataclass_schema import dbtClassMixin, StrEnum, ValidationError
from dbt.common.helper_types import Port
from dbt.common.exceptions import DbtRuntimeError, CompilationError, DbtDatabaseError


class SSLConfigError(CompilationError):
Expand All @@ -33,9 +34,6 @@ def get_message(self) -> str:
logger = AdapterLogger("Redshift")


drop_lock: Lock = dbt.flags.MP_CONTEXT.Lock() # type: ignore


class RedshiftConnectionMethod(StrEnum):
DATABASE = "database"
IAM = "iam"
Expand Down Expand Up @@ -185,7 +183,7 @@ def get_connect_method(self):
# this requirement is really annoying to encode into json schema,
# so validate it here
if self.credentials.password is None:
raise dbt.exceptions.FailedToConnectError(
raise FailedToConnectError(
"'password' field is required for 'database' credentials"
)

Expand All @@ -204,7 +202,7 @@ def connect():

elif method == RedshiftConnectionMethod.IAM:
if not self.credentials.cluster_id and "serverless" not in self.credentials.host:
raise dbt.exceptions.FailedToConnectError(
raise FailedToConnectError(
"Failed to use IAM method. 'cluster_id' must be provided for provisioned cluster. "
"'host' must be provided for serverless endpoint."
)
Expand All @@ -227,9 +225,7 @@ def connect():
return c

else:
raise dbt.exceptions.FailedToConnectError(
"Invalid 'method' in profile: '{}'".format(method)
)
raise FailedToConnectError("Invalid 'method' in profile: '{}'".format(method))

return connect

Expand Down Expand Up @@ -278,16 +274,16 @@ def exception_handler(self, sql):
err_msg = str(e).strip()
logger.debug(f"Redshift error: {err_msg}")
self.rollback_if_open()
raise dbt.exceptions.DbtDatabaseError(err_msg) from e
raise DbtDatabaseError(err_msg) from e

except Exception as e:
logger.debug("Error running SQL: {}", sql)
logger.debug("Rolling back transaction.")
self.rollback_if_open()
# Raise DBT native exceptions as is.
if isinstance(e, dbt.exceptions.DbtRuntimeError):
if isinstance(e, DbtRuntimeError):
raise
raise dbt.exceptions.DbtRuntimeError(str(e)) from e
raise DbtRuntimeError(str(e)) from e

@contextmanager
def fresh_transaction(self):
Expand All @@ -297,6 +293,8 @@ def fresh_transaction(self):
See drop_relation in RedshiftAdapter for more information.
"""
drop_lock: Lock = self.lock

with drop_lock:
connection = self.get_thread_connection()

Expand Down Expand Up @@ -349,7 +347,7 @@ def execute(
if fetch:
table = self.get_result_from_cursor(cursor, limit)
else:
table = dbt.clients.agate_helper.empty_table()
table = agate_helper.empty_table()
return response, table

def add_query(self, sql, auto_begin=True, bindings=None, abridge_sql_log=False):
Expand Down
14 changes: 7 additions & 7 deletions dbt/adapters/redshift/impl.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,17 @@
import os
from dataclasses import dataclass
from dbt.common.contracts.constraints import ConstraintType
from typing import Optional, Set, Any, Dict, Type
from collections import namedtuple
from dbt.adapters.base import PythonJobHelper
from dbt.adapters.base.impl import AdapterConfig, ConstraintSupport
from dbt.adapters.base.meta import available
from dbt.adapters.sql import SQLAdapter
from dbt.contracts.connection import AdapterResponse
from dbt.contracts.graph.nodes import ConstraintType
from dbt.events import AdapterLogger
from dbt.adapters.contracts.connection import AdapterResponse
from dbt.adapters.events.logging import AdapterLogger


import dbt.exceptions
import dbt.adapters.exceptions

from dbt.adapters.redshift import RedshiftConnectionManager, RedshiftRelation

Expand Down Expand Up @@ -96,7 +96,7 @@ def verify_database(self, database):
ra3_node = self.config.credentials.ra3_node

if database.lower() != expected.lower() and not ra3_node:
raise dbt.exceptions.NotImplementedError(
raise dbt.common.exceptions.NotImplementedError(
"Cross-db references allowed only in RA3.* node. ({} vs {})".format(
database, expected
)
Expand All @@ -109,9 +109,9 @@ def _get_catalog_schemas(self, manifest):
schemas = super(SQLAdapter, self)._get_catalog_schemas(manifest)
try:
return schemas.flatten(allow_multiple_databases=self.config.credentials.ra3_node)
except dbt.exceptions.DbtRuntimeError as exc:
except dbt.common.exceptions.DbtRuntimeError as exc:
msg = f"Cross-db references allowed only in {self.type()} RA3.* node. Got {exc.msg}"
raise dbt.exceptions.CompilationError(msg)
raise dbt.common.exceptions.CompilationError(msg)

def valid_incremental_strategies(self):
"""The set of standard builtin strategies which this adapter supports out-of-the-box.
Expand Down
24 changes: 10 additions & 14 deletions dbt/adapters/redshift/relation.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from dataclasses import dataclass
from dbt.adapters.contracts.relation import RelationConfig
from typing import Optional

from dbt.adapters.base.relation import BaseRelation
Expand All @@ -7,10 +8,8 @@
RelationConfigChangeAction,
RelationResults,
)
from dbt.context.providers import RuntimeConfigObject
from dbt.contracts.graph.nodes import ModelNode
from dbt.contracts.relation import RelationType
from dbt.exceptions import DbtRuntimeError
from dbt.adapters.base import RelationType
from dbt.common.exceptions import DbtRuntimeError

from dbt.adapters.redshift.relation_configs import (
RedshiftMaterializedViewConfig,
Expand Down Expand Up @@ -60,31 +59,28 @@ def relation_max_name_length(self):
return MAX_CHARACTERS_IN_IDENTIFIER

@classmethod
def from_runtime_config(cls, runtime_config: RuntimeConfigObject) -> RelationConfigBase:
model_node: ModelNode = runtime_config.model
relation_type: str = model_node.config.materialized
def from_config(cls, config: RelationConfig) -> RelationConfigBase:
relation_type: str = config.config.materialized # type: ignore

if relation_config := cls.relation_configs.get(relation_type):
return relation_config.from_model_node(model_node)
return relation_config.from_relation_config(config)

raise DbtRuntimeError(
f"from_runtime_config() is not supported for the provided relation type: {relation_type}"
f"from_config() is not supported for the provided relation type: {relation_type}"
)

@classmethod
def materialized_view_config_changeset(
cls, relation_results: RelationResults, runtime_config: RuntimeConfigObject
cls, relation_results: RelationResults, relation_config: RelationConfig
) -> Optional[RedshiftMaterializedViewConfigChangeset]:
config_change_collection = RedshiftMaterializedViewConfigChangeset()

existing_materialized_view = RedshiftMaterializedViewConfig.from_relation_results(
relation_results
)
new_materialized_view = RedshiftMaterializedViewConfig.from_model_node(
runtime_config.model
new_materialized_view = RedshiftMaterializedViewConfig.from_relation_config(
relation_config
)
assert isinstance(existing_materialized_view, RedshiftMaterializedViewConfig)
assert isinstance(new_materialized_view, RedshiftMaterializedViewConfig)

if new_materialized_view.autorefresh != existing_materialized_view.autorefresh:
config_change_collection.autorefresh = RedshiftAutoRefreshConfigChange(
Expand Down
24 changes: 12 additions & 12 deletions dbt/adapters/redshift/relation_configs/base.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
from dataclasses import dataclass
from typing import Optional
from typing import Optional, Dict

import agate
from dbt.adapters.base.relation import Policy
from dbt.adapters.contracts.relation import ComponentName, RelationConfig
from dbt.adapters.relation_configs import (
RelationConfigBase,
RelationResults,
)
from dbt.contracts.graph.nodes import ModelNode
from dbt.contracts.relation import ComponentName
from typing_extensions import Self

from dbt.adapters.redshift.relation_configs.policies import (
RedshiftIncludePolicy,
Expand All @@ -31,25 +31,25 @@ def quote_policy(cls) -> Policy:
return RedshiftQuotePolicy()

@classmethod
def from_model_node(cls, model_node: ModelNode) -> "RelationConfigBase":
relation_config = cls.parse_model_node(model_node)
relation = cls.from_dict(relation_config)
return relation
def from_relation_config(cls, relation_config: RelationConfig) -> Self:
relation_config_dict = cls.parse_relation_config(relation_config)
relation = cls.from_dict(relation_config_dict)
return relation # type: ignore

@classmethod
def parse_model_node(cls, model_node: ModelNode) -> dict:
def parse_relation_config(cls, relation_config: RelationConfig) -> Dict:
raise NotImplementedError(
"`parse_model_node()` needs to be implemented on this RelationConfigBase instance"
"`parse_relation_config()` needs to be implemented on this RelationConfigBase instance"
)

@classmethod
def from_relation_results(cls, relation_results: RelationResults) -> "RelationConfigBase":
def from_relation_results(cls, relation_results: RelationResults) -> Self:
relation_config = cls.parse_relation_results(relation_results)
relation = cls.from_dict(relation_config)
return relation
return relation # type: ignore

@classmethod
def parse_relation_results(cls, relation_results: RelationResults) -> dict:
def parse_relation_results(cls, relation_results: RelationResults) -> Dict:
raise NotImplementedError(
"`parse_relation_results()` needs to be implemented on this RelationConfigBase instance"
)
Expand Down
21 changes: 11 additions & 10 deletions dbt/adapters/redshift/relation_configs/dist.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from dataclasses import dataclass
from typing import Optional, Set
from dbt.adapters.contracts.relation import RelationConfig
from typing import Optional, Set, Dict

import agate
from dbt.adapters.relation_configs import (
Expand All @@ -8,9 +9,9 @@
RelationConfigValidationMixin,
RelationConfigValidationRule,
)
from dbt.contracts.graph.nodes import ModelNode
from dbt.dataclass_schema import StrEnum
from dbt.exceptions import DbtRuntimeError
from dbt.common.dataclass_schema import StrEnum
from dbt.common.exceptions import DbtRuntimeError
from typing_extensions import Self

from dbt.adapters.redshift.relation_configs.base import RedshiftRelationConfigBase

Expand Down Expand Up @@ -65,29 +66,29 @@ def validation_rules(self) -> Set[RelationConfigValidationRule]:
}

@classmethod
def from_dict(cls, config_dict) -> "RedshiftDistConfig":
def from_dict(cls, config_dict) -> Self:
kwargs_dict = {
"diststyle": config_dict.get("diststyle"),
"distkey": config_dict.get("distkey"),
}
dist: "RedshiftDistConfig" = super().from_dict(kwargs_dict) # type: ignore
dist: Self = super().from_dict(kwargs_dict) # type: ignore
return dist

@classmethod
def parse_model_node(cls, model_node: ModelNode) -> dict:
def parse_relation_config(cls, relation_config: RelationConfig) -> dict:
"""
Translate ModelNode objects from the user-provided config into a standard dictionary.
Args:
model_node: the description of the distkey and diststyle from the user in this format:
relation_config: the description of the distkey and diststyle from the user in this format:
{
"dist": any("auto", "even", "all") or "<column_name>"
}
Returns: a standard dictionary describing this `RedshiftDistConfig` instance
"""
dist = model_node.config.extra.get("dist", "")
dist = relation_config.config.extra.get("dist", "") # type: ignore

diststyle = dist.lower()

Expand All @@ -107,7 +108,7 @@ def parse_model_node(cls, model_node: ModelNode) -> dict:
return config

@classmethod
def parse_relation_results(cls, relation_results_entry: agate.Row) -> dict:
def parse_relation_results(cls, relation_results_entry: agate.Row) -> Dict:
"""
Translate agate objects from the database into a standard dictionary.
Expand Down
Loading

0 comments on commit f95c534

Please sign in to comment.