Skip to content

Commit

Permalink
Support WSGI/ASGI middleware (#65)
Browse files Browse the repository at this point in the history
  • Loading branch information
abersheeran authored Jul 26, 2024
1 parent 546ac17 commit 829fe62
Show file tree
Hide file tree
Showing 9 changed files with 437 additions and 75 deletions.
5 changes: 5 additions & 0 deletions baize/asgi/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
WebSocketDisconnect,
WebSocketState,
)
from baize.asgi.middleware import NextRequest, NextResponse, middleware, CachedStream

__all__ = [
"empty_receive",
Expand Down Expand Up @@ -50,4 +51,8 @@
"request_response",
"websocket_session",
"decorator",
"NextRequest",
"NextResponse",
"middleware",
"CachedStream",
]
141 changes: 141 additions & 0 deletions baize/asgi/middleware.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,141 @@
import functools
from tempfile import SpooledTemporaryFile
from typing import (
Any,
AsyncGenerator,
AsyncIterator,
Awaitable,
Callable,
MutableMapping,
)

from ..concurrency import run_in_threadpool
from ..datastructures import Headers
from ..typing import ASGIApp, Scope, Receive, Send, Message
from .requests import Request
from .responses import Response, StreamingResponse


class CachedStream(AsyncIterator[bytes]):
spool_max_size = 1024 * 1024

def __init__(self) -> None:
self._buffer = SpooledTemporaryFile(max_size=self.spool_max_size, mode="w+b")
self._pushed_eof = False

async def push(self, chunk: bytes) -> None:
if self._pushed_eof:
raise RuntimeError("Cannot push chunk after push EOF.") # pragma: no cover
await run_in_threadpool(self._buffer.write, chunk)

async def push_eof(self) -> None:
await run_in_threadpool(self._buffer.seek, 0)
self._pushed_eof = True

async def __anext__(self) -> bytes:
chunk = await run_in_threadpool(self._buffer.read, 4096 * 16)
if not chunk:
raise StopAsyncIteration
return chunk


class NextRequest(Request, MutableMapping[str, Any]):
def __setitem__(self, name: str, value: Any) -> None:
self._scope[name] = value

def __delitem__(self, name: str) -> None:
del self._scope[name]

def stream(self) -> AsyncIterator[bytes]:
raise RuntimeError("Cannot read request body in middleware.")


class NextResponse(StreamingResponse):
"""
This is a response object for middleware.
"""

async def render_stream(self) -> AsyncGenerator[bytes, None]:
async for chunk in self.iterable:
yield chunk

@classmethod
async def from_app(cls, app: ASGIApp, request: NextRequest) -> "NextResponse":
"""
This is a helper method to convert a ASGI application into a NextResponse object.
"""
status_code = 200
headers = Headers()
body = CachedStream()

async def send(message: Message) -> None:
nonlocal status_code
nonlocal headers
if message["type"] == "http.response.start":
status_code = message["status"]
headers = Headers(
[
(k.decode("latin-1"), v.decode("latin-1"))
for k, v in message.get("headers", [])
]
)
elif message["type"] == "http.response.body":
await body.push(message.get("body", b""))
if not message.get("more_body", False):
await body.push_eof()

await app(request, request._receive, send)
return NextResponse(body, status_code, headers)


def middleware(
handler: Callable[
[NextRequest, Callable[[NextRequest], Awaitable[NextResponse]]],
Awaitable[Response],
],
) -> Callable[[ASGIApp], ASGIApp]:
"""
This can turn a callable object into a middleware for ASGI application.
```python
@middleware
async def m(
request: NextRequest, next_call: Callable[[NextRequest], Awaitable[NextResponse]]
) -> Response:
...
response = await next_call(request)
...
return response
@m
@request_response
async def v(request: Request) -> Response:
...
# OR
@m
async def asgi(scope: Scope, receive: Receive, send: Send) -> None:
...
```
"""

@functools.wraps(handler)
def d(app: ASGIApp) -> ASGIApp:
"""
This is the actual middleware.
"""

@functools.wraps(app)
async def asgi(scope: Scope, receive: Receive, send: Send) -> None:
request = NextRequest(scope, receive, send)

async def next_call(request: NextRequest) -> NextResponse:
return await NextResponse.from_app(app, request)

response = await handler(request, next_call)
await response(scope, receive, send)

return asgi

