Skip to content

Commit

Permalink
chore: use PurePosixPath type in sqlalchemy
Browse files Browse the repository at this point in the history
  • Loading branch information
olevski committed Aug 28, 2024
1 parent ed279fb commit 09ca8b3
Show file tree
Hide file tree
Showing 6 changed files with 37 additions and 19 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
mount_dir: str = "/home/jovyan/work"
uid: int = 1000
gid: int = 1000
port: int = 4180
port: int = 8888


def upgrade() -> None:
Expand Down
8 changes: 3 additions & 5 deletions components/renku_data_services/session/db.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,8 +74,8 @@ async def __insert_environment(
container_image=new_environment.container_image,
default_url=new_environment.default_url,
port=new_environment.port,
working_directory=new_environment.working_directory.as_posix(),
mount_directory=new_environment.mount_directory.as_posix(),
working_directory=new_environment.working_directory,
mount_directory=new_environment.mount_directory,
uid=new_environment.uid,
gid=new_environment.gid,
environment_kind=new_environment.environment_kind,
Expand Down Expand Up @@ -229,9 +229,7 @@ async def insert_launcher(
raise errors.UnauthorizedError(message="You do not have the required permissions for this operation.")

project_id = new_launcher.project_id
authorized = await self.project_authz.has_permission(
user, ResourceType.project, ULID.from_str(project_id), Scope.WRITE
)
authorized = await self.project_authz.has_permission(user, ResourceType.project, project_id, Scope.WRITE)
if not authorized:
raise errors.MissingResourceError(
message=f"Project with id '{project_id}' does not exist or you do not have access to it."
Expand Down
10 changes: 5 additions & 5 deletions components/renku_data_services/session/orm.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from renku_data_services.crc.orm import ResourceClassORM
from renku_data_services.project.orm import ProjectORM
from renku_data_services.session import models
from renku_data_services.utils.sqlalchemy import ULIDType
from renku_data_services.utils.sqlalchemy import PurePosixPathType, ULIDType

metadata_obj = MetaData(schema="sessions") # Has to match alembic ini section name
JSONVariant = JSON().with_variant(JSONB(), "postgresql")
Expand Down Expand Up @@ -51,8 +51,8 @@ class EnvironmentORM(BaseORM):
"""Default URL path to open in a session."""

port: Mapped[int] = mapped_column("port")
working_directory: Mapped[str] = mapped_column("working_directory", String())
mount_directory: Mapped[str] = mapped_column("mount_directory", String())
working_directory: Mapped[PurePosixPath] = mapped_column("working_directory", PurePosixPathType)
mount_directory: Mapped[PurePosixPath] = mapped_column("mount_directory", PurePosixPathType)
uid: Mapped[int] = mapped_column("uid")
gid: Mapped[int] = mapped_column("gid")
environment_kind: Mapped[models.EnvironmentKind] = mapped_column("environment_kind")
Expand All @@ -72,8 +72,8 @@ def dump(self) -> models.Environment:
gid=self.gid,
uid=self.uid,
environment_kind=self.environment_kind,
mount_directory=PurePosixPath(self.mount_directory),
working_directory=PurePosixPath(self.working_directory),
mount_directory=self.mount_directory,
working_directory=self.working_directory,
port=self.port,
args=self.args,
command=self.command,
Expand Down
20 changes: 20 additions & 0 deletions components/renku_data_services/utils/sqlalchemy.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""Utilities for SQLAlchemy."""

from pathlib import PurePosixPath
from typing import cast

from sqlalchemy import Dialect, types
Expand All @@ -23,3 +24,22 @@ def process_result_value(self, value: str | None, dialect: Dialect) -> ULID | No
if value is None:
return None
return cast(ULID, ULID.from_str(value)) # cast because mypy doesn't understand ULID type annotations


class PurePosixPathType(types.TypeDecorator):
"""Wrapper type for Path <--> str conversion."""

impl = types.String
cache_ok = True

def process_bind_param(self, value: PurePosixPath | None, dialect: Dialect) -> str | None:
"""Transform value for storing in the database."""
if value is None:
return None
return value.as_posix()

def process_result_value(self, value: str | None, dialect: Dialect) -> PurePosixPath | None:
"""Transform string from database into PosixPath."""
if value is None:
return None
return PurePosixPath(value)
14 changes: 7 additions & 7 deletions test/bases/renku_data_services/data_api/test_migrations.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,9 +100,8 @@ async def test_migration_to_f34b87ddd954(


@pytest.mark.asyncio
async def test_migration_to_584598f3b769(sanic_client_no_migrations: SanicASGITestClient, app_config: Config) -> None:
async def test_migration_to_584598f3b769(app_config: Config) -> None:
run_migrations_for_app("common", "dcc1c1ee662f")
sanic_client = sanic_client_no_migrations
await app_config.kc_user_repo.initialize(app_config.kc_api)
await app_config.group_repo.generate_user_namespaces()
env_id = str(ULID())
Expand All @@ -122,14 +121,15 @@ async def test_migration_to_584598f3b769(sanic_client_no_migrations: SanicASGITe
)
)
run_migrations_for_app("common", "584598f3b769")
_, response = await sanic_client.get("/api/data/environments")
assert response.status_code == 200, response.text
assert len(response.json) == 1
env = response.json[0]
async with app_config.db.async_session_maker() as session, session.begin():
res = await session.execute(sa.text("SELECT * FROM sessions.environments"))
data = res.all()
assert len(data) == 1
env = data[0]._mapping
assert env["id"] == env_id
assert env["name"] == "test"
assert env["container_image"] == "test"
assert env["default_url"] == "/test"
assert env["port"] == 4180
assert env["port"] == 8888
assert env["uid"] == 1000
assert env["gid"] == 1000
2 changes: 1 addition & 1 deletion test/bases/renku_data_services/data_api/test_sessions.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,7 @@ async def test_patch_session_environment_unauthorized(

_, res = await sanic_client.patch(f"/api/data/environments/{environment_id}", headers=user_headers, json=payload)

assert res.status_code == 403, res.text
assert res.status_code == 401, res.text


@pytest.mark.asyncio
Expand Down

0 comments on commit 09ca8b3

Please sign in to comment.