Skip to content

Commit

Permalink
Update InjectAutoreloadMiddleware to be compatible with starlette >…
Browse files Browse the repository at this point in the history
…= 0.35.0 (#1013)

Co-authored-by: Winston Chang <winston@posit.co>
  • Loading branch information
schloerke and wch authored Mar 2, 2024
1 parent 7579f24 commit f0a05fa
Showing 1 changed file with 32 additions and 8 deletions.
40 changes: 32 additions & 8 deletions shiny/_autoreload.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,9 @@
import secrets
import threading
import webbrowser
from typing import Callable, Optional
from typing import Callable, Optional, cast

import starlette.types
from asgiref.typing import (
ASGI3Application,
ASGIReceiveCallable,
Expand Down Expand Up @@ -90,8 +91,19 @@ class InjectAutoreloadMiddleware:
because we want autoreload to be effective even when displaying an error page.
"""

def __init__(self, app: ASGI3Application):
self.app = app
def __init__(
self,
app: starlette.types.ASGIApp | ASGI3Application,
*args: object,
**kwargs: object,
):
if len(args) > 0 or len(kwargs) > 0:
raise TypeError(
f"InjectAutoreloadMiddleware does not support positional or keyword arguments, received {args}, {kwargs}"
)
# The starlette types and the asgiref types are compatible, but we'll use the
# latter internally. See the note in the __call__ method for more details.
self.app = cast(ASGI3Application, app)
ws_url = autoreload_url()
self.script = (
f""" <script src="__shared/shiny-autoreload.js" data-ws-url="{html.escape(ws_url)}"></script>
Expand All @@ -103,19 +115,31 @@ def __init__(self, app: ASGI3Application):
)

async def __call__(
self, scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable
self,
scope: starlette.types.Scope | Scope,
receive: starlette.types.Receive | ASGIReceiveCallable,
send: starlette.types.Send | ASGISendCallable,
) -> None:
if scope["type"] != "http" or scope["path"] != "/" or len(self.script) == 0:
return await self.app(scope, receive, send)
# The starlette types and the asgiref types are compatible, but the latter are
# more rigorous. In the call interface, we accept both types for compatibility
# with both. But internally we'll use the more rigorous types.
# See https://github.com/encode/starlette/blob/39dccd9/docs/middleware.md#type-annotations
scope = cast(Scope, scope)
receive_casted = cast(ASGIReceiveCallable, receive)
send_casted = cast(ASGISendCallable, send)
if scope["type"] != "http":
return await self.app(scope, receive_casted, send_casted)
if scope["path"] != "/" or len(self.script) == 0:
return await self.app(scope, receive_casted, send_casted)

def mangle_callback(body: bytes) -> tuple[bytes, bool]:
if b"</head>" in body:
return (body.replace(b"</head>", self.script, 1), True)
else:
return (body, False)

mangler = ResponseMangler(send, mangle_callback)
await self.app(scope, receive, mangler.send)
mangler = ResponseMangler(send_casted, mangle_callback)
await self.app(scope, receive_casted, mangler.send)


# PARENT PROCESS ------------------------------------------------------------
Expand Down

0 comments on commit f0a05fa

Please sign in to comment.