From fdb456a1750fb5c1cb7d0127236b00a976035931 Mon Sep 17 00:00:00 2001 From: Michiel <918128+mvanderlee@users.noreply.github.com> Date: Sat, 14 May 2022 15:59:57 -0300 Subject: [PATCH] initial commit --- .editorconfig | 24 + .gitignore | 127 +++++ .pre-commit-config.yaml | 33 ++ LICENSE.md | 21 + examples/goals.ipynb | 439 ++++++++++++++++++ examples/sample_server.py | 48 ++ requirements.txt | 23 + setup.cfg | 12 + setup.py | 50 ++ starmallow/__init__.py | 1 + starmallow/applications.py | 362 +++++++++++++++ starmallow/constants.py | 2 + starmallow/exception_handlers.py | 30 ++ starmallow/exceptions.py | 13 + starmallow/params.py | 54 +++ starmallow/routing.py | 764 +++++++++++++++++++++++++++++++ starmallow/types.py | 3 + starmallow/utils.py | 63 +++ 18 files changed, 2069 insertions(+) create mode 100644 .editorconfig create mode 100644 .gitignore create mode 100644 .pre-commit-config.yaml create mode 100644 LICENSE.md create mode 100644 examples/goals.ipynb create mode 100644 examples/sample_server.py create mode 100644 requirements.txt create mode 100644 setup.cfg create mode 100644 setup.py create mode 100644 starmallow/__init__.py create mode 100644 starmallow/applications.py create mode 100644 starmallow/constants.py create mode 100644 starmallow/exception_handlers.py create mode 100644 starmallow/exceptions.py create mode 100644 starmallow/params.py create mode 100644 starmallow/routing.py create mode 100644 starmallow/types.py create mode 100644 starmallow/utils.py diff --git a/.editorconfig b/.editorconfig new file mode 100644 index 0000000..9e6705d --- /dev/null +++ b/.editorconfig @@ -0,0 +1,24 @@ +# EditorConfig is awesome: http://EditorConfig.org + +# top-most EditorConfig file +root = true + +# Unix-style newlines with a newline ending every file +[*] +charset = utf-8 +end_of_line = lf +insert_final_newline = true +trim_trailing_whitespace = true +indent_style = space +indent_size = 2 + +[*.{yml,yaml}] +indent_size = 2 + +[*.py] +indent_size = 4 + +[{Makefile,makefile,**.mk}] +# Use tabs for indentation (Makefiles require tabs) +indent_style = tab +indent_size = 4 diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..c4d8708 --- /dev/null +++ b/.gitignore @@ -0,0 +1,127 @@ +# Created by .ignore support plugin (hsz.mobi) +### Python template +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +# C extensions +*.so + +# Distribution / packaging +.Python +env/ +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +*.egg-info/ +.installed.cfg +*.egg + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*,cover +.pytest_cache/ + +# Translations +*.mo +*.pot + +# Django stuff: +*.log + +# Sphinx documentation +docs/_build/ + +# PyBuilder +target/ +### JetBrains template +# Covers JetBrains IDEs: IntelliJ, RubyMine, PhpStorm, AppCode, PyCharm, CLion, Android Studio + +*.iml + +## Directory-based project format: +.idea/ +# if you remove the above rule, at least ignore the following: + +# User-specific stuff: +# .idea/workspace.xml +# .idea/tasks.xml +# .idea/dictionaries + +# Sensitive or high-churn files: +# .idea/dataSources.ids +# .idea/dataSources.xml +# .idea/sqlDataSources.xml +# .idea/dynamic.xml +# .idea/uiDesigner.xml + +# Gradle: +# .idea/gradle.xml +# .idea/libraries + +# Mongo Explorer plugin: +# .idea/mongoSettings.xml + +## File-based project format: +*.ipr +*.iws + +## Plugin-specific files: + +# IntelliJ +/out/ + +# mpeltonen/sbt-idea plugin +.idea_modules/ + +# JIRA plugin +atlassian-ide-plugin.xml + +# Crashlytics plugin (for Android Studio and IntelliJ) +com_crashlytics_export_strings.xml +crashlytics.properties +crashlytics-build.properties +### IPythonNotebook template +# Temporary data +.ipynb_checkpoints/ +### VirtualEnv template +# Virtualenv +# http://iamzed.com/2009/05/07/a-primer-on-virtualenv/ +.Python +[Bb]in +[Ii]nclude +[Ll]ib +pyvenv.cfg +pip-selfcheck.json + +.vscode +.venv +/.env +.minio.sys/ +tests/data/minio/bucket1/Storage/ +scripts/*.jar diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml new file mode 100644 index 0000000..ad308a9 --- /dev/null +++ b/.pre-commit-config.yaml @@ -0,0 +1,33 @@ +exclude: '^.*.ipynb$' +repos: +- repo: https://github.com/pre-commit/pre-commit-hooks + rev: v4.0.1 + hooks: + - id: check-added-large-files + - id: check-shebang-scripts-are-executable + - id: check-yaml + - id: detect-aws-credentials + - id: detect-private-key + - id: mixed-line-ending + - id: end-of-file-fixer + - id: trailing-whitespace + +- repo: https://github.com/asottile/add-trailing-comma + rev: v2.1.0 + hooks: + - id: add-trailing-comma + +- repo: https://github.com/timothycrosley/isort + rev: 5.8.0 + hooks: + - id: isort + +- repo: https://gitlab.com/pycqa/flake8 + rev: 3.8.4 + hooks: + - id: flake8 + additional_dependencies: + - flake8-commas + - flake8-comprehensions + - flake8-isort + - flake8-printf-formatting diff --git a/LICENSE.md b/LICENSE.md new file mode 100644 index 0000000..3e92463 --- /dev/null +++ b/LICENSE.md @@ -0,0 +1,21 @@ +The MIT License (MIT) + +Copyright (c) 2018 Sebastián Ramírez + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +THE SOFTWARE. diff --git a/examples/goals.ipynb b/examples/goals.ipynb new file mode 100644 index 0000000..aecd4c8 --- /dev/null +++ b/examples/goals.ipynb @@ -0,0 +1,439 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# StarMallow Goals/examples" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Detect Schema from marshmallow-dataclass\n", + "\n", + "By default, read from JSON body and return JSON body" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import marshmallow.fields as mf\n", + "from marshmallow_dataclass import dataclass as ma_dataclass\n", + "from starmallow import APIRouter\n", + "from starmallow.params import (\n", + " Header, \n", + " Body, \n", + " Query,\n", + " Path,\n", + " Cookie,\n", + " Form,\n", + " File,\n", + ")\n", + "\n", + "@ma_dataclass\n", + "class CreateRequest:\n", + " pass\n", + "\n", + "@ma_dataclass\n", + "class CreateResponse:\n", + " pass\n", + "\n", + "\n", + "router = APIRouter(prefix='/user', name='user')\n", + "\n", + "## Read from json body \n", + "@router.post('/create', status_code=202)\n", + "async def create_user(\n", + " create_request: CreateRequest,\n", + " create_request_json: CreateRequest = Body(None),\n", + " create_request_query: CreateRequest = Query(None),\n", + " create_request_path: CreateRequest = Path(None),\n", + " create_request_header: CreateRequest = Header(None),\n", + " create_request_cookie: CreateRequest = Cookie(None),\n", + " create_request_form: CreateRequest = Form(None),\n", + " create_request_file: CreateRequest = File(None),\n", + "\n", + " # Custom schema\n", + " create_request_custom: CreateRequest = Body(None, model=CreateRequest.Schema),\n", + "\n", + " # Allow usign marshmallow fields directly as well?\n", + " my_int: int = Body(None),\n", + " email: str = Body(None, schema=mf.Email())\n", + ") -> CreateResponse:\n", + " pass\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Playground" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "from typing import Any, Optional, Union\n", + "\n", + "import marshmallow as ma\n", + "import marshmallow.fields as mf\n", + "\n", + "\n", + "class FieldInfo:\n", + "\n", + " def __init__(\n", + " self,\n", + " default: Any,\n", + " *,\n", + " deprecated: Optional[bool] = None,\n", + " include_in_schema: bool = True,\n", + " model: Union[ma.Schema, mf.Field] = None,\n", + " ) -> None:\n", + " self.default = default\n", + " self.deprecated = deprecated\n", + " self.include_in_schema = include_in_schema\n", + " self.model = model\n", + "\n", + "\n", + "class Path(FieldInfo):\n", + " pass\n", + "\n", + "\n", + "class Query(FieldInfo):\n", + " pass\n", + "\n", + "\n", + "class Header(FieldInfo):\n", + " pass\n", + "\n", + "\n", + "class Cookie(FieldInfo):\n", + " pass\n", + "\n", + "\n", + "class Body(FieldInfo):\n", + " pass\n", + "\n", + "\n", + "class Form(FieldInfo):\n", + " pass\n", + "\n", + "\n", + "class File(FieldInfo):\n", + " pass\n" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "FOO\n", + "bar\n" + ] + } + ], + "source": [ + "\n", + "from typing import Any, Callable, TypeVar, Generic, Optional, _SpecialForm\n", + "\n", + "import marshmallow.fields as mf\n", + "from marshmallow_dataclass import dataclass as ma_dataclass\n", + "\n", + "@ma_dataclass\n", + "class CreateRequest:\n", + " my_int2: int\n", + "\n", + "@ma_dataclass\n", + "class CreateResponse:\n", + " pass\n", + "\n", + "\n", + " \n", + "@ma_dataclass\n", + "class Foobar:\n", + " foo: str\n", + " bar: str\n", + "\n", + "def test(foobar: Foobar) -> None:\n", + " print(foobar.foo)\n", + " print(foobar.bar)\n", + "\n", + "test(Foobar(foo='FOO', bar='bar'))\n", + "\n", + "\n", + "async def create_user(\n", + " create_request: CreateRequest,\n", + " create_request_json: CreateRequest = Body(...),\n", + " create_request_query: CreateRequest = Query(...),\n", + " create_request_path: CreateRequest = Path(...),\n", + " create_request_header: CreateRequest = Header(...),\n", + " create_request_cookie: CreateRequest = Cookie(...),\n", + " create_request_form: CreateRequest = Form(...),\n", + " \n", + " # Custom schema\n", + " create_request_custom: CreateRequest = Body(..., model=CreateRequest.Schema),\n", + "\n", + " # Allow usign marshmallow fields directly as well?\n", + " my_int: Optional[int] = Body(5),\n", + " email: str = Body(..., model=mf.Email()),\n", + "\n", + " my_int2: int = Header(...),\n", + "\n", + " optional: Optional[CreateRequest] = Body(None),\n", + ") -> CreateResponse:\n", + " pass\n" + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "<__main__.Body at 0x1ed665cc888>" + ] + }, + "execution_count": 18, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "import inspect\n", + "\n", + "parameters = dict(inspect.signature(create_user).parameters)\n", + "\n", + "parameters['create_request'].default\n", + "parameters['create_request_json'].default\n" + ] + }, + { + "cell_type": "code", + "execution_count": 29, + "metadata": {}, + "outputs": [], + "source": [ + "import marshmallow as ma\n", + "import marshmallow.fields as mf\n", + "\n", + "t = mf.Integer(required=True)\n", + "\n", + "try:\n", + " t.deserialize('1d', 'foobar', {'foobar': '1d'})\n", + "except Exception as e:\n", + " ex = e" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": 30, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "{'messages': ['Not a valid integer.'],\n", + " 'field_name': '_schema',\n", + " 'data': None,\n", + " 'valid_data': None,\n", + " 'kwargs': {}}" + ] + }, + "execution_count": 30, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "ex.__dict__" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Validation response\n", + "## FastAPI\n", + "{\n", + " \"detail\": [\n", + " {\n", + " \"loc\": [\n", + " \"path\",\n", + " \"item_id\"\n", + " ],\n", + " \"msg\": \"value is not a valid integer\",\n", + " \"type\": \"type_error.integer\"\n", + " }\n", + " ],\n", + " \"error\": \"ValidationError\",\n", + " \"status_code\": 422\n", + "}\n", + "\n", + "## Flask-Smorest\n", + "{\n", + " \"code\": 422,\n", + " \"errors\": {\n", + " \"json\": {\n", + " \"columns\": [\n", + " \"Missing data for required field.\"\n", + " ],\n", + " \"name\": [\n", + " \"Missing data for required field.\"\n", + " ],\n", + " \"order_columns\": [\n", + " \"Missing data for required field.\"\n", + " ],\n", + " \"sql_query\": [\n", + " \"Missing data for required field.\"\n", + " ],\n", + " \"table_names\": [\n", + " \"Missing data for required field.\"\n", + " ]\n", + " }\n", + " },\n", + " \"status\": \"Unprocessable Entity\"\n", + "}" + ] + }, + { + "cell_type": "code", + "execution_count": 32, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "{'b': 2, 'a': 'A'}" + ] + }, + "execution_count": 32, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "import marshmallow as ma\n", + "import marshmallow.fields as mf\n", + "from dataclasses import dataclass\n", + "\n", + "@dataclass\n", + "class Foobar:\n", + " a: str\n", + " b: int\n", + "\n", + "class FoobarSchema(ma.Schema):\n", + " a = mf.String()\n", + " b = mf.Integer()\n", + "\n", + "\n", + "f = Foobar('A', 2)\n", + "FoobarSchema().dump(f)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Formdata request with multiple content-types" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import http.client\n", + "import mimetypes\n", + "from codecs import encode\n", + "\n", + "conn = http.client.HTTPConnection(\"127.0.0.1\", 8000)\n", + "dataList = []\n", + "boundary = 'wL36Yn8afVp8Ag7AmP8qZ0SA4n1v9T'\n", + "dataList.append(encode('--' + boundary))\n", + "dataList.append(encode('Content-Disposition: form-data; name=foobar_json;'))\n", + "\n", + "dataList.append(encode('Content-Type: {}'.format('application/json')))\n", + "dataList.append(encode(''))\n", + "\n", + "dataList.append(encode('{\"foo\": \"bar\"}'))\n", + "dataList.append(encode('--' + boundary))\n", + "dataList.append(encode('Content-Disposition: form-data; name=foobar;'))\n", + "\n", + "dataList.append(encode('Content-Type: {}'.format('text/plain')))\n", + "dataList.append(encode(''))\n", + "\n", + "dataList.append(encode(\"foo\"))\n", + "dataList.append(encode('--' + boundary))\n", + "dataList.append(encode('Content-Disposition: form-data; name=file; filename={0}'.format('/C:/Users/maste/Downloads/new_user_credentials.csv')))\n", + "\n", + "fileType = mimetypes.guess_type('/C:/Users/maste/Downloads/new_user_credentials.csv')[0] or 'application/octet-stream'\n", + "dataList.append(encode('Content-Type: {}'.format(fileType)))\n", + "dataList.append(encode(''))\n", + "\n", + "with open('C:/Users/maste/Downloads/new_user_credentials.csv', 'rb') as f:\n", + " dataList.append(f.read())\n", + "dataList.append(encode('--'+boundary+'--'))\n", + "dataList.append(encode(''))\n", + "body = b'\\r\\n'.join(dataList)\n", + "payload = body\n", + "headers = {\n", + " 'Content-type': 'multipart/form-data; boundary={}'.format(boundary) \n", + "}\n", + "conn.request(\"POST\", \"/query/run\", payload, headers)\n", + "res = conn.getresponse()\n", + "data = res.read()\n", + "print(data.decode(\"utf-8\"))" + ] + } + ], + "metadata": { + "interpreter": { + "hash": "f7ddab2c409f2639957015aff5e68e05e927823e1d836cfda5995cef4fd43b5c" + }, + "kernelspec": { + "display_name": "Python 3.7.9 ('.venv': venv)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.7.9" + }, + "orig_nbformat": 4 + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/examples/sample_server.py b/examples/sample_server.py new file mode 100644 index 0000000..2a79545 --- /dev/null +++ b/examples/sample_server.py @@ -0,0 +1,48 @@ +# Add starmallow package without installing it +import os.path +import sys + +sys.path.insert(1, os.path.abspath('../')) + + +from dataclasses import asdict + +from marshmallow_dataclass import dataclass as ma_dataclass + +from starmallow.applications import StarMallow +from starmallow.params import Body, Header + +app = StarMallow() + + +@ma_dataclass +class CreateRequest: + my_string: str + my_int: int = 5 + + +@ma_dataclass +class CreateResponse: + my_string: str + + +@app.post('/test') +async def test( + create_request: CreateRequest, + limit: int, + my_string: str = Body(...), + authorization: str = Header(...), +) -> CreateResponse: + print(create_request) + print(limit) + print(authorization) + print(my_string) + + return create_request + + +@app.get('/test2') +def test(create_request: CreateRequest) -> CreateResponse: + print(create_request) + + return asdict(create_request) diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..e3fcea0 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,23 @@ +# Contains all requirements for development purposes +apispec[marshmallow]>=5.1,<6 +marshmallow>=3.11.0,<4 +marshmallow-dataclass>=8.3.1,<9 +python-multipart >=0.0.5,<0.0.6 +pyyaml>=5,<6 +starlette>=0.19,<1 + +# dev only requirements +uvicorn[standard]>=0.17.6,<0.18 +colorama +flake8>=3.9.2<4 +flake8-commas +flake8-comprehensions +flake8-isort +flake8-printf-formatting +isort>=5.8.0,<6 +pre-commit +pyclean +pyflakes +pytest-flake8 +pytest-cov +pytest-sugar diff --git a/setup.cfg b/setup.cfg new file mode 100644 index 0000000..80291e8 --- /dev/null +++ b/setup.cfg @@ -0,0 +1,12 @@ +[flake8] +per-file-ignores= + __init__.py:F401 + conftest.py:E402 + app.py:E402 +ignore=E731,E501,W503,C901,C408,E266 +max-complexity = 15 +exclude=conftest.py,scripts/,.venv/ + +[isort] +profile=hug +line_length=100 diff --git a/setup.py b/setup.py new file mode 100644 index 0000000..7e2f2e4 --- /dev/null +++ b/setup.py @@ -0,0 +1,50 @@ +#!/usr/bin/env python + +import os +import re + +from setuptools import find_packages, setup + + +def get_version(package): + """ + Return package version as listed in `__version__` in `init.py`. + """ + with open(os.path.join(package, "__init__.py")) as f: + return re.search("__version__ = ['\"]([^'\"]+)['\"]", f.read()).group(1) + + +setup( + name='StarMallow', + version=get_version("starmallow"), + description='TechLock Query Service', + author='Michiel Vanderlee', + author_email='jmt.vanderlee@gmail.com', + license="MIT", + packages=find_packages(exclude=('tests', 'docs')), + include_package_data=True, + install_requires=[ + 'apispec[marshmallow]>=5.1,<6', + "marshmallow>=3.11.0,<4", + "marshmallow-dataclass>=8.3.1,<9", + "python-multipart >=0.0.5,<0.0.6", + "pyyaml>=5,<6", + "starlette>=0.19,<1", + ], + classifiers=[ + "Development Status :: 3 - Alpha", + "Environment :: Web Environment", + "Intended Audience :: Developers", + "License :: OSI Approved :: MIT License", + "Operating System :: OS Independent", + "Topic :: Internet :: WWW/HTTP", + "Typing :: Typed", + "Framework :: AnyIO", + "Programming Language :: Python :: 3 :: Only", + "Programming Language :: Python :: 3.6", + "Programming Language :: Python :: 3.7", + "Programming Language :: Python :: 3.8", + "Programming Language :: Python :: 3.9", + "Programming Language :: Python :: 3.10", + ], +) diff --git a/starmallow/__init__.py b/starmallow/__init__.py new file mode 100644 index 0000000..3dc1f76 --- /dev/null +++ b/starmallow/__init__.py @@ -0,0 +1 @@ +__version__ = "0.1.0" diff --git a/starmallow/applications.py b/starmallow/applications.py new file mode 100644 index 0000000..73091d0 --- /dev/null +++ b/starmallow/applications.py @@ -0,0 +1,362 @@ +from typing import ( + Any, + AsyncContextManager, + Awaitable, + Callable, + Dict, + List, + Mapping, + Optional, + Sequence, + Set, + Type, + Union, +) + +from starlette.applications import Starlette +from starlette.datastructures import State +from starlette.exceptions import HTTPException +from starlette.middleware import Middleware +from starlette.requests import Request +from starlette.responses import JSONResponse, Response +from starlette.routing import BaseRoute + +from starmallow.exception_handlers import ( + http_exception_handler, + request_validation_exception_handler, +) +from starmallow.exceptions import RequestValidationError +from starmallow.routing import APIRoute, APIRouter +from starmallow.types import DecoratedCallable +from starmallow.utils import generate_unique_id + + +class StarMallow(Starlette): + + def __init__( + self, + *args, + debug: bool = False, + routes: Optional[List[BaseRoute]] = None, + middleware: Sequence[Middleware] = None, + exception_handlers: Mapping[ + Any, + Callable[ + [Request, Exception], Union[Response, Awaitable[Response]] + ], + ] = None, + on_startup: Optional[Sequence[Callable[[], Any]]] = None, + on_shutdown: Optional[Sequence[Callable[[], Any]]] = None, + lifespan: Callable[[Starlette], AsyncContextManager] = None, + **kwargs, + ) -> None: + # The lifespan context function is a newer style that replaces + # on_startup / on_shutdown handlers. Use one or the other, not both. + assert lifespan is None or ( + on_startup is None and on_shutdown is None + ), "Use either 'lifespan' or 'on_startup'/'on_shutdown', not both." + + self._debug = debug + self.state = State() + self.router: APIRouter = APIRouter( + routes=routes, + on_startup=on_startup, + on_shutdown=on_shutdown, + lifespan=lifespan, + ) + self.exception_handlers = ( + {} if exception_handlers is None else dict(exception_handlers) + ) + self.exception_handlers.setdefault(HTTPException, http_exception_handler) + self.exception_handlers.setdefault( + RequestValidationError, request_validation_exception_handler + ) + + self.user_middleware = [] if middleware is None else list(middleware) + self.middleware_stack = self.build_middleware_stack() + + def add_api_route( + self, + path: str, + endpoint: Callable[..., Any], + *, + methods: Optional[Union[Set[str], List[str]]] = None, + name: str = None, + include_in_schema: bool = True, + status_code: Optional[int] = None, + response_model: Optional[Type[Any]] = None, + response_class: Type[Response] = JSONResponse, + # Sets the OpenAPI operationId to be used in your path operation + operation_id: Optional[str] = None, + # If operation_id is None, this function will be used to create one. + generate_unique_id_function: Callable[["APIRoute"], str] = generate_unique_id, + # Will be deeply merged with the automatically generated OpenAPI schema for the path operation. + openapi_extra: Optional[Dict[str, Any]] = None, + ) -> None: + return self.router.add_api_route( + path=path, + endpoint=endpoint, + methods=methods, + name=name, + include_in_schema=include_in_schema, + status_code=status_code, + response_model=response_model, + response_class=response_class, + operation_id=operation_id, + generate_unique_id_function=generate_unique_id_function, + openapi_extra=openapi_extra, + ) + + def api_route( + self, + path: str, + *, + methods: Optional[Union[Set[str], List[str]]] = None, + name: str = None, + include_in_schema: bool = True, + status_code: Optional[int] = None, + response_model: Optional[Type[Any]] = None, + response_class: Type[Response] = JSONResponse, + # Sets the OpenAPI operationId to be used in your path operation + operation_id: Optional[str] = None, + # If operation_id is None, this function will be used to create one. + generate_unique_id_function: Callable[["APIRoute"], str] = generate_unique_id, + # Will be deeply merged with the automatically generated OpenAPI schema for the path operation. + openapi_extra: Optional[Dict[str, Any]] = None, + ) -> Callable[[DecoratedCallable], DecoratedCallable]: + return self.router.api_route( + path=path, + methods=methods, + name=name, + include_in_schema=include_in_schema, + status_code=status_code, + response_model=response_model, + response_class=response_class, + operation_id=operation_id, + generate_unique_id_function=generate_unique_id_function, + openapi_extra=openapi_extra, + ) + + def get( + self, + path: str, + *, + name: str = None, + include_in_schema: bool = True, + status_code: Optional[int] = None, + response_model: Optional[Type[Any]] = None, + response_class: Type[Response] = JSONResponse, + # Sets the OpenAPI operationId to be used in your path operation + operation_id: Optional[str] = None, + # If operation_id is None, this function will be used to create one. + generate_unique_id_function: Callable[[APIRoute], str] = generate_unique_id, + # Will be deeply merged with the automatically generated OpenAPI schema for the path operation. + openapi_extra: Optional[Dict[str, Any]] = None, + ): + return self.router.get( + path, + name=name, + include_in_schema=include_in_schema, + status_code=status_code, + response_model=response_model, + response_class=response_class, + operation_id=operation_id, + generate_unique_id_function=generate_unique_id_function, + openapi_extra=openapi_extra, + ) + + def put( + self, + path: str, + *, + name: str = None, + include_in_schema: bool = True, + status_code: Optional[int] = None, + response_model: Optional[Type[Any]] = None, + response_class: Type[Response] = JSONResponse, + # Sets the OpenAPI operationId to be used in your path operation + operation_id: Optional[str] = None, + # If operation_id is None, this function will be used to create one. + generate_unique_id_function: Callable[[APIRoute], str] = generate_unique_id, + # Will be deeply merged with the automatically generated OpenAPI schema for the path operation. + openapi_extra: Optional[Dict[str, Any]] = None, + ): + return self.router.put( + path, + name=name, + include_in_schema=include_in_schema, + status_code=status_code, + response_model=response_model, + response_class=response_class, + operation_id=operation_id, + generate_unique_id_function=generate_unique_id_function, + openapi_extra=openapi_extra, + ) + + def post( + self, + path: str, + *, + name: str = None, + include_in_schema: bool = True, + status_code: Optional[int] = None, + response_model: Optional[Type[Any]] = None, + response_class: Type[Response] = JSONResponse, + # Sets the OpenAPI operationId to be used in your path operation + operation_id: Optional[str] = None, + # If operation_id is None, this function will be used to create one. + generate_unique_id_function: Callable[[APIRoute], str] = generate_unique_id, + # Will be deeply merged with the automatically generated OpenAPI schema for the path operation. + openapi_extra: Optional[Dict[str, Any]] = None, + ): + return self.router.post( + path, + name=name, + include_in_schema=include_in_schema, + status_code=status_code, + response_model=response_model, + response_class=response_class, + operation_id=operation_id, + generate_unique_id_function=generate_unique_id_function, + openapi_extra=openapi_extra, + ) + + def delete( + self, + path: str, + *, + name: str = None, + include_in_schema: bool = True, + status_code: Optional[int] = None, + response_model: Optional[Type[Any]] = None, + response_class: Type[Response] = JSONResponse, + # Sets the OpenAPI operationId to be used in your path operation + operation_id: Optional[str] = None, + # If operation_id is None, this function will be used to create one. + generate_unique_id_function: Callable[[APIRoute], str] = generate_unique_id, + # Will be deeply merged with the automatically generated OpenAPI schema for the path operation. + openapi_extra: Optional[Dict[str, Any]] = None, + ): + return self.router.delete( + path, + name=name, + include_in_schema=include_in_schema, + status_code=status_code, + response_model=response_model, + response_class=response_class, + operation_id=operation_id, + generate_unique_id_function=generate_unique_id_function, + openapi_extra=openapi_extra, + ) + + def options( + self, + path: str, + *, + name: str = None, + include_in_schema: bool = True, + status_code: Optional[int] = None, + response_model: Optional[Type[Any]] = None, + response_class: Type[Response] = JSONResponse, + # Sets the OpenAPI operationId to be used in your path operation + operation_id: Optional[str] = None, + # If operation_id is None, this function will be used to create one. + generate_unique_id_function: Callable[[APIRoute], str] = generate_unique_id, + # Will be deeply merged with the automatically generated OpenAPI schema for the path operation. + openapi_extra: Optional[Dict[str, Any]] = None, + ): + return self.router.options( + path, + name=name, + include_in_schema=include_in_schema, + status_code=status_code, + response_model=response_model, + response_class=response_class, + operation_id=operation_id, + generate_unique_id_function=generate_unique_id_function, + openapi_extra=openapi_extra, + ) + + def head( + self, + path: str, + *, + name: str = None, + include_in_schema: bool = True, + status_code: Optional[int] = None, + response_model: Optional[Type[Any]] = None, + response_class: Type[Response] = JSONResponse, + # Sets the OpenAPI operationId to be used in your path operation + operation_id: Optional[str] = None, + # If operation_id is None, this function will be used to create one. + generate_unique_id_function: Callable[[APIRoute], str] = generate_unique_id, + # Will be deeply merged with the automatically generated OpenAPI schema for the path operation. + openapi_extra: Optional[Dict[str, Any]] = None, + ): + return self.router.head( + path, + name=name, + include_in_schema=include_in_schema, + status_code=status_code, + response_model=response_model, + response_class=response_class, + operation_id=operation_id, + generate_unique_id_function=generate_unique_id_function, + openapi_extra=openapi_extra, + ) + + def patch( + self, + path: str, + *, + name: str = None, + include_in_schema: bool = True, + status_code: Optional[int] = None, + response_model: Optional[Type[Any]] = None, + response_class: Type[Response] = JSONResponse, + # Sets the OpenAPI operationId to be used in your path operation + operation_id: Optional[str] = None, + # If operation_id is None, this function will be used to create one. + generate_unique_id_function: Callable[[APIRoute], str] = generate_unique_id, + # Will be deeply merged with the automatically generated OpenAPI schema for the path operation. + openapi_extra: Optional[Dict[str, Any]] = None, + ): + return self.router.patch( + path, + name=name, + include_in_schema=include_in_schema, + status_code=status_code, + response_model=response_model, + response_class=response_class, + operation_id=operation_id, + generate_unique_id_function=generate_unique_id_function, + openapi_extra=openapi_extra, + ) + + def trace( + self, + path: str, + *, + name: str = None, + include_in_schema: bool = True, + status_code: Optional[int] = None, + response_model: Optional[Type[Any]] = None, + response_class: Type[Response] = JSONResponse, + # Sets the OpenAPI operationId to be used in your path operation + operation_id: Optional[str] = None, + # If operation_id is None, this function will be used to create one. + generate_unique_id_function: Callable[[APIRoute], str] = generate_unique_id, + # Will be deeply merged with the automatically generated OpenAPI schema for the path operation. + openapi_extra: Optional[Dict[str, Any]] = None, + ): + return self.router.trace( + path, + name=name, + include_in_schema=include_in_schema, + status_code=status_code, + response_model=response_model, + response_class=response_class, + operation_id=operation_id, + generate_unique_id_function=generate_unique_id_function, + openapi_extra=openapi_extra, + ) diff --git a/starmallow/constants.py b/starmallow/constants.py new file mode 100644 index 0000000..fe28ee7 --- /dev/null +++ b/starmallow/constants.py @@ -0,0 +1,2 @@ +METHODS_WITH_BODY = {"GET", "HEAD", "POST", "PUT", "DELETE", "PATCH"} +STATUS_CODES_WITH_NO_BODY = {100, 101, 102, 103, 204, 304} diff --git a/starmallow/exception_handlers.py b/starmallow/exception_handlers.py new file mode 100644 index 0000000..2152d07 --- /dev/null +++ b/starmallow/exception_handlers.py @@ -0,0 +1,30 @@ + +from starlette.exceptions import HTTPException +from starlette.requests import Request +from starlette.responses import JSONResponse + +from starmallow.exceptions import RequestValidationError + + +async def http_exception_handler(request: Request, exc: HTTPException) -> JSONResponse: + headers = getattr(exc, "headers", None) + if headers: + return JSONResponse( + {"detail": exc.detail}, status_code=exc.status_code, headers=headers + ) + else: + return JSONResponse({"detail": exc.detail}, status_code=exc.status_code) + + +async def request_validation_exception_handler( + request: Request, + exc: RequestValidationError, +) -> JSONResponse: + return JSONResponse( + status_code=exc.status_code, + content={ + "detail": exc.errors, + "error": "ValidationError", + "status_code": exc.status_code, + }, + ) diff --git a/starmallow/exceptions.py b/starmallow/exceptions.py new file mode 100644 index 0000000..217f080 --- /dev/null +++ b/starmallow/exceptions.py @@ -0,0 +1,13 @@ +from typing import Any, Dict, List, Union + +from starlette.exceptions import HTTPException +from starlette.status import HTTP_422_UNPROCESSABLE_ENTITY + + +class RequestValidationError(HTTPException): + def __init__( + self, + errors: Dict[str, Union[Any, List, Dict]], + ) -> None: + super().__init__(status_code=HTTP_422_UNPROCESSABLE_ENTITY) + self.errors = errors diff --git a/starmallow/params.py b/starmallow/params.py new file mode 100644 index 0000000..a030725 --- /dev/null +++ b/starmallow/params.py @@ -0,0 +1,54 @@ +from enum import Enum +from typing import Any, Optional, Union + +import marshmallow as ma +import marshmallow.fields as mf + + +class ParamType(Enum): + path = 'path' + query = 'query' + header = 'header' + cookie = 'cookie' + body = 'body' + form = 'form' + + +class Param: + + def __init__( + self, + default: Any, + *, + deprecated: Optional[bool] = None, + include_in_schema: bool = True, + model: Union[ma.Schema, mf.Field] = None, + ) -> None: + self.default = default + self.deprecated = deprecated + self.include_in_schema = include_in_schema + self.model = model + + +class Path(Param): + in_ = ParamType.path + + +class Query(Param): + in_ = ParamType.query + + +class Header(Param): + in_ = ParamType.header + + +class Cookie(Param): + in_ = ParamType.cookie + + +class Body(Param): + in_ = ParamType.body + + +class Form(Body): + in_ = ParamType.form diff --git a/starmallow/routing.py b/starmallow/routing.py new file mode 100644 index 0000000..9616901 --- /dev/null +++ b/starmallow/routing.py @@ -0,0 +1,764 @@ +import asyncio +import datetime as dt +import inspect +import uuid +from dataclasses import dataclass, field +from decimal import Decimal +from enum import IntEnum +from typing import ( + Any, + Callable, + Coroutine, + Dict, + Iterable, + List, + Mapping, + Optional, + Set, + Tuple, + Type, + Union, +) + +import marshmallow as ma +import marshmallow.fields as mf +from marshmallow.error_store import ErrorStore +from starlette import routing +from starlette.concurrency import run_in_threadpool +from starlette.datastructures import FormData, Headers, QueryParams +from starlette.exceptions import HTTPException +from starlette.requests import Request +from starlette.responses import JSONResponse, Response +from starlette.routing import request_response + +from starmallow.constants import STATUS_CODES_WITH_NO_BODY +from starmallow.exceptions import RequestValidationError +from starmallow.params import Body, Cookie, Form, Header, Param, ParamType, Path, Query +from starmallow.types import DecoratedCallable +from starmallow.utils import ( + Undefined, + generate_unique_id, + get_args, + is_marshmallow_dataclass, + is_marshmallow_field, + is_marshmallow_schema, + is_optional, +) + +PY_TO_MF_MAPPING = { + int: mf.Integer, + float: mf.Float, + bool: mf.Boolean, + str: mf.String, + Decimal: mf.Decimal, + dt.date: mf.Date, + dt.datetime: mf.DateTime, + dt.time: mf.Time, + dt.timedelta: mf.TimeDelta, + uuid.UUID: mf.UUID, +} + + +class SchemaModel(ma.Schema): + def __init__( + self, + schema: ma.Schema, + missing: Any = Undefined, + required: bool = True, + ) -> None: + self.schema = schema + self.missing = missing + self.required = required + + def load( + self, + data: Union[ + Mapping[str, Any], + Iterable[Mapping[str, Any]], + ], + *, + many: Optional[bool] = None, + partial: Optional[bool] = None, + unknown: Optional[str] = None, + ) -> Any: + try: + result = self.schema.load(data, many=many, partial=partial, unknown=unknown) + except Exception as e: + raise e + + return result + + +@dataclass +class EndpointModel: + path_params: Optional[Dict[str, Path]] = field(default_factory=list) + query_params: Optional[Dict[str, Query]] = field(default_factory=list) + header_params: Optional[Dict[str, Header]] = field(default_factory=list) + cookie_params: Optional[Dict[str, Cookie]] = field(default_factory=list) + body_params: Optional[Dict[str, Body]] = field(default_factory=list) + form_params: Optional[Dict[str, Form]] = field(default_factory=list) + name: Optional[str] = None + path: Optional[str] = None + call: Optional[Callable[..., Any]] = None + response_model: Optional[ma.Schema] = None + response_class: Type[Response] = JSONResponse + status_code: Optional[int] = None + + +async def get_body( + request: Request, + endpoint_model: "EndpointModel", +) -> Union[FormData, bytes, Dict[str, Any]]: + is_body_form = bool(endpoint_model.form_params) + should_process_body = is_body_form or endpoint_model.body_params + try: + body: Any = None + if should_process_body: + if is_body_form: + body = await request.form() + else: + body_bytes = await request.body() + if body_bytes: + json_body: Any = Undefined + content_type_value: str = request.headers.get("content-type") + if not content_type_value: + json_body = await request.json() + else: + main_type, sub_type = content_type_value.split('/') + if main_type == "application": + if sub_type == "json" or sub_type.endswith("+json"): + json_body = await request.json() + if json_body != Undefined: + body = json_body + else: + body = body_bytes + + return body + except Exception as e: + raise HTTPException( + status_code=400, detail="There was an error parsing the body" + ) from e + + +def request_params_to_args( + received_params: Union[Mapping[str, Any], QueryParams, Headers], + endpoint_params: Dict[str, Param], +) -> Tuple[Dict[str, Any], ErrorStore]: + values = {} + error_store = ErrorStore() + for field_name, param in endpoint_params.items(): + if isinstance(param.model, mf.Field): + try: + # Load model from specific param + values[field_name] = param.model.deserialize( + received_params.get(field_name, ma.missing), + field_name, + received_params, + ) + except ma.ValidationError as error: + error_store.store_error(error.messages, field_name) + elif isinstance(param.model, ma.Schema): + try: + # Load model from entire params + values[field_name] = param.model.load(received_params, unknown=ma.EXCLUDE) + except ma.ValidationError as error: + error_store.store_error(error.messages) + else: + raise Exception(f'Invalid model type {type(param.model)}, expected marshmallow Schema or Field') + + return values, error_store + + +async def get_request_args( + request: Request, + endpoint_model: EndpointModel, +) -> Tuple[Dict[str, Any], Dict[str, Union[Any, List, Dict]]]: + path_values, path_errors = request_params_to_args( + request.path_params, + endpoint_model.path_params, + ) + query_values, query_errors = request_params_to_args( + request.query_params, + endpoint_model.query_params, + ) + header_values, header_errors = request_params_to_args( + request.headers, + endpoint_model.header_params, + ) + cookie_values, cookie_errors = request_params_to_args( + request.cookies, + endpoint_model.cookie_params, + ) + + body = await get_body(request, endpoint_model) + form_values, form_errors = {}, None + json_values, json_errors = {}, None + if endpoint_model.form_params: + form_values, form_errors = request_params_to_args( + body if body is not None and isinstance(body, FormData) else {}, + endpoint_model.form_params, + ) + if endpoint_model.body_params: + json_values, json_errors = request_params_to_args( + body if body is not None and isinstance(body, Mapping) else {}, + endpoint_model.body_params, + ) + + values = { + **path_values, + **query_values, + **header_values, + **cookie_values, + **form_values, + **json_values, + } + errors = {} + if path_errors.errors: + errors['path'] = path_errors.errors + if query_errors.errors: + errors['query'] = query_errors.errors + if header_errors.errors: + errors['header'] = header_errors.errors + if cookie_errors.errors: + errors['cookie'] = cookie_errors.errors + if form_errors and form_errors.errors: + errors['form'] = form_errors.errors + if json_errors and json_errors.errors: + errors['json'] = json_errors.errors + + return values, errors + + +async def run_endpoint_function( + endpoint_model: EndpointModel, + values: Dict[str, Any], +) -> Any: + assert endpoint_model.call is not None, "endpoint_model.call must be a function" + + if asyncio.iscoroutinefunction(endpoint_model.call): + return await endpoint_model.call(**values) + else: + return await run_in_threadpool(endpoint_model.call, **values) + + +def get_request_handler( + endpoint_model: EndpointModel, +) -> Callable[[Request], Coroutine[Any, Any, Response]]: + assert endpoint_model.call is not None, "dependant.call must be a function" + + async def app(request: Request) -> Response: + values, errors = await get_request_args(request, endpoint_model) + + if errors: + raise RequestValidationError(errors) + + raw_response = await run_endpoint_function( + endpoint_model, + values, + ) + if isinstance(raw_response, Response): + return raw_response + + response_data = None + if endpoint_model.response_model is not None: + response_data = endpoint_model.response_model.dump(raw_response) + response_args = {} + if endpoint_model.status_code is not None: + response_args["status_code"] = endpoint_model.status_code + + response = endpoint_model.response_class(response_data, **response_args) + + return response + + return app + + +class EndpointMixin: + + def _get_param_model(self, parameter: inspect.Parameter) -> Union[ma.Schema, mf.Field]: + model = parameter.annotation + + kwargs = { + 'required': True, + } + if is_optional(parameter.annotation): + kwargs = { + 'missing': None, + 'required': False, + } + # This does not support Union[A,B,C,None]. Only Union[A,None] and Optional[A] + model = next((a for a in get_args(parameter.annotation) if a is not None), None) + + if isinstance(parameter.default, Param): + # If default is not Ellipsis, then it's optional regardless of the typehint. + # Although it's best practice to also mark the typehint as Optional + if parameter.default.default != Ellipsis: + kwargs = { + 'missing': parameter.default.default, + 'required': False, + } + + # Ignore type hint. Use provided model instead. + if parameter.default.model is not None: + model = parameter.default.model + + if is_marshmallow_dataclass(model): + model = model.Schema + + if is_marshmallow_schema(model): + # # Wrap in Nested so that we can apply required and missing + # return mf.Nested(model(), **kwargs) + # TODO: Handle default value? + # return model() + return SchemaModel(model(), **kwargs) + elif is_marshmallow_field(model): + return model(**kwargs) + elif model in PY_TO_MF_MAPPING: + return PY_TO_MF_MAPPING[model](**kwargs) + else: + raise Exception(f'Unknown model type for parameter {parameter.name}, model is {model}') + + def _get_params_from_endpoint( + self, + endpoint: Callable[..., Any], + ) -> Dict[ParamType, List[Dict[str, Param]]]: + params = {param_type: {} for param_type in ParamType} + for name, parameter in inspect.signature(endpoint).parameters.items(): + model = self._get_param_model(parameter) + + if isinstance(parameter.default, Param): + # Create new field_info with processed model + field_info = parameter.default.__class__( + parameter.default.default, + deprecated=parameter.default.deprecated, + include_in_schema=parameter.default.include_in_schema, + model=model, + ) + elif isinstance(model, mf.Field): + # If marshmallow field with now FieldInfo, default to QueryParameter + field_info = Query( + # If a default was provided, honor it. + ... if parameter.default == inspect._empty else parameter.default, + deprecated=False, + include_in_schema=True, + model=model, + ) + else: + field_info = Body(..., deprecated=False, include_in_schema=True, model=model) + + params[field_info.in_][name] = field_info + + return params + + def get_endpoint_model( + self, + path: str, + endpoint: Callable[..., Any], + name: Optional[str] = None, + + status_code: Optional[int] = None, + response_model: Optional[ma.Schema] = None, + response_class: Type[Response] = JSONResponse, + ) -> EndpointModel: + params = self._get_params_from_endpoint(endpoint) + + response_model = response_model or inspect.signature(endpoint).return_annotation + if is_marshmallow_dataclass(response_model): + response_model = response_model.Schema + if is_marshmallow_schema(response_model): + response_model = response_model() + else: + response_model = None + + return EndpointModel( + path=path, + name=name, + call=endpoint, + path_params=params[ParamType.path], + query_params=params[ParamType.query], + header_params=params[ParamType.header], + cookie_params=params[ParamType.cookie], + body_params=params[ParamType.body], + form_params=params[ParamType.form], + response_model=response_model, + response_class=response_class, + status_code=status_code, + ) + + +class APIRoute(routing.Route, EndpointMixin): + + def __init__( + self, + path: str, + endpoint: Callable[..., Any], + *, + name: Optional[str] = None, + methods: Optional[Union[Set[str], List[str]]] = None, + include_in_schema: bool = True, + + description: Optional[str] = None, + + status_code: Optional[int] = None, + response_model: Optional[ma.Schema] = None, + response_class: Type[Response] = JSONResponse, + # Sets the OpenAPI operationId to be used in your path operation + operation_id: Optional[str] = None, + # If operation_id is None, this function will be used to create one. + generate_unique_id_function: Callable[["APIRoute"], str] = generate_unique_id, + # Will be deeply merged with the automatically generated OpenAPI schema for the path operation. + openapi_extra: Optional[Dict[str, Any]] = None, + ) -> None: + super().__init__( + path, + endpoint, + methods=methods, + name=name, + include_in_schema=include_in_schema, + ) + assert callable(endpoint), "An endpoint must be a callable" + + self.status_code = status_code + self.response_model = response_model + self.response_class = response_class + self.operation_id = operation_id + self.generate_unique_id_function = generate_unique_id_function + self.openapi_extra = openapi_extra + + self.unique_id = self.operation_id or generate_unique_id_function(self) + + # normalize enums e.g. http.HTTPStatus + if isinstance(status_code, IntEnum): + status_code = int(status_code) + self.status_code = status_code + + self.description = description or inspect.cleandoc(self.endpoint.__doc__ or "") + # if a "form feed" character (page break) is found in the description text, + # truncate description text to the content preceding the first "form feed" + self.description = self.description.split("\f")[0] + + if self.response_model: + assert ( + status_code not in STATUS_CODES_WITH_NO_BODY + ), f"Status code {status_code} must not have a response body" + + endpoint_model = self.get_endpoint_model( + path, + endpoint, + name=name, + status_code=status_code, + response_model=response_model, + response_class=response_class, + ) + + self.app = request_response(get_request_handler(endpoint_model)) + + +class APIRouter(routing.Router): + + def __init__( + self, + *args, + **kwargs, + ) -> None: + super().__init__(*args, **kwargs) + + def add_api_route( + self, + path: str, + endpoint: Callable[..., Any], + *, + methods: Optional[Union[Set[str], List[str]]] = None, + name: str = None, + include_in_schema: bool = True, + status_code: Optional[int] = None, + response_model: Optional[Type[Any]] = None, + response_class: Type[Response] = JSONResponse, + # Sets the OpenAPI operationId to be used in your path operation + operation_id: Optional[str] = None, + # If operation_id is None, this function will be used to create one. + generate_unique_id_function: Callable[["APIRoute"], str] = generate_unique_id, + # Will be deeply merged with the automatically generated OpenAPI schema for the path operation. + openapi_extra: Optional[Dict[str, Any]] = None, + ) -> None: + route = APIRoute( + path, + endpoint, + methods=methods, + name=name, + include_in_schema=include_in_schema, + status_code=status_code, + response_model=response_model, + response_class=response_class, + operation_id=operation_id, + generate_unique_id_function=generate_unique_id_function, + openapi_extra=openapi_extra, + ) + + self.routes.append(route) + + def api_route( + self, + path: str, + *, + methods: Optional[Union[Set[str], List[str]]] = None, + name: str = None, + include_in_schema: bool = True, + status_code: Optional[int] = None, + response_model: Optional[Type[Any]] = None, + response_class: Type[Response] = JSONResponse, + # Sets the OpenAPI operationId to be used in your path operation + operation_id: Optional[str] = None, + # If operation_id is None, this function will be used to create one. + generate_unique_id_function: Callable[["APIRoute"], str] = generate_unique_id, + # Will be deeply merged with the automatically generated OpenAPI schema for the path operation. + openapi_extra: Optional[Dict[str, Any]] = None, + ) -> Callable[[DecoratedCallable], DecoratedCallable]: + def decorator(func: DecoratedCallable) -> DecoratedCallable: + self.add_api_route( + path, + func, + methods=methods, + name=name, + include_in_schema=include_in_schema, + status_code=status_code, + response_model=response_model, + response_class=response_class, + operation_id=operation_id, + generate_unique_id_function=generate_unique_id_function, + openapi_extra=openapi_extra, + ) + return func + return decorator + + def get( + self, + path: str, + *, + name: str = None, + include_in_schema: bool = True, + status_code: Optional[int] = None, + response_model: Optional[Type[Any]] = None, + response_class: Type[Response] = JSONResponse, + # Sets the OpenAPI operationId to be used in your path operation + operation_id: Optional[str] = None, + # If operation_id is None, this function will be used to create one. + generate_unique_id_function: Callable[["APIRoute"], str] = generate_unique_id, + # Will be deeply merged with the automatically generated OpenAPI schema for the path operation. + openapi_extra: Optional[Dict[str, Any]] = None, + ): + return self.api_route( + path, + methods=['GET'], + name=name, + include_in_schema=include_in_schema, + status_code=status_code, + response_model=response_model, + response_class=response_class, + operation_id=operation_id, + generate_unique_id_function=generate_unique_id_function, + openapi_extra=openapi_extra, + ) + + def put( + self, + path: str, + *, + name: str = None, + include_in_schema: bool = True, + status_code: Optional[int] = None, + response_model: Optional[Type[Any]] = None, + response_class: Type[Response] = JSONResponse, + # Sets the OpenAPI operationId to be used in your path operation + operation_id: Optional[str] = None, + # If operation_id is None, this function will be used to create one. + generate_unique_id_function: Callable[["APIRoute"], str] = generate_unique_id, + # Will be deeply merged with the automatically generated OpenAPI schema for the path operation. + openapi_extra: Optional[Dict[str, Any]] = None, + ): + return self.api_route( + path, + methods=['PUT'], + name=name, + include_in_schema=include_in_schema, + status_code=status_code, + response_model=response_model, + response_class=response_class, + operation_id=operation_id, + generate_unique_id_function=generate_unique_id_function, + openapi_extra=openapi_extra, + ) + + def post( + self, + path: str, + *, + name: str = None, + include_in_schema: bool = True, + status_code: Optional[int] = None, + response_model: Optional[Type[Any]] = None, + response_class: Type[Response] = JSONResponse, + # Sets the OpenAPI operationId to be used in your path operation + operation_id: Optional[str] = None, + # If operation_id is None, this function will be used to create one. + generate_unique_id_function: Callable[["APIRoute"], str] = generate_unique_id, + # Will be deeply merged with the automatically generated OpenAPI schema for the path operation. + openapi_extra: Optional[Dict[str, Any]] = None, + ): + return self.api_route( + path, + methods=['POST'], + name=name, + include_in_schema=include_in_schema, + status_code=status_code, + response_model=response_model, + response_class=response_class, + operation_id=operation_id, + generate_unique_id_function=generate_unique_id_function, + openapi_extra=openapi_extra, + ) + + def delete( + self, + path: str, + *, + name: str = None, + include_in_schema: bool = True, + status_code: Optional[int] = None, + response_model: Optional[Type[Any]] = None, + response_class: Type[Response] = JSONResponse, + # Sets the OpenAPI operationId to be used in your path operation + operation_id: Optional[str] = None, + # If operation_id is None, this function will be used to create one. + generate_unique_id_function: Callable[["APIRoute"], str] = generate_unique_id, + # Will be deeply merged with the automatically generated OpenAPI schema for the path operation. + openapi_extra: Optional[Dict[str, Any]] = None, + ): + return self.api_route( + path, + methods=['DELETE'], + name=name, + include_in_schema=include_in_schema, + status_code=status_code, + response_model=response_model, + response_class=response_class, + operation_id=operation_id, + generate_unique_id_function=generate_unique_id_function, + openapi_extra=openapi_extra, + ) + + def options( + self, + path: str, + *, + name: str = None, + include_in_schema: bool = True, + status_code: Optional[int] = None, + response_model: Optional[Type[Any]] = None, + response_class: Type[Response] = JSONResponse, + # Sets the OpenAPI operationId to be used in your path operation + operation_id: Optional[str] = None, + # If operation_id is None, this function will be used to create one. + generate_unique_id_function: Callable[["APIRoute"], str] = generate_unique_id, + # Will be deeply merged with the automatically generated OpenAPI schema for the path operation. + openapi_extra: Optional[Dict[str, Any]] = None, + ): + return self.api_route( + path, + methods=['OPTIONS'], + name=name, + include_in_schema=include_in_schema, + status_code=status_code, + response_model=response_model, + response_class=response_class, + operation_id=operation_id, + generate_unique_id_function=generate_unique_id_function, + openapi_extra=openapi_extra, + ) + + def head( + self, + path: str, + *, + name: str = None, + include_in_schema: bool = True, + status_code: Optional[int] = None, + response_model: Optional[Type[Any]] = None, + response_class: Type[Response] = JSONResponse, + # Sets the OpenAPI operationId to be used in your path operation + operation_id: Optional[str] = None, + # If operation_id is None, this function will be used to create one. + generate_unique_id_function: Callable[["APIRoute"], str] = generate_unique_id, + # Will be deeply merged with the automatically generated OpenAPI schema for the path operation. + openapi_extra: Optional[Dict[str, Any]] = None, + ): + return self.api_route( + path, + methods=['HEAD'], + name=name, + include_in_schema=include_in_schema, + status_code=status_code, + response_model=response_model, + response_class=response_class, + operation_id=operation_id, + generate_unique_id_function=generate_unique_id_function, + openapi_extra=openapi_extra, + ) + + def patch( + self, + path: str, + *, + name: str = None, + include_in_schema: bool = True, + status_code: Optional[int] = None, + response_model: Optional[Type[Any]] = None, + response_class: Type[Response] = JSONResponse, + # Sets the OpenAPI operationId to be used in your path operation + operation_id: Optional[str] = None, + # If operation_id is None, this function will be used to create one. + generate_unique_id_function: Callable[["APIRoute"], str] = generate_unique_id, + # Will be deeply merged with the automatically generated OpenAPI schema for the path operation. + openapi_extra: Optional[Dict[str, Any]] = None, + ): + return self.api_route( + path, + methods=['PATCH'], + name=name, + include_in_schema=include_in_schema, + status_code=status_code, + response_model=response_model, + response_class=response_class, + operation_id=operation_id, + generate_unique_id_function=generate_unique_id_function, + openapi_extra=openapi_extra, + ) + + def trace( + self, + path: str, + *, + name: str = None, + include_in_schema: bool = True, + status_code: Optional[int] = None, + response_model: Optional[Type[Any]] = None, + response_class: Type[Response] = JSONResponse, + # Sets the OpenAPI operationId to be used in your path operation + operation_id: Optional[str] = None, + # If operation_id is None, this function will be used to create one. + generate_unique_id_function: Callable[["APIRoute"], str] = generate_unique_id, + # Will be deeply merged with the automatically generated OpenAPI schema for the path operation. + openapi_extra: Optional[Dict[str, Any]] = None, + ): + return self.api_route( + path, + methods=['TRACE'], + name=name, + include_in_schema=include_in_schema, + status_code=status_code, + response_model=response_model, + response_class=response_class, + operation_id=operation_id, + generate_unique_id_function=generate_unique_id_function, + openapi_extra=openapi_extra, + ) diff --git a/starmallow/types.py b/starmallow/types.py new file mode 100644 index 0000000..e0bca46 --- /dev/null +++ b/starmallow/types.py @@ -0,0 +1,3 @@ +from typing import Any, Callable, TypeVar + +DecoratedCallable = TypeVar("DecoratedCallable", bound=Callable[..., Any]) diff --git a/starmallow/utils.py b/starmallow/utils.py new file mode 100644 index 0000000..09faca0 --- /dev/null +++ b/starmallow/utils.py @@ -0,0 +1,63 @@ +import inspect +import re +import warnings +from dataclasses import is_dataclass +from typing import TYPE_CHECKING, Generic, Union + +import marshmallow as ma +import marshmallow.fields as mf + +if TYPE_CHECKING: # pragma: nocover + from starmallow.routing import APIRoute + +# Python >= 3.8 - Source: https://stackoverflow.com/a/58841311/3776765 +try: + from typing import get_args, get_origin +# Compatibility +except ImportError: + get_args = lambda t: getattr(t, '__args__', ()) if t is not Generic else Generic + get_origin = lambda t: getattr(t, '__origin__', None) + + +class Undefined: + '''Allows us to check if something is undefined vs None''' + pass + + +def is_optional(field): + return get_origin(field) is Union and type(None) in get_args(field) + + +def generate_operation_id_for_path( + *, name: str, path: str, method: str +) -> str: # pragma: nocover + warnings.warn( + "fastapi.utils.generate_operation_id_for_path() was deprecated, " + "it is not used internally, and will be removed soon", + DeprecationWarning, + stacklevel=2, + ) + operation_id = name + path + operation_id = re.sub("[^0-9a-zA-Z_]", "_", operation_id) + operation_id = operation_id + "_" + method.lower() + return operation_id + + +def generate_unique_id(route: "APIRoute") -> str: + operation_id = route.name + route.path_format + operation_id = re.sub("[^0-9a-zA-Z_]", "_", operation_id) + assert route.methods + operation_id = operation_id + "_" + list(route.methods)[0].lower() + return operation_id + + +def is_marshmallow_schema(obj): + return (inspect.isclass(obj) and issubclass(obj, ma.Schema)) or isinstance(obj, ma.Schema) + + +def is_marshmallow_field(obj): + return (inspect.isclass(obj) and issubclass(obj, mf.Field)) or isinstance(obj, mf.Field) + + +def is_marshmallow_dataclass(obj): + return is_dataclass(obj) and hasattr(obj, 'Schema') and is_marshmallow_schema(obj.Schema)