diff --git a/.bumpversion.cfg b/.bumpversion.cfg index 744284849..4de02c345 100644 --- a/.bumpversion.cfg +++ b/.bumpversion.cfg @@ -4,7 +4,7 @@ parse = (?P\d+) \.(?P\d+) \.(?P\d+) ((?Pa|b|rc)(?P\d+))? -serialize = +serialize = {major}.{minor}.{patch}{prerelease}{num} {major}.{minor}.{patch} commit = False @@ -13,7 +13,7 @@ tag = False [bumpversion:part:prerelease] first_value = a optional_value = final -values = +values = a b rc @@ -25,4 +25,3 @@ first_value = 1 [bumpversion:file:setup.py] [bumpversion:file:dbt/adapters/spark/__version__.py] - diff --git a/.github/ISSUE_TEMPLATE/dependabot.yml b/.github/ISSUE_TEMPLATE/dependabot.yml index 8a8c85b9f..2a6f34492 100644 --- a/.github/ISSUE_TEMPLATE/dependabot.yml +++ b/.github/ISSUE_TEMPLATE/dependabot.yml @@ -5,4 +5,4 @@ updates: directory: "/" schedule: interval: "daily" - rebase-strategy: "disabled" \ No newline at end of file + rebase-strategy: "disabled" diff --git a/.github/ISSUE_TEMPLATE/release.md b/.github/ISSUE_TEMPLATE/release.md index ac28792a3..a69349f54 100644 --- a/.github/ISSUE_TEMPLATE/release.md +++ b/.github/ISSUE_TEMPLATE/release.md @@ -7,4 +7,4 @@ assignees: '' --- -### TBD \ No newline at end of file +### TBD diff --git a/.github/pull_request_template.md b/.github/pull_request_template.md index 60e12779b..5928b1cbf 100644 --- a/.github/pull_request_template.md +++ b/.github/pull_request_template.md @@ -18,4 +18,4 @@ resolves # - [ ] I have signed the [CLA](https://docs.getdbt.com/docs/contributor-license-agreements) - [ ] I have run this code in development and it appears to resolve the stated issue - [ ] This PR includes tests, or tests are not required/relevant for this PR -- [ ] I have updated the `CHANGELOG.md` and added information about my change to the "dbt-spark next" section. \ No newline at end of file +- [ ] I have updated the `CHANGELOG.md` and added information about my change to the "dbt-spark next" section. diff --git a/.github/workflows/jira-creation.yml b/.github/workflows/jira-creation.yml index c84e106a7..b4016befc 100644 --- a/.github/workflows/jira-creation.yml +++ b/.github/workflows/jira-creation.yml @@ -13,7 +13,7 @@ name: Jira Issue Creation on: issues: types: [opened, labeled] - + permissions: issues: write diff --git a/.github/workflows/jira-label.yml b/.github/workflows/jira-label.yml index fd533a170..3da2e3a38 100644 --- a/.github/workflows/jira-label.yml +++ b/.github/workflows/jira-label.yml @@ -13,7 +13,7 @@ name: Jira Label Mirroring on: issues: types: [labeled, unlabeled] - + permissions: issues: read @@ -24,4 +24,3 @@ jobs: JIRA_BASE_URL: ${{ secrets.JIRA_BASE_URL }} JIRA_USER_EMAIL: ${{ secrets.JIRA_USER_EMAIL }} JIRA_API_TOKEN: ${{ secrets.JIRA_API_TOKEN }} - diff --git a/.github/workflows/jira-transition.yml b/.github/workflows/jira-transition.yml index 71273c7a9..ed9f9cd4f 100644 --- a/.github/workflows/jira-transition.yml +++ b/.github/workflows/jira-transition.yml @@ -21,4 +21,4 @@ jobs: secrets: JIRA_BASE_URL: ${{ secrets.JIRA_BASE_URL }} JIRA_USER_EMAIL: ${{ secrets.JIRA_USER_EMAIL }} - JIRA_API_TOKEN: ${{ secrets.JIRA_API_TOKEN }} \ No newline at end of file + JIRA_API_TOKEN: ${{ secrets.JIRA_API_TOKEN }} diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index b25ea884e..3b9f7c858 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -3,23 +3,23 @@ name: Build and Release on: workflow_dispatch: - + # Release version number that must be updated for each release env: version_number: '0.20.0rc2' -jobs: +jobs: Test: runs-on: ubuntu-latest steps: - name: Setup Python uses: actions/setup-python@v2.2.2 - with: + with: python-version: '3.8' - + - uses: actions/checkout@v2 - - name: Test release + - name: Test release run: | python3 -m venv env source env/bin/activate @@ -38,9 +38,9 @@ jobs: steps: - name: Setup Python uses: actions/setup-python@v2.2.2 - with: + with: python-version: '3.8' - + - uses: actions/checkout@v2 - name: Bumping version @@ -60,7 +60,7 @@ jobs: author_email: 'leah.antkiewicz@dbtlabs.com' message: 'Bumping version to ${{env.version_number}}' tag: v${{env.version_number}} - + # Need to set an output variable because env variables can't be taken as input # This is needed for the next step with releasing to GitHub - name: Find release type @@ -69,7 +69,7 @@ jobs: IS_PRERELEASE: ${{ contains(env.version_number, 'rc') || contains(env.version_number, 'b') }} run: | echo ::set-output name=isPrerelease::$IS_PRERELEASE - + - name: Create GitHub release uses: actions/create-release@v1 env: @@ -88,7 +88,7 @@ jobs: # or $ pip install "dbt-spark[PyHive]==${{env.version_number}}" ``` - + PypiRelease: name: Pypi release runs-on: ubuntu-latest @@ -97,13 +97,13 @@ jobs: steps: - name: Setup Python uses: actions/setup-python@v2.2.2 - with: + with: python-version: '3.8' - + - uses: actions/checkout@v2 with: ref: v${{env.version_number}} - + - name: Release to pypi env: TWINE_USERNAME: ${{ secrets.PYPI_USERNAME }} @@ -116,4 +116,3 @@ jobs: pip install twine wheel setuptools python setup.py sdist bdist_wheel twine upload --non-interactive dist/dbt_spark-${{env.version_number}}-py3-none-any.whl dist/dbt-spark-${{env.version_number}}.tar.gz - diff --git a/.github/workflows/version-bump.yml b/.github/workflows/version-bump.yml index 7fb8bb6eb..15da2eee8 100644 --- a/.github/workflows/version-bump.yml +++ b/.github/workflows/version-bump.yml @@ -1,16 +1,16 @@ # **what?** # This workflow will take a version number and a dry run flag. With that -# it will run versionbump to update the version number everywhere in the +# it will run versionbump to update the version number everywhere in the # code base and then generate an update Docker requirements file. If this # is a dry run, a draft PR will open with the changes. If this isn't a dry # run, the changes will be committed to the branch this is run on. # **why?** -# This is to aid in releasing dbt and making sure we have updated +# This is to aid in releasing dbt and making sure we have updated # the versions and Docker requirements in all places. # **when?** -# This is triggered either manually OR +# This is triggered either manually OR # from the repository_dispatch event "version-bump" which is sent from # the dbt-release repo Action @@ -25,11 +25,11 @@ on: is_dry_run: description: 'Creates a draft PR to allow testing instead of committing to a branch' required: true - default: 'true' + default: 'true' repository_dispatch: types: [version-bump] -jobs: +jobs: bump: runs-on: ubuntu-latest steps: @@ -58,19 +58,19 @@ jobs: sudo apt-get install libsasl2-dev python3 -m venv env source env/bin/activate - pip install --upgrade pip - + pip install --upgrade pip + - name: Create PR branch if: ${{ steps.variables.outputs.IS_DRY_RUN == 'true' }} run: | git checkout -b bumping-version/${{steps.variables.outputs.VERSION_NUMBER}}_$GITHUB_RUN_ID git push origin bumping-version/${{steps.variables.outputs.VERSION_NUMBER}}_$GITHUB_RUN_ID git branch --set-upstream-to=origin/bumping-version/${{steps.variables.outputs.VERSION_NUMBER}}_$GITHUB_RUN_ID bumping-version/${{steps.variables.outputs.VERSION_NUMBER}}_$GITHUB_RUN_ID - + - name: Bumping version run: | source env/bin/activate - pip install -r dev_requirements.txt + pip install -r dev_requirements.txt env/bin/bumpversion --allow-dirty --new-version ${{steps.variables.outputs.VERSION_NUMBER}} major git status @@ -100,4 +100,4 @@ jobs: draft: true base: ${{github.ref}} title: 'Bumping version to ${{steps.variables.outputs.VERSION_NUMBER}}' - branch: 'bumping-version/${{steps.variables.outputs.VERSION_NUMBER}}_${{GITHUB.RUN_ID}}' + branch: 'bumping-version/${{steps.variables.outputs.VERSION_NUMBER}}_${{GITHUB.RUN_ID}}' diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index ccaa093bf..e70156dcd 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -42,7 +42,7 @@ repos: alias: flake8-check stages: [manual] - repo: https://github.com/pre-commit/mirrors-mypy - rev: v0.782 + rev: v0.950 hooks: - id: mypy # N.B.: Mypy is... a bit fragile. diff --git a/MANIFEST.in b/MANIFEST.in index 78412d5b8..cfbc714ed 100644 --- a/MANIFEST.in +++ b/MANIFEST.in @@ -1 +1 @@ -recursive-include dbt/include *.sql *.yml *.md \ No newline at end of file +recursive-include dbt/include *.sql *.yml *.md diff --git a/dbt/adapters/spark/__init__.py b/dbt/adapters/spark/__init__.py index 469e202b9..6ecc5eccf 100644 --- a/dbt/adapters/spark/__init__.py +++ b/dbt/adapters/spark/__init__.py @@ -8,6 +8,5 @@ from dbt.include import spark Plugin = AdapterPlugin( - adapter=SparkAdapter, - credentials=SparkCredentials, - include_path=spark.PACKAGE_PATH) + adapter=SparkAdapter, credentials=SparkCredentials, include_path=spark.PACKAGE_PATH +) diff --git a/dbt/adapters/spark/column.py b/dbt/adapters/spark/column.py index fd377ad15..4df6b301b 100644 --- a/dbt/adapters/spark/column.py +++ b/dbt/adapters/spark/column.py @@ -1,11 +1,11 @@ from dataclasses import dataclass -from typing import TypeVar, Optional, Dict, Any +from typing import Any, Dict, Optional, TypeVar, Union from dbt.adapters.base.column import Column from dbt.dataclass_schema import dbtClassMixin from hologram import JsonDict -Self = TypeVar('Self', bound='SparkColumn') +Self = TypeVar("Self", bound="SparkColumn") @dataclass @@ -31,7 +31,7 @@ def literal(self, value): @property def quoted(self) -> str: - return '`{}`'.format(self.column) + return "`{}`".format(self.column) @property def data_type(self) -> str: @@ -42,26 +42,23 @@ def __repr__(self) -> str: @staticmethod def convert_table_stats(raw_stats: Optional[str]) -> Dict[str, Any]: - table_stats = {} + table_stats: Dict[str, Union[int, str, bool]] = {} if raw_stats: # format: 1109049927 bytes, 14093476 rows stats = { - stats.split(" ")[1]: int(stats.split(" ")[0]) - for stats in raw_stats.split(', ') + stats.split(" ")[1]: int(stats.split(" ")[0]) for stats in raw_stats.split(", ") } for key, val in stats.items(): - table_stats[f'stats:{key}:label'] = key - table_stats[f'stats:{key}:value'] = val - table_stats[f'stats:{key}:description'] = '' - table_stats[f'stats:{key}:include'] = True + table_stats[f"stats:{key}:label"] = key + table_stats[f"stats:{key}:value"] = val + table_stats[f"stats:{key}:description"] = "" + table_stats[f"stats:{key}:include"] = True return table_stats - def to_column_dict( - self, omit_none: bool = True, validate: bool = False - ) -> JsonDict: + def to_column_dict(self, omit_none: bool = True, validate: bool = False) -> JsonDict: original_dict = self.to_dict(omit_none=omit_none) # If there are stats, merge them into the root of the dict - original_stats = original_dict.pop('table_stats', None) + original_stats = original_dict.pop("table_stats", None) if original_stats: original_dict.update(original_stats) return original_dict diff --git a/dbt/adapters/spark/connections.py b/dbt/adapters/spark/connections.py index 11163ccf0..59ceb9dd8 100644 --- a/dbt/adapters/spark/connections.py +++ b/dbt/adapters/spark/connections.py @@ -26,6 +26,7 @@ from hologram.helpers import StrEnum from dataclasses import dataclass, field from typing import Any, Dict, Optional + try: from thrift.transport.TSSLSocket import TSSLSocket import thrift @@ -33,11 +34,7 @@ import sasl import thrift_sasl except ImportError: - TSSLSocket = None - thrift = None - ssl = None - sasl = None - thrift_sasl = None + pass # done deliberately: setting modules to None explicitly violates MyPy contracts by degrading type semantics import base64 import time @@ -52,10 +49,10 @@ def _build_odbc_connnection_string(**kwargs) -> str: class SparkConnectionMethod(StrEnum): - THRIFT = 'thrift' - HTTP = 'http' - ODBC = 'odbc' - SESSION = 'session' + THRIFT = "thrift" + HTTP = "http" + ODBC = "odbc" + SESSION = "session" @dataclass @@ -71,7 +68,7 @@ class SparkCredentials(Credentials): port: int = 443 auth: Optional[str] = None kerberos_service_name: Optional[str] = None - organization: str = '0' + organization: str = "0" connect_retries: int = 0 connect_timeout: int = 10 use_ssl: bool = False @@ -81,27 +78,24 @@ class SparkCredentials(Credentials): @classmethod def __pre_deserialize__(cls, data): data = super().__pre_deserialize__(data) - if 'database' not in data: - data['database'] = None + if "database" not in data: + data["database"] = None return data def __post_init__(self): # spark classifies database and schema as the same thing - if ( - self.database is not None and - self.database != self.schema - ): + if self.database is not None and self.database != self.schema: raise dbt.exceptions.RuntimeException( - f' schema: {self.schema} \n' - f' database: {self.database} \n' - f'On Spark, database must be omitted or have the same value as' - f' schema.' + f" schema: {self.schema} \n" + f" database: {self.database} \n" + f"On Spark, database must be omitted or have the same value as" + f" schema." ) self.database = None if self.method == SparkConnectionMethod.ODBC: try: - import pyodbc # noqa: F401 + import pyodbc # noqa: F401 except ImportError as e: raise dbt.exceptions.RuntimeException( f"{self.method} connection method requires " @@ -111,22 +105,16 @@ def __post_init__(self): f"ImportError({e.msg})" ) from e - if ( - self.method == SparkConnectionMethod.ODBC and - self.cluster and - self.endpoint - ): + if self.method == SparkConnectionMethod.ODBC and self.cluster and self.endpoint: raise dbt.exceptions.RuntimeException( "`cluster` and `endpoint` cannot both be set when" f" using {self.method} method to connect to Spark" ) if ( - self.method == SparkConnectionMethod.HTTP or - self.method == SparkConnectionMethod.THRIFT - ) and not ( - ThriftState and THttpClient and hive - ): + self.method == SparkConnectionMethod.HTTP + or self.method == SparkConnectionMethod.THRIFT + ) and not (ThriftState and THttpClient and hive): raise dbt.exceptions.RuntimeException( f"{self.method} connection method requires " "additional dependencies. \n" @@ -148,19 +136,19 @@ def __post_init__(self): @property def type(self): - return 'spark' + return "spark" @property def unique_field(self): return self.host def _connection_keys(self): - return ('host', 'port', 'cluster', - 'endpoint', 'schema', 'organization') + return ("host", "port", "cluster", "endpoint", "schema", "organization") class PyhiveConnectionWrapper(object): """Wrap a Spark connection in a way that no-ops transactions""" + # https://forums.databricks.com/questions/2157/in-apache-spark-sql-can-we-roll-back-the-transacti.html # noqa def __init__(self, handle): @@ -178,9 +166,7 @@ def cancel(self): try: self._cursor.cancel() except EnvironmentError as exc: - logger.debug( - "Exception while cancelling query: {}".format(exc) - ) + logger.debug("Exception while cancelling query: {}".format(exc)) def close(self): if self._cursor: @@ -189,9 +175,7 @@ def close(self): try: self._cursor.close() except EnvironmentError as exc: - logger.debug( - "Exception while closing cursor: {}".format(exc) - ) + logger.debug("Exception while closing cursor: {}".format(exc)) self.handle.close() def rollback(self, *args, **kwargs): @@ -247,23 +231,20 @@ def execute(self, sql, bindings=None): dbt.exceptions.raise_database_error(poll_state.errorMessage) elif state not in STATE_SUCCESS: - status_type = ThriftState._VALUES_TO_NAMES.get( - state, - 'Unknown<{!r}>'.format(state)) + status_type = ThriftState._VALUES_TO_NAMES.get(state, "Unknown<{!r}>".format(state)) - dbt.exceptions.raise_database_error( - "Query failed with status: {}".format(status_type)) + dbt.exceptions.raise_database_error("Query failed with status: {}".format(status_type)) logger.debug("Poll status: {}, query complete".format(state)) @classmethod def _fix_binding(cls, value): """Convert complex datatypes to primitives that can be loaded by - the Spark driver""" + the Spark driver""" if isinstance(value, NUMBERS): return float(value) elif isinstance(value, datetime): - return value.strftime('%Y-%m-%d %H:%M:%S.%f')[:-3] + return value.strftime("%Y-%m-%d %H:%M:%S.%f")[:-3] else: return value @@ -273,7 +254,6 @@ def description(self): class PyodbcConnectionWrapper(PyhiveConnectionWrapper): - def execute(self, sql, bindings=None): if sql.strip().endswith(";"): sql = sql.strip()[:-1] @@ -282,19 +262,17 @@ def execute(self, sql, bindings=None): self._cursor.execute(sql) else: # pyodbc only supports `qmark` sql params! - query = sqlparams.SQLParams('format', 'qmark') + query = sqlparams.SQLParams("format", "qmark") sql, bindings = query.format(sql, bindings) self._cursor.execute(sql, *bindings) class SparkConnectionManager(SQLConnectionManager): - TYPE = 'spark' + TYPE = "spark" SPARK_CLUSTER_HTTP_PATH = "/sql/protocolv1/o/{organization}/{cluster}" SPARK_SQL_ENDPOINT_HTTP_PATH = "/sql/1.0/endpoints/{endpoint}" - SPARK_CONNECTION_URL = ( - "{host}:{port}" + SPARK_CLUSTER_HTTP_PATH - ) + SPARK_CONNECTION_URL = "{host}:{port}" + SPARK_CLUSTER_HTTP_PATH @contextmanager def exception_handler(self, sql): @@ -308,7 +286,7 @@ def exception_handler(self, sql): raise thrift_resp = exc.args[0] - if hasattr(thrift_resp, 'status'): + if hasattr(thrift_resp, "status"): msg = thrift_resp.status.errorMessage raise dbt.exceptions.RuntimeException(msg) else: @@ -320,10 +298,8 @@ def cancel(self, connection): @classmethod def get_response(cls, cursor) -> AdapterResponse: # https://github.com/dbt-labs/dbt-spark/issues/142 - message = 'OK' - return AdapterResponse( - _message=message - ) + message = "OK" + return AdapterResponse(_message=message) # No transactions on Spark.... def add_begin_query(self, *args, **kwargs): @@ -346,12 +322,13 @@ def validate_creds(cls, creds, required): if not hasattr(creds, key): raise dbt.exceptions.DbtProfileError( "The config '{}' is required when using the {} method" - " to connect to Spark".format(key, method)) + " to connect to Spark".format(key, method) + ) @classmethod def open(cls, connection): if connection.state == ConnectionState.OPEN: - logger.debug('Connection is already open, skipping open.') + logger.debug("Connection is already open, skipping open.") return connection creds = connection.credentials @@ -360,19 +337,18 @@ def open(cls, connection): for i in range(1 + creds.connect_retries): try: if creds.method == SparkConnectionMethod.HTTP: - cls.validate_creds(creds, ['token', 'host', 'port', - 'cluster', 'organization']) + cls.validate_creds(creds, ["token", "host", "port", "cluster", "organization"]) # Prepend https:// if it is missing host = creds.host - if not host.startswith('https://'): - host = 'https://' + creds.host + if not host.startswith("https://"): + host = "https://" + creds.host conn_url = cls.SPARK_CONNECTION_URL.format( host=host, port=creds.port, organization=creds.organization, - cluster=creds.cluster + cluster=creds.cluster, ) logger.debug("connection url: {}".format(conn_url)) @@ -381,15 +357,12 @@ def open(cls, connection): raw_token = "token:{}".format(creds.token).encode() token = base64.standard_b64encode(raw_token).decode() - transport.setCustomHeaders({ - 'Authorization': 'Basic {}'.format(token) - }) + transport.setCustomHeaders({"Authorization": "Basic {}".format(token)}) conn = hive.connect(thrift_transport=transport) handle = PyhiveConnectionWrapper(conn) elif creds.method == SparkConnectionMethod.THRIFT: - cls.validate_creds(creds, - ['host', 'port', 'user', 'schema']) + cls.validate_creds(creds, ["host", "port", "user", "schema"]) if creds.use_ssl: transport = build_ssl_transport( @@ -397,26 +370,33 @@ def open(cls, connection): port=creds.port, username=creds.user, auth=creds.auth, - kerberos_service_name=creds.kerberos_service_name) + kerberos_service_name=creds.kerberos_service_name, + ) conn = hive.connect(thrift_transport=transport) else: - conn = hive.connect(host=creds.host, - port=creds.port, - username=creds.user, - auth=creds.auth, - kerberos_service_name=creds.kerberos_service_name) # noqa + conn = hive.connect( + host=creds.host, + port=creds.port, + username=creds.user, + auth=creds.auth, + kerberos_service_name=creds.kerberos_service_name, + ) # noqa handle = PyhiveConnectionWrapper(conn) elif creds.method == SparkConnectionMethod.ODBC: if creds.cluster is not None: - required_fields = ['driver', 'host', 'port', 'token', - 'organization', 'cluster'] + required_fields = [ + "driver", + "host", + "port", + "token", + "organization", + "cluster", + ] http_path = cls.SPARK_CLUSTER_HTTP_PATH.format( - organization=creds.organization, - cluster=creds.cluster + organization=creds.organization, cluster=creds.cluster ) elif creds.endpoint is not None: - required_fields = ['driver', 'host', 'port', 'token', - 'endpoint'] + required_fields = ["driver", "host", "port", "token", "endpoint"] http_path = cls.SPARK_SQL_ENDPOINT_HTTP_PATH.format( endpoint=creds.endpoint ) @@ -429,13 +409,12 @@ def open(cls, connection): cls.validate_creds(creds, required_fields) dbt_spark_version = __version__.version - user_agent_entry = f"dbt-labs-dbt-spark/{dbt_spark_version} (Databricks)" # noqa + user_agent_entry = ( + f"dbt-labs-dbt-spark/{dbt_spark_version} (Databricks)" # noqa + ) # http://simba.wpengine.com/products/Spark/doc/ODBC_InstallGuide/unix/content/odbc/hi/configuring/serverside.htm - ssp = { - f"SSP_{k}": f"{{{v}}}" - for k, v in creds.server_side_parameters.items() - } + ssp = {f"SSP_{k}": f"{{{v}}}" for k, v in creds.server_side_parameters.items()} # https://www.simba.com/products/Spark/doc/v2/ODBC_InstallGuide/unix/content/odbc/options/driver.htm connection_str = _build_odbc_connnection_string( @@ -461,6 +440,7 @@ def open(cls, connection): Connection, SessionConnectionWrapper, ) + handle = SessionConnectionWrapper(Connection()) else: raise dbt.exceptions.DbtProfileError( @@ -472,9 +452,9 @@ def open(cls, connection): if isinstance(e, EOFError): # The user almost certainly has invalid credentials. # Perhaps a token expired, or something - msg = 'Failed to connect' + msg = "Failed to connect" if creds.token is not None: - msg += ', is your token valid?' + msg += ", is your token valid?" raise dbt.exceptions.FailedToConnectException(msg) from e retryable_message = _is_retryable_error(e) if retryable_message and creds.connect_retries > 0: @@ -496,9 +476,7 @@ def open(cls, connection): logger.warning(msg) time.sleep(creds.connect_timeout) else: - raise dbt.exceptions.FailedToConnectException( - 'failed to connect' - ) from e + raise dbt.exceptions.FailedToConnectException("failed to connect") from e else: raise exc @@ -507,56 +485,50 @@ def open(cls, connection): return connection -def build_ssl_transport(host, port, username, auth, - kerberos_service_name, password=None): +def build_ssl_transport(host, port, username, auth, kerberos_service_name, password=None): transport = None if port is None: port = 10000 if auth is None: - auth = 'NONE' + auth = "NONE" socket = TSSLSocket(host, port, cert_reqs=ssl.CERT_NONE) - if auth == 'NOSASL': + if auth == "NOSASL": # NOSASL corresponds to hive.server2.authentication=NOSASL # in hive-site.xml transport = thrift.transport.TTransport.TBufferedTransport(socket) - elif auth in ('LDAP', 'KERBEROS', 'NONE', 'CUSTOM'): + elif auth in ("LDAP", "KERBEROS", "NONE", "CUSTOM"): # Defer import so package dependency is optional - if auth == 'KERBEROS': + if auth == "KERBEROS": # KERBEROS mode in hive.server2.authentication is GSSAPI # in sasl library - sasl_auth = 'GSSAPI' + sasl_auth = "GSSAPI" else: - sasl_auth = 'PLAIN' + sasl_auth = "PLAIN" if password is None: # Password doesn't matter in NONE mode, just needs # to be nonempty. - password = 'x' + password = "x" def sasl_factory(): sasl_client = sasl.Client() - sasl_client.setAttr('host', host) - if sasl_auth == 'GSSAPI': - sasl_client.setAttr('service', kerberos_service_name) - elif sasl_auth == 'PLAIN': - sasl_client.setAttr('username', username) - sasl_client.setAttr('password', password) + sasl_client.setAttr("host", host) + if sasl_auth == "GSSAPI": + sasl_client.setAttr("service", kerberos_service_name) + elif sasl_auth == "PLAIN": + sasl_client.setAttr("username", username) + sasl_client.setAttr("password", password) else: raise AssertionError sasl_client.init() return sasl_client - transport = thrift_sasl.TSaslClientTransport(sasl_factory, - sasl_auth, socket) + transport = thrift_sasl.TSaslClientTransport(sasl_factory, sasl_auth, socket) return transport -def _is_retryable_error(exc: Exception) -> Optional[str]: - message = getattr(exc, 'message', None) - if message is None: - return None - message = message.lower() - if 'pending' in message: - return exc.message - if 'temporarily_unavailable' in message: - return exc.message - return None +def _is_retryable_error(exc: Exception) -> str: + message = str(exc).lower() + if "pending" in message or "temporarily_unavailable" in message: + return str(exc) + else: + return "" diff --git a/dbt/adapters/spark/impl.py b/dbt/adapters/spark/impl.py index eb001fbc9..dd090a23b 100644 --- a/dbt/adapters/spark/impl.py +++ b/dbt/adapters/spark/impl.py @@ -1,7 +1,9 @@ import re from concurrent.futures import Future from dataclasses import dataclass -from typing import Optional, List, Dict, Any, Union, Iterable +from typing import Any, Dict, Iterable, List, Optional, Union +from typing_extensions import TypeAlias + import agate from dbt.contracts.relation import RelationType @@ -21,19 +23,19 @@ logger = AdapterLogger("Spark") -GET_COLUMNS_IN_RELATION_MACRO_NAME = 'get_columns_in_relation' -LIST_SCHEMAS_MACRO_NAME = 'list_schemas' -LIST_RELATIONS_MACRO_NAME = 'list_relations_without_caching' -DROP_RELATION_MACRO_NAME = 'drop_relation' -FETCH_TBL_PROPERTIES_MACRO_NAME = 'fetch_tbl_properties' +GET_COLUMNS_IN_RELATION_MACRO_NAME = "get_columns_in_relation" +LIST_SCHEMAS_MACRO_NAME = "list_schemas" +LIST_RELATIONS_MACRO_NAME = "list_relations_without_caching" +DROP_RELATION_MACRO_NAME = "drop_relation" +FETCH_TBL_PROPERTIES_MACRO_NAME = "fetch_tbl_properties" -KEY_TABLE_OWNER = 'Owner' -KEY_TABLE_STATISTICS = 'Statistics' +KEY_TABLE_OWNER = "Owner" +KEY_TABLE_STATISTICS = "Statistics" @dataclass class SparkConfig(AdapterConfig): - file_format: str = 'parquet' + file_format: str = "parquet" location_root: Optional[str] = None partition_by: Optional[Union[List[str], str]] = None clustered_by: Optional[Union[List[str], str]] = None @@ -44,48 +46,44 @@ class SparkConfig(AdapterConfig): class SparkAdapter(SQLAdapter): COLUMN_NAMES = ( - 'table_database', - 'table_schema', - 'table_name', - 'table_type', - 'table_comment', - 'table_owner', - 'column_name', - 'column_index', - 'column_type', - 'column_comment', - - 'stats:bytes:label', - 'stats:bytes:value', - 'stats:bytes:description', - 'stats:bytes:include', - - 'stats:rows:label', - 'stats:rows:value', - 'stats:rows:description', - 'stats:rows:include', + "table_database", + "table_schema", + "table_name", + "table_type", + "table_comment", + "table_owner", + "column_name", + "column_index", + "column_type", + "column_comment", + "stats:bytes:label", + "stats:bytes:value", + "stats:bytes:description", + "stats:bytes:include", + "stats:rows:label", + "stats:rows:value", + "stats:rows:description", + "stats:rows:include", ) - INFORMATION_COLUMNS_REGEX = re.compile( - r"^ \|-- (.*): (.*) \(nullable = (.*)\b", re.MULTILINE) + INFORMATION_COLUMNS_REGEX = re.compile(r"^ \|-- (.*): (.*) \(nullable = (.*)\b", re.MULTILINE) INFORMATION_OWNER_REGEX = re.compile(r"^Owner: (.*)$", re.MULTILINE) - INFORMATION_STATISTICS_REGEX = re.compile( - r"^Statistics: (.*)$", re.MULTILINE) + INFORMATION_STATISTICS_REGEX = re.compile(r"^Statistics: (.*)$", re.MULTILINE) HUDI_METADATA_COLUMNS = [ - '_hoodie_commit_time', - '_hoodie_commit_seqno', - '_hoodie_record_key', - '_hoodie_partition_path', - '_hoodie_file_name' + "_hoodie_commit_time", + "_hoodie_commit_seqno", + "_hoodie_record_key", + "_hoodie_partition_path", + "_hoodie_file_name", ] - Relation = SparkRelation - Column = SparkColumn - ConnectionManager = SparkConnectionManager - AdapterSpecificConfigs = SparkConfig + Relation: TypeAlias = SparkRelation + Column: TypeAlias = SparkColumn + ConnectionManager: TypeAlias = SparkConnectionManager + AdapterSpecificConfigs: TypeAlias = SparkConfig @classmethod def date_function(cls) -> str: - return 'current_timestamp()' + return "current_timestamp()" @classmethod def convert_text_type(cls, agate_table, col_idx): @@ -109,31 +107,28 @@ def convert_datetime_type(cls, agate_table, col_idx): return "timestamp" def quote(self, identifier): - return '`{}`'.format(identifier) + return "`{}`".format(identifier) def add_schema_to_cache(self, schema) -> str: """Cache a new schema in dbt. It will show up in `list relations`.""" if schema is None: name = self.nice_connection_name() dbt.exceptions.raise_compiler_error( - 'Attempted to cache a null schema for {}'.format(name) + "Attempted to cache a null schema for {}".format(name) ) if dbt.flags.USE_CACHE: self.cache.add_schema(None, schema) # so jinja doesn't render things - return '' + return "" def list_relations_without_caching( self, schema_relation: SparkRelation ) -> List[SparkRelation]: - kwargs = {'schema_relation': schema_relation} + kwargs = {"schema_relation": schema_relation} try: - results = self.execute_macro( - LIST_RELATIONS_MACRO_NAME, - kwargs=kwargs - ) + results = self.execute_macro(LIST_RELATIONS_MACRO_NAME, kwargs=kwargs) except dbt.exceptions.RuntimeException as e: - errmsg = getattr(e, 'msg', '') + errmsg = getattr(e, "msg", "") if f"Database '{schema_relation}' not found" in errmsg: return [] else: @@ -146,13 +141,12 @@ def list_relations_without_caching( if len(row) != 4: raise dbt.exceptions.RuntimeException( f'Invalid value from "show table extended ...", ' - f'got {len(row)} values, expected 4' + f"got {len(row)} values, expected 4" ) _schema, name, _, information = row - rel_type = RelationType.View \ - if 'Type: VIEW' in information else RelationType.Table - is_delta = 'Provider: delta' in information - is_hudi = 'Provider: hudi' in information + rel_type = RelationType.View if "Type: VIEW" in information else RelationType.Table + is_delta = "Provider: delta" in information + is_hudi = "Provider: hudi" in information relation = self.Relation.create( schema=_schema, identifier=name, @@ -166,7 +160,7 @@ def list_relations_without_caching( return relations def get_relation( - self, database: str, schema: str, identifier: str + self, database: Optional[str], schema: str, identifier: str ) -> Optional[BaseRelation]: if not self.Relation.include_policy.database: database = None @@ -174,9 +168,7 @@ def get_relation( return super().get_relation(database, schema, identifier) def parse_describe_extended( - self, - relation: Relation, - raw_rows: List[agate.Row] + self, relation: Relation, raw_rows: List[agate.Row] ) -> List[SparkColumn]: # Convert the Row to a dict dict_rows = [dict(zip(row._keys, row._values)) for row in raw_rows] @@ -185,44 +177,45 @@ def parse_describe_extended( pos = self.find_table_information_separator(dict_rows) # Remove rows that start with a hash, they are comments - rows = [ - row for row in raw_rows[0:pos] - if not row['col_name'].startswith('#') - ] - metadata = { - col['col_name']: col['data_type'] for col in raw_rows[pos + 1:] - } + rows = [row for row in raw_rows[0:pos] if not row["col_name"].startswith("#")] + metadata = {col["col_name"]: col["data_type"] for col in raw_rows[pos + 1 :]} raw_table_stats = metadata.get(KEY_TABLE_STATISTICS) table_stats = SparkColumn.convert_table_stats(raw_table_stats) - return [SparkColumn( - table_database=None, - table_schema=relation.schema, - table_name=relation.name, - table_type=relation.type, - table_owner=str(metadata.get(KEY_TABLE_OWNER)), - table_stats=table_stats, - column=column['col_name'], - column_index=idx, - dtype=column['data_type'], - ) for idx, column in enumerate(rows)] + return [ + SparkColumn( + table_database=None, + table_schema=relation.schema, + table_name=relation.name, + table_type=relation.type, + table_owner=str(metadata.get(KEY_TABLE_OWNER)), + table_stats=table_stats, + column=column["col_name"], + column_index=idx, + dtype=column["data_type"], + ) + for idx, column in enumerate(rows) + ] @staticmethod def find_table_information_separator(rows: List[dict]) -> int: pos = 0 for row in rows: - if not row['col_name'] or row['col_name'].startswith('#'): + if not row["col_name"] or row["col_name"].startswith("#"): break pos += 1 return pos def get_columns_in_relation(self, relation: Relation) -> List[SparkColumn]: - cached_relations = self.cache.get_relations( - relation.database, relation.schema) - cached_relation = next((cached_relation - for cached_relation in cached_relations - if str(cached_relation) == str(relation)), - None) + cached_relations = self.cache.get_relations(relation.database, relation.schema) + cached_relation = next( + ( + cached_relation + for cached_relation in cached_relations + if str(cached_relation) == str(relation) + ), + None, + ) columns = [] if cached_relation and cached_relation.information: columns = self.parse_columns_from_information(cached_relation) @@ -238,30 +231,21 @@ def get_columns_in_relation(self, relation: Relation) -> List[SparkColumn]: # spark would throw error when table doesn't exist, where other # CDW would just return and empty list, normalizing the behavior here errmsg = getattr(e, "msg", "") - if ( - "Table or view not found" in errmsg or - "NoSuchTableException" in errmsg - ): + if "Table or view not found" in errmsg or "NoSuchTableException" in errmsg: pass else: raise e # strip hudi metadata columns. - columns = [x for x in columns - if x.name not in self.HUDI_METADATA_COLUMNS] + columns = [x for x in columns if x.name not in self.HUDI_METADATA_COLUMNS] return columns - def parse_columns_from_information( - self, relation: SparkRelation - ) -> List[SparkColumn]: - owner_match = re.findall( - self.INFORMATION_OWNER_REGEX, relation.information) + def parse_columns_from_information(self, relation: SparkRelation) -> List[SparkColumn]: + owner_match = re.findall(self.INFORMATION_OWNER_REGEX, relation.information) owner = owner_match[0] if owner_match else None - matches = re.finditer( - self.INFORMATION_COLUMNS_REGEX, relation.information) + matches = re.finditer(self.INFORMATION_COLUMNS_REGEX, relation.information) columns = [] - stats_match = re.findall( - self.INFORMATION_STATISTICS_REGEX, relation.information) + stats_match = re.findall(self.INFORMATION_STATISTICS_REGEX, relation.information) raw_table_stats = stats_match[0] if stats_match else None table_stats = SparkColumn.convert_table_stats(raw_table_stats) for match_num, match in enumerate(matches): @@ -275,28 +259,25 @@ def parse_columns_from_information( table_owner=owner, column=column_name, dtype=column_type, - table_stats=table_stats + table_stats=table_stats, ) columns.append(column) return columns - def _get_columns_for_catalog( - self, relation: SparkRelation - ) -> Iterable[Dict[str, Any]]: + def _get_columns_for_catalog(self, relation: SparkRelation) -> Iterable[Dict[str, Any]]: columns = self.parse_columns_from_information(relation) for column in columns: # convert SparkColumns into catalog dicts as_dict = column.to_column_dict() - as_dict['column_name'] = as_dict.pop('column', None) - as_dict['column_type'] = as_dict.pop('dtype') - as_dict['table_database'] = None + as_dict["column_name"] = as_dict.pop("column", None) + as_dict["column_type"] = as_dict.pop("dtype") + as_dict["table_database"] = None yield as_dict def get_properties(self, relation: Relation) -> Dict[str, str]: properties = self.execute_macro( - FETCH_TBL_PROPERTIES_MACRO_NAME, - kwargs={'relation': relation} + FETCH_TBL_PROPERTIES_MACRO_NAME, kwargs={"relation": relation} ) return dict(properties) @@ -304,28 +285,30 @@ def get_catalog(self, manifest): schema_map = self._get_catalog_schemas(manifest) if len(schema_map) > 1: dbt.exceptions.raise_compiler_error( - f'Expected only one database in get_catalog, found ' - f'{list(schema_map)}' + f"Expected only one database in get_catalog, found " f"{list(schema_map)}" ) with executor(self.config) as tpe: futures: List[Future[agate.Table]] = [] for info, schemas in schema_map.items(): for schema in schemas: - futures.append(tpe.submit_connected( - self, schema, - self._get_one_catalog, info, [schema], manifest - )) + futures.append( + tpe.submit_connected( + self, schema, self._get_one_catalog, info, [schema], manifest + ) + ) catalogs, exceptions = catch_as_completed(futures) return catalogs, exceptions def _get_one_catalog( - self, information_schema, schemas, manifest, + self, + information_schema, + schemas, + manifest, ) -> agate.Table: if len(schemas) != 1: dbt.exceptions.raise_compiler_error( - f'Expected only one schema in spark _get_one_catalog, found ' - f'{schemas}' + f"Expected only one schema in spark _get_one_catalog, found " f"{schemas}" ) database = information_schema.database @@ -335,15 +318,10 @@ def _get_one_catalog( for relation in self.list_relations(database, schema): logger.debug("Getting table schema for relation {}", relation) columns.extend(self._get_columns_for_catalog(relation)) - return agate.Table.from_object( - columns, column_types=DEFAULT_TYPE_TESTER - ) + return agate.Table.from_object(columns, column_types=DEFAULT_TYPE_TESTER) def check_schema_exists(self, database, schema): - results = self.execute_macro( - LIST_SCHEMAS_MACRO_NAME, - kwargs={'database': database} - ) + results = self.execute_macro(LIST_SCHEMAS_MACRO_NAME, kwargs={"database": database}) exists = True if schema in [row[0] for row in results] else False return exists @@ -353,7 +331,7 @@ def get_rows_different_sql( relation_a: BaseRelation, relation_b: BaseRelation, column_names: Optional[List[str]] = None, - except_operator: str = 'EXCEPT', + except_operator: str = "EXCEPT", ) -> str: """Generate SQL for a query that returns a single row with a two columns: the number of rows that are different between the two @@ -366,7 +344,7 @@ def get_rows_different_sql( names = sorted((self.quote(c.name) for c in columns)) else: names = sorted((self.quote(n) for n in column_names)) - columns_csv = ', '.join(names) + columns_csv = ", ".join(names) sql = COLUMNS_EQUAL_SQL.format( columns=columns_csv, @@ -384,7 +362,7 @@ def run_sql_for_tests(self, sql, fetch, conn): try: cursor.execute(sql) if fetch == "one": - if hasattr(cursor, 'fetchone'): + if hasattr(cursor, "fetchone"): return cursor.fetchone() else: # AttributeError: 'PyhiveConnectionWrapper' object has no attribute 'fetchone' @@ -406,7 +384,7 @@ def run_sql_for_tests(self, sql, fetch, conn): # "trivial". Which is true, though it seems like an unreasonable cause for # failure! It also doesn't like the `from foo, bar` syntax as opposed to # `from foo cross join bar`. -COLUMNS_EQUAL_SQL = ''' +COLUMNS_EQUAL_SQL = """ with diff_count as ( SELECT 1 as id, @@ -433,4 +411,4 @@ def run_sql_for_tests(self, sql, fetch, conn): diff_count.num_missing as num_mismatched from row_count_diff cross join diff_count -'''.strip() +""".strip() diff --git a/dbt/adapters/spark/relation.py b/dbt/adapters/spark/relation.py index 043cabfa0..249caf0d7 100644 --- a/dbt/adapters/spark/relation.py +++ b/dbt/adapters/spark/relation.py @@ -24,19 +24,19 @@ class SparkIncludePolicy(Policy): class SparkRelation(BaseRelation): quote_policy: SparkQuotePolicy = SparkQuotePolicy() include_policy: SparkIncludePolicy = SparkIncludePolicy() - quote_character: str = '`' + quote_character: str = "`" is_delta: Optional[bool] = None is_hudi: Optional[bool] = None - information: str = None + information: Optional[str] = None def __post_init__(self): if self.database != self.schema and self.database: - raise RuntimeException('Cannot set database in spark!') + raise RuntimeException("Cannot set database in spark!") def render(self): if self.include_policy.database and self.include_policy.schema: raise RuntimeException( - 'Got a spark relation with schema and database set to ' - 'include, but only one can be set' + "Got a spark relation with schema and database set to " + "include, but only one can be set" ) return super().render() diff --git a/dbt/adapters/spark/session.py b/dbt/adapters/spark/session.py index 6010df920..beb77d548 100644 --- a/dbt/adapters/spark/session.py +++ b/dbt/adapters/spark/session.py @@ -4,7 +4,7 @@ import datetime as dt from types import TracebackType -from typing import Any +from typing import Any, List, Optional, Tuple from dbt.events import AdapterLogger from dbt.utils import DECIMALS @@ -25,17 +25,17 @@ class Cursor: """ def __init__(self) -> None: - self._df: DataFrame | None = None - self._rows: list[Row] | None = None + self._df: Optional[DataFrame] = None + self._rows: Optional[List[Row]] = None def __enter__(self) -> Cursor: return self def __exit__( self, - exc_type: type[BaseException] | None, - exc_val: Exception | None, - exc_tb: TracebackType | None, + exc_type: Optional[BaseException], + exc_val: Optional[Exception], + exc_tb: Optional[TracebackType], ) -> bool: self.close() return True @@ -43,13 +43,13 @@ def __exit__( @property def description( self, - ) -> list[tuple[str, str, None, None, None, None, bool]]: + ) -> List[Tuple[str, str, None, None, None, None, bool]]: """ Get the description. Returns ------- - out : list[tuple[str, str, None, None, None, None, bool]] + out : List[Tuple[str, str, None, None, None, None, bool]] The description. Source @@ -109,13 +109,13 @@ def execute(self, sql: str, *parameters: Any) -> None: spark_session = SparkSession.builder.enableHiveSupport().getOrCreate() self._df = spark_session.sql(sql) - def fetchall(self) -> list[Row] | None: + def fetchall(self) -> Optional[List[Row]]: """ Fetch all data. Returns ------- - out : list[Row] | None + out : Optional[List[Row]] The rows. Source @@ -126,7 +126,7 @@ def fetchall(self) -> list[Row] | None: self._rows = self._df.collect() return self._rows - def fetchone(self) -> Row | None: + def fetchone(self) -> Optional[Row]: """ Fetch the first output. diff --git a/dbt/include/spark/__init__.py b/dbt/include/spark/__init__.py index 564a3d1e8..b177e5d49 100644 --- a/dbt/include/spark/__init__.py +++ b/dbt/include/spark/__init__.py @@ -1,2 +1,3 @@ import os + PACKAGE_PATH = os.path.dirname(__file__) diff --git a/dbt/include/spark/macros/adapters.sql b/dbt/include/spark/macros/adapters.sql index e96501c45..22381d9ea 100644 --- a/dbt/include/spark/macros/adapters.sql +++ b/dbt/include/spark/macros/adapters.sql @@ -231,7 +231,7 @@ {% set comment = column_dict[column_name]['description'] %} {% set escaped_comment = comment | replace('\'', '\\\'') %} {% set comment_query %} - alter table {{ relation }} change column + alter table {{ relation }} change column {{ adapter.quote(column_name) if column_dict[column_name]['quote'] else column_name }} comment '{{ escaped_comment }}'; {% endset %} @@ -260,25 +260,25 @@ {% macro spark__alter_relation_add_remove_columns(relation, add_columns, remove_columns) %} - + {% if remove_columns %} {% set platform_name = 'Delta Lake' if relation.is_delta else 'Apache Spark' %} {{ exceptions.raise_compiler_error(platform_name + ' does not support dropping columns from tables') }} {% endif %} - + {% if add_columns is none %} {% set add_columns = [] %} {% endif %} - + {% set sql -%} - + alter {{ relation.type }} {{ relation }} - + {% if add_columns %} add columns {% endif %} {% for column in add_columns %} {{ column.name }} {{ column.data_type }}{{ ',' if not loop.last }} {% endfor %} - + {%- endset -%} {% do run_query(sql) %} diff --git a/dbt/include/spark/macros/materializations/incremental/incremental.sql b/dbt/include/spark/macros/materializations/incremental/incremental.sql index d0b6e89ba..8d8e69d93 100644 --- a/dbt/include/spark/macros/materializations/incremental/incremental.sql +++ b/dbt/include/spark/macros/materializations/incremental/incremental.sql @@ -1,17 +1,17 @@ {% materialization incremental, adapter='spark' -%} - + {#-- Validate early so we don't run SQL if the file_format + strategy combo is invalid --#} {%- set raw_file_format = config.get('file_format', default='parquet') -%} {%- set raw_strategy = config.get('incremental_strategy', default='append') -%} - + {%- set file_format = dbt_spark_validate_get_file_format(raw_file_format) -%} {%- set strategy = dbt_spark_validate_get_incremental_strategy(raw_strategy, file_format) -%} - + {%- set unique_key = config.get('unique_key', none) -%} {%- set partition_by = config.get('partition_by', none) -%} {%- set full_refresh_mode = (should_full_refresh()) -%} - + {% set on_schema_change = incremental_validate_on_schema_change(config.get('on_schema_change'), default='ignore') %} {% set target_relation = this %} @@ -42,7 +42,7 @@ {%- endcall -%} {% do persist_docs(target_relation, model) %} - + {{ run_hooks(post_hooks) }} {{ return({'relations': [target_relation]}) }} diff --git a/dbt/include/spark/macros/materializations/incremental/strategies.sql b/dbt/include/spark/macros/materializations/incremental/strategies.sql index 215b5f3f9..28b8f2001 100644 --- a/dbt/include/spark/macros/materializations/incremental/strategies.sql +++ b/dbt/include/spark/macros/materializations/incremental/strategies.sql @@ -1,5 +1,5 @@ {% macro get_insert_overwrite_sql(source_relation, target_relation) %} - + {%- set dest_columns = adapter.get_columns_in_relation(target_relation) -%} {%- set dest_cols_csv = dest_columns | map(attribute='quoted') | join(', ') -%} insert overwrite table {{ target_relation }} @@ -41,20 +41,20 @@ {% else %} {% do predicates.append('FALSE') %} {% endif %} - + {{ sql_header if sql_header is not none }} - + merge into {{ target }} as DBT_INTERNAL_DEST using {{ source.include(schema=false) }} as DBT_INTERNAL_SOURCE on {{ predicates | join(' and ') }} - + when matched then update set {% if update_columns -%}{%- for column_name in update_columns %} {{ column_name }} = DBT_INTERNAL_SOURCE.{{ column_name }} {%- if not loop.last %}, {%- endif %} {%- endfor %} {%- else %} * {% endif %} - + when not matched then insert * {% endmacro %} diff --git a/dbt/include/spark/macros/materializations/incremental/validate.sql b/dbt/include/spark/macros/materializations/incremental/validate.sql index 3e9de359b..ffd56f106 100644 --- a/dbt/include/spark/macros/materializations/incremental/validate.sql +++ b/dbt/include/spark/macros/materializations/incremental/validate.sql @@ -28,13 +28,13 @@ Invalid incremental strategy provided: {{ raw_strategy }} You can only choose this strategy when file_format is set to 'delta' or 'hudi' {%- endset %} - + {% set invalid_insert_overwrite_delta_msg -%} Invalid incremental strategy provided: {{ raw_strategy }} You cannot use this strategy when file_format is set to 'delta' Use the 'append' or 'merge' strategy instead {%- endset %} - + {% set invalid_insert_overwrite_endpoint_msg -%} Invalid incremental strategy provided: {{ raw_strategy }} You cannot use this strategy when connecting via endpoint diff --git a/dbt/include/spark/macros/materializations/snapshot.sql b/dbt/include/spark/macros/materializations/snapshot.sql index 82d186ce2..9c891ef04 100644 --- a/dbt/include/spark/macros/materializations/snapshot.sql +++ b/dbt/include/spark/macros/materializations/snapshot.sql @@ -32,7 +32,7 @@ {% macro spark_build_snapshot_staging_table(strategy, sql, target_relation) %} {% set tmp_identifier = target_relation.identifier ~ '__dbt_tmp' %} - + {%- set tmp_relation = api.Relation.create(identifier=tmp_identifier, schema=target_relation.schema, database=none, diff --git a/dbt/include/spark/macros/materializations/table.sql b/dbt/include/spark/macros/materializations/table.sql index 3ae2df973..2eeb806fd 100644 --- a/dbt/include/spark/macros/materializations/table.sql +++ b/dbt/include/spark/macros/materializations/table.sql @@ -21,7 +21,7 @@ {% call statement('main') -%} {{ create_table_as(False, target_relation, sql) }} {%- endcall %} - + {% do persist_docs(target_relation, model) %} {{ run_hooks(post_hooks) }} diff --git a/dev_requirements.txt b/dev_requirements.txt index 16a41588d..3bd5187a5 100644 --- a/dev_requirements.txt +++ b/dev_requirements.txt @@ -11,7 +11,7 @@ flaky freezegun==0.3.9 ipdb mock>=1.3.0 -mypy==0.782 +mypy==0.950 pre-commit pytest-csv pytest-dotenv diff --git a/scripts/build-dist.sh b/scripts/build-dist.sh index 65e6dbc97..3c3808399 100755 --- a/scripts/build-dist.sh +++ b/scripts/build-dist.sh @@ -1,4 +1,4 @@ -#!/bin/bash +#!/bin/bash set -eo pipefail diff --git a/setup.py b/setup.py index 12ecbacde..836aeed43 100644 --- a/setup.py +++ b/setup.py @@ -5,41 +5,39 @@ # require python 3.7 or newer if sys.version_info < (3, 7): - print('Error: dbt does not support this version of Python.') - print('Please upgrade to Python 3.7 or higher.') + print("Error: dbt does not support this version of Python.") + print("Please upgrade to Python 3.7 or higher.") sys.exit(1) # require version of setuptools that supports find_namespace_packages from setuptools import setup + try: from setuptools import find_namespace_packages except ImportError: # the user has a downlevel version of setuptools. - print('Error: dbt requires setuptools v40.1.0 or higher.') - print('Please upgrade setuptools with "pip install --upgrade setuptools" ' - 'and try again') + print("Error: dbt requires setuptools v40.1.0 or higher.") + print('Please upgrade setuptools with "pip install --upgrade setuptools" ' "and try again") sys.exit(1) # pull long description from README this_directory = os.path.abspath(os.path.dirname(__file__)) -with open(os.path.join(this_directory, 'README.md'), 'r', encoding='utf8') as f: +with open(os.path.join(this_directory, "README.md"), "r", encoding="utf8") as f: long_description = f.read() # get this package's version from dbt/adapters//__version__.py def _get_plugin_version_dict(): - _version_path = os.path.join( - this_directory, 'dbt', 'adapters', 'spark', '__version__.py' - ) - _semver = r'''(?P\d+)\.(?P\d+)\.(?P\d+)''' - _pre = r'''((?Pa|b|rc)(?P
\d+))?'''
-    _version_pattern = fr'''version\s*=\s*["']{_semver}{_pre}["']'''
+    _version_path = os.path.join(this_directory, "dbt", "adapters", "spark", "__version__.py")
+    _semver = r"""(?P\d+)\.(?P\d+)\.(?P\d+)"""
+    _pre = r"""((?Pa|b|rc)(?P
\d+))?"""
+    _version_pattern = fr"""version\s*=\s*["']{_semver}{_pre}["']"""
     with open(_version_path) as f:
         match = re.search(_version_pattern, f.read().strip())
         if match is None:
-            raise ValueError(f'invalid version at {_version_path}')
+            raise ValueError(f"invalid version at {_version_path}")
         return match.groupdict()
 
 
@@ -47,7 +45,7 @@ def _get_plugin_version_dict():
 def _get_dbt_core_version():
     parts = _get_plugin_version_dict()
     minor = "{major}.{minor}.0".format(**parts)
-    pre = (parts["prekind"]+"1" if parts["prekind"] else "")
+    pre = parts["prekind"] + "1" if parts["prekind"] else ""
     return f"{minor}{pre}"
 
 
@@ -56,33 +54,28 @@ def _get_dbt_core_version():
 dbt_core_version = _get_dbt_core_version()
 description = """The Apache Spark adapter plugin for dbt"""
 
-odbc_extras = ['pyodbc>=4.0.30']
+odbc_extras = ["pyodbc>=4.0.30"]
 pyhive_extras = [
-    'PyHive[hive]>=0.6.0,<0.7.0',
-    'thrift>=0.11.0,<0.16.0',
-]
-session_extras = [
-    "pyspark>=3.0.0,<4.0.0"
+    "PyHive[hive]>=0.6.0,<0.7.0",
+    "thrift>=0.11.0,<0.16.0",
 ]
+session_extras = ["pyspark>=3.0.0,<4.0.0"]
 all_extras = odbc_extras + pyhive_extras + session_extras
 
 setup(
     name=package_name,
     version=package_version,
-
     description=description,
     long_description=long_description,
-    long_description_content_type='text/markdown',
-
-    author='dbt Labs',
-    author_email='info@dbtlabs.com',
-    url='https://github.com/dbt-labs/dbt-spark',
-
-    packages=find_namespace_packages(include=['dbt', 'dbt.*']),
+    long_description_content_type="text/markdown",
+    author="dbt Labs",
+    author_email="info@dbtlabs.com",
+    url="https://github.com/dbt-labs/dbt-spark",
+    packages=find_namespace_packages(include=["dbt", "dbt.*"]),
     include_package_data=True,
     install_requires=[
-        'dbt-core~={}'.format(dbt_core_version),
-        'sqlparams>=3.0.0',
+        "dbt-core~={}".format(dbt_core_version),
+        "sqlparams>=3.0.0",
     ],
     extras_require={
         "ODBC": odbc_extras,
@@ -92,17 +85,14 @@ def _get_dbt_core_version():
     },
     zip_safe=False,
     classifiers=[
-        'Development Status :: 5 - Production/Stable',
-        
-        'License :: OSI Approved :: Apache Software License',
-        
-        'Operating System :: Microsoft :: Windows',
-        'Operating System :: MacOS :: MacOS X',
-        'Operating System :: POSIX :: Linux',
-
-        'Programming Language :: Python :: 3.7',
-        'Programming Language :: Python :: 3.8',
-        'Programming Language :: Python :: 3.9',
+        "Development Status :: 5 - Production/Stable",
+        "License :: OSI Approved :: Apache Software License",
+        "Operating System :: Microsoft :: Windows",
+        "Operating System :: MacOS :: MacOS X",
+        "Operating System :: POSIX :: Linux",
+        "Programming Language :: Python :: 3.7",
+        "Programming Language :: Python :: 3.8",
+        "Programming Language :: Python :: 3.9",
     ],
     python_requires=">=3.7",
 )