return d
4 changes: 4 additions & 0 deletions baize/wsgi/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from baize.wsgi.routing import Hosts, Router, Subpaths
from baize.wsgi.shortcut import decorator, request_response
from baize.wsgi.staticfiles import Files, Pages
from baize.wsgi.middleware import NextRequest, NextResponse, middleware

__all__ = [
"HTTPConnection",
Expand All @@ -33,4 +34,7 @@
"Pages",
"request_response",
"decorator",
"NextRequest",
"NextResponse",
"middleware",
]
95 changes: 95 additions & 0 deletions baize/wsgi/middleware.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
import functools
from typing import Any, Callable, Generator, Iterable, Iterator, MutableMapping, Tuple

from ..datastructures import Headers
from ..typing import Environ, StartResponse, WSGIApp
from .requests import Request
from .responses import Response, StreamingResponse


class NextRequest(Request, MutableMapping[str, Any]):
def __setitem__(self, name: str, value: Any) -> None:
self._environ[name] = value

def __delitem__(self, name: str) -> None:
del self._environ[name]

def stream(self, chunk_size: int = 4096 * 16) -> Iterator[bytes]:
raise RuntimeError("Cannot read request body in middleware.")


class NextResponse(StreamingResponse):
"""
This is a response object for middleware.
"""

def render_stream(self) -> Generator[bytes, None, None]:
yield from self.iterable

@classmethod
def from_app(cls, app: WSGIApp, request: NextRequest) -> "NextResponse":
"""
This is a helper method to convert a WSGI application into a NextResponse object.
"""
status_code = 200
headers: Headers = Headers()

def start_response(
status: str, response_headers: Iterable[Tuple[str, str]], exc_info=None
) -> None:
nonlocal status_code
nonlocal headers
status_code = int(status.split(" ")[0])
headers = Headers(response_headers)

body = app(request, start_response)
return NextResponse(body, status_code, headers)


def middleware(
handler: Callable[[NextRequest, Callable[[NextRequest], NextResponse]], Response]
) -> Callable[[WSGIApp], WSGIApp]:
"""
This can turn a callable object into a middleware for WSGI application.
```python
@middleware
def m(request: NextRequest, next_call: Callable[[NextRequest], NextResponse]) -> Response:
...
response = next_call(request)
...
return response
@m
@request_response
def v(request: Request) -> Response:
...
# OR
@m
def wsgi(environ: Environ, start_response: StartResponse) -> Iterable[bytes]:
...
```
"""

@functools.wraps(handler)
def d(app: WSGIApp) -> WSGIApp:
"""
This is the actual middleware.
"""

@functools.wraps(app)
def wsgi(environ: Environ, start_response: StartResponse) -> Iterable[bytes]:
request = NextRequest(environ, start_response)

def next_call(request: NextRequest) -> NextResponse:
next_response = NextResponse.from_app(app, request)
return next_response

response = handler(request, next_call)
yield from response(environ, start_response)

return wsgi

return d
6 changes: 6 additions & 0 deletions docs/source/asgi.md
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,12 @@ If the built-in types are not enough, then you only need to write a class that i
.. autofunction:: baize.asgi.decorator
```

### middleware

```eval_rst
.. autofunction:: baize.asgi.middleware
```

### websocket_session

```eval_rst
Expand Down
6 changes: 6 additions & 0 deletions docs/source/wsgi.md
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,12 @@ If the built-in types are not enough, then you only need to write a class that i
.. autofunction:: baize.wsgi.decorator
```

### middleware

```eval_rst
.. autofunction:: baize.wsgi.middleware
```

## Files

```eval_rst
Expand Down
2 changes: 2 additions & 0 deletions speedup.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,13 +21,15 @@ def build(setup_kwargs):
not in (
"baize/multipart_helper.py",
# ASGI
"baize/asgi/middleware.py",
"baize/asgi/requests.py",
"baize/asgi/responses.py",
"baize/asgi/routing.py",
"baize/asgi/shortcut.py" if os.name == "nt" else None,
"baize/asgi/staticfiles.py",
"baize/asgi/websocket.py",
# WSGI
"baize/wsgi/middleware.py",
"baize/wsgi/requests.py",
"baize/wsgi/responses.py",
"baize/wsgi/routing.py",
Expand Down
Loading

0 comments on commit 829fe62

Please sign in to comment.