Skip to content

Commit

Permalink
[SB] combine conn str parser logic in base handler and _common (#16464)
Browse files Browse the repository at this point in the history
* combine conn str parser logic in base handler and _common

* removed unused import

* adams comments

* add arg type, remove whitespace

* fix mypy errors

* strip whitespace around conn str
  • Loading branch information
swathipil authored Feb 5, 2021
1 parent e08c230 commit ca1303e
Show file tree
Hide file tree
Showing 3 changed files with 112 additions and 67 deletions.
88 changes: 58 additions & 30 deletions sdk/servicebus/azure-servicebus/azure/servicebus/_base_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,10 @@
from typing import cast, Optional, Tuple, TYPE_CHECKING, Dict, Any, Callable

try:
from urllib import quote_plus # type: ignore
from urllib.parse import quote_plus, urlparse
except ImportError:
from urllib.parse import quote_plus
from urllib import quote_plus # type: ignore
from urlparse import urlparse # type: ignore

import uamqp
from uamqp import utils, compat
Expand Down Expand Up @@ -48,33 +49,37 @@
_LOGGER = logging.getLogger(__name__)


def _parse_conn_str(conn_str):
# type: (str) -> Tuple[str, Optional[str], Optional[str], str, Optional[str], Optional[int]]
def _parse_conn_str(conn_str, check_case=False):
# type: (str, Optional[bool]) -> Tuple[str, Optional[str], Optional[str], str, Optional[str], Optional[int]]
endpoint = None
shared_access_key_name = None
shared_access_key = None
entity_path = None # type: Optional[str]
shared_access_signature = None # type: Optional[str]
shared_access_signature_expiry = None # type: Optional[int]
for element in conn_str.strip().split(";"):
key, _, value = element.partition("=")
if key.lower() == "endpoint":
endpoint = value.rstrip("/")
elif key.lower() == "hostname":
endpoint = value.rstrip("/")
elif key.lower() == "sharedaccesskeyname":
shared_access_key_name = value
elif key.lower() == "sharedaccesskey":
shared_access_key = value
elif key.lower() == "entitypath":
entity_path = value
elif key.lower() == "sharedaccesssignature":

# split connection string into properties
conn_properties = [s.split("=", 1) for s in conn_str.strip().rstrip(";").split(";")]
if any(len(tup) != 2 for tup in conn_properties):
raise ValueError("Connection string is either blank or malformed.")
conn_settings = dict(conn_properties) # type: ignore

# case sensitive check when parsing for connection string properties
if check_case:
shared_access_key = conn_settings.get("SharedAccessKey")
shared_access_key_name = conn_settings.get("SharedAccessKeyName")
endpoint = conn_settings.get("Endpoint")
entity_path = conn_settings.get("EntityPath")

# non case sensitive check when parsing connection string for internal use
for key, value in conn_settings.items():
# only sas check is non case sensitive for both conn str properties and internal use
if key.lower() == "sharedaccesssignature":
shared_access_signature = value
try:
# Expiry can be stored in the "se=<timestamp>" clause of the token. ('&'-separated key-value pairs)
# type: ignore
shared_access_signature_expiry = int(
shared_access_signature.split("se=")[1].split("&")[0]
shared_access_signature.split("se=")[1].split("&")[0] # type: ignore
)
except (
IndexError,
Expand All @@ -83,19 +88,42 @@ def _parse_conn_str(conn_str):
): # Fallback since technically expiry is optional.
# An arbitrary, absurdly large number, since you can't renew.
shared_access_signature_expiry = int(time.time() * 2)
if not (
all((endpoint, shared_access_key_name, shared_access_key))
or all((endpoint, shared_access_signature))
) or all(
(shared_access_key_name, shared_access_signature)
): # this latter clause since we don't accept both
if not check_case:
if key.lower() == "endpoint":
endpoint = value.rstrip("/")
elif key.lower() == "hostname":
endpoint = value.rstrip("/")
elif key.lower() == "sharedaccesskeyname":
shared_access_key_name = value
elif key.lower() == "sharedaccesskey":
shared_access_key = value
elif key.lower() == "entitypath":
entity_path = value

entity = cast(str, entity_path)

# check that endpoint is valid
if not endpoint:
raise ValueError("Connection string is either blank or malformed.")
parsed = urlparse(endpoint)
if not parsed.netloc:
raise ValueError("Invalid Endpoint on the Connection String.")
host = cast(str, parsed.netloc.strip())

if any([shared_access_key, shared_access_key_name]) and not all(
[shared_access_key, shared_access_key_name]
):
raise ValueError(
"Invalid connection string. Should be in the format: "
"Endpoint=sb://<FQDN>/;SharedAccessKeyName=<KeyName>;SharedAccessKey=<KeyValue>"
"\nWith alternate option of providing SharedAccessSignature instead of SharedAccessKeyName and Key"
"Connection string must have both SharedAccessKeyName and SharedAccessKey."
)
if shared_access_signature and shared_access_key:
raise ValueError(
"Only one of the SharedAccessKey or SharedAccessSignature must be present."
)
if not shared_access_signature and not shared_access_key:
raise ValueError(
"At least one of the SharedAccessKey or SharedAccessSignature must be present."
)
entity = cast(str, entity_path)
host = cast(str, strip_protocol_from_uri(cast(str, endpoint)))

return (
host,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,9 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License. See License.txt in the project root for license information.
# --------------------------------------------------------------------------------------------
try:
from urllib.parse import urlparse
except ImportError:
from urlparse import urlparse # type: ignore


from ..management._models import DictMixin
from .._base_handler import _parse_conn_str


class ServiceBusConnectionStringProperties(DictMixin):
Expand Down Expand Up @@ -71,39 +67,14 @@ def parse_connection_string(conn_str):
:type conn_str: str
:rtype: ~azure.servicebus.ServiceBusConnectionStringProperties
"""
conn_settings = [s.split("=", 1) for s in conn_str.split(";")]
if any(len(tup) != 2 for tup in conn_settings):
raise ValueError("Connection string is either blank or malformed.")
conn_settings = dict(conn_settings)
shared_access_signature = None
for key, value in conn_settings.items():
if key.lower() == "sharedaccesssignature":
shared_access_signature = value
shared_access_key = conn_settings.get("SharedAccessKey")
shared_access_key_name = conn_settings.get("SharedAccessKeyName")
if any([shared_access_key, shared_access_key_name]) and not all(
[shared_access_key, shared_access_key_name]
):
raise ValueError(
"Connection string must have both SharedAccessKeyName and SharedAccessKey."
)
if shared_access_signature is not None and shared_access_key is not None:
raise ValueError(
"Only one of the SharedAccessKey or SharedAccessSignature must be present."
)
endpoint = conn_settings.get("Endpoint")
if not endpoint:
raise ValueError("Connection string is either blank or malformed.")
parsed = urlparse(endpoint.rstrip("/"))
if not parsed.netloc:
raise ValueError("Invalid Endpoint on the Connection String.")
namespace = parsed.netloc.strip()
fully_qualified_namespace, policy, key, entity, signature = _parse_conn_str(conn_str, True)[:-1]
endpoint = "sb://" + fully_qualified_namespace + "/"
props = {
"fully_qualified_namespace": namespace,
"fully_qualified_namespace": fully_qualified_namespace,
"endpoint": endpoint,
"entity_path": conn_settings.get("EntityPath"),
"shared_access_signature": shared_access_signature,
"shared_access_key_name": shared_access_key_name,
"shared_access_key": shared_access_key,
"entity_path": entity,
"shared_access_signature": signature,
"shared_access_key_name": policy,
"shared_access_key": key,
}
return ServiceBusConnectionStringProperties(**props)
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,12 @@ def test_sb_parse_malformed_conn_str_no_endpoint(self, **kwargs):
parse_result = parse_connection_string(conn_str)
assert str(e.value) == 'Connection string is either blank or malformed.'

def test_sb_parse_malformed_conn_str_no_endpoint_value(self, **kwargs):
conn_str = 'Endpoint=;SharedAccessKeyName=test;SharedAccessKey=THISISATESTKEYXXXXXXXXXXXXXXXXXXXXXXXXXXXX='
with pytest.raises(ValueError) as e:
parse_result = parse_connection_string(conn_str)
assert str(e.value) == 'Connection string is either blank or malformed.'

def test_sb_parse_malformed_conn_str_no_netloc(self, **kwargs):
conn_str = 'Endpoint=MALFORMED;SharedAccessKeyName=test;SharedAccessKey=THISISATESTKEYXXXXXXXXXXXXXXXXXXXXXXXXXXXX='
with pytest.raises(ValueError) as e:
Expand All @@ -48,6 +54,22 @@ def test_sb_parse_conn_str_sas(self, **kwargs):
assert parse_result.shared_access_signature == 'THISISATESTKEYXXXXXXXXXXXXXXXXXXXXXXXXXXXX='
assert parse_result.shared_access_key_name == None

def test_sb_parse_conn_str_whitespace_trailing_semicolon(self, **kwargs):
conn_str = ' Endpoint=sb://resourcename.servicebus.windows.net/;SharedAccessSignature=THISISATESTKEYXXXXXXXXXXXXXXXXXXXXXXXXXXXX=; '
parse_result = parse_connection_string(conn_str)
assert parse_result.endpoint == 'sb://resourcename.servicebus.windows.net/'
assert parse_result.fully_qualified_namespace == 'resourcename.servicebus.windows.net'
assert parse_result.shared_access_signature == 'THISISATESTKEYXXXXXXXXXXXXXXXXXXXXXXXXXXXX='
assert parse_result.shared_access_key_name == None

def test_sb_parse_conn_str_sas_trailing_semicolon(self, **kwargs):
conn_str = 'Endpoint=sb://resourcename.servicebus.windows.net/;SharedAccessSignature=THISISATESTKEYXXXXXXXXXXXXXXXXXXXXXXXXXXXX=;'
parse_result = parse_connection_string(conn_str)
assert parse_result.endpoint == 'sb://resourcename.servicebus.windows.net/'
assert parse_result.fully_qualified_namespace == 'resourcename.servicebus.windows.net'
assert parse_result.shared_access_signature == 'THISISATESTKEYXXXXXXXXXXXXXXXXXXXXXXXXXXXX='
assert parse_result.shared_access_key_name == None

def test_sb_parse_conn_str_no_keyname(self, **kwargs):
conn_str = 'Endpoint=sb://resourcename.servicebus.windows.net/;SharedAccessKey=THISISATESTKEYXXXXXXXXXXXXXXXXXXXXXXXXXXXX='
with pytest.raises(ValueError) as e:
Expand All @@ -59,3 +81,27 @@ def test_sb_parse_conn_str_no_key(self, **kwargs):
with pytest.raises(ValueError) as e:
parse_result = parse_connection_string(conn_str)
assert str(e.value) == 'Connection string must have both SharedAccessKeyName and SharedAccessKey.'

def test_sb_parse_conn_str_no_key_or_sas(self, **kwargs):
conn_str = 'Endpoint=sb://resourcename.servicebus.windows.net/'
with pytest.raises(ValueError) as e:
parse_result = parse_connection_string(conn_str)
assert str(e.value) == 'At least one of the SharedAccessKey or SharedAccessSignature must be present.'

def test_sb_parse_malformed_conn_str_lowercase_endpoint(self, **kwargs):
conn_str = 'endpoint=sb://resourcename.servicebus.windows.net/;SharedAccessKeyName=test;SharedAccessKey=THISISATESTKEYXXXXXXXXXXXXXXXXXXXXXXXXXXXX='
with pytest.raises(ValueError) as e:
parse_result = parse_connection_string(conn_str)
assert str(e.value) == 'Connection string is either blank or malformed.'

def test_sb_parse_malformed_conn_str_lowercase_sa_key_name(self, **kwargs):
conn_str = 'Endpoint=sb://resourcename.servicebus.windows.net/;sharedaccesskeyname=test;SharedAccessKey=THISISATESTKEYXXXXXXXXXXXXXXXXXXXXXXXXXXXX='
with pytest.raises(ValueError) as e:
parse_result = parse_connection_string(conn_str)
assert str(e.value) == 'Connection string must have both SharedAccessKeyName and SharedAccessKey.'

def test_sb_parse_malformed_conn_str_lowercase_sa_key_name(self, **kwargs):
conn_str = 'Endpoint=sb://resourcename.servicebus.windows.net/;SharedAccessKeyName=test;sharedaccesskey=THISISATESTKEYXXXXXXXXXXXXXXXXXXXXXXXXXXXX='
with pytest.raises(ValueError) as e:
parse_result = parse_connection_string(conn_str)
assert str(e.value) == 'Connection string must have both SharedAccessKeyName and SharedAccessKey.'

0 comments on commit ca1303e

Please sign in to comment.