Skip to content

Commit

Permalink
fix: allow timeouts in the client methods that make API call
Browse files Browse the repository at this point in the history
  • Loading branch information
frrist authored and frrist committed May 8, 2024
1 parent 693c6d3 commit 1c62d0f
Show file tree
Hide file tree
Showing 4 changed files with 21 additions and 15 deletions.
7 changes: 6 additions & 1 deletion cmd/cli/version/version.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ package version
import (
"context"
"fmt"
"time"

"github.com/jedib0t/go-pretty/v6/table"
"github.com/rs/zerolog/log"
Expand Down Expand Up @@ -100,8 +101,12 @@ func (oV *VersionOptions) Run(ctx context.Context, cmd *cobra.Command) error {
if oV.ClientOnly {
versions.ClientVersion = version.Get()
} else {
// NB(forrest): since `GetAllVersions` is an API call - in the event the server is un-reachable
// we timeout after 3 seconds to avoid waiting on an unavailable server to return its version information.
vctx, cancel := context.WithTimeout(ctx, time.Second*3)
defer cancel()
var err error
versions, err = util.GetAllVersions(ctx)
versions, err = util.GetAllVersions(vctx)
if err != nil {
// No error on fail of version check. Just print as much as we can.
log.Ctx(ctx).Warn().Err(err).Msg("failed to get updated versions")
Expand Down
4 changes: 2 additions & 2 deletions cmd/util/api.go
Original file line number Diff line number Diff line change
Expand Up @@ -59,8 +59,8 @@ func GetAPIClientV2(cmd *cobra.Command) clientv2.API {
PersistCredential: func(cred *apimodels.HTTPCredential) error {
return WriteToken(base, cred)
},
Authenticate: func(a *clientv2.Auth) (*apimodels.HTTPCredential, error) {
return auth.RunAuthenticationFlow(cmd, a)
Authenticate: func(ctx context.Context, a *clientv2.Auth) (*apimodels.HTTPCredential, error) {
return auth.RunAuthenticationFlow(ctx, cmd, a)
},
},
)
Expand Down
7 changes: 4 additions & 3 deletions cmd/util/auth/auth.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package auth

import (
"context"
"encoding/json"
"fmt"
"io"
Expand All @@ -19,13 +20,13 @@ import (

type responder = func(request *json.RawMessage) (response []byte, err error)

func RunAuthenticationFlow(cmd *cobra.Command, auth *client.Auth) (*apimodels.HTTPCredential, error) {
func RunAuthenticationFlow(ctx context.Context, cmd *cobra.Command, auth *client.Auth) (*apimodels.HTTPCredential, error) {
supportedMethods := map[authn.MethodType]responder{
authn.MethodTypeChallenge: challenge.Respond,
authn.MethodTypeAsk: askResponder(cmd),
}

methods, err := auth.Methods(cmd.Context(), &apimodels.ListAuthnMethodsRequest{})
methods, err := auth.Methods(ctx, &apimodels.ListAuthnMethodsRequest{})
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -60,7 +61,7 @@ func RunAuthenticationFlow(cmd *cobra.Command, auth *client.Auth) (*apimodels.HT
return nil, err
}

authnResponse, err := auth.Authenticate(cmd.Context(), &apimodels.AuthnRequest{
authnResponse, err := auth.Authenticate(ctx, &apimodels.AuthnRequest{
Name: chosenMethodName,
MethodData: response,
})
Expand Down
18 changes: 9 additions & 9 deletions pkg/publicapi/client/v2/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -310,35 +310,35 @@ type AuthenticatingClient struct {

// Authenticate will be called when the system should run an authentication
// flow using the passed Auth API.
Authenticate func(*Auth) (*apimodels.HTTPCredential, error)
Authenticate func(context.Context, *Auth) (*apimodels.HTTPCredential, error)
}

func (t *AuthenticatingClient) Get(ctx context.Context, path string, in apimodels.GetRequest, out apimodels.GetResponse) error {
return doRequest(t, in, func(req apimodels.GetRequest) error {
return doRequest(ctx, t, in, func(req apimodels.GetRequest) error {
return t.Client.Get(ctx, path, req, out)
})
}

func (t *AuthenticatingClient) List(ctx context.Context, path string, in apimodels.ListRequest, out apimodels.ListResponse) error {
return doRequest(t, in, func(req apimodels.ListRequest) error {
return doRequest(ctx, t, in, func(req apimodels.ListRequest) error {
return t.Client.List(ctx, path, req, out)
})
}

func (t *AuthenticatingClient) Post(ctx context.Context, path string, in apimodels.PutRequest, out apimodels.PutResponse) error {
return doRequest(t, in, func(req apimodels.PutRequest) error {
return doRequest(ctx, t, in, func(req apimodels.PutRequest) error {
return t.Client.Post(ctx, path, req, out)
})
}

func (t *AuthenticatingClient) Put(ctx context.Context, path string, in apimodels.PutRequest, out apimodels.PutResponse) error {
return doRequest(t, in, func(req apimodels.PutRequest) error {
return doRequest(ctx, t, in, func(req apimodels.PutRequest) error {
return t.Client.Put(ctx, path, req, out)
})
}

func (t *AuthenticatingClient) Delete(ctx context.Context, path string, in apimodels.PutRequest, out apimodels.Response) error {
return doRequest(t, in, func(req apimodels.PutRequest) error {
return doRequest(ctx, t, in, func(req apimodels.PutRequest) error {
return t.Client.Delete(ctx, path, req, out)
})
}
Expand All @@ -349,14 +349,14 @@ func (t *AuthenticatingClient) Dial(
in apimodels.Request,
) (<-chan *concurrency.AsyncResult[[]byte], error) {
var output <-chan *concurrency.AsyncResult[[]byte]
err := doRequest(t, in, func(req apimodels.Request) (err error) {
err := doRequest(ctx, t, in, func(req apimodels.Request) (err error) {
output, err = t.Client.Dial(ctx, path, req)
return
})
return output, err
}

func doRequest[R apimodels.Request](t *AuthenticatingClient, request R, runRequest func(R) error) (err error) {
func doRequest[R apimodels.Request](ctx context.Context, t *AuthenticatingClient, request R, runRequest func(R) error) (err error) {
if t.Credential != nil {
request.SetCredential(t.Credential)
if err = runRequest(request); err == nil {
Expand All @@ -374,7 +374,7 @@ func doRequest[R apimodels.Request](t *AuthenticatingClient, request R, runReque
if t.Credential == nil || pkgerrors.Is(err, apimodels.ErrInvalidToken) {
var authErr error
auth := NewAPI(t.Client).Auth()
if t.Credential, err = t.Authenticate(auth); err != nil {
if t.Credential, err = t.Authenticate(ctx, auth); err != nil {
authErr = errors.Join(authErr, pkgerrors.Wrap(err, "failed to authorize user"))
t.Credential = nil // Don't assume Authenticate returned nil
}
Expand Down

0 comments on commit 1c62d0f

Please sign in to comment.