Skip to content

Commit

Permalink
feat: Add reject middleware for web security
Browse files Browse the repository at this point in the history
- Implemented reject middleware to enhance security.
- Updated server to integrate the new middleware.

close #2923
  • Loading branch information
HyeockJinKim committed Oct 22, 2024
1 parent 0dd465d commit c4d8c5d
Show file tree
Hide file tree
Showing 2 changed files with 50 additions and 1 deletion.
45 changes: 45 additions & 0 deletions src/ai/backend/web/security.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
from typing import Callable, Iterable

from aiohttp import web


@web.middleware
async def security_policy_middleware(request: web.Request, handler) -> web.StreamResponse:
security_policy = request.app["security_policy"]
security_policy.check_request(request)
return await handler(request)


class SecurityPolicy:
def __init__(
self,
request_policies: Iterable[Callable[[web.Request], None]],
response_policies: Iterable[Callable[[web.Response], web.Response]],
):
self.request_policies = request_policies
self.response_policies = response_policies

@classmethod
def default_policy(cls) -> "SecurityPolicy":
request_policies = [reject_metadata_local_link]
response_policies = [add_self_content_security_policy, set_content_type_nosniff]
return cls(request_policies, response_policies)

def check_request(self, request: web.Request):
for policy in self.request_policies:
policy(request)


def reject_metadata_local_link(request: web.Request):
if request.host == "169.254.169.254":
raise web.HTTPForbidden()


def add_self_content_security_policy(response: web.Response) -> web.Response:
response.headers["Content-Security-Policy"] = "default-src 'self'"
return response


def set_content_type_nosniff(response: web.Response) -> web.Response:
response.headers["X-Content-Type-Options"] = "nosniff"
return response
6 changes: 5 additions & 1 deletion src/ai/backend/web/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
from ai.backend.common.web.session import setup as setup_session
from ai.backend.common.web.session.redis_storage import RedisStorage
from ai.backend.logging import BraceStyleAdapter, Logger, LogLevel
from ai.backend.web.security import SecurityPolicy, security_policy_middleware

from . import __version__, user_agent
from .auth import fill_forwarding_hdrs_to_api_session, get_client_ip
Expand Down Expand Up @@ -603,8 +604,11 @@ async def server_main(
args: Tuple[Any, ...],
) -> AsyncIterator[None]:
config = args[0]
app = web.Application(middlewares=[decrypt_payload, track_active_handlers])
app = web.Application(
middlewares=[decrypt_payload, track_active_handlers, security_policy_middleware]
)
app["config"] = config
app["security_policy"] = SecurityPolicy.default_policy()
j2env = jinja2.Environment(
extensions=[
"ai.backend.web.template.TOMLField",
Expand Down

0 comments on commit c4d8c5d

Please sign in to comment.