-
-
Notifications
You must be signed in to change notification settings - Fork 906
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add 400 response when boundary
is missing
#1617
Changes from all commits
efc4342
b7b4e92
d2e190a
d52dd9a
f1eb5ba
7a7a31e
9245408
90aa31b
598cfac
ed1ccb2
57e16d6
234c908
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,11 +1,8 @@ | ||
import asyncio | ||
import http | ||
import typing | ||
import warnings | ||
|
||
from starlette.concurrency import run_in_threadpool | ||
from starlette.requests import Request | ||
from starlette.responses import PlainTextResponse, Response | ||
from starlette.types import ASGIApp, Message, Receive, Scope, Send | ||
__all__ = ("HTTPException",) | ||
|
||
|
||
class HTTPException(Exception): | ||
|
@@ -26,86 +23,22 @@ def __repr__(self) -> str: | |
return f"{class_name}(status_code={self.status_code!r}, detail={self.detail!r})" | ||
|
||
|
||
class ExceptionMiddleware: | ||
def __init__( | ||
self, | ||
app: ASGIApp, | ||
handlers: typing.Optional[ | ||
typing.Mapping[typing.Any, typing.Callable[[Request, Exception], Response]] | ||
] = None, | ||
debug: bool = False, | ||
) -> None: | ||
self.app = app | ||
self.debug = debug # TODO: We ought to handle 404 cases if debug is set. | ||
self._status_handlers: typing.Dict[int, typing.Callable] = {} | ||
self._exception_handlers: typing.Dict[ | ||
typing.Type[Exception], typing.Callable | ||
] = {HTTPException: self.http_exception} | ||
if handlers is not None: | ||
for key, value in handlers.items(): | ||
self.add_exception_handler(key, value) | ||
|
||
def add_exception_handler( | ||
self, | ||
exc_class_or_status_code: typing.Union[int, typing.Type[Exception]], | ||
handler: typing.Callable[[Request, Exception], Response], | ||
) -> None: | ||
if isinstance(exc_class_or_status_code, int): | ||
self._status_handlers[exc_class_or_status_code] = handler | ||
else: | ||
assert issubclass(exc_class_or_status_code, Exception) | ||
self._exception_handlers[exc_class_or_status_code] = handler | ||
|
||
def _lookup_exception_handler( | ||
self, exc: Exception | ||
) -> typing.Optional[typing.Callable]: | ||
for cls in type(exc).__mro__: | ||
if cls in self._exception_handlers: | ||
return self._exception_handlers[cls] | ||
return None | ||
|
||
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: | ||
if scope["type"] != "http": | ||
await self.app(scope, receive, send) | ||
return | ||
__deprecated__ = "ExceptionMiddleware" | ||
|
||
response_started = False | ||
|
||
async def sender(message: Message) -> None: | ||
nonlocal response_started | ||
def __getattr__(name: str) -> typing.Any: # pragma: no cover | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What's the plan to remove all of this deprecation code? I get why we need it, but it'd be nice to know when we're going to remove it so that:
I suppose this depends on #1623 , but I think we should make an executive decision (maybe @tomchristie ?) on when / if the deprecation shim is ever going to be removed. Personally, I feel like 3 minor releases or 1 year, whichever comes first, sounds reasonable? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I was more into 6 months and 3 minors... But any number is good for me... I'd just like to have a number 👍 But let's mention those stuff on that discussion. For here, there's no plan yet. The plan should be defined on that discussion. If this is merged, it's just common sense i.e. in some arbitrary time we just remove it... |
||
if name == __deprecated__: | ||
from starlette.middleware.exceptions import ExceptionMiddleware | ||
|
||
if message["type"] == "http.response.start": | ||
response_started = True | ||
await send(message) | ||
|
||
try: | ||
await self.app(scope, receive, sender) | ||
except Exception as exc: | ||
handler = None | ||
|
||
if isinstance(exc, HTTPException): | ||
handler = self._status_handlers.get(exc.status_code) | ||
|
||
if handler is None: | ||
handler = self._lookup_exception_handler(exc) | ||
|
||
if handler is None: | ||
raise exc | ||
|
||
if response_started: | ||
msg = "Caught handled exception, but response already started." | ||
raise RuntimeError(msg) from exc | ||
warnings.warn( | ||
f"{__deprecated__} is deprecated on `starlette.exceptions`. " | ||
f"Import it from `starlette.middleware.exceptions` instead.", | ||
category=DeprecationWarning, | ||
stacklevel=3, | ||
) | ||
return ExceptionMiddleware | ||
raise AttributeError(f"module '{__name__}' has no attribute '{name}'") | ||
|
||
request = Request(scope, receive=receive) | ||
if asyncio.iscoroutinefunction(handler): | ||
response = await handler(request, exc) | ||
else: | ||
response = await run_in_threadpool(handler, request, exc) | ||
await response(scope, receive, sender) | ||
|
||
def http_exception(self, request: Request, exc: HTTPException) -> Response: | ||
if exc.status_code in {204, 304}: | ||
return Response(status_code=exc.status_code, headers=exc.headers) | ||
return PlainTextResponse( | ||
exc.detail, status_code=exc.status_code, headers=exc.headers | ||
) | ||
def __dir__() -> typing.List[str]: | ||
return sorted(list(__all__) + [__deprecated__]) # pragma: no cover |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,93 @@ | ||
import asyncio | ||
import typing | ||
|
||
from starlette.concurrency import run_in_threadpool | ||
from starlette.exceptions import HTTPException | ||
from starlette.requests import Request | ||
from starlette.responses import PlainTextResponse, Response | ||
from starlette.types import ASGIApp, Message, Receive, Scope, Send | ||
|
||
|
||
class ExceptionMiddleware: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 👍 Yup - much better to have this here. Have confirmed to myself that the class here and the class when it was in |
||
def __init__( | ||
self, | ||
app: ASGIApp, | ||
handlers: typing.Optional[ | ||
typing.Mapping[typing.Any, typing.Callable[[Request, Exception], Response]] | ||
] = None, | ||
debug: bool = False, | ||
) -> None: | ||
self.app = app | ||
self.debug = debug # TODO: We ought to handle 404 cases if debug is set. | ||
self._status_handlers: typing.Dict[int, typing.Callable] = {} | ||
self._exception_handlers: typing.Dict[ | ||
typing.Type[Exception], typing.Callable | ||
] = {HTTPException: self.http_exception} | ||
if handlers is not None: | ||
for key, value in handlers.items(): | ||
self.add_exception_handler(key, value) | ||
|
||
def add_exception_handler( | ||
self, | ||
exc_class_or_status_code: typing.Union[int, typing.Type[Exception]], | ||
handler: typing.Callable[[Request, Exception], Response], | ||
) -> None: | ||
if isinstance(exc_class_or_status_code, int): | ||
self._status_handlers[exc_class_or_status_code] = handler | ||
else: | ||
assert issubclass(exc_class_or_status_code, Exception) | ||
self._exception_handlers[exc_class_or_status_code] = handler | ||
|
||
def _lookup_exception_handler( | ||
self, exc: Exception | ||
) -> typing.Optional[typing.Callable]: | ||
for cls in type(exc).__mro__: | ||
if cls in self._exception_handlers: | ||
return self._exception_handlers[cls] | ||
return None | ||
|
||
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: | ||
if scope["type"] != "http": | ||
await self.app(scope, receive, send) | ||
return | ||
|
||
response_started = False | ||
|
||
async def sender(message: Message) -> None: | ||
nonlocal response_started | ||
|
||
if message["type"] == "http.response.start": | ||
response_started = True | ||
await send(message) | ||
|
||
try: | ||
await self.app(scope, receive, sender) | ||
except Exception as exc: | ||
handler = None | ||
|
||
if isinstance(exc, HTTPException): | ||
handler = self._status_handlers.get(exc.status_code) | ||
|
||
if handler is None: | ||
handler = self._lookup_exception_handler(exc) | ||
|
||
if handler is None: | ||
raise exc | ||
|
||
if response_started: | ||
msg = "Caught handled exception, but response already started." | ||
raise RuntimeError(msg) from exc | ||
|
||
request = Request(scope, receive=receive) | ||
if asyncio.iscoroutinefunction(handler): | ||
response = await handler(request, exc) | ||
else: | ||
response = await run_in_threadpool(handler, request, exc) | ||
await response(scope, receive, sender) | ||
|
||
def http_exception(self, request: Request, exc: HTTPException) -> Response: | ||
if exc.status_code in {204, 304}: | ||
return Response(status_code=exc.status_code, headers=exc.headers) | ||
return PlainTextResponse( | ||
exc.detail, status_code=exc.status_code, headers=exc.headers | ||
) |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -6,7 +6,8 @@ | |
import anyio | ||
|
||
from starlette.datastructures import URL, Address, FormData, Headers, QueryParams, State | ||
from starlette.formparsers import FormParser, MultiPartParser | ||
from starlette.exceptions import HTTPException | ||
from starlette.formparsers import FormParser, MultiPartException, MultiPartParser | ||
from starlette.types import Message, Receive, Scope, Send | ||
|
||
try: | ||
|
@@ -250,8 +251,13 @@ async def form(self) -> FormData: | |
content_type_header = self.headers.get("Content-Type") | ||
content_type, options = parse_options_header(content_type_header) | ||
if content_type == b"multipart/form-data": | ||
multipart_parser = MultiPartParser(self.headers, self.stream()) | ||
self._form = await multipart_parser.parse() | ||
try: | ||
multipart_parser = MultiPartParser(self.headers, self.stream()) | ||
self._form = await multipart_parser.parse() | ||
except MultiPartException as exc: | ||
if "app" in self.scope: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Do we need to have the switch here? Can we just always coerce to an There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yep, we use the pattern in the code, as @adriangb pointed out. Always coerce to an There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Realistically, I think we could remove that pattern everywhere. IMO it's not worth supporting (nice) usage of Starlette's Request object outside of a Starlette app. If someone wants to do that, they can handle the HTTPException. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
You have a point, but I think that's another discussion. I really don't think it's a burden to maintain that, but I wonder if someone really uses that code outside |
||
raise HTTPException(status_code=400, detail=exc.message) | ||
Kludex marked this conversation as resolved.
Show resolved
Hide resolved
|
||
raise exc | ||
elif content_type == b"application/x-www-form-urlencoded": | ||
form_parser = FormParser(self.headers, self.stream()) | ||
self._form = await form_parser.parse() | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,11 +1,15 @@ | ||
import os | ||
import typing | ||
from contextlib import nullcontext as does_not_raise | ||
|
||
import pytest | ||
|
||
from starlette.formparsers import UploadFile, _user_safe_decode | ||
from starlette.applications import Starlette | ||
from starlette.formparsers import MultiPartException, UploadFile, _user_safe_decode | ||
from starlette.requests import Request | ||
from starlette.responses import JSONResponse | ||
from starlette.routing import Mount | ||
from starlette.testclient import TestClient | ||
|
||
|
||
class ForceMultipartDict(dict): | ||
|
@@ -390,10 +394,19 @@ def test_user_safe_decode_ignores_wrong_charset(): | |
assert result == "abc" | ||
|
||
|
||
def test_missing_boundary_parameter(test_client_factory): | ||
@pytest.mark.parametrize( | ||
"app,expectation", | ||
[ | ||
(app, pytest.raises(MultiPartException)), | ||
(Starlette(routes=[Mount("/", app=app)]), does_not_raise()), | ||
Kludex marked this conversation as resolved.
Show resolved
Hide resolved
|
||
], | ||
) | ||
def test_missing_boundary_parameter( | ||
app, expectation, test_client_factory: typing.Callable[..., TestClient] | ||
) -> None: | ||
client = test_client_factory(app) | ||
with pytest.raises(KeyError, match="boundary"): | ||
client.post( | ||
with expectation: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'd personally probably nudge towards the simpler The parameterised Just my personal perspective tho, happy to go either way. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The idea was to avoid the creation of two very similar tests. Do you think is clear if I create two tests instead? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think your way or combining an exception and non-exception test into 1 is the cleanest I've seen for that pattern, but if it's |
||
res = client.post( | ||
"/", | ||
data=( | ||
# file | ||
|
@@ -403,3 +416,5 @@ def test_missing_boundary_parameter(test_client_factory): | |
), | ||
headers={"Content-Type": "multipart/form-data; charset=utf-8"}, | ||
) | ||
assert res.status_code == 400 | ||
assert res.text == "Missing boundary in multipart." |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I need to add a test for this.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Didn't realise you could do this at a module level. 😎