Skip to content

Commit

Permalink
Merge pull request #872 from DalgoT4D/handle-ssl-mode-warehouse
Browse files Browse the repository at this point in the history
Handle ssl mode warehouse
  • Loading branch information
Ishankoradia authored Oct 7, 2024
2 parents f3fb60f + 2b3780d commit 66241f6
Show file tree
Hide file tree
Showing 5 changed files with 140 additions and 5 deletions.
49 changes: 49 additions & 0 deletions ddpui/core/dbtfunctions.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import os
from typing import Union
from pathlib import Path
from ddpui.models.org import Org
from ddpui.ddpdbt.schema import DbtProjectParams
Expand All @@ -14,6 +15,7 @@ def gather_dbt_project_params(org: Org):
dbtrepodir = Path(os.getenv("CLIENTDBT_ROOT")) / org.slug / "dbtrepo"
project_dir = str(dbtrepodir)
target = org.dbt.default_schema
org_project_dir = Path(os.getenv("CLIENTDBT_ROOT")) / org.slug

return (
DbtProjectParams(
Expand All @@ -22,6 +24,53 @@ def gather_dbt_project_params(org: Org):
dbt_repo_dir=dbtrepodir,
target=target,
project_dir=project_dir,
org_project_dir=org_project_dir,
),
None,
)


def map_airbyte_destination_spec_to_dbtcli_profile(
conn_info: dict, dbt_project_params: Union[DbtProjectParams | None]
):
"""
Dbt doesn't support tunnel methods
So the translation to tunnel params is for our proxy service
To do a hack & run dbt using ssh tunnel
"""
if "tunnel_method" in conn_info:
method = conn_info["tunnel_method"]

if method["tunnel_method"] in ["SSH_KEY_AUTH", "SSH_PASSWORD_AUTH"]:
conn_info["ssh_host"] = method["tunnel_host"]
conn_info["ssh_port"] = method["tunnel_port"]
conn_info["ssh_username"] = method["tunnel_user"]

if method["tunnel_method"] == "SSH_KEY_AUTH":
conn_info["ssh_pkey"] = method["ssh_key"]
conn_info["ssh_private_key_password"] = method.get("tunnel_private_key_password")

elif method["tunnel_method"] == "SSH_PASSWORD_AUTH":
conn_info["ssh_password"] = method.get("tunnel_user_password")

if "username" in conn_info:
conn_info["user"] = conn_info["username"]

# handle dbt ssl params
if "ssl_mode" in conn_info:
ssl_data = conn_info["ssl_mode"]
mode = ssl_data["mode"] if "mode" in ssl_data else None
ca_certificate = ssl_data["ca_certificate"] if "ca_certificate" in ssl_data else None
# client_key_password = (
# ssl_data["client_key_password"] if "client_key_password" in ssl_data else None
# )
if mode:
conn_info["sslmode"] = mode

if ca_certificate and dbt_project_params.org_project_dir:
file_path = os.path.join(dbt_project_params.org_project_dir, "sslrootcert.pem")
with open(file_path, "w") as file:
file.write(ca_certificate)
conn_info["sslrootcert"] = file_path

