diff --git a/examples/http-server-using-redis/main_test.go b/examples/http-server-using-redis/main_test.go index 7f6528873..3da2a64f8 100644 --- a/examples/http-server-using-redis/main_test.go +++ b/examples/http-server-using-redis/main_test.go @@ -3,6 +3,7 @@ package main import ( "bytes" "context" + "fmt" "net/http" "testing" "time" @@ -22,6 +23,8 @@ import ( func TestHTTPServerUsingRedis(t *testing.T) { const host = "http://localhost:8000" + t.Setenv("METRICS_PORT", fmt.Sprint(2034)) + go main() time.Sleep(100 * time.Millisecond) // Giving some time to start the server @@ -55,6 +58,9 @@ func TestHTTPServerUsingRedis(t *testing.T) { } func TestRedisSetHandler(t *testing.T) { + t.Setenv("HTTP_PORT", "8085") + t.Setenv("METRICS_PORT", "2036") + a := gofr.New() logger := logging.NewLogger(logging.DEBUG) redisClient, mock := redismock.NewClientMock() @@ -78,6 +84,7 @@ func TestRedisSetHandler(t *testing.T) { } func TestRedisPipelineHandler(t *testing.T) { + t.Setenv("HTTP_PORT", "8086") a := gofr.New() logger := logging.NewLogger(logging.DEBUG) redisClient, mock := redismock.NewClientMock() diff --git a/examples/http-server/main_test.go b/examples/http-server/main_test.go index 8493f4717..79c582474 100644 --- a/examples/http-server/main_test.go +++ b/examples/http-server/main_test.go @@ -143,6 +143,9 @@ func TestIntegration_SimpleAPIServer_Health(t *testing.T) { } func TestRedisHandler(t *testing.T) { + t.Setenv("METRICS_PORT", "2036") + t.Setenv("HTTP_PORT", "8082") + a := gofr.New() logger := logging.NewLogger(logging.DEBUG) redisClient, mock := redismock.NewClientMock() diff --git a/examples/using-add-rest-handlers/main_test.go b/examples/using-add-rest-handlers/main_test.go index 179a9f530..d0d51701c 100644 --- a/examples/using-add-rest-handlers/main_test.go +++ b/examples/using-add-rest-handlers/main_test.go @@ -2,6 +2,7 @@ package main import ( "bytes" + "fmt" "net/http" "testing" "time" @@ -12,6 +13,8 @@ import ( func TestIntegration_AddRESTHandlers(t *testing.T) { const host = "http://localhost:9090" + t.Setenv("METRICS_PORT", fmt.Sprint(2023)) + go main() time.Sleep(100 * time.Millisecond) // Giving some time to start the server diff --git a/examples/using-cron-jobs/main_test.go b/examples/using-cron-jobs/main_test.go index d64e0975f..8a38f1f20 100644 --- a/examples/using-cron-jobs/main_test.go +++ b/examples/using-cron-jobs/main_test.go @@ -1,6 +1,7 @@ package main import ( + "fmt" "testing" "time" @@ -8,6 +9,8 @@ import ( ) func Test_UserPurgeCron(t *testing.T) { + t.Setenv("METRICS_PORT", fmt.Sprint(2022)) + go main() time.Sleep(1100 * time.Millisecond) diff --git a/examples/using-custom-metrics/main_test.go b/examples/using-custom-metrics/main_test.go index 50864e862..3efc552ab 100644 --- a/examples/using-custom-metrics/main_test.go +++ b/examples/using-custom-metrics/main_test.go @@ -11,6 +11,9 @@ import ( func TestIntegration(t *testing.T) { const host = "http://localhost:9011" + t.Setenv("HTTP_PORT", "9011") + t.Setenv("METRICS_PORT", "2120") + go main() time.Sleep(100 * time.Millisecond) // Giving some time to start the server diff --git a/examples/using-publisher/main_test.go b/examples/using-publisher/main_test.go index 8f8e16bf1..f652f0446 100644 --- a/examples/using-publisher/main_test.go +++ b/examples/using-publisher/main_test.go @@ -2,6 +2,7 @@ package main import ( "bytes" + "fmt" "net/http" "testing" "time" @@ -12,6 +13,8 @@ import ( func TestExamplePublisher(t *testing.T) { const host = "http://localhost:8100" + t.Setenv("METRICS_PORT", fmt.Sprint(2032)) + go main() time.Sleep(200 * time.Millisecond) diff --git a/examples/using-subscriber/main_test.go b/examples/using-subscriber/main_test.go index 51eb3fa09..5df68d385 100644 --- a/examples/using-subscriber/main_test.go +++ b/examples/using-subscriber/main_test.go @@ -40,6 +40,9 @@ func initializeTest(t *testing.T) { func TestExampleSubscriber(t *testing.T) { log := testutil.StdoutOutputForFunc(func() { + t.Setenv("HTTP_PORT", "8080") + t.Setenv("METRICS_PORT", "2031") + go main() time.Sleep(time.Second * 1) // Giving some time to start the server diff --git a/examples/using-web-socket/main_test.go b/examples/using-web-socket/main_test.go index 84f53d43b..5ce6c2851 100644 --- a/examples/using-web-socket/main_test.go +++ b/examples/using-web-socket/main_test.go @@ -11,6 +11,7 @@ import ( func Test_WebSocket_Success(t *testing.T) { wsURL := fmt.Sprintf("ws://%s/ws", "localhost:8001") + t.Setenv("METRICS_PORT", fmt.Sprint(2030)) go main() time.Sleep(100 * time.Millisecond) diff --git a/pkg/gofr/gofr.go b/pkg/gofr/gofr.go index 9cd270bb0..43c653d16 100644 --- a/pkg/gofr/gofr.go +++ b/pkg/gofr/gofr.go @@ -4,6 +4,7 @@ import ( "context" "errors" "fmt" + "net" "net/http" "os" "os/signal" @@ -40,6 +41,7 @@ const ( shutDownTimeout = 30 * time.Second gofrTraceExporter = "gofr" gofrTracerURL = "https://tracer.gofr.dev" + checkPortTimeout = 2 * time.Second ) // App is the main application in the GoFr framework. @@ -77,6 +79,10 @@ func New() *App { port = defaultMetricPort } + if !isPortAvailable(port) { + app.container.Logger.Fatalf("metrics port %d is blocked or unreachable", port) + } + app.metricServer = newMetricServer(port) // HTTP Server @@ -94,16 +100,7 @@ func New() *App { app.add(http.MethodGet, "/.well-known/alive", liveHandler) app.add(http.MethodGet, "/favicon.ico", faviconHandler) - // If the openapi.json file exists in the static directory, set up routes for OpenAPI and Swagger documentation. - if _, err = os.Stat("./static/" + gofrHTTP.DefaultSwaggerFileName); err == nil { - // Route to serve the OpenAPI JSON specification file. - app.add(http.MethodGet, "/.well-known/"+gofrHTTP.DefaultSwaggerFileName, OpenAPIHandler) - // Route to serve the Swagger UI, providing a user interface for the API documentation. - app.add(http.MethodGet, "/.well-known/swagger", SwaggerUIHandler) - // Catchall route: any request to /.well-known/{name} (e.g., /.well-known/other) - // will be handled by the SwaggerUIHandler, serving the Swagger UI. - app.add(http.MethodGet, "/.well-known/{name}", SwaggerUIHandler) - } + app.checkAndAddOpenAPIDocumentation() if app.Config.Get("APP_ENV") == "DEBUG" { app.httpServer.RegisterProfilingRoutes() @@ -130,6 +127,19 @@ func New() *App { return app } +func (a *App) checkAndAddOpenAPIDocumentation() { + // If the openapi.json file exists in the static directory, set up routes for OpenAPI and Swagger documentation. + if _, err := os.Stat("./static/" + gofrHTTP.DefaultSwaggerFileName); err == nil { + // Route to serve the OpenAPI JSON specification file. + a.add(http.MethodGet, "/.well-known/"+gofrHTTP.DefaultSwaggerFileName, OpenAPIHandler) + // Route to serve the Swagger UI, providing a user interface for the API documentation. + a.add(http.MethodGet, "/.well-known/swagger", SwaggerUIHandler) + // Catchall route: any request to /.well-known/{name} (e.g., /.well-known/other) + // will be handled by the SwaggerUIHandler, serving the Swagger UI. + a.add(http.MethodGet, "/.well-known/{name}", SwaggerUIHandler) + } +} + // NewCMD creates a command-line application. func NewCMD() *App { app := &App{} @@ -244,6 +254,17 @@ func (a *App) Shutdown(ctx context.Context) error { return err } +func isPortAvailable(port int) bool { + conn, err := net.DialTimeout("tcp", fmt.Sprintf(":%d", port), checkPortTimeout) + if err != nil { + return true + } + + conn.Close() + + return false +} + func (a *App) httpServerSetup() { // TODO: find a way to read REQUEST_TIMEOUT config only once and log it there. currently doing it twice one for populating // the value and other for logging @@ -349,6 +370,10 @@ func (a *App) PATCH(pattern string, handler Handler) { } func (a *App) add(method, pattern string, h Handler) { + if !a.httpRegistered && !isPortAvailable(a.httpServer.port) { + a.container.Logger.Fatalf("http port %d is blocked or unreachable", a.httpServer.port) + } + a.httpRegistered = true reqTimeout, err := strconv.Atoi(a.Config.Get("REQUEST_TIMEOUT")) @@ -694,6 +719,10 @@ func contains(elems []string, v string) bool { // If `filePath` starts with "./", it will be interpreted as a relative path // to the current working directory. func (a *App) AddStaticFiles(endpoint, filePath string) { + if !a.httpRegistered && !isPortAvailable(a.httpServer.port) { + a.container.Logger.Fatalf("http port %d is blocked or unreachable", a.httpServer.port) + } + a.httpRegistered = true if !strings.HasPrefix(filePath, "./") && !filepath.IsAbs(filePath) { diff --git a/pkg/gofr/gofr_test.go b/pkg/gofr/gofr_test.go index ddd437d1d..6c33d613f 100644 --- a/pkg/gofr/gofr_test.go +++ b/pkg/gofr/gofr_test.go @@ -45,6 +45,34 @@ func TestGofr_readConfig(t *testing.T) { } } +func TestGoFr_isPortAvailable(t *testing.T) { + port := testutil.GetFreePort(t) + + tests := []struct { + name string + isAvailable bool + }{ + {"Port is available", true}, + {"Port is not available", false}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if !tt.isAvailable { + t.Setenv("HTTP_PORT", fmt.Sprint(port)) + + g := New() + + go g.Run() + time.Sleep(100 * time.Millisecond) + } + + isAvailable := isPortAvailable(port) + require.Equal(t, tt.isAvailable, isAvailable) + }) + } +} + func TestGofr_ServerRoutes(t *testing.T) { port := testutil.GetFreePort(t) diff --git a/pkg/gofr/grpc.go b/pkg/gofr/grpc.go index db685c83a..0e2b78064 100644 --- a/pkg/gofr/grpc.go +++ b/pkg/gofr/grpc.go @@ -66,6 +66,10 @@ var ( // RegisterService adds a gRPC service to the GoFr application. func (a *App) RegisterService(desc *grpc.ServiceDesc, impl any) { + if !a.grpcRegistered && !isPortAvailable(a.grpcServer.port) { + a.container.Logger.Fatalf("gRPC port %d is blocked or unreachable", a.grpcServer.port) + } + a.container.Logger.Infof("registering gRPC Server: %s", desc.ServiceName) a.grpcServer.server.RegisterService(desc, impl)