diff --git a/cmd/dex/config.go b/cmd/dex/config.go index 474f28a6e7..e4c3988ffc 100644 --- a/cmd/dex/config.go +++ b/cmd/dex/config.go @@ -6,6 +6,7 @@ import ( "fmt" "log/slog" "net/http" + "net/netip" "os" "slices" "strings" @@ -182,15 +183,38 @@ type OAuth2 struct { // Web is the config format for the HTTP server. type Web struct { - HTTP string `json:"http"` - HTTPS string `json:"https"` - Headers Headers `json:"headers"` - TLSCert string `json:"tlsCert"` - TLSKey string `json:"tlsKey"` - TLSMinVersion string `json:"tlsMinVersion"` - TLSMaxVersion string `json:"tlsMaxVersion"` - AllowedOrigins []string `json:"allowedOrigins"` - AllowedHeaders []string `json:"allowedHeaders"` + HTTP string `json:"http"` + HTTPS string `json:"https"` + Headers Headers `json:"headers"` + TLSCert string `json:"tlsCert"` + TLSKey string `json:"tlsKey"` + TLSMinVersion string `json:"tlsMinVersion"` + TLSMaxVersion string `json:"tlsMaxVersion"` + AllowedOrigins []string `json:"allowedOrigins"` + AllowedHeaders []string `json:"allowedHeaders"` + ClientRemoteIP ClientRemoteIP `json:"clientRemoteIP"` +} + +type ClientRemoteIP struct { + Header string `json:"header"` + TrustedProxies []string `json:"trustedProxies"` +} + +func (cr *ClientRemoteIP) ParseTrustedProxies() ([]netip.Prefix, error) { + if cr == nil { + return nil, nil + } + + trusted := make([]netip.Prefix, 0, len(cr.TrustedProxies)) + for _, cidr := range cr.TrustedProxies { + ipNet, err := netip.ParsePrefix(cidr) + if err != nil { + return nil, fmt.Errorf("failed to parse CIDR %q: %v", cidr, err) + } + trusted = append(trusted, ipNet) + } + + return trusted, nil } type Headers struct { diff --git a/cmd/dex/logger.go b/cmd/dex/logger.go new file mode 100644 index 0000000000..e979011c4f --- /dev/null +++ b/cmd/dex/logger.go @@ -0,0 +1,67 @@ +package main + +import ( + "context" + "fmt" + "log/slog" + "os" + "strings" + + "github.com/dexidp/dex/server" +) + +var logFormats = []string{"json", "text"} + +func newLogger(level slog.Level, format string) (*slog.Logger, error) { + var handler slog.Handler + switch strings.ToLower(format) { + case "", "text": + handler = slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{ + Level: level, + }) + case "json": + handler = slog.NewJSONHandler(os.Stderr, &slog.HandlerOptions{ + Level: level, + }) + default: + return nil, fmt.Errorf("log format is not one of the supported values (%s): %s", strings.Join(logFormats, ", "), format) + } + + return slog.New(newRequestContextHandler(handler)), nil +} + +var _ slog.Handler = requestContextHandler{} + +type requestContextHandler struct { + handler slog.Handler +} + +func newRequestContextHandler(handler slog.Handler) slog.Handler { + return requestContextHandler{ + handler: handler, + } +} + +func (h requestContextHandler) Enabled(ctx context.Context, level slog.Level) bool { + return h.handler.Enabled(ctx, level) +} + +func (h requestContextHandler) Handle(ctx context.Context, record slog.Record) error { + if v, ok := ctx.Value(server.RequestKeyRemoteIP).(string); ok { + record.AddAttrs(slog.String(string(server.RequestKeyRemoteIP), v)) + } + + if v, ok := ctx.Value(server.RequestKeyRequestID).(string); ok { + record.AddAttrs(slog.String(string(server.RequestKeyRequestID), v)) + } + + return h.handler.Handle(ctx, record) +} + +func (h requestContextHandler) WithAttrs(attrs []slog.Attr) slog.Handler { + return requestContextHandler{h.handler.WithAttrs(attrs)} +} + +func (h requestContextHandler) WithGroup(name string) slog.Handler { + return h.handler.WithGroup(name) +} diff --git a/cmd/dex/serve.go b/cmd/dex/serve.go index 863f039b30..572da8c97a 100644 --- a/cmd/dex/serve.go +++ b/cmd/dex/serve.go @@ -348,6 +348,13 @@ func runServe(options serveOptions) error { } serverConfig.RefreshTokenPolicy = refreshTokenPolicy + + serverConfig.RealIPHeader = c.Web.ClientRemoteIP.Header + serverConfig.TrustedRealIPCIDRs, err = c.Web.ClientRemoteIP.ParseTrustedProxies() + if err != nil { + return fmt.Errorf("failed to parse client remote IP settings: %v", err) + } + serv, err := server.NewServer(context.Background(), serverConfig) if err != nil { return fmt.Errorf("failed to initialize server: %v", err) @@ -528,26 +535,6 @@ func runServe(options serveOptions) error { return nil } -var logFormats = []string{"json", "text"} - -func newLogger(level slog.Level, format string) (*slog.Logger, error) { - var handler slog.Handler - switch strings.ToLower(format) { - case "", "text": - handler = slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{ - Level: level, - }) - case "json": - handler = slog.NewJSONHandler(os.Stderr, &slog.HandlerOptions{ - Level: level, - }) - default: - return nil, fmt.Errorf("log format is not one of the supported values (%s): %s", strings.Join(logFormats, ", "), format) - } - - return slog.New(handler), nil -} - func applyConfigOverrides(options serveOptions, config *Config) { if options.webHTTPAddr != "" { config.Web.HTTP = options.webHTTPAddr diff --git a/examples/config-dev.yaml b/examples/config-dev.yaml index 1803511aa6..47adc04b75 100644 --- a/examples/config-dev.yaml +++ b/examples/config-dev.yaml @@ -58,7 +58,10 @@ web: # X-XSS-Protection: "1; mode=block" # Content-Security-Policy: "default-src 'self'" # Strict-Transport-Security: "max-age=31536000; includeSubDomains" - + # clientRemoteIP: + # header: X-Forwarded-For + # trustedProxies: + # - 10.0.0.0/8 # Configuration for dex appearance # frontend: diff --git a/server/deviceflowhandlers.go b/server/deviceflowhandlers.go index 6f8aae0306..06f3a7b2d5 100644 --- a/server/deviceflowhandlers.go +++ b/server/deviceflowhandlers.go @@ -48,7 +48,7 @@ func (s *Server) handleDeviceExchange(w http.ResponseWriter, r *http.Request) { invalidAttempt = false } if err := s.templates.device(r, w, s.getDeviceVerificationURI(), userCode, invalidAttempt); err != nil { - s.logger.Error("server template error", "err", err) + s.logger.ErrorContext(r.Context(), "server template error", "err", err) s.renderError(r, w, http.StatusNotFound, "Page not found") } default: @@ -64,7 +64,7 @@ func (s *Server) handleDeviceCode(w http.ResponseWriter, r *http.Request) { case http.MethodPost: err := r.ParseForm() if err != nil { - s.logger.Error("could not parse Device Request body", "err", err) + s.logger.ErrorContext(r.Context(), "could not parse Device Request body", "err", err) s.tokenErrHelper(w, errInvalidRequest, "", http.StatusNotFound) return } @@ -85,7 +85,7 @@ func (s *Server) handleDeviceCode(w http.ResponseWriter, r *http.Request) { return } - s.logger.Info("received device request", "client_id", clientID, "scoped", scopes) + s.logger.InfoContext(r.Context(), "received device request", "client_id", clientID, "scoped", scopes) // Make device code deviceCode := storage.NewDeviceCode() @@ -107,7 +107,7 @@ func (s *Server) handleDeviceCode(w http.ResponseWriter, r *http.Request) { } if err := s.storage.CreateDeviceRequest(ctx, deviceReq); err != nil { - s.logger.Error("failed to store device request", "err", err) + s.logger.ErrorContext(r.Context(), "failed to store device request", "err", err) s.tokenErrHelper(w, errInvalidRequest, "", http.StatusInternalServerError) return } @@ -126,14 +126,14 @@ func (s *Server) handleDeviceCode(w http.ResponseWriter, r *http.Request) { } if err := s.storage.CreateDeviceToken(ctx, deviceToken); err != nil { - s.logger.Error("failed to store device token", "err", err) + s.logger.ErrorContext(r.Context(), "failed to store device token", "err", err) s.tokenErrHelper(w, errInvalidRequest, "", http.StatusInternalServerError) return } u, err := url.Parse(s.issuerURL.String()) if err != nil { - s.logger.Error("could not parse issuer URL", "err", err) + s.logger.ErrorContext(r.Context(), "could not parse issuer URL", "err", err) s.tokenErrHelper(w, errInvalidRequest, "", http.StatusInternalServerError) return } @@ -211,7 +211,7 @@ func (s *Server) handleDeviceToken(w http.ResponseWriter, r *http.Request) { deviceToken, err := s.storage.GetDeviceToken(deviceCode) if err != nil { if err != storage.ErrNotFound { - s.logger.Error("failed to get device code", "err", err) + s.logger.ErrorContext(r.Context(), "failed to get device code", "err", err) } s.tokenErrHelper(w, errInvalidRequest, "Invalid Device code.", http.StatusBadRequest) return @@ -241,7 +241,7 @@ func (s *Server) handleDeviceToken(w http.ResponseWriter, r *http.Request) { } // Update device token last request time in storage if err := s.storage.UpdateDeviceToken(deviceCode, updater); err != nil { - s.logger.Error("failed to update device token", "err", err) + s.logger.ErrorContext(r.Context(), "failed to update device token", "err", err) s.renderError(r, w, http.StatusInternalServerError, "") return } @@ -258,7 +258,7 @@ func (s *Server) handleDeviceToken(w http.ResponseWriter, r *http.Request) { case providedCodeVerifier != "" && codeChallengeFromStorage != "": calculatedCodeChallenge, err := s.calculateCodeChallenge(providedCodeVerifier, deviceToken.PKCE.CodeChallengeMethod) if err != nil { - s.logger.Error("failed to calculate code challenge", "err", err) + s.logger.ErrorContext(r.Context(), "failed to calculate code challenge", "err", err) s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError) return } @@ -303,7 +303,7 @@ func (s *Server) handleDeviceCallback(w http.ResponseWriter, r *http.Request) { if err != nil || s.now().After(authCode.Expiry) { errCode := http.StatusBadRequest if err != nil && err != storage.ErrNotFound { - s.logger.Error("failed to get auth code", "err", err) + s.logger.ErrorContext(r.Context(), "failed to get auth code", "err", err) errCode = http.StatusInternalServerError } s.renderError(r, w, errCode, "Invalid or expired auth code.") @@ -315,7 +315,7 @@ func (s *Server) handleDeviceCallback(w http.ResponseWriter, r *http.Request) { if err != nil || s.now().After(deviceReq.Expiry) { errCode := http.StatusBadRequest if err != nil && err != storage.ErrNotFound { - s.logger.Error("failed to get device code", "err", err) + s.logger.ErrorContext(r.Context(), "failed to get device code", "err", err) errCode = http.StatusInternalServerError } s.renderError(r, w, errCode, "Invalid or expired user code.") @@ -325,7 +325,7 @@ func (s *Server) handleDeviceCallback(w http.ResponseWriter, r *http.Request) { client, err := s.storage.GetClient(deviceReq.ClientID) if err != nil { if err != storage.ErrNotFound { - s.logger.Error("failed to get client", "err", err) + s.logger.ErrorContext(r.Context(), "failed to get client", "err", err) s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError) } else { s.tokenErrHelper(w, errInvalidClient, "Invalid client credentials.", http.StatusUnauthorized) @@ -339,7 +339,7 @@ func (s *Server) handleDeviceCallback(w http.ResponseWriter, r *http.Request) { resp, err := s.exchangeAuthCode(ctx, w, authCode, client) if err != nil { - s.logger.Error("could not exchange auth code for clien", "client_id", deviceReq.ClientID, "err", err) + s.logger.ErrorContext(r.Context(), "could not exchange auth code for clien", "client_id", deviceReq.ClientID, "err", err) s.renderError(r, w, http.StatusInternalServerError, "Failed to exchange auth code.") return } @@ -349,7 +349,7 @@ func (s *Server) handleDeviceCallback(w http.ResponseWriter, r *http.Request) { if err != nil || s.now().After(old.Expiry) { errCode := http.StatusBadRequest if err != nil && err != storage.ErrNotFound { - s.logger.Error("failed to get device token", "err", err) + s.logger.ErrorContext(r.Context(), "failed to get device token", "err", err) errCode = http.StatusInternalServerError } s.renderError(r, w, errCode, "Invalid or expired device code.") @@ -362,7 +362,7 @@ func (s *Server) handleDeviceCallback(w http.ResponseWriter, r *http.Request) { } respStr, err := json.MarshalIndent(resp, "", " ") if err != nil { - s.logger.Error("failed to marshal device token response", "err", err) + s.logger.ErrorContext(r.Context(), "failed to marshal device token response", "err", err) s.renderError(r, w, http.StatusInternalServerError, "") return old, err } @@ -374,13 +374,13 @@ func (s *Server) handleDeviceCallback(w http.ResponseWriter, r *http.Request) { // Update refresh token in the storage, store the token and mark as complete if err := s.storage.UpdateDeviceToken(deviceReq.DeviceCode, updater); err != nil { - s.logger.Error("failed to update device token", "err", err) + s.logger.ErrorContext(r.Context(), "failed to update device token", "err", err) s.renderError(r, w, http.StatusBadRequest, "") return } if err := s.templates.deviceSuccess(r, w, client.Name); err != nil { - s.logger.Error("Server template error", "err", err) + s.logger.ErrorContext(r.Context(), "Server template error", "err", err) s.renderError(r, w, http.StatusNotFound, "Page not found") } @@ -412,10 +412,10 @@ func (s *Server) verifyUserCode(w http.ResponseWriter, r *http.Request) { deviceRequest, err := s.storage.GetDeviceRequest(userCode) if err != nil || s.now().After(deviceRequest.Expiry) { if err != nil && err != storage.ErrNotFound { - s.logger.Error("failed to get device request", "err", err) + s.logger.ErrorContext(r.Context(), "failed to get device request", "err", err) } if err := s.templates.device(r, w, s.getDeviceVerificationURI(), userCode, true); err != nil { - s.logger.Error("Server template error", "err", err) + s.logger.ErrorContext(r.Context(), "Server template error", "err", err) s.renderError(r, w, http.StatusNotFound, "Page not found") } return diff --git a/server/handlers.go b/server/handlers.go index 42f3ebe5d5..63cb612295 100644 --- a/server/handlers.go +++ b/server/handlers.go @@ -35,13 +35,13 @@ func (s *Server) handlePublicKeys(w http.ResponseWriter, r *http.Request) { // TODO(ericchiang): Cache this. keys, err := s.storage.GetKeys() if err != nil { - s.logger.Error("failed to get keys", "err", err) + s.logger.ErrorContext(r.Context(), "failed to get keys", "err", err) s.renderError(r, w, http.StatusInternalServerError, "Internal server error.") return } if keys.SigningKeyPub == nil { - s.logger.Error("no public keys found.") + s.logger.ErrorContext(r.Context(), "no public keys found.") s.renderError(r, w, http.StatusInternalServerError, "Internal server error.") return } @@ -56,7 +56,7 @@ func (s *Server) handlePublicKeys(w http.ResponseWriter, r *http.Request) { data, err := json.MarshalIndent(jwks, "", " ") if err != nil { - s.logger.Error("failed to marshal discovery data", "err", err) + s.logger.ErrorContext(r.Context(), "failed to marshal discovery data", "err", err) s.renderError(r, w, http.StatusInternalServerError, "Internal server error.") return } @@ -132,7 +132,7 @@ func (s *Server) discoveryHandler() (http.HandlerFunc, error) { func (s *Server) handleAuthorization(w http.ResponseWriter, r *http.Request) { // Extract the arguments if err := r.ParseForm(); err != nil { - s.logger.Error("failed to parse arguments", "err", err) + s.logger.ErrorContext(r.Context(), "failed to parse arguments", "err", err) s.renderError(r, w, http.StatusBadRequest, err.Error()) return @@ -142,7 +142,7 @@ func (s *Server) handleAuthorization(w http.ResponseWriter, r *http.Request) { connectors, err := s.storage.ListConnectors() if err != nil { - s.logger.Error("failed to get list of connectors", "err", err) + s.logger.ErrorContext(r.Context(), "failed to get list of connectors", "err", err) s.renderError(r, w, http.StatusInternalServerError, "Failed to retrieve connector list.") return } @@ -185,7 +185,7 @@ func (s *Server) handleAuthorization(w http.ResponseWriter, r *http.Request) { } if err := s.templates.login(r, w, connectorInfos); err != nil { - s.logger.Error("server template error", "err", err) + s.logger.ErrorContext(r.Context(), "server template error", "err", err) } } @@ -193,7 +193,7 @@ func (s *Server) handleConnectorLogin(w http.ResponseWriter, r *http.Request) { ctx := r.Context() authReq, err := s.parseAuthorizationRequest(r) if err != nil { - s.logger.Error("failed to parse authorization request", "err", err) + s.logger.ErrorContext(r.Context(), "failed to parse authorization request", "err", err) switch authErr := err.(type) { case *redirectedAuthErr: @@ -209,21 +209,21 @@ func (s *Server) handleConnectorLogin(w http.ResponseWriter, r *http.Request) { connID, err := url.PathUnescape(mux.Vars(r)["connector"]) if err != nil { - s.logger.Error("failed to parse connector", "err", err) + s.logger.ErrorContext(r.Context(), "failed to parse connector", "err", err) s.renderError(r, w, http.StatusBadRequest, "Requested resource does not exist") return } conn, err := s.getConnector(connID) if err != nil { - s.logger.Error("Failed to get connector", "err", err) + s.logger.ErrorContext(r.Context(), "Failed to get connector", "err", err) s.renderError(r, w, http.StatusBadRequest, "Requested resource does not exist") return } // Set the connector being used for the login. if authReq.ConnectorID != "" && authReq.ConnectorID != connID { - s.logger.Error("mismatched connector ID in auth request", + s.logger.ErrorContext(r.Context(), "mismatched connector ID in auth request", "auth_request_connector_id", authReq.ConnectorID, "connector_id", connID) s.renderError(r, w, http.StatusBadRequest, "Bad connector ID") return @@ -234,7 +234,7 @@ func (s *Server) handleConnectorLogin(w http.ResponseWriter, r *http.Request) { // Actually create the auth request authReq.Expiry = s.now().Add(s.authRequestsValidFor) if err := s.storage.CreateAuthRequest(ctx, *authReq); err != nil { - s.logger.Error("failed to create authorization request", "err", err) + s.logger.ErrorContext(r.Context(), "failed to create authorization request", "err", err) s.renderError(r, w, http.StatusInternalServerError, "Failed to connect to the database.") return } @@ -260,7 +260,7 @@ func (s *Server) handleConnectorLogin(w http.ResponseWriter, r *http.Request) { // TODO(ericchiang): Is this appropriate or should we also be using a nonce? callbackURL, err := conn.LoginURL(scopes, s.absURL("/callback"), authReq.ID) if err != nil { - s.logger.Error("connector returned error when creating callback", "connector_id", connID, "err", err) + s.logger.ErrorContext(r.Context(), "connector returned error when creating callback", "connector_id", connID, "err", err) s.renderError(r, w, http.StatusInternalServerError, "Login error.") return } @@ -278,7 +278,7 @@ func (s *Server) handleConnectorLogin(w http.ResponseWriter, r *http.Request) { case connector.SAMLConnector: action, value, err := conn.POSTData(scopes, authReq.ID) if err != nil { - s.logger.Error("creating SAML data", "err", err) + s.logger.ErrorContext(r.Context(), "creating SAML data", "err", err) s.renderError(r, w, http.StatusInternalServerError, "Connector Login Error") return } @@ -309,7 +309,6 @@ func (s *Server) handleConnectorLogin(w http.ResponseWriter, r *http.Request) { } func (s *Server) handlePasswordLogin(w http.ResponseWriter, r *http.Request) { - ctx := r.Context() authID := r.URL.Query().Get("state") if authID == "" { s.renderError(r, w, http.StatusBadRequest, "User session error.") @@ -321,36 +320,36 @@ func (s *Server) handlePasswordLogin(w http.ResponseWriter, r *http.Request) { authReq, err := s.storage.GetAuthRequest(authID) if err != nil { if err == storage.ErrNotFound { - s.logger.Error("invalid 'state' parameter provided", "err", err) + s.logger.ErrorContext(r.Context(), "invalid 'state' parameter provided", "err", err) s.renderError(r, w, http.StatusBadRequest, "Requested resource does not exist.") return } - s.logger.Error("failed to get auth request", "err", err) + s.logger.ErrorContext(r.Context(), "failed to get auth request", "err", err) s.renderError(r, w, http.StatusInternalServerError, "Database error.") return } connID, err := url.PathUnescape(mux.Vars(r)["connector"]) if err != nil { - s.logger.Error("failed to parse connector", "err", err) + s.logger.ErrorContext(r.Context(), "failed to parse connector", "err", err) s.renderError(r, w, http.StatusBadRequest, "Requested resource does not exist") return } else if connID != "" && connID != authReq.ConnectorID { - s.logger.Error("connector mismatch: password login triggered for different connector from authentication start", "start_connector_id", authReq.ConnectorID, "password_connector_id", connID) + s.logger.ErrorContext(r.Context(), "connector mismatch: password login triggered for different connector from authentication start", "start_connector_id", authReq.ConnectorID, "password_connector_id", connID) s.renderError(r, w, http.StatusInternalServerError, "Requested resource does not exist.") return } conn, err := s.getConnector(authReq.ConnectorID) if err != nil { - s.logger.Error("failed to get connector", "connector_id", authReq.ConnectorID, "err", err) + s.logger.ErrorContext(r.Context(), "failed to get connector", "connector_id", authReq.ConnectorID, "err", err) s.renderError(r, w, http.StatusInternalServerError, "Requested resource does not exist.") return } pwConn, ok := conn.Connector.(connector.PasswordConnector) if !ok { - s.logger.Error("expected password connector in handlePasswordLogin()", "password_connector", pwConn) + s.logger.ErrorContext(r.Context(), "expected password connector in handlePasswordLogin()", "password_connector", pwConn) s.renderError(r, w, http.StatusInternalServerError, "Requested resource does not exist.") return } @@ -358,29 +357,29 @@ func (s *Server) handlePasswordLogin(w http.ResponseWriter, r *http.Request) { switch r.Method { case http.MethodGet: if err := s.templates.password(r, w, r.URL.String(), "", usernamePrompt(pwConn), false, backLink); err != nil { - s.logger.Error("server template error", "err", err) + s.logger.ErrorContext(r.Context(), "server template error", "err", err) } case http.MethodPost: username := r.FormValue("login") password := r.FormValue("password") scopes := parseScopes(authReq.Scopes) - identity, ok, err := pwConn.Login(ctx, scopes, username, password) + identity, ok, err := pwConn.Login(r.Context(), scopes, username, password) if err != nil { - s.logger.Error("failed to login user", "err", err) + s.logger.ErrorContext(r.Context(), "failed to login user", "err", err) s.renderError(r, w, http.StatusInternalServerError, fmt.Sprintf("Login error: %v", err)) return } if !ok { if err := s.templates.password(r, w, r.URL.String(), username, usernamePrompt(pwConn), true, backLink); err != nil { - s.logger.Error("server template error", "err", err) + s.logger.ErrorContext(r.Context(), "server template error", "err", err) } - s.logger.Error("failed login attempt: Invalid credentials.", "user", username) + s.logger.ErrorContext(r.Context(), "failed login attempt: Invalid credentials.", "user", username) return } - redirectURL, canSkipApproval, err := s.finalizeLogin(ctx, identity, authReq, conn.Connector) + redirectURL, canSkipApproval, err := s.finalizeLogin(r.Context(), identity, authReq, conn.Connector) if err != nil { - s.logger.Error("failed to finalize login", "err", err) + s.logger.ErrorContext(r.Context(), "failed to finalize login", "err", err) s.renderError(r, w, http.StatusInternalServerError, "Login error.") return } @@ -388,7 +387,7 @@ func (s *Server) handlePasswordLogin(w http.ResponseWriter, r *http.Request) { if canSkipApproval { authReq, err = s.storage.GetAuthRequest(authReq.ID) if err != nil { - s.logger.Error("failed to get finalized auth request", "err", err) + s.logger.ErrorContext(r.Context(), "failed to get finalized auth request", "err", err) s.renderError(r, w, http.StatusInternalServerError, "Login error.") return } @@ -424,29 +423,29 @@ func (s *Server) handleConnectorCallback(w http.ResponseWriter, r *http.Request) authReq, err := s.storage.GetAuthRequest(authID) if err != nil { if err == storage.ErrNotFound { - s.logger.Error("invalid 'state' parameter provided", "err", err) + s.logger.ErrorContext(r.Context(), "invalid 'state' parameter provided", "err", err) s.renderError(r, w, http.StatusBadRequest, "Requested resource does not exist.") return } - s.logger.Error("failed to get auth request", "err", err) + s.logger.ErrorContext(r.Context(), "failed to get auth request", "err", err) s.renderError(r, w, http.StatusInternalServerError, "Database error.") return } connID, err := url.PathUnescape(mux.Vars(r)["connector"]) if err != nil { - s.logger.Error("failed to get connector", "connector_id", authReq.ConnectorID, "err", err) + s.logger.ErrorContext(r.Context(), "failed to get connector", "connector_id", authReq.ConnectorID, "err", err) s.renderError(r, w, http.StatusInternalServerError, "Requested resource does not exist.") return } else if connID != "" && connID != authReq.ConnectorID { - s.logger.Error("connector mismatch: callback triggered for different connector than authentication start", "authentication_start_connector_id", authReq.ConnectorID, "connector_id", connID) + s.logger.ErrorContext(r.Context(), "connector mismatch: callback triggered for different connector than authentication start", "authentication_start_connector_id", authReq.ConnectorID, "connector_id", connID) s.renderError(r, w, http.StatusInternalServerError, "Requested resource does not exist.") return } conn, err := s.getConnector(authReq.ConnectorID) if err != nil { - s.logger.Error("failed to get connector", "connector_id", authReq.ConnectorID, "err", err) + s.logger.ErrorContext(r.Context(), "failed to get connector", "connector_id", authReq.ConnectorID, "err", err) s.renderError(r, w, http.StatusInternalServerError, "Requested resource does not exist.") return } @@ -455,14 +454,14 @@ func (s *Server) handleConnectorCallback(w http.ResponseWriter, r *http.Request) switch conn := conn.Connector.(type) { case connector.CallbackConnector: if r.Method != http.MethodGet { - s.logger.Error("SAML request mapped to OAuth2 connector") + s.logger.ErrorContext(r.Context(), "SAML request mapped to OAuth2 connector") s.renderError(r, w, http.StatusBadRequest, "Invalid request") return } identity, err = conn.HandleCallback(parseScopes(authReq.Scopes), r) case connector.SAMLConnector: if r.Method != http.MethodPost { - s.logger.Error("OAuth2 request mapped to SAML connector") + s.logger.ErrorContext(r.Context(), "OAuth2 request mapped to SAML connector") s.renderError(r, w, http.StatusBadRequest, "Invalid request") return } @@ -473,14 +472,14 @@ func (s *Server) handleConnectorCallback(w http.ResponseWriter, r *http.Request) } if err != nil { - s.logger.Error("failed to authenticate", "err", err) + s.logger.ErrorContext(r.Context(), "failed to authenticate", "err", err) s.renderError(r, w, http.StatusInternalServerError, fmt.Sprintf("Failed to authenticate: %v", err)) return } redirectURL, canSkipApproval, err := s.finalizeLogin(ctx, identity, authReq, conn.Connector) if err != nil { - s.logger.Error("failed to finalize login", "err", err) + s.logger.ErrorContext(r.Context(), "failed to finalize login", "err", err) s.renderError(r, w, http.StatusInternalServerError, "Login error.") return } @@ -488,7 +487,7 @@ func (s *Server) handleConnectorCallback(w http.ResponseWriter, r *http.Request) if canSkipApproval { authReq, err = s.storage.GetAuthRequest(authReq.ID) if err != nil { - s.logger.Error("failed to get finalized auth request", "err", err) + s.logger.ErrorContext(r.Context(), "failed to get finalized auth request", "err", err) s.renderError(r, w, http.StatusInternalServerError, "Login error.") return } @@ -526,7 +525,7 @@ func (s *Server) finalizeLogin(ctx context.Context, identity connector.Identity, email += " (unverified)" } - s.logger.Info("login successful", + s.logger.InfoContext(ctx, "login successful", "connector_id", authReq.ConnectorID, "username", claims.Username, "preferred_username", claims.PreferredUsername, "email", email, "groups", claims.Groups) @@ -562,7 +561,7 @@ func (s *Server) finalizeLogin(ctx context.Context, identity connector.Identity, session, err := s.storage.GetOfflineSessions(identity.UserID, authReq.ConnectorID) if err != nil { if err != storage.ErrNotFound { - s.logger.Error("failed to get offline session", "err", err) + s.logger.ErrorContext(ctx, "failed to get offline session", "err", err) return "", false, err } offlineSessions := storage.OfflineSessions{ @@ -575,7 +574,7 @@ func (s *Server) finalizeLogin(ctx context.Context, identity connector.Identity, // Create a new OfflineSession object for the user and add a reference object for // the newly received refreshtoken. if err := s.storage.CreateOfflineSessions(ctx, offlineSessions); err != nil { - s.logger.Error("failed to create offline session", "err", err) + s.logger.ErrorContext(ctx, "failed to create offline session", "err", err) return "", false, err } @@ -589,7 +588,7 @@ func (s *Server) finalizeLogin(ctx context.Context, identity connector.Identity, } return old, nil }); err != nil { - s.logger.Error("failed to update offline session", "err", err) + s.logger.ErrorContext(ctx, "failed to update offline session", "err", err) return "", false, err } @@ -610,12 +609,12 @@ func (s *Server) handleApproval(w http.ResponseWriter, r *http.Request) { authReq, err := s.storage.GetAuthRequest(r.FormValue("req")) if err != nil { - s.logger.Error("failed to get auth request", "err", err) + s.logger.ErrorContext(r.Context(), "failed to get auth request", "err", err) s.renderError(r, w, http.StatusInternalServerError, "Database error.") return } if !authReq.LoggedIn { - s.logger.Error("auth request does not have an identity for approval") + s.logger.ErrorContext(r.Context(), "auth request does not have an identity for approval") s.renderError(r, w, http.StatusInternalServerError, "Login process not yet finalized.") return } @@ -634,12 +633,12 @@ func (s *Server) handleApproval(w http.ResponseWriter, r *http.Request) { case http.MethodGet: client, err := s.storage.GetClient(authReq.ClientID) if err != nil { - s.logger.Error("Failed to get client", "client_id", authReq.ClientID, "err", err) + s.logger.ErrorContext(r.Context(), "Failed to get client", "client_id", authReq.ClientID, "err", err) s.renderError(r, w, http.StatusInternalServerError, "Failed to retrieve client.") return } if err := s.templates.approval(r, w, authReq.ID, authReq.Claims.Username, client.Name, authReq.Scopes); err != nil { - s.logger.Error("server template error", "err", err) + s.logger.ErrorContext(r.Context(), "server template error", "err", err) } case http.MethodPost: if r.FormValue("approval") != "approve" { @@ -659,7 +658,7 @@ func (s *Server) sendCodeResponse(w http.ResponseWriter, r *http.Request, authRe if err := s.storage.DeleteAuthRequest(authReq.ID); err != nil { if err != storage.ErrNotFound { - s.logger.Error("Failed to delete authorization request", "err", err) + s.logger.ErrorContext(r.Context(), "Failed to delete authorization request", "err", err) s.renderError(r, w, http.StatusInternalServerError, "Internal server error.") } else { s.renderError(r, w, http.StatusBadRequest, "User session error.") @@ -705,7 +704,7 @@ func (s *Server) sendCodeResponse(w http.ResponseWriter, r *http.Request, authRe PKCE: authReq.PKCE, } if err := s.storage.CreateAuthCode(ctx, code); err != nil { - s.logger.Error("Failed to create auth code", "err", err) + s.logger.ErrorContext(r.Context(), "Failed to create auth code", "err", err) s.renderError(r, w, http.StatusInternalServerError, "Internal server error.") return } @@ -714,7 +713,7 @@ func (s *Server) sendCodeResponse(w http.ResponseWriter, r *http.Request, authRe // rejected earlier. If we got here we're using the code flow. if authReq.RedirectURI == redirectURIOOB { if err := s.templates.oob(r, w, code.ID); err != nil { - s.logger.Error("server template error", "err", err) + s.logger.ErrorContext(r.Context(), "server template error", "err", err) } return } @@ -724,16 +723,16 @@ func (s *Server) sendCodeResponse(w http.ResponseWriter, r *http.Request, authRe implicitOrHybrid = true var err error - accessToken, _, err = s.newAccessToken(authReq.ClientID, authReq.Claims, authReq.Scopes, authReq.Nonce, authReq.ConnectorID) + accessToken, _, err = s.newAccessToken(r.Context(), authReq.ClientID, authReq.Claims, authReq.Scopes, authReq.Nonce, authReq.ConnectorID) if err != nil { - s.logger.Error("failed to create new access token", "err", err) + s.logger.ErrorContext(r.Context(), "failed to create new access token", "err", err) s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError) return } - idToken, idTokenExpiry, err = s.newIDToken(authReq.ClientID, authReq.Claims, authReq.Scopes, authReq.Nonce, accessToken, code.ID, authReq.ConnectorID) + idToken, idTokenExpiry, err = s.newIDToken(r.Context(), authReq.ClientID, authReq.Claims, authReq.Scopes, authReq.Nonce, accessToken, code.ID, authReq.ConnectorID) if err != nil { - s.logger.Error("failed to create ID token", "err", err) + s.logger.ErrorContext(r.Context(), "failed to create ID token", "err", err) s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError) return } @@ -808,7 +807,7 @@ func (s *Server) withClientFromStorage(w http.ResponseWriter, r *http.Request, h client, err := s.storage.GetClient(clientID) if err != nil { if err != storage.ErrNotFound { - s.logger.Error("failed to get client", "err", err) + s.logger.ErrorContext(r.Context(), "failed to get client", "err", err) s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError) } else { s.tokenErrHelper(w, errInvalidClient, "Invalid client credentials.", http.StatusUnauthorized) @@ -818,9 +817,9 @@ func (s *Server) withClientFromStorage(w http.ResponseWriter, r *http.Request, h if subtle.ConstantTimeCompare([]byte(client.Secret), []byte(clientSecret)) != 1 { if clientSecret == "" { - s.logger.Info("missing client_secret on token request", "client_id", client.ID) + s.logger.InfoContext(r.Context(), "missing client_secret on token request", "client_id", client.ID) } else { - s.logger.Info("invalid client_secret on token request", "client_id", client.ID) + s.logger.InfoContext(r.Context(), "invalid client_secret on token request", "client_id", client.ID) } s.tokenErrHelper(w, errInvalidClient, "Invalid client credentials.", http.StatusUnauthorized) return @@ -838,14 +837,14 @@ func (s *Server) handleToken(w http.ResponseWriter, r *http.Request) { err := r.ParseForm() if err != nil { - s.logger.Error("could not parse request body", "err", err) + s.logger.ErrorContext(r.Context(), "could not parse request body", "err", err) s.tokenErrHelper(w, errInvalidRequest, "", http.StatusBadRequest) return } grantType := r.PostFormValue("grant_type") if !contains(s.supportedGrantTypes, grantType) { - s.logger.Error("unsupported grant type", "grant_type", grantType) + s.logger.ErrorContext(r.Context(), "unsupported grant type", "grant_type", grantType) s.tokenErrHelper(w, errUnsupportedGrantType, "", http.StatusBadRequest) return } @@ -891,7 +890,7 @@ func (s *Server) handleAuthCode(w http.ResponseWriter, r *http.Request, client s authCode, err := s.storage.GetAuthCode(code) if err != nil || s.now().After(authCode.Expiry) || authCode.ClientID != client.ID { if err != storage.ErrNotFound { - s.logger.Error("failed to get auth code", "err", err) + s.logger.ErrorContext(r.Context(), "failed to get auth code", "err", err) s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError) } else { s.tokenErrHelper(w, errInvalidGrant, "Invalid or expired code parameter.", http.StatusBadRequest) @@ -907,7 +906,7 @@ func (s *Server) handleAuthCode(w http.ResponseWriter, r *http.Request, client s case providedCodeVerifier != "" && codeChallengeFromStorage != "": calculatedCodeChallenge, err := s.calculateCodeChallenge(providedCodeVerifier, authCode.PKCE.CodeChallengeMethod) if err != nil { - s.logger.Error("failed to calculate code challenge", "err", err) + s.logger.ErrorContext(r.Context(), "failed to calculate code challenge", "err", err) s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError) return } @@ -939,22 +938,22 @@ func (s *Server) handleAuthCode(w http.ResponseWriter, r *http.Request, client s } func (s *Server) exchangeAuthCode(ctx context.Context, w http.ResponseWriter, authCode storage.AuthCode, client storage.Client) (*accessTokenResponse, error) { - accessToken, _, err := s.newAccessToken(client.ID, authCode.Claims, authCode.Scopes, authCode.Nonce, authCode.ConnectorID) + accessToken, _, err := s.newAccessToken(ctx, client.ID, authCode.Claims, authCode.Scopes, authCode.Nonce, authCode.ConnectorID) if err != nil { - s.logger.Error("failed to create new access token", "err", err) + s.logger.ErrorContext(ctx, "failed to create new access token", "err", err) s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError) return nil, err } - idToken, expiry, err := s.newIDToken(client.ID, authCode.Claims, authCode.Scopes, authCode.Nonce, accessToken, authCode.ID, authCode.ConnectorID) + idToken, expiry, err := s.newIDToken(ctx, client.ID, authCode.Claims, authCode.Scopes, authCode.Nonce, accessToken, authCode.ID, authCode.ConnectorID) if err != nil { - s.logger.Error("failed to create ID token", "err", err) + s.logger.ErrorContext(ctx, "failed to create ID token", "err", err) s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError) return nil, err } if err := s.storage.DeleteAuthCode(authCode.ID); err != nil { - s.logger.Error("failed to delete auth code", "err", err) + s.logger.ErrorContext(ctx, "failed to delete auth code", "err", err) s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError) return nil, err } @@ -965,7 +964,7 @@ func (s *Server) exchangeAuthCode(ctx context.Context, w http.ResponseWriter, au // Connectors like `saml` do not implement RefreshConnector. conn, err := s.getConnector(authCode.ConnectorID) if err != nil { - s.logger.Error("connector not found", "connector_id", authCode.ConnectorID, "err", err) + s.logger.ErrorContext(ctx, "connector not found", "connector_id", authCode.ConnectorID, "err", err) s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError) return false } @@ -1001,13 +1000,13 @@ func (s *Server) exchangeAuthCode(ctx context.Context, w http.ResponseWriter, au Token: refresh.Token, } if refreshToken, err = internal.Marshal(token); err != nil { - s.logger.Error("failed to marshal refresh token", "err", err) + s.logger.ErrorContext(ctx, "failed to marshal refresh token", "err", err) s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError) return nil, err } if err := s.storage.CreateRefresh(ctx, refresh); err != nil { - s.logger.Error("failed to create refresh token", "err", err) + s.logger.ErrorContext(ctx, "failed to create refresh token", "err", err) s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError) return nil, err } @@ -1020,7 +1019,7 @@ func (s *Server) exchangeAuthCode(ctx context.Context, w http.ResponseWriter, au if deleteToken { // Delete newly created refresh token from storage. if err := s.storage.DeleteRefresh(refresh.ID); err != nil { - s.logger.Error("failed to delete refresh token", "err", err) + s.logger.ErrorContext(ctx, "failed to delete refresh token", "err", err) s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError) return } @@ -1037,7 +1036,7 @@ func (s *Server) exchangeAuthCode(ctx context.Context, w http.ResponseWriter, au // Try to retrieve an existing OfflineSession object for the corresponding user. if session, err := s.storage.GetOfflineSessions(refresh.Claims.UserID, refresh.ConnectorID); err != nil { if err != storage.ErrNotFound { - s.logger.Error("failed to get offline session", "err", err) + s.logger.ErrorContext(ctx, "failed to get offline session", "err", err) s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError) deleteToken = true return nil, err @@ -1052,7 +1051,7 @@ func (s *Server) exchangeAuthCode(ctx context.Context, w http.ResponseWriter, au // Create a new OfflineSession object for the user and add a reference object for // the newly received refreshtoken. if err := s.storage.CreateOfflineSessions(ctx, offlineSessions); err != nil { - s.logger.Error("failed to create offline session", "err", err) + s.logger.ErrorContext(ctx, "failed to create offline session", "err", err) s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError) deleteToken = true return nil, err @@ -1061,7 +1060,7 @@ func (s *Server) exchangeAuthCode(ctx context.Context, w http.ResponseWriter, au if oldTokenRef, ok := session.Refresh[tokenRef.ClientID]; ok { // Delete old refresh token from storage. if err := s.storage.DeleteRefresh(oldTokenRef.ID); err != nil && err != storage.ErrNotFound { - s.logger.Error("failed to delete refresh token", "err", err) + s.logger.ErrorContext(ctx, "failed to delete refresh token", "err", err) s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError) deleteToken = true return nil, err @@ -1073,7 +1072,7 @@ func (s *Server) exchangeAuthCode(ctx context.Context, w http.ResponseWriter, au old.Refresh[tokenRef.ClientID] = &tokenRef return old, nil }); err != nil { - s.logger.Error("failed to update offline session", "err", err) + s.logger.ErrorContext(ctx, "failed to update offline session", "err", err) s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError) deleteToken = true return nil, err @@ -1143,7 +1142,7 @@ func (s *Server) handlePasswordGrant(w http.ResponseWriter, r *http.Request, cli continue } - isTrusted, err := s.validateCrossClientTrust(client.ID, peerID) + isTrusted, err := s.validateCrossClientTrust(r.Context(), client.ID, peerID) if err != nil { s.tokenErrHelper(w, errInvalidClient, fmt.Sprintf("Error validating cross client trust %v.", err), http.StatusBadRequest) return @@ -1185,7 +1184,7 @@ func (s *Server) handlePasswordGrant(w http.ResponseWriter, r *http.Request, cli password := q.Get("password") identity, ok, err := passwordConnector.Login(ctx, parseScopes(scopes), username, password) if err != nil { - s.logger.Error("failed to login user", "err", err) + s.logger.ErrorContext(r.Context(), "failed to login user", "err", err) s.tokenErrHelper(w, errInvalidRequest, "Could not login user", http.StatusBadRequest) return } @@ -1204,16 +1203,16 @@ func (s *Server) handlePasswordGrant(w http.ResponseWriter, r *http.Request, cli Groups: identity.Groups, } - accessToken, _, err := s.newAccessToken(client.ID, claims, scopes, nonce, connID) + accessToken, _, err := s.newAccessToken(r.Context(), client.ID, claims, scopes, nonce, connID) if err != nil { - s.logger.Error("password grant failed to create new access token", "err", err) + s.logger.ErrorContext(r.Context(), "password grant failed to create new access token", "err", err) s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError) return } - idToken, expiry, err := s.newIDToken(client.ID, claims, scopes, nonce, accessToken, "", connID) + idToken, expiry, err := s.newIDToken(r.Context(), client.ID, claims, scopes, nonce, accessToken, "", connID) if err != nil { - s.logger.Error("password grant failed to create new ID token", "err", err) + s.logger.ErrorContext(r.Context(), "password grant failed to create new ID token", "err", err) s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError) return } @@ -1253,13 +1252,13 @@ func (s *Server) handlePasswordGrant(w http.ResponseWriter, r *http.Request, cli Token: refresh.Token, } if refreshToken, err = internal.Marshal(token); err != nil { - s.logger.Error("failed to marshal refresh token", "err", err) + s.logger.ErrorContext(r.Context(), "failed to marshal refresh token", "err", err) s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError) return } if err := s.storage.CreateRefresh(ctx, refresh); err != nil { - s.logger.Error("failed to create refresh token", "err", err) + s.logger.ErrorContext(r.Context(), "failed to create refresh token", "err", err) s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError) return } @@ -1272,7 +1271,7 @@ func (s *Server) handlePasswordGrant(w http.ResponseWriter, r *http.Request, cli if deleteToken { // Delete newly created refresh token from storage. if err := s.storage.DeleteRefresh(refresh.ID); err != nil { - s.logger.Error("failed to delete refresh token", "err", err) + s.logger.ErrorContext(r.Context(), "failed to delete refresh token", "err", err) s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError) return } @@ -1289,7 +1288,7 @@ func (s *Server) handlePasswordGrant(w http.ResponseWriter, r *http.Request, cli // Try to retrieve an existing OfflineSession object for the corresponding user. if session, err := s.storage.GetOfflineSessions(refresh.Claims.UserID, refresh.ConnectorID); err != nil { if err != storage.ErrNotFound { - s.logger.Error("failed to get offline session", "err", err) + s.logger.ErrorContext(r.Context(), "failed to get offline session", "err", err) s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError) deleteToken = true return @@ -1305,7 +1304,7 @@ func (s *Server) handlePasswordGrant(w http.ResponseWriter, r *http.Request, cli // Create a new OfflineSession object for the user and add a reference object for // the newly received refreshtoken. if err := s.storage.CreateOfflineSessions(ctx, offlineSessions); err != nil { - s.logger.Error("failed to create offline session", "err", err) + s.logger.ErrorContext(r.Context(), "failed to create offline session", "err", err) s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError) deleteToken = true return @@ -1317,7 +1316,7 @@ func (s *Server) handlePasswordGrant(w http.ResponseWriter, r *http.Request, cli if err == storage.ErrNotFound { s.logger.Warn("database inconsistent, refresh token missing", "token_id", oldTokenRef.ID) } else { - s.logger.Error("failed to delete refresh token", "err", err) + s.logger.ErrorContext(r.Context(), "failed to delete refresh token", "err", err) s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError) deleteToken = true return @@ -1331,7 +1330,7 @@ func (s *Server) handlePasswordGrant(w http.ResponseWriter, r *http.Request, cli old.ConnectorData = identity.ConnectorData return old, nil }); err != nil { - s.logger.Error("failed to update offline session", "err", err) + s.logger.ErrorContext(r.Context(), "failed to update offline session", "err", err) s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError) deleteToken = true return @@ -1347,7 +1346,7 @@ func (s *Server) handleTokenExchange(w http.ResponseWriter, r *http.Request, cli ctx := r.Context() if err := r.ParseForm(); err != nil { - s.logger.Error("could not parse request body", "err", err) + s.logger.ErrorContext(r.Context(), "could not parse request body", "err", err) s.tokenErrHelper(w, errInvalidRequest, "", http.StatusBadRequest) return } @@ -1376,19 +1375,19 @@ func (s *Server) handleTokenExchange(w http.ResponseWriter, r *http.Request, cli conn, err := s.getConnector(connID) if err != nil { - s.logger.Error("failed to get connector", "err", err) + s.logger.ErrorContext(r.Context(), "failed to get connector", "err", err) s.tokenErrHelper(w, errInvalidRequest, "Requested connector does not exist.", http.StatusBadRequest) return } teConn, ok := conn.Connector.(connector.TokenIdentityConnector) if !ok { - s.logger.Error("connector doesn't implement token exchange", "connector_id", connID) + s.logger.ErrorContext(r.Context(), "connector doesn't implement token exchange", "connector_id", connID) s.tokenErrHelper(w, errInvalidRequest, "Requested connector does not exist.", http.StatusBadRequest) return } identity, err := teConn.TokenIdentity(ctx, subjectTokenType, subjectToken) if err != nil { - s.logger.Error("failed to verify subject token", "err", err) + s.logger.ErrorContext(r.Context(), "failed to verify subject token", "err", err) s.tokenErrHelper(w, errAccessDenied, "", http.StatusUnauthorized) return } @@ -1408,15 +1407,15 @@ func (s *Server) handleTokenExchange(w http.ResponseWriter, r *http.Request, cli var expiry time.Time switch requestedTokenType { case tokenTypeID: - resp.AccessToken, expiry, err = s.newIDToken(client.ID, claims, scopes, "", "", "", connID) + resp.AccessToken, expiry, err = s.newIDToken(r.Context(), client.ID, claims, scopes, "", "", "", connID) case tokenTypeAccess: - resp.AccessToken, expiry, err = s.newAccessToken(client.ID, claims, scopes, "", connID) + resp.AccessToken, expiry, err = s.newAccessToken(r.Context(), client.ID, claims, scopes, "", connID) default: s.tokenErrHelper(w, errRequestNotSupported, "Invalid requested_token_type.", http.StatusBadRequest) return } if err != nil { - s.logger.Error("token exchange failed to create new token", "requested_token_type", requestedTokenType, "err", err) + s.logger.ErrorContext(r.Context(), "token exchange failed to create new token", "requested_token_type", requestedTokenType, "err", err) s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError) return } @@ -1452,6 +1451,7 @@ func (s *Server) toAccessTokenResponse(idToken, accessToken, refreshToken string func (s *Server) writeAccessToken(w http.ResponseWriter, resp *accessTokenResponse) { data, err := json.Marshal(resp) if err != nil { + // TODO(nabokihms): error with context s.logger.Error("failed to marshal access token response", "err", err) s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError) return @@ -1467,12 +1467,13 @@ func (s *Server) writeAccessToken(w http.ResponseWriter, resp *accessTokenRespon func (s *Server) renderError(r *http.Request, w http.ResponseWriter, status int, description string) { if err := s.templates.err(r, w, status, description); err != nil { - s.logger.Error("server template error", "err", err) + s.logger.ErrorContext(r.Context(), "server template error", "err", err) } } func (s *Server) tokenErrHelper(w http.ResponseWriter, typ string, description string, statusCode int) { if err := tokenErr(w, typ, description, statusCode); err != nil { + // TODO(nabokihms): error with context s.logger.Error("token error response", "err", err) } } diff --git a/server/introspectionhandler.go b/server/introspectionhandler.go index 8c6e4419f3..ffcbb13679 100644 --- a/server/introspectionhandler.go +++ b/server/introspectionhandler.go @@ -179,7 +179,7 @@ func (s *Server) getTokenFromRequest(r *http.Request) (string, TokenTypeEnum, er token := r.PostForm.Get("token") tokenType, err := s.guessTokenType(r.Context(), token) if err != nil { - s.logger.Error("failed to guess token type", "err", err) + s.logger.ErrorContext(r.Context(), "failed to guess token type", "err", err) return "", 0, newIntrospectInternalServerError() } @@ -193,7 +193,7 @@ func (s *Server) getTokenFromRequest(r *http.Request) (string, TokenTypeEnum, er return token, tokenType, nil } -func (s *Server) introspectRefreshToken(_ context.Context, token string) (*Introspection, error) { +func (s *Server) introspectRefreshToken(ctx context.Context, token string) (*Introspection, error) { rToken := new(internal.RefreshToken) if err := internal.Unmarshal(token, rToken); err != nil { // For backward compatibility, assume the refresh_token is a raw refresh token ID @@ -205,19 +205,19 @@ func (s *Server) introspectRefreshToken(_ context.Context, token string) (*Intro rToken = &internal.RefreshToken{RefreshId: token, Token: ""} } - rCtx, err := s.getRefreshTokenFromStorage(nil, rToken) + rCtx, err := s.getRefreshTokenFromStorage(ctx, nil, rToken) if err != nil { if errors.Is(err, invalidErr) || errors.Is(err, expiredErr) { return nil, newIntrospectInactiveTokenError() } - s.logger.Error("failed to get refresh token", "err", err) + s.logger.ErrorContext(ctx, "failed to get refresh token", "err", err) return nil, newIntrospectInternalServerError() } subjectString, sErr := genSubject(rCtx.storageToken.Claims.UserID, rCtx.storageToken.ConnectorID) if sErr != nil { - s.logger.Error("failed to marshal offline session ID", "err", err) + s.logger.ErrorContext(ctx, "failed to marshal offline session ID", "err", err) return nil, newIntrospectInternalServerError() } @@ -253,19 +253,19 @@ func (s *Server) introspectAccessToken(ctx context.Context, token string) (*Intr var claims IntrospectionExtra if err := idToken.Claims(&claims); err != nil { - s.logger.Error("error while fetching token claims", "err", err.Error()) + s.logger.ErrorContext(ctx, "error while fetching token claims", "err", err.Error()) return nil, newIntrospectInternalServerError() } clientID, err := getClientID(idToken.Audience, claims.AuthorizingParty) if err != nil { - s.logger.Error("error while fetching client_id from token:", "err", err.Error()) + s.logger.ErrorContext(ctx, "error while fetching client_id from token:", "err", err.Error()) return nil, newIntrospectInternalServerError() } client, err := s.storage.GetClient(clientID) if err != nil { - s.logger.Error("error while fetching client from storage", "err", err.Error()) + s.logger.ErrorContext(ctx, "error while fetching client from storage", "err", err.Error()) return nil, newIntrospectInternalServerError() } @@ -299,7 +299,7 @@ func (s *Server) handleIntrospect(w http.ResponseWriter, r *http.Request) { introspect, err = s.introspectRefreshToken(ctx, token) default: // Token type is neither handled token types. - s.logger.Error("unknown token type", "token_type", tokenType) + s.logger.ErrorContext(r.Context(), "unknown token type", "token_type", tokenType) introspectInactiveErr(w) return } @@ -309,7 +309,7 @@ func (s *Server) handleIntrospect(w http.ResponseWriter, r *http.Request) { if intErr, ok := err.(*introspectionError); ok { s.introspectErrHelper(w, intErr.typ, intErr.desc, intErr.code) } else { - s.logger.Error("an unknown error occurred", "err", err.Error()) + s.logger.ErrorContext(r.Context(), "an unknown error occurred", "err", err.Error()) s.introspectErrHelper(w, errServerError, "An unknown error occurred", http.StatusInternalServerError) } @@ -332,6 +332,7 @@ func (s *Server) introspectErrHelper(w http.ResponseWriter, typ string, descript } if err := tokenErr(w, typ, description, statusCode); err != nil { + // TODO(nabokihms): error with context s.logger.Error("introspect error response", "err", err) } } diff --git a/server/introspectionhandler_test.go b/server/introspectionhandler_test.go index 2b17c2e9f5..695bbad8e6 100644 --- a/server/introspectionhandler_test.go +++ b/server/introspectionhandler_test.go @@ -259,7 +259,7 @@ func TestHandleIntrospect(t *testing.T) { mockTestStorage(t, s.storage) - activeAccessToken, expiry, err := s.newIDToken("test", storage.Claims{ + activeAccessToken, expiry, err := s.newIDToken(ctx, "test", storage.Claims{ UserID: "1", Username: "jane", Email: "jane.doe@example.com", diff --git a/server/oauth2.go b/server/oauth2.go index 3d9cfc8fe7..ec972beab1 100644 --- a/server/oauth2.go +++ b/server/oauth2.go @@ -303,8 +303,8 @@ type federatedIDClaims struct { UserID string `json:"user_id,omitempty"` } -func (s *Server) newAccessToken(clientID string, claims storage.Claims, scopes []string, nonce, connID string) (accessToken string, expiry time.Time, err error) { - return s.newIDToken(clientID, claims, scopes, nonce, storage.NewID(), "", connID) +func (s *Server) newAccessToken(ctx context.Context, clientID string, claims storage.Claims, scopes []string, nonce, connID string) (accessToken string, expiry time.Time, err error) { + return s.newIDToken(ctx, clientID, claims, scopes, nonce, storage.NewID(), "", connID) } func getClientID(aud audience, azp string) (string, error) { @@ -350,10 +350,10 @@ func genSubject(userID string, connID string) (string, error) { return internal.Marshal(sub) } -func (s *Server) newIDToken(clientID string, claims storage.Claims, scopes []string, nonce, accessToken, code, connID string) (idToken string, expiry time.Time, err error) { +func (s *Server) newIDToken(ctx context.Context, clientID string, claims storage.Claims, scopes []string, nonce, accessToken, code, connID string) (idToken string, expiry time.Time, err error) { keys, err := s.storage.GetKeys() if err != nil { - s.logger.Error("failed to get keys", "err", err) + s.logger.ErrorContext(ctx, "failed to get keys", "err", err) return "", expiry, err } @@ -371,7 +371,7 @@ func (s *Server) newIDToken(clientID string, claims storage.Claims, scopes []str subjectString, err := genSubject(claims.UserID, connID) if err != nil { - s.logger.Error("failed to marshal offline session ID", "err", err) + s.logger.ErrorContext(ctx, "failed to marshal offline session ID", "err", err) return "", expiry, fmt.Errorf("failed to marshal offline session ID: %v", err) } @@ -386,7 +386,7 @@ func (s *Server) newIDToken(clientID string, claims storage.Claims, scopes []str if accessToken != "" { atHash, err := accessTokenHash(signingAlg, accessToken) if err != nil { - s.logger.Error("error computing at_hash", "err", err) + s.logger.ErrorContext(ctx, "error computing at_hash", "err", err) return "", expiry, fmt.Errorf("error computing at_hash: %v", err) } tok.AccessTokenHash = atHash @@ -395,7 +395,7 @@ func (s *Server) newIDToken(clientID string, claims storage.Claims, scopes []str if code != "" { cHash, err := accessTokenHash(signingAlg, code) if err != nil { - s.logger.Error("error computing c_hash", "err", err) + s.logger.ErrorContext(ctx, "error computing c_hash", "err", err) return "", expiry, fmt.Errorf("error computing c_hash: #{err}") } tok.CodeHash = cHash @@ -423,7 +423,7 @@ func (s *Server) newIDToken(clientID string, claims storage.Claims, scopes []str // initial auth request. continue } - isTrusted, err := s.validateCrossClientTrust(clientID, peerID) + isTrusted, err := s.validateCrossClientTrust(ctx, clientID, peerID) if err != nil { return "", expiry, err } @@ -482,7 +482,7 @@ func (s *Server) parseAuthorizationRequest(r *http.Request) (*storage.AuthReques if err == storage.ErrNotFound { return nil, newDisplayedErr(http.StatusNotFound, "Invalid client_id (%q).", clientID) } - s.logger.Error("failed to get client", "err", err) + s.logger.ErrorContext(r.Context(), "failed to get client", "err", err) return nil, newDisplayedErr(http.StatusInternalServerError, "Database error.") } @@ -501,7 +501,7 @@ func (s *Server) parseAuthorizationRequest(r *http.Request) (*storage.AuthReques if connectorID != "" { connectors, err := s.storage.ListConnectors() if err != nil { - s.logger.Error("failed to list connectors", "err", err) + s.logger.ErrorContext(r.Context(), "failed to list connectors", "err", err) return nil, newRedirectedErr(errServerError, "Unable to retrieve connectors") } if !validateConnectorID(connectors, connectorID) { @@ -537,7 +537,7 @@ func (s *Server) parseAuthorizationRequest(r *http.Request) (*storage.AuthReques continue } - isTrusted, err := s.validateCrossClientTrust(clientID, peerID) + isTrusted, err := s.validateCrossClientTrust(r.Context(), clientID, peerID) if err != nil { return nil, newRedirectedErr(errServerError, "Internal server error.") } @@ -630,14 +630,14 @@ func parseCrossClientScope(scope string) (peerID string, ok bool) { return } -func (s *Server) validateCrossClientTrust(clientID, peerID string) (trusted bool, err error) { +func (s *Server) validateCrossClientTrust(ctx context.Context, clientID, peerID string) (trusted bool, err error) { if peerID == clientID { return true, nil } peer, err := s.storage.GetClient(peerID) if err != nil { if err != storage.ErrNotFound { - s.logger.Error("failed to get client", "err", err) + s.logger.ErrorContext(ctx, "failed to get client", "err", err) return false, err } return false, nil diff --git a/server/refreshhandlers.go b/server/refreshhandlers.go index 01a0f435b6..391d552251 100644 --- a/server/refreshhandlers.go +++ b/server/refreshhandlers.go @@ -80,14 +80,14 @@ type refreshContext struct { } // getRefreshTokenFromStorage checks that refresh token is valid and exists in the storage and gets its info -func (s *Server) getRefreshTokenFromStorage(clientID *string, token *internal.RefreshToken) (*refreshContext, *refreshError) { +func (s *Server) getRefreshTokenFromStorage(ctx context.Context, clientID *string, token *internal.RefreshToken) (*refreshContext, *refreshError) { refreshCtx := refreshContext{requestToken: token} // Get RefreshToken refresh, err := s.storage.GetRefresh(token.RefreshId) if err != nil { if err != storage.ErrNotFound { - s.logger.Error("failed to get refresh token", "err", err) + s.logger.ErrorContext(ctx, "failed to get refresh token", "err", err) return nil, newInternalServerError() } return nil, invalidErr @@ -95,7 +95,7 @@ func (s *Server) getRefreshTokenFromStorage(clientID *string, token *internal.Re // Only check ClientID if it was provided; if clientID != nil && (refresh.ClientID != *clientID) { - s.logger.Error("trying to claim token for different client", "client_id", clientID, "refresh_client_id", refresh.ClientID) + s.logger.ErrorContext(ctx, "trying to claim token for different client", "client_id", clientID, "refresh_client_id", refresh.ClientID) // According to https://datatracker.ietf.org/doc/html/rfc6749#section-5.2 Dex should respond with an // invalid grant error if token has already been claimed by another client. return nil, &refreshError{msg: errInvalidGrant, desc: invalidErr.desc, code: http.StatusBadRequest} @@ -108,18 +108,18 @@ func (s *Server) getRefreshTokenFromStorage(clientID *string, token *internal.Re case refresh.ObsoleteToken != token.Token: fallthrough case refresh.ObsoleteToken == "": - s.logger.Error("refresh token claimed twice", "token_id", refresh.ID) + s.logger.ErrorContext(ctx, "refresh token claimed twice", "token_id", refresh.ID) return nil, invalidErr } } if s.refreshTokenPolicy.CompletelyExpired(refresh.CreatedAt) { - s.logger.Error("refresh token expired", "token_id", refresh.ID) + s.logger.ErrorContext(ctx, "refresh token expired", "token_id", refresh.ID) return nil, expiredErr } if s.refreshTokenPolicy.ExpiredBecauseUnused(refresh.LastUsed) { - s.logger.Error("refresh token expired due to inactivity", "token_id", refresh.ID) + s.logger.ErrorContext(ctx, "refresh token expired due to inactivity", "token_id", refresh.ID) return nil, expiredErr } @@ -128,7 +128,7 @@ func (s *Server) getRefreshTokenFromStorage(clientID *string, token *internal.Re // Get Connector refreshCtx.connector, err = s.getConnector(refresh.ConnectorID) if err != nil { - s.logger.Error("connector not found", "connector_id", refresh.ConnectorID, "err", err) + s.logger.ErrorContext(ctx, "connector not found", "connector_id", refresh.ConnectorID, "err", err) return nil, newInternalServerError() } @@ -137,7 +137,7 @@ func (s *Server) getRefreshTokenFromStorage(clientID *string, token *internal.Re switch { case err != nil: if err != storage.ErrNotFound { - s.logger.Error("failed to get offline session", "err", err) + s.logger.ErrorContext(ctx, "failed to get offline session", "err", err) return nil, newInternalServerError() } case len(refresh.ConnectorData) > 0: @@ -195,7 +195,7 @@ func (s *Server) refreshWithConnector(ctx context.Context, rCtx *refreshContext, newIdent, err := refreshConn.Refresh(ctx, parseScopes(rCtx.scopes), ident) if err != nil { - s.logger.Error("failed to refresh identity", "err", err) + s.logger.ErrorContext(ctx, "failed to refresh identity", "err", err) return ident, newInternalServerError() } @@ -205,7 +205,7 @@ func (s *Server) refreshWithConnector(ctx context.Context, rCtx *refreshContext, } // updateOfflineSession updates offline session in the storage -func (s *Server) updateOfflineSession(refresh *storage.RefreshToken, ident connector.Identity, lastUsed time.Time) *refreshError { +func (s *Server) updateOfflineSession(ctx context.Context, refresh *storage.RefreshToken, ident connector.Identity, lastUsed time.Time) *refreshError { offlineSessionUpdater := func(old storage.OfflineSessions) (storage.OfflineSessions, error) { if old.Refresh[refresh.ClientID].ID != refresh.ID { return old, errors.New("refresh token invalid") @@ -216,7 +216,7 @@ func (s *Server) updateOfflineSession(refresh *storage.RefreshToken, ident conne old.ConnectorData = ident.ConnectorData } - s.logger.Debug("saved connector data", "user_id", ident.UserID, "connector_data", ident.ConnectorData) + s.logger.DebugContext(ctx, "saved connector data", "user_id", ident.UserID, "connector_data", ident.ConnectorData) return old, nil } @@ -225,7 +225,7 @@ func (s *Server) updateOfflineSession(refresh *storage.RefreshToken, ident conne // in offline session for the user. err := s.storage.UpdateOfflineSessions(refresh.Claims.UserID, refresh.ConnectorID, offlineSessionUpdater) if err != nil { - s.logger.Error("failed to update offline session", "err", err) + s.logger.ErrorContext(ctx, "failed to update offline session", "err", err) return newInternalServerError() } @@ -316,11 +316,11 @@ func (s *Server) updateRefreshToken(ctx context.Context, rCtx *refreshContext) ( // Update refresh token in the storage. err := s.storage.UpdateRefreshToken(rCtx.storageToken.ID, refreshTokenUpdater) if err != nil { - s.logger.Error("failed to update refresh token", "err", err) + s.logger.ErrorContext(ctx, "failed to update refresh token", "err", err) return nil, ident, newInternalServerError() } - rerr = s.updateOfflineSession(rCtx.storageToken, ident, lastUsed) + rerr = s.updateOfflineSession(ctx, rCtx.storageToken, ident, lastUsed) if rerr != nil { return nil, ident, rerr } @@ -337,7 +337,7 @@ func (s *Server) handleRefreshToken(w http.ResponseWriter, r *http.Request, clie return } - rCtx, rerr := s.getRefreshTokenFromStorage(&client.ID, token) + rCtx, rerr := s.getRefreshTokenFromStorage(r.Context(), &client.ID, token) if rerr != nil { s.refreshTokenErrHelper(w, rerr) return @@ -364,23 +364,23 @@ func (s *Server) handleRefreshToken(w http.ResponseWriter, r *http.Request, clie Groups: ident.Groups, } - accessToken, _, err := s.newAccessToken(client.ID, claims, rCtx.scopes, rCtx.storageToken.Nonce, rCtx.storageToken.ConnectorID) + accessToken, _, err := s.newAccessToken(r.Context(), client.ID, claims, rCtx.scopes, rCtx.storageToken.Nonce, rCtx.storageToken.ConnectorID) if err != nil { - s.logger.Error("failed to create new access token", "err", err) + s.logger.ErrorContext(r.Context(), "failed to create new access token", "err", err) s.refreshTokenErrHelper(w, newInternalServerError()) return } - idToken, expiry, err := s.newIDToken(client.ID, claims, rCtx.scopes, rCtx.storageToken.Nonce, accessToken, "", rCtx.storageToken.ConnectorID) + idToken, expiry, err := s.newIDToken(r.Context(), client.ID, claims, rCtx.scopes, rCtx.storageToken.Nonce, accessToken, "", rCtx.storageToken.ConnectorID) if err != nil { - s.logger.Error("failed to create ID token", "err", err) + s.logger.ErrorContext(r.Context(), "failed to create ID token", "err", err) s.refreshTokenErrHelper(w, newInternalServerError()) return } rawNewToken, err := internal.Marshal(newToken) if err != nil { - s.logger.Error("failed to marshal refresh token", "err", err) + s.logger.ErrorContext(r.Context(), "failed to marshal refresh token", "err", err) s.refreshTokenErrHelper(w, newInternalServerError()) return } diff --git a/server/server.go b/server/server.go index 68294885b9..b447fa3276 100644 --- a/server/server.go +++ b/server/server.go @@ -8,7 +8,9 @@ import ( "fmt" "io/fs" "log/slog" + "net" "net/http" + "net/netip" "net/url" "os" "path" @@ -21,6 +23,7 @@ import ( gosundheit "github.com/AppsFlyer/go-sundheit" "github.com/felixge/httpsnoop" + "github.com/google/uuid" "github.com/gorilla/handlers" "github.com/gorilla/mux" "github.com/prometheus/client_golang/prometheus" @@ -85,6 +88,10 @@ type Config struct { // Headers is a map of headers to be added to the all responses. Headers http.Header + // Header to extract real ip from. + RealIPHeader string + TrustedRealIPCIDRs []netip.Prefix + // List of allowed origins for CORS requests on discovery, token and keys endpoint. // If none are indicated, CORS requests are disabled. Passing in "*" will allow any // domain. @@ -358,11 +365,52 @@ func newServer(ctx context.Context, c Config, rotationStrategy rotationStrategy) } } + parseRealIP := func(r *http.Request) (string, error) { + remoteAddr, _, err := net.SplitHostPort(r.RemoteAddr) + if err != nil { + return "", err + } + + remoteIP, err := netip.ParseAddr(remoteAddr) + if err != nil { + return "", err + } + + for _, n := range c.TrustedRealIPCIDRs { + if !n.Contains(remoteIP) { + return remoteAddr, nil // Fallback to the address from the request if the header is provided + } + } + + ipVal := r.Header.Get(c.RealIPHeader) + if ipVal != "" { + ip, err := netip.ParseAddr(ipVal) + if err == nil { + return ip.String(), nil + } + } + + return remoteAddr, nil + } + handlerWithHeaders := func(handlerName string, handler http.Handler) http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { for k, v := range c.Headers { w.Header()[k] = v } + + // Context values are used for logging purposes with the log/slog logger. + rCtx := r.Context() + rCtx = WithRequestID(rCtx) + + if c.RealIPHeader != "" { + realIP, err := parseRealIP(r) + if err == nil { + rCtx = WithRemoteIP(rCtx, realIP) + } + } + + r = r.WithContext(rCtx) instrumentHandlerCounter(handlerName, handler)(w, r) } } @@ -682,3 +730,18 @@ func (s *Server) getConnector(id string) (Connector, error) { return conn, nil } + +type logRequestKey string + +const ( + RequestKeyRequestID logRequestKey = "request_id" + RequestKeyRemoteIP logRequestKey = "client_remote_addr" +) + +func WithRequestID(ctx context.Context) context.Context { + return context.WithValue(ctx, RequestKeyRequestID, uuid.NewString()) +} + +func WithRemoteIP(ctx context.Context, ip string) context.Context { + return context.WithValue(ctx, RequestKeyRemoteIP, ip) +}