diff --git a/.changes/unreleased/Under the Hood-20240102-152425.yaml b/.changes/unreleased/Under the Hood-20240102-152425.yaml new file mode 100644 index 000000000..23a3eeb46 --- /dev/null +++ b/.changes/unreleased/Under the Hood-20240102-152425.yaml @@ -0,0 +1,6 @@ +kind: Under the Hood +body: Update base adapter references as part of decoupling migration +time: 2024-01-02T15:24:25.890421-08:00 +custom: + Author: colin-rogers-dbt VersusFacit + Issue: "698" diff --git a/dbt/adapters/redshift/connections.py b/dbt/adapters/redshift/connections.py index 0c9d1b7ed..005f03feb 100644 --- a/dbt/adapters/redshift/connections.py +++ b/dbt/adapters/redshift/connections.py @@ -7,16 +7,17 @@ import agate import sqlparse import redshift_connector +from dbt.adapters.exceptions import FailedToConnectError +from dbt.common.clients import agate_helper from redshift_connector.utils.oids import get_datatype_name from dbt.adapters.sql import SQLConnectionManager -from dbt.contracts.connection import AdapterResponse, Connection, Credentials -from dbt.contracts.util import Replaceable -from dbt.dataclass_schema import dbtClassMixin, StrEnum, ValidationError -from dbt.events import AdapterLogger -from dbt.exceptions import DbtRuntimeError, CompilationError -import dbt.flags -from dbt.helper_types import Port +from dbt.adapters.contracts.connection import AdapterResponse, Connection, Credentials +from dbt.adapters.events.logging import AdapterLogger +from dbt.common.contracts.util import Replaceable +from dbt.common.dataclass_schema import dbtClassMixin, StrEnum, ValidationError +from dbt.common.helper_types import Port +from dbt.common.exceptions import DbtRuntimeError, CompilationError, DbtDatabaseError class SSLConfigError(CompilationError): @@ -33,9 +34,6 @@ def get_message(self) -> str: logger = AdapterLogger("Redshift") -drop_lock: Lock = dbt.flags.MP_CONTEXT.Lock() # type: ignore - - class RedshiftConnectionMethod(StrEnum): DATABASE = "database" IAM = "iam" @@ -185,7 +183,7 @@ def get_connect_method(self): # this requirement is really annoying to encode into json schema, # so validate it here if self.credentials.password is None: - raise dbt.exceptions.FailedToConnectError( + raise FailedToConnectError( "'password' field is required for 'database' credentials" ) @@ -204,7 +202,7 @@ def connect(): elif method == RedshiftConnectionMethod.IAM: if not self.credentials.cluster_id and "serverless" not in self.credentials.host: - raise dbt.exceptions.FailedToConnectError( + raise FailedToConnectError( "Failed to use IAM method. 'cluster_id' must be provided for provisioned cluster. " "'host' must be provided for serverless endpoint." ) @@ -227,9 +225,7 @@ def connect(): return c else: - raise dbt.exceptions.FailedToConnectError( - "Invalid 'method' in profile: '{}'".format(method) - ) + raise FailedToConnectError("Invalid 'method' in profile: '{}'".format(method)) return connect @@ -278,16 +274,16 @@ def exception_handler(self, sql): err_msg = str(e).strip() logger.debug(f"Redshift error: {err_msg}") self.rollback_if_open() - raise dbt.exceptions.DbtDatabaseError(err_msg) from e + raise DbtDatabaseError(err_msg) from e except Exception as e: logger.debug("Error running SQL: {}", sql) logger.debug("Rolling back transaction.") self.rollback_if_open() # Raise DBT native exceptions as is. - if isinstance(e, dbt.exceptions.DbtRuntimeError): + if isinstance(e, DbtRuntimeError): raise - raise dbt.exceptions.DbtRuntimeError(str(e)) from e + raise DbtRuntimeError(str(e)) from e @contextmanager def fresh_transaction(self): @@ -297,6 +293,8 @@ def fresh_transaction(self): See drop_relation in RedshiftAdapter for more information. """ + drop_lock: Lock = self.lock + with drop_lock: connection = self.get_thread_connection() @@ -349,7 +347,7 @@ def execute( if fetch: table = self.get_result_from_cursor(cursor, limit) else: - table = dbt.clients.agate_helper.empty_table() + table = agate_helper.empty_table() return response, table def add_query(self, sql, auto_begin=True, bindings=None, abridge_sql_log=False): diff --git a/dbt/adapters/redshift/impl.py b/dbt/adapters/redshift/impl.py index fbb30c784..d8d47aa40 100644 --- a/dbt/adapters/redshift/impl.py +++ b/dbt/adapters/redshift/impl.py @@ -1,17 +1,17 @@ import os from dataclasses import dataclass +from dbt.common.contracts.constraints import ConstraintType from typing import Optional, Set, Any, Dict, Type from collections import namedtuple from dbt.adapters.base import PythonJobHelper from dbt.adapters.base.impl import AdapterConfig, ConstraintSupport from dbt.adapters.base.meta import available from dbt.adapters.sql import SQLAdapter -from dbt.contracts.connection import AdapterResponse -from dbt.contracts.graph.nodes import ConstraintType -from dbt.events import AdapterLogger +from dbt.adapters.contracts.connection import AdapterResponse +from dbt.adapters.events.logging import AdapterLogger -import dbt.exceptions +import dbt.adapters.exceptions from dbt.adapters.redshift import RedshiftConnectionManager, RedshiftRelation @@ -96,7 +96,7 @@ def verify_database(self, database): ra3_node = self.config.credentials.ra3_node if database.lower() != expected.lower() and not ra3_node: - raise dbt.exceptions.NotImplementedError( + raise dbt.common.exceptions.NotImplementedError( "Cross-db references allowed only in RA3.* node. ({} vs {})".format( database, expected ) @@ -109,9 +109,9 @@ def _get_catalog_schemas(self, manifest): schemas = super(SQLAdapter, self)._get_catalog_schemas(manifest) try: return schemas.flatten(allow_multiple_databases=self.config.credentials.ra3_node) - except dbt.exceptions.DbtRuntimeError as exc: + except dbt.common.exceptions.DbtRuntimeError as exc: msg = f"Cross-db references allowed only in {self.type()} RA3.* node. Got {exc.msg}" - raise dbt.exceptions.CompilationError(msg) + raise dbt.common.exceptions.CompilationError(msg) def valid_incremental_strategies(self): """The set of standard builtin strategies which this adapter supports out-of-the-box. diff --git a/dbt/adapters/redshift/relation.py b/dbt/adapters/redshift/relation.py index ba2ad4a5a..daa3c8f45 100644 --- a/dbt/adapters/redshift/relation.py +++ b/dbt/adapters/redshift/relation.py @@ -1,4 +1,5 @@ from dataclasses import dataclass +from dbt.adapters.contracts.relation import RelationConfig from typing import Optional from dbt.adapters.base.relation import BaseRelation @@ -7,10 +8,8 @@ RelationConfigChangeAction, RelationResults, ) -from dbt.context.providers import RuntimeConfigObject -from dbt.contracts.graph.nodes import ModelNode -from dbt.contracts.relation import RelationType -from dbt.exceptions import DbtRuntimeError +from dbt.adapters.base import RelationType +from dbt.common.exceptions import DbtRuntimeError from dbt.adapters.redshift.relation_configs import ( RedshiftMaterializedViewConfig, @@ -60,31 +59,28 @@ def relation_max_name_length(self): return MAX_CHARACTERS_IN_IDENTIFIER @classmethod - def from_runtime_config(cls, runtime_config: RuntimeConfigObject) -> RelationConfigBase: - model_node: ModelNode = runtime_config.model - relation_type: str = model_node.config.materialized + def from_config(cls, config: RelationConfig) -> RelationConfigBase: + relation_type: str = config.config.materialized # type: ignore if relation_config := cls.relation_configs.get(relation_type): - return relation_config.from_model_node(model_node) + return relation_config.from_relation_config(config) raise DbtRuntimeError( - f"from_runtime_config() is not supported for the provided relation type: {relation_type}" + f"from_config() is not supported for the provided relation type: {relation_type}" ) @classmethod def materialized_view_config_changeset( - cls, relation_results: RelationResults, runtime_config: RuntimeConfigObject + cls, relation_results: RelationResults, relation_config: RelationConfig ) -> Optional[RedshiftMaterializedViewConfigChangeset]: config_change_collection = RedshiftMaterializedViewConfigChangeset() existing_materialized_view = RedshiftMaterializedViewConfig.from_relation_results( relation_results ) - new_materialized_view = RedshiftMaterializedViewConfig.from_model_node( - runtime_config.model + new_materialized_view = RedshiftMaterializedViewConfig.from_relation_config( + relation_config ) - assert isinstance(existing_materialized_view, RedshiftMaterializedViewConfig) - assert isinstance(new_materialized_view, RedshiftMaterializedViewConfig) if new_materialized_view.autorefresh != existing_materialized_view.autorefresh: config_change_collection.autorefresh = RedshiftAutoRefreshConfigChange( diff --git a/dbt/adapters/redshift/relation_configs/base.py b/dbt/adapters/redshift/relation_configs/base.py index ebbd46b1b..c4faab664 100644 --- a/dbt/adapters/redshift/relation_configs/base.py +++ b/dbt/adapters/redshift/relation_configs/base.py @@ -1,14 +1,14 @@ from dataclasses import dataclass -from typing import Optional +from typing import Optional, Dict import agate from dbt.adapters.base.relation import Policy +from dbt.adapters.contracts.relation import ComponentName, RelationConfig from dbt.adapters.relation_configs import ( RelationConfigBase, RelationResults, ) -from dbt.contracts.graph.nodes import ModelNode -from dbt.contracts.relation import ComponentName +from typing_extensions import Self from dbt.adapters.redshift.relation_configs.policies import ( RedshiftIncludePolicy, @@ -31,25 +31,25 @@ def quote_policy(cls) -> Policy: return RedshiftQuotePolicy() @classmethod - def from_model_node(cls, model_node: ModelNode) -> "RelationConfigBase": - relation_config = cls.parse_model_node(model_node) - relation = cls.from_dict(relation_config) - return relation + def from_relation_config(cls, relation_config: RelationConfig) -> Self: + relation_config_dict = cls.parse_relation_config(relation_config) + relation = cls.from_dict(relation_config_dict) + return relation # type: ignore @classmethod - def parse_model_node(cls, model_node: ModelNode) -> dict: + def parse_relation_config(cls, relation_config: RelationConfig) -> Dict: raise NotImplementedError( - "`parse_model_node()` needs to be implemented on this RelationConfigBase instance" + "`parse_relation_config()` needs to be implemented on this RelationConfigBase instance" ) @classmethod - def from_relation_results(cls, relation_results: RelationResults) -> "RelationConfigBase": + def from_relation_results(cls, relation_results: RelationResults) -> Self: relation_config = cls.parse_relation_results(relation_results) relation = cls.from_dict(relation_config) - return relation + return relation # type: ignore @classmethod - def parse_relation_results(cls, relation_results: RelationResults) -> dict: + def parse_relation_results(cls, relation_results: RelationResults) -> Dict: raise NotImplementedError( "`parse_relation_results()` needs to be implemented on this RelationConfigBase instance" ) diff --git a/dbt/adapters/redshift/relation_configs/dist.py b/dbt/adapters/redshift/relation_configs/dist.py index 668f3f65a..58812ee57 100644 --- a/dbt/adapters/redshift/relation_configs/dist.py +++ b/dbt/adapters/redshift/relation_configs/dist.py @@ -1,5 +1,6 @@ from dataclasses import dataclass -from typing import Optional, Set +from dbt.adapters.contracts.relation import RelationConfig +from typing import Optional, Set, Dict import agate from dbt.adapters.relation_configs import ( @@ -8,9 +9,9 @@ RelationConfigValidationMixin, RelationConfigValidationRule, ) -from dbt.contracts.graph.nodes import ModelNode -from dbt.dataclass_schema import StrEnum -from dbt.exceptions import DbtRuntimeError +from dbt.common.dataclass_schema import StrEnum +from dbt.common.exceptions import DbtRuntimeError +from typing_extensions import Self from dbt.adapters.redshift.relation_configs.base import RedshiftRelationConfigBase @@ -65,21 +66,21 @@ def validation_rules(self) -> Set[RelationConfigValidationRule]: } @classmethod - def from_dict(cls, config_dict) -> "RedshiftDistConfig": + def from_dict(cls, config_dict) -> Self: kwargs_dict = { "diststyle": config_dict.get("diststyle"), "distkey": config_dict.get("distkey"), } - dist: "RedshiftDistConfig" = super().from_dict(kwargs_dict) # type: ignore + dist: Self = super().from_dict(kwargs_dict) # type: ignore return dist @classmethod - def parse_model_node(cls, model_node: ModelNode) -> dict: + def parse_relation_config(cls, relation_config: RelationConfig) -> dict: """ Translate ModelNode objects from the user-provided config into a standard dictionary. Args: - model_node: the description of the distkey and diststyle from the user in this format: + relation_config: the description of the distkey and diststyle from the user in this format: { "dist": any("auto", "even", "all") or "" @@ -87,7 +88,7 @@ def parse_model_node(cls, model_node: ModelNode) -> dict: Returns: a standard dictionary describing this `RedshiftDistConfig` instance """ - dist = model_node.config.extra.get("dist", "") + dist = relation_config.config.extra.get("dist", "") # type: ignore diststyle = dist.lower() @@ -107,7 +108,7 @@ def parse_model_node(cls, model_node: ModelNode) -> dict: return config @classmethod - def parse_relation_results(cls, relation_results_entry: agate.Row) -> dict: + def parse_relation_results(cls, relation_results_entry: agate.Row) -> Dict: """ Translate agate objects from the database into a standard dictionary. diff --git a/dbt/adapters/redshift/relation_configs/materialized_view.py b/dbt/adapters/redshift/relation_configs/materialized_view.py index e19b45547..81f7d2931 100644 --- a/dbt/adapters/redshift/relation_configs/materialized_view.py +++ b/dbt/adapters/redshift/relation_configs/materialized_view.py @@ -1,5 +1,5 @@ from dataclasses import dataclass, field -from typing import Optional, Set +from typing import Optional, Set, Dict, Any import agate from dbt.adapters.relation_configs import ( @@ -8,9 +8,9 @@ RelationConfigValidationMixin, RelationConfigValidationRule, ) -from dbt.contracts.graph.nodes import ModelNode -from dbt.contracts.relation import ComponentName -from dbt.exceptions import DbtRuntimeError +from dbt.adapters.contracts.relation import ComponentName, RelationConfig +from dbt.common.exceptions import DbtRuntimeError +from typing_extensions import Self from dbt.adapters.redshift.relation_configs.base import RedshiftRelationConfigBase from dbt.adapters.redshift.relation_configs.dist import ( @@ -95,7 +95,7 @@ def validation_rules(self) -> Set[RelationConfigValidationRule]: } @classmethod - def from_dict(cls, config_dict) -> "RedshiftMaterializedViewConfig": + def from_dict(cls, config_dict) -> Self: kwargs_dict = { "mv_name": cls._render_part(ComponentName.Identifier, config_dict.get("mv_name")), "schema_name": cls._render_part(ComponentName.Schema, config_dict.get("schema_name")), @@ -114,39 +114,39 @@ def from_dict(cls, config_dict) -> "RedshiftMaterializedViewConfig": if sort := config_dict.get("sort"): kwargs_dict.update({"sort": RedshiftSortConfig.from_dict(sort)}) - materialized_view: "RedshiftMaterializedViewConfig" = super().from_dict(kwargs_dict) # type: ignore + materialized_view: Self = super().from_dict(kwargs_dict) # type: ignore return materialized_view @classmethod - def parse_model_node(cls, model_node: ModelNode) -> dict: - config_dict = { - "mv_name": model_node.identifier, - "schema_name": model_node.schema, - "database_name": model_node.database, + def parse_relation_config(cls, config: RelationConfig) -> Dict[str, Any]: + config_dict: Dict[str, Any] = { + "mv_name": config.identifier, + "schema_name": config.schema, + "database_name": config.database, } # backup/autorefresh can be bools or strings - backup_value = model_node.config.extra.get("backup") + backup_value = config.config.extra.get("backup") # type: ignore if backup_value is not None: config_dict["backup"] = evaluate_bool(backup_value) - autorefresh_value = model_node.config.extra.get("auto_refresh") + autorefresh_value = config.config.extra.get("auto_refresh") # type: ignore if autorefresh_value is not None: config_dict["autorefresh"] = evaluate_bool(autorefresh_value) - if query := model_node.compiled_code: + if query := config.compiled_code: # type: ignore config_dict.update({"query": query.strip()}) - if model_node.config.get("dist"): - config_dict.update({"dist": RedshiftDistConfig.parse_model_node(model_node)}) + if config.config.get("dist"): # type: ignore + config_dict.update({"dist": RedshiftDistConfig.parse_relation_config(config)}) - if model_node.config.get("sort"): - config_dict.update({"sort": RedshiftSortConfig.parse_model_node(model_node)}) + if config.config.get("sort"): # type: ignore + config_dict.update({"sort": RedshiftSortConfig.parse_relation_config(config)}) return config_dict @classmethod - def parse_relation_results(cls, relation_results: RelationResults) -> dict: + def parse_relation_results(cls, relation_results: RelationResults) -> Dict: """ Translate agate objects from the database into a standard dictionary. diff --git a/dbt/adapters/redshift/relation_configs/sort.py b/dbt/adapters/redshift/relation_configs/sort.py index 58104b65f..c97d137bc 100644 --- a/dbt/adapters/redshift/relation_configs/sort.py +++ b/dbt/adapters/redshift/relation_configs/sort.py @@ -1,5 +1,6 @@ from dataclasses import dataclass -from typing import Optional, FrozenSet, Set +from dbt.adapters.contracts.relation import RelationConfig +from typing import Optional, FrozenSet, Set, Dict, Any import agate from dbt.adapters.relation_configs import ( @@ -8,9 +9,9 @@ RelationConfigValidationMixin, RelationConfigValidationRule, ) -from dbt.contracts.graph.nodes import ModelNode -from dbt.dataclass_schema import StrEnum -from dbt.exceptions import DbtRuntimeError +from dbt.common.dataclass_schema import StrEnum +from dbt.common.exceptions import DbtRuntimeError +from typing_extensions import Self from dbt.adapters.redshift.relation_configs.base import RedshiftRelationConfigBase @@ -97,21 +98,21 @@ def validation_rules(self) -> Set[RelationConfigValidationRule]: } @classmethod - def from_dict(cls, config_dict) -> "RedshiftSortConfig": + def from_dict(cls, config_dict) -> Self: kwargs_dict = { "sortstyle": config_dict.get("sortstyle"), "sortkey": frozenset(column for column in config_dict.get("sortkey", {})), } - sort: "RedshiftSortConfig" = super().from_dict(kwargs_dict) # type: ignore - return sort + sort: Self = super().from_dict(kwargs_dict) # type: ignore + return sort # type: ignore @classmethod - def parse_model_node(cls, model_node: ModelNode) -> dict: + def parse_relation_config(cls, relation_config: RelationConfig) -> Dict[str, Any]: """ Translate ModelNode objects from the user-provided config into a standard dictionary. Args: - model_node: the description of the sortkey and sortstyle from the user in this format: + relation_config: the description of the sortkey and sortstyle from the user in this format: { "sort_key": "" or [""] or ["",...] @@ -122,10 +123,10 @@ def parse_model_node(cls, model_node: ModelNode) -> dict: """ config_dict = {} - if sortstyle := model_node.config.extra.get("sort_type"): + if sortstyle := relation_config.config.extra.get("sort_type"): # type: ignore config_dict.update({"sortstyle": sortstyle.lower()}) - if sortkey := model_node.config.extra.get("sort"): + if sortkey := relation_config.config.extra.get("sort"): # type: ignore # we allow users to specify the `sort_key` as a string if it's a single column if isinstance(sortkey, str): sortkey = [sortkey] diff --git a/dbt/include/redshift/macros/materializations/materialized_view.sql b/dbt/include/redshift/macros/materializations/materialized_view.sql index 5cdb26504..9b1ef2d50 100644 --- a/dbt/include/redshift/macros/materializations/materialized_view.sql +++ b/dbt/include/redshift/macros/materializations/materialized_view.sql @@ -1,5 +1,5 @@ {% macro redshift__get_materialized_view_configuration_changes(existing_relation, new_config) %} {% set _existing_materialized_view = redshift__describe_materialized_view(existing_relation) %} - {% set _configuration_changes = existing_relation.materialized_view_config_changeset(_existing_materialized_view, new_config) %} + {% set _configuration_changes = existing_relation.materialized_view_config_changeset(_existing_materialized_view, new_config.model) %} {% do return(_configuration_changes) %} {% endmacro %} diff --git a/dbt/include/redshift/macros/relations/materialized_view/create.sql b/dbt/include/redshift/macros/relations/materialized_view/create.sql index b84680525..06fe2b6b5 100644 --- a/dbt/include/redshift/macros/relations/materialized_view/create.sql +++ b/dbt/include/redshift/macros/relations/materialized_view/create.sql @@ -1,6 +1,6 @@ {% macro redshift__get_create_materialized_view_as_sql(relation, sql) %} - {%- set materialized_view = relation.from_runtime_config(config) -%} + {%- set materialized_view = relation.from_config(config.model) -%} create materialized view {{ materialized_view.path }} backup {% if materialized_view.backup %}yes{% else %}no{% endif %} diff --git a/tests/functional/adapter/incremental/test_incremental_strategies.py b/tests/functional/adapter/incremental/test_incremental_strategies.py index ed27be392..b8a4c4656 100644 --- a/tests/functional/adapter/incremental/test_incremental_strategies.py +++ b/tests/functional/adapter/incremental/test_incremental_strategies.py @@ -1,6 +1,6 @@ import pytest from dbt.tests.util import run_dbt, get_manifest -from dbt.exceptions import DbtRuntimeError +from dbt.common.exceptions import DbtRuntimeError from dbt.context.providers import generate_runtime_model_context diff --git a/tests/unit/relation_configs/test_materialized_view.py b/tests/unit/relation_configs/test_materialized_view.py index 42a3223d0..5e454fe5e 100644 --- a/tests/unit/relation_configs/test_materialized_view.py +++ b/tests/unit/relation_configs/test_materialized_view.py @@ -17,7 +17,7 @@ def test_redshift_materialized_view_config_handles_all_valid_bools(bool_value): model_node.config.extra.get = ( lambda x, y=None: bool_value if x in ["auto_refresh", "backup"] else "someDistValue" ) - config_dict = config.parse_model_node(model_node) + config_dict = config.parse_relation_config(model_node) assert isinstance(config_dict["autorefresh"], bool) assert isinstance(config_dict["backup"], bool) @@ -37,7 +37,7 @@ def test_redshift_materialized_view_config_throws_expected_exception_with_invali lambda x, y=None: bool_value if x in ["auto_refresh", "backup"] else "someDistValue" ) with pytest.raises(TypeError): - config.parse_model_node(model_node) + config.parse_relation_config(model_node) def test_redshift_materialized_view_config_throws_expected_exception_with_invalid_str(): @@ -52,4 +52,4 @@ def test_redshift_materialized_view_config_throws_expected_exception_with_invali lambda x, y=None: "notABool" if x in ["auto_refresh", "backup"] else "someDistValue" ) with pytest.raises(ValueError): - config.parse_model_node(model_node) + config.parse_relation_config(model_node) diff --git a/tests/unit/test_context.py b/tests/unit/test_context.py deleted file mode 100644 index 31c436d82..000000000 --- a/tests/unit/test_context.py +++ /dev/null @@ -1,231 +0,0 @@ -import os -import pytest -import unittest - -from unittest import mock - -from .utils import config_from_parts_or_dicts, inject_adapter, clear_plugin -from .mock_adapter import adapter_factory -import dbt.exceptions - -from dbt.adapters import ( - redshift, - factory, -) -from dbt.contracts.graph.model_config import ( - NodeConfig, -) -from dbt.contracts.graph.nodes import ModelNode, DependsOn, Macro -from dbt.context import providers -from dbt.node_types import NodeType - - -class TestRuntimeWrapper(unittest.TestCase): - def setUp(self): - self.mock_config = mock.MagicMock() - self.mock_config.quoting = {"database": True, "schema": True, "identifier": True} - adapter_class = adapter_factory() - self.mock_adapter = adapter_class(self.mock_config) - self.namespace = mock.MagicMock() - self.wrapper = providers.RuntimeDatabaseWrapper(self.mock_adapter, self.namespace) - self.responder = self.mock_adapter.responder - - -PROFILE_DATA = { - "target": "test", - "quoting": {}, - "outputs": { - "test": { - "type": "redshift", - "host": "localhost", - "schema": "analytics", - "user": "test", - "pass": "test", - "dbname": "test", - "port": 1, - } - }, -} - - -PROJECT_DATA = { - "name": "root", - "version": "0.1", - "profile": "test", - "project-root": os.getcwd(), - "config-version": 2, -} - - -def model(): - return ModelNode( - alias="model_one", - name="model_one", - database="dbt", - schema="analytics", - resource_type=NodeType.Model, - unique_id="model.root.model_one", - fqn=["root", "model_one"], - package_name="root", - original_file_path="model_one.sql", - root_path="/usr/src/app", - refs=[], - sources=[], - depends_on=DependsOn(), - config=NodeConfig.from_dict( - { - "enabled": True, - "materialized": "view", - "persist_docs": {}, - "post-hook": [], - "pre-hook": [], - "vars": {}, - "quoting": {}, - "column_types": {}, - "tags": [], - } - ), - tags=[], - path="model_one.sql", - raw_sql="", - description="", - columns={}, - ) - - -def mock_macro(name, package_name): - macro = mock.MagicMock( - __class__=Macro, - package_name=package_name, - resource_type="macro", - unique_id=f"macro.{package_name}.{name}", - ) - # Mock(name=...) does not set the `name` attribute, this does. - macro.name = name - return macro - - -def mock_manifest(config): - manifest_macros = {} - for name in ["macro_a", "macro_b"]: - macro = mock_macro(name, config.project_name) - manifest_macros[macro.unique_id] = macro - return mock.MagicMock(macros=manifest_macros) - - -def mock_model(): - return mock.MagicMock( - __class__=ModelNode, - alias="model_one", - name="model_one", - database="dbt", - schema="analytics", - resource_type=NodeType.Model, - unique_id="model.root.model_one", - fqn=["root", "model_one"], - package_name="root", - original_file_path="model_one.sql", - root_path="/usr/src/app", - refs=[], - sources=[], - depends_on=DependsOn(), - config=NodeConfig.from_dict( - { - "enabled": True, - "materialized": "view", - "persist_docs": {}, - "post-hook": [], - "pre-hook": [], - "vars": {}, - "quoting": {}, - "column_types": {}, - "tags": [], - } - ), - tags=[], - path="model_one.sql", - raw_sql="", - description="", - columns={}, - defer_relation=None, - ) - - -@pytest.fixture -def get_adapter(): - with mock.patch.object(providers, "get_adapter") as patch: - yield patch - - -@pytest.fixture -def get_include_paths(): - with mock.patch.object(factory, "get_include_paths") as patch: - patch.return_value = [] - yield patch - - -@pytest.fixture -def config(): - return config_from_parts_or_dicts(PROJECT_DATA, PROFILE_DATA) - - -@pytest.fixture -def manifest_fx(config): - return mock_manifest(config) - - -@pytest.fixture -def manifest_extended(manifest_fx): - dbt_macro = mock_macro("default__some_macro", "dbt") - # same namespace, same name, different pkg! - rs_macro = mock_macro("redshift__some_macro", "dbt_redshift") - # same name, different package - package_default_macro = mock_macro("default__some_macro", "root") - package_rs_macro = mock_macro("redshift__some_macro", "root") - manifest_fx.macros[dbt_macro.unique_id] = dbt_macro - manifest_fx.macros[rs_macro.unique_id] = rs_macro - manifest_fx.macros[package_default_macro.unique_id] = package_default_macro - manifest_fx.macros[package_rs_macro.unique_id] = package_rs_macro - return manifest_fx - - -@pytest.fixture -def redshift_adapter(config, get_adapter): - adapter = redshift.RedshiftAdapter(config) - inject_adapter(adapter, redshift.Plugin) - get_adapter.return_value = adapter - yield adapter - clear_plugin(redshift.Plugin) - - -def test_resolve_specific(config, manifest_extended, redshift_adapter, get_include_paths): - rs_macro = manifest_extended.macros["macro.dbt_redshift.redshift__some_macro"] - package_rs_macro = manifest_extended.macros["macro.root.redshift__some_macro"] - - ctx = providers.generate_runtime_model_context( - model=mock_model(), - config=config, - manifest=manifest_extended, - ) - - ctx["adapter"].config.dispatch - - # macro_a exists, but default__macro_a and redshift__macro_a do not - with pytest.raises(dbt.exceptions.CompilationError): - ctx["adapter"].dispatch("macro_a").macro - - # root namespace is always preferred, unless search order is explicitly defined in 'dispatch' config - assert ctx["adapter"].dispatch("some_macro").macro is package_rs_macro - assert ctx["adapter"].dispatch("some_macro", "dbt").macro is package_rs_macro - assert ctx["adapter"].dispatch("some_macro", "root").macro is package_rs_macro - - # override 'dbt' namespace search order, dispatch to 'root' first - ctx["adapter"].config.dispatch = [{"macro_namespace": "dbt", "search_order": ["root", "dbt"]}] - assert ctx["adapter"].dispatch("some_macro", macro_namespace="dbt").macro is package_rs_macro - - # override 'dbt' namespace search order, dispatch to 'dbt' only - ctx["adapter"].config.dispatch = [{"macro_namespace": "dbt", "search_order": ["dbt"]}] - assert ctx["adapter"].dispatch("some_macro", macro_namespace="dbt").macro is rs_macro - - # override 'root' namespace search order, dispatch to 'dbt' first - ctx["adapter"].config.dispatch = [{"macro_namespace": "root", "search_order": ["dbt", "root"]}] diff --git a/tests/unit/test_redshift_adapter.py b/tests/unit/test_redshift_adapter.py index aeb9d6417..feb846892 100644 --- a/tests/unit/test_redshift_adapter.py +++ b/tests/unit/test_redshift_adapter.py @@ -1,4 +1,6 @@ import unittest + +from multiprocessing import get_context from unittest import mock from unittest.mock import Mock, call @@ -10,8 +12,8 @@ RedshiftAdapter, Plugin as RedshiftPlugin, ) -from dbt.clients import agate_helper -from dbt.exceptions import FailedToConnectError +from dbt.common.clients import agate_helper +from dbt.adapters.exceptions import FailedToConnectError from dbt.adapters.redshift.connections import RedshiftConnectMethodFactory, RedshiftSSLConfig from .utils import ( config_from_parts_or_dicts, @@ -59,7 +61,7 @@ def setUp(self): @property def adapter(self): if self._adapter is None: - self._adapter = RedshiftAdapter(self.config) + self._adapter = RedshiftAdapter(self.config, get_context("spawn")) inject_adapter(self._adapter, RedshiftPlugin) return self._adapter @@ -235,7 +237,7 @@ def test_explicit_region_failure(self): region=None, ) - with self.assertRaises(dbt.exceptions.FailedToConnectError): + with self.assertRaises(dbt.adapters.exceptions.FailedToConnectError): connection = self.adapter.acquire_connection("dummy") connection.handle redshift_connector.connect.assert_called_once_with( @@ -264,7 +266,7 @@ def test_explicit_invalid_region(self): region=None, ) - with self.assertRaises(dbt.exceptions.FailedToConnectError): + with self.assertRaises(dbt.adapters.exceptions.FailedToConnectError): connection = self.adapter.acquire_connection("dummy") connection.handle redshift_connector.connect.assert_called_once_with( @@ -385,7 +387,7 @@ def test_serverless_iam_failure(self): iam_profile="test", host="doesnotexist.1233.us-east-2.redshift-srvrlss.amazonaws.com", ) - with self.assertRaises(dbt.exceptions.FailedToConnectError) as context: + with self.assertRaises(dbt.adapters.exceptions.FailedToConnectError) as context: connection = self.adapter.acquire_connection("dummy") connection.handle redshift_connector.connect.assert_called_once_with( @@ -507,12 +509,12 @@ def test_dbname_verification_is_case_insensitive(self): } self.config = config_from_parts_or_dicts(project_cfg, profile_cfg) self.adapter.cleanup_connections() - self._adapter = RedshiftAdapter(self.config) + self._adapter = RedshiftAdapter(self.config, get_context("spawn")) self.adapter.verify_database("redshift") def test_execute_with_fetch(self): cursor = mock.Mock() - table = dbt.clients.agate_helper.empty_table() + table = dbt.common.clients.agate_helper.empty_table() with mock.patch.object(self.adapter.connections, "add_query") as mock_add_query: mock_add_query.return_value = ( None, @@ -552,7 +554,7 @@ def test_add_query_with_no_cursor(self): ) as mock_get_thread_connection: mock_get_thread_connection.return_value = None with self.assertRaisesRegex( - dbt.exceptions.DbtRuntimeError, "Tried to run invalid SQL: on " + dbt.common.exceptions.DbtRuntimeError, "Tried to run invalid SQL: on " ): self.adapter.connections.add_query(sql="") mock_get_thread_connection.assert_called_once() diff --git a/tests/unit/utils.py b/tests/unit/utils.py index f2ca418e3..8b6b85501 100644 --- a/tests/unit/utils.py +++ b/tests/unit/utils.py @@ -9,7 +9,7 @@ import agate import pytest -from dbt.dataclass_schema import ValidationError +from dbt.common.dataclass_schema import ValidationError from dbt.config.project import PartialProject @@ -233,7 +233,7 @@ def assert_fails_validation(dct, cls): class TestAdapterConversions(TestCase): @staticmethod def _get_tester_for(column_type): - from dbt.clients import agate_helper + from dbt.common.clients import agate_helper if column_type is agate.TimeDelta: # dbt never makes this! return agate.TimeDelta()