diff --git a/components/renku_data_services/base_models/core.py b/components/renku_data_services/base_models/core.py index fe4f32fe7..484731e15 100644 --- a/components/renku_data_services/base_models/core.py +++ b/components/renku_data_services/base_models/core.py @@ -212,3 +212,17 @@ class Authenticator(Protocol[AnyAPIUser]): async def authenticate(self, access_token: str, request: Request) -> AnyAPIUser: """Validates the user credentials (i.e. we can say that the user is a valid Renku user).""" ... + + +@dataclass(frozen=True, eq=True, kw_only=True) +class Null: + """Parent class for distinguishing between None values.""" + + value: None = field(default=None, init=False, repr=False) + + +@dataclass(frozen=True, eq=True, kw_only=True) +class Reset(Null): + """Used to indicate a None value that has been deliberately set by the user or caller.""" + + ... diff --git a/components/renku_data_services/session/blueprints.py b/components/renku_data_services/session/blueprints.py index 75fbf25a4..b42892571 100644 --- a/components/renku_data_services/session/blueprints.py +++ b/components/renku_data_services/session/blueprints.py @@ -11,7 +11,7 @@ import renku_data_services.base_models as base_models from renku_data_services.base_api.auth import authenticate, validate_path_project_id from renku_data_services.base_api.blueprint import BlueprintFactoryResponse, CustomBlueprint -from renku_data_services.session import apispec, models +from renku_data_services.session import apispec, converters, models from renku_data_services.session.db import SessionRepository @@ -75,9 +75,11 @@ def patch(self) -> BlueprintFactoryResponse: async def _patch( _: Request, user: base_models.APIUser, environment_id: ULID, body: apispec.EnvironmentPatch ) -> JSONResponse: - body_dict = body.model_dump(exclude_none=True) + update = converters.environment_update_from_patch(body) environment = await self.session_repo.update_environment( - user=user, environment_id=environment_id, **body_dict + user=user, + environment_id=environment_id, + update=update, ) return json(apispec.Environment.model_validate(environment).model_dump(exclude_none=True, mode="json")) @@ -172,34 +174,14 @@ def patch(self) -> BlueprintFactoryResponse: async def _patch( _: Request, user: base_models.APIUser, launcher_id: ULID, body: apispec.SessionLauncherPatch ) -> JSONResponse: - body_dict = body.model_dump(exclude_none=True, mode="json") async with self.session_repo.session_maker() as session, session.begin(): current_launcher = await self.session_repo.get_launcher(user, launcher_id) - new_env: models.UnsavedEnvironment | None = None - if ( - isinstance(body.environment, apispec.EnvironmentPatchInLauncher) - and current_launcher.environment.environment_kind == models.EnvironmentKind.GLOBAL - and body.environment.environment_kind == apispec.EnvironmentKind.CUSTOM - ): - # This means that the global environment is being swapped for a custom one, - # so we have to create a brand new environment, but we have to validate here. - validated_env = apispec.EnvironmentPostInLauncher.model_validate(body_dict.pop("environment")) - new_env = models.UnsavedEnvironment( - name=validated_env.name, - description=validated_env.description, - container_image=validated_env.container_image, - default_url=validated_env.default_url, - port=validated_env.port, - working_directory=PurePosixPath(validated_env.working_directory), - mount_directory=PurePosixPath(validated_env.mount_directory), - uid=validated_env.uid, - gid=validated_env.gid, - environment_kind=models.EnvironmentKind(validated_env.environment_kind.value), - args=validated_env.args, - command=validated_env.command, - ) + update = converters.launcher_update_from_patch(body, current_launcher) launcher = await self.session_repo.update_launcher( - user=user, launcher_id=launcher_id, new_custom_environment=new_env, session=session, **body_dict + user=user, + launcher_id=launcher_id, + session=session, + update=update, ) return json(apispec.SessionLauncher.model_validate(launcher).model_dump(exclude_none=True, mode="json")) diff --git a/components/renku_data_services/session/converters.py b/components/renku_data_services/session/converters.py new file mode 100644 index 000000000..bbdade105 --- /dev/null +++ b/components/renku_data_services/session/converters.py @@ -0,0 +1,83 @@ +"""Code used to convert from/to apispec and models.""" + +from pathlib import PurePosixPath + +from renku_data_services.base_models.core import Reset +from renku_data_services.session import apispec, models + + +def environment_update_from_patch(data: apispec.EnvironmentPatch) -> models.EnvironmentUpdate: + """Create an update object from an apispec or any other pydantic model.""" + data_dict = data.model_dump(exclude_unset=True, mode="json") + working_directory: PurePosixPath | None = None + if data.working_directory is not None: + working_directory = PurePosixPath(data.working_directory) + mount_directory: PurePosixPath | None = None + if data.mount_directory is not None: + mount_directory = PurePosixPath(data.mount_directory) + # NOTE: If the args or command are present in the data_dict and they are None they were passed in by the user. + # The None specifically passed by the user indicates that the value should be removed from the DB. + args = Reset() if "args" in data_dict and data_dict["args"] is None else data.args + command = Reset() if "command" in data_dict and data_dict["command"] is None else data.command + return models.EnvironmentUpdate( + name=data.name, + description=data.description, + container_image=data.container_image, + default_url=data.default_url, + port=data.port, + working_directory=working_directory, + mount_directory=mount_directory, + uid=data.uid, + gid=data.gid, + args=args, + command=command, + ) + + +def launcher_update_from_patch( + data: apispec.SessionLauncherPatch, + current_launcher: models.SessionLauncher | None = None, +) -> models.SessionLauncherUpdate: + """Create an update object from an apispec or any other pydantic model.""" + data_dict = data.model_dump(exclude_unset=True, mode="json") + environment: str | models.EnvironmentUpdate | models.UnsavedEnvironment | None = None + if ( + isinstance(data.environment, apispec.EnvironmentPatchInLauncher) + and current_launcher is not None + and current_launcher.environment.environment_kind == models.EnvironmentKind.GLOBAL + and data.environment.environment_kind == apispec.EnvironmentKind.CUSTOM + ): + # This means that the global environment is being swapped for a custom one, + # so we have to create a brand new environment, but we have to validate here. + validated_env = apispec.EnvironmentPostInLauncher.model_validate(data_dict["environment"]) + environment = models.UnsavedEnvironment( + name=validated_env.name, + description=validated_env.description, + container_image=validated_env.container_image, + default_url=validated_env.default_url, + port=validated_env.port, + working_directory=PurePosixPath(validated_env.working_directory), + mount_directory=PurePosixPath(validated_env.mount_directory), + uid=validated_env.uid, + gid=validated_env.gid, + environment_kind=models.EnvironmentKind(validated_env.environment_kind.value), + args=validated_env.args, + command=validated_env.command, + ) + elif isinstance(data.environment, apispec.EnvironmentPatchInLauncher): + environment = environment_update_from_patch(data.environment) + elif isinstance(data.environment, apispec.EnvironmentIdOnlyPatch): + environment = data.environment.id + resource_class_id: int | None | Reset = None + if "resource_class_id" in data_dict and data_dict["resource_class_id"] is None: + # NOTE: This means that the resource class set in the DB should be removed so that the + # default resource class currently set in the CRC will be used. + resource_class_id = Reset() + else: + resource_class_id = data_dict.get("resource_class_id") + return models.SessionLauncherUpdate( + name=data_dict.get("name"), + description=data_dict.get("description"), + environment=environment, + resource_class_id=resource_class_id, + ) diff --git a/components/renku_data_services/session/db.py b/components/renku_data_services/session/db.py index 417820e40..9a1215136 100644 --- a/components/renku_data_services/session/db.py +++ b/components/renku_data_services/session/db.py @@ -5,7 +5,6 @@ from collections.abc import Callable from contextlib import AbstractAsyncContextManager, nullcontext from datetime import UTC, datetime -from typing import Any from sqlalchemy import select from sqlalchemy.ext.asyncio import AsyncSession @@ -15,6 +14,7 @@ from renku_data_services import errors from renku_data_services.authz.authz import Authz, ResourceType from renku_data_services.authz.models import Scope +from renku_data_services.base_models.core import Reset from renku_data_services.crc.db import ResourcePoolRepository from renku_data_services.session import models from renku_data_services.session import orm as schemas @@ -101,53 +101,59 @@ async def insert_environment( env = await self.__insert_environment(user, session, new_environment) return env.dump() - async def __update_environment( + def __update_environment( self, - user: base_models.APIUser, - session: AsyncSession, - environment_id: ULID, - kind: models.EnvironmentKind, - **kwargs: dict, - ) -> models.Environment: - res = await session.scalars( - select(schemas.EnvironmentORM) - .where(schemas.EnvironmentORM.id == str(environment_id)) - .where(schemas.EnvironmentORM.environment_kind == kind.value) - ) - environment = res.one_or_none() - if environment is None: - raise errors.MissingResourceError(message=f"Session environment with id '{environment_id}' does not exist.") - - for key, value in kwargs.items(): - # NOTE: Only some fields can be edited - if key in [ - "name", - "description", - "container_image", - "default_url", - "port", - "working_directory", - "mount_directory", - "uid", - "gid", - "args", - "command", - ]: - setattr(environment, key, value) - - return environment.dump() + environment: schemas.EnvironmentORM, + update: models.EnvironmentUpdate, + ) -> None: + # NOTE: this is more verbose than a loop and setattr but this way we get mypy type checks + if update.name is not None: + environment.name = update.name + if update.description is not None: + environment.description = update.description + if update.container_image is not None: + environment.container_image = update.container_image + if update.default_url is not None: + environment.default_url = update.default_url + if update.port is not None: + environment.port = update.port + if update.working_directory is not None: + environment.working_directory = update.working_directory + if update.mount_directory is not None: + environment.mount_directory = update.mount_directory + if update.uid is not None: + environment.uid = update.uid + if update.gid is not None: + environment.gid = update.gid + if isinstance(update.args, Reset): + environment.args = None + elif isinstance(update.args, list): + environment.args = update.args + if isinstance(update.command, Reset): + environment.command = None + elif isinstance(update.command, list): + environment.command = update.command async def update_environment( - self, user: base_models.APIUser, environment_id: ULID, **kwargs: dict + self, user: base_models.APIUser, environment_id: ULID, update: models.EnvironmentUpdate ) -> models.Environment: """Update a global session environment entry.""" if not user.is_admin: raise errors.UnauthorizedError(message="You do not have the required permissions for this operation.") async with self.session_maker() as session, session.begin(): - return await self.__update_environment( - user, session, environment_id, models.EnvironmentKind.GLOBAL, **kwargs + res = await session.scalars( + select(schemas.EnvironmentORM) + .where(schemas.EnvironmentORM.id == str(environment_id)) + .where(schemas.EnvironmentORM.environment_kind == models.EnvironmentKind.GLOBAL) ) + environment = res.one_or_none() + if environment is None: + raise errors.MissingResourceError( + message=f"Session environment with id '{environment_id}' does not exist." + ) + self.__update_environment(environment, update) + return environment.dump() async def delete_environment(self, user: base_models.APIUser, environment_id: ULID) -> None: """Delete a global session environment entry.""" @@ -300,9 +306,8 @@ async def update_launcher( self, user: base_models.APIUser, launcher_id: ULID, - new_custom_environment: models.UnsavedEnvironment | None, + update: models.SessionLauncherUpdate, session: AsyncSession | None = None, - **kwargs: Any, ) -> models.SessionLauncher: """Update a session launcher entry.""" if not user.is_authenticated or user.id is None: @@ -336,8 +341,8 @@ async def update_launcher( if not authorized: raise errors.ForbiddenError(message="You do not have the required permissions for this operation.") - resource_class_id = kwargs.get("resource_class_id") - if resource_class_id is not None: + resource_class_id = update.resource_class_id + if isinstance(resource_class_id, int): res = await session.scalars( select(schemas.ResourceClassORM).where(schemas.ResourceClassORM.id == resource_class_id) ) @@ -354,17 +359,20 @@ async def update_launcher( message=f"You do not have access to resource class with id '{resource_class_id}'." ) - for key, value in kwargs.items(): - # NOTE: Only some fields can be updated. - if key in [ - "name", - "description", - "resource_class_id", - ]: - setattr(launcher, key, value) - - env_payload = kwargs.get("environment", {}) - await self.__update_launcher_environment(user, launcher, session, new_custom_environment, **env_payload) + # NOTE: Only some fields can be updated. + if update.name is not None: + launcher.name = update.name + if update.description is not None: + launcher.description = update.description + if isinstance(update.resource_class_id, int): + launcher.resource_class_id = update.resource_class_id + elif isinstance(update.resource_class_id, Reset): + launcher.resource_class_id = None + + if update.environment is None: + return launcher.dump() + + await self.__update_launcher_environment(user, launcher, session, update.environment) return launcher.dump() async def __update_launcher_environment( @@ -372,12 +380,11 @@ async def __update_launcher_environment( user: base_models.APIUser, launcher: schemas.SessionLauncherORM, session: AsyncSession, - new_custom_environment: models.UnsavedEnvironment | None, - **kwargs: Any, + update: models.EnvironmentUpdate | models.UnsavedEnvironment | str, ) -> None: current_env_kind = launcher.environment.environment_kind - match new_custom_environment, current_env_kind, kwargs: - case None, _, {"id": env_id, **nothing_else} if len(nothing_else) == 0: + match update, current_env_kind: + case str() as env_id, _: # The environment in the launcher is set via ID, the new ID has to refer # to an environment that is GLOBAL. old_environment = launcher.environment @@ -404,29 +411,11 @@ async def __update_launcher_environment( # We remove the custom environment to avoid accumulating custom environments that are not associated # with any launchers. await session.delete(old_environment) - case None, models.EnvironmentKind.CUSTOM, {**rest} if ( - rest.get("environment_kind") is None - or rest.get("environment_kind") == models.EnvironmentKind.CUSTOM.value - ): + case models.EnvironmentUpdate(), models.EnvironmentKind.CUSTOM: # Custom environment being updated - for key, val in rest.items(): - # NOTE: Only some fields can be updated. - if key in [ - "name", - "description", - "container_image", - "default_url", - "port", - "working_directory", - "mount_directory", - "uid", - "gid", - "args", - "command", - ]: - setattr(launcher.environment, key, val) - case models.UnsavedEnvironment(), models.EnvironmentKind.GLOBAL, {**nothing_else} if ( - len(nothing_else) == 0 and new_custom_environment.environment_kind == models.EnvironmentKind.CUSTOM + self.__update_environment(launcher.environment, update) + case models.UnsavedEnvironment() as new_custom_environment, models.EnvironmentKind.GLOBAL if ( + new_custom_environment.environment_kind == models.EnvironmentKind.CUSTOM ): # Global environment replaced by a custom one new_env = await self.__insert_environment(user, session, new_custom_environment) diff --git a/components/renku_data_services/session/models.py b/components/renku_data_services/session/models.py index 6dcff46c2..9ce6f0fbe 100644 --- a/components/renku_data_services/session/models.py +++ b/components/renku_data_services/session/models.py @@ -8,6 +8,7 @@ from ulid import ULID from renku_data_services import errors +from renku_data_services.base_models.core import Reset class EnvironmentKind(StrEnum): @@ -70,6 +71,23 @@ class Environment(BaseEnvironment): created_by: str +@dataclass(kw_only=True, frozen=True, eq=True) +class EnvironmentUpdate: + """Model for the update of some or all parts of an environment.""" + + name: str | None = None + description: str | None = None + container_image: str | None = None + default_url: str | None = None + port: int | None = None + working_directory: PurePosixPath | None = None + mount_directory: PurePosixPath | None = None + uid: int | None = None + gid: int | None = None + args: list[str] | None | Reset = None + command: list[str] | None | Reset = None + + @dataclass(frozen=True, eq=True, kw_only=True) class BaseSessionLauncher: """Session launcher model.""" @@ -99,3 +117,15 @@ class SessionLauncher(BaseSessionLauncher): creation_date: datetime created_by: str environment: Environment + + +@dataclass(frozen=True, eq=True, kw_only=True) +class SessionLauncherUpdate: + """Model for the update of a session launcher.""" + + name: str | None = None + description: str | None = None + # NOTE: When unsaved environment is used it means a brand new environment should be created for the + # launcher with the update of the launcher. + environment: str | EnvironmentUpdate | UnsavedEnvironment | None = None + resource_class_id: int | None | Reset = None diff --git a/test/bases/renku_data_services/data_api/test_sessions.py b/test/bases/renku_data_services/data_api/test_sessions.py index 732e5001a..d920cba9d 100644 --- a/test/bases/renku_data_services/data_api/test_sessions.py +++ b/test/bases/renku_data_services/data_api/test_sessions.py @@ -3,6 +3,7 @@ import os import shutil from asyncio import AbstractEventLoop +from collections.abc import Iterator from typing import Any import pytest @@ -18,8 +19,8 @@ os.environ["KUBECONFIG"] = ".k3d-config.yaml" -@pytest.fixture(scope="module", autouse=True) -def cluster() -> K3DCluster: +@pytest.fixture(scope="module") +def cluster() -> Iterator[K3DCluster]: if shutil.which("k3d") is None: pytest.skip("Requires k3d for cluster creation") @@ -172,10 +173,14 @@ async def test_patch_session_environment( env = await create_session_environment("Environment 1") environment_id = env["id"] + command = ["python", "test.py"] + args = ["arg1", "arg2"] payload = { "name": "New name", "description": "New description.", "container_image": "new_image:new_tag", + "command": command, + "args": args, } _, res = await sanic_client.patch(f"/api/data/environments/{environment_id}", headers=admin_headers, json=payload) @@ -185,6 +190,14 @@ async def test_patch_session_environment( assert res.json.get("name") == "New name" assert res.json.get("description") == "New description." assert res.json.get("container_image") == "new_image:new_tag" + assert res.json.get("args") == args + assert res.json.get("command") == command + + # Test that patching with None will reset the command and args + payload = {"args": None, "command": None} + _, res = await sanic_client.patch(f"/api/data/environments/{environment_id}", headers=admin_headers, json=payload) + assert res.json.get("args") is None + assert res.json.get("command") is None @pytest.mark.asyncio @@ -540,6 +553,7 @@ async def test_starting_session_anonymous( admin_headers, launch_session, anonymous_user_headers, + cluster, ) -> None: _, res = await sanic_client.post( "/api/data/resource_pools", diff --git a/test/conftest.py b/test/conftest.py index fdbd357d3..5f5658a64 100644 --- a/test/conftest.py +++ b/test/conftest.py @@ -115,6 +115,8 @@ def secrets_key_pair(monkeypatch, tmp_path) -> None: @pytest.fixture def app_config(authz_config, db_config, monkeypatch, worker_id, secrets_key_pair) -> Generator[DataConfig, None, None]: monkeypatch.setenv("MAX_PINNED_PROJECTS", "5") + monkeypatch.setenv("NB_SERVER_OPTIONS__DEFAULTS_PATH", "server_defaults.json") + monkeypatch.setenv("NB_SERVER_OPTIONS__UI_CHOICES_PATH", "server_options.json") config = DataConfig.from_env() app_name = "app_" + str(ULID()).lower() + "_" + worker_id