Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add 400 response when boundary is missing #1617

Merged
merged 12 commits into from
May 19, 2022
2 changes: 1 addition & 1 deletion starlette/applications.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
import typing

from starlette.datastructures import State, URLPath
from starlette.exceptions import ExceptionMiddleware
from starlette.middleware import Middleware
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.middleware.errors import ServerErrorMiddleware
from starlette.middleware.exceptions import ExceptionMiddleware
from starlette.requests import Request
from starlette.responses import Response
from starlette.routing import BaseRoute, Router
Expand Down
99 changes: 16 additions & 83 deletions starlette/exceptions.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,8 @@
import asyncio
import http
import typing
import warnings

from starlette.concurrency import run_in_threadpool
from starlette.requests import Request
from starlette.responses import PlainTextResponse, Response
from starlette.types import ASGIApp, Message, Receive, Scope, Send
__all__ = ("HTTPException",)


class HTTPException(Exception):
Expand All @@ -26,86 +23,22 @@ def __repr__(self) -> str:
return f"{class_name}(status_code={self.status_code!r}, detail={self.detail!r})"


class ExceptionMiddleware:
def __init__(
self,
app: ASGIApp,
handlers: typing.Optional[
typing.Mapping[typing.Any, typing.Callable[[Request, Exception], Response]]
] = None,
debug: bool = False,
) -> None:
self.app = app
self.debug = debug # TODO: We ought to handle 404 cases if debug is set.
self._status_handlers: typing.Dict[int, typing.Callable] = {}
self._exception_handlers: typing.Dict[
typing.Type[Exception], typing.Callable
] = {HTTPException: self.http_exception}
if handlers is not None:
for key, value in handlers.items():
self.add_exception_handler(key, value)

def add_exception_handler(
self,
exc_class_or_status_code: typing.Union[int, typing.Type[Exception]],
handler: typing.Callable[[Request, Exception], Response],
) -> None:
if isinstance(exc_class_or_status_code, int):
self._status_handlers[exc_class_or_status_code] = handler
else:
assert issubclass(exc_class_or_status_code, Exception)
self._exception_handlers[exc_class_or_status_code] = handler

def _lookup_exception_handler(
self, exc: Exception
) -> typing.Optional[typing.Callable]:
for cls in type(exc).__mro__:
if cls in self._exception_handlers:
return self._exception_handlers[cls]
return None

async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
if scope["type"] != "http":
await self.app(scope, receive, send)
return
__deprecated__ = "ExceptionMiddleware"

response_started = False

async def sender(message: Message) -> None:
nonlocal response_started
def __getattr__(name: str) -> typing.Any: # pragma: no cover
Copy link
Sponsor Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I need to add a test for this.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Didn't realise you could do this at a module level. 😎

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What's the plan to remove all of this deprecation code? I get why we need it, but it'd be nice to know when we're going to remove it so that:

  1. We can put it in the code for our knowledge down the road.
  2. We can put it in the error message.

I suppose this depends on #1623 , but I think we should make an executive decision (maybe @tomchristie ?) on when / if the deprecation shim is ever going to be removed.

Personally, I feel like 3 minor releases or 1 year, whichever comes first, sounds reasonable?

Copy link
Sponsor Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was more into 6 months and 3 minors... But any number is good for me... I'd just like to have a number 👍 But let's mention those stuff on that discussion.

For here, there's no plan yet. The plan should be defined on that discussion. If this is merged, it's just common sense i.e. in some arbitrary time we just remove it...

if name == __deprecated__:
from starlette.middleware.exceptions import ExceptionMiddleware

if message["type"] == "http.response.start":
response_started = True
await send(message)

try:
await self.app(scope, receive, sender)
except Exception as exc:
handler = None

if isinstance(exc, HTTPException):
handler = self._status_handlers.get(exc.status_code)

if handler is None:
handler = self._lookup_exception_handler(exc)

if handler is None:
raise exc

if response_started:
msg = "Caught handled exception, but response already started."
raise RuntimeError(msg) from exc
warnings.warn(
f"{__deprecated__} is deprecated on `starlette.exceptions`. "
f"Import it from `starlette.middleware.exceptions` instead.",
category=DeprecationWarning,
stacklevel=3,
)
return ExceptionMiddleware
raise AttributeError(f"module '{__name__}' has no attribute '{name}'")

request = Request(scope, receive=receive)
if asyncio.iscoroutinefunction(handler):
response = await handler(request, exc)
else:
response = await run_in_threadpool(handler, request, exc)
await response(scope, receive, sender)

