Skip to content

Commit

Permalink
Allow SASL callback modules to return {ok, ServerResponse}
Browse files Browse the repository at this point in the history
This expansion of the callback return values allows `kpro_connection` to
interrogate the server response message, in preparation for
re-authenticating SASL connections before session lifetime expires.

Authentication was moved to a separate function to allow repeating
authentication flow, which also required storing connection
configuration in process state.
  • Loading branch information
urmastalimaa committed Aug 12, 2024
1 parent 344f2b7 commit ca0bf49
Show file tree
Hide file tree
Showing 3 changed files with 31 additions and 12 deletions.
10 changes: 6 additions & 4 deletions src/kpro_auth_backend.erl
Original file line number Diff line number Diff line change
Expand Up @@ -18,23 +18,25 @@

-export([auth/8]).

-type server_auth_response() :: term().

-callback auth(Host :: string(), Sock :: gen_tcp:socket() | ssl:sslsocket(),
Mod :: gen_tcp | ssl, ClientName :: binary(),
Timeout :: pos_integer(), SaslOpts :: term()) ->
ok | {error, Reason :: term()}.
ok | {ok, server_auth_response()} | {error, Reason :: term()}.

-callback auth(Host :: string(), Sock :: gen_tcp:socket() | ssl:sslsocket(),
HandShakeVsn :: non_neg_integer(), Mod :: gen_tcp | ssl, ClientName :: binary(),
Timeout :: pos_integer(), SaslOpts :: term()) ->
ok | {error, Reason :: term()}.
ok | {ok, server_auth_response()} | {error, Reason :: term()}.

-optional_callbacks([auth/6]).

-spec auth(CallbackModule :: atom(), Host :: string(),
Sock :: gen_tcp:socket() | ssl:sslsocket(),
Mod :: gen_tcp | ssl, ClientName :: binary(),
Timeout :: pos_integer(), SaslOpts :: term()) ->
ok | {error, Reason :: term()}.
ok | {ok, server_auth_response()} | {error, Reason :: term()}.
auth(CallbackModule, Host, Sock, Mod, ClientName, Timeout, SaslOpts) ->
CallbackModule:auth(Host, Sock, Mod, ClientName, Timeout, SaslOpts).

Expand All @@ -43,7 +45,7 @@ auth(CallbackModule, Host, Sock, Mod, ClientName, Timeout, SaslOpts) ->
HandShakeVsn :: non_neg_integer(),
Mod :: gen_tcp | ssl, ClientName :: binary(),
Timeout :: pos_integer(), SaslOpts :: term()) ->
ok | {error, Reason :: term()}.
ok | {ok, server_auth_response()} | {error, Reason :: term()}.
auth(CallbackModule, Host, Sock, Vsn, Mod, ClientName, Timeout, SaslOpts) ->
case is_exported(CallbackModule, auth, 7) of
true ->
Expand Down
31 changes: 23 additions & 8 deletions src/kpro_connection.erl
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,7 @@

-record(state, { client_id :: client_id()
, parent :: pid()
, config :: config()
, remote :: kpro:endpoint()
, sock :: gen_tcp:socket() | ssl:sslsocket()
, mod :: ?undef | gen_tcp | ssl
Expand Down Expand Up @@ -227,6 +228,7 @@ connect(Parent, Host, Port, Config) ->
State = #state{ client_id = get_client_id(Config)
, parent = Parent
, remote = {Host, Port}
, config = Config
, sock = Sock
},
init_connection(State, Config, Deadline);
Expand Down Expand Up @@ -258,14 +260,8 @@ init_connection(#state{ client_id = ClientId
#{query_api_versions := false} -> ?undef;
_ -> query_api_versions(NewSock, Mod, ClientId, Deadline)
end,
HandshakeVsn = case Versions of
#{sasl_handshake := {_, V}} -> V;
_ -> 0
end,
SaslOpts = get_sasl_opt(Config),
ok = kpro_sasl:auth(Host, NewSock, Mod, ClientId,
timeout(Deadline), SaslOpts, HandshakeVsn),
State#state{mod = Mod, sock = NewSock, api_vsns = Versions}.
State1 = State#state{mod = Mod, sock = NewSock, api_vsns = Versions},
sasl_authenticate(State1).

query_api_versions(Sock, Mod, ClientId, Deadline) ->
Req = kpro_req_lib:make(api_versions, 0, []),
Expand Down Expand Up @@ -474,6 +470,25 @@ handle_msg(Msg, #state{} = State, Debug) ->
[?MODULE, self(), Msg]),
?MODULE:loop(State, Debug).

sasl_authenticate(#state{client_id = ClientId, mod = Mod, sock = Sock, remote = {Host, _Port}, api_vsns = Versions, config = Config} = State) ->
Timeout = get_connect_timeout(Config),
Deadline = deadline(Timeout),
SaslOpts = get_sasl_opt(Config),
HandshakeVsn = case Versions of
#{sasl_handshake := {_, V}} -> V;
_ -> 0
end,
ok = setopts(Sock, Mod, [{active, false}]),
case kpro_sasl:auth(Host, Sock, Mod, ClientId,
timeout(Deadline), SaslOpts, HandshakeVsn) of
ok ->
ok;
{ok, _ServerResponse} ->
ok
end,
ok = setopts(Sock, Mod, [{active, once}]),
State.

cast(Pid, Msg) ->
try
Pid ! Msg,
Expand Down
2 changes: 2 additions & 0 deletions src/kpro_sasl.erl
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,8 @@ auth(Host, Sock, Mod, ClientId, Timeout,
ClientId, Timeout, Opts) of
ok ->
ok;
{ok, ServerResponse} ->
{ok, ServerResponse};
{error, Reason} ->
?ERROR(Reason)
end;
Expand Down

0 comments on commit ca0bf49

Please sign in to comment.