diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index c69c6e7..91d002e 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -35,6 +35,10 @@ jobs: pip install -r requirements.txt pip install "Django~=${{ matrix.django-version }}.0" . + - name: Run mypy + run: | + python -m mypy dj_database_url.py + - name: Run Tests run: | echo "$(python --version) / Django $(django-admin --version)" diff --git a/MANIFEST.in b/MANIFEST.in new file mode 100644 index 0000000..13c250a --- /dev/null +++ b/MANIFEST.in @@ -0,0 +1 @@ +include py.typed diff --git a/dj_database_url.py b/dj_database_url.py index ac194a4..8a5e141 100644 --- a/dj_database_url.py +++ b/dj_database_url.py @@ -1,5 +1,8 @@ import os import urllib.parse as urlparse +from typing import Any, Dict, Optional, Union + +from typing_extensions import TypedDict # Register database schemes in URLs. urlparse.uses_netloc.append("postgres") @@ -45,15 +48,33 @@ } +# From https://docs.djangoproject.com/en/4.0/ref/settings/#databases +class DBConfig(TypedDict, total=False): + ATOMIC_REQUESTS: bool + AUTOCOMMIT: bool + CONN_MAX_AGE: int + CONN_HEALTH_CHECKS: bool + DISABLE_SERVER_SIDE_CURSORS: bool + ENGINE: str + HOST: str + NAME: str + OPTIONS: Optional[Dict[str, Any]] + PASSWORD: str + PORT: Union[str, int] + TEST: Dict[str, Any] + TIME_ZONE: str + USER: str + + def config( - env=DEFAULT_ENV, - default=None, - engine=None, - conn_max_age=0, - conn_health_checks=False, - ssl_require=False, - test_options=None, -): + env: str = DEFAULT_ENV, + default: Optional[str] = None, + engine: Optional[str] = None, + conn_max_age: int = 0, + conn_health_checks: bool = False, + ssl_require: bool = False, + test_options: Optional[Dict] = None, +) -> DBConfig: """Returns configured DATABASE dictionary from DATABASE_URL.""" s = os.environ.get(env, default) @@ -66,13 +87,13 @@ def config( def parse( - url, - engine=None, - conn_max_age=0, - conn_health_checks=False, - ssl_require=False, - test_options=None, -): + url: str, + engine: Optional[str] = None, + conn_max_age: int = 0, + conn_health_checks: bool = False, + ssl_require: bool = False, + test_options: Optional[dict] = None, +) -> DBConfig: """Parses a database URL.""" if url == "sqlite://:memory:": # this is a special case, because if we pass this URL into @@ -82,31 +103,31 @@ def parse( # note: no other settings are required for sqlite # otherwise parse the url as normal - parsed_config = {} + parsed_config: DBConfig = {} if test_options is None: test_options = {} - url = urlparse.urlsplit(url) + spliturl = urlparse.urlsplit(url) # Split query strings from path. - path = url.path[1:] - if "?" in path and not url.query: - path, query = path.split("?", 2) + path = spliturl.path[1:] + if "?" in path and not spliturl.query: + path, raw_query = path.split("?", 2) else: - path, query = path, url.query - query = urlparse.parse_qs(query) + path, raw_query = path, spliturl.query + query = urlparse.parse_qs(raw_query) # If we are using sqlite and we have no path, then assume we # want an in-memory database (this is the behaviour of sqlalchemy) - if url.scheme == "sqlite" and path == "": + if spliturl.scheme == "sqlite" and path == "": path = ":memory:" # Handle postgres percent-encoded paths. - hostname = url.hostname or "" + hostname = spliturl.hostname or "" if "%" in hostname: # Switch to url.netloc to avoid lower cased paths - hostname = url.netloc + hostname = spliturl.netloc if "@" in hostname: hostname = hostname.rsplit("@", 1)[1] if ":" in hostname: @@ -116,26 +137,26 @@ def parse( # Lookup specified engine. if engine is None: - engine = SCHEMES.get(url.scheme) + engine = SCHEMES.get(spliturl.scheme) if engine is None: raise ValueError( "No support for '%s'. We support: %s" - % (url.scheme, ", ".join(sorted(SCHEMES.keys()))) + % (spliturl.scheme, ", ".join(sorted(SCHEMES.keys()))) ) port = ( - str(url.port) - if url.port + str(spliturl.port) + if spliturl.port and engine in (SCHEMES["oracle"], SCHEMES["mssql"], SCHEMES["mssqlms"]) - else url.port + else spliturl.port ) # Update with environment configuration. parsed_config.update( { "NAME": urlparse.unquote(path or ""), - "USER": urlparse.unquote(url.username or ""), - "PASSWORD": urlparse.unquote(url.password or ""), + "USER": urlparse.unquote(spliturl.username or ""), + "PASSWORD": urlparse.unquote(spliturl.password or ""), "HOST": hostname, "PORT": port or "", "CONN_MAX_AGE": conn_max_age, @@ -150,9 +171,9 @@ def parse( ) # Pass the query string into OPTIONS. - options = {} + options: Dict[str, Any] = {} for key, values in query.items(): - if url.scheme == "mysql" and key == "ssl-ca": + if spliturl.scheme == "mysql" and key == "ssl-ca": options["ssl"] = {"ca": values[-1]} continue diff --git a/py.typed b/py.typed new file mode 100644 index 0000000..e69de29 diff --git a/pyproject.toml b/pyproject.toml index 08ca2f5..721375d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,2 +1,8 @@ [tool.black] skip-string-normalization = 1 + +[tool.mypy] +show_error_codes=true +disallow_untyped_defs=true +disallow_untyped_calls=true +warn_redundant_casts=true diff --git a/requirements.txt b/requirements.txt index 4ebc8ae..3c036d6 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1 +1,2 @@ coverage +mypy diff --git a/setup.py b/setup.py index e787d12..9a18a87 100644 --- a/setup.py +++ b/setup.py @@ -14,7 +14,7 @@ long_description=readme, long_description_content_type="text/x-rst", py_modules=["dj_database_url"], - install_requires=["Django>=3.2"], + install_requires=["Django>=3.2", "typing_extensions >= 3.10.0.0"], zip_safe=False, include_package_data=True, platforms="any",