From 7b3d9e9234cda1a768d88956c031124dd0231865 Mon Sep 17 00:00:00 2001 From: Andrii Ieroshenko Date: Wed, 15 Nov 2023 10:47:01 -0800 Subject: [PATCH 01/10] handle config validation error on the backend --- .../jupyter-ai/jupyter_ai/config_manager.py | 140 +++++++++++------- packages/jupyter-ai/jupyter_ai/handlers.py | 5 + packages/jupyter-ai/jupyter_ai/models.py | 6 + 3 files changed, 96 insertions(+), 55 deletions(-) diff --git a/packages/jupyter-ai/jupyter_ai/config_manager.py b/packages/jupyter-ai/jupyter_ai/config_manager.py index 63bb003f8..1383b62e9 100644 --- a/packages/jupyter-ai/jupyter_ai/config_manager.py +++ b/packages/jupyter-ai/jupyter_ai/config_manager.py @@ -7,7 +7,12 @@ from deepmerge import always_merger as Merger from jsonschema import Draft202012Validator as Validator -from jupyter_ai.models import DescribeConfigResponse, GlobalConfig, UpdateConfigRequest +from jupyter_ai.models import ( + APIErrorModel, + DescribeConfigResponse, + GlobalConfig, + UpdateConfigRequest, +) from jupyter_ai_magics.utils import ( AnyProvider, EmProvidersDict, @@ -16,6 +21,7 @@ get_lm_provider, ) from jupyter_core.paths import jupyter_data_dir +from pydantic import ValidationError from traitlets import Integer, Unicode from traitlets.config import Configurable @@ -70,6 +76,16 @@ def _validate_provider_authn(config: GlobalConfig, provider: AnyProvider): ) +def _format_validation_errors(error: ValidationError): + """Format Pydantic validation errors for user-friendly output.""" + messages = [] + for e in error.errors(): + field_path = " -> ".join(map(str, e["loc"])) + error_message = f"Error in '{field_path}': {e['msg']}. Please review and correct this field." + messages.append(error_message) + return "Configuration Error: " + " | ".join(messages) + + class ConfigManager(Configurable): """Provides model and embedding provider id along with the credentials to authenticate providers. @@ -111,6 +127,7 @@ def __init__( super().__init__(*args, **kwargs) self.log = log + self._config_error = None self._lm_providers = lm_providers """List of LM providers.""" self._em_providers = em_providers @@ -146,60 +163,70 @@ def _init_validator(self) -> Validator: self.validator = Validator(schema) def _init_config(self): - if os.path.exists(self.config_path): - with open(self.config_path, encoding="utf-8") as f: - config = GlobalConfig(**json.loads(f.read())) - lm_id = config.model_provider_id - em_id = config.embeddings_provider_id - - # if the currently selected language or embedding model are - # forbidden, set them to `None` and log a warning. - if lm_id is not None and not self._validate_model( - lm_id, raise_exc=False - ): - self.log.warning( - f"Language model {lm_id} is forbidden by current allow/blocklists. Setting to None." - ) - config.model_provider_id = None - if em_id is not None and not self._validate_model( - em_id, raise_exc=False - ): - self.log.warning( - f"Embedding model {em_id} is forbidden by current allow/blocklists. Setting to None." - ) - config.embeddings_provider_id = None - - # if the currently selected language or embedding model ids are - # not associated with models, set them to `None` and log a warning. - if ( - lm_id is not None - and not get_lm_provider(lm_id, self._lm_providers)[1] - ): - self.log.warning( - f"No language model is associated with '{lm_id}'. Setting to None." - ) - config.model_provider_id = None - if ( - em_id is not None - and not get_em_provider(em_id, self._em_providers)[1] - ): - self.log.warning( - f"No embedding model is associated with '{em_id}'. Setting to None." - ) - config.embeddings_provider_id = None - - # re-write to the file to validate the config and apply any - # updates to the config file immediately - self._write_config(config) - return - - properties = self.validator.schema.get("properties", {}) - field_list = GlobalConfig.__fields__.keys() - field_dict = { - field: properties.get(field).get("default") for field in field_list - } - default_config = GlobalConfig(**field_dict) - self._write_config(default_config) + try: + if os.path.exists(self.config_path): + with open(self.config_path, encoding="utf-8") as f: + config = GlobalConfig(**json.loads(f.read())) + lm_id = config.model_provider_id + em_id = config.embeddings_provider_id + + # if the currently selected language or embedding model are + # forbidden, set them to `None` and log a warning. + if lm_id is not None and not self._validate_model( + lm_id, raise_exc=False + ): + self.log.warning( + f"Language model {lm_id} is forbidden by current allow/blocklists. Setting to None." + ) + config.model_provider_id = None + if em_id is not None and not self._validate_model( + em_id, raise_exc=False + ): + self.log.warning( + f"Embedding model {em_id} is forbidden by current allow/blocklists. Setting to None." + ) + config.embeddings_provider_id = None + + # if the currently selected language or embedding model ids are + # not associated with models, set them to `None` and log a warning. + if ( + lm_id is not None + and not get_lm_provider(lm_id, self._lm_providers)[1] + ): + self.log.warning( + f"No language model is associated with '{lm_id}'. Setting to None." + ) + config.model_provider_id = None + if ( + em_id is not None + and not get_em_provider(em_id, self._em_providers)[1] + ): + self.log.warning( + f"No embedding model is associated with '{em_id}'. Setting to None." + ) + config.embeddings_provider_id = None + + # re-write to the file to validate the config and apply any + # updates to the config file immediately + self._write_config(config) + return + + properties = self.validator.schema.get("properties", {}) + field_list = GlobalConfig.__fields__.keys() + field_dict = { + field: properties.get(field).get("default") for field in field_list + } + default_config = GlobalConfig(**field_dict) + self._write_config(default_config) + + except ValidationError as e: + formatted_error = _format_validation_errors(e) + self.config_error = APIErrorModel( + type="ValidationError", + message="Configuration validation failed", + details=formatted_error, + ) + self.log.error(f"Configuration validation error: {self.config_error}") def _read_config(self) -> GlobalConfig: """Returns the user's current configuration as a GlobalConfig object. @@ -333,6 +360,9 @@ def delete_api_key(self, key_name: str): config_dict["api_keys"].pop(key_name, None) self._write_config(GlobalConfig(**config_dict)) + def get_config_error(self): + return self._config_error + def update_config(self, config_update: UpdateConfigRequest): last_write = os.stat(self.config_path).st_mtime_ns if config_update.last_read and config_update.last_read < last_write: diff --git a/packages/jupyter-ai/jupyter_ai/handlers.py b/packages/jupyter-ai/jupyter_ai/handlers.py index 126a6c94c..bc70f95cf 100644 --- a/packages/jupyter-ai/jupyter_ai/handlers.py +++ b/packages/jupyter-ai/jupyter_ai/handlers.py @@ -361,6 +361,11 @@ def config_manager(self): @web.authenticated def get(self): + if self.config_manager.config_error: + self.set_status(400) + self.finish(self.config_manager.config_error.json()) + return + config = self.config_manager.get_config() if not config: raise HTTPError(500, "No config found.") diff --git a/packages/jupyter-ai/jupyter_ai/models.py b/packages/jupyter-ai/jupyter_ai/models.py index 41509a74d..261704c89 100644 --- a/packages/jupyter-ai/jupyter_ai/models.py +++ b/packages/jupyter-ai/jupyter_ai/models.py @@ -143,3 +143,9 @@ class GlobalConfig(BaseModel): send_with_shift_enter: bool fields: Dict[str, Dict[str, Any]] api_keys: Dict[str, str] + + +class APIErrorModel(BaseModel): + type: str = "APIError" + message: str + details: str = None From abe6ab3cb9c001a31e36a144ae1cc0bd800461c1 Mon Sep 17 00:00:00 2001 From: Andrii Ieroshenko Date: Wed, 15 Nov 2023 23:43:11 -0800 Subject: [PATCH 02/10] refactor _init_config --- .../jupyter-ai/jupyter_ai/config_manager.py | 127 +++++++++--------- packages/jupyter-ai/jupyter_ai/handlers.py | 5 +- 2 files changed, 67 insertions(+), 65 deletions(-) diff --git a/packages/jupyter-ai/jupyter_ai/config_manager.py b/packages/jupyter-ai/jupyter_ai/config_manager.py index 1383b62e9..65d3ab159 100644 --- a/packages/jupyter-ai/jupyter_ai/config_manager.py +++ b/packages/jupyter-ai/jupyter_ai/config_manager.py @@ -163,70 +163,71 @@ def _init_validator(self) -> Validator: self.validator = Validator(schema) def _init_config(self): - try: - if os.path.exists(self.config_path): - with open(self.config_path, encoding="utf-8") as f: - config = GlobalConfig(**json.loads(f.read())) - lm_id = config.model_provider_id - em_id = config.embeddings_provider_id - - # if the currently selected language or embedding model are - # forbidden, set them to `None` and log a warning. - if lm_id is not None and not self._validate_model( - lm_id, raise_exc=False - ): - self.log.warning( - f"Language model {lm_id} is forbidden by current allow/blocklists. Setting to None." - ) - config.model_provider_id = None - if em_id is not None and not self._validate_model( - em_id, raise_exc=False - ): - self.log.warning( - f"Embedding model {em_id} is forbidden by current allow/blocklists. Setting to None." - ) - config.embeddings_provider_id = None - - # if the currently selected language or embedding model ids are - # not associated with models, set them to `None` and log a warning. - if ( - lm_id is not None - and not get_lm_provider(lm_id, self._lm_providers)[1] - ): - self.log.warning( - f"No language model is associated with '{lm_id}'. Setting to None." - ) - config.model_provider_id = None - if ( - em_id is not None - and not get_em_provider(em_id, self._em_providers)[1] - ): - self.log.warning( - f"No embedding model is associated with '{em_id}'. Setting to None." - ) - config.embeddings_provider_id = None - - # re-write to the file to validate the config and apply any - # updates to the config file immediately - self._write_config(config) - return - - properties = self.validator.schema.get("properties", {}) - field_list = GlobalConfig.__fields__.keys() - field_dict = { - field: properties.get(field).get("default") for field in field_list - } - default_config = GlobalConfig(**field_dict) - self._write_config(default_config) - - except ValidationError as e: - formatted_error = _format_validation_errors(e) - self.config_error = APIErrorModel( - type="ValidationError", - message="Configuration validation failed", - details=formatted_error, + # try: + if os.path.exists(self.config_path): + self._process_existing_config() + else: + self._create_default_config() + # except ValidationError as e: + # self._handle_validation_error(e) + + def _process_existing_config(self): + with open(self.config_path, encoding="utf-8") as f: + config = GlobalConfig(**json.loads(f.read())) + self._validate_lm_em_id(config) + self._write_config(config) + + def _create_default_config(self): + properties = self.validator.schema.get("properties", {}) + field_list = GlobalConfig.__fields__.keys() + field_dict = { + field: properties.get(field).get("default") for field in field_list + } + default_config = GlobalConfig(**field_dict) + self._write_config(default_config) + + def _validate_lm_em_id(self, config): + lm_id = config.model_provider_id + em_id = config.embeddings_provider_id + + # if the currently selected language or embedding model are + # forbidden, set them to `None` and log a warning. + if lm_id is not None and not self._validate_model(lm_id, raise_exc=False): + self.log.warning( + f"Language model {lm_id} is forbidden by current allow/blocklists. Setting to None." + ) + config.model_provider_id = None + if em_id is not None and not self._validate_model(em_id, raise_exc=False): + self.log.warning( + f"Embedding model {em_id} is forbidden by current allow/blocklists. Setting to None." ) - self.log.error(f"Configuration validation error: {self.config_error}") + config.embeddings_provider_id = None + + # if the currently selected language or embedding model ids are + # not associated with models, set them to `None` and log a warning. + if lm_id is not None and not get_lm_provider(lm_id, self._lm_providers)[1]: + self.log.warning( + f"No language model is associated with '{lm_id}'. Setting to None." + ) + config.model_provider_id = None + if em_id is not None and not get_em_provider(em_id, self._em_providers)[1]: + self.log.warning( + f"No embedding model is associated with '{em_id}'. Setting to None." + ) + config.embeddings_provider_id = None + + # re-write to the file to validate the config and apply any + # updates to the config file immediately + self._write_config(config) + + def _handle_validation_error(self, e: ValidationError): + formatted_error = _format_validation_errors(e) + self._config_error = APIErrorModel( + type="ValidationError", + message="Configuration validation failed", + details=formatted_error, + ) + self.log.error(f"Configuration validation error: {self.config_error}") def _read_config(self) -> GlobalConfig: """Returns the user's current configuration as a GlobalConfig object. diff --git a/packages/jupyter-ai/jupyter_ai/handlers.py b/packages/jupyter-ai/jupyter_ai/handlers.py index bc70f95cf..e69f33894 100644 --- a/packages/jupyter-ai/jupyter_ai/handlers.py +++ b/packages/jupyter-ai/jupyter_ai/handlers.py @@ -361,9 +361,10 @@ def config_manager(self): @web.authenticated def get(self): - if self.config_manager.config_error: + config_error = self.config_manager.get_config_error() + if config_error: self.set_status(400) - self.finish(self.config_manager.config_error.json()) + self.finish(config_error.json()) return config = self.config_manager.get_config() From 9f8ac022b1bf21a4effc28e585a4b7ec5e24e812 Mon Sep 17 00:00:00 2001 From: Andrii Ieroshenko Date: Wed, 15 Nov 2023 23:46:22 -0800 Subject: [PATCH 03/10] rename APIErrorModel -> ErrorModel --- packages/jupyter-ai/jupyter_ai/config_manager.py | 4 ++-- packages/jupyter-ai/jupyter_ai/models.py | 6 +++--- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/packages/jupyter-ai/jupyter_ai/config_manager.py b/packages/jupyter-ai/jupyter_ai/config_manager.py index 65d3ab159..68b44406d 100644 --- a/packages/jupyter-ai/jupyter_ai/config_manager.py +++ b/packages/jupyter-ai/jupyter_ai/config_manager.py @@ -8,7 +8,7 @@ from deepmerge import always_merger as Merger from jsonschema import Draft202012Validator as Validator from jupyter_ai.models import ( - APIErrorModel, + ErrorModel, DescribeConfigResponse, GlobalConfig, UpdateConfigRequest, @@ -222,7 +222,7 @@ def _validate_lm_em_id(self, config): def _handle_validation_error(self, e: ValidationError): formatted_error = _format_validation_errors(e) - self._config_error = APIErrorModel( + self._config_error = ErrorModel( type="ValidationError", message="Configuration validation failed", details=formatted_error, diff --git a/packages/jupyter-ai/jupyter_ai/models.py b/packages/jupyter-ai/jupyter_ai/models.py index 261704c89..2b76abf59 100644 --- a/packages/jupyter-ai/jupyter_ai/models.py +++ b/packages/jupyter-ai/jupyter_ai/models.py @@ -145,7 +145,7 @@ class GlobalConfig(BaseModel): api_keys: Dict[str, str] -class APIErrorModel(BaseModel): - type: str = "APIError" - message: str +class ErrorModel(BaseModel): + type: str = None + message: str = None details: str = None From 0da3683bbac1fd175eb42a29a6f12bc0c9c573a1 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 16 Nov 2023 07:46:33 +0000 Subject: [PATCH 04/10] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- packages/jupyter-ai/jupyter_ai/config_manager.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/packages/jupyter-ai/jupyter_ai/config_manager.py b/packages/jupyter-ai/jupyter_ai/config_manager.py index 68b44406d..cf7bf7fed 100644 --- a/packages/jupyter-ai/jupyter_ai/config_manager.py +++ b/packages/jupyter-ai/jupyter_ai/config_manager.py @@ -8,8 +8,8 @@ from deepmerge import always_merger as Merger from jsonschema import Draft202012Validator as Validator from jupyter_ai.models import ( - ErrorModel, DescribeConfigResponse, + ErrorModel, GlobalConfig, UpdateConfigRequest, ) From eb0b755db5a905ba0eea1388a27caab4c2088a54 Mon Sep 17 00:00:00 2001 From: Andrii Ieroshenko Date: Thu, 16 Nov 2023 13:24:35 -0800 Subject: [PATCH 05/10] refactor initialize_settings to check for settings config error --- packages/jupyter-ai/jupyter_ai/extension.py | 17 +++++++++++++---- 1 file changed, 13 insertions(+), 4 deletions(-) diff --git a/packages/jupyter-ai/jupyter_ai/extension.py b/packages/jupyter-ai/jupyter_ai/extension.py index 8ab8c0cc6..772a6f529 100644 --- a/packages/jupyter-ai/jupyter_ai/extension.py +++ b/packages/jupyter-ai/jupyter_ai/extension.py @@ -130,10 +130,22 @@ def initialize_settings(self): blocked_models=self.blocked_models, ) - self.log.info("Registered providers.") + config_error = self.settings["jai_config_manager"].get_config_error() + if config_error: + # Log the error and proceed with limited functionality + self.log.error(f"Configuration error detected: {config_error}") + # TODO: self._initialize_limited_functionality() + else: + # Full functionality initialization + self._initialize_full_functionality() + self.log.info("Registered providers.") self.log.info(f"Registered {self.name} server extension") + latency_ms = round((time.time() - start) * 1000) + self.log.info(f"Initialized Jupyter AI server extension in {latency_ms} ms.") + + def _initialize_full_functionality(self): # Store chat clients in a dictionary self.settings["chat_clients"] = {} self.settings["jai_root_chat_handlers"] = {} @@ -190,8 +202,5 @@ def initialize_settings(self): "/help": help_chat_handler, } - latency_ms = round((time.time() - start) * 1000) - self.log.info(f"Initialized Jupyter AI server extension in {latency_ms} ms.") - async def _get_dask_client(self): return DaskClient(processes=False, asynchronous=True) From a43955a4173574be7c2a9199b32c8363f8ded9d0 Mon Sep 17 00:00:00 2001 From: Andrii Ieroshenko Date: Mon, 27 Nov 2023 21:57:14 -0800 Subject: [PATCH 06/10] use Exception-based ConfigError instead of ErrorModel --- .../jupyter-ai/jupyter_ai/config_manager.py | 80 +++++++++++++------ packages/jupyter-ai/jupyter_ai/extension.py | 6 +- packages/jupyter-ai/jupyter_ai/models.py | 6 -- 3 files changed, 59 insertions(+), 33 deletions(-) diff --git a/packages/jupyter-ai/jupyter_ai/config_manager.py b/packages/jupyter-ai/jupyter_ai/config_manager.py index cf7bf7fed..782b1b7b3 100644 --- a/packages/jupyter-ai/jupyter_ai/config_manager.py +++ b/packages/jupyter-ai/jupyter_ai/config_manager.py @@ -1,3 +1,4 @@ +from enum import Enum import json import logging import os @@ -9,7 +10,6 @@ from jsonschema import Draft202012Validator as Validator from jupyter_ai.models import ( DescribeConfigResponse, - ErrorModel, GlobalConfig, UpdateConfigRequest, ) @@ -22,6 +22,7 @@ ) from jupyter_core.paths import jupyter_data_dir from pydantic import ValidationError +from tornado.web import HTTPError from traitlets import Integer, Unicode from traitlets.config import Configurable @@ -65,6 +66,21 @@ class BlockedModelError(Exception): pass +class ConfigErrorType(Enum): + CRITICAL = "Critical" + WARNING = "Warning" + + +class ConfigError(Exception): + def __init__(self, error_type: ConfigErrorType, message: str, details: str = None): + self.error_type = error_type + self.message = message + self.details = details + + def __str__(self): + return f"{self.error_type.value} ConfigError: {self.message} - {self.details or ''}" + + def _validate_provider_authn(config: GlobalConfig, provider: AnyProvider): # TODO: handle non-env auth strategies if not provider.auth_strategy or provider.auth_strategy.type != "env": @@ -127,7 +143,7 @@ def __init__( super().__init__(*args, **kwargs) self.log = log - self._config_error = None + self._config_errors = [] self._lm_providers = lm_providers """List of LM providers.""" self._em_providers = em_providers @@ -163,13 +179,13 @@ def _init_validator(self) -> Validator: self.validator = Validator(schema) def _init_config(self): - # try: - if os.path.exists(self.config_path): - self._process_existing_config() - else: - self._create_default_config() - # except ValidationError as e: - # self._handle_validation_error(e) + try: + if os.path.exists(self.config_path): + self._process_existing_config() + else: + self._create_default_config() + except ValidationError as e: + self._handle_validation_error(e) def _process_existing_config(self): with open(self.config_path, encoding="utf-8") as f: @@ -193,28 +209,42 @@ def _validate_lm_em_id(self, config): # if the currently selected language or embedding model are # forbidden, set them to `None` and log a warning. if lm_id is not None and not self._validate_model(lm_id, raise_exc=False): - self.log.warning( - f"Language model {lm_id} is forbidden by current allow/blocklists. Setting to None." - ) + warning_message = f"Language model {lm_id} is forbidden by current allow/blocklists. Setting to None." + self.log.warning(warning_message) config.model_provider_id = None - if em_id is not None and not self._validate_model(em_id, raise_exc=False): - self.log.warning( - f"Embedding model {em_id} is forbidden by current allow/blocklists. Setting to None." + self._config_errors.append = ConfigError( + ConfigErrorType.WARNING, warning_message ) + + if em_id is not None and not self._validate_model(em_id, raise_exc=False): + warning_message = f"Embedding model {em_id} is forbidden by current allow/blocklists. Setting to None." + self.log.warning(warning_message) config.embeddings_provider_id = None + self._config_errors.append = ConfigError( + ConfigErrorType.WARNING, warning_message + ) # if the currently selected language or embedding model ids are # not associated with models, set them to `None` and log a warning. if lm_id is not None and not get_lm_provider(lm_id, self._lm_providers)[1]: - self.log.warning( + warning_message = ( f"No language model is associated with '{lm_id}'. Setting to None." ) + self.log.warning(warning_message) config.model_provider_id = None + self._config_errors.append = ConfigError( + ConfigErrorType.WARNING, warning_message + ) + if em_id is not None and not get_em_provider(em_id, self._em_providers)[1]: - self.log.warning( + warning_message = ( f"No embedding model is associated with '{em_id}'. Setting to None." ) + self.log.warning(warning_message) config.embeddings_provider_id = None + self._config_errors.append = ConfigError( + ConfigErrorType.WARNING, warning_message + ) # re-write to the file to validate the config and apply any # updates to the config file immediately @@ -222,12 +252,11 @@ def _validate_lm_em_id(self, config): def _handle_validation_error(self, e: ValidationError): formatted_error = _format_validation_errors(e) - self._config_error = ErrorModel( - type="ValidationError", - message="Configuration validation failed", - details=formatted_error, + error_message = "Configuration validation failed" + self._config_errors.append( + ConfigError(ConfigErrorType.CRITICAL, error_message, formatted_error) ) - self.log.error(f"Configuration validation error: {self.config_error}") + self.log.error(f"{error_message}: {formatted_error}") def _read_config(self) -> GlobalConfig: """Returns the user's current configuration as a GlobalConfig object. @@ -361,8 +390,11 @@ def delete_api_key(self, key_name: str): config_dict["api_keys"].pop(key_name, None) self._write_config(GlobalConfig(**config_dict)) - def get_config_error(self): - return self._config_error + def get_config_errors(self): + if self._config_errors: + return self._config_errors + else: + return None def update_config(self, config_update: UpdateConfigRequest): last_write = os.stat(self.config_path).st_mtime_ns diff --git a/packages/jupyter-ai/jupyter_ai/extension.py b/packages/jupyter-ai/jupyter_ai/extension.py index 772a6f529..69441cecc 100644 --- a/packages/jupyter-ai/jupyter_ai/extension.py +++ b/packages/jupyter-ai/jupyter_ai/extension.py @@ -130,10 +130,10 @@ def initialize_settings(self): blocked_models=self.blocked_models, ) - config_error = self.settings["jai_config_manager"].get_config_error() - if config_error: + config_errors = self.settings["jai_config_manager"].get_config_errors() + if config_errors: # Log the error and proceed with limited functionality - self.log.error(f"Configuration error detected: {config_error}") + self.log.error(f"Configuration error detected: {config_errors}") # TODO: self._initialize_limited_functionality() else: # Full functionality initialization diff --git a/packages/jupyter-ai/jupyter_ai/models.py b/packages/jupyter-ai/jupyter_ai/models.py index 2b76abf59..41509a74d 100644 --- a/packages/jupyter-ai/jupyter_ai/models.py +++ b/packages/jupyter-ai/jupyter_ai/models.py @@ -143,9 +143,3 @@ class GlobalConfig(BaseModel): send_with_shift_enter: bool fields: Dict[str, Dict[str, Any]] api_keys: Dict[str, str] - - -class ErrorModel(BaseModel): - type: str = None - message: str = None - details: str = None From 883078a5a3b0545ffb8744db3894aeb3bd247b99 Mon Sep 17 00:00:00 2001 From: Andrii Ieroshenko Date: Mon, 27 Nov 2023 22:42:47 -0800 Subject: [PATCH 07/10] pass config error to frontend --- .../jupyter-ai/jupyter_ai/config_manager.py | 30 +++++-------------- packages/jupyter-ai/jupyter_ai/extension.py | 14 +++++---- packages/jupyter-ai/jupyter_ai/handlers.py | 6 ---- packages/jupyter-ai/jupyter_ai/models.py | 15 ++++++++++ 4 files changed, 31 insertions(+), 34 deletions(-) diff --git a/packages/jupyter-ai/jupyter_ai/config_manager.py b/packages/jupyter-ai/jupyter_ai/config_manager.py index 782b1b7b3..ad65924a5 100644 --- a/packages/jupyter-ai/jupyter_ai/config_manager.py +++ b/packages/jupyter-ai/jupyter_ai/config_manager.py @@ -1,4 +1,3 @@ -from enum import Enum import json import logging import os @@ -9,6 +8,8 @@ from deepmerge import always_merger as Merger from jsonschema import Draft202012Validator as Validator from jupyter_ai.models import ( + ConfigErrorModel, + ConfigErrorType, DescribeConfigResponse, GlobalConfig, UpdateConfigRequest, @@ -21,7 +22,7 @@ get_lm_provider, ) from jupyter_core.paths import jupyter_data_dir -from pydantic import ValidationError +from pydantic import BaseModel, ValidationError from tornado.web import HTTPError from traitlets import Integer, Unicode from traitlets.config import Configurable @@ -66,21 +67,6 @@ class BlockedModelError(Exception): pass -class ConfigErrorType(Enum): - CRITICAL = "Critical" - WARNING = "Warning" - - -class ConfigError(Exception): - def __init__(self, error_type: ConfigErrorType, message: str, details: str = None): - self.error_type = error_type - self.message = message - self.details = details - - def __str__(self): - return f"{self.error_type.value} ConfigError: {self.message} - {self.details or ''}" - - def _validate_provider_authn(config: GlobalConfig, provider: AnyProvider): # TODO: handle non-env auth strategies if not provider.auth_strategy or provider.auth_strategy.type != "env": @@ -212,7 +198,7 @@ def _validate_lm_em_id(self, config): warning_message = f"Language model {lm_id} is forbidden by current allow/blocklists. Setting to None." self.log.warning(warning_message) config.model_provider_id = None - self._config_errors.append = ConfigError( + self._config_errors.append = ConfigErrorModel( ConfigErrorType.WARNING, warning_message ) @@ -220,7 +206,7 @@ def _validate_lm_em_id(self, config): warning_message = f"Embedding model {em_id} is forbidden by current allow/blocklists. Setting to None." self.log.warning(warning_message) config.embeddings_provider_id = None - self._config_errors.append = ConfigError( + self._config_errors.append = ConfigErrorModel( ConfigErrorType.WARNING, warning_message ) @@ -232,7 +218,7 @@ def _validate_lm_em_id(self, config): ) self.log.warning(warning_message) config.model_provider_id = None - self._config_errors.append = ConfigError( + self._config_errors.append = ConfigErrorModel( ConfigErrorType.WARNING, warning_message ) @@ -242,7 +228,7 @@ def _validate_lm_em_id(self, config): ) self.log.warning(warning_message) config.embeddings_provider_id = None - self._config_errors.append = ConfigError( + self._config_errors.append = ConfigErrorModel( ConfigErrorType.WARNING, warning_message ) @@ -254,7 +240,7 @@ def _handle_validation_error(self, e: ValidationError): formatted_error = _format_validation_errors(e) error_message = "Configuration validation failed" self._config_errors.append( - ConfigError(ConfigErrorType.CRITICAL, error_message, formatted_error) + ConfigErrorModel(ConfigErrorType.CRITICAL, error_message, formatted_error) ) self.log.error(f"{error_message}: {formatted_error}") diff --git a/packages/jupyter-ai/jupyter_ai/extension.py b/packages/jupyter-ai/jupyter_ai/extension.py index 69441cecc..3dc8bbb40 100644 --- a/packages/jupyter-ai/jupyter_ai/extension.py +++ b/packages/jupyter-ai/jupyter_ai/extension.py @@ -15,7 +15,7 @@ LearnChatHandler, ) from .chat_handlers.help import HelpMessage -from .config_manager import ConfigManager +from .config_manager import ConfigErrorType, ConfigManager from .handlers import ( ApiKeysHandler, ChatHistoryHandler, @@ -131,13 +131,15 @@ def initialize_settings(self): ) config_errors = self.settings["jai_config_manager"].get_config_errors() - if config_errors: - # Log the error and proceed with limited functionality - self.log.error(f"Configuration error detected: {config_errors}") - # TODO: self._initialize_limited_functionality() - else: + if config_errors is None or all( + error.error_type != ConfigErrorType.CRITICAL for error in config_errors + ): # Full functionality initialization self._initialize_full_functionality() + else: + # Log the error and proceed with limited functionality + self.log.error(f"Configuration errors detected: {config_errors}") + # TODO: self._initialize_limited_functionality() self.log.info("Registered providers.") self.log.info(f"Registered {self.name} server extension") diff --git a/packages/jupyter-ai/jupyter_ai/handlers.py b/packages/jupyter-ai/jupyter_ai/handlers.py index e69f33894..126a6c94c 100644 --- a/packages/jupyter-ai/jupyter_ai/handlers.py +++ b/packages/jupyter-ai/jupyter_ai/handlers.py @@ -361,12 +361,6 @@ def config_manager(self): @web.authenticated def get(self): - config_error = self.config_manager.get_config_error() - if config_error: - self.set_status(400) - self.finish(config_error.json()) - return - config = self.config_manager.get_config() if not config: raise HTTPError(500, "No config found.") diff --git a/packages/jupyter-ai/jupyter_ai/models.py b/packages/jupyter-ai/jupyter_ai/models.py index 41509a74d..6cbfec14c 100644 --- a/packages/jupyter-ai/jupyter_ai/models.py +++ b/packages/jupyter-ai/jupyter_ai/models.py @@ -1,3 +1,4 @@ +from enum import Enum from typing import Any, Dict, List, Literal, Optional, Union from jupyter_ai_magics.providers import AuthStrategy, Field @@ -143,3 +144,17 @@ class GlobalConfig(BaseModel): send_with_shift_enter: bool fields: Dict[str, Dict[str, Any]] api_keys: Dict[str, str] + + +class ConfigErrorType(Enum): + CRITICAL = "Critical" + WARNING = "Warning" + + +class ConfigErrorModel(BaseModel): + error_type: ConfigErrorType + message: str + details: str = None + + def __str__(self): + return f"{self.error_type.value} ConfigError: {self.message} - {self.details or ''}" From 625c0120e76d66cee8a38c8f3ebbd0bcaa8d8561 Mon Sep 17 00:00:00 2001 From: Andrii Ieroshenko Date: Mon, 27 Nov 2023 22:43:30 -0800 Subject: [PATCH 08/10] remove unused imports --- packages/jupyter-ai/jupyter_ai/config_manager.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/packages/jupyter-ai/jupyter_ai/config_manager.py b/packages/jupyter-ai/jupyter_ai/config_manager.py index ad65924a5..b8da56bb4 100644 --- a/packages/jupyter-ai/jupyter_ai/config_manager.py +++ b/packages/jupyter-ai/jupyter_ai/config_manager.py @@ -22,8 +22,7 @@ get_lm_provider, ) from jupyter_core.paths import jupyter_data_dir -from pydantic import BaseModel, ValidationError -from tornado.web import HTTPError +from pydantic import ValidationError from traitlets import Integer, Unicode from traitlets.config import Configurable From de7fbaf1e6a7a24030ed93fccf7961dc2adca9ca Mon Sep 17 00:00:00 2001 From: Andrii Ieroshenko Date: Mon, 27 Nov 2023 23:32:43 -0800 Subject: [PATCH 09/10] show errors and warnings on the frontend --- .../jupyter-ai/jupyter_ai/config_manager.py | 35 +++++++++++++------ packages/jupyter-ai/jupyter_ai/models.py | 29 +++++++-------- packages/jupyter-ai/src/components/chat.tsx | 13 ++++++- packages/jupyter-ai/src/handler.ts | 12 +++++++ 4 files changed, 64 insertions(+), 25 deletions(-) diff --git a/packages/jupyter-ai/jupyter_ai/config_manager.py b/packages/jupyter-ai/jupyter_ai/config_manager.py index b8da56bb4..c937ff614 100644 --- a/packages/jupyter-ai/jupyter_ai/config_manager.py +++ b/packages/jupyter-ai/jupyter_ai/config_manager.py @@ -197,16 +197,20 @@ def _validate_lm_em_id(self, config): warning_message = f"Language model {lm_id} is forbidden by current allow/blocklists. Setting to None." self.log.warning(warning_message) config.model_provider_id = None - self._config_errors.append = ConfigErrorModel( - ConfigErrorType.WARNING, warning_message + self._config_errors.append( + ConfigErrorModel( + error_type=ConfigErrorType.WARNING, message=warning_message + ) ) if em_id is not None and not self._validate_model(em_id, raise_exc=False): warning_message = f"Embedding model {em_id} is forbidden by current allow/blocklists. Setting to None." self.log.warning(warning_message) config.embeddings_provider_id = None - self._config_errors.append = ConfigErrorModel( - ConfigErrorType.WARNING, warning_message + self._config_errors.append( + ConfigErrorModel( + error_type=ConfigErrorType.WARNING, message=warning_message + ) ) # if the currently selected language or embedding model ids are @@ -217,8 +221,10 @@ def _validate_lm_em_id(self, config): ) self.log.warning(warning_message) config.model_provider_id = None - self._config_errors.append = ConfigErrorModel( - ConfigErrorType.WARNING, warning_message + self._config_errors.append( + ConfigErrorModel( + error_type=ConfigErrorType.WARNING, message=warning_message + ) ) if em_id is not None and not get_em_provider(em_id, self._em_providers)[1]: @@ -227,8 +233,10 @@ def _validate_lm_em_id(self, config): ) self.log.warning(warning_message) config.embeddings_provider_id = None - self._config_errors.append = ConfigErrorModel( - ConfigErrorType.WARNING, warning_message + self._config_errors.append( + ConfigErrorModel( + error_type=ConfigErrorType.WARNING, message=warning_message + ) ) # re-write to the file to validate the config and apply any @@ -239,7 +247,11 @@ def _handle_validation_error(self, e: ValidationError): formatted_error = _format_validation_errors(e) error_message = "Configuration validation failed" self._config_errors.append( - ConfigErrorModel(ConfigErrorType.CRITICAL, error_message, formatted_error) + ConfigErrorModel( + error_type=ConfigErrorType.CRITICAL, + message=error_message, + details=formatted_error, + ) ) self.log.error(f"{error_message}: {formatted_error}") @@ -404,7 +416,10 @@ def get_config(self): config_dict = config.dict(exclude_unset=True) api_key_names = list(config_dict.pop("api_keys").keys()) return DescribeConfigResponse( - **config_dict, api_keys=api_key_names, last_read=self._last_read + **config_dict, + api_keys=api_key_names, + last_read=self._last_read, + config_errors=self.get_config_errors(), ) @property diff --git a/packages/jupyter-ai/jupyter_ai/models.py b/packages/jupyter-ai/jupyter_ai/models.py index 6cbfec14c..2c27f739c 100644 --- a/packages/jupyter-ai/jupyter_ai/models.py +++ b/packages/jupyter-ai/jupyter_ai/models.py @@ -100,6 +100,20 @@ class IndexMetadata(BaseModel): dirs: List[IndexedDir] +class ConfigErrorType(Enum): + CRITICAL = "Critical" + WARNING = "Warning" + + +class ConfigErrorModel(BaseModel): + error_type: ConfigErrorType + message: str + details: str = None + + def __str__(self): + return f"{self.error_type.value} ConfigError: {self.message} - {self.details or ''}" + + class DescribeConfigResponse(BaseModel): model_provider_id: Optional[str] embeddings_provider_id: Optional[str] @@ -111,6 +125,7 @@ class DescribeConfigResponse(BaseModel): # timestamp indicating when the configuration file was last read. should be # passed to the subsequent UpdateConfig request. last_read: int + config_errors: Optional[List[ConfigErrorModel]] = None def forbid_none(cls, v): @@ -144,17 +159,3 @@ class GlobalConfig(BaseModel): send_with_shift_enter: bool fields: Dict[str, Dict[str, Any]] api_keys: Dict[str, str] - - -class ConfigErrorType(Enum): - CRITICAL = "Critical" - WARNING = "Warning" - - -class ConfigErrorModel(BaseModel): - error_type: ConfigErrorType - message: str - details: str = None - - def __str__(self): - return f"{self.error_type.value} ConfigError: {self.message} - {self.details or ''}" diff --git a/packages/jupyter-ai/src/components/chat.tsx b/packages/jupyter-ai/src/components/chat.tsx index ded339c70..83ba69c42 100644 --- a/packages/jupyter-ai/src/components/chat.tsx +++ b/packages/jupyter-ai/src/components/chat.tsx @@ -1,6 +1,6 @@ import React, { useState, useEffect } from 'react'; import { Box } from '@mui/system'; -import { Button, IconButton, Stack } from '@mui/material'; +import { Alert, AlertTitle, Button, IconButton, Stack } from '@mui/material'; import SettingsIcon from '@mui/icons-material/Settings'; import ArrowBackIcon from '@mui/icons-material/ArrowBack'; import type { Awareness } from 'y-protocols/awareness'; @@ -35,6 +35,7 @@ function ChatBody({ const [input, setInput] = useState(''); const [selection, replaceSelectionFn] = useSelectionContext(); const [sendWithShiftEnter, setSendWithShiftEnter] = useState(true); + const [configErrors, setConfigErrors] = useState([]); /** * Effect: fetch history and config on initial render @@ -51,6 +52,9 @@ function ChatBody({ if (!config.model_provider_id) { setShowWelcomeMessage(true); } + if (config.config_errors) { + setConfigErrors(config.config_errors); + } } catch (e) { console.error(e); } @@ -125,6 +129,13 @@ function ChatBody({ }} > + {configErrors && + configErrors.map((error, idx) => ( + + {error.error_type} + {error.message} {error.details && `- ${error.details}`} + + ))}

Welcome to Jupyter AI! To get started, please select a language model to chat with from the settings panel. You may also need to diff --git a/packages/jupyter-ai/src/handler.ts b/packages/jupyter-ai/src/handler.ts index ec91c1f3f..b397d34ad 100644 --- a/packages/jupyter-ai/src/handler.ts +++ b/packages/jupyter-ai/src/handler.ts @@ -104,6 +104,17 @@ export namespace AiService { messages: ChatMessage[]; }; + export enum ConfigErrorType { + CRITICAL = 'Critical', + WARNING = 'Warning' + } + + export type ConfigError = { + error_type: ConfigErrorType; + message: string; + details?: string; + }; + export type DescribeConfigResponse = { model_provider_id: string | null; embeddings_provider_id: string | null; @@ -111,6 +122,7 @@ export namespace AiService { send_with_shift_enter: boolean; fields: Record>; last_read: number; + config_errors?: ConfigError[]; }; export type UpdateConfigRequest = { From 742f20bc9c6a10f84ccb9479f73a2a560b4140ae Mon Sep 17 00:00:00 2001 From: Andrii Ieroshenko Date: Fri, 8 Dec 2023 10:43:40 -0800 Subject: [PATCH 10/10] initialize config manager with errors --- .../jupyter-ai/jupyter_ai/config_manager.py | 104 +++++++++++------- packages/jupyter-ai/jupyter_ai/extension.py | 38 +++++-- packages/jupyter-ai/src/components/chat.tsx | 34 ++++-- packages/jupyter-ai/src/index.ts | 24 +++- .../jupyter-ai/src/widgets/chat-sidebar.tsx | 2 +- 5 files changed, 135 insertions(+), 67 deletions(-) diff --git a/packages/jupyter-ai/jupyter_ai/config_manager.py b/packages/jupyter-ai/jupyter_ai/config_manager.py index c937ff614..325be3dd5 100644 --- a/packages/jupyter-ai/jupyter_ai/config_manager.py +++ b/packages/jupyter-ai/jupyter_ai/config_manager.py @@ -164,19 +164,29 @@ def _init_validator(self) -> Validator: self.validator = Validator(schema) def _init_config(self): - try: - if os.path.exists(self.config_path): - self._process_existing_config() - else: - self._create_default_config() - except ValidationError as e: - self._handle_validation_error(e) + # try: + if os.path.exists(self.config_path): + self._process_existing_config() + else: + self._create_default_config() + # except ValidationError as e: + # self._handle_validation_error(e) + # self._config = GlobalConfig( + # send_with_shift_enter=False, fields={}, api_keys={} + # ) def _process_existing_config(self): with open(self.config_path, encoding="utf-8") as f: - config = GlobalConfig(**json.loads(f.read())) - self._validate_lm_em_id(config) + raw_config = json.loads(f.read()) + + validated_raw_config = self._validate_lm_em_id(raw_config) + + try: + config = GlobalConfig(**validated_raw_config) self._write_config(config) + except ValidationError as e: + corrected_config = self._handle_validation_error(e, validated_raw_config) + self._write_config(corrected_config) def _create_default_config(self): properties = self.validator.schema.get("properties", {}) @@ -187,16 +197,16 @@ def _create_default_config(self): default_config = GlobalConfig(**field_dict) self._write_config(default_config) - def _validate_lm_em_id(self, config): - lm_id = config.model_provider_id - em_id = config.embeddings_provider_id + def _validate_lm_em_id(self, raw_config): + lm_id = raw_config.get("model_provider_id") + em_id = raw_config.get("embeddings_provider_id") # if the currently selected language or embedding model are # forbidden, set them to `None` and log a warning. if lm_id is not None and not self._validate_model(lm_id, raise_exc=False): warning_message = f"Language model {lm_id} is forbidden by current allow/blocklists. Setting to None." self.log.warning(warning_message) - config.model_provider_id = None + raw_config["model_provider_id"] = None self._config_errors.append( ConfigErrorModel( error_type=ConfigErrorType.WARNING, message=warning_message @@ -206,7 +216,7 @@ def _validate_lm_em_id(self, config): if em_id is not None and not self._validate_model(em_id, raise_exc=False): warning_message = f"Embedding model {em_id} is forbidden by current allow/blocklists. Setting to None." self.log.warning(warning_message) - config.embeddings_provider_id = None + raw_config["embeddings_provider_id"] = None self._config_errors.append( ConfigErrorModel( error_type=ConfigErrorType.WARNING, message=warning_message @@ -220,7 +230,7 @@ def _validate_lm_em_id(self, config): f"No language model is associated with '{lm_id}'. Setting to None." ) self.log.warning(warning_message) - config.model_provider_id = None + raw_config["model_provider_id"] = None self._config_errors.append( ConfigErrorModel( error_type=ConfigErrorType.WARNING, message=warning_message @@ -232,28 +242,43 @@ def _validate_lm_em_id(self, config): f"No embedding model is associated with '{em_id}'. Setting to None." ) self.log.warning(warning_message) - config.embeddings_provider_id = None + raw_config["embeddings_provider_id"] = None self._config_errors.append( ConfigErrorModel( error_type=ConfigErrorType.WARNING, message=warning_message ) ) - # re-write to the file to validate the config and apply any - # updates to the config file immediately - self._write_config(config) - - def _handle_validation_error(self, e: ValidationError): - formatted_error = _format_validation_errors(e) - error_message = "Configuration validation failed" - self._config_errors.append( - ConfigErrorModel( - error_type=ConfigErrorType.CRITICAL, - message=error_message, - details=formatted_error, - ) - ) - self.log.error(f"{error_message}: {formatted_error}") + return raw_config + + def _handle_validation_error(self, e: ValidationError, raw_config): + # Extract default values from schema + properties = self.validator.schema.get("properties", {}) + field_list = GlobalConfig.__fields__.keys() + default_values = { + field: properties.get(field).get("default") for field in field_list + } + + # Apply default values to erroneous fields + for error in e.errors(): + field = error["loc"][0] + if field in default_values: + raw_config[field] = default_values[field] + warning_message = f"Error in '{field}': {error['msg']}. Resetting to default value ('{default_values[field]}')." + self.log.warning(warning_message) + self._config_errors.append( + ConfigErrorModel( + error_type=ConfigErrorType.WARNING, message=warning_message + ) + ) + + # Create a config with default values for erroneous fields + config = GlobalConfig(**raw_config) + self.log.warning("\n\n\n Config \n\n\n") + + self.log.warning(config) + self._validate_config(config) + return config def _read_config(self) -> GlobalConfig: """Returns the user's current configuration as a GlobalConfig object. @@ -264,12 +289,15 @@ def _read_config(self) -> GlobalConfig: if last_write <= self._last_read: return self._config - with open(self.config_path, encoding="utf-8") as f: - self._last_read = time.time_ns() - raw_config = json.loads(f.read()) - config = GlobalConfig(**raw_config) - self._validate_config(config) - return config + with open(self.config_path, encoding="utf-8") as f: + self._last_read = time.time_ns() + raw_config = json.loads(f.read()) + try: + config = GlobalConfig(**raw_config) + except ValidationError as e: + config = self._handle_validation_error(e, raw_config) + self._validate_config(config) + return config def _validate_config(self, config: GlobalConfig): """Method used to validate the configuration. This is called after every @@ -414,7 +442,7 @@ def update_config(self, config_update: UpdateConfigRequest): def get_config(self): config = self._read_config() config_dict = config.dict(exclude_unset=True) - api_key_names = list(config_dict.pop("api_keys").keys()) + api_key_names = list(config_dict.pop("api_keys", {}).keys()) return DescribeConfigResponse( **config_dict, api_keys=api_key_names, diff --git a/packages/jupyter-ai/jupyter_ai/extension.py b/packages/jupyter-ai/jupyter_ai/extension.py index 3dc8bbb40..50a91d1e1 100644 --- a/packages/jupyter-ai/jupyter_ai/extension.py +++ b/packages/jupyter-ai/jupyter_ai/extension.py @@ -28,14 +28,7 @@ class AiExtension(ExtensionApp): name = "jupyter_ai" - handlers = [ - (r"api/ai/api_keys/(?P\w+)", ApiKeysHandler), - (r"api/ai/config/?", GlobalConfigHandler), - (r"api/ai/chats/?", RootChatHandler), - (r"api/ai/chats/history?", ChatHistoryHandler), - (r"api/ai/providers?", ModelProviderHandler), - (r"api/ai/providers/embeddings?", EmbeddingsModelProviderHandler), - ] + handlers = [(r"api/ai/config/?", GlobalConfigHandler)] allowed_providers = List( Unicode(), @@ -139,15 +132,24 @@ def initialize_settings(self): else: # Log the error and proceed with limited functionality self.log.error(f"Configuration errors detected: {config_errors}") - # TODO: self._initialize_limited_functionality() + self._initialize_limited_functionality(config_errors) - self.log.info("Registered providers.") self.log.info(f"Registered {self.name} server extension") latency_ms = round((time.time() - start) * 1000) self.log.info(f"Initialized Jupyter AI server extension in {latency_ms} ms.") def _initialize_full_functionality(self): + self.handlers.extend( + [ + (r"api/ai/api_keys/(?P\w+)", ApiKeysHandler), + (r"api/ai/chats/?", RootChatHandler), + (r"api/ai/chats/history?", ChatHistoryHandler), + (r"api/ai/providers?", ModelProviderHandler), + (r"api/ai/providers/embeddings?", EmbeddingsModelProviderHandler), + ] + ) + # Store chat clients in a dictionary self.settings["chat_clients"] = {} self.settings["jai_root_chat_handlers"] = {} @@ -204,5 +206,21 @@ def _initialize_full_functionality(self): "/help": help_chat_handler, } + self.log.info("Registered providers.") + + def _initialize_limited_functionality(self, config_errors): + """ + Initialize the extension with limited functionality due to configuration errors. + """ + self.log.warning( + "Initializing Jupyter AI extension with limited functionality due to configuration errors." + ) + + # Capture configuration error details + config_errors = self.settings["jai_config_manager"].get_config_errors() + self.settings["config_errors"] = config_errors + + self.settings["jai_chat_handlers"] = [] + async def _get_dask_client(self): return DaskClient(processes=False, asynchronous=True) diff --git a/packages/jupyter-ai/src/components/chat.tsx b/packages/jupyter-ai/src/components/chat.tsx index 83ba69c42..e87e0eb13 100644 --- a/packages/jupyter-ai/src/components/chat.tsx +++ b/packages/jupyter-ai/src/components/chat.tsx @@ -20,7 +20,7 @@ import { CollaboratorsContextProvider } from '../contexts/collaborators-context' import { ScrollContainer } from './scroll-container'; type ChatBodyProps = { - chatHandler: ChatHandler; + chatHandler: ChatHandler | null; setChatView: (view: ChatView) => void; }; @@ -43,12 +43,22 @@ function ChatBody({ useEffect(() => { async function fetchHistory() { try { - const [history, config] = await Promise.all([ - chatHandler.getHistory(), - AiService.getConfig() - ]); + const config = await AiService.getConfig(); setSendWithShiftEnter(config.send_with_shift_enter ?? false); - setMessages(history.messages); + + // Check if there are critical errors + const hasCriticalErrors = config.config_errors?.some( + error => error.error_type === AiService.ConfigErrorType.CRITICAL + ); + console.log('\n\n\n *** \n\n\n'); + console.log(hasCriticalErrors); + if (!hasCriticalErrors && chatHandler) { + const history = await chatHandler.getHistory(); + setMessages(history.messages); + } else { + setMessages([]); + } + if (!config.model_provider_id) { setShowWelcomeMessage(true); } @@ -78,9 +88,9 @@ function ChatBody({ setMessages(messageGroups => [...messageGroups, message]); } - chatHandler.addListener(handleChatEvents); + chatHandler?.addListener(handleChatEvents); return function cleanup() { - chatHandler.removeListener(handleChatEvents); + chatHandler?.removeListener(handleChatEvents); }; }, [chatHandler]); @@ -96,18 +106,18 @@ function ChatBody({ : ''); // send message to backend - const messageId = await chatHandler.sendMessage({ prompt }); + const messageId = await chatHandler?.sendMessage({ prompt }); // await reply from agent // no need to append to messageGroups state variable, since that's already // handled in the effect hooks. - const reply = await chatHandler.replyFor(messageId); + const reply = await chatHandler?.replyFor(messageId ?? ''); if (replaceSelection && selection) { const { cellId, ...selectionProps } = selection; replaceSelectionFn({ ...selectionProps, ...(cellId && { cellId }), - text: reply.body + text: reply?.body ?? '' }); } }; @@ -187,7 +197,7 @@ function ChatBody({ export type ChatProps = { selectionWatcher: SelectionWatcher; - chatHandler: ChatHandler; + chatHandler: ChatHandler | null; globalAwareness: Awareness | null; chatView?: ChatView; }; diff --git a/packages/jupyter-ai/src/index.ts b/packages/jupyter-ai/src/index.ts index e48e2b211..c8e57bd93 100644 --- a/packages/jupyter-ai/src/index.ts +++ b/packages/jupyter-ai/src/index.ts @@ -12,6 +12,7 @@ import { buildChatSidebar } from './widgets/chat-sidebar'; import { SelectionWatcher } from './selection-watcher'; import { ChatHandler } from './chat_handler'; import { buildErrorWidget } from './widgets/chat-error'; +import { AiService } from './handler'; export type DocumentTracker = IWidgetTracker; @@ -32,14 +33,25 @@ const plugin: JupyterFrontEndPlugin = { */ const selectionWatcher = new SelectionWatcher(app.shell); - /** - * Initialize chat handler, open WS connection - */ - const chatHandler = new ChatHandler(); - let chatWidget: ReactWidget | null = null; + let chatHandler: ChatHandler | null = null; + try { - await chatHandler.initialize(); + // Fetch configuration to check for critical errors + const config = await AiService.getConfig(); + console.log('\n\n\n *** \n\n\n'); + console.log(config.config_errors); + const hasCriticalErrors = config.config_errors?.some( + error => error.error_type === AiService.ConfigErrorType.CRITICAL + ); + + if (!hasCriticalErrors) { + /** + * Initialize chat handler, open WS connection + */ + chatHandler = new ChatHandler(); + await chatHandler.initialize(); + } chatWidget = buildChatSidebar( selectionWatcher, chatHandler, diff --git a/packages/jupyter-ai/src/widgets/chat-sidebar.tsx b/packages/jupyter-ai/src/widgets/chat-sidebar.tsx index 8bc7df12c..abcd81ba4 100644 --- a/packages/jupyter-ai/src/widgets/chat-sidebar.tsx +++ b/packages/jupyter-ai/src/widgets/chat-sidebar.tsx @@ -9,7 +9,7 @@ import { ChatHandler } from '../chat_handler'; export function buildChatSidebar( selectionWatcher: SelectionWatcher, - chatHandler: ChatHandler, + chatHandler: ChatHandler | null, globalAwareness: Awareness | null ): ReactWidget { const ChatWidget = ReactWidget.create(