diff --git a/internal/http/interceptors/auth/auth.go b/internal/http/interceptors/auth/auth.go index 1a10d4fccc..3348e07947 100644 --- a/internal/http/interceptors/auth/auth.go +++ b/internal/http/interceptors/auth/auth.go @@ -200,92 +200,87 @@ func authenticateUser(w http.ResponseWriter, r *http.Request, conf *config, toke return nil, err } - tkn := tokenStrategy.GetToken(r) - if tkn == "" { - log.Warn().Msg("core access token not set") + // reva token or auth token can be passed using the same tecnique (for example bearer) + // before validating it against an auth provider, we can check directly if it's a reva + // token and if not try to use it for authenticating the user. + + token := tokenStrategy.GetToken(r) + if token != "" { + if user, ok := isTokenValid(r, tokenManager, token); ok { + if err := insertGroupsInUser(ctx, userGroupsCache, client, user); err != nil { + logError(isUnprotectedEndpoint, log, err, "got an error retrieving groups for user "+user.Username, http.StatusInternalServerError, w) + return nil, err + } + return ctxWithUserInfo(ctx, r, user, token), nil + } + } - userAgentCredKeys := getCredsForUserAgent(r.UserAgent(), conf.CredentialsByUserAgent, conf.CredentialChain) + log.Warn().Msg("core access token not set") - // obtain credentials (basic auth, bearer token, ...) based on user agent - var creds *auth.Credentials - for _, k := range userAgentCredKeys { - creds, err = credChain[k].GetCredentials(w, r) - if err != nil { - log.Debug().Err(err).Msg("error retrieving credentials") - } + userAgentCredKeys := getCredsForUserAgent(r.UserAgent(), conf.CredentialsByUserAgent, conf.CredentialChain) - if creds != nil { - log.Debug().Msgf("credentials obtained from credential strategy: type: %s, client_id: %s", creds.Type, creds.ClientID) - break - } + // obtain credentials (basic auth, bearer token, ...) based on user agent + var creds *auth.Credentials + for _, k := range userAgentCredKeys { + creds, err = credChain[k].GetCredentials(w, r) + if err != nil { + log.Debug().Err(err).Msg("error retrieving credentials") } - // if no credentials are found, reply with authentication challenge depending on user agent - if creds == nil { - if !isUnprotectedEndpoint { - for _, key := range userAgentCredKeys { - if cred, ok := credChain[key]; ok { - cred.AddWWWAuthenticate(w, r, conf.Realm) - } else { - panic("auth credential strategy: " + key + "must have been loaded in init method") - } + if creds != nil { + log.Debug().Msgf("credentials obtained from credential strategy: type: %s, client_id: %s", creds.Type, creds.ClientID) + break + } + } + + // if no credentials are found, reply with authentication challenge depending on user agent + if creds == nil { + if !isUnprotectedEndpoint { + for _, key := range userAgentCredKeys { + if cred, ok := credChain[key]; ok { + cred.AddWWWAuthenticate(w, r, conf.Realm) + } else { + panic("auth credential strategy: " + key + "must have been loaded in init method") } - w.WriteHeader(http.StatusUnauthorized) } - return nil, errtypes.PermissionDenied("no credentials found") + w.WriteHeader(http.StatusUnauthorized) } + return nil, errtypes.PermissionDenied("no credentials found") + } - req := &gateway.AuthenticateRequest{ - Type: creds.Type, - ClientId: creds.ClientID, - ClientSecret: creds.ClientSecret, - } + req := &gateway.AuthenticateRequest{ + Type: creds.Type, + ClientId: creds.ClientID, + ClientSecret: creds.ClientSecret, + } - log.Debug().Msgf("AuthenticateRequest: type: %s, client_id: %s against %s", req.Type, req.ClientId, conf.GatewaySvc) + log.Debug().Msgf("AuthenticateRequest: type: %s, client_id: %s against %s", req.Type, req.ClientId, conf.GatewaySvc) - res, err := client.Authenticate(ctx, req) - if err != nil { - logError(isUnprotectedEndpoint, log, err, "error calling Authenticate", http.StatusUnauthorized, w) - return nil, err - } - - if res.Status.Code != rpc.Code_CODE_OK { - err := status.NewErrorFromCode(res.Status.Code, "auth") - logError(isUnprotectedEndpoint, log, err, "error generating access token from credentials", http.StatusUnauthorized, w) - return nil, err - } + res, err := client.Authenticate(ctx, req) + if err != nil { + logError(isUnprotectedEndpoint, log, err, "error calling Authenticate", http.StatusUnauthorized, w) + return nil, err + } - log.Info().Msg("core access token generated") - // write token to response - tkn = res.Token - tokenWriter.WriteToken(tkn, w) - } else { - log.Debug().Msg("access token is already provided") + if res.Status.Code != rpc.Code_CODE_OK { + err := status.NewErrorFromCode(res.Status.Code, "auth") + logError(isUnprotectedEndpoint, log, err, "error generating access token from credentials", http.StatusUnauthorized, w) + return nil, err } + log.Info().Msg("core access token generated") + + // write token to response + token = res.Token + tokenWriter.WriteToken(token, w) + // validate token - u, tokenScope, err := tokenManager.DismantleToken(r.Context(), tkn) + u, tokenScope, err := tokenManager.DismantleToken(r.Context(), token) if err != nil { logError(isUnprotectedEndpoint, log, err, "error dismantling token", http.StatusUnauthorized, w) return nil, err } - if sharedconf.SkipUserGroupsInToken() { - var groups []string - if groupsIf, err := userGroupsCache.Get(u.Id.OpaqueId); err == nil { - groups = groupsIf.([]string) - } else { - groupsRes, err := client.GetUserGroups(ctx, &userpb.GetUserGroupsRequest{UserId: u.Id}) - if err != nil { - logError(isUnprotectedEndpoint, log, err, "error retrieving user groups", http.StatusInternalServerError, w) - return nil, err - } - groups = groupsRes.Groups - _ = userGroupsCache.SetWithExpire(u.Id.OpaqueId, groupsRes.Groups, 3600*time.Second) - } - u.Groups = groups - } - // ensure access to the resource is allowed ok, err := scope.VerifyScope(ctx, tokenScope, r.URL.Path) if err != nil { @@ -298,14 +293,51 @@ func authenticateUser(w http.ResponseWriter, r *http.Request, conf *config, toke return nil, err } - // store user and core access token in context. - ctx = ctxpkg.ContextSetUser(ctx, u) - ctx = ctxpkg.ContextSetToken(ctx, tkn) - ctx = metadata.AppendToOutgoingContext(ctx, ctxpkg.TokenHeader, tkn) // TODO(jfd): hardcoded metadata key. use PerRPCCredentials? + return ctxWithUserInfo(ctx, r, u, token), nil +} +func ctxWithUserInfo(ctx context.Context, r *http.Request, user *userpb.User, token string) context.Context { + ctx = ctxpkg.ContextSetUser(ctx, user) + ctx = ctxpkg.ContextSetToken(ctx, token) + ctx = metadata.AppendToOutgoingContext(ctx, ctxpkg.TokenHeader, token) ctx = metadata.AppendToOutgoingContext(ctx, ctxpkg.UserAgentHeader, r.UserAgent()) - return ctx, nil + return ctx +} + +func insertGroupsInUser(ctx context.Context, userGroupsCache gcache.Cache, client gateway.GatewayAPIClient, user *userpb.User) error { + if sharedconf.SkipUserGroupsInToken() { + var groups []string + if groupsIf, err := userGroupsCache.Get(user.Id.OpaqueId); err == nil { + groups = groupsIf.([]string) + } else { + groupsRes, err := client.GetUserGroups(ctx, &userpb.GetUserGroupsRequest{UserId: user.Id}) + if err != nil { + return err + } + groups = groupsRes.Groups + _ = userGroupsCache.SetWithExpire(user.Id.OpaqueId, groupsRes.Groups, 3600*time.Second) + } + user.Groups = groups + } + return nil +} + +func isTokenValid(r *http.Request, tokenManager token.Manager, token string) (*userpb.User, bool) { + ctx := r.Context() + + u, tokenScope, err := tokenManager.DismantleToken(ctx, token) + if err != nil { + return nil, false + } + + // ensure access to the resource is allowed + ok, err := scope.VerifyScope(ctx, tokenScope, r.URL.Path) + if err != nil { + return nil, false + } + + return u, ok } func logError(isUnprotectedEndpoint bool, log *zerolog.Logger, err error, msg string, status int, w http.ResponseWriter) {