Skip to content

Commit

Permalink
Use pure-sasl in python 3.11
Browse files Browse the repository at this point in the history
  • Loading branch information
mdeshmu committed May 22, 2023
1 parent 0bd6f5b commit 0421cda
Show file tree
Hide file tree
Showing 4 changed files with 101 additions and 23 deletions.
42 changes: 29 additions & 13 deletions pyhive/hive.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,8 +200,6 @@ def __init__(
self._transport = thrift.transport.TTransport.TBufferedTransport(socket)
elif auth in ('LDAP', 'KERBEROS', 'NONE', 'CUSTOM'):
# Defer import so package dependency is optional
import sasl
import thrift_sasl

if auth == 'KERBEROS':
# KERBEROS mode in hive.server2.authentication is GSSAPI in sasl library
Expand All @@ -213,17 +211,35 @@ def __init__(
password = 'x'

def sasl_factory():
sasl_client = sasl.Client()
sasl_client.setAttr('host', host)
if sasl_auth == 'GSSAPI':
sasl_client.setAttr('service', kerberos_service_name)
elif sasl_auth == 'PLAIN':
sasl_client.setAttr('username', username)
sasl_client.setAttr('password', password)
else:
raise AssertionError
sasl_client.init()
return sasl_client
try:
import sasl
# The sasl library is available
sasl_client = sasl.Client()
sasl_client.setAttr('host', host)
if sasl_auth == 'GSSAPI':
sasl_client.setAttr('service', kerberos_service_name)
elif sasl_auth == 'PLAIN':
sasl_client.setAttr('username', username)
sasl_client.setAttr('password', password)
else:
raise AssertionError
sasl_client.init()
return sasl_client
except ImportError:
# Fallback to pure-sasl library
from pyhive.sasl_compat import PureSASLClient
sasl_kwargs = {}
if sasl_auth == 'GSSAPI':
sasl_kwargs['service'] = kerberos_service_name
elif sasl_auth == 'PLAIN':
sasl_kwargs['username'] = username
sasl_kwargs['password'] = password
else:
raise AssertionError
sasl_client = PureSASLClient(host=host, **sasl_kwargs)
return sasl_client

import thrift_sasl
self._transport = thrift_sasl.TSaslClientTransport(sasl_factory, sasl_auth, socket)
else:
# All HS2 config options:
Expand Down
56 changes: 56 additions & 0 deletions pyhive/sasl_compat.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
# Original source of this file is https://github.com/cloudera/impyla/blob/master/impala/sasl_compat.py
# which uses Apache-2.0 license as of 21 May 2023.
# This code was added to Impyla in 2016 as a compatibility layer to allow use of either python-sasl or pure-sasl
# via PR https://github.com/cloudera/impyla/pull/179
# Even though thrift_sasl lists pure-sasl as dependency here https://github.com/cloudera/thrift_sasl/blob/master/setup.py#L34
# but it still calls functions native to python-sasl in this file https://github.com/cloudera/thrift_sasl/blob/master/thrift_sasl/__init__.py#L82
# Hence this code is required for the fallback to work.


from puresasl.client import SASLClient, SASLError
from contextlib import contextmanager

@contextmanager
def error_catcher(self, Exc = Exception):
try:
self.error = None
yield
except Exc as e:
self.error = str(e)


class PureSASLClient(SASLClient):
def __init__(self, *args, **kwargs):
self.error = None
super(PureSASLClient, self).__init__(*args, **kwargs)

def start(self, mechanism):
with error_catcher(self, SASLError):
if isinstance(mechanism, list):
self.choose_mechanism(mechanism)
else:
self.choose_mechanism([mechanism])
return True, self.mechanism, self.process()
# else
return False, mechanism, None

def encode(self, incoming):
with error_catcher(self):
return True, self.unwrap(incoming)
# else
return False, None

def decode(self, outgoing):
with error_catcher(self):
return True, self.wrap(outgoing)
# else
return False, None

def step(self, challenge):
with error_catcher(self):
return True, self.process(challenge)
# else
return False, None

def getError(self):
return self.error
24 changes: 14 additions & 10 deletions pyhive/tests/test_hive.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,9 @@
import time
import unittest
from decimal import Decimal

import mock
import sasl
import thrift.transport.TSocket
import thrift.transport.TTransport
import thrift_sasl
from thrift.transport.TTransport import TTransportException

from TCLIService import ttypes
Expand Down Expand Up @@ -205,13 +202,20 @@ def test_custom_transport(self):
sasl_auth = 'PLAIN'

def sasl_factory():
sasl_client = sasl.Client()
sasl_client.setAttr('host', 'localhost')
sasl_client.setAttr('username', 'test_username')
sasl_client.setAttr('password', 'x')
sasl_client.init()
return sasl_client

try:
import sasl
sasl_client = sasl.Client()
sasl_client.setAttr('host', 'localhost')
sasl_client.setAttr('username', 'test_username')
sasl_client.setAttr('password', 'x')
sasl_client.init()
return sasl_client
except ImportError:
from pyhive.sasl_compat import PureSASLClient
sasl_client = PureSASLClient(host='localhost', username='test_username', password='x')
return sasl_client

import thrift_sasl
transport = thrift_sasl.TSaslClientTransport(sasl_factory, sasl_auth, socket)
conn = hive.connect(thrift_transport=transport)
with contextlib.closing(conn):
Expand Down
2 changes: 2 additions & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ def run_tests(self):
'presto': ['requests>=1.0.0'],
'trino': ['requests>=1.0.0'],
'hive': ['sasl>=0.2.1', 'thrift>=0.10.0', 'thrift_sasl>=0.1.0'],
'hive_pure': ['pure-sasl>=0.6.2', 'thrift>=0.10.0', 'thrift_sasl>=0.1.0'],
'sqlalchemy': ['sqlalchemy>=1.3.0'],
'kerberos': ['requests_kerberos>=0.12.0'],
},
Expand All @@ -56,6 +57,7 @@ def run_tests(self):
'requests>=1.0.0',
'requests_kerberos>=0.12.0',
'sasl>=0.2.1',
'pure-sasl>=0.6.2',
'sqlalchemy>=1.3.0',
'thrift>=0.10.0',
],
Expand Down

0 comments on commit 0421cda

Please sign in to comment.