Skip to content

Commit

Permalink
Enrich Dex logs with real IP and request ID
Browse files Browse the repository at this point in the history
- Include client's real IP in logs.
- Capture and log unique request IDs.
- Update middleware to extract and forward this data.
- Enhance debugging and traceability.
Improves insight and troubleshooting efficiency.

Signed-off-by: m.nabokikh <maksim.nabokikh@flant.com>
  • Loading branch information
nabokihms committed Jul 30, 2024
1 parent 2a6ddc1 commit ac1c12b
Show file tree
Hide file tree
Showing 11 changed files with 329 additions and 187 deletions.
40 changes: 31 additions & 9 deletions cmd/dex/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"fmt"
"log/slog"
"net/http"
"net/netip"
"os"
"slices"
"strings"
Expand Down Expand Up @@ -182,15 +183,36 @@ type OAuth2 struct {

// Web is the config format for the HTTP server.
type Web struct {
HTTP string `json:"http"`
HTTPS string `json:"https"`
Headers Headers `json:"headers"`
TLSCert string `json:"tlsCert"`
TLSKey string `json:"tlsKey"`
TLSMinVersion string `json:"tlsMinVersion"`
TLSMaxVersion string `json:"tlsMaxVersion"`
AllowedOrigins []string `json:"allowedOrigins"`
AllowedHeaders []string `json:"allowedHeaders"`
HTTP string `json:"http"`
HTTPS string `json:"https"`
Headers Headers `json:"headers"`
TLSCert string `json:"tlsCert"`
TLSKey string `json:"tlsKey"`
TLSMinVersion string `json:"tlsMinVersion"`
TLSMaxVersion string `json:"tlsMaxVersion"`
AllowedOrigins []string `json:"allowedOrigins"`
AllowedHeaders []string `json:"allowedHeaders"`
ClientRemoteIP ClientRemoteIP `json:"clientRemoteIP"`
}

type ClientRemoteIP struct {
Header string `json:"header"`
TrustedProxyCIDRs []string `json:"trustedProxyCIDRs"`
}

func (cr *ClientRemoteIP) ToParsedCIDRs() (string, []netip.Prefix, error) {
if cr == nil {
return "", nil, nil
}
trusted := make([]netip.Prefix, 0, len(cr.TrustedProxyCIDRs))
for _, cidr := range cr.TrustedProxyCIDRs {
ipNet, err := netip.ParsePrefix(cidr)
if err != nil {
return "", nil, fmt.Errorf("failed to parse CIDR %q: %v", cidr, err)
}
trusted = append(trusted, ipNet)
}
return cr.Header, trusted, nil
}

type Headers struct {
Expand Down
67 changes: 67 additions & 0 deletions cmd/dex/logger.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
package main

import (
"context"
"fmt"
"log/slog"
"os"
"strings"

"github.com/dexidp/dex/server"
)

var logFormats = []string{"json", "text"}

func newLogger(level slog.Level, format string) (*slog.Logger, error) {
var handler slog.Handler
switch strings.ToLower(format) {
case "", "text":
handler = slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{
Level: level,
})
case "json":
handler = slog.NewJSONHandler(os.Stderr, &slog.HandlerOptions{
Level: level,
})
default:
return nil, fmt.Errorf("log format is not one of the supported values (%s): %s", strings.Join(logFormats, ", "), format)
}

return slog.New(newRequestContextHandler(handler)), nil
}

var _ slog.Handler = requestContextHandler{}

type requestContextHandler struct {
handler slog.Handler
}

func newRequestContextHandler(handler slog.Handler) slog.Handler {
return requestContextHandler{
handler: handler,
}
}

func (h requestContextHandler) Enabled(ctx context.Context, level slog.Level) bool {
return h.handler.Enabled(ctx, level)
}

func (h requestContextHandler) Handle(ctx context.Context, record slog.Record) error {
if v, ok := ctx.Value(server.LogRequestKeyRemoteIP).(string); ok {
record.AddAttrs(slog.String(string(server.LogRequestKeyRemoteIP), v))
}

if v, ok := ctx.Value(server.LogRequestKeyRequestID).(string); ok {
record.AddAttrs(slog.String(string(server.LogRequestKeyRequestID), v))
}

return h.handler.Handle(ctx, record)
}

func (h requestContextHandler) WithAttrs(attrs []slog.Attr) slog.Handler {
return requestContextHandler{h.handler.WithAttrs(attrs)}
}

func (h requestContextHandler) WithGroup(name string) slog.Handler {
return h.handler.WithGroup(name)
}
26 changes: 6 additions & 20 deletions cmd/dex/serve.go
Original file line number Diff line number Diff line change
Expand Up @@ -348,6 +348,12 @@ func runServe(options serveOptions) error {
}

serverConfig.RefreshTokenPolicy = refreshTokenPolicy

serverConfig.RealIPHeader, serverConfig.TrustedRealIPCIDRs, err = c.Web.ClientRemoteIP.ToParsedCIDRs()
if err != nil {
return fmt.Errorf("failed to parse client remote IP settings: %v", err)
}

serv, err := server.NewServer(context.Background(), serverConfig)
if err != nil {
return fmt.Errorf("failed to initialize server: %v", err)
Expand Down Expand Up @@ -528,26 +534,6 @@ func runServe(options serveOptions) error {
return nil
}

var logFormats = []string{"json", "text"}

func newLogger(level slog.Level, format string) (*slog.Logger, error) {
var handler slog.Handler
switch strings.ToLower(format) {
case "", "text":
handler = slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{
Level: level,
})
case "json":
handler = slog.NewJSONHandler(os.Stderr, &slog.HandlerOptions{
Level: level,
})
default:
return nil, fmt.Errorf("log format is not one of the supported values (%s): %s", strings.Join(logFormats, ", "), format)
}

return slog.New(handler), nil
}

func applyConfigOverrides(options serveOptions, config *Config) {
if options.webHTTPAddr != "" {
config.Web.HTTP = options.webHTTPAddr
Expand Down
5 changes: 4 additions & 1 deletion examples/config-dev.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,10 @@ web:
# X-XSS-Protection: "1; mode=block"
# Content-Security-Policy: "default-src 'self'"
# Strict-Transport-Security: "max-age=31536000; includeSubDomains"

# clientRemoteIP:
# header: X-Forwarded-For
# trustedProxyCIDRs:
# - 10.0.0.0/8

# Configuration for dex appearance
# frontend:
Expand Down
38 changes: 19 additions & 19 deletions server/deviceflowhandlers.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ func (s *Server) handleDeviceExchange(w http.ResponseWriter, r *http.Request) {
invalidAttempt = false
}
if err := s.templates.device(r, w, s.getDeviceVerificationURI(), userCode, invalidAttempt); err != nil {
s.logger.Error("server template error", "err", err)
s.logger.ErrorContext(r.Context(), "server template error", "err", err)
s.renderError(r, w, http.StatusNotFound, "Page not found")
}
default:
Expand All @@ -64,7 +64,7 @@ func (s *Server) handleDeviceCode(w http.ResponseWriter, r *http.Request) {
case http.MethodPost:
err := r.ParseForm()
if err != nil {
s.logger.Error("could not parse Device Request body", "err", err)
s.logger.ErrorContext(r.Context(), "could not parse Device Request body", "err", err)
s.tokenErrHelper(w, errInvalidRequest, "", http.StatusNotFound)
return
}
Expand All @@ -85,7 +85,7 @@ func (s *Server) handleDeviceCode(w http.ResponseWriter, r *http.Request) {
return
}

s.logger.Info("received device request", "client_id", clientID, "scoped", scopes)
s.logger.InfoContext(r.Context(), "received device request", "client_id", clientID, "scoped", scopes)

// Make device code
deviceCode := storage.NewDeviceCode()
Expand All @@ -107,7 +107,7 @@ func (s *Server) handleDeviceCode(w http.ResponseWriter, r *http.Request) {
}

if err := s.storage.CreateDeviceRequest(ctx, deviceReq); err != nil {
s.logger.Error("failed to store device request", "err", err)
s.logger.ErrorContext(r.Context(), "failed to store device request", "err", err)
s.tokenErrHelper(w, errInvalidRequest, "", http.StatusInternalServerError)
return
}
Expand All @@ -126,14 +126,14 @@ func (s *Server) handleDeviceCode(w http.ResponseWriter, r *http.Request) {
}

if err := s.storage.CreateDeviceToken(ctx, deviceToken); err != nil {
s.logger.Error("failed to store device token", "err", err)
s.logger.ErrorContext(r.Context(), "failed to store device token", "err", err)
s.tokenErrHelper(w, errInvalidRequest, "", http.StatusInternalServerError)
return
}

u, err := url.Parse(s.issuerURL.String())
if err != nil {
s.logger.Error("could not parse issuer URL", "err", err)
s.logger.ErrorContext(r.Context(), "could not parse issuer URL", "err", err)
s.tokenErrHelper(w, errInvalidRequest, "", http.StatusInternalServerError)
return
}
Expand Down Expand Up @@ -211,7 +211,7 @@ func (s *Server) handleDeviceToken(w http.ResponseWriter, r *http.Request) {
deviceToken, err := s.storage.GetDeviceToken(deviceCode)
if err != nil {
if err != storage.ErrNotFound {
s.logger.Error("failed to get device code", "err", err)
s.logger.ErrorContext(r.Context(), "failed to get device code", "err", err)
}
s.tokenErrHelper(w, errInvalidRequest, "Invalid Device code.", http.StatusBadRequest)
return
Expand Down Expand Up @@ -241,7 +241,7 @@ func (s *Server) handleDeviceToken(w http.ResponseWriter, r *http.Request) {
}
// Update device token last request time in storage
if err := s.storage.UpdateDeviceToken(deviceCode, updater); err != nil {
s.logger.Error("failed to update device token", "err", err)
s.logger.ErrorContext(r.Context(), "failed to update device token", "err", err)
s.renderError(r, w, http.StatusInternalServerError, "")
return
}
Expand All @@ -258,7 +258,7 @@ func (s *Server) handleDeviceToken(w http.ResponseWriter, r *http.Request) {
case providedCodeVerifier != "" && codeChallengeFromStorage != "":
calculatedCodeChallenge, err := s.calculateCodeChallenge(providedCodeVerifier, deviceToken.PKCE.CodeChallengeMethod)
if err != nil {
s.logger.Error("failed to calculate code challenge", "err", err)
s.logger.ErrorContext(r.Context(), "failed to calculate code challenge", "err", err)
s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError)
return
}
Expand Down Expand Up @@ -303,7 +303,7 @@ func (s *Server) handleDeviceCallback(w http.ResponseWriter, r *http.Request) {
if err != nil || s.now().After(authCode.Expiry) {
errCode := http.StatusBadRequest
if err != nil && err != storage.ErrNotFound {
s.logger.Error("failed to get auth code", "err", err)
s.logger.ErrorContext(r.Context(), "failed to get auth code", "err", err)
errCode = http.StatusInternalServerError
}
s.renderError(r, w, errCode, "Invalid or expired auth code.")
Expand All @@ -315,7 +315,7 @@ func (s *Server) handleDeviceCallback(w http.ResponseWriter, r *http.Request) {
if err != nil || s.now().After(deviceReq.Expiry) {
errCode := http.StatusBadRequest
if err != nil && err != storage.ErrNotFound {
s.logger.Error("failed to get device code", "err", err)
s.logger.ErrorContext(r.Context(), "failed to get device code", "err", err)
errCode = http.StatusInternalServerError
}
s.renderError(r, w, errCode, "Invalid or expired user code.")
Expand All @@ -325,7 +325,7 @@ func (s *Server) handleDeviceCallback(w http.ResponseWriter, r *http.Request) {
client, err := s.storage.GetClient(deviceReq.ClientID)
if err != nil {
if err != storage.ErrNotFound {
s.logger.Error("failed to get client", "err", err)
s.logger.ErrorContext(r.Context(), "failed to get client", "err", err)
s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError)
} else {
s.tokenErrHelper(w, errInvalidClient, "Invalid client credentials.", http.StatusUnauthorized)
Expand All @@ -339,7 +339,7 @@ func (s *Server) handleDeviceCallback(w http.ResponseWriter, r *http.Request) {

resp, err := s.exchangeAuthCode(ctx, w, authCode, client)
if err != nil {
s.logger.Error("could not exchange auth code for clien", "client_id", deviceReq.ClientID, "err", err)
s.logger.ErrorContext(r.Context(), "could not exchange auth code for clien", "client_id", deviceReq.ClientID, "err", err)
s.renderError(r, w, http.StatusInternalServerError, "Failed to exchange auth code.")
return
}
Expand All @@ -349,7 +349,7 @@ func (s *Server) handleDeviceCallback(w http.ResponseWriter, r *http.Request) {
if err != nil || s.now().After(old.Expiry) {
errCode := http.StatusBadRequest
if err != nil && err != storage.ErrNotFound {
s.logger.Error("failed to get device token", "err", err)
s.logger.ErrorContext(r.Context(), "failed to get device token", "err", err)
errCode = http.StatusInternalServerError
}
s.renderError(r, w, errCode, "Invalid or expired device code.")
Expand All @@ -362,7 +362,7 @@ func (s *Server) handleDeviceCallback(w http.ResponseWriter, r *http.Request) {
}
respStr, err := json.MarshalIndent(resp, "", " ")
if err != nil {
s.logger.Error("failed to marshal device token response", "err", err)
s.logger.ErrorContext(r.Context(), "failed to marshal device token response", "err", err)
s.renderError(r, w, http.StatusInternalServerError, "")
return old, err
}
Expand All @@ -374,13 +374,13 @@ func (s *Server) handleDeviceCallback(w http.ResponseWriter, r *http.Request) {

// Update refresh token in the storage, store the token and mark as complete
if err := s.storage.UpdateDeviceToken(deviceReq.DeviceCode, updater); err != nil {
s.logger.Error("failed to update device token", "err", err)
s.logger.ErrorContext(r.Context(), "failed to update device token", "err", err)
s.renderError(r, w, http.StatusBadRequest, "")
return
}

if err := s.templates.deviceSuccess(r, w, client.Name); err != nil {
s.logger.Error("Server template error", "err", err)
s.logger.ErrorContext(r.Context(), "Server template error", "err", err)
s.renderError(r, w, http.StatusNotFound, "Page not found")
}

Expand Down Expand Up @@ -412,10 +412,10 @@ func (s *Server) verifyUserCode(w http.ResponseWriter, r *http.Request) {
deviceRequest, err := s.storage.GetDeviceRequest(userCode)
if err != nil || s.now().After(deviceRequest.Expiry) {
if err != nil && err != storage.ErrNotFound {
s.logger.Error("failed to get device request", "err", err)
s.logger.ErrorContext(r.Context(), "failed to get device request", "err", err)
}
if err := s.templates.device(r, w, s.getDeviceVerificationURI(), userCode, true); err != nil {
s.logger.Error("Server template error", "err", err)
s.logger.ErrorContext(r.Context(), "Server template error", "err", err)
s.renderError(r, w, http.StatusNotFound, "Page not found")
}
return
Expand Down
Loading

0 comments on commit ac1c12b

Please sign in to comment.