Skip to content

Commit

Permalink
fix: allow session launcher parameters to be reset
Browse files Browse the repository at this point in the history
Allows the API to accept None as input for args, command and the session
launcher resource class ID so that they can be reset to their defaults
in patch endpoints.
  • Loading branch information
olevski committed Oct 1, 2024
1 parent c78b345 commit 659480c
Show file tree
Hide file tree
Showing 7 changed files with 223 additions and 109 deletions.
14 changes: 14 additions & 0 deletions components/renku_data_services/base_models/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,3 +212,17 @@ class Authenticator(Protocol[AnyAPIUser]):
async def authenticate(self, access_token: str, request: Request) -> AnyAPIUser:
"""Validates the user credentials (i.e. we can say that the user is a valid Renku user)."""
...


@dataclass(frozen=True, eq=True, kw_only=True)
class Null:
"""Parent class for distinguishing between None values."""

value: None = field(default=None, init=False, repr=False)


@dataclass(frozen=True, eq=True, kw_only=True)
class Reset(Null):
"""Used to indicate a None value that has been deliberately set by the user or caller."""

...
38 changes: 10 additions & 28 deletions components/renku_data_services/session/blueprints.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
import renku_data_services.base_models as base_models
from renku_data_services.base_api.auth import authenticate, validate_path_project_id
from renku_data_services.base_api.blueprint import BlueprintFactoryResponse, CustomBlueprint
from renku_data_services.session import apispec, models
from renku_data_services.session import apispec, converters, models
from renku_data_services.session.db import SessionRepository


Expand Down Expand Up @@ -75,9 +75,11 @@ def patch(self) -> BlueprintFactoryResponse:
async def _patch(
_: Request, user: base_models.APIUser, environment_id: ULID, body: apispec.EnvironmentPatch
) -> JSONResponse:
body_dict = body.model_dump(exclude_none=True)
update = converters.environment_update_from_patch(body)
environment = await self.session_repo.update_environment(
user=user, environment_id=environment_id, **body_dict
user=user,
environment_id=environment_id,
update=update,
)
return json(apispec.Environment.model_validate(environment).model_dump(exclude_none=True, mode="json"))

Expand Down Expand Up @@ -172,34 +174,14 @@ def patch(self) -> BlueprintFactoryResponse:
async def _patch(
_: Request, user: base_models.APIUser, launcher_id: ULID, body: apispec.SessionLauncherPatch
) -> JSONResponse:
body_dict = body.model_dump(exclude_none=True, mode="json")
async with self.session_repo.session_maker() as session, session.begin():
current_launcher = await self.session_repo.get_launcher(user, launcher_id)
new_env: models.UnsavedEnvironment | None = None
if (
isinstance(body.environment, apispec.EnvironmentPatchInLauncher)
and current_launcher.environment.environment_kind == models.EnvironmentKind.GLOBAL
and body.environment.environment_kind == apispec.EnvironmentKind.CUSTOM
):
# This means that the global environment is being swapped for a custom one,
# so we have to create a brand new environment, but we have to validate here.
validated_env = apispec.EnvironmentPostInLauncher.model_validate(body_dict.pop("environment"))
new_env = models.UnsavedEnvironment(
name=validated_env.name,
description=validated_env.description,
container_image=validated_env.container_image,
default_url=validated_env.default_url,
port=validated_env.port,
working_directory=PurePosixPath(validated_env.working_directory),
mount_directory=PurePosixPath(validated_env.mount_directory),
uid=validated_env.uid,
gid=validated_env.gid,
environment_kind=models.EnvironmentKind(validated_env.environment_kind.value),
args=validated_env.args,
command=validated_env.command,
)
update = converters.launcher_update_from_patch(body, current_launcher)
launcher = await self.session_repo.update_launcher(
user=user, launcher_id=launcher_id, new_custom_environment=new_env, session=session, **body_dict
user=user,
launcher_id=launcher_id,
session=session,
update=update,
)
return json(apispec.SessionLauncher.model_validate(launcher).model_dump(exclude_none=True, mode="json"))

Expand Down
83 changes: 83 additions & 0 deletions components/renku_data_services/session/converters.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
"""Code used to convert from/to apispec and models."""

from pathlib import PurePosixPath

from renku_data_services.base_models.core import Reset
from renku_data_services.session import apispec, models


def environment_update_from_patch(data: apispec.EnvironmentPatch) -> models.EnvironmentUpdate:
"""Create an update object from an apispec or any other pydantic model."""
data_dict = data.model_dump(exclude_unset=True, mode="json")
working_directory: PurePosixPath | None = None
if data.working_directory is not None:
working_directory = PurePosixPath(data.working_directory)
mount_directory: PurePosixPath | None = None
if data.mount_directory is not None:
mount_directory = PurePosixPath(data.mount_directory)
# NOTE: If the args or command are present in the data_dict and they are None they were passed in by the user.
# The None specifically passed by the user indicates that the value should be removed from the DB.
args = Reset() if "args" in data_dict and data_dict["args"] is None else data.args
command = Reset() if "command" in data_dict and data_dict["command"] is None else data.command
return models.EnvironmentUpdate(
name=data.name,
description=data.description,
container_image=data.container_image,
default_url=data.default_url,
port=data.port,
working_directory=working_directory,
mount_directory=mount_directory,
uid=data.uid,
gid=data.gid,
args=args,
command=command,
)


