Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature/data insights out of ui4t #876

Merged
merged 5 commits into from
Oct 11, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 19 additions & 2 deletions ddpui/api/transform_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,10 @@
LockCanvasRequestSchema,
LockCanvasResponseSchema,
)
from ddpui.utils.taskprogress import TaskProgress
from ddpui.core.transformfunctions import validate_operation_config, check_canvas_locked
from ddpui.api.warehouse_api import get_warehouse_data
from ddpui.models.tasks import TaskProgressHashPrefix

from ddpui.core import dbtautomation_service
from ddpui.core.dbtautomation_service import sync_sources_for_warehouse
Expand Down Expand Up @@ -114,9 +116,24 @@ def sync_sources(request):
if not orgdbt:
raise HttpError(404, "DBT workspace not set up")

task = sync_sources_for_warehouse.delay(orgdbt.id, org_warehouse.id, org_warehouse.org.slug)
task_id = str(uuid.uuid4())
hashkey = f"{TaskProgressHashPrefix.SYNCSOURCES.value}-{org.slug}"

return {"task_id": task.id}
taskprogress = TaskProgress(
task_id=task_id,
hashkey=hashkey,
expire_in_seconds=10 * 60, # max 10 minutes)
)
taskprogress.add(
{
"message": "Started syncing sources",
"status": "runnning",
}
)

sync_sources_for_warehouse.delay(orgdbt.id, org_warehouse.id, task_id, hashkey)

return {"task_id": task_id, "hashkey": hashkey}


########################## Models & Sources #############################################
Expand Down
82 changes: 45 additions & 37 deletions ddpui/api/warehouse_api.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import threading
from typing import List
import json
import csv
from io import StringIO
Expand All @@ -12,6 +14,7 @@
from django.http import StreamingHttpResponse
from ddpui import auth
from ddpui.core import dbtautomation_service
from ddpui.core.warehousefunctions import get_warehouse_data, fetch_warehouse_tables
from ddpui.models.org import OrgWarehouse
from ddpui.models.org_user import OrgUser
from ddpui.models.tasks import TaskProgressHashPrefix
Expand All @@ -38,48 +41,12 @@
from ddpui.utils import secretsmanager
from ddpui.utils.helpers import convert_to_standard_types
from ddpui.utils.constants import LIMIT_ROWS_TO_SEND_TO_LLM
from ddpui.utils.redis_client import RedisClient

warehouse_router = Router()
logger = CustomLogger("ddpui")


def get_warehouse_data(request, data_type: str, **kwargs):
"""
Fetches data from a warehouse based on the data type
and optional parameters
"""
try:
org_user = request.orguser
org_warehouse = OrgWarehouse.objects.filter(org=org_user.org).first()

data = []
client = dbtautomation_service._get_wclient(org_warehouse)
if data_type == "tables":
data = client.get_tables(kwargs["schema_name"])
elif data_type == "schemas":
data = client.get_schemas()
elif data_type == "table_columns":
data = client.get_table_columns(kwargs["schema_name"], kwargs["table_name"])
elif data_type == "table_data":
data = client.get_table_data(
schema=kwargs["schema_name"],
table=kwargs["table_name"],
limit=kwargs["limit"],
page=kwargs["page"],
order_by=kwargs["order_by"],
order=kwargs["order"],
)
for element in data:
for key, value in element.items():
if (isinstance(value, list) or isinstance(value, dict)) and value:
element[key] = json.dumps(value)
except Exception as error:
logger.exception(f"Exception occurred in get_{data_type}: {error}")
raise HttpError(500, f"Failed to get {data_type}")

return convert_to_standard_types(data)


@warehouse_router.get("/tables/{schema_name}", auth=auth.CustomAuthMiddleware())
@has_permission(["can_view_warehouse_data"])
def get_table(request, schema_name: str):
Expand Down Expand Up @@ -471,3 +438,44 @@ def get_warehouse_llm_analysis_sessions(
for session in sessions
],
}


@warehouse_router.get(
"/sync_tables",
auth=auth.CustomAuthMiddleware(),
)
@has_permission(["can_view_warehouse_data"])
def get_warehouse_schemas_and_tables(
request,
):
"""
Get all tables under all schemas in the warehouse
Read from warehouse directly if no cache is found
"""
orguser: OrgUser = request.orguser
org = orguser.org

