Skip to content

Commit

Permalink
server: allow TLS
Browse files Browse the repository at this point in the history
  • Loading branch information
zadjadr committed Jul 21, 2024
1 parent b3fcfda commit 529b31c
Show file tree
Hide file tree
Showing 3 changed files with 317 additions and 16 deletions.
111 changes: 100 additions & 11 deletions cmd/prometheus-cve-exporter/main.go
Original file line number Diff line number Diff line change
@@ -1,9 +1,15 @@
package main

import (
"context"
"crypto/tls"
"fmt"
"log"
"net/http"
"os"
"os/signal"
"syscall"
"time"

"zops.top/prometheus-cve-exporter/config"
"zops.top/prometheus-cve-exporter/internal/exporter"
Expand All @@ -12,23 +18,106 @@ import (
"github.com/prometheus/client_golang/prometheus/promhttp"
)

func main() {
cfg, err := config.Load()
if err != nil {
log.Fatalf("Failed to load configuration: %v", err)
}
type UpdateMetricsFunc func(*config.Config)

go exporter.UpdateMetrics(cfg)
startServer(cfg)
type Server struct {
cfg *config.Config
logger *log.Logger
mux *http.ServeMux
server *http.Server
updateMetrics UpdateMetricsFunc
}

func startServer(cfg *config.Config) {
http.Handle("/metrics", promhttp.HandlerFor(
func NewServer(cfg *config.Config, logger *log.Logger, updateMetrics UpdateMetricsFunc) *Server {
return &Server{
cfg: cfg,
logger: logger,
mux: http.NewServeMux(),
updateMetrics: updateMetrics,
}
}

func (s *Server) SetupRouter() {
s.mux.Handle("/metrics", promhttp.HandlerFor(
prometheus.DefaultGatherer,
promhttp.HandlerOpts{
EnableOpenMetrics: true,
},
))
fmt.Printf("Starting server on :%d\n", cfg.Port)
log.Fatal(http.ListenAndServe(fmt.Sprintf(":%d", cfg.Port), nil))

s.mux.HandleFunc("/", s.homeHandler)
}

func (s *Server) homeHandler(w http.ResponseWriter, r *http.Request) {
if r.URL.Path != "/" {
http.NotFound(w, r)
return
}
w.Header().Set("Content-Type", "text/html; charset=utf-8")
fmt.Fprint(w, `<a href="/metrics">Go to metrics</a>`)
}

func (s *Server) Start() {
s.server = &http.Server{
Addr: fmt.Sprintf(":%d", s.cfg.Port),
Handler: s.mux,
ReadTimeout: 5 * time.Second,
WriteTimeout: 10 * time.Second,
IdleTimeout: 120 * time.Second,
}

if s.cfg.UseTLS {
s.server.TLSConfig = &tls.Config{
MinVersion: tls.VersionTLS12,
PreferServerCipherSuites: true,
}
}

go func() {
var err error
s.logger.Printf("Starting server on :%d\n", s.cfg.Port)
if s.cfg.UseTLS {
s.logger.Println("TLS enabled")
err = s.server.ListenAndServeTLS(s.cfg.TLSCert, s.cfg.TLSKey)
} else {
s.logger.Println("TLS disabled")
err = s.server.ListenAndServe()
}
if err != nil && err != http.ErrServerClosed {
s.logger.Fatalf("Could not listen on %d: %v\n", s.cfg.Port, err)
}
}()
}

func (s *Server) GracefulShutdown() {
quit := make(chan os.Signal, 1)
signal.Notify(quit, syscall.SIGINT, syscall.SIGTERM)
sig := <-quit
s.logger.Printf("Received signal: %v. Initiating shutdown...", sig)

ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()

s.server.SetKeepAlivesEnabled(false)
if err := s.server.Shutdown(ctx); err != nil {
s.logger.Fatalf("Server forced to shutdown: %v", err)
}

s.logger.Println("Server exiting")
}

func main() {
logger := log.New(os.Stdout, "", log.LstdFlags)

cfg, err := config.Load()
if err != nil {
logger.Fatalf("Failed to load configuration: %v", err)
}

server := NewServer(cfg, logger, exporter.UpdateMetrics)
go server.updateMetrics(cfg)

server.SetupRouter()
server.Start()
server.GracefulShutdown()
}
204 changes: 204 additions & 0 deletions cmd/prometheus-cve-exporter/main_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,204 @@
package main

import (
"context"
"fmt"
"log"
"net/http"
"net/http/httptest"
"os"
"sync"
"syscall"
"testing"
"time"

"zops.top/prometheus-cve-exporter/config"
)

func mockUpdateMetrics(*config.Config) {}

func TestNewServer(t *testing.T) {
cfg := &config.Config{}
logger := log.New(os.Stdout, "", log.LstdFlags)
server := NewServer(cfg, logger, mockUpdateMetrics)

if server.cfg != cfg {
t.Errorf("Expected cfg to be %v, got %v", cfg, server.cfg)
}
if server.logger != logger {
t.Errorf("Expected logger to be %v, got %v", logger, server.logger)
}
if server.mux == nil {
t.Error("Expected mux to be initialized")
}
if server.updateMetrics == nil {
t.Error("Expected updateMetrics to be initialized")
}
}

func TestSetupRouter(t *testing.T) {
server := NewServer(&config.Config{}, log.New(os.Stdout, "", log.LstdFlags), mockUpdateMetrics)
server.SetupRouter()

testCases := []struct {
path string
expectedCode int
}{
{"/metrics", http.StatusOK},
{"/", http.StatusOK},
{"/nonexistent", http.StatusNotFound},
}

for _, tc := range testCases {
req, err := http.NewRequest("GET", tc.path, nil)
if err != nil {
t.Fatalf("Could not create request: %v", err)
}

rr := httptest.NewRecorder()
server.mux.ServeHTTP(rr, req)

if rr.Code != tc.expectedCode {
t.Errorf("handler returned wrong status code for %s: got %v want %v",
tc.path, rr.Code, tc.expectedCode)
}
}
}

func TestHomeHandler(t *testing.T) {
server := NewServer(&config.Config{}, log.New(os.Stdout, "", log.LstdFlags), mockUpdateMetrics)

req, err := http.NewRequest("GET", "/", nil)
if err != nil {
t.Fatal(err)
}

rr := httptest.NewRecorder()
handler := http.HandlerFunc(server.homeHandler)
handler.ServeHTTP(rr, req)

if status := rr.Code; status != http.StatusOK {
t.Errorf("handler returned wrong status code: got %v want %v",
status, http.StatusOK)
}

expected := `<a href="/metrics">Go to metrics</a>`
if rr.Body.String() != expected {
t.Errorf("handler returned unexpected body: got %v want %v",
rr.Body.String(), expected)
}
}

func TestStart(t *testing.T) {
cfg := &config.Config{Port: 20000}
logger := log.New(os.Stdout, "", log.LstdFlags)
server := NewServer(cfg, logger, mockUpdateMetrics)
server.SetupRouter()

var wg sync.WaitGroup
wg.Add(1)
go func() {
defer wg.Done()
server.Start()
}()

// Give some time for the server to start
time.Sleep(100 * time.Millisecond)

resp, err := http.Get(fmt.Sprintf("http://localhost:%d", cfg.Port))
if err != nil {
t.Fatalf("Could not send GET request: %v", err)
}
defer resp.Body.Close()

if resp.StatusCode != http.StatusOK {
t.Errorf("Expected status OK; got %v", resp.Status)
}

// Shutdown the server
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
server.server.Shutdown(ctx)

Check failure on line 121 in cmd/prometheus-cve-exporter/main_test.go

View workflow job for this annotation

GitHub Actions / test

Error return value of `server.server.Shutdown` is not checked (errcheck)

Check failure on line 121 in cmd/prometheus-cve-exporter/main_test.go

View workflow job for this annotation

GitHub Actions / test

Error return value of `server.server.Shutdown` is not checked (errcheck)

wg.Wait()
}

func TestUpdateMetricsExecution(t *testing.T) {
updateMetricsCalled := false
mockUpdateMetrics := func(*config.Config) {
updateMetricsCalled = true
}

cfg := &config.Config{}
logger := log.New(os.Stdout, "", log.LstdFlags)
server := NewServer(cfg, logger, mockUpdateMetrics)

server.updateMetrics(cfg)

if !updateMetricsCalled {
t.Error("Expected UpdateMetrics to be called")
}
}

func TestMainIntegration(t *testing.T) {
// Backup original os.Args
oldArgs := os.Args
defer func() { os.Args = oldArgs }()

// Set up a test config file
testConfigPath := "test_config.json"
testPort := 20001
testConfigContent := []byte(fmt.Sprintf(`{
"nvd_feed_url": "https://test.nvd.feed.url",
"update_interval": "2h",
"port": %d,
"severity": ["HIGH", "CRITICAL"],
"package_file": "",
"use_tls": false
}`, testPort))
err := os.WriteFile(testConfigPath, testConfigContent, 0644)
if err != nil {
t.Fatalf("Failed to create test config file: %v", err)
}
defer os.Remove(testConfigPath)

// Set the command-line argument to use our test config
os.Args = []string{"cmd", "-config", testConfigPath}

// Run main in a goroutine
go func() {
main()
}()

// Give some time for the server to start
time.Sleep(100 * time.Millisecond)

// Test if the server is running
resp, err := http.Get(fmt.Sprintf("http://localhost:%d", testPort))
if err != nil {
t.Fatalf("Could not send GET request: %v", err)
}
defer resp.Body.Close()

if resp.StatusCode != http.StatusOK {
t.Errorf("Expected status OK; got %v", resp.Status)
}

// Send shutdown signal
syscall.Kill(syscall.Getpid(), syscall.SIGINT)

Check failure on line 188 in cmd/prometheus-cve-exporter/main_test.go

View workflow job for this annotation

GitHub Actions / test

Error return value of `syscall.Kill` is not checked (errcheck)

Check failure on line 188 in cmd/prometheus-cve-exporter/main_test.go

View workflow job for this annotation

GitHub Actions / test

Error return value of `syscall.Kill` is not checked (errcheck)

// Give some time for the server to shut down
time.Sleep(500 * time.Millisecond)

// Verify that the server has shut down
_, err = http.Get(fmt.Sprintf("http://localhost:%d", testPort))
if err == nil {
t.Error("Expected an error when connecting to a shutdown server")
}
}

func TestMain(m *testing.M) {
// Run tests
code := m.Run()
os.Exit(code)
}
18 changes: 13 additions & 5 deletions config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,10 @@ func Load() (*Config, error) {
return nil, err
}

if cfg.TLSCert != defaultTLSCert && cfg.TLSKey != defaultTLSKey {
cfg.UseTLS = true
}

fmt.Print(prettyfyCfg(cfg))
return cfg, nil
}
Expand All @@ -83,7 +87,6 @@ func parseFlags(cfg *Config) string {
flag.DurationVar(&cfg.UpdateInterval, "update-interval", defaultUpdateInterval, "Update interval duration")
flag.IntVar(&cfg.Port, "port", defaultPort, "Port to run the server on")
flag.StringVar(&cfg.PackageFile, "package-file", defaultPackageFile, "Path to file containing packages and versions")
flag.BoolVar(&cfg.UseTLS, "use-tls", defaultUseTLS, "Use TLS for the server")
flag.StringVar(&cfg.TLSCert, "tls-cert", defaultTLSCert, "Path to TLS certificate file")
flag.StringVar(&cfg.TLSKey, "tls-key", defaultTLSKey, "Path to TLS key file")

Expand Down Expand Up @@ -160,17 +163,22 @@ func validateConfig(cfg *Config) error {
return fmt.Errorf("the file %s does not exist", cfg.PackageFile)
}
}
if cfg.UseTLS {
if cfg.TLSCert == "" || cfg.TLSKey == "" {
return fmt.Errorf("TLS is enabled but certificate or key file is not specified")
}
tlsFilesProvided := 0
if cfg.TLSCert != defaultTLSCert {
tlsFilesProvided += 1
if _, err := os.Stat(cfg.TLSCert); os.IsNotExist(err) {
return fmt.Errorf("the TLS certificate file %s does not exist", cfg.TLSCert)
}
}
if cfg.TLSKey != defaultTLSKey {
tlsFilesProvided += 1
if _, err := os.Stat(cfg.TLSKey); os.IsNotExist(err) {
return fmt.Errorf("the TLS key file %s does not exist", cfg.TLSKey)
}
}
if tlsFilesProvided != 2 {
return fmt.Errorf("you must provide both TLSKey and TLSCert - provided %d file(s)", tlsFilesProvided)
}
return nil
}

Expand Down

0 comments on commit 529b31c

Please sign in to comment.