Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature: Add DoRedirects, DoTimeout and DoDeadline to Proxy middleware #2332

Merged
merged 6 commits into from
Feb 24, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 36 additions & 0 deletions middleware/proxy/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,12 @@ func Balancer(config Config) fiber.Handler
func Forward(addr string, clients ...*fasthttp.Client) fiber.Handler
// Do performs the given http request and fills the given http response.
func Do(c *fiber.Ctx, addr string, clients ...*fasthttp.Client) error
// DoRedirects performs the given http request and fills the given http response while following up to maxRedirectsCount redirects.
func DoRedirects(c *fiber.Ctx, addr string, maxRedirectsCount int, clients ...*fasthttp.Client) error
// DoDeadline performs the given request and waits for response until the given deadline.
func DoDeadline(c *fiber.Ctx, addr string, deadline time.Time, clients ...*fasthttp.Client) error
// DoTimeout performs the given request and waits for response during the given timeout duration.
func DoTimeout(c *fiber.Ctx, addr string, timeout time.Duration, clients ...*fasthttp.Client) error
// DomainForward the given http request based on the given domain and fills the given http response
func DomainForward(hostname string, addr string, clients ...*fasthttp.Client) fiber.Handler
// BalancerForward performs the given http request based round robin balancer and fills the given http response
Expand Down Expand Up @@ -73,6 +79,36 @@ app.Get("/:id", func(c *fiber.Ctx) error {
return nil
})

// Make proxy requests while following redirects
app.Get("/proxy", func(c *fiber.Ctx) error {
if err := proxy.DoRedirects(c, "http://google.com", 3); err != nil {
return err
}
// Remove Server header from response
c.Response().Header.Del(fiber.HeaderServer)
return nil
})

// Make proxy requests and wait up to 5 seconds before timing out
app.Get("/proxy", func(c *fiber.Ctx) error {
if err := proxy.DoTimeout(c, "http://localhost:3000", time.Second * 5); err != nil {
return err
}
// Remove Server header from response
c.Response().Header.Del(fiber.HeaderServer)
return nil
})

// Make proxy requests, timeout a minute from now
app.Get("/proxy", func(c *fiber.Ctx) error {
if err := DoDeadline(c, "http://localhost", time.Now().Add(time.Minute)); err != nil {
return err
}
// Remove Server header from response
c.Response().Header.Del(fiber.HeaderServer)
return nil
})

