Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enrich Dex logs with real IP and request ID #3661

Merged
merged 3 commits into from
Aug 1, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 31 additions & 9 deletions cmd/dex/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"fmt"
"log/slog"
"net/http"
"net/netip"
"os"
"slices"
"strings"
Expand Down Expand Up @@ -182,15 +183,36 @@ type OAuth2 struct {

// Web is the config format for the HTTP server.
type Web struct {
HTTP string `json:"http"`
HTTPS string `json:"https"`
Headers Headers `json:"headers"`
TLSCert string `json:"tlsCert"`
TLSKey string `json:"tlsKey"`
TLSMinVersion string `json:"tlsMinVersion"`
TLSMaxVersion string `json:"tlsMaxVersion"`
AllowedOrigins []string `json:"allowedOrigins"`
AllowedHeaders []string `json:"allowedHeaders"`
HTTP string `json:"http"`
HTTPS string `json:"https"`
Headers Headers `json:"headers"`
TLSCert string `json:"tlsCert"`
TLSKey string `json:"tlsKey"`
TLSMinVersion string `json:"tlsMinVersion"`
TLSMaxVersion string `json:"tlsMaxVersion"`
AllowedOrigins []string `json:"allowedOrigins"`
AllowedHeaders []string `json:"allowedHeaders"`
ClientRemoteIP ClientRemoteIP `json:"clientRemoteIP"`
}

type ClientRemoteIP struct {
Header string `json:"header"`
TrustedProxyCIDRs []string `json:"trustedProxyCIDRs"`
sagikazarmark marked this conversation as resolved.
Show resolved Hide resolved
}

func (cr *ClientRemoteIP) ToParsedCIDRs() (string, []netip.Prefix, error) {
nabokihms marked this conversation as resolved.
Show resolved Hide resolved
if cr == nil {
return "", nil, nil
}
nabokihms marked this conversation as resolved.
Show resolved Hide resolved
trusted := make([]netip.Prefix, 0, len(cr.TrustedProxyCIDRs))
for _, cidr := range cr.TrustedProxyCIDRs {
ipNet, err := netip.ParsePrefix(cidr)
if err != nil {
return "", nil, fmt.Errorf("failed to parse CIDR %q: %v", cidr, err)
}
trusted = append(trusted, ipNet)
}
nabokihms marked this conversation as resolved.
Show resolved Hide resolved
return cr.Header, trusted, nil
}

type Headers struct {
Expand Down
67 changes: 67 additions & 0 deletions cmd/dex/logger.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
package main

import (
"context"
"fmt"
"log/slog"
"os"
"strings"

"github.com/dexidp/dex/server"
)

var logFormats = []string{"json", "text"}

func newLogger(level slog.Level, format string) (*slog.Logger, error) {
var handler slog.Handler
switch strings.ToLower(format) {
case "", "text":
handler = slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{
Level: level,
})
case "json":
handler = slog.NewJSONHandler(os.Stderr, &slog.HandlerOptions{
Level: level,
})
default:
return nil, fmt.Errorf("log format is not one of the supported values (%s): %s", strings.Join(logFormats, ", "), format)
}

return slog.New(newRequestContextHandler(handler)), nil
}

var _ slog.Handler = requestContextHandler{}

type requestContextHandler struct {
handler slog.Handler
}

func newRequestContextHandler(handler slog.Handler) slog.Handler {
return requestContextHandler{
handler: handler,
}
}

func (h requestContextHandler) Enabled(ctx context.Context, level slog.Level) bool {
return h.handler.Enabled(ctx, level)
}

func (h requestContextHandler) Handle(ctx context.Context, record slog.Record) error {
if v, ok := ctx.Value(server.LogRequestKeyRemoteIP).(string); ok {
record.AddAttrs(slog.String(string(server.LogRequestKeyRemoteIP), v))
}

if v, ok := ctx.Value(server.LogRequestKeyRequestID).(string); ok {
record.AddAttrs(slog.String(string(server.LogRequestKeyRequestID), v))
}

return h.handler.Handle(ctx, record)
}

func (h requestContextHandler) WithAttrs(attrs []slog.Attr) slog.Handler {
return requestContextHandler{h.handler.WithAttrs(attrs)}
}

func (h requestContextHandler) WithGroup(name string) slog.Handler {
return h.handler.WithGroup(name)
}
26 changes: 6 additions & 20 deletions cmd/dex/serve.go
Original file line number Diff line number Diff line change
Expand Up @@ -348,6 +348,12 @@ func runServe(options serveOptions) error {
}

serverConfig.RefreshTokenPolicy = refreshTokenPolicy

serverConfig.RealIPHeader, serverConfig.TrustedRealIPCIDRs, err = c.Web.ClientRemoteIP.ToParsedCIDRs()
if err != nil {
return fmt.Errorf("failed to parse client remote IP settings: %v", err)
}

serv, err := server.NewServer(context.Background(), serverConfig)
if err != nil {
return fmt.Errorf("failed to initialize server: %v", err)
Expand Down Expand Up @@ -528,26 +534,6 @@ func runServe(options serveOptions) error {
return nil
}

var logFormats = []string{"json", "text"}

func newLogger(level slog.Level, format string) (*slog.Logger, error) {
var handler slog.Handler
switch strings.ToLower(format) {
case "", "text":
handler = slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{
Level: level,
})
case "json":
handler = slog.NewJSONHandler(os.Stderr, &slog.HandlerOptions{
Level: level,
})
default:
return nil, fmt.Errorf("log format is not one of the supported values (%s): %s", strings.Join(logFormats, ", "), format)
}

return slog.New(handler), nil
}

func applyConfigOverrides(options serveOptions, config *Config) {
if options.webHTTPAddr != "" {
config.Web.HTTP = options.webHTTPAddr
Expand Down
5 changes: 4 additions & 1 deletion examples/config-dev.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,10 @@ web:
# X-XSS-Protection: "1; mode=block"
# Content-Security-Policy: "default-src 'self'"
# Strict-Transport-Security: "max-age=31536000; includeSubDomains"

# clientRemoteIP:
# header: X-Forwarded-For
# trustedProxyCIDRs:
# - 10.0.0.0/8

# Configuration for dex appearance
# frontend:
Expand Down
38 changes: 19 additions & 19 deletions server/deviceflowhandlers.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ func (s *Server) handleDeviceExchange(w http.ResponseWriter, r *http.Request) {
invalidAttempt = false
}
if err := s.templates.device(r, w, s.getDeviceVerificationURI(), userCode, invalidAttempt); err != nil {
s.logger.Error("server template error", "err", err)
s.logger.ErrorContext(r.Context(), "server template error", "err", err)
s.renderError(r, w, http.StatusNotFound, "Page not found")
}
default:
Expand All @@ -64,7 +64,7 @@ func (s *Server) handleDeviceCode(w http.ResponseWriter, r *http.Request) {
case http.MethodPost:
err := r.ParseForm()
if err != nil {
s.logger.Error("could not parse Device Request body", "err", err)
s.logger.ErrorContext(r.Context(), "could not parse Device Request body", "err", err)
s.tokenErrHelper(w, errInvalidRequest, "", http.StatusNotFound)
return
}
Expand All @@ -85,7 +85,7 @@ func (s *Server) handleDeviceCode(w http.ResponseWriter, r *http.Request) {
return
}

s.logger.Info("received device request", "client_id", clientID, "scoped", scopes)
s.logger.InfoContext(r.Context(), "received device request", "client_id", clientID, "scoped", scopes)

// Make device code
deviceCode := storage.NewDeviceCode()
Expand All @@ -107,7 +107,7 @@ func (s *Server) handleDeviceCode(w http.ResponseWriter, r *http.Request) {
}

if err := s.storage.CreateDeviceRequest(ctx, deviceReq); err != nil {
s.logger.Error("failed to store device request", "err", err)
s.logger.ErrorContext(r.Context(), "failed to store device request", "err", err)
s.tokenErrHelper(w, errInvalidRequest, "", http.StatusInternalServerError)
return
}
Expand All @@ -126,14 +126,14 @@ func (s *Server) handleDeviceCode(w http.ResponseWriter, r *http.Request) {
}

if err := s.storage.CreateDeviceToken(ctx, deviceToken); err != nil {
s.logger.Error("failed to store device token", "err", err)
s.logger.ErrorContext(r.Context(), "failed to store device token", "err", err)
s.tokenErrHelper(w, errInvalidRequest, "", http.StatusInternalServerError)
return
}

u, err := url.Parse(s.issuerURL.String())
if err != nil {
s.logger.Error("could not parse issuer URL", "err", err)
s.logger.ErrorContext(r.Context(), "could not parse issuer URL", "err", err)
s.tokenErrHelper(w, errInvalidRequest, "", http.StatusInternalServerError)
return
}
Expand Down Expand Up @@ -211,7 +211,7 @@ func (s *Server) handleDeviceToken(w http.ResponseWriter, r *http.Request) {
deviceToken, err := s.storage.GetDeviceToken(deviceCode)
if err != nil {
if err != storage.ErrNotFound {
s.logger.Error("failed to get device code", "err", err)
s.logger.ErrorContext(r.Context(), "failed to get device code", "err", err)
}
s.tokenErrHelper(w, errInvalidRequest, "Invalid Device code.", http.StatusBadRequest)
return
Expand Down Expand Up @@ -241,7 +241,7 @@ func (s *Server) handleDeviceToken(w http.ResponseWriter, r *http.Request) {
}
// Update device token last request time in storage
if err := s.storage.UpdateDeviceToken(deviceCode, updater); err != nil {
s.logger.Error("failed to update device token", "err", err)
s.logger.ErrorContext(r.Context(), "failed to update device token", "err", err)
s.renderError(r, w, http.StatusInternalServerError, "")
return
}
Expand All @@ -258,7 +258,7 @@ func (s *Server) handleDeviceToken(w http.ResponseWriter, r *http.Request) {
case providedCodeVerifier != "" && codeChallengeFromStorage != "":
calculatedCodeChallenge, err := s.calculateCodeChallenge(providedCodeVerifier, deviceToken.PKCE.CodeChallengeMethod)
if err != nil {
s.logger.Error("failed to calculate code challenge", "err", err)
s.logger.ErrorContext(r.Context(), "failed to calculate code challenge", "err", err)
s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError)
return
}
Expand Down Expand Up @@ -303,7 +303,7 @@ func (s *Server) handleDeviceCallback(w http.ResponseWriter, r *http.Request) {
if err != nil || s.now().After(authCode.Expiry) {
errCode := http.StatusBadRequest
if err != nil && err != storage.ErrNotFound {
s.logger.Error("failed to get auth code", "err", err)
s.logger.ErrorContext(r.Context(), "failed to get auth code", "err", err)
errCode = http.StatusInternalServerError
}
s.renderError(r, w, errCode, "Invalid or expired auth code.")
Expand All @@ -315,7 +315,7 @@ func (s *Server) handleDeviceCallback(w http.ResponseWriter, r *http.Request) {
if err != nil || s.now().After(deviceReq.Expiry) {
errCode := http.StatusBadRequest
if err != nil && err != storage.ErrNotFound {
s.logger.Error("failed to get device code", "err", err)
s.logger.ErrorContext(r.Context(), "failed to get device code", "err", err)
errCode = http.StatusInternalServerError
}
s.renderError(r, w, errCode, "Invalid or expired user code.")
Expand All @@ -325,7 +325,7 @@ func (s *Server) handleDeviceCallback(w http.ResponseWriter, r *http.Request) {
client, err := s.storage.GetClient(deviceReq.ClientID)
if err != nil {
if err != storage.ErrNotFound {
s.logger.Error("failed to get client", "err", err)
s.logger.ErrorContext(r.Context(), "failed to get client", "err", err)
s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError)
} else {
s.tokenErrHelper(w, errInvalidClient, "Invalid client credentials.", http.StatusUnauthorized)
Expand All @@ -339,7 +339,7 @@ func (s *Server) handleDeviceCallback(w http.ResponseWriter, r *http.Request) {

resp, err := s.exchangeAuthCode(ctx, w, authCode, client)
if err != nil {
s.logger.Error("could not exchange auth code for clien", "client_id", deviceReq.ClientID, "err", err)
s.logger.ErrorContext(r.Context(), "could not exchange auth code for clien", "client_id", deviceReq.ClientID, "err", err)
s.renderError(r, w, http.StatusInternalServerError, "Failed to exchange auth code.")
return
}
Expand All @@ -349,7 +349,7 @@ func (s *Server) handleDeviceCallback(w http.ResponseWriter, r *http.Request) {
if err != nil || s.now().After(old.Expiry) {
errCode := http.StatusBadRequest
if err != nil && err != storage.ErrNotFound {
s.logger.Error("failed to get device token", "err", err)
s.logger.ErrorContext(r.Context(), "failed to get device token", "err", err)
errCode = http.StatusInternalServerError
}
s.renderError(r, w, errCode, "Invalid or expired device code.")
Expand All @@ -362,7 +362,7 @@ func (s *Server) handleDeviceCallback(w http.ResponseWriter, r *http.Request) {
}
respStr, err := json.MarshalIndent(resp, "", " ")
if err != nil {
s.logger.Error("failed to marshal device token response", "err", err)
s.logger.ErrorContext(r.Context(), "failed to marshal device token response", "err", err)
s.renderError(r, w, http.StatusInternalServerError, "")
return old, err
}
Expand All @@ -374,13 +374,13 @@ func (s *Server) handleDeviceCallback(w http.ResponseWriter, r *http.Request) {

// Update refresh token in the storage, store the token and mark as complete
if err := s.storage.UpdateDeviceToken(deviceReq.DeviceCode, updater); err != nil {
s.logger.Error("failed to update device token", "err", err)
s.logger.ErrorContext(r.Context(), "failed to update device token", "err", err)
s.renderError(r, w, http.StatusBadRequest, "")
return
}

if err := s.templates.deviceSuccess(r, w, client.Name); err != nil {
s.logger.Error("Server template error", "err", err)
s.logger.ErrorContext(r.Context(), "Server template error", "err", err)
s.renderError(r, w, http.StatusNotFound, "Page not found")
}

Expand Down Expand Up @@ -412,10 +412,10 @@ func (s *Server) verifyUserCode(w http.ResponseWriter, r *http.Request) {
deviceRequest, err := s.storage.GetDeviceRequest(userCode)
if err != nil || s.now().After(deviceRequest.Expiry) {
if err != nil && err != storage.ErrNotFound {
s.logger.Error("failed to get device request", "err", err)
s.logger.ErrorContext(r.Context(), "failed to get device request", "err", err)
}
if err := s.templates.device(r, w, s.getDeviceVerificationURI(), userCode, true); err != nil {
s.logger.Error("Server template error", "err", err)
s.logger.ErrorContext(r.Context(), "Server template error", "err", err)
s.renderError(r, w, http.StatusNotFound, "Page not found")
}
return
Expand Down
Loading