Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

dashboard fixes #53

Merged
merged 1 commit into from
Sep 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 0 additions & 3 deletions aryaxai/common/constants.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,5 @@
MODEL_TYPES = ["classification", "regression"]

TARGET_DRIFT_MODEL_TYPES = ["classification"]


DATA_DRIFT_DASHBOARD_REQUIRED_FIELDS = [
"base_line_tag",
"current_tag",
Expand Down
20 changes: 12 additions & 8 deletions aryaxai/common/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

from aryaxai.common.xai_uris import POLL_EVENTS


def parse_float(s):
"""parse float from string, return None if not possible

Expand All @@ -16,7 +17,8 @@ def parse_float(s):
except ValueError:
return None

def parse_datetime(s, format='%Y-%m-%d %H:%M:%S'):

def parse_datetime(s, format="%Y-%m-%d %H:%M:%S"):
"""Parse datetime from string, return None if not possible

:param s: string to parse
Expand All @@ -28,21 +30,23 @@ def parse_datetime(s, format='%Y-%m-%d %H:%M:%S'):
except ValueError:
return None


def pretty_date(date: str) -> str:
"""return date in format dd-mm-YYYY HH:MM:SS

:param date: str datetime
:return: pretty datetime
"""
try:
datetime_obj = datetime.strptime(date, '%Y-%m-%dT%H:%M:%S.%f')
datetime_obj = datetime.strptime(date, "%Y-%m-%dT%H:%M:%S.%f")
except ValueError:
try:
datetime_obj = datetime.strptime(date, '%Y-%m-%d %H:%M:%S.%f')
datetime_obj = datetime.strptime(date, "%Y-%m-%d %H:%M:%S.%f")
except ValueError:
print("Date format invalid.")

return datetime_obj.strftime('%d-%m-%Y %H:%M:%S')
return datetime_obj.strftime("%d-%m-%Y %H:%M:%S")


def poll_events(
api_client: APIClient,
Expand All @@ -62,9 +66,9 @@ def poll_events(

if not event.get("success"):
raise Exception(details)
if details.get("logs"):
print(details.get("logs")[log_length:])
log_length = len(details.get("logs"))
if details.get("event_logs"):
print(details.get("event_logs")[log_length:])
log_length = len(details.get("event_logs"))
if details.get("message") != last_message:
last_message = details.get("message")
print(f"{details.get('message')}")
Expand All @@ -76,4 +80,4 @@ def poll_events(
if details.get("status") == "failed":
if handle_failed_event:
handle_failed_event()
raise Exception(details.get("message"))
raise Exception(details.get("message"))
4 changes: 3 additions & 1 deletion aryaxai/common/xai_uris.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@
API_VERSION = os.getenv("XAI_API_VERSION", "v1")
API_VERSION_V2 = "v2"

# APP
XAI_APP_URI = "https://beta.aryaxai.com"

# URIs of XAI base service starts here
# Auth
LOGIN_URI = f"{API_VERSION}/access-token/authorize"
Expand Down Expand Up @@ -48,7 +51,6 @@
TAG_DATA_URI = f"{API_VERSION_V2}/project/tag_data"

# Monitoring
HOSTED_DASHBOARD_URI = "https://beta.aryaxai.com/sdk/dashboard"
GENERATE_DASHBOARD_URI = f"{API_VERSION_V2}/dashboards/generate_dashboard"
DASHBOARD_CONFIG_URI = f"{API_VERSION_V2}/dashboards/config"
MODEL_PERFORMANCE_DASHBOARD_URI = (
Expand Down
9 changes: 7 additions & 2 deletions aryaxai/core/dashboard.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,15 @@
import os
from typing import Any
from pydantic import BaseModel
import json
from IPython.display import IFrame, display

from aryaxai.common.xai_uris import XAI_APP_URI


class Dashboard(BaseModel):
config: dict
url: str
query_params: str
raw_data: dict | list

def __init__(self, **kwargs):
Expand All @@ -22,7 +25,9 @@ def plot(self, width: int = "100%", height: int = 800):
width (int, optional): _description_. Defaults to 100%.
height (int, optional): _description_. Defaults to 650.
"""
display(IFrame(src=f"{self.url}", width=width, height=height))
uri = os.environ.get("XAI_APP_URL", XAI_APP_URI)
url = f"{uri}/sdk/dashboard{self.query_params}"
display(IFrame(src=f"{url}", width=width, height=height))

def get_config(self) -> dict:
"""
Expand Down
3 changes: 2 additions & 1 deletion aryaxai/core/organization.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from aryaxai.client.client import APIClient
from aryaxai.common.validation import Validate
from aryaxai.common.xai_uris import (
AVAILABLE_CUSTOM_SERVERS_URI,
CREATE_WORKSPACE_URI,
GET_WORKSPACES_URI,
INVITE_USER_ORGANIZATION_URI,
Expand Down Expand Up @@ -159,7 +160,7 @@ def create_workspace(
payload = {"workspace_name": workspace_name}

if server_type:
custom_servers = self.available_custom_servers()
custom_servers = self.api_client.get(AVAILABLE_CUSTOM_SERVERS_URI)
Validate.value_against_list(
"server_type",
server_type,
Expand Down
17 changes: 6 additions & 11 deletions aryaxai/core/project.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
DATA_DRIFT_STAT_TESTS,
SYNTHETIC_MODELS_DEFAULT_HYPER_PARAMS,
TARGET_DRIFT_DASHBOARD_REQUIRED_FIELDS,
TARGET_DRIFT_MODEL_TYPES,
TARGET_DRIFT_STAT_TESTS,
BIAS_MONITORING_DASHBOARD_REQUIRED_FIELDS,
MODEL_PERF_DASHBOARD_REQUIRED_FIELDS,
Expand Down Expand Up @@ -85,7 +84,6 @@
GET_SYNTHETIC_MODELS_URI,
GET_SYNTHETIC_PROMPT_URI,
GET_VIEWED_CASE_URI,
HOSTED_DASHBOARD_URI,
MODEL_INFERENCES_URI,
MODEL_PARAMETERS_URI,
MODEL_SUMMARY_URI,
Expand Down Expand Up @@ -918,11 +916,10 @@ def get_default_dashboard(self, type: str) -> Dashboard:
if res["success"]:
auth_token = self.api_client.get_auth_token()
query_params = f"?project_name={self.project_name}&type={type}&access_token={auth_token}"

return Dashboard(
config=res.get("config"),
raw_data=res.get("details"),
url=f"{HOSTED_DASHBOARD_URI}{query_params}",
query_params=query_params,
)

raise Exception(
Expand Down Expand Up @@ -1118,9 +1115,7 @@ def get_target_drift_dashboard(

Validate.validate_date_feature_val(payload, tags_info["alldatetimefeatures"])

Validate.value_against_list(
"model_type", payload["model_type"], TARGET_DRIFT_MODEL_TYPES
)
Validate.value_against_list("model_type", payload["model_type"], MODEL_TYPES)

Validate.value_against_list(
"stat_test_name", payload["stat_test_name"], TARGET_DRIFT_STAT_TESTS
Expand Down Expand Up @@ -1451,7 +1446,7 @@ def get_dashboard(self, type: str, dashboard_id: str) -> Dashboard:
return Dashboard(
config=res.get("config"),
raw_data=res.get("details"),
url=f"{HOSTED_DASHBOARD_URI}{query_params}",
query_params=query_params,
)

def monitoring_triggers(self) -> pd.DataFrame:
Expand Down Expand Up @@ -1762,12 +1757,12 @@ def get_model_performance(self, model_name: str = None) -> Dashboard:
get model performance dashboard
"""
auth_token = self.api_client.get_auth_token()
url = f"{HOSTED_DASHBOARD_URI}?type=model_performance&project_name={self.project_name}&access_token={auth_token}"
query_params = f"?type=model_performance&project_name={self.project_name}&access_token={auth_token}"

if model_name:
url = f"{url}&model_name={model_name}"
query_params = f"{query_params}&model_name={model_name}"

return Dashboard(config={}, url=url, raw_data={})
return Dashboard(config={}, query_params=query_params, raw_data={})

def model_parameters(self) -> dict:
"""Model Parameters
Expand Down