def http_exception(self, request: Request, exc: HTTPException) -> Response:
if exc.status_code in {204, 304}:
return Response(status_code=exc.status_code, headers=exc.headers)
return PlainTextResponse(
exc.detail, status_code=exc.status_code, headers=exc.headers
)
def __dir__() -> typing.List[str]:
return sorted(list(__all__) + [__deprecated__]) # pragma: no cover
10 changes: 9 additions & 1 deletion starlette/formparsers.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,11 @@ def _user_safe_decode(src: bytes, codec: str) -> str:
return src.decode("latin-1")


class MultiPartException(Exception):
def __init__(self, message: str) -> None:
self.message = message


class FormParser:
def __init__(
self, headers: Headers, stream: typing.AsyncGenerator[bytes, None]
Expand Down Expand Up @@ -159,7 +164,10 @@ async def parse(self) -> FormData:
charset = params.get(b"charset", "utf-8")
if type(charset) == bytes:
charset = charset.decode("latin-1")
boundary = params[b"boundary"]
try:
boundary = params[b"boundary"]
except KeyError:
raise MultiPartException("Missing boundary in multipart.")

# Callbacks dictionary.
callbacks = {
Expand Down
93 changes: 93 additions & 0 deletions starlette/middleware/exceptions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
import asyncio
import typing

from starlette.concurrency import run_in_threadpool
from starlette.exceptions import HTTPException
from starlette.requests import Request
from starlette.responses import PlainTextResponse, Response
from starlette.types import ASGIApp, Message, Receive, Scope, Send


class ExceptionMiddleware:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👍 Yup - much better to have this here.

Have confirmed to myself that the class here and the class when it was in exceptions.py are identical.

def __init__(
self,
app: ASGIApp,
handlers: typing.Optional[
typing.Mapping[typing.Any, typing.Callable[[Request, Exception], Response]]
] = None,
debug: bool = False,
) -> None:
self.app = app
self.debug = debug # TODO: We ought to handle 404 cases if debug is set.
self._status_handlers: typing.Dict[int, typing.Callable] = {}
self._exception_handlers: typing.Dict[
typing.Type[Exception], typing.Callable
] = {HTTPException: self.http_exception}
if handlers is not None:
for key, value in handlers.items():
self.add_exception_handler(key, value)

def add_exception_handler(
self,
exc_class_or_status_code: typing.Union[int, typing.Type[Exception]],
handler: typing.Callable[[Request, Exception], Response],
) -> None:
if isinstance(exc_class_or_status_code, int):
self._status_handlers[exc_class_or_status_code] = handler
else:
assert issubclass(exc_class_or_status_code, Exception)
self._exception_handlers[exc_class_or_status_code] = handler

def _lookup_exception_handler(
self, exc: Exception
) -> typing.Optional[typing.Callable]:
for cls in type(exc).__mro__:
if cls in self._exception_handlers:
return self._exception_handlers[cls]
return None

async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
if scope["type"] != "http":
await self.app(scope, receive, send)
return

response_started = False

async def sender(message: Message) -> None:
nonlocal response_started

if message["type"] == "http.response.start":
response_started = True
await send(message)

try:
await self.app(scope, receive, sender)
except Exception as exc:
handler = None

if isinstance(exc, HTTPException):
handler = self._status_handlers.get(exc.status_code)

if handler is None:
handler = self._lookup_exception_handler(exc)

if handler is None:
raise exc

if response_started:
msg = "Caught handled exception, but response already started."
raise RuntimeError(msg) from exc

request = Request(scope, receive=receive)
if asyncio.iscoroutinefunction(handler):
response = await handler(request, exc)
else:
response = await run_in_threadpool(handler, request, exc)
await response(scope, receive, sender)

def http_exception(self, request: Request, exc: HTTPException) -> Response:
if exc.status_code in {204, 304}:
return Response(status_code=exc.status_code, headers=exc.headers)
return PlainTextResponse(
exc.detail, status_code=exc.status_code, headers=exc.headers
)
12 changes: 9 additions & 3 deletions starlette/requests.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@
import anyio

from starlette.datastructures import URL, Address, FormData, Headers, QueryParams, State
from starlette.formparsers import FormParser, MultiPartParser
from starlette.exceptions import HTTPException
from starlette.formparsers import FormParser, MultiPartException, MultiPartParser
from starlette.types import Message, Receive, Scope, Send

try:
Expand Down Expand Up @@ -250,8 +251,13 @@ async def form(self) -> FormData:
content_type_header = self.headers.get("Content-Type")
content_type, options = parse_options_header(content_type_header)
if content_type == b"multipart/form-data":
multipart_parser = MultiPartParser(self.headers, self.stream())
self._form = await multipart_parser.parse()
try:
multipart_parser = MultiPartParser(self.headers, self.stream())
self._form = await multipart_parser.parse()
except MultiPartException as exc:
if "app" in self.scope:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need to have the switch here? Can we just always coerce to an HTTPException here instead? Are we already using this pattern, and if so, where?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Sponsor Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep, we use the pattern in the code, as @adriangb pointed out.

Always coerce to an HTTPException is an option. I opted to not do it because we only use HTTPExceptions when we are in Starlette (as application) e.g. "app" in self.scope.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Realistically, I think we could remove that pattern everywhere. IMO it's not worth supporting (nice) usage of Starlette's Request object outside of a Starlette app. If someone wants to do that, they can handle the HTTPException.

Copy link
Sponsor Member Author

@Kludex Kludex May 19, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If someone wants to do that, they can handle the HTTPException.

You have a point, but I think that's another discussion.

I really don't think it's a burden to maintain that, but I wonder if someone really uses that code outside Starlette (app)... 🤔

raise HTTPException(status_code=400, detail=exc.message)
Kludex marked this conversation as resolved.
Show resolved Hide resolved
raise exc
elif content_type == b"application/x-www-form-urlencoded":
form_parser = FormParser(self.headers, self.stream())
self._form = await form_parser.parse()
Expand Down
18 changes: 17 additions & 1 deletion tests/test_exceptions.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
import warnings

import pytest

from starlette.exceptions import ExceptionMiddleware, HTTPException
from starlette.exceptions import HTTPException
from starlette.middleware.exceptions import ExceptionMiddleware
from starlette.responses import PlainTextResponse
from starlette.routing import Route, Router, WebSocketRoute

Expand Down Expand Up @@ -130,3 +133,16 @@ class CustomHTTPException(HTTPException):
assert repr(CustomHTTPException(500, detail="Something custom")) == (
"CustomHTTPException(status_code=500, detail='Something custom')"
)


def test_exception_middleware_deprecation() -> None:
# this test should be removed once the deprecation shim is removed
with pytest.warns(DeprecationWarning):
from starlette.exceptions import ExceptionMiddleware # noqa: F401

with warnings.catch_warnings():
warnings.simplefilter("error")
import starlette.exceptions

with pytest.warns(DeprecationWarning):
starlette.exceptions.ExceptionMiddleware
23 changes: 19 additions & 4 deletions tests/test_formparsers.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,15 @@
import os
import typing
from contextlib import nullcontext as does_not_raise

import pytest

from starlette.formparsers import UploadFile, _user_safe_decode
from starlette.applications import Starlette
from starlette.formparsers import MultiPartException, UploadFile, _user_safe_decode
from starlette.requests import Request
from starlette.responses import JSONResponse
from starlette.routing import Mount
from starlette.testclient import TestClient


class ForceMultipartDict(dict):
Expand Down Expand Up @@ -390,10 +394,19 @@ def test_user_safe_decode_ignores_wrong_charset():
assert result == "abc"


def test_missing_boundary_parameter(test_client_factory):
@pytest.mark.parametrize(
"app,expectation",
[
(app, pytest.raises(MultiPartException)),
(Starlette(routes=[Mount("/", app=app)]), does_not_raise()),
Kludex marked this conversation as resolved.
Show resolved Hide resolved
],
)
def test_missing_boundary_parameter(
app, expectation, test_client_factory: typing.Callable[..., TestClient]
) -> None:
client = test_client_factory(app)
with pytest.raises(KeyError, match="boundary"):
client.post(
with expectation:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd personally probably nudge towards the simpler pytest.raises(KeyError, match="boundary") case here, just because it's more obviously readable to me.

The parameterised expectation is neat, but also more complex to understand.

Just my personal perspective tho, happy to go either way.

Copy link
Sponsor Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The idea was to avoid the creation of two very similar tests. Do you think is clear if I create two tests instead?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think your way or combining an exception and non-exception test into 1 is the cleanest I've seen for that pattern, but if it's
only 2 tests and not 10 I also think making it more explicit would be nice

res = client.post(
"/",
data=(
# file
Expand All @@ -403,3 +416,5 @@ def test_missing_boundary_parameter(test_client_factory):
),
headers={"Content-Type": "multipart/form-data; charset=utf-8"},
)
assert res.status_code == 400
assert res.text == "Missing boundary in multipart."