From 717710f20254e19e816972f8925906a62caa4ee8 Mon Sep 17 00:00:00 2001 From: Swathi Pillalamarri Date: Mon, 1 Feb 2021 20:26:16 -0500 Subject: [PATCH 1/6] combine conn str parser logic in base handler and _common --- .../azure/servicebus/_base_handler.py | 33 +++++++++++----- .../_common/_connection_string_parser.py | 38 ++++--------------- .../tests/test_connection_string_parser.py | 6 +++ 3 files changed, 37 insertions(+), 40 deletions(-) diff --git a/sdk/servicebus/azure-servicebus/azure/servicebus/_base_handler.py b/sdk/servicebus/azure-servicebus/azure/servicebus/_base_handler.py index 36e7e5a877fc..9801ba2fc2b6 100644 --- a/sdk/servicebus/azure-servicebus/azure/servicebus/_base_handler.py +++ b/sdk/servicebus/azure-servicebus/azure/servicebus/_base_handler.py @@ -14,6 +14,11 @@ except ImportError: from urllib.parse import quote_plus +try: + from urllib.parse import urlparse +except ImportError: + from urlparse import urlparse # type: ignore + import uamqp from uamqp import utils, compat from uamqp.message import MessageProperties @@ -83,17 +88,27 @@ def _parse_conn_str(conn_str): ): # Fallback since technically expiry is optional. # An arbitrary, absurdly large number, since you can't renew. shared_access_signature_expiry = int(time.time() * 2) - if not ( - all((endpoint, shared_access_key_name, shared_access_key)) - or all((endpoint, shared_access_signature)) - ) or all( - (shared_access_key_name, shared_access_signature) - ): # this latter clause since we don't accept both + + if not endpoint: + raise ValueError("Connection string is either blank or malformed.") + parsed = urlparse(endpoint.rstrip("/")) + if not parsed.netloc: + raise ValueError("Invalid Endpoint on the Connection String.") + if any([shared_access_key, shared_access_key_name]) and not all( + [shared_access_key, shared_access_key_name] + ): raise ValueError( - "Invalid connection string. Should be in the format: " - "Endpoint=sb:///;SharedAccessKeyName=;SharedAccessKey=" - "\nWith alternate option of providing SharedAccessSignature instead of SharedAccessKeyName and Key" + "Connection string must have both SharedAccessKeyName and SharedAccessKey." ) + if shared_access_signature is not None and shared_access_key is not None: + raise ValueError( + "Only one of the SharedAccessKey or SharedAccessSignature must be present." + ) + if shared_access_signature is None and shared_access_key is None: + raise ValueError( + "At least one of the SharedAccessKey or SharedAccessSignature must be present." + ) + entity = cast(str, entity_path) host = cast(str, strip_protocol_from_uri(cast(str, endpoint))) diff --git a/sdk/servicebus/azure-servicebus/azure/servicebus/_common/_connection_string_parser.py b/sdk/servicebus/azure-servicebus/azure/servicebus/_common/_connection_string_parser.py index a67f3816015c..a8e6c690addf 100644 --- a/sdk/servicebus/azure-servicebus/azure/servicebus/_common/_connection_string_parser.py +++ b/sdk/servicebus/azure-servicebus/azure/servicebus/_common/_connection_string_parser.py @@ -9,6 +9,7 @@ from ..management._models import DictMixin +from .._base_handler import _parse_conn_str class ServiceBusConnectionStringProperties(DictMixin): @@ -71,39 +72,14 @@ def parse_connection_string(conn_str): :type conn_str: str :rtype: ~azure.servicebus.ServiceBusConnectionStringProperties """ - conn_settings = [s.split("=", 1) for s in conn_str.split(";")] - if any(len(tup) != 2 for tup in conn_settings): - raise ValueError("Connection string is either blank or malformed.") - conn_settings = dict(conn_settings) - shared_access_signature = None - for key, value in conn_settings.items(): - if key.lower() == "sharedaccesssignature": - shared_access_signature = value - shared_access_key = conn_settings.get("SharedAccessKey") - shared_access_key_name = conn_settings.get("SharedAccessKeyName") - if any([shared_access_key, shared_access_key_name]) and not all( - [shared_access_key, shared_access_key_name] - ): - raise ValueError( - "Connection string must have both SharedAccessKeyName and SharedAccessKey." - ) - if shared_access_signature is not None and shared_access_key is not None: - raise ValueError( - "Only one of the SharedAccessKey or SharedAccessSignature must be present." - ) - endpoint = conn_settings.get("Endpoint") - if not endpoint: - raise ValueError("Connection string is either blank or malformed.") - parsed = urlparse(endpoint.rstrip("/")) - if not parsed.netloc: - raise ValueError("Invalid Endpoint on the Connection String.") - namespace = parsed.netloc.strip() + namespace, policy, key, entity, signature = _parse_conn_str(conn_str)[:-1] + endpoint = "sb://" + namespace + "/" props = { "fully_qualified_namespace": namespace, "endpoint": endpoint, - "entity_path": conn_settings.get("EntityPath"), - "shared_access_signature": shared_access_signature, - "shared_access_key_name": shared_access_key_name, - "shared_access_key": shared_access_key, + "entity_path": entity, + "shared_access_signature": signature, + "shared_access_key_name": policy, + "shared_access_key": key, } return ServiceBusConnectionStringProperties(**props) diff --git a/sdk/servicebus/azure-servicebus/tests/test_connection_string_parser.py b/sdk/servicebus/azure-servicebus/tests/test_connection_string_parser.py index fd0657006f54..4d9717887df7 100644 --- a/sdk/servicebus/azure-servicebus/tests/test_connection_string_parser.py +++ b/sdk/servicebus/azure-servicebus/tests/test_connection_string_parser.py @@ -59,3 +59,9 @@ def test_sb_parse_conn_str_no_key(self, **kwargs): with pytest.raises(ValueError) as e: parse_result = parse_connection_string(conn_str) assert str(e.value) == 'Connection string must have both SharedAccessKeyName and SharedAccessKey.' + + def test_sb_parse_conn_str_no_key_or_sas(self, **kwargs): + conn_str = 'Endpoint=sb://resourcename.servicebus.windows.net/' + with pytest.raises(ValueError) as e: + parse_result = parse_connection_string(conn_str) + assert str(e.value) == 'At least one of the SharedAccessKey or SharedAccessSignature must be present.' From 46fa7f041e11fc9bd9befa8ff6e4f1abfa2f4be7 Mon Sep 17 00:00:00 2001 From: Swathi Pillalamarri Date: Tue, 2 Feb 2021 12:11:59 -0500 Subject: [PATCH 2/6] removed unused import --- .../azure/servicebus/_common/_connection_string_parser.py | 5 ----- 1 file changed, 5 deletions(-) diff --git a/sdk/servicebus/azure-servicebus/azure/servicebus/_common/_connection_string_parser.py b/sdk/servicebus/azure-servicebus/azure/servicebus/_common/_connection_string_parser.py index a8e6c690addf..fa48543a87e0 100644 --- a/sdk/servicebus/azure-servicebus/azure/servicebus/_common/_connection_string_parser.py +++ b/sdk/servicebus/azure-servicebus/azure/servicebus/_common/_connection_string_parser.py @@ -2,11 +2,6 @@ # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. See License.txt in the project root for license information. # -------------------------------------------------------------------------------------------- -try: - from urllib.parse import urlparse -except ImportError: - from urlparse import urlparse # type: ignore - from ..management._models import DictMixin from .._base_handler import _parse_conn_str From e3b8719fcfc2e3af2324c4f6660558ace22c57de Mon Sep 17 00:00:00 2001 From: Swathi Pillalamarri Date: Wed, 3 Feb 2021 18:19:19 -0500 Subject: [PATCH 3/6] adams comments --- .../azure/servicebus/_base_handler.py | 66 +++++++++++-------- .../_common/_connection_string_parser.py | 6 +- .../tests/test_connection_string_parser.py | 32 +++++++++ 3 files changed, 75 insertions(+), 29 deletions(-) diff --git a/sdk/servicebus/azure-servicebus/azure/servicebus/_base_handler.py b/sdk/servicebus/azure-servicebus/azure/servicebus/_base_handler.py index 9801ba2fc2b6..c0ecbeb588f9 100644 --- a/sdk/servicebus/azure-servicebus/azure/servicebus/_base_handler.py +++ b/sdk/servicebus/azure-servicebus/azure/servicebus/_base_handler.py @@ -10,13 +10,9 @@ from typing import cast, Optional, Tuple, TYPE_CHECKING, Dict, Any, Callable try: - from urllib import quote_plus # type: ignore -except ImportError: - from urllib.parse import quote_plus - -try: - from urllib.parse import urlparse + from urllib.parse import quote_plus, urlparse except ImportError: + from urllib import quote_plus # type: ignore from urlparse import urlparse # type: ignore import uamqp @@ -53,7 +49,7 @@ _LOGGER = logging.getLogger(__name__) -def _parse_conn_str(conn_str): +def _parse_conn_str(conn_str, check_case=False): # type: (str) -> Tuple[str, Optional[str], Optional[str], str, Optional[str], Optional[int]] endpoint = None shared_access_key_name = None @@ -61,19 +57,24 @@ def _parse_conn_str(conn_str): entity_path = None # type: Optional[str] shared_access_signature = None # type: Optional[str] shared_access_signature_expiry = None # type: Optional[int] - for element in conn_str.strip().split(";"): - key, _, value = element.partition("=") - if key.lower() == "endpoint": - endpoint = value.rstrip("/") - elif key.lower() == "hostname": - endpoint = value.rstrip("/") - elif key.lower() == "sharedaccesskeyname": - shared_access_key_name = value - elif key.lower() == "sharedaccesskey": - shared_access_key = value - elif key.lower() == "entitypath": - entity_path = value - elif key.lower() == "sharedaccesssignature": + + # split connection string into properties + conn_settings = [s.split("=", 1) for s in conn_str.rstrip(";").split(";")] + if any(len(tup) != 2 for tup in conn_settings): + raise ValueError("Connection string is either blank or malformed.") + conn_settings = dict(conn_settings) + + # case sensitive check when parsing for connection string properties + if check_case: + shared_access_key = conn_settings.get("SharedAccessKey") + shared_access_key_name = conn_settings.get("SharedAccessKeyName") + endpoint = conn_settings.get("Endpoint") + entity_path = conn_settings.get("EntityPath") + + # non case sensitive check when parsing connection string for internal use + for key, value in conn_settings.items(): + # only sas check is non case sensitive for both conn str properties and internal use + if key.lower() == "sharedaccesssignature": shared_access_signature = value try: # Expiry can be stored in the "se=" clause of the token. ('&'-separated key-value pairs) @@ -88,30 +89,43 @@ def _parse_conn_str(conn_str): ): # Fallback since technically expiry is optional. # An arbitrary, absurdly large number, since you can't renew. shared_access_signature_expiry = int(time.time() * 2) + if not check_case: + if key.lower() == "endpoint": + endpoint = value.rstrip("/") + elif key.lower() == "hostname": + endpoint = value.rstrip("/") + elif key.lower() == "sharedaccesskeyname": + shared_access_key_name = value + elif key.lower() == "sharedaccesskey": + shared_access_key = value + elif key.lower() == "entitypath": + entity_path = value + + entity = cast(str, entity_path) + # check that endpoint is valid if not endpoint: raise ValueError("Connection string is either blank or malformed.") - parsed = urlparse(endpoint.rstrip("/")) + parsed = urlparse(endpoint) if not parsed.netloc: raise ValueError("Invalid Endpoint on the Connection String.") + host = cast(str, parsed.netloc) + if any([shared_access_key, shared_access_key_name]) and not all( [shared_access_key, shared_access_key_name] ): raise ValueError( "Connection string must have both SharedAccessKeyName and SharedAccessKey." ) - if shared_access_signature is not None and shared_access_key is not None: + if shared_access_signature and shared_access_key: raise ValueError( "Only one of the SharedAccessKey or SharedAccessSignature must be present." ) - if shared_access_signature is None and shared_access_key is None: + if not shared_access_signature and not shared_access_key: raise ValueError( "At least one of the SharedAccessKey or SharedAccessSignature must be present." ) - entity = cast(str, entity_path) - host = cast(str, strip_protocol_from_uri(cast(str, endpoint))) - return ( host, str(shared_access_key_name) if shared_access_key_name else None, diff --git a/sdk/servicebus/azure-servicebus/azure/servicebus/_common/_connection_string_parser.py b/sdk/servicebus/azure-servicebus/azure/servicebus/_common/_connection_string_parser.py index fa48543a87e0..466deaef0a62 100644 --- a/sdk/servicebus/azure-servicebus/azure/servicebus/_common/_connection_string_parser.py +++ b/sdk/servicebus/azure-servicebus/azure/servicebus/_common/_connection_string_parser.py @@ -67,10 +67,10 @@ def parse_connection_string(conn_str): :type conn_str: str :rtype: ~azure.servicebus.ServiceBusConnectionStringProperties """ - namespace, policy, key, entity, signature = _parse_conn_str(conn_str)[:-1] - endpoint = "sb://" + namespace + "/" + fully_qualified_namespace, policy, key, entity, signature = _parse_conn_str(conn_str, True)[:-1] + endpoint = "sb://" + fully_qualified_namespace + "/" props = { - "fully_qualified_namespace": namespace, + "fully_qualified_namespace": fully_qualified_namespace, "endpoint": endpoint, "entity_path": entity, "shared_access_signature": signature, diff --git a/sdk/servicebus/azure-servicebus/tests/test_connection_string_parser.py b/sdk/servicebus/azure-servicebus/tests/test_connection_string_parser.py index 4d9717887df7..5e89eab72377 100644 --- a/sdk/servicebus/azure-servicebus/tests/test_connection_string_parser.py +++ b/sdk/servicebus/azure-servicebus/tests/test_connection_string_parser.py @@ -34,6 +34,12 @@ def test_sb_parse_malformed_conn_str_no_endpoint(self, **kwargs): parse_result = parse_connection_string(conn_str) assert str(e.value) == 'Connection string is either blank or malformed.' + def test_sb_parse_malformed_conn_str_no_endpoint_value(self, **kwargs): + conn_str = 'Endpoint=;SharedAccessKeyName=test;SharedAccessKey=THISISATESTKEYXXXXXXXXXXXXXXXXXXXXXXXXXXXX=' + with pytest.raises(ValueError) as e: + parse_result = parse_connection_string(conn_str) + assert str(e.value) == 'Connection string is either blank or malformed.' + def test_sb_parse_malformed_conn_str_no_netloc(self, **kwargs): conn_str = 'Endpoint=MALFORMED;SharedAccessKeyName=test;SharedAccessKey=THISISATESTKEYXXXXXXXXXXXXXXXXXXXXXXXXXXXX=' with pytest.raises(ValueError) as e: @@ -48,6 +54,14 @@ def test_sb_parse_conn_str_sas(self, **kwargs): assert parse_result.shared_access_signature == 'THISISATESTKEYXXXXXXXXXXXXXXXXXXXXXXXXXXXX=' assert parse_result.shared_access_key_name == None + def test_sb_parse_conn_str_sas_trailing_semicolon(self, **kwargs): + conn_str = 'Endpoint=sb://resourcename.servicebus.windows.net/;SharedAccessSignature=THISISATESTKEYXXXXXXXXXXXXXXXXXXXXXXXXXXXX=;' + parse_result = parse_connection_string(conn_str) + assert parse_result.endpoint == 'sb://resourcename.servicebus.windows.net/' + assert parse_result.fully_qualified_namespace == 'resourcename.servicebus.windows.net' + assert parse_result.shared_access_signature == 'THISISATESTKEYXXXXXXXXXXXXXXXXXXXXXXXXXXXX=' + assert parse_result.shared_access_key_name == None + def test_sb_parse_conn_str_no_keyname(self, **kwargs): conn_str = 'Endpoint=sb://resourcename.servicebus.windows.net/;SharedAccessKey=THISISATESTKEYXXXXXXXXXXXXXXXXXXXXXXXXXXXX=' with pytest.raises(ValueError) as e: @@ -65,3 +79,21 @@ def test_sb_parse_conn_str_no_key_or_sas(self, **kwargs): with pytest.raises(ValueError) as e: parse_result = parse_connection_string(conn_str) assert str(e.value) == 'At least one of the SharedAccessKey or SharedAccessSignature must be present.' + + def test_sb_parse_malformed_conn_str_lowercase_endpoint(self, **kwargs): + conn_str = 'endpoint=sb://resourcename.servicebus.windows.net/;SharedAccessKeyName=test;SharedAccessKey=THISISATESTKEYXXXXXXXXXXXXXXXXXXXXXXXXXXXX=' + with pytest.raises(ValueError) as e: + parse_result = parse_connection_string(conn_str) + assert str(e.value) == 'Connection string is either blank or malformed.' + + def test_sb_parse_malformed_conn_str_lowercase_sa_key_name(self, **kwargs): + conn_str = 'Endpoint=sb://resourcename.servicebus.windows.net/;sharedaccesskeyname=test;SharedAccessKey=THISISATESTKEYXXXXXXXXXXXXXXXXXXXXXXXXXXXX=' + with pytest.raises(ValueError) as e: + parse_result = parse_connection_string(conn_str) + assert str(e.value) == 'Connection string must have both SharedAccessKeyName and SharedAccessKey.' + + def test_sb_parse_malformed_conn_str_lowercase_sa_key_name(self, **kwargs): + conn_str = 'Endpoint=sb://resourcename.servicebus.windows.net/;SharedAccessKeyName=test;sharedaccesskey=THISISATESTKEYXXXXXXXXXXXXXXXXXXXXXXXXXXXX=' + with pytest.raises(ValueError) as e: + parse_result = parse_connection_string(conn_str) + assert str(e.value) == 'Connection string must have both SharedAccessKeyName and SharedAccessKey.' \ No newline at end of file From 1d709212e5cd27a8cb2b0c0a0fb8be84c2ad49c5 Mon Sep 17 00:00:00 2001 From: Swathi Pillalamarri Date: Wed, 3 Feb 2021 19:30:06 -0500 Subject: [PATCH 4/6] add arg type, remove whitespace --- .../azure-servicebus/azure/servicebus/_base_handler.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sdk/servicebus/azure-servicebus/azure/servicebus/_base_handler.py b/sdk/servicebus/azure-servicebus/azure/servicebus/_base_handler.py index c0ecbeb588f9..2241268985de 100644 --- a/sdk/servicebus/azure-servicebus/azure/servicebus/_base_handler.py +++ b/sdk/servicebus/azure-servicebus/azure/servicebus/_base_handler.py @@ -50,7 +50,7 @@ def _parse_conn_str(conn_str, check_case=False): - # type: (str) -> Tuple[str, Optional[str], Optional[str], str, Optional[str], Optional[int]] + # type: (str, Optional[bool]) -> Tuple[str, Optional[str], Optional[str], str, Optional[str], Optional[int]] endpoint = None shared_access_key_name = None shared_access_key = None @@ -110,7 +110,7 @@ def _parse_conn_str(conn_str, check_case=False): if not parsed.netloc: raise ValueError("Invalid Endpoint on the Connection String.") host = cast(str, parsed.netloc) - + if any([shared_access_key, shared_access_key_name]) and not all( [shared_access_key, shared_access_key_name] ): From 05236dc6847b792c4e42c988eb202992905dd76b Mon Sep 17 00:00:00 2001 From: Swathi Pillalamarri Date: Thu, 4 Feb 2021 13:17:04 -0500 Subject: [PATCH 5/6] fix mypy errors --- .../azure/servicebus/_base_handler.py | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/sdk/servicebus/azure-servicebus/azure/servicebus/_base_handler.py b/sdk/servicebus/azure-servicebus/azure/servicebus/_base_handler.py index 2241268985de..623c6e20e6e8 100644 --- a/sdk/servicebus/azure-servicebus/azure/servicebus/_base_handler.py +++ b/sdk/servicebus/azure-servicebus/azure/servicebus/_base_handler.py @@ -59,10 +59,10 @@ def _parse_conn_str(conn_str, check_case=False): shared_access_signature_expiry = None # type: Optional[int] # split connection string into properties - conn_settings = [s.split("=", 1) for s in conn_str.rstrip(";").split(";")] - if any(len(tup) != 2 for tup in conn_settings): + conn_properties = [s.split("=", 1) for s in conn_str.rstrip(";").split(";")] + if any(len(tup) != 2 for tup in conn_properties): raise ValueError("Connection string is either blank or malformed.") - conn_settings = dict(conn_settings) + conn_settings = dict(conn_properties) # type: ignore # case sensitive check when parsing for connection string properties if check_case: @@ -78,9 +78,8 @@ def _parse_conn_str(conn_str, check_case=False): shared_access_signature = value try: # Expiry can be stored in the "se=" clause of the token. ('&'-separated key-value pairs) - # type: ignore shared_access_signature_expiry = int( - shared_access_signature.split("se=")[1].split("&")[0] + shared_access_signature.split("se=")[1].split("&")[0] # type: ignore ) except ( IndexError, @@ -109,7 +108,7 @@ def _parse_conn_str(conn_str, check_case=False): parsed = urlparse(endpoint) if not parsed.netloc: raise ValueError("Invalid Endpoint on the Connection String.") - host = cast(str, parsed.netloc) + host = cast(str, parsed.netloc.strip()) if any([shared_access_key, shared_access_key_name]) and not all( [shared_access_key, shared_access_key_name] From 638c2a9e2c9769010893d46e1d4e92aaec38e44d Mon Sep 17 00:00:00 2001 From: Swathi Pillalamarri Date: Thu, 4 Feb 2021 18:05:27 -0500 Subject: [PATCH 6/6] strip whitespace around conn str --- .../azure-servicebus/azure/servicebus/_base_handler.py | 2 +- .../tests/test_connection_string_parser.py | 8 ++++++++ 2 files changed, 9 insertions(+), 1 deletion(-) diff --git a/sdk/servicebus/azure-servicebus/azure/servicebus/_base_handler.py b/sdk/servicebus/azure-servicebus/azure/servicebus/_base_handler.py index 623c6e20e6e8..560bb289d4c7 100644 --- a/sdk/servicebus/azure-servicebus/azure/servicebus/_base_handler.py +++ b/sdk/servicebus/azure-servicebus/azure/servicebus/_base_handler.py @@ -59,7 +59,7 @@ def _parse_conn_str(conn_str, check_case=False): shared_access_signature_expiry = None # type: Optional[int] # split connection string into properties - conn_properties = [s.split("=", 1) for s in conn_str.rstrip(";").split(";")] + conn_properties = [s.split("=", 1) for s in conn_str.strip().rstrip(";").split(";")] if any(len(tup) != 2 for tup in conn_properties): raise ValueError("Connection string is either blank or malformed.") conn_settings = dict(conn_properties) # type: ignore diff --git a/sdk/servicebus/azure-servicebus/tests/test_connection_string_parser.py b/sdk/servicebus/azure-servicebus/tests/test_connection_string_parser.py index 5e89eab72377..6877c1171ce4 100644 --- a/sdk/servicebus/azure-servicebus/tests/test_connection_string_parser.py +++ b/sdk/servicebus/azure-servicebus/tests/test_connection_string_parser.py @@ -54,6 +54,14 @@ def test_sb_parse_conn_str_sas(self, **kwargs): assert parse_result.shared_access_signature == 'THISISATESTKEYXXXXXXXXXXXXXXXXXXXXXXXXXXXX=' assert parse_result.shared_access_key_name == None + def test_sb_parse_conn_str_whitespace_trailing_semicolon(self, **kwargs): + conn_str = ' Endpoint=sb://resourcename.servicebus.windows.net/;SharedAccessSignature=THISISATESTKEYXXXXXXXXXXXXXXXXXXXXXXXXXXXX=; ' + parse_result = parse_connection_string(conn_str) + assert parse_result.endpoint == 'sb://resourcename.servicebus.windows.net/' + assert parse_result.fully_qualified_namespace == 'resourcename.servicebus.windows.net' + assert parse_result.shared_access_signature == 'THISISATESTKEYXXXXXXXXXXXXXXXXXXXXXXXXXXXX=' + assert parse_result.shared_access_key_name == None + def test_sb_parse_conn_str_sas_trailing_semicolon(self, **kwargs): conn_str = 'Endpoint=sb://resourcename.servicebus.windows.net/;SharedAccessSignature=THISISATESTKEYXXXXXXXXXXXXXXXXXXXXXXXXXXXX=;' parse_result = parse_connection_string(conn_str)