Skip to content

Commit

Permalink
Use pure-sasl in python 3.11
Browse files Browse the repository at this point in the history
  • Loading branch information
mdeshmu committed May 29, 2023
1 parent 0bd6f5b commit dd62269
Show file tree
Hide file tree
Showing 5 changed files with 419 additions and 7 deletions.
33 changes: 30 additions & 3 deletions pyhive/hive.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,8 +200,6 @@ def __init__(
self._transport = thrift.transport.TTransport.TBufferedTransport(socket)
elif auth in ('LDAP', 'KERBEROS', 'NONE', 'CUSTOM'):
# Defer import so package dependency is optional
import sasl
import thrift_sasl

if auth == 'KERBEROS':
# KERBEROS mode in hive.server2.authentication is GSSAPI in sasl library
Expand All @@ -212,7 +210,9 @@ def __init__(
# Password doesn't matter in NONE mode, just needs to be nonempty.
password = 'x'

def sasl_factory():

def get_sasl_client():
import sasl
sasl_client = sasl.Client()
sasl_client.setAttr('host', host)
if sasl_auth == 'GSSAPI':
Expand All @@ -224,6 +224,33 @@ def sasl_factory():
raise AssertionError
sasl_client.init()
return sasl_client


def get_pure_sasl_client():
from pyhive.sasl_compat import PureSASLClient
sasl_kwargs = {}
if sasl_auth == 'GSSAPI':
sasl_kwargs['service'] = kerberos_service_name
elif sasl_auth == 'PLAIN':
sasl_kwargs['username'] = username
sasl_kwargs['password'] = password
else:
raise AssertionError
return PureSASLClient(host=host, **sasl_kwargs)


def sasl_factory():
try:
sasl_client = get_sasl_client()
# The sasl library is available
except ImportError:
# Fallback to pure-sasl library
sasl_client = get_pure_sasl_client()

return sasl_client


import thrift_sasl
self._transport = thrift_sasl.TSaslClientTransport(sasl_factory, sasl_auth, socket)
else:
# All HS2 config options:
Expand Down
56 changes: 56 additions & 0 deletions pyhive/sasl_compat.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
# Original source of this file is https://github.com/cloudera/impyla/blob/master/impala/sasl_compat.py
# which uses Apache-2.0 license as of 21 May 2023.
# This code was added to Impyla in 2016 as a compatibility layer to allow use of either python-sasl or pure-sasl
# via PR https://github.com/cloudera/impyla/pull/179
# Even though thrift_sasl lists pure-sasl as dependency here https://github.com/cloudera/thrift_sasl/blob/master/setup.py#L34
# but it still calls functions native to python-sasl in this file https://github.com/cloudera/thrift_sasl/blob/master/thrift_sasl/__init__.py#L82
# Hence this code is required for the fallback to work.


from puresasl.client import SASLClient, SASLError
from contextlib import contextmanager

@contextmanager
def error_catcher(self, Exc = Exception):
try:
self.error = None
yield
except Exc as e:
self.error = str(e)


class PureSASLClient(SASLClient):
def __init__(self, *args, **kwargs):
self.error = None
super(PureSASLClient, self).__init__(*args, **kwargs)

def start(self, mechanism):
with error_catcher(self, SASLError):
if isinstance(mechanism, list):
self.choose_mechanism(mechanism)
else:
self.choose_mechanism([mechanism])
return True, self.mechanism, self.process()
# else
return False, mechanism, None

def encode(self, incoming):
with error_catcher(self):
return True, self.unwrap(incoming)
# else
return False, None

def decode(self, outgoing):
with error_catcher(self):
return True, self.wrap(outgoing)
# else
return False, None

def step(self, challenge=None):
with error_catcher(self):
return True, self.process(challenge)
# else
return False, None

def getError(self):
return self.error
27 changes: 23 additions & 4 deletions pyhive/tests/test_hive.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,19 +8,17 @@
from __future__ import unicode_literals

import contextlib
import importlib.util
import datetime
import os
import socket
import subprocess
import time
import unittest
from decimal import Decimal

import mock
import sasl
import thrift.transport.TSocket
import thrift.transport.TTransport
import thrift_sasl
from thrift.transport.TTransport import TTransportException

from TCLIService import ttypes
Expand Down Expand Up @@ -204,14 +202,35 @@ def test_custom_transport(self):
socket = thrift.transport.TSocket.TSocket('localhost', 10000)
sasl_auth = 'PLAIN'

def sasl_factory():
spec_sasl = importlib.util.find_spec('sasl')
spec_puresasl = importlib.util.find_spec('puresasl')

def get_sasl_client():
import sasl
sasl_client = sasl.Client()
sasl_client.setAttr('host', 'localhost')
sasl_client.setAttr('username', 'test_username')
sasl_client.setAttr('password', 'x')
sasl_client.init()
return sasl_client

def get_pure_sasl_client():
from pyhive.sasl_compat import PureSASLClient
sasl_client = PureSASLClient(host='localhost', username='test_username', password='x')
return sasl_client

def sasl_factory():
if spec_sasl:
sasl_client = get_sasl_client()
return sasl_client
elif spec_puresasl:
sasl_client = get_pure_sasl_client()
return sasl_client
else:
raise ValueError("No suitable SASL module available. Please install either sasl or pure-sasl.")


import thrift_sasl
transport = thrift_sasl.TSaslClientTransport(sasl_factory, sasl_auth, socket)
conn = hive.connect(thrift_transport=transport)
with contextlib.closing(conn):
Expand Down
Loading

0 comments on commit dd62269

Please sign in to comment.