return conn_info
9 changes: 6 additions & 3 deletions ddpui/ddpairbyte/airbytehelpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,14 +31,13 @@
)
from ddpui.ddpprefect import AIRBYTESERVER
from ddpui.ddpprefect import DBTCLIPROFILE
from ddpui.ddpdbt.schema import DbtProjectParams
from ddpui.core.dbtfunctions import map_airbyte_destination_spec_to_dbtcli_profile
from ddpui.models.org import OrgDataFlowv1, OrgWarehouse
from ddpui.models.tasks import Task, OrgTask, DataflowOrgTask
from ddpui.utils.constants import TASK_AIRBYTESYNC, TASK_AIRBYTERESET
from ddpui.utils.helpers import (
generate_hash_id,
update_dict_but_not_stars,
map_airbyte_keys_to_postgres_keys,
)
from ddpui.utils import secretsmanager
from ddpui.assets.whitelist import DEMO_WHITELIST_SOURCES
Expand Down Expand Up @@ -793,7 +792,6 @@ def create_or_update_org_cli_block(org: Org, warehouse: OrgWarehouse, airbyte_cr
"""
Create/update the block in db and also in prefect
"""
dbt_creds = map_airbyte_keys_to_postgres_keys(airbyte_creds)

bqlocation = None
if warehouse.wtype == "bigquery":
Expand Down Expand Up @@ -824,6 +822,11 @@ def create_or_update_org_cli_block(org: Org, warehouse: OrgWarehouse, airbyte_cr
)
logger.error(err)

dbt_creds = map_airbyte_destination_spec_to_dbtcli_profile(airbyte_creds, dbt_project_params)

dbt_creds.pop("ssl_mode", None)
dbt_creds.pop("ssl", None)

# set defaults to target and profile
# cant create a cli profile without these two
# idea is these should be updated when we setup transformation or update the warehouse
Expand Down
1 change: 1 addition & 0 deletions ddpui/ddpdbt/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,3 +13,4 @@ class DbtProjectParams(Schema):
project_dir: Union[str, Path]
target: str
dbt_repo_dir: Union[str, Path]
org_project_dir: Union[str, Path]
78 changes: 78 additions & 0 deletions ddpui/tests/core/test_dbtfunctions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
from ddpui.ddpdbt.schema import DbtProjectParams
from ddpui.core.dbtfunctions import map_airbyte_destination_spec_to_dbtcli_profile


def test_map_airbyte_destination_spec_to_dbtcli_profile_success_tunnel_params(tmpdir):
"""Tests all the success cases"""
dbt_project_params = DbtProjectParams(
org_project_dir=tmpdir,
dbt_env_dir="/path/to/dbt_venv",
dbt_repo_dir="/path/to/dbt_repo",
project_dir="/path/to/project_dir",
target="target",
dbt_binary="dbt_binary",
)

conn_info = {"some": "random value"}
res = map_airbyte_destination_spec_to_dbtcli_profile(conn_info, dbt_project_params)
assert res == conn_info

# SSH_KEY_AUTH
conn_info = {
"tunnel_method": {
"tunnel_method": "SSH_KEY_AUTH",
"tunnel_host": "tunnel_host",
"tunnel_port": 22,
"tunnel_user": "tunnel_user",
"ssh_key": "ssh_key",
"tunnel_private_key_password": "tunnel_private_key_password",
}
}
res = map_airbyte_destination_spec_to_dbtcli_profile(conn_info, dbt_project_params)
assert res["ssh_host"] == conn_info["tunnel_method"]["tunnel_host"]
assert res["ssh_port"] == conn_info["tunnel_method"]["tunnel_port"]
assert res["ssh_username"] == conn_info["tunnel_method"]["tunnel_user"]
assert res["ssh_pkey"] == conn_info["tunnel_method"]["ssh_key"]
assert (
res["ssh_private_key_password"] == conn_info["tunnel_method"]["tunnel_private_key_password"]
)

# SSH_PASSWORD_AUTH
conn_info = {
"tunnel_method": {
"tunnel_method": "SSH_PASSWORD_AUTH",
"tunnel_host": "tunnel_host",
"tunnel_port": 22,
"tunnel_user": "tunnel_user",
"tunnel_user_password": "tunnel_user_password",
}
}
res = map_airbyte_destination_spec_to_dbtcli_profile(conn_info, dbt_project_params)
assert res["ssh_host"] == conn_info["tunnel_method"]["tunnel_host"]
assert res["ssh_port"] == conn_info["tunnel_method"]["tunnel_port"]
assert res["ssh_username"] == conn_info["tunnel_method"]["tunnel_user"]
assert res["ssh_password"] == conn_info["tunnel_method"]["tunnel_user_password"]

# make sure the username is mapped to user
conn_info = {"username": "username"}
res = map_airbyte_destination_spec_to_dbtcli_profile(conn_info, dbt_project_params)
assert res["user"] == conn_info["username"]


def test_map_airbyte_destination_spec_to_dbtcli_profile_success_ssl_params(tmpdir):
"""Tests all the success cases"""
dbt_project_params = DbtProjectParams(
org_project_dir=tmpdir,
dbt_env_dir="/path/to/dbt_venv",
dbt_repo_dir="/path/to/dbt_repo",
project_dir="/path/to/project_dir",
target="target",
dbt_binary="dbt_binary",
)

conn_info = {"ssl_mode": {"mode": "verify-ca", "ca_certificate": "ca_certificate"}}
res = map_airbyte_destination_spec_to_dbtcli_profile(conn_info, dbt_project_params)
assert res["sslmode"] == conn_info["ssl_mode"]["mode"]
assert res["sslrootcert"] == f"{tmpdir}/sslrootcert.pem"
with open(f"{tmpdir}/sslrootcert.pem") as file:
assert file.read() == conn_info["ssl_mode"]["ca_certificate"]
8 changes: 6 additions & 2 deletions ddpui/tests/services/test_dbt_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,12 @@
from ninja.errors import HttpError

from ddpui.models.org import Org, OrgDbt, OrgPrefectBlockv1, OrgWarehouse
from ddpui.ddpdbt.dbt_service import delete_dbt_workspace, setup_local_dbt_workspace
from ddpui.ddpprefect import DBTCORE, SHELLOPERATION, DBTCLIPROFILE, SECRET
from ddpui.ddpdbt.dbt_service import (
delete_dbt_workspace,
setup_local_dbt_workspace,
)
from ddpui.ddpprefect import DBTCLIPROFILE, SECRET
from ddpui.ddpdbt.schema import DbtProjectParams
from dbt_automation import assets

pytestmark = pytest.mark.django_db
Expand Down

0 comments on commit 66241f6

Please sign in to comment.