def launcher_update_from_patch(
data: apispec.SessionLauncherPatch,
current_launcher: models.SessionLauncher | None = None,
) -> models.SessionLauncherUpdate:
"""Create an update object from an apispec or any other pydantic model."""
data_dict = data.model_dump(exclude_unset=True, mode="json")
environment: str | models.EnvironmentUpdate | models.UnsavedEnvironment | None = None
if (
isinstance(data.environment, apispec.EnvironmentPatchInLauncher)
and current_launcher is not None
and current_launcher.environment.environment_kind == models.EnvironmentKind.GLOBAL
and data.environment.environment_kind == apispec.EnvironmentKind.CUSTOM
):
# This means that the global environment is being swapped for a custom one,
# so we have to create a brand new environment, but we have to validate here.
validated_env = apispec.EnvironmentPostInLauncher.model_validate(data_dict["environment"])
environment = models.UnsavedEnvironment(
name=validated_env.name,
description=validated_env.description,
container_image=validated_env.container_image,
default_url=validated_env.default_url,
port=validated_env.port,
working_directory=PurePosixPath(validated_env.working_directory),
mount_directory=PurePosixPath(validated_env.mount_directory),
uid=validated_env.uid,
gid=validated_env.gid,
environment_kind=models.EnvironmentKind(validated_env.environment_kind.value),
args=validated_env.args,
command=validated_env.command,
)
elif isinstance(data.environment, apispec.EnvironmentPatchInLauncher):
environment = environment_update_from_patch(data.environment)
elif isinstance(data.environment, apispec.EnvironmentIdOnlyPatch):
environment = data.environment.id
resource_class_id: int | None | Reset = None
if "resource_class_id" in data_dict and data_dict["resource_class_id"] is None:
# NOTE: This means that the resource class set in the DB should be removed so that the
# default resource class currently set in the CRC will be used.
resource_class_id = Reset()
else:
resource_class_id = data_dict.get("resource_class_id")
return models.SessionLauncherUpdate(
name=data_dict.get("name"),
description=data_dict.get("description"),
environment=environment,
resource_class_id=resource_class_id,
)
147 changes: 68 additions & 79 deletions components/renku_data_services/session/db.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
from collections.abc import Callable
from contextlib import AbstractAsyncContextManager, nullcontext
from datetime import UTC, datetime
from typing import Any