org_warehouse = OrgWarehouse.objects.filter(org=org).first()
if not org_warehouse:
raise HttpError(404, "Please set up your warehouse first")

res = []
cache_key = f"{org.slug}_warehouse_tables"
redis_client = RedisClient.get_instance()
try:
# fetch & set response in redis asynchronously
if redis_client.exists(cache_key):
logger.info("Fetching warehouse tables from cache")
res = json.loads(redis_client.get(cache_key))
threading.Thread(
target=fetch_warehouse_tables, args=(request, org_warehouse, cache_key)
).start()
else:
logger.info("Fetching warehouse tables directly")
res = fetch_warehouse_tables(request, org_warehouse, cache_key)

except Exception as err:
logger.error("Failed to fetch data from the warehouse - %s", err)
raise HttpError(500, "Failed to fetch data from the warehouse") from err

return res
10 changes: 6 additions & 4 deletions ddpui/core/dbtautomation_service.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import os, uuid
import os, uuid, time
from pathlib import Path

from dbt_automation.operations.arithmetic import arithmetic, arithmetic_dbt_sql
Expand Down Expand Up @@ -293,14 +293,16 @@ def propagate_changes_to_downstream_operations(


@app.task(bind=True)
def sync_sources_for_warehouse(self, org_dbt_id: str, org_warehouse_id: str, orgslug: str):
def sync_sources_for_warehouse(
self, org_dbt_id: str, org_warehouse_id: str, task_id: str, hashkey: str
):
"""
Sync all tables in all schemas in the warehouse.
Dbt source name will be the same as the schema name.
"""
taskprogress = TaskProgress(
task_id=self.request.id,
hashkey=f"{TaskProgressHashPrefix.SYNCSOURCES}-{orgslug}",
task_id=task_id,
hashkey=hashkey,
expire_in_seconds=10 * 60, # max 10 minutes
)

Expand Down
76 changes: 76 additions & 0 deletions ddpui/core/warehousefunctions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
import json
from ninja.errors import HttpError

from ddpui.core import dbtautomation_service
from ddpui.utils.custom_logger import CustomLogger
from ddpui.utils.helpers import convert_to_standard_types
from ddpui.models.org import OrgWarehouse
from ddpui.utils.redis_client import RedisClient

logger = CustomLogger("ddpui")


def get_warehouse_data(request, data_type: str, **kwargs):
"""
Fetches data from a warehouse based on the data type
and optional parameters
"""
try:
org_warehouse = kwargs.get("org_warehouse", None)
if not org_warehouse:
org_user = request.orguser
org_warehouse = OrgWarehouse.objects.filter(org=org_user.org).first()

data = []
client = dbtautomation_service._get_wclient(org_warehouse)
if data_type == "tables":
data = client.get_tables(kwargs["schema_name"])
elif data_type == "schemas":
data = client.get_schemas()
elif data_type == "table_columns":
data = client.get_table_columns(kwargs["schema_name"], kwargs["table_name"])
elif data_type == "table_data":
data = client.get_table_data(
schema=kwargs["schema_name"],
table=kwargs["table_name"],
limit=kwargs["limit"],
page=kwargs["page"],
order_by=kwargs["order_by"],
order=kwargs["order"],
)
for element in data:
for key, value in element.items():
if (isinstance(value, list) or isinstance(value, dict)) and value:
element[key] = json.dumps(value)
except Exception as error:
logger.exception(f"Exception occurred in get_{data_type}: {error}")
raise HttpError(500, f"Failed to get {data_type}")

return convert_to_standard_types(data)


def fetch_warehouse_tables(request, org_warehouse, cache_key=None):
"""
Fetch all the tables from the warehouse
Cache the results
"""
res = []
Ishankoradia marked this conversation as resolved.
Show resolved Hide resolved
schemas = get_warehouse_data(request, "schemas", org_warehouse=org_warehouse)
logger.info(f"Inside helper function for fetching tables : {cache_key}")
for schema in schemas:
for table in get_warehouse_data(
request, "tables", schema_name=schema, org_warehouse=org_warehouse
):
res.append(
{
"schema": schema,
"input_name": table,
"type": "src_model_node",
"id": schema + "-" + table,
}
)

if cache_key:
RedisClient.get_instance().set(cache_key, json.dumps(res))

return res
Loading