From 1eef88bdd535e78b9b01eb14441eb79945a15cbe Mon Sep 17 00:00:00 2001 From: mdeshmu Date: Sun, 4 Jun 2023 21:17:06 +0530 Subject: [PATCH] use pure-sasl with python 3.11 --- dev_requirements.txt | 2 + pyhive/hive.py | 56 ++++-- pyhive/sasl_compat.py | 56 ++++++ pyhive/tests/test_hive.py | 11 +- pyhive/tests/test_sasl_compat.py | 333 +++++++++++++++++++++++++++++++ setup.py | 3 + 6 files changed, 436 insertions(+), 25 deletions(-) create mode 100644 pyhive/sasl_compat.py create mode 100644 pyhive/tests/test_sasl_compat.py diff --git a/dev_requirements.txt b/dev_requirements.txt index 0bf6d8a7..40bb605a 100644 --- a/dev_requirements.txt +++ b/dev_requirements.txt @@ -12,6 +12,8 @@ pytest-timeout==1.2.0 requests>=1.0.0 requests_kerberos>=0.12.0 sasl>=0.2.1 +pure-sasl>=0.6.2 +kerberos>=1.3.0 thrift>=0.10.0 #thrift_sasl>=0.1.0 git+https://github.com/cloudera/thrift_sasl # Using master branch in order to get Python 3 SASL patches diff --git a/pyhive/hive.py b/pyhive/hive.py index 3f71df33..c1287488 100644 --- a/pyhive/hive.py +++ b/pyhive/hive.py @@ -49,6 +49,45 @@ } +def get_sasl_client(host, sasl_auth, service=None, username=None, password=None): + import sasl + sasl_client = sasl.Client() + sasl_client.setAttr('host', host) + + if sasl_auth == 'GSSAPI': + sasl_client.setAttr('service', service) + elif sasl_auth == 'PLAIN': + sasl_client.setAttr('username', username) + sasl_client.setAttr('password', password) + else: + raise ValueError("sasl_auth only supports GSSAPI and PLAIN") + + sasl_client.init() + return sasl_client + + +def get_pure_sasl_client(host, sasl_auth, service=None, username=None, password=None): + from pyhive.sasl_compat import PureSASLClient + + if sasl_auth == 'GSSAPI': + sasl_kwargs = {'service': service} + elif sasl_auth == 'PLAIN': + sasl_kwargs = {'username': username, 'password': password} + else: + raise ValueError("sasl_auth only supports GSSAPI and PLAIN") + + return PureSASLClient(host=host, **sasl_kwargs) + + +def get_installed_sasl(host, sasl_auth, service=None, username=None, password=None): + try: + return get_sasl_client(host=host, sasl_auth=sasl_auth, service=service, username=username, password=password) + # The sasl library is available + except ImportError: + # Fallback to pure-sasl library + return get_pure_sasl_client(host=host, sasl_auth=sasl_auth, service=service, username=username, password=password) + + def _parse_timestamp(value): if value: match = _TIMESTAMP_PATTERN.match(value) @@ -200,7 +239,6 @@ def __init__( self._transport = thrift.transport.TTransport.TBufferedTransport(socket) elif auth in ('LDAP', 'KERBEROS', 'NONE', 'CUSTOM'): # Defer import so package dependency is optional - import sasl import thrift_sasl if auth == 'KERBEROS': @@ -211,20 +249,8 @@ def __init__( if password is None: # Password doesn't matter in NONE mode, just needs to be nonempty. password = 'x' - - def sasl_factory(): - sasl_client = sasl.Client() - sasl_client.setAttr('host', host) - if sasl_auth == 'GSSAPI': - sasl_client.setAttr('service', kerberos_service_name) - elif sasl_auth == 'PLAIN': - sasl_client.setAttr('username', username) - sasl_client.setAttr('password', password) - else: - raise AssertionError - sasl_client.init() - return sasl_client - self._transport = thrift_sasl.TSaslClientTransport(sasl_factory, sasl_auth, socket) + + self._transport = thrift_sasl.TSaslClientTransport(lambda: get_installed_sasl(host=host, sasl_auth=sasl_auth, service=kerberos_service_name, username=username, password=password), sasl_auth, socket) else: # All HS2 config options: # https://cwiki.apache.org/confluence/display/Hive/Setting+Up+HiveServer2#SettingUpHiveServer2-Configuration diff --git a/pyhive/sasl_compat.py b/pyhive/sasl_compat.py new file mode 100644 index 00000000..dc65abe9 --- /dev/null +++ b/pyhive/sasl_compat.py @@ -0,0 +1,56 @@ +# Original source of this file is https://github.com/cloudera/impyla/blob/master/impala/sasl_compat.py +# which uses Apache-2.0 license as of 21 May 2023. +# This code was added to Impyla in 2016 as a compatibility layer to allow use of either python-sasl or pure-sasl +# via PR https://github.com/cloudera/impyla/pull/179 +# Even though thrift_sasl lists pure-sasl as dependency here https://github.com/cloudera/thrift_sasl/blob/master/setup.py#L34 +# but it still calls functions native to python-sasl in this file https://github.com/cloudera/thrift_sasl/blob/master/thrift_sasl/__init__.py#L82 +# Hence this code is required for the fallback to work. + + +from puresasl.client import SASLClient, SASLError +from contextlib import contextmanager + +@contextmanager +def error_catcher(self, Exc = Exception): + try: + self.error = None + yield + except Exc as e: + self.error = str(e) + + +class PureSASLClient(SASLClient): + def __init__(self, *args, **kwargs): + self.error = None + super(PureSASLClient, self).__init__(*args, **kwargs) + + def start(self, mechanism): + with error_catcher(self, SASLError): + if isinstance(mechanism, list): + self.choose_mechanism(mechanism) + else: + self.choose_mechanism([mechanism]) + return True, self.mechanism, self.process() + # else + return False, mechanism, None + + def encode(self, incoming): + with error_catcher(self): + return True, self.unwrap(incoming) + # else + return False, None + + def decode(self, outgoing): + with error_catcher(self): + return True, self.wrap(outgoing) + # else + return False, None + + def step(self, challenge=None): + with error_catcher(self): + return True, self.process(challenge) + # else + return False, None + + def getError(self): + return self.error diff --git a/pyhive/tests/test_hive.py b/pyhive/tests/test_hive.py index c70ed962..b49fc190 100644 --- a/pyhive/tests/test_hive.py +++ b/pyhive/tests/test_hive.py @@ -17,7 +17,6 @@ from decimal import Decimal import mock -import sasl import thrift.transport.TSocket import thrift.transport.TTransport import thrift_sasl @@ -204,15 +203,7 @@ def test_custom_transport(self): socket = thrift.transport.TSocket.TSocket('localhost', 10000) sasl_auth = 'PLAIN' - def sasl_factory(): - sasl_client = sasl.Client() - sasl_client.setAttr('host', 'localhost') - sasl_client.setAttr('username', 'test_username') - sasl_client.setAttr('password', 'x') - sasl_client.init() - return sasl_client - - transport = thrift_sasl.TSaslClientTransport(sasl_factory, sasl_auth, socket) + transport = thrift_sasl.TSaslClientTransport(lambda: hive.get_installed_sasl(host='localhost', sasl_auth=sasl_auth, username='test_username', password='x'), sasl_auth, socket) conn = hive.connect(thrift_transport=transport) with contextlib.closing(conn): with contextlib.closing(conn.cursor()) as cursor: diff --git a/pyhive/tests/test_sasl_compat.py b/pyhive/tests/test_sasl_compat.py new file mode 100644 index 00000000..49516249 --- /dev/null +++ b/pyhive/tests/test_sasl_compat.py @@ -0,0 +1,333 @@ +''' +http://www.opensource.org/licenses/mit-license.php + +Copyright 2007-2011 David Alan Cridland +Copyright 2011 Lance Stout +Copyright 2012 Tyler L Hobbs + +Permission is hereby granted, free of charge, to any person obtaining a copy of this +software and associated documentation files (the "Software"), to deal in the Software +without restriction, including without limitation the rights to use, copy, modify, merge, +publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons +to whom the Software is furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all copies or +substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, +INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR +PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE +FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR +OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +DEALINGS IN THE SOFTWARE. +''' +# This file was generated by referring test cases from the pure-sasl repo i.e. https://github.com/thobbs/pure-sasl/tree/master/tests/unit +# and by refactoring them to cover wrapper functions in sasl_compat.py along with added coverage for functions exclusive to sasl_compat.py. + +import unittest +import base64 +import hashlib +import hmac +import kerberos +from mock import patch +import six +import struct +from puresasl import SASLProtocolException, QOP +from puresasl.client import SASLError +from pyhive.sasl_compat import PureSASLClient, error_catcher + + +class TestPureSASLClient(unittest.TestCase): + """Test cases for initialization of SASL client using PureSASLClient class""" + + def setUp(self): + self.sasl_kwargs = {} + self.sasl = PureSASLClient('localhost', **self.sasl_kwargs) + + def test_start_no_mechanism(self): + """Test starting SASL authentication with no mechanism.""" + success, mechanism, response = self.sasl.start(mechanism=None) + self.assertFalse(success) + self.assertIsNone(mechanism) + self.assertIsNone(response) + self.assertEqual(self.sasl.getError(), 'None of the mechanisms listed meet all required properties') + + def test_start_wrong_mechanism(self): + """Test starting SASL authentication with a single unsupported mechanism.""" + success, mechanism, response = self.sasl.start(mechanism='WRONG') + self.assertFalse(success) + self.assertEqual(mechanism, 'WRONG') + self.assertIsNone(response) + self.assertEqual(self.sasl.getError(), 'None of the mechanisms listed meet all required properties') + + def test_start_list_of_invalid_mechanisms(self): + """Test starting SASL authentication with a list of unsupported mechanisms.""" + self.sasl.start(['invalid1', 'invalid2']) + self.assertEqual(self.sasl.getError(), 'None of the mechanisms listed meet all required properties') + + def test_start_list_of_valid_mechanisms(self): + """Test starting SASL authentication with a list of supported mechanisms.""" + self.sasl.start(['PLAIN', 'DIGEST-MD5', 'CRAM-MD5']) + # Validate right mechanism is chosen based on score. + self.assertEqual(self.sasl._chosen_mech.name, 'DIGEST-MD5') + + def test_error_catcher_no_error(self): + """Test the error_catcher with no error.""" + with error_catcher(self.sasl): + result, _, _ = self.sasl.start(mechanism='ANONYMOUS') + + self.assertEqual(self.sasl.getError(), None) + self.assertEqual(result, True) + + def test_error_catcher_with_error(self): + """Test the error_catcher with an error.""" + with error_catcher(self.sasl): + result, _, _ = self.sasl.start(mechanism='WRONG') + + self.assertEqual(result, False) + self.assertEqual(self.sasl.getError(), 'None of the mechanisms listed meet all required properties') + +"""Assuming Client initilization went well and a mechanism is chosen, Below are the test cases for different mechanims""" + +class _BaseMechanismTests(unittest.TestCase): + """Base test case for SASL mechanisms.""" + + mechanism = 'ANONYMOUS' + sasl_kwargs = {} + + def setUp(self): + self.sasl = PureSASLClient('localhost', mechanism=self.mechanism, **self.sasl_kwargs) + self.mechanism_class = self.sasl._chosen_mech + + def test_init_basic(self, *args): + sasl = PureSASLClient('localhost', mechanism=self.mechanism, **self.sasl_kwargs) + mech = sasl._chosen_mech + self.assertIs(mech.sasl, sasl) + + def test_step_basic(self, *args): + success, response = self.sasl.step(six.b('string')) + self.assertTrue(success) + self.assertIsInstance(response, six.binary_type) + + def test_decode_encode(self, *args): + self.assertEqual(self.sasl.encode('msg'), (False, None)) + self.assertEqual(self.sasl.getError(), '') + self.assertEqual(self.sasl.decode('msg'), (False, None)) + self.assertEqual(self.sasl.getError(), '') + + +class AnonymousMechanismTest(_BaseMechanismTests): + """Test case for the Anonymous SASL mechanism.""" + + mechanism = 'ANONYMOUS' + + +class PlainTextMechanismTest(_BaseMechanismTests): + """Test case for the PlainText SASL mechanism.""" + + mechanism = 'PLAIN' + username = 'user' + password = 'pass' + sasl_kwargs = {'username': username, 'password': password} + + def test_step(self): + for challenge in (None, '', b'asdf', u"\U0001F44D"): + success, response = self.sasl.step(challenge) + self.assertTrue(success) + self.assertEqual(response, six.b(f'\x00{self.username}\x00{self.password}')) + self.assertIsInstance(response, six.binary_type) + + def test_step_with_authorization_id_or_identity(self): + challenge = u"\U0001F44D" + identity = 'user2' + + # Test that we can pass an identity + sasl_kwargs = self.sasl_kwargs.copy() + sasl_kwargs.update({'identity': identity}) + sasl = PureSASLClient('localhost', mechanism=self.mechanism, **sasl_kwargs) + success, response = sasl.step(challenge) + self.assertTrue(success) + self.assertEqual(response, six.b(f'{identity}\x00{self.username}\x00{self.password}')) + self.assertIsInstance(response, six.binary_type) + self.assertTrue(sasl.complete) + + # Test that the sasl authorization_id has priority over identity + auth_id = 'user3' + sasl_kwargs.update({'authorization_id': auth_id}) + sasl = PureSASLClient('localhost', mechanism=self.mechanism, **sasl_kwargs) + success, response = sasl.step(challenge) + self.assertTrue(success) + self.assertEqual(response, six.b(f'{auth_id}\x00{self.username}\x00{self.password}')) + self.assertIsInstance(response, six.binary_type) + self.assertTrue(sasl.complete) + + def test_decode_encode(self): + msg = 'msg' + self.assertEqual(self.sasl.decode(msg), (True, msg)) + self.assertEqual(self.sasl.encode(msg), (True, msg)) + + +class ExternalMechanismTest(_BaseMechanismTests): + """Test case for the External SASL mechanisms""" + + mechanism = 'EXTERNAL' + + def test_step(self): + self.assertEqual(self.sasl.step(), (True, b'')) + + def test_decode_encode(self): + msg = 'msg' + self.assertEqual(self.sasl.decode(msg), (True, msg)) + self.assertEqual(self.sasl.encode(msg), (True, msg)) + + +@patch('puresasl.mechanisms.kerberos.authGSSClientStep') +@patch('puresasl.mechanisms.kerberos.authGSSClientResponse', return_value=base64.b64encode(six.b('some\x00 response'))) +class GSSAPIMechanismTest(_BaseMechanismTests): + """Test case for the GSSAPI SASL mechanism.""" + + mechanism = 'GSSAPI' + service = 'GSSAPI' + sasl_kwargs = {'service': service} + + @patch('puresasl.mechanisms.kerberos.authGSSClientWrap') + @patch('puresasl.mechanisms.kerberos.authGSSClientUnwrap') + def test_decode_encode(self, _inner1, _inner2, authGSSClientResponse, *args): + # bypassing step setup by setting qop directly + self.mechanism_class.qop = QOP.AUTH + msg = b'msg' + self.assertEqual(self.sasl.decode(msg), (True, msg)) + self.assertEqual(self.sasl.encode(msg), (True, msg)) + + # Test for behavior with different QOP like data integrity and confidentiality for Kerberos authentication + for qop in (QOP.AUTH_INT, QOP.AUTH_CONF): + self.mechanism_class.qop = qop + with patch('puresasl.mechanisms.kerberos.authGSSClientResponseConf', return_value=1): + self.assertEqual(self.sasl.decode(msg), (True, base64.b64decode(authGSSClientResponse.return_value))) + self.assertEqual(self.sasl.encode(msg), (True, base64.b64decode(authGSSClientResponse.return_value))) + if qop == QOP.AUTH_CONF: + with patch('puresasl.mechanisms.kerberos.authGSSClientResponseConf', return_value=0): + self.assertEqual(self.sasl.encode(msg), (False, None)) + self.assertEqual(self.sasl.getError(), 'Error: confidentiality requested, but not honored by the server.') + + def test_step_no_user(self, authGSSClientResponse, *args): + msg = six.b('whatever') + + # no user + self.assertEqual(self.sasl.step(msg), (True, base64.b64decode(authGSSClientResponse.return_value))) + with patch('puresasl.mechanisms.kerberos.authGSSClientResponse', return_value=''): + self.assertEqual(self.sasl.step(msg), (True, six.b(''))) + + username = 'username' + # with user; this has to be last because it sets mechanism.user + with patch('puresasl.mechanisms.kerberos.authGSSClientStep', return_value=kerberos.AUTH_GSS_COMPLETE): + with patch('puresasl.mechanisms.kerberos.authGSSClientUserName', return_value=six.b(username)): + self.assertEqual(self.sasl.step(msg), (True, six.b(''))) + self.assertEqual(self.mechanism_class.user, six.b(username)) + + @patch('puresasl.mechanisms.kerberos.authGSSClientUnwrap') + def test_step_qop(self, *args): + self.mechanism_class._have_negotiated_details = True + self.mechanism_class.user = 'user' + msg = six.b('msg') + self.assertEqual(self.sasl.step(msg), (False, None)) + self.assertEqual(self.sasl.getError(), 'Bad response from server') + + max_len = 100 + self.assertLess(max_len, self.sasl.max_buffer) + for i, qop in QOP.bit_map.items(): + qop_size = struct.pack('!i', i << 24 | max_len) + response = base64.b64encode(qop_size) + with patch('puresasl.mechanisms.kerberos.authGSSClientResponse', return_value=response): + with patch('puresasl.mechanisms.kerberos.authGSSClientWrap') as authGSSClientWrap: + self.mechanism_class.complete = False + self.assertEqual(self.sasl.step(msg), (True, qop_size)) + self.assertTrue(self.mechanism_class.complete) + self.assertEqual(self.mechanism_class.qop, qop) + self.assertEqual(self.mechanism_class.max_buffer, max_len) + + args = authGSSClientWrap.call_args[0] + out_data = args[1] + out = base64.b64decode(out_data) + self.assertEqual(out[:4], qop_size) + self.assertEqual(out[4:], six.b(self.mechanism_class.user)) + + +class CramMD5MechanismTest(_BaseMechanismTests): + """Test case for the CRAM-MD5 SASL mechanism.""" + + mechanism = 'CRAM-MD5' + username = 'user' + password = 'pass' + sasl_kwargs = {'username': username, 'password': password} + + def test_step(self): + success, response = self.sasl.step(None) + self.assertTrue(success) + self.assertIsNone(response) + challenge = six.b('msg') + hash = hmac.HMAC(key=six.b(self.password), digestmod=hashlib.md5) + hash.update(challenge) + success, response = self.sasl.step(challenge) + self.assertTrue(success) + self.assertIn(six.b(self.username), response) + self.assertIn(six.b(hash.hexdigest()), response) + self.assertIsInstance(response, six.binary_type) + self.assertTrue(self.sasl.complete) + + def test_decode_encode(self): + msg = 'msg' + self.assertEqual(self.sasl.decode(msg), (True, msg)) + self.assertEqual(self.sasl.encode(msg), (True, msg)) + + +class DigestMD5MechanismTest(_BaseMechanismTests): + """Test case for the DIGEST-MD5 SASL mechanism.""" + + mechanism = 'DIGEST-MD5' + username = 'user' + password = 'pass' + sasl_kwargs = {'username': username, 'password': password} + + def test_decode_encode(self): + msg = 'msg' + self.assertEqual(self.sasl.decode(msg), (True, msg)) + self.assertEqual(self.sasl.encode(msg), (True, msg)) + + def test_step_basic(self, *args): + pass + + def test_step(self): + """Test a SASL step with dummy challenge for DIGEST-MD5 mechanism.""" + testChallenge = ( + b'nonce="rmD6R8aMYVWH+/ih9HGBr3xNGAR6o2DUxpKlgDz6gUQ=",r' + b'ealm="example.org",qop="auth,auth-int,auth-conf",cipher="rc4-40,rc' + b'4-56,rc4,des,3des",maxbuf=65536,charset=utf-8,algorithm=md5-sess' + ) + result, response = self.sasl.step(testChallenge) + self.assertTrue(result) + self.assertIsNotNone(response) + + def test_step_server_answer(self): + """Test a SASL step with a proper server answer for DIGEST-MD5 mechanism.""" + sasl_kwargs = {'username': "chris", 'password': "secret"} + sasl = PureSASLClient('elwood.innosoft.com', + service="imap", + mechanism=self.mechanism, + mutual_auth=True, + **sasl_kwargs) + testChallenge = ( + b'utf-8,username="chris",realm="elwood.innosoft.com",' + b'nonce="OA6MG9tEQGm2hh",nc=00000001,cnonce="OA6MHXh6VqTrRk",' + b'digest-uri="imap/elwood.innosoft.com",' + b'response=d388dad90d4bbd760a152321f2143af7,qop=auth' + ) + sasl.step(testChallenge) + sasl._chosen_mech.cnonce = b"OA6MHXh6VqTrRk" + + serverResponse = ( + b'rspauth=ea40f60335c427b5527b84dbabcdfffd' + ) + sasl.step(serverResponse) + # assert that step choses the only supported QOP for for DIGEST-MD5 + self.assertEqual(self.sasl.qop, QOP.AUTH) diff --git a/setup.py b/setup.py index be593fc0..d141ea1b 100755 --- a/setup.py +++ b/setup.py @@ -46,6 +46,7 @@ def run_tests(self): 'presto': ['requests>=1.0.0'], 'trino': ['requests>=1.0.0'], 'hive': ['sasl>=0.2.1', 'thrift>=0.10.0', 'thrift_sasl>=0.1.0'], + 'hive_pure_sasl': ['pure-sasl>=0.6.2', 'thrift>=0.10.0', 'thrift_sasl>=0.1.0'], 'sqlalchemy': ['sqlalchemy>=1.3.0'], 'kerberos': ['requests_kerberos>=0.12.0'], }, @@ -56,6 +57,8 @@ def run_tests(self): 'requests>=1.0.0', 'requests_kerberos>=0.12.0', 'sasl>=0.2.1', + 'pure-sasl>=0.6.2', + 'kerberos>=1.3.0', 'sqlalchemy>=1.3.0', 'thrift>=0.10.0', ],