Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: datasource save, improve data validation #22038

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions superset/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -1306,6 +1306,9 @@ def EMAIL_HEADER_MUTATOR( # pylint: disable=invalid-name,unused-argument
# Typically these should not be allowed.
PREVENT_UNSAFE_DB_CONNECTIONS = True

# Prevents unsafe default endpoints to be registered on datasets.
PREVENT_UNSAFE_DEFAULT_URLS_ON_DATASET = True

# Path used to store SSL certificates that are generated when using custom certs.
# Defaults to temporary directory.
# Example: SSL_CERT_PATH = "/certs"
Expand Down
19 changes: 18 additions & 1 deletion superset/utils/urls.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,12 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import unicodedata
import urllib
from typing import Any
from urllib.parse import urlparse

from flask import current_app, url_for
from flask import current_app, request, url_for


def get_url_host(user_friendly: bool = False) -> str:
Expand Down Expand Up @@ -48,3 +50,18 @@ def modify_url_query(url: str, **kwargs: Any) -> str:

parts[3] = "&".join(f"{k}={urllib.parse.quote(v[0])}" for k, v in params.items())
return urllib.parse.urlunsplit(parts)


def is_safe_url(url: str) -> bool:
if url.startswith("///"):
return False
try:
ref_url = urlparse(request.host_url)
test_url = urlparse(url)
except ValueError:
return False
if unicodedata.category(url[0])[0] == "C":
return False
if test_url.scheme != ref_url.scheme or ref_url.netloc != test_url.netloc:
return False
return True
17 changes: 16 additions & 1 deletion superset/views/datasource/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from collections import Counter
from typing import Any

from flask import redirect, request
from flask import current_app, redirect, request
from flask_appbuilder import expose, permission_name
from flask_appbuilder.api import rison
from flask_appbuilder.security.decorators import has_access, has_access_api
Expand All @@ -40,6 +40,7 @@
from superset.models.core import Database
from superset.superset_typing import FlaskResponse
from superset.utils.core import DatasourceType
from superset.utils.urls import is_safe_url
from superset.views.base import (
api,
BaseSupersetView,
Expand Down Expand Up @@ -77,6 +78,20 @@ def save(self) -> FlaskResponse:
datasource_id = datasource_dict.get("id")
datasource_type = datasource_dict.get("type")
database_id = datasource_dict["database"].get("id")
default_endpoint = datasource_dict["default_endpoint"]
if (
default_endpoint
and not is_safe_url(default_endpoint)
and current_app.config["PREVENT_UNSAFE_DEFAULT_URLS_ON_DATASET"]
):
return json_error_response(
_(
"The submitted URL is not considered safe,"
" only use URLs with the same domain as Superset."
),
status=400,
)

orm_datasource = DatasourceDAO.get_datasource(
db.session, DatasourceType(datasource_type), datasource_id
)
Expand Down
38 changes: 38 additions & 0 deletions tests/integration_tests/datasource_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -297,6 +297,44 @@ def test_save(self):
print(k)
self.assertEqual(resp[k], datasource_post[k])

def test_save_default_endpoint_validation_fail(self):
self.login(username="admin")
tbl_id = self.get_table(name="birth_names").id

datasource_post = get_datasource_post()
datasource_post["id"] = tbl_id
datasource_post["owners"] = [1]
datasource_post["default_endpoint"] = "http://www.google.com"
data = dict(data=json.dumps(datasource_post))
resp = self.client.post("/datasource/save/", data=data)
assert resp.status_code == 400

def test_save_default_endpoint_validation_unsafe(self):
self.app.config["PREVENT_UNSAFE_DEFAULT_URLS_ON_DATASET"] = False
self.login(username="admin")
tbl_id = self.get_table(name="birth_names").id

datasource_post = get_datasource_post()
datasource_post["id"] = tbl_id
datasource_post["owners"] = [1]
datasource_post["default_endpoint"] = "http://www.google.com"
data = dict(data=json.dumps(datasource_post))
resp = self.client.post("/datasource/save/", data=data)
assert resp.status_code == 200
self.app.config["PREVENT_UNSAFE_DEFAULT_URLS_ON_DATASET"] = True

def test_save_default_endpoint_validation_success(self):
self.login(username="admin")
tbl_id = self.get_table(name="birth_names").id

datasource_post = get_datasource_post()
datasource_post["id"] = tbl_id
datasource_post["owners"] = [1]
datasource_post["default_endpoint"] = "http://localhost/superset/1"
data = dict(data=json.dumps(datasource_post))
resp = self.client.post("/datasource/save/", data=data)
assert resp.status_code == 200

def save_datasource_from_dict(self, datasource_post):
data = dict(data=json.dumps(datasource_post))
resp = self.get_json_resp("/datasource/save/", data)
Expand Down
26 changes: 26 additions & 0 deletions tests/unit_tests/utils/urls_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import pytest

from superset.utils.urls import modify_url_query

EXPLORE_CHART_LINK = "http://localhost:9000/explore/?form_data=%7B%22slice_id%22%3A+76%7D&standalone=true&force=false"
Expand All @@ -33,3 +35,27 @@ def test_convert_chart_link() -> None:
def test_convert_dashboard_link() -> None:
test_url = modify_url_query(EXPLORE_DASHBOARD_LINK, standalone="0")
assert test_url == "http://localhost:9000/superset/dashboard/3/?standalone=0"


@pytest.mark.parametrize(
"url,is_safe",
[
("http://localhost/", True),
("http://localhost/superset/1", True),
("https://localhost/", False),
("https://localhost/superset/1", False),
("localhost/superset/1", False),
("ftp://localhost/superset/1", False),
("http://external.com", False),
("https://external.com", False),
("external.com", False),
("///localhost", False),
("xpto://localhost:[3/1/", False),
],
)
def test_is_safe_url(url: str, is_safe: bool) -> None:
from superset import app
from superset.utils.urls import is_safe_url

with app.test_request_context("/"):
assert is_safe_url(url) == is_safe