diff --git a/cmdutil/debug/debug.go b/cmdutil/debug/debug.go index f8a72306..d43cc642 100644 --- a/cmdutil/debug/debug.go +++ b/cmdutil/debug/debug.go @@ -15,7 +15,12 @@ package debug import ( + "context" "fmt" + "net/http" + "net/http/pprof" + "runtime" + "time" "github.com/google/gops/agent" "github.com/sirupsen/logrus" @@ -71,3 +76,83 @@ func (s *Server) Stop(_ error) { close(s.done) } + +// PProfServer wraps a pprof server. +type PProfServer struct { + logger logrus.FieldLogger + addr string + done chan struct{} + pprofServer *http.Server +} + +// ProfileConfig holds the configuration for the pprof server. +type PProfServerConfig struct { + Addr string + MutexProfileFraction int +} + +// defaultMutexProfileFraction is the default value for MutexProfileFraction +const defaultMutexProfileFraction = 2 + +// NewPProfServer sets up a pprof server with configurable profiling types and returns a PProfServer instance. +func NewPProfServer(config PProfServerConfig, l logrus.FieldLogger) *PProfServer { + if config.Addr == "" { + config.Addr = "127.0.0.1:9998" // Default port + } + + // Use a local variable for the mutex profile fraction + mpf := defaultMutexProfileFraction + if config.MutexProfileFraction != 0 { + mpf = config.MutexProfileFraction + } + runtime.SetMutexProfileFraction(mpf) + + httpServer := &http.Server{ + Addr: config.Addr, + Handler: http.HandlerFunc(pprof.Index), + ReadHeaderTimeout: 5 * time.Second, + } + + return &PProfServer{ + logger: l, + addr: config.Addr, + done: make(chan struct{}), + pprofServer: httpServer, + } +} + +// Run starts the pprof server. +// +// It implements oklog group's runFn. +func (s *PProfServer) Run() error { + if s.pprofServer == nil { + return fmt.Errorf("pprofServer is nil") + } + + s.logger.WithFields(logrus.Fields{ + "at": "binding", + "service": "pprof", + "addr": s.addr, + }).Info() + + if err := s.pprofServer.ListenAndServe(); err != nil && err != http.ErrServerClosed { + return err + } + + <-s.done + return nil +} + +// Stop shuts down the pprof server. +// +// It implements oklog group's interruptFn. +func (s *PProfServer) Stop(_ error) { + if s.pprofServer != nil { + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + if err := s.pprofServer.Shutdown(ctx); err != nil { + s.logger.WithError(err).Error("Error shutting down pprof server") + } + } + close(s.done) +} diff --git a/cmdutil/debug/debug_test.go b/cmdutil/debug/debug_test.go new file mode 100644 index 00000000..350a08b0 --- /dev/null +++ b/cmdutil/debug/debug_test.go @@ -0,0 +1,100 @@ +package debug + +import ( + "net/http" + "runtime" + "testing" + "time" + + "github.com/sirupsen/logrus" +) + +func TestNewPProfServer(t *testing.T) { + logger := logrus.New() + + tests := []struct { + name string + config PProfServerConfig + expectedAddr string + expectedMutexFraction int + }{ + { + name: "DefaultAddr", + config: PProfServerConfig{}, + expectedAddr: "127.0.0.1:9998", + expectedMutexFraction: defaultMutexProfileFraction, + }, + { + name: "CustomAddr", + config: PProfServerConfig{Addr: "127.0.0.1:9090"}, + expectedAddr: "127.0.0.1:9090", + expectedMutexFraction: defaultMutexProfileFraction, + }, + { + name: "CustomMutexProfileFraction", + config: PProfServerConfig{MutexProfileFraction: 5}, + expectedAddr: "127.0.0.1:9998", + expectedMutexFraction: 5, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + server := NewPProfServer(tt.config, logger) + + // Check server address + if server.addr != tt.expectedAddr { + t.Errorf("NewPProfServer() addr = %v, want %v", server.addr, tt.expectedAddr) + } + + // Start the server + go func() { + if err := server.Run(); err != nil { + t.Errorf("NewPProfServer() run error = %v", err) + } + }() + + // Give the server a moment to start + time.Sleep(100 * time.Millisecond) + + // Check mutex profile fraction + if got := runtime.SetMutexProfileFraction(0); got != tt.expectedMutexFraction { + t.Errorf("runtime.SetMutexProfileFraction() = %v, want %v", got, tt.expectedMutexFraction) + } + runtime.SetMutexProfileFraction(tt.expectedMutexFraction) // Reset to the expected value + + // Perform HTTP GET request to the root path + url := "http://" + server.addr + "/debug/pprof/" + client := &http.Client{} + + t.Run("GET "+url, func(t *testing.T) { + req, err := http.NewRequest("GET", url, nil) + if err != nil { + t.Errorf("http.NewRequest(%s) error = %v", url, err) + } + + resp, err := client.Do(req) + if err != nil { + t.Errorf("http.Client.Do() error = %v", err) + } + + if resp.StatusCode != http.StatusOK { + t.Errorf("http.Client.Do() status = %v, want %v", resp.StatusCode, http.StatusOK) + } + + resp.Body.Close() + }) + + // Stop the server + server.Stop(nil) + + // Ensure the server is stopped + select { + case <-server.done: + // success + case <-time.After(1 * time.Second): + t.Fatal("server did not stop in time") + } + }) + } +}