Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement GSSAPI authentication #1122

Merged
merged 2 commits into from
Mar 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions .github/workflows/install-krb5.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
#!/bin/bash

set -Eexuo pipefail

if [ "$RUNNER_OS" == "Linux" ]; then
# Assume Ubuntu since this is the only Linux used in CI.
sudo apt-get update
sudo apt-get install -y --no-install-recommends \
libkrb5-dev krb5-user krb5-kdc krb5-admin-server
fi
2 changes: 2 additions & 0 deletions .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ jobs:
- name: Install Python Deps
if: steps.release.outputs.version == 0
run: |
.github/workflows/install-krb5.sh
python -m pip install -U pip setuptools wheel
python -m pip install -e .[test]

Expand Down Expand Up @@ -122,6 +123,7 @@ jobs:
- name: Install Python Deps
if: steps.release.outputs.version == 0
run: |
.github/workflows/install-krb5.sh
python -m pip install -U pip setuptools wheel
python -m pip install -e .[test]

Expand Down
19 changes: 15 additions & 4 deletions asyncpg/connect_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ def parse(cls, sslmode):
'direct_tls',
'server_settings',
'target_session_attrs',
'krbsrvname',
])


Expand Down Expand Up @@ -261,7 +262,7 @@ def _dot_postgresql_path(filename) -> typing.Optional[pathlib.Path]:
def _parse_connect_dsn_and_args(*, dsn, host, port, user,
password, passfile, database, ssl,
direct_tls, server_settings,
target_session_attrs):
target_session_attrs, krbsrvname):
# `auth_hosts` is the version of host information for the purposes
# of reading the pgpass file.
auth_hosts = None
Expand Down Expand Up @@ -383,6 +384,11 @@ def _parse_connect_dsn_and_args(*, dsn, host, port, user,
if target_session_attrs is None:
target_session_attrs = dsn_target_session_attrs

if 'krbsrvname' in query:
val = query.pop('krbsrvname')
if krbsrvname is None:
krbsrvname = val
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please add handling of the PGKRBSRVNAME environment variable as well to stay compatible with libpq.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.


if query:
if server_settings is None:
server_settings = query
Expand Down Expand Up @@ -650,11 +656,15 @@ def _parse_connect_dsn_and_args(*, dsn, host, port, user,
)
) from None

if krbsrvname is None:
krbsrvname = os.getenv('PGKRBSRVNAME')

params = _ConnectionParameters(
user=user, password=password, database=database, ssl=ssl,
sslmode=sslmode, direct_tls=direct_tls,
server_settings=server_settings,
target_session_attrs=target_session_attrs)
target_session_attrs=target_session_attrs,
krbsrvname=krbsrvname)

return addrs, params

Expand All @@ -665,7 +675,7 @@ def _parse_connect_arguments(*, dsn, host, port, user, password, passfile,
max_cached_statement_lifetime,
max_cacheable_statement_size,
ssl, direct_tls, server_settings,
target_session_attrs):
target_session_attrs, krbsrvname):
local_vars = locals()
for var_name in {'max_cacheable_statement_size',
'max_cached_statement_lifetime',
Expand Down Expand Up @@ -694,7 +704,8 @@ def _parse_connect_arguments(*, dsn, host, port, user, password, passfile,
password=password, passfile=passfile, ssl=ssl,
direct_tls=direct_tls, database=database,
server_settings=server_settings,
target_session_attrs=target_session_attrs)
target_session_attrs=target_session_attrs,
krbsrvname=krbsrvname)