// Minimal round robin balancer
app.Use(proxy.Balancer(proxy.Config{
Servers: []string{
Expand Down
47 changes: 42 additions & 5 deletions middleware/proxy/proxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"net/url"
"strings"
"sync"
"time"

"github.com/gofiber/fiber/v2"
"github.com/gofiber/fiber/v2/utils"
Expand Down Expand Up @@ -139,16 +140,53 @@ func Forward(addr string, clients ...*fasthttp.Client) fiber.Handler {
// Do performs the given http request and fills the given http response.
// This method can be used within a fiber.Handler
func Do(c *fiber.Ctx, addr string, clients ...*fasthttp.Client) error {
return doAction(c, addr, func(cli *fasthttp.Client, req *fasthttp.Request, resp *fasthttp.Response) error {
return cli.Do(req, resp)
}, clients...)
}

// DoRedirects performs the given http request and fills the given http response, following up to maxRedirectsCount redirects.
// When the redirect count exceeds maxRedirectsCount, ErrTooManyRedirects is returned.
// This method can be used within a fiber.Handler
func DoRedirects(c *fiber.Ctx, addr string, maxRedirectsCount int, clients ...*fasthttp.Client) error {
return doAction(c, addr, func(cli *fasthttp.Client, req *fasthttp.Request, resp *fasthttp.Response) error {
return cli.DoRedirects(req, resp, maxRedirectsCount)
}, clients...)
}

// DoDeadline performs the given request and waits for response until the given deadline.
// This method can be used within a fiber.Handler
func DoDeadline(c *fiber.Ctx, addr string, deadline time.Time, clients ...*fasthttp.Client) error {
return doAction(c, addr, func(cli *fasthttp.Client, req *fasthttp.Request, resp *fasthttp.Response) error {
return cli.DoDeadline(req, resp, deadline)
}, clients...)
}

// DoTimeout performs the given request and waits for response during the given timeout duration.
// This method can be used within a fiber.Handler
func DoTimeout(c *fiber.Ctx, addr string, timeout time.Duration, clients ...*fasthttp.Client) error {
return doAction(c, addr, func(cli *fasthttp.Client, req *fasthttp.Request, resp *fasthttp.Response) error {
return cli.DoTimeout(req, resp, timeout)
}, clients...)
}

func doAction(
c *fiber.Ctx,
addr string,
action func(cli *fasthttp.Client, req *fasthttp.Request, resp *fasthttp.Response) error,
clients ...*fasthttp.Client,
) error {
var cli *fasthttp.Client

// set local or global client
if len(clients) != 0 {
// Set local client
cli = clients[0]
} else {
// Set global client
lock.RLock()
cli = client
lock.RUnlock()
}

req := c.Request()
res := c.Response()
originalURL := utils.CopyString(c.OriginalURL())
Expand All @@ -157,14 +195,13 @@ func Do(c *fiber.Ctx, addr string, clients ...*fasthttp.Client) error {
copiedURL := utils.CopyString(addr)
req.SetRequestURI(copiedURL)
// NOTE: if req.isTLS is true, SetRequestURI keeps the scheme as https.
// issue reference:
// https://github.com/gofiber/fiber/issues/1762
// Reference: https://github.com/gofiber/fiber/issues/1762
if scheme := getScheme(utils.UnsafeBytes(copiedURL)); len(scheme) > 0 {
req.URI().SetSchemeBytes(scheme)
}

req.Header.Del(fiber.HeaderConnection)
if err := cli.Do(req, res); err != nil {
if err := action(cli, req, res); err != nil {
return err
}
res.Header.Del(fiber.HeaderConnection)
Expand Down
179 changes: 168 additions & 11 deletions middleware/proxy/proxy_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package proxy

import (
"crypto/tls"
"errors"
"io"
"net"
"net/http/httptest"
Expand Down Expand Up @@ -48,6 +49,19 @@ func Test_Proxy_Empty_Upstream_Servers(t *testing.T) {
app.Use(Balancer(Config{Servers: []string{}}))
}

// go test -run Test_Proxy_Empty_Config
func Test_Proxy_Empty_Config(t *testing.T) {
t.Parallel()

defer func() {
if r := recover(); r != nil {
utils.AssertEqual(t, "Servers cannot be empty", r)
}
}()
app := fiber.New()
app.Use(New(Config{}))
}

// go test -run Test_Proxy_Next
func Test_Proxy_Next(t *testing.T) {
t.Parallel()
Expand Down Expand Up @@ -345,24 +359,167 @@ func Test_Proxy_Buffer_Size_Response(t *testing.T) {
// go test -race -run Test_Proxy_Do_RestoreOriginalURL
func Test_Proxy_Do_RestoreOriginalURL(t *testing.T) {
t.Parallel()
_, addr := createProxyTestServer(t, func(c *fiber.Ctx) error {
return c.SendString("proxied")
})

app := fiber.New()
app.Get("/proxy", func(c *fiber.Ctx) error {
return c.SendString("ok")
app.Get("/test", func(c *fiber.Ctx) error {
return Do(c, "http://"+addr)
})
resp, err1 := app.Test(httptest.NewRequest(fiber.MethodGet, "/test", nil))
utils.AssertEqual(t, nil, err1)
utils.AssertEqual(t, "/test", resp.Request.URL.String())
utils.AssertEqual(t, fiber.StatusOK, resp.StatusCode)
body, err := io.ReadAll(resp.Body)
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, "proxied", string(body))
}

// go test -race -run Test_Proxy_Do_WithRealURL
func Test_Proxy_Do_WithRealURL(t *testing.T) {
t.Parallel()
app := fiber.New()
app.Get("/test", func(c *fiber.Ctx) error {
originalURL := utils.CopyString(c.OriginalURL())
if err := Do(c, "/proxy"); err != nil {
return err
}
utils.AssertEqual(t, originalURL, c.OriginalURL())
return c.SendString("ok")
return Do(c, "https://www.google.com")
})

resp, err1 := app.Test(httptest.NewRequest(fiber.MethodGet, "/test", nil))
utils.AssertEqual(t, nil, err1)
utils.AssertEqual(t, fiber.StatusOK, resp.StatusCode)
utils.AssertEqual(t, "/test", resp.Request.URL.String())
body, err := io.ReadAll(resp.Body)
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, true, strings.Contains(string(body), "https://www.google.com/"))
}

// go test -race -run Test_Proxy_Do_WithRedirect
func Test_Proxy_Do_WithRedirect(t *testing.T) {
t.Parallel()
app := fiber.New()
app.Get("/test", func(c *fiber.Ctx) error {
return Do(c, "https://google.com")
})

resp, err1 := app.Test(httptest.NewRequest(fiber.MethodGet, "/test", nil))
utils.AssertEqual(t, nil, err1)
body, err := io.ReadAll(resp.Body)
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, true, strings.Contains(string(body), "https://www.google.com/"))
utils.AssertEqual(t, 301, resp.StatusCode)
}

// go test -race -run Test_Proxy_DoRedirects_RestoreOriginalURL
func Test_Proxy_DoRedirects_RestoreOriginalURL(t *testing.T) {
t.Parallel()
app := fiber.New()
app.Get("/test", func(c *fiber.Ctx) error {
return DoRedirects(c, "http://google.com", 1)
})

resp, err1 := app.Test(httptest.NewRequest(fiber.MethodGet, "/test", nil))
utils.AssertEqual(t, nil, err1)
_, err := io.ReadAll(resp.Body)
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, fiber.StatusOK, resp.StatusCode)
utils.AssertEqual(t, "/test", resp.Request.URL.String())
}

// go test -race -run Test_Proxy_DoRedirects_TooManyRedirects
func Test_Proxy_DoRedirects_TooManyRedirects(t *testing.T) {
t.Parallel()
app := fiber.New()
app.Get("/test", func(c *fiber.Ctx) error {
return DoRedirects(c, "http://google.com", 0)
})

resp, err1 := app.Test(httptest.NewRequest(fiber.MethodGet, "/test", nil))
utils.AssertEqual(t, nil, err1)
body, err := io.ReadAll(resp.Body)
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, "too many redirects detected when doing the request", string(body))
utils.AssertEqual(t, fiber.StatusInternalServerError, resp.StatusCode)
utils.AssertEqual(t, "/test", resp.Request.URL.String())
}

// go test -race -run Test_Proxy_DoTimeout_RestoreOriginalURL
func Test_Proxy_DoTimeout_RestoreOriginalURL(t *testing.T) {
t.Parallel()

_, addr := createProxyTestServer(t, func(c *fiber.Ctx) error {
return c.SendString("proxied")
})

app := fiber.New()
app.Get("/test", func(c *fiber.Ctx) error {
return DoTimeout(c, "http://"+addr, time.Second)
})

resp, err1 := app.Test(httptest.NewRequest(fiber.MethodGet, "/test", nil))
utils.AssertEqual(t, nil, err1)
body, err := io.ReadAll(resp.Body)
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, "proxied", string(body))
utils.AssertEqual(t, fiber.StatusOK, resp.StatusCode)
utils.AssertEqual(t, "/test", resp.Request.URL.String())
}

// go test -race -run Test_Proxy_DoTimeout_Timeout
func Test_Proxy_DoTimeout_Timeout(t *testing.T) {
t.Parallel()

_, addr := createProxyTestServer(t, func(c *fiber.Ctx) error {
time.Sleep(time.Second * 5)
return c.SendString("proxied")
})

app := fiber.New()
app.Get("/test", func(c *fiber.Ctx) error {
return DoTimeout(c, "http://"+addr, time.Second)
})

_, err1 := app.Test(httptest.NewRequest(fiber.MethodGet, "/test", nil))
// This test requires multiple requests due to zero allocation used in fiber
_, err2 := app.Test(httptest.NewRequest(fiber.MethodGet, "/test", nil))
utils.AssertEqual(t, errors.New("test: timeout error 1000ms"), err1)
}

// go test -race -run Test_Proxy_DoDeadline_RestoreOriginalURL
func Test_Proxy_DoDeadline_RestoreOriginalURL(t *testing.T) {
t.Parallel()

_, addr := createProxyTestServer(t, func(c *fiber.Ctx) error {
return c.SendString("proxied")
})

app := fiber.New()
app.Get("/test", func(c *fiber.Ctx) error {
return DoDeadline(c, "http://"+addr, time.Now().Add(time.Second))
})

resp, err1 := app.Test(httptest.NewRequest(fiber.MethodGet, "/test", nil))
utils.AssertEqual(t, nil, err1)
utils.AssertEqual(t, nil, err2)
body, err := io.ReadAll(resp.Body)
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, "proxied", string(body))
utils.AssertEqual(t, fiber.StatusOK, resp.StatusCode)
utils.AssertEqual(t, "/test", resp.Request.URL.String())
}

// go test -race -run Test_Proxy_DoDeadline_PastDeadline
func Test_Proxy_DoDeadline_PastDeadline(t *testing.T) {
t.Parallel()

_, addr := createProxyTestServer(t, func(c *fiber.Ctx) error {
time.Sleep(time.Second * 5)
return c.SendString("proxied")
})

app := fiber.New()
app.Get("/test", func(c *fiber.Ctx) error {
return DoDeadline(c, "http://"+addr, time.Now().Add(time.Second))
})

_, err1 := app.Test(httptest.NewRequest(fiber.MethodGet, "/test", nil))
utils.AssertEqual(t, errors.New("test: timeout error 1000ms"), err1)
}

// go test -race -run Test_Proxy_Do_HTTP_Prefix_URL
Expand Down