Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SB] combine conn str parser logic in base handler and _common #16464

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
88 changes: 58 additions & 30 deletions sdk/servicebus/azure-servicebus/azure/servicebus/_base_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,10 @@
from typing import cast, Optional, Tuple, TYPE_CHECKING, Dict, Any, Callable

try:
from urllib import quote_plus # type: ignore
from urllib.parse import quote_plus, urlparse
except ImportError:
from urllib.parse import quote_plus
from urllib import quote_plus # type: ignore
from urlparse import urlparse # type: ignore

import uamqp
from uamqp import utils, compat
Expand Down Expand Up @@ -48,33 +49,37 @@
_LOGGER = logging.getLogger(__name__)


def _parse_conn_str(conn_str):
# type: (str) -> Tuple[str, Optional[str], Optional[str], str, Optional[str], Optional[int]]
def _parse_conn_str(conn_str, check_case=False):
# type: (str, Optional[bool]) -> Tuple[str, Optional[str], Optional[str], str, Optional[str], Optional[int]]
endpoint = None
shared_access_key_name = None
shared_access_key = None
entity_path = None # type: Optional[str]
shared_access_signature = None # type: Optional[str]
shared_access_signature_expiry = None # type: Optional[int]
for element in conn_str.strip().split(";"):
key, _, value = element.partition("=")
if key.lower() == "endpoint":
endpoint = value.rstrip("/")
elif key.lower() == "hostname":
endpoint = value.rstrip("/")
elif key.lower() == "sharedaccesskeyname":
shared_access_key_name = value
elif key.lower() == "sharedaccesskey":
shared_access_key = value
elif key.lower() == "entitypath":
entity_path = value
elif key.lower() == "sharedaccesssignature":

# split connection string into properties
conn_properties = [s.split("=", 1) for s in conn_str.strip().rstrip(";").split(";")]
if any(len(tup) != 2 for tup in conn_properties):
raise ValueError("Connection string is either blank or malformed.")
conn_settings = dict(conn_properties) # type: ignore

# case sensitive check when parsing for connection string properties
if check_case:
shared_access_key = conn_settings.get("SharedAccessKey")
shared_access_key_name = conn_settings.get("SharedAccessKeyName")
endpoint = conn_settings.get("Endpoint")
entity_path = conn_settings.get("EntityPath")

# non case sensitive check when parsing connection string for internal use
for key, value in conn_settings.items():
# only sas check is non case sensitive for both conn str properties and internal use
if key.lower() == "sharedaccesssignature":
shared_access_signature = value
try:
# Expiry can be stored in the "se=<timestamp>" clause of the token. ('&'-separated key-value pairs)
# type: ignore
shared_access_signature_expiry = int(
shared_access_signature.split("se=")[1].split("&")[0]
shared_access_signature.split("se=")[1].split("&")[0] # type: ignore
)
except (
IndexError,
Expand All @@ -83,19 +88,42 @@ def _parse_conn_str(conn_str):
): # Fallback since technically expiry is optional.
# An arbitrary, absurdly large number, since you can't renew.
shared_access_signature_expiry = int(time.time() * 2)
if not (
all((endpoint, shared_access_key_name, shared_access_key))
or all((endpoint, shared_access_signature))
) or all(
(shared_access_key_name, shared_access_signature)
): # this latter clause since we don't accept both
if not check_case:
if key.lower() == "endpoint":
endpoint = value.rstrip("/")
elif key.lower() == "hostname":
endpoint = value.rstrip("/")
elif key.lower() == "sharedaccesskeyname":
shared_access_key_name = value
elif key.lower() == "sharedaccesskey":
shared_access_key = value
elif key.lower() == "entitypath":
entity_path = value

entity = cast(str, entity_path)

# check that endpoint is valid
if not endpoint:
raise ValueError("Connection string is either blank or malformed.")
parsed = urlparse(endpoint)
if not parsed.netloc:
raise ValueError("Invalid Endpoint on the Connection String.")
host = cast(str, parsed.netloc.strip())