from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
Expand All @@ -15,6 +14,7 @@
from renku_data_services import errors
from renku_data_services.authz.authz import Authz, ResourceType
from renku_data_services.authz.models import Scope
from renku_data_services.base_models.core import Reset
from renku_data_services.crc.db import ResourcePoolRepository
from renku_data_services.session import models
from renku_data_services.session import orm as schemas
Expand Down Expand Up @@ -101,53 +101,59 @@ async def insert_environment(
env = await self.__insert_environment(user, session, new_environment)
return env.dump()

async def __update_environment(
def __update_environment(
self,
user: base_models.APIUser,
session: AsyncSession,
environment_id: ULID,
kind: models.EnvironmentKind,
**kwargs: dict,
) -> models.Environment:
res = await session.scalars(
select(schemas.EnvironmentORM)
.where(schemas.EnvironmentORM.id == str(environment_id))
.where(schemas.EnvironmentORM.environment_kind == kind.value)
)
environment = res.one_or_none()
if environment is None:
raise errors.MissingResourceError(message=f"Session environment with id '{environment_id}' does not exist.")

for key, value in kwargs.items():
# NOTE: Only some fields can be edited
if key in [
"name",
"description",
"container_image",
"default_url",
"port",
"working_directory",
"mount_directory",
"uid",
"gid",
"args",
"command",
]:
setattr(environment, key, value)

return environment.dump()
environment: schemas.EnvironmentORM,
update: models.EnvironmentUpdate,
) -> None:
# NOTE: this is more verbose than a loop and setattr but this way we get mypy type checks
if update.name is not None:
environment.name = update.name
if update.description is not None:
environment.description = update.description
if update.container_image is not None:
environment.container_image = update.container_image
if update.default_url is not None:
environment.default_url = update.default_url
if update.port is not None:
environment.port = update.port
if update.working_directory is not None:
environment.working_directory = update.working_directory
if update.mount_directory is not None:
environment.mount_directory = update.mount_directory
if update.uid is not None:
environment.uid = update.uid
if update.gid is not None:
environment.gid = update.gid
if isinstance(update.args, Reset):
environment.args = None
elif isinstance(update.args, list):
environment.args = update.args
if isinstance(update.command, Reset):
environment.command = None
elif isinstance(update.command, list):
environment.command = update.command

async def update_environment(
self, user: base_models.APIUser, environment_id: ULID, **kwargs: dict
self, user: base_models.APIUser, environment_id: ULID, update: models.EnvironmentUpdate
) -> models.Environment:
"""Update a global session environment entry."""
if not user.is_admin:
raise errors.UnauthorizedError(message="You do not have the required permissions for this operation.")

async with self.session_maker() as session, session.begin():
return await self.__update_environment(
user, session, environment_id, models.EnvironmentKind.GLOBAL, **kwargs
res = await session.scalars(
select(schemas.EnvironmentORM)
.where(schemas.EnvironmentORM.id == str(environment_id))
.where(schemas.EnvironmentORM.environment_kind == models.EnvironmentKind.GLOBAL)
)
environment = res.one_or_none()
if environment is None:
raise errors.MissingResourceError(
message=f"Session environment with id '{environment_id}' does not exist."
)
self.__update_environment(environment, update)
return environment.dump()

async def delete_environment(self, user: base_models.APIUser, environment_id: ULID) -> None:
"""Delete a global session environment entry."""
Expand Down Expand Up @@ -300,9 +306,8 @@ async def update_launcher(
self,
user: base_models.APIUser,
launcher_id: ULID,
new_custom_environment: models.UnsavedEnvironment | None,
update: models.SessionLauncherUpdate,
session: AsyncSession | None = None,
**kwargs: Any,
) -> models.SessionLauncher:
"""Update a session launcher entry."""
if not user.is_authenticated or user.id is None:
Expand Down Expand Up @@ -336,8 +341,8 @@ async def update_launcher(
if not authorized:
raise errors.ForbiddenError(message="You do not have the required permissions for this operation.")

resource_class_id = kwargs.get("resource_class_id")
if resource_class_id is not None:
resource_class_id = update.resource_class_id
if isinstance(resource_class_id, int):
res = await session.scalars(
select(schemas.ResourceClassORM).where(schemas.ResourceClassORM.id == resource_class_id)
)
Expand All @@ -354,30 +359,32 @@ async def update_launcher(
message=f"You do not have access to resource class with id '{resource_class_id}'."
)

for key, value in kwargs.items():
# NOTE: Only some fields can be updated.
if key in [
"name",
"description",
"resource_class_id",
]:
setattr(launcher, key, value)

env_payload = kwargs.get("environment", {})
await self.__update_launcher_environment(user, launcher, session, new_custom_environment, **env_payload)
# NOTE: Only some fields can be updated.
if update.name is not None:
launcher.name = update.name
if update.description is not None:
launcher.description = update.description
if isinstance(update.resource_class_id, int):
launcher.resource_class_id = update.resource_class_id
elif isinstance(update.resource_class_id, Reset):
launcher.resource_class_id = None

if update.environment is None:
return launcher.dump()

await self.__update_launcher_environment(user, launcher, session, update.environment)
return launcher.dump()

async def __update_launcher_environment(
self,
user: base_models.APIUser,
launcher: schemas.SessionLauncherORM,
session: AsyncSession,
new_custom_environment: models.UnsavedEnvironment | None,
**kwargs: Any,
update: models.EnvironmentUpdate | models.UnsavedEnvironment | str,
) -> None:
current_env_kind = launcher.environment.environment_kind
match new_custom_environment, current_env_kind, kwargs:
case None, _, {"id": env_id, **nothing_else} if len(nothing_else) == 0:
match update, current_env_kind:
case str() as env_id, _:
# The environment in the launcher is set via ID, the new ID has to refer
# to an environment that is GLOBAL.
old_environment = launcher.environment
Expand All @@ -404,29 +411,11 @@ async def __update_launcher_environment(
# We remove the custom environment to avoid accumulating custom environments that are not associated
# with any launchers.
await session.delete(old_environment)
case None, models.EnvironmentKind.CUSTOM, {**rest} if (
rest.get("environment_kind") is None
or rest.get("environment_kind") == models.EnvironmentKind.CUSTOM.value
):
case models.EnvironmentUpdate(), models.EnvironmentKind.CUSTOM:
# Custom environment being updated
for key, val in rest.items():
# NOTE: Only some fields can be updated.
if key in [
"name",
"description",
"container_image",
"default_url",
"port",
"working_directory",
"mount_directory",
"uid",
"gid",
"args",
"command",
]:
setattr(launcher.environment, key, val)
case models.UnsavedEnvironment(), models.EnvironmentKind.GLOBAL, {**nothing_else} if (
len(nothing_else) == 0 and new_custom_environment.environment_kind == models.EnvironmentKind.CUSTOM
self.__update_environment(launcher.environment, update)
case models.UnsavedEnvironment() as new_custom_environment, models.EnvironmentKind.GLOBAL if (
new_custom_environment.environment_kind == models.EnvironmentKind.CUSTOM
):
# Global environment replaced by a custom one
new_env = await self.__insert_environment(user, session, new_custom_environment)
Expand Down
Loading

0 comments on commit 659480c

Please sign in to comment.