diff --git a/backend/analytics_server/mhq/service/ai/ai_analytics_service.py b/backend/analytics_server/mhq/service/ai/ai_analytics_service.py index 14a3d0dc..e7b01b74 100644 --- a/backend/analytics_server/mhq/service/ai/ai_analytics_service.py +++ b/backend/analytics_server/mhq/service/ai/ai_analytics_service.py @@ -2,7 +2,7 @@ import requests from http import HTTPStatus from enum import Enum -from typing import Dict, List +from typing import Dict, List, Union class AIProvider(Enum): @@ -44,7 +44,30 @@ def __init__(self, llm: LLM, access_token: str): def _get_message(self, message: str, role: str = "user"): return {"role": role, "content": message} + def _handle_api_response(self, response) -> Dict[str, Union[str, int]]: + """ + Handles the API response, returning a success or error structure that the frontend can use. + """ + if response.status_code == HTTPStatus.OK: + return { + "status": "success", + "data": response.json()["choices"][0]["message"]["content"], + } + elif response.status_code == HTTPStatus.UNAUTHORIZED: + return { + "status": "error", + "message": "Unauthorized Access: Your access token is either missing, expired, or invalid. Please ensure that you are providing a valid token. ", + } + else: + return { + "status": "error", + "message": f"Unexpected error: {response.text}", + } + def _open_ai_fetch_completion_open_ai(self, messages: List[Dict[str, str]]): + """ + Handles the request to OpenAI API for fetching completions. + """ payload = { "model": self.LLM_NAME_TO_MODEL_MAP[self._llm], "temperature": 0.6, @@ -53,13 +76,12 @@ def _open_ai_fetch_completion_open_ai(self, messages: List[Dict[str, str]]): api_url = "https://api.openai.com/v1/chat/completions" response = requests.post(api_url, headers=self._headers, json=payload) - print(payload, api_url, response) - if response.status_code != HTTPStatus.OK: - raise Exception(response.json()) - - return response.json() + return self._handle_api_response(response) def _fireworks_ai_fetch_completions(self, messages: List[Dict[str, str]]): + """ + Handles the request to Fireworks AI API for fetching completions. + """ payload = { "model": self.LLM_NAME_TO_MODEL_MAP[self._llm], "temperature": 0.6, @@ -73,28 +95,28 @@ def _fireworks_ai_fetch_completions(self, messages: List[Dict[str, str]]): api_url = "https://api.fireworks.ai/inference/v1/chat/completions" response = requests.post(api_url, headers=self._headers, json=payload) - if response.status_code != HTTPStatus.OK: - raise Exception(response.json()) - - return response.json() - - def _fetch_completion(self, messages: List[Dict[str, str]]): + return self._handle_api_response(response) + def _fetch_completion( + self, messages: List[Dict[str, str]] + ) -> Dict[str, Union[str, int]]: + """ + Fetches the completion using the appropriate AI provider based on the LLM. + """ if self._ai_provider == AIProvider.FIREWORKS_AI: - return self._fireworks_ai_fetch_completions(messages)["choices"][0][ - "message" - ]["content"] + return self._fireworks_ai_fetch_completions(messages) if self._ai_provider == AIProvider.OPEN_AI: - return self._open_ai_fetch_completion_open_ai(messages)["choices"][0][ - "message" - ]["content"] + return self._open_ai_fetch_completion_open_ai(messages) - raise Exception(f"Invalid AI provider {self._ai_provider}") + return { + "status": "error", + "message": f"Invalid AI provider {self._ai_provider}", + } def get_dora_metrics_score( self, four_keys_data: Dict[str, float] - ) -> Dict[str, str]: + ) -> Dict[str, Union[str, int]]: """ Calculate the DORA metrics score using input data and an LLM (Language Learning Model). diff --git a/web-server/pages/api/internal/ai/dora_metrics.ts b/web-server/pages/api/internal/ai/dora_metrics.ts index 551ff75f..d14145aa 100644 --- a/web-server/pages/api/internal/ai/dora_metrics.ts +++ b/web-server/pages/api/internal/ai/dora_metrics.ts @@ -64,50 +64,78 @@ const postSchema = yup.object().shape({ }); const endpoint = new Endpoint(nullSchema); - endpoint.handle.POST(postSchema, async (req, res) => { const { data, model, access_token } = req.payload; + const dora_data = data as TeamDoraMetricsApiResponseType; + + try { + const [ + doraMetricsScore, + leadTimeSummary, + CFRSummary, + MTTRSummary, + deploymentFrequencySummary, + doraTrendSummary + ] = await Promise.all( + [ + getDoraMetricsScore, + getLeadTimeSummary, + getCFRSummary, + getMTTRSummary, + getDeploymentFrequencySummary, + getDoraTrendsCorrelationSummary + ].map((fn) => fn(dora_data, model, access_token)) + ); + + const aggregatedData = { + ...doraMetricsScore, + ...leadTimeSummary, + ...CFRSummary, + ...MTTRSummary, + ...deploymentFrequencySummary, + ...doraTrendSummary + }; - const dora_data = data as unknown as TeamDoraMetricsApiResponseType; - - const [ - dora_metrics_score, - lead_time_trends_summary, - change_failure_rate_trends_summary, - mean_time_to_recovery_trends_summary, - deployment_frequency_trends_summary, - dora_trend_summary - ] = await Promise.all( - [ - getDoraMetricsScore, - getLeadTimeSummary, - getCFRSummary, - getMTTRSummary, - getDeploymentFrequencySummary, - getDoraTrendsCorrelationSummary - ].map((f) => f(dora_data, model, access_token)) - ); + const compiledSummary = await getDORACompiledSummary( + aggregatedData, + model, + access_token + ); - const aggregated_dora_data = { - ...dora_metrics_score, - ...lead_time_trends_summary, - ...change_failure_rate_trends_summary, - ...mean_time_to_recovery_trends_summary, - ...deployment_frequency_trends_summary, - ...dora_trend_summary - } as AggregatedDORAData; - - const dora_compiled_summary = await getDORACompiledSummary( - aggregated_dora_data, - model, - access_token - ); + const responses = { + ...aggregatedData, + ...compiledSummary + }; - res.send({ - ...aggregated_dora_data, - ...dora_compiled_summary - }); + const { status, message } = checkForErrors(responses); + + if (status === 'error') { + return res.status(400).send({ message }); + } + + const simplifiedData = Object.fromEntries( + Object.entries(responses).map(([key, value]) => [key, value.data]) + ); + + return res.status(200).send(simplifiedData); + } catch (error) { + return res.status(500).send({ + message: 'Internal Server Error', + error: error.message + }); + } }); +const checkForErrors = ( + responses: Record +): { status: string; message: string } => { + const errorResponse = Object.values(responses).find( + (value) => value.status === 'error' + ); + + return errorResponse + ? { status: 'error', message: errorResponse.message } + : { status: 'success', message: '' }; +}; const getDoraMetricsScore = ( dora_data: TeamDoraMetricsApiResponseType,