if any([shared_access_key, shared_access_key_name]) and not all(
[shared_access_key, shared_access_key_name]
):
raise ValueError(
"Invalid connection string. Should be in the format: "
"Endpoint=sb://<FQDN>/;SharedAccessKeyName=<KeyName>;SharedAccessKey=<KeyValue>"
"\nWith alternate option of providing SharedAccessSignature instead of SharedAccessKeyName and Key"
"Connection string must have both SharedAccessKeyName and SharedAccessKey."
)
if shared_access_signature and shared_access_key:
raise ValueError(
"Only one of the SharedAccessKey or SharedAccessSignature must be present."
)
if not shared_access_signature and not shared_access_key:
raise ValueError(
"At least one of the SharedAccessKey or SharedAccessSignature must be present."
)
entity = cast(str, entity_path)
host = cast(str, strip_protocol_from_uri(cast(str, endpoint)))

return (
host,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,9 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License. See License.txt in the project root for license information.
# --------------------------------------------------------------------------------------------
try:
from urllib.parse import urlparse
except ImportError:
from urlparse import urlparse # type: ignore


from ..management._models import DictMixin
from .._base_handler import _parse_conn_str


class ServiceBusConnectionStringProperties(DictMixin):
Expand Down Expand Up @@ -71,39 +67,14 @@ def parse_connection_string(conn_str):
:type conn_str: str
:rtype: ~azure.servicebus.ServiceBusConnectionStringProperties
"""
conn_settings = [s.split("=", 1) for s in conn_str.split(";")]
if any(len(tup) != 2 for tup in conn_settings):
raise ValueError("Connection string is either blank or malformed.")
conn_settings = dict(conn_settings)
shared_access_signature = None
for key, value in conn_settings.items():
if key.lower() == "sharedaccesssignature":
shared_access_signature = value
shared_access_key = conn_settings.get("SharedAccessKey")
shared_access_key_name = conn_settings.get("SharedAccessKeyName")
swathipil marked this conversation as resolved.
Show resolved Hide resolved
if any([shared_access_key, shared_access_key_name]) and not all(
[shared_access_key, shared_access_key_name]
):
raise ValueError(
"Connection string must have both SharedAccessKeyName and SharedAccessKey."
)
if shared_access_signature is not None and shared_access_key is not None:
raise ValueError(
"Only one of the SharedAccessKey or SharedAccessSignature must be present."
)
endpoint = conn_settings.get("Endpoint")
if not endpoint:
raise ValueError("Connection string is either blank or malformed.")
parsed = urlparse(endpoint.rstrip("/"))
if not parsed.netloc:
raise ValueError("Invalid Endpoint on the Connection String.")
namespace = parsed.netloc.strip()
fully_qualified_namespace, policy, key, entity, signature = _parse_conn_str(conn_str, True)[:-1]
endpoint = "sb://" + fully_qualified_namespace + "/"
props = {
"fully_qualified_namespace": namespace,
"fully_qualified_namespace": fully_qualified_namespace,
"endpoint": endpoint,
"entity_path": conn_settings.get("EntityPath"),
swathipil marked this conversation as resolved.
Show resolved Hide resolved
"shared_access_signature": shared_access_signature,
"shared_access_key_name": shared_access_key_name,
"shared_access_key": shared_access_key,
"entity_path": entity,
"shared_access_signature": signature,
"shared_access_key_name": policy,
"shared_access_key": key,
}
return ServiceBusConnectionStringProperties(**props)
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,12 @@ def test_sb_parse_malformed_conn_str_no_endpoint(self, **kwargs):
parse_result = parse_connection_string(conn_str)
assert str(e.value) == 'Connection string is either blank or malformed.'

def test_sb_parse_malformed_conn_str_no_endpoint_value(self, **kwargs):
conn_str = 'Endpoint=;SharedAccessKeyName=test;SharedAccessKey=THISISATESTKEYXXXXXXXXXXXXXXXXXXXXXXXXXXXX='
with pytest.raises(ValueError) as e:
parse_result = parse_connection_string(conn_str)
assert str(e.value) == 'Connection string is either blank or malformed.'

def test_sb_parse_malformed_conn_str_no_netloc(self, **kwargs):
conn_str = 'Endpoint=MALFORMED;SharedAccessKeyName=test;SharedAccessKey=THISISATESTKEYXXXXXXXXXXXXXXXXXXXXXXXXXXXX='
with pytest.raises(ValueError) as e:
Expand All @@ -48,6 +54,22 @@ def test_sb_parse_conn_str_sas(self, **kwargs):
assert parse_result.shared_access_signature == 'THISISATESTKEYXXXXXXXXXXXXXXXXXXXXXXXXXXXX='
assert parse_result.shared_access_key_name == None

def test_sb_parse_conn_str_whitespace_trailing_semicolon(self, **kwargs):
conn_str = ' Endpoint=sb://resourcename.servicebus.windows.net/;SharedAccessSignature=THISISATESTKEYXXXXXXXXXXXXXXXXXXXXXXXXXXXX=; '
parse_result = parse_connection_string(conn_str)
assert parse_result.endpoint == 'sb://resourcename.servicebus.windows.net/'
assert parse_result.fully_qualified_namespace == 'resourcename.servicebus.windows.net'
assert parse_result.shared_access_signature == 'THISISATESTKEYXXXXXXXXXXXXXXXXXXXXXXXXXXXX='
assert parse_result.shared_access_key_name == None

def test_sb_parse_conn_str_sas_trailing_semicolon(self, **kwargs):
conn_str = 'Endpoint=sb://resourcename.servicebus.windows.net/;SharedAccessSignature=THISISATESTKEYXXXXXXXXXXXXXXXXXXXXXXXXXXXX=;'
parse_result = parse_connection_string(conn_str)
assert parse_result.endpoint == 'sb://resourcename.servicebus.windows.net/'
assert parse_result.fully_qualified_namespace == 'resourcename.servicebus.windows.net'
assert parse_result.shared_access_signature == 'THISISATESTKEYXXXXXXXXXXXXXXXXXXXXXXXXXXXX='
assert parse_result.shared_access_key_name == None

def test_sb_parse_conn_str_no_keyname(self, **kwargs):
conn_str = 'Endpoint=sb://resourcename.servicebus.windows.net/;SharedAccessKey=THISISATESTKEYXXXXXXXXXXXXXXXXXXXXXXXXXXXX='
with pytest.raises(ValueError) as e:
Expand All @@ -59,3 +81,27 @@ def test_sb_parse_conn_str_no_key(self, **kwargs):
with pytest.raises(ValueError) as e:
parse_result = parse_connection_string(conn_str)
assert str(e.value) == 'Connection string must have both SharedAccessKeyName and SharedAccessKey.'

def test_sb_parse_conn_str_no_key_or_sas(self, **kwargs):
conn_str = 'Endpoint=sb://resourcename.servicebus.windows.net/'
with pytest.raises(ValueError) as e:
parse_result = parse_connection_string(conn_str)
assert str(e.value) == 'At least one of the SharedAccessKey or SharedAccessSignature must be present.'

def test_sb_parse_malformed_conn_str_lowercase_endpoint(self, **kwargs):
conn_str = 'endpoint=sb://resourcename.servicebus.windows.net/;SharedAccessKeyName=test;SharedAccessKey=THISISATESTKEYXXXXXXXXXXXXXXXXXXXXXXXXXXXX='
with pytest.raises(ValueError) as e:
parse_result = parse_connection_string(conn_str)
assert str(e.value) == 'Connection string is either blank or malformed.'

def test_sb_parse_malformed_conn_str_lowercase_sa_key_name(self, **kwargs):
conn_str = 'Endpoint=sb://resourcename.servicebus.windows.net/;sharedaccesskeyname=test;SharedAccessKey=THISISATESTKEYXXXXXXXXXXXXXXXXXXXXXXXXXXXX='
with pytest.raises(ValueError) as e:
parse_result = parse_connection_string(conn_str)
assert str(e.value) == 'Connection string must have both SharedAccessKeyName and SharedAccessKey.'

def test_sb_parse_malformed_conn_str_lowercase_sa_key_name(self, **kwargs):
conn_str = 'Endpoint=sb://resourcename.servicebus.windows.net/;SharedAccessKeyName=test;sharedaccesskey=THISISATESTKEYXXXXXXXXXXXXXXXXXXXXXXXXXXXX='
with pytest.raises(ValueError) as e:
parse_result = parse_connection_string(conn_str)
assert str(e.value) == 'Connection string must have both SharedAccessKeyName and SharedAccessKey.'