diff --git a/README.rst b/README.rst index 2938bcf..019e7dd 100644 --- a/README.rst +++ b/README.rst @@ -22,12 +22,6 @@ also a `conn_max_age` argument to easily enable Django's connection pool. If you'd rather not use an environment variable, you can pass a URL in directly instead to ``dj_database_url.parse``. -Supported Databases -------------------- - -Support currently exists for PostgreSQL, PostGIS, MySQL, MySQL (GIS), -Oracle, Oracle (GIS), Redshift, CockroachDB, Timescale, Timescale (GIS) and SQLite. - Installation ------------ @@ -148,6 +142,63 @@ and should instead be passed as: DATABASES['default'] = dj_database_url.config(default='postgres://...', test_options={'NAME': 'mytestdatabase'}) +Supported Databases +------------------- + +Support currently exists for PostgreSQL, PostGIS, MySQL, MySQL (GIS), +Oracle, Oracle (GIS), Redshift, CockroachDB, Timescale, Timescale (GIS) and SQLite. + +If you want to use +some non-default backends, you need to register them first: + +.. code-block:: python + + import dj_database_url + + # registration should be performed only once + dj_database_url.register("mysql-connector", "mysql.connector.django") + + assert dj_database_url.parse("mysql-connector://user:password@host:port/db-name") == { + "ENGINE": "mysql.connector.django", + # ...other connection params + } + +Some backends need further config adjustments (e.g. oracle and mssql +expect ``PORT`` to be a string). For such cases you can provide a +post-processing function to ``register()`` (note that ``register()`` is +used as a **decorator(!)** in this case): + +.. code-block:: python + + import dj_database_url + + @dj_database_url.register("mssql", "sql_server.pyodbc") + def stringify_port(config): + config["PORT"] = str(config["PORT"]) + + @dj_database_url.register("redshift", "django_redshift_backend") + def apply_current_schema(config): + options = config["OPTIONS"] + schema = options.pop("currentSchema", None) + if schema: + options["options"] = f"-c search_path={schema}" + + @dj_database_url.register("snowflake", "django_snowflake") + def adjust_snowflake_config(config): + config.pop("PORT", None) + config["ACCOUNT"] = config.pop("HOST") + name, _, schema = config["NAME"].partition("/") + if schema: + config["SCHEMA"] = schema + config["NAME"] = name + options = config.get("OPTIONS", {}) + warehouse = options.pop("warehouse", None) + if warehouse: + config["WAREHOUSE"] = warehouse + role = options.pop("role", None) + if role: + config["ROLE"] = role + URL schema ---------- diff --git a/dj_database_url/__init__.py b/dj_database_url/__init__.py index d620a3d..4db4cdc 100644 --- a/dj_database_url/__init__.py +++ b/dj_database_url/__init__.py @@ -1,50 +1,15 @@ import logging import os import urllib.parse as urlparse -from typing import Any, Dict, Optional, Union +from typing import Any, Callable, Dict, List, Optional, Union from typing_extensions import TypedDict DEFAULT_ENV = "DATABASE_URL" +ENGINE_SCHEMES: Dict[str, "Engine"] = {} -SCHEMES = { - "postgres": "django.db.backends.postgresql", - "postgresql": "django.db.backends.postgresql", - "pgsql": "django.db.backends.postgresql", - "postgis": "django.contrib.gis.db.backends.postgis", - "mysql": "django.db.backends.mysql", - "mysql2": "django.db.backends.mysql", - "mysqlgis": "django.contrib.gis.db.backends.mysql", - "mysql-connector": "mysql.connector.django", - "mssql": "sql_server.pyodbc", - "mssqlms": "mssql", - "spatialite": "django.contrib.gis.db.backends.spatialite", - "sqlite": "django.db.backends.sqlite3", - "oracle": "django.db.backends.oracle", - "oraclegis": "django.contrib.gis.db.backends.oracle", - "redshift": "django_redshift_backend", - "cockroach": "django_cockroachdb", - "timescale": "timescale.db.backends.postgresql", - "timescalegis": "timescale.db.backends.postgis", -} - -SCHEMES_WITH_SEARCH_PATH = [ - "postgres", - "postgresql", - "pgsql", - "postgis", - "redshift", - "timescale", - "timescalegis", -] - -# Register database schemes in URLs. -for key in SCHEMES.keys(): - urlparse.uses_netloc.append(key) -del key - - -# From https://docs.djangoproject.com/en/4.0/ref/settings/#databases + +# From https://docs.djangoproject.com/en/stable/ref/settings/#databases class DBConfig(TypedDict, total=False): ATOMIC_REQUESTS: bool AUTOCOMMIT: bool @@ -62,11 +27,109 @@ class DBConfig(TypedDict, total=False): USER: str +PostprocessCallable = Callable[[DBConfig], None] +OptionType = Union[int, str, bool] + + +class ParseError(ValueError): + def __str__(self) -> str: + return ( + "This string is not a valid url, possibly because some of its parts" + " is not properly urllib.parse.quote()'ed." + ) + + +class UnknownSchemeError(ValueError): + def __init__(self, scheme: str) -> None: + self.scheme = scheme + + def __str__(self) -> str: + schemes = ", ".join(sorted(ENGINE_SCHEMES.keys())) + return ( + f"Scheme '{self.scheme}://' is unknown." + " Did you forget to register custom backend?" + f" Following schemes have registered backends: {schemes}." + ) + + +def default_postprocess(parsed_config: DBConfig) -> None: + pass + + +class Engine: + def __init__( + self, + backend: str, + postprocess: PostprocessCallable = default_postprocess, + ) -> None: + self.backend = backend + self.postprocess = postprocess + + +def register( + scheme: str, backend: str +) -> Callable[[PostprocessCallable], PostprocessCallable]: + engine = Engine(backend) + if scheme not in ENGINE_SCHEMES: + urlparse.uses_netloc.append(scheme) + ENGINE_SCHEMES[scheme] = engine + + def inner(func: PostprocessCallable) -> PostprocessCallable: + engine.postprocess = func + return func + + return inner + + +register("spatialite", "django.contrib.gis.db.backends.spatialite") +register("mysql-connector", "mysql.connector.django") +register("mysqlgis", "django.contrib.gis.db.backends.mysql") +register("oraclegis", "django.contrib.gis.db.backends.oracle") +register("cockroach", "django_cockroachdb") + + +@register("sqlite", "django.db.backends.sqlite3") +def default_to_in_memory_db(parsed_config: DBConfig) -> None: + # mimic sqlalchemy behaviour + if parsed_config["NAME"] == "": + parsed_config["NAME"] = ":memory:" + + +@register("oracle", "django.db.backends.oracle") +@register("mssqlms", "mssql") +@register("mssql", "sql_server.pyodbc") +def stringify_port(parsed_config: DBConfig) -> None: + parsed_config["PORT"] = str(parsed_config["PORT"]) + + +@register("mysql", "django.db.backends.mysql") +@register("mysql2", "django.db.backends.mysql") +def apply_ssl_ca(parsed_config: DBConfig) -> None: + options = parsed_config["OPTIONS"] + ca = options.pop("ssl-ca", None) + if ca: + options["ssl"] = {"ca": ca} + + +@register("postgres", "django.db.backends.postgresql") +@register("postgresql", "django.db.backends.postgresql") +@register("pgsql", "django.db.backends.postgresql") +@register("postgis", "django.contrib.gis.db.backends.postgis") +@register("redshift", "django_redshift_backend") +@register("timescale", "timescale.db.backends.postgresql") +@register("timescalegis", "timescale.db.backends.postgis") +def apply_current_schema(parsed_config: DBConfig) -> None: + options = parsed_config["OPTIONS"] + schema = options.pop("currentSchema", None) + if schema: + options["options"] = f"-c search_path={schema}" + + def config( env: str = DEFAULT_ENV, default: Optional[str] = None, engine: Optional[str] = None, - conn_max_age: Optional[int] = 0, + conn_max_age: int = 0, conn_health_checks: bool = False, disable_server_side_cursors: bool = False, ssl_require: bool = False, @@ -77,7 +140,7 @@ def config( if s is None: logging.warning( - "No %s environment variable set, and so no databases setup" % env + "No %s environment variable set, and so no databases setup", env ) if s: @@ -97,107 +160,95 @@ def config( def parse( url: str, engine: Optional[str] = None, - conn_max_age: Optional[int] = 0, + conn_max_age: int = 0, conn_health_checks: bool = False, disable_server_side_cursors: bool = False, ssl_require: bool = False, test_options: Optional[dict] = None, ) -> DBConfig: - """Parses a database URL.""" + """Parses a database URL and returns configured DATABASE dictionary.""" + settings = _convert_to_settings( + engine, + conn_max_age, + conn_health_checks, + disable_server_side_cursors, + ssl_require, + test_options, + ) + if url == "sqlite://:memory:": # this is a special case, because if we pass this URL into # urlparse, urlparse will choke trying to interpret "memory" # as a port number - return {"ENGINE": SCHEMES["sqlite"], "NAME": ":memory:"} + return {"ENGINE": ENGINE_SCHEMES["sqlite"].backend, "NAME": ":memory:"} # note: no other settings are required for sqlite - # otherwise parse the url as normal - parsed_config: DBConfig = {} - - if test_options is None: - test_options = {} - - spliturl = urlparse.urlsplit(url) - - # Split query strings from path. - path = spliturl.path[1:] - query = urlparse.parse_qs(spliturl.query) - - # If we are using sqlite and we have no path, then assume we - # want an in-memory database (this is the behaviour of sqlalchemy) - if spliturl.scheme == "sqlite" and path == "": - path = ":memory:" - - # Handle postgres percent-encoded paths. - hostname = spliturl.hostname or "" - if "%" in hostname: - # Switch to url.netloc to avoid lower cased paths - hostname = spliturl.netloc - if "@" in hostname: - hostname = hostname.rsplit("@", 1)[1] - # Use URL Parse library to decode % encodes - hostname = urlparse.unquote(hostname) - - # Lookup specified engine. - if engine is None: - engine = SCHEMES.get(spliturl.scheme) - if engine is None: - raise ValueError( - "No support for '%s'. We support: %s" - % (spliturl.scheme, ", ".join(sorted(SCHEMES.keys()))) - ) - - port = ( - str(spliturl.port) - if spliturl.port - and engine in (SCHEMES["oracle"], SCHEMES["mssql"], SCHEMES["mssqlms"]) - else spliturl.port - ) - - # Update with environment configuration. - parsed_config.update( - { - "NAME": urlparse.unquote(path or ""), - "USER": urlparse.unquote(spliturl.username or ""), - "PASSWORD": urlparse.unquote(spliturl.password or ""), - "HOST": hostname, - "PORT": port or "", - "CONN_MAX_AGE": conn_max_age, - "CONN_HEALTH_CHECKS": conn_health_checks, - "DISABLE_SERVER_SIDE_CURSORS": disable_server_side_cursors, - "ENGINE": engine, + try: + split_result = urlparse.urlsplit(url) + engine_obj = ENGINE_SCHEMES.get(split_result.scheme) + if engine_obj is None: + raise UnknownSchemeError(split_result.scheme) + path = split_result.path[1:] + query = urlparse.parse_qs(split_result.query) + options = {k: _parse_option_values(v) for k, v in query.items()} + parsed_config: DBConfig = { + "ENGINE": engine_obj.backend, + "USER": urlparse.unquote(split_result.username or ""), + "PASSWORD": urlparse.unquote(split_result.password or ""), + "HOST": urlparse.unquote(split_result.hostname or ""), + "PORT": split_result.port or "", + "NAME": urlparse.unquote(path), + "OPTIONS": options, } - ) - if test_options: - parsed_config.update( - { - 'TEST': test_options, - } - ) + except UnknownSchemeError: + raise + except ValueError: + raise ParseError() from None - # Pass the query string into OPTIONS. - options: Dict[str, Any] = {} - for key, values in query.items(): - if spliturl.scheme == "mysql" and key == "ssl-ca": - options["ssl"] = {"ca": values[-1]} - continue - - value = values[-1] - if value.isdigit(): - options[key] = int(value) - elif value.lower() in ("true", "false"): - options[key] = value.lower() == "true" - else: - options[key] = value + # Guarantee that config has options, possibly empty, when postprocess() is called + assert isinstance(parsed_config["OPTIONS"], dict) + engine_obj.postprocess(parsed_config) - if ssl_require: - options["sslmode"] = "require" + # Update the final config with any settings passed in explicitly. + parsed_config["OPTIONS"].update(settings.pop("OPTIONS", {})) + parsed_config.update(settings) + + if not parsed_config["OPTIONS"]: + parsed_config.pop("OPTIONS") + return parsed_config - # Support for Postgres Schema URLs - if "currentSchema" in options and spliturl.scheme in SCHEMES_WITH_SEARCH_PATH: - options["options"] = "-c search_path={0}".format(options.pop("currentSchema")) - if options: - parsed_config["OPTIONS"] = options +def _parse_option_values(values: List[str]) -> Union[OptionType, List[OptionType]]: + parsed_values = [_parse_value(v) for v in values] + return parsed_values[0] if len(parsed_values) == 1 else parsed_values - return parsed_config + +def _parse_value(value: str) -> OptionType: + if value.isdigit(): + return int(value) + if value.lower() in ("true", "false"): + return value.lower() == "true" + return value + + +def _convert_to_settings( + engine: Optional[str], + conn_max_age: int, + conn_health_checks: bool, + disable_server_side_cursors: bool, + ssl_require: bool, + test_options: Optional[dict], +) -> DBConfig: + settings: DBConfig = { + "CONN_MAX_AGE": conn_max_age, + "CONN_HEALTH_CHECKS": conn_health_checks, + "DISABLE_SERVER_SIDE_CURSORS": disable_server_side_cursors, + } + if engine: + settings["ENGINE"] = engine + if ssl_require: + settings["OPTIONS"] = {} + settings["OPTIONS"]["sslmode"] = "require" + if test_options: + settings["TEST"] = test_options + return settings diff --git a/tests/test_dj_database_url.py b/tests/test_dj_database_url.py index 6e46e27..4b48bd3 100644 --- a/tests/test_dj_database_url.py +++ b/tests/test_dj_database_url.py @@ -1,6 +1,8 @@ import os +import re import unittest from unittest import mock +from urllib.parse import uses_netloc import dj_database_url @@ -203,6 +205,24 @@ def test_memory_sqlite_url(self): assert url["ENGINE"] == "django.db.backends.sqlite3" assert url["NAME"] == ":memory:" + def test_sqlite_relative_url(self): + url = "sqlite:///db.sqlite3" + config = dj_database_url.parse(url) + + assert config["ENGINE"] == "django.db.backends.sqlite3" + assert config["NAME"] == "db.sqlite3" + + def test_sqlite_absolute_url(self): + # 4 slashes are needed: + # two are part of scheme + # one separates host:port from path + # and the fourth goes to "NAME" value + url = "sqlite:////db.sqlite3" + config = dj_database_url.parse(url) + + assert config["ENGINE"] == "django.db.backends.sqlite3" + assert config["NAME"] == "/db.sqlite3" + def test_parse_engine_setting(self): engine = "django_mysqlpool.backends.mysqlpool" url = "mysql://bea6eb025ca0d8:69772142@us-cdbr-east.cleardb.com/heroku_97681db3eff7580?reconnect=true" @@ -588,9 +608,36 @@ def test_no_env_variable(self): 'WARNING:root:No DATABASE_URL environment variable set, and so no databases setup' ], cm.output - def test_bad_url_parsing(self): - with self.assertRaisesRegex(ValueError, "No support for 'foo'. We support: "): - dj_database_url.parse("foo://bar") + def test_credentials_unquoted__raise_value_error(self): + expected_message = ( + "This string is not a valid url, possibly because some of its parts " + r"is not properly urllib.parse.quote()'ed." + ) + with self.assertRaisesRegex(ValueError, re.escape(expected_message)): + dj_database_url.parse("postgres://user:passw#ord!@localhost/foobar") + + def test_credentials_quoted__ok(self): + url = "postgres://user%40domain:p%23ssword!@localhost/foobar" + config = dj_database_url.parse(url) + assert config["USER"] == "user@domain" + assert config["PASSWORD"] == "p#ssword!" + + def test_unknown_scheme__raise_value_error(self): + expected_message = ( + "Scheme 'unknown-scheme://' is unknown. " + "Did you forget to register custom backend? Following schemes have registered backends:" + ) + with self.assertRaisesRegex(ValueError, re.escape(expected_message)): + dj_database_url.parse("unknown-scheme://user:password@localhost/foobar") + + def test_register_multiple_times__no_duplicates_in_uses_netloc(self): + # make sure that when register() function is misused, + # it won't pollute urllib.parse.uses_netloc list with duplicates. + # Otherwise, it might cause performance issue if some code assumes that + # that list is short and performs linear search on it. + dj_database_url.register("django.contrib.db.backends.bag_end", "bag-end") + dj_database_url.register("django.contrib.db.backends.bag_end", "bag-end") + assert len(uses_netloc) == len(set(uses_netloc)) @mock.patch.dict( os.environ,