config = _ClientConfiguration(
command_timeout=command_timeout,
Expand Down
13 changes: 11 additions & 2 deletions asyncpg/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -2007,7 +2007,8 @@ async def connect(dsn=None, *,
connection_class=Connection,
record_class=protocol.Record,
server_settings=None,
target_session_attrs=None):
target_session_attrs=None,
krbsrvname=None):
r"""A coroutine to establish a connection to a PostgreSQL server.

The connection parameters may be specified either as a connection
Expand Down Expand Up @@ -2235,6 +2236,10 @@ async def connect(dsn=None, *,
or the value of the ``PGTARGETSESSIONATTRS`` environment variable,
or ``"any"`` if neither is specified.

:param str krbsrvname:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Mention the new parameter in a .. versionadded:: block below please.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

Kerberos service name to use when authenticating with GSSAPI. This
must match the server configuration. Defaults to 'postgres'.

:return: A :class:`~asyncpg.connection.Connection` instance.

Example:
Expand Down Expand Up @@ -2303,6 +2308,9 @@ async def connect(dsn=None, *,
.. versionchanged:: 0.28.0
Added the *target_session_attrs* parameter.

.. versionchanged:: 0.30.0
Added the *krbsrvname* parameter.

.. _SSLContext: https://docs.python.org/3/library/ssl.html#ssl.SSLContext
.. _create_default_context:
https://docs.python.org/3/library/ssl.html#ssl.create_default_context
Expand Down Expand Up @@ -2344,7 +2352,8 @@ async def connect(dsn=None, *,
statement_cache_size=statement_cache_size,
max_cached_statement_lifetime=max_cached_statement_lifetime,
max_cacheable_statement_size=max_cacheable_statement_size,
target_session_attrs=target_session_attrs
target_session_attrs=target_session_attrs,
krbsrvname=krbsrvname,
)


Expand Down
15 changes: 5 additions & 10 deletions asyncpg/protocol/coreproto.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -51,16 +51,6 @@ cdef enum AuthenticationMessage:
AUTH_SASL_FINAL = 12


AUTH_METHOD_NAME = {
AUTH_REQUIRED_KERBEROS: 'kerberosv5',
AUTH_REQUIRED_PASSWORD: 'password',
AUTH_REQUIRED_PASSWORDMD5: 'md5',
AUTH_REQUIRED_GSS: 'gss',
AUTH_REQUIRED_SASL: 'scram-sha-256',
AUTH_REQUIRED_SSPI: 'sspi',
}


cdef enum ResultType:
RESULT_OK = 1
RESULT_FAILED = 2
Expand Down Expand Up @@ -96,10 +86,13 @@ cdef class CoreProtocol:

object transport

object address
# Instance of _ConnectionParameters
object con_params
# Instance of SCRAMAuthentication
SCRAMAuthentication scram
# Instance of gssapi.SecurityContext
object gss_ctx

readonly int32_t backend_pid
readonly int32_t backend_secret
Expand Down Expand Up @@ -145,6 +138,8 @@ cdef class CoreProtocol:
cdef _auth_password_message_md5(self, bytes salt)
cdef _auth_password_message_sasl_initial(self, list sasl_auth_methods)
cdef _auth_password_message_sasl_continue(self, bytes server_response)
cdef _auth_gss_init(self)
cdef _auth_gss_step(self, bytes server_response)

cdef _write(self, buf)
cdef _writelines(self, list buffers)
Expand Down
63 changes: 60 additions & 3 deletions asyncpg/protocol/coreproto.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,26 @@


import hashlib
import socket


include "scram.pyx"


cdef dict AUTH_METHOD_NAME = {
AUTH_REQUIRED_KERBEROS: 'kerberosv5',
AUTH_REQUIRED_PASSWORD: 'password',
AUTH_REQUIRED_PASSWORDMD5: 'md5',
AUTH_REQUIRED_GSS: 'gss',
AUTH_REQUIRED_SASL: 'scram-sha-256',
AUTH_REQUIRED_SSPI: 'sspi',
}


cdef class CoreProtocol:

def __init__(self, con_params):
def __init__(self, addr, con_params):
self.address = addr
# type of `con_params` is `_ConnectionParameters`
self.buffer = ReadBuffer()
self.user = con_params.user
Expand All @@ -26,6 +38,8 @@ cdef class CoreProtocol:
self.encoding = 'utf-8'
# type of `scram` is `SCRAMAuthentcation`
self.scram = None
# type of `gss_ctx` is `gssapi.SecurityContext`
self.gss_ctx = None

self._reset_result()

Expand Down Expand Up @@ -619,9 +633,17 @@ cdef class CoreProtocol:
'could not verify server signature for '
'SCRAM authentciation: scram-sha-256',
)
self.scram = None

elif status == AUTH_REQUIRED_GSS:
self._auth_gss_init()
self.auth_msg = self._auth_gss_step(None)

elif status == AUTH_REQUIRED_GSS_CONTINUE:
server_response = self.buffer.consume_message()
self.auth_msg = self._auth_gss_step(server_response)

elif status in (AUTH_REQUIRED_KERBEROS, AUTH_REQUIRED_SCMCRED,
AUTH_REQUIRED_GSS, AUTH_REQUIRED_GSS_CONTINUE,
AUTH_REQUIRED_SSPI):
self.result_type = RESULT_FAILED
self.result = apg_exc.InterfaceError(
Expand All @@ -634,7 +656,8 @@ cdef class CoreProtocol:
'unsupported authentication method requested by the '
'server: {}'.format(status))

if status not in [AUTH_SASL_CONTINUE, AUTH_SASL_FINAL]:
if status not in [AUTH_SASL_CONTINUE, AUTH_SASL_FINAL,
AUTH_REQUIRED_GSS_CONTINUE]:
self.buffer.discard_message()

cdef _auth_password_message_cleartext(self):
Expand Down Expand Up @@ -691,6 +714,40 @@ cdef class CoreProtocol:

return msg

cdef _auth_gss_init(self):
try:
import gssapi
except ModuleNotFoundError:
raise RuntimeError(
'gssapi module not found; please install asyncpg[gssapi] to '
'use asyncpg with Kerberos or GSSAPI authentication'
) from None

service_name = self.con_params.krbsrvname or 'postgres'
# find the canonical name of the server host
if isinstance(self.address, str):
raise RuntimeError('GSSAPI authentication is only supported for '
'TCP/IP connections')

host = self.address[0]
host_cname = socket.gethostbyname_ex(host)[0]
gss_name = gssapi.Name(f'{service_name}/{host_cname}')
self.gss_ctx = gssapi.SecurityContext(name=gss_name, usage='initiate')

cdef _auth_gss_step(self, bytes server_response):
cdef:
WriteBuffer msg

token = self.gss_ctx.step(server_response)
if not token:
self.gss_ctx = None
return None
msg = WriteBuffer.new_message(b'p')
msg.write_bytes(token)
msg.end_message()

return msg

cdef _parse_msg_ready_for_query(self):
cdef char status = self.buffer.read_byte()

Expand Down
1 change: 0 additions & 1 deletion asyncpg/protocol/protocol.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@ cdef class BaseProtocol(CoreProtocol):

cdef:
object loop
object address
ConnectionSettings settings
object cancel_sent_waiter
object cancel_waiter
Expand Down
5 changes: 2 additions & 3 deletions asyncpg/protocol/protocol.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -75,16 +75,15 @@ NO_TIMEOUT = object()
cdef class BaseProtocol(CoreProtocol):
def __init__(self, addr, connected_fut, con_params, record_class: type, loop):
# type of `con_params` is `_ConnectionParameters`
CoreProtocol.__init__(self, con_params)
CoreProtocol.__init__(self, addr, con_params)

self.loop = loop
self.transport = None
self.waiter = connected_fut
self.cancel_waiter = None
self.cancel_sent_waiter = None

self.address = addr
self.settings = ConnectionSettings((self.address, con_params.database))
self.settings = ConnectionSettings((addr, con_params.database))
self.record_class = record_class

self.statement = None
Expand Down
5 changes: 5 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,14 @@ dependencies = [
github = "https://github.com/MagicStack/asyncpg"

[project.optional-dependencies]
gssapi = [
'gssapi',
]
test = [
'flake8~=6.1',
'uvloop>=0.15.3; platform_system != "Windows" and python_version < "3.12.0"',
'gssapi; platform_system == "Linux"',
'k5test; platform_system == "Linux"',
]
docs = [
'Sphinx~=5.3.0',
Expand Down
Loading
Loading