Skip to content

Commit

Permalink
use pure-sasl with python 3.11
Browse files Browse the repository at this point in the history
  • Loading branch information
mdeshmu committed Jun 20, 2023
1 parent 0bd6f5b commit 1eef88b
Show file tree
Hide file tree
Showing 6 changed files with 436 additions and 25 deletions.
2 changes: 2 additions & 0 deletions dev_requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@ pytest-timeout==1.2.0
requests>=1.0.0
requests_kerberos>=0.12.0
sasl>=0.2.1
pure-sasl>=0.6.2
kerberos>=1.3.0
thrift>=0.10.0
#thrift_sasl>=0.1.0
git+https://github.com/cloudera/thrift_sasl # Using master branch in order to get Python 3 SASL patches
56 changes: 41 additions & 15 deletions pyhive/hive.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,45 @@
}


def get_sasl_client(host, sasl_auth, service=None, username=None, password=None):
import sasl
sasl_client = sasl.Client()
sasl_client.setAttr('host', host)

if sasl_auth == 'GSSAPI':
sasl_client.setAttr('service', service)
elif sasl_auth == 'PLAIN':
sasl_client.setAttr('username', username)
sasl_client.setAttr('password', password)
else:
raise ValueError("sasl_auth only supports GSSAPI and PLAIN")

sasl_client.init()
return sasl_client


def get_pure_sasl_client(host, sasl_auth, service=None, username=None, password=None):
from pyhive.sasl_compat import PureSASLClient

if sasl_auth == 'GSSAPI':
sasl_kwargs = {'service': service}
elif sasl_auth == 'PLAIN':
sasl_kwargs = {'username': username, 'password': password}
else:
raise ValueError("sasl_auth only supports GSSAPI and PLAIN")

return PureSASLClient(host=host, **sasl_kwargs)


def get_installed_sasl(host, sasl_auth, service=None, username=None, password=None):
try:
return get_sasl_client(host=host, sasl_auth=sasl_auth, service=service, username=username, password=password)
# The sasl library is available
except ImportError:
# Fallback to pure-sasl library
return get_pure_sasl_client(host=host, sasl_auth=sasl_auth, service=service, username=username, password=password)


def _parse_timestamp(value):
if value:
match = _TIMESTAMP_PATTERN.match(value)
Expand Down Expand Up @@ -200,7 +239,6 @@ def __init__(
self._transport = thrift.transport.TTransport.TBufferedTransport(socket)
elif auth in ('LDAP', 'KERBEROS', 'NONE', 'CUSTOM'):
# Defer import so package dependency is optional
import sasl
import thrift_sasl

if auth == 'KERBEROS':
Expand All @@ -211,20 +249,8 @@ def __init__(
if password is None:
# Password doesn't matter in NONE mode, just needs to be nonempty.
password = 'x'

def sasl_factory():
sasl_client = sasl.Client()
sasl_client.setAttr('host', host)
if sasl_auth == 'GSSAPI':
sasl_client.setAttr('service', kerberos_service_name)
elif sasl_auth == 'PLAIN':
sasl_client.setAttr('username', username)
sasl_client.setAttr('password', password)
else:
raise AssertionError
sasl_client.init()
return sasl_client
self._transport = thrift_sasl.TSaslClientTransport(sasl_factory, sasl_auth, socket)

self._transport = thrift_sasl.TSaslClientTransport(lambda: get_installed_sasl(host=host, sasl_auth=sasl_auth, service=kerberos_service_name, username=username, password=password), sasl_auth, socket)
else:
# All HS2 config options:
# https://cwiki.apache.org/confluence/display/Hive/Setting+Up+HiveServer2#SettingUpHiveServer2-Configuration
Expand Down
56 changes: 56 additions & 0 deletions pyhive/sasl_compat.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
# Original source of this file is https://github.com/cloudera/impyla/blob/master/impala/sasl_compat.py
# which uses Apache-2.0 license as of 21 May 2023.
# This code was added to Impyla in 2016 as a compatibility layer to allow use of either python-sasl or pure-sasl
# via PR https://github.com/cloudera/impyla/pull/179
# Even though thrift_sasl lists pure-sasl as dependency here https://github.com/cloudera/thrift_sasl/blob/master/setup.py#L34
# but it still calls functions native to python-sasl in this file https://github.com/cloudera/thrift_sasl/blob/master/thrift_sasl/__init__.py#L82
# Hence this code is required for the fallback to work.


from puresasl.client import SASLClient, SASLError
from contextlib import contextmanager

@contextmanager
def error_catcher(self, Exc = Exception):
try:
self.error = None
yield
except Exc as e:
self.error = str(e)


class PureSASLClient(SASLClient):
def __init__(self, *args, **kwargs):
self.error = None
super(PureSASLClient, self).__init__(*args, **kwargs)

def start(self, mechanism):
with error_catcher(self, SASLError):
if isinstance(mechanism, list):
self.choose_mechanism(mechanism)
else:
self.choose_mechanism([mechanism])
return True, self.mechanism, self.process()
# else
return False, mechanism, None

def encode(self, incoming):
with error_catcher(self):
return True, self.unwrap(incoming)
# else
return False, None

def decode(self, outgoing):
with error_catcher(self):
return True, self.wrap(outgoing)
# else
return False, None

def step(self, challenge=None):
with error_catcher(self):
return True, self.process(challenge)
# else
return False, None

def getError(self):
return self.error
11 changes: 1 addition & 10 deletions pyhive/tests/test_hive.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
from decimal import Decimal

import mock
import sasl
import thrift.transport.TSocket
import thrift.transport.TTransport
import thrift_sasl
Expand Down Expand Up @@ -204,15 +203,7 @@ def test_custom_transport(self):
socket = thrift.transport.TSocket.TSocket('localhost', 10000)
sasl_auth = 'PLAIN'

def sasl_factory():
sasl_client = sasl.Client()
sasl_client.setAttr('host', 'localhost')
sasl_client.setAttr('username', 'test_username')
sasl_client.setAttr('password', 'x')
sasl_client.init()
return sasl_client

transport = thrift_sasl.TSaslClientTransport(sasl_factory, sasl_auth, socket)
transport = thrift_sasl.TSaslClientTransport(lambda: hive.get_installed_sasl(host='localhost', sasl_auth=sasl_auth, username='test_username', password='x'), sasl_auth, socket)
conn = hive.connect(thrift_transport=transport)
with contextlib.closing(conn):
with contextlib.closing(conn.cursor()) as cursor:
Expand Down
Loading

0 comments on commit 1eef88b

Please sign in to comment.