diff --git a/README.md b/README.md index f96ac2f..e13a0d5 100644 --- a/README.md +++ b/README.md @@ -25,6 +25,9 @@ This is a generic middleware to rate-limit HTTP requests. **v7.x.x:** Replaced `time/rate` with `embedded time/rate` so that we can support more rate limit headers. +**v8.x.x:** Address `RemoteIP` vulnerability concern by replacing it with `RemoteIPFromIPLookup`, an explicit way to pick the IP address. + + ## Five Minute Tutorial ```go @@ -34,6 +37,7 @@ import ( "net/http" "github.com/didip/tollbooth/v7" + "github.com/didip/tollbooth/v7/limiter" ) func HelloHandler(w http.ResponseWriter, req *http.Request) { @@ -42,7 +46,15 @@ func HelloHandler(w http.ResponseWriter, req *http.Request) { func main() { // Create a request limiter per handler. - http.Handle("/", tollbooth.LimitFuncHandler(tollbooth.NewLimiter(1, nil), HelloHandler)) + lmt := tollbooth.NewLimiter(1, nil) + + // New in version >= 8, you must explicitly define how to pick the IP address. + lmt.SetIPLookup(limiter.IPLookup{ + Name: "X-Real-IP", + IndexFromRight: 0, + }) + + http.Handle("/", tollbooth.LimitFuncHandler(lmt, HelloHandler)) http.ListenAndServe(":12345", nil) } ``` @@ -66,10 +78,24 @@ func main() { // every token bucket in it will expire 1 hour after it was initially set. lmt = tollbooth.NewLimiter(1, &limiter.ExpirableOptions{DefaultExpirationTTL: time.Hour}) - // Configure list of places to look for IP address. - // By default it's: "RemoteAddr", "X-Forwarded-For", "X-Real-IP" - // If your application is behind a proxy, set "X-Forwarded-For" first. - lmt.SetIPLookups([]string{"RemoteAddr", "X-Forwarded-For", "X-Real-IP"}) + // New in version >= 8, you must explicitly define how to pick the IP address. + // If IP address cannot be found, rate limiter will not be activated. + lmt.SetIPLookup(limiter.IPLookup{ + // The name of lookup method. + // Possible options are: RemoteAddr, X-Forwarded-For, X-Real-IP, CF-Connecting-IP + // All other headers are considered unknown and will be ignored. + Name: "X-Real-IP", + + // The index position to pick the ip address from a comma separated list. + // The index goes from right to left. + // + // When there are multiple of the same headers, + // we will concat them together in the order of first to last seen. + // And then we pick the IP using this index position. + IndexFromRight: 0, + }) + + // In version >= 8, lmt.SetIPLookups and lmt.GetIPLookups are removed. // Limit only GET and POST requests. lmt.SetMethods([]string{"GET", "POST"}) @@ -89,8 +115,7 @@ func main() { lmt.RemoveHeaderEntries("X-Access-Token", []string{"limitless-token"}) // By the way, the setters are chainable. Example: - lmt.SetIPLookups([]string{"RemoteAddr", "X-Forwarded-For", "X-Real-IP"}). - SetMethods([]string{"GET", "POST"}). + lmt.SetMethods([]string{"GET", "POST"}). SetBasicAuthUsers([]string{"sansa"}). SetBasicAuthUsers([]string{"tyrion"}) ``` @@ -137,6 +162,12 @@ func main() { ```go lmt := tollbooth.NewLimiter(1, nil) + // New in version >= 8, you must explicitly define how to pick the IP address. + lmt.SetIPLookup(limiter.IPLookup{ + Name: "X-Forwarded-For", + IndexFromRight: 0, + }) + // Set a custom message. lmt.SetMessage("You have reached maximum request limit.") diff --git a/libstring/libstring.go b/libstring/libstring.go index 730b654..0e6b334 100644 --- a/libstring/libstring.go +++ b/libstring/libstring.go @@ -5,6 +5,8 @@ import ( "net" "net/http" "strings" + + "github.com/didip/tollbooth/v7/limiter" ) // StringInSlice finds needle in a slice of strings. @@ -17,38 +19,35 @@ func StringInSlice(sliceString []string, needle string) bool { return false } -// RemoteIP finds IP Address given http.Request struct. -func RemoteIP(ipLookups []string, forwardedForIndexFromBehind int, r *http.Request) string { - realIP := r.Header.Get("X-Real-IP") - forwardedFor := r.Header.Get("X-Forwarded-For") - - for _, lookup := range ipLookups { - if lookup == "RemoteAddr" { - // 1. Cover the basic use cases for both ipv4 and ipv6 - ip, _, err := net.SplitHostPort(r.RemoteAddr) - if err != nil { - // 2. Upon error, just return the remote addr. - return r.RemoteAddr - } - return ip +// RemoteIPFromIPLookup picks an ip address explicitly from limiter.IPLookup criteria. +// This function is intended to replace RemoteIP function. +func RemoteIPFromIPLookup(ipLookup limiter.IPLookup, r *http.Request) string { + switch ipLookup.Name { + case "RemoteAddr": + // 1. Cover the basic use cases for both ipv4 and ipv6 + ip, _, err := net.SplitHostPort(r.RemoteAddr) + if err != nil { + // 2. Upon error, just return the remote addr. + return r.RemoteAddr } - if lookup == "X-Forwarded-For" && forwardedFor != "" { - // X-Forwarded-For is potentially a list of addresses separated with "," - parts := strings.Split(forwardedFor, ",") - for i, p := range parts { - parts[i] = strings.TrimSpace(p) - } + return ip - partIndex := len(parts) - 1 - forwardedForIndexFromBehind - if partIndex < 0 { - partIndex = 0 - } + case "X-Forwarded-For", "X-Real-IP", "CF-Connecting-IP": + ipAddrListCommaSeparated := r.Header.Values(ipLookup.Name) - return parts[partIndex] + ipAddrCommaSeparated := strings.Join(ipAddrListCommaSeparated, ",") + + ips := strings.Split(ipAddrCommaSeparated, ",") + for i, p := range ips { + ips[i] = strings.TrimSpace(p) } - if lookup == "X-Real-IP" && realIP != "" { - return realIP + + ipIndex := len(ips) - 1 - ipLookup.IndexFromRight + if ipIndex < 0 { + ipIndex = 0 } + + return ips[ipIndex] } return "" diff --git a/libstring/libstring_test.go b/libstring/libstring_test.go index 08abe3c..aed70ce 100644 --- a/libstring/libstring_test.go +++ b/libstring/libstring_test.go @@ -4,6 +4,8 @@ import ( "net/http" "strings" "testing" + + "github.com/didip/tollbooth/v7/limiter" ) func TestStringInSlice(t *testing.T) { @@ -12,28 +14,7 @@ func TestStringInSlice(t *testing.T) { } } -func TestRemoteIPDefault(t *testing.T) { - ipLookups := []string{"RemoteAddr", "X-Real-IP"} - ipv6 := "2601:7:1c82:4097:59a0:a80b:2841:b8c8" - - request, err := http.NewRequest("GET", "/", strings.NewReader("Hello, world!")) - if err != nil { - t.Errorf("Unable to create new HTTP request. Error: %v", err) - } - - request.Header.Set("X-Real-IP", ipv6) - - ip := RemoteIP(ipLookups, 0, request) - if ip != request.RemoteAddr { - t.Errorf("Did not get the right IP. IP: %v", ip) - } - if ip == ipv6 { - t.Errorf("X-Real-IP should have been skipped. IP: %v", ip) - } -} - func TestRemoteIPForwardedFor(t *testing.T) { - ipLookups := []string{"X-Forwarded-For", "X-Real-IP", "RemoteAddr"} ipv6 := "2601:7:1c82:4097:59a0:a80b:2841:b8c8" request, err := http.NewRequest("GET", "/", strings.NewReader("Hello, world!")) @@ -44,7 +25,11 @@ func TestRemoteIPForwardedFor(t *testing.T) { request.Header.Set("X-Forwarded-For", "10.10.10.10") request.Header.Set("X-Real-IP", ipv6) - ip := RemoteIP(ipLookups, 0, request) + ip := RemoteIPFromIPLookup(limiter.IPLookup{ + Name: "X-Forwarded-For", + IndexFromRight: 0, + }, request) + if ip != "10.10.10.10" { t.Errorf("Did not get the right IP. IP: %v", ip) } @@ -54,7 +39,6 @@ func TestRemoteIPForwardedFor(t *testing.T) { } func TestRemoteIPRealIP(t *testing.T) { - ipLookups := []string{"X-Real-IP", "X-Forwarded-For", "RemoteAddr"} ipv6 := "2601:7:1c82:4097:59a0:a80b:2841:b8c8" request, err := http.NewRequest("GET", "/", strings.NewReader("Hello, world!")) @@ -65,7 +49,11 @@ func TestRemoteIPRealIP(t *testing.T) { request.Header.Set("X-Forwarded-For", "10.10.10.10") request.Header.Set("X-Real-IP", ipv6) - ip := RemoteIP(ipLookups, 0, request) + ip := RemoteIPFromIPLookup(limiter.IPLookup{ + Name: "X-Real-IP", + IndexFromRight: 0, + }, request) + if ip != ipv6 { t.Errorf("Did not get the right IP. IP: %v", ip) } @@ -74,53 +62,64 @@ func TestRemoteIPRealIP(t *testing.T) { } } -func TestRemoteIPMultipleForwardedFor(t *testing.T) { - ipLookups := []string{"X-Forwarded-For", "X-Real-IP", "RemoteAddr"} - ipv6 := "2601:7:1c82:4097:59a0:a80b:2841:b8c8" - +func TestRemoteIPMultipleForwardedForIPAddresses(t *testing.T) { request, err := http.NewRequest("GET", "/", strings.NewReader("Hello, world!")) if err != nil { t.Errorf("Unable to create new HTTP request. Error: %v", err) } - request.Header.Set("X-Real-IP", ipv6) - - // Missing X-Forwarded-For should not break things - ip := RemoteIP(ipLookups, 0, request) - if ip != ipv6 { - t.Errorf("X-Real-IP should have been chosen because X-Forwarded-For is missing. IP: %v", ip) - } - request.Header.Set("X-Forwarded-For", "10.10.10.10,10.10.10.11") + ip := RemoteIPFromIPLookup(limiter.IPLookup{ + Name: "X-Forwarded-For", + IndexFromRight: 0, + }, request) + // Should get the last one - ip = RemoteIP(ipLookups, 0, request) if ip != "10.10.10.11" { t.Errorf("Did not get the right IP. IP: %v", ip) } - if ip == ipv6 { - t.Errorf("X-Real-IP should have been skipped. IP: %v", ip) - } + + ip = RemoteIPFromIPLookup(limiter.IPLookup{ + Name: "X-Forwarded-For", + IndexFromRight: 1, + }, request) // Should get the 2nd from last - ip = RemoteIP(ipLookups, 1, request) if ip != "10.10.10.10" { t.Errorf("Did not get the right IP. IP: %v", ip) } - if ip == ipv6 { - t.Errorf("X-Real-IP should have been skipped. IP: %v", ip) - } // What about index out of bound? RemoteIP should simply choose index 0. - ip = RemoteIP(ipLookups, 2, request) + ip = RemoteIPFromIPLookup(limiter.IPLookup{ + Name: "X-Forwarded-For", + IndexFromRight: 2, + }, request) + if ip != "10.10.10.10" { t.Errorf("Did not get the right IP. IP: %v", ip) } - if ip == ipv6 { - t.Errorf("X-Real-IP should have been skipped. IP: %v", ip) - } } +func TestRemoteIPMultipleForwardedForHeaders(t *testing.T) { + request, err := http.NewRequest("GET", "/", strings.NewReader("Hello, world!")) + if err != nil { + t.Errorf("Unable to create new HTTP request. Error: %v", err) + } + + request.Header.Add("X-Forwarded-For", "8.8.8.8,8.8.4.4") + request.Header.Add("X-Forwarded-For", "10.10.10.10,10.10.10.11") + + ip := RemoteIPFromIPLookup(limiter.IPLookup{ + Name: "X-Forwarded-For", + IndexFromRight: 0, + }, request) + + // Should get the last header and the last IP + if ip != "10.10.10.11" { + t.Errorf("Did not get the right IP. IP: %v", ip) + } +} func TestCanonicalizeIP(t *testing.T) { tests := []struct { name string @@ -169,10 +168,12 @@ func TestCanonicalizeIP(t *testing.T) { }, } for _, tt := range tests { - tt := tt + ip := tt.ip + want := tt.want + t.Run(tt.name, func(t *testing.T) { - if got := CanonicalizeIP(tt.ip); got != tt.want { - t.Errorf("CanonicalizeIP() = %v, want %v", got, tt.want) + if got := CanonicalizeIP(ip); got != want { + t.Errorf("CanonicalizeIP() = %v, want %v", got, want) } }) } diff --git a/limiter/limiter.go b/limiter/limiter.go index 0153dc8..38636e6 100644 --- a/limiter/limiter.go +++ b/limiter/limiter.go @@ -19,7 +19,6 @@ func New(generalExpirableOptions *ExpirableOptions) *Limiter { SetMessage("You have reached maximum request limit."). SetStatusCode(429). SetOnLimitReached(nil). - SetIPLookups([]string{"RemoteAddr", "X-Forwarded-For", "X-Real-IP"}). SetForwardedForIndexFromBehind(0). SetHeaders(make(map[string][]string)). SetContextValues(make(map[string][]string)). @@ -43,6 +42,18 @@ func New(generalExpirableOptions *ExpirableOptions) *Limiter { return lmt } +// IPLookup is a config struct to define how users want to pick the remote IP address. +type IPLookup struct { + // The name of lookup method. + // Possible options are: RemoteAddr, X-Forwarded-For, X-Real-IP, CF-Connecting-IP + // All other headers are considered unknown and will be ignored. + Name string + + // The index position to pick the ip address from a comma separated list. + // The index goes from right to left. + IndexFromRight int +} + // Limiter is a config struct to limit a particular request handler. type Limiter struct { // Maximum number of requests to limit per second. @@ -66,10 +77,9 @@ type Limiter struct { // An option to write back what you want upon reaching a limit. overrideDefaultResponseWriter bool - // List of places to look up IP address. - // Default is "RemoteAddr", "X-Forwarded-For", "X-Real-IP". - // You can rearrange the order as you like. - ipLookups []string + // Explicitly define how to look up IP address. + // This is intended to replace ipLookups + explicitIPLookup IPLookup forwardedForIndex int @@ -270,10 +280,12 @@ func (l *Limiter) ExecOnLimitReached(w http.ResponseWriter, r *http.Request) { } // SetOverrideDefaultResponseWriter is a thread-safe way of setting the response writer override variable. -func (l *Limiter) SetOverrideDefaultResponseWriter(override bool) { +func (l *Limiter) SetOverrideDefaultResponseWriter(override bool) *Limiter { l.Lock() l.overrideDefaultResponseWriter = override l.Unlock() + + return l } // GetOverrideDefaultResponseWriter is a thread-safe way of getting the response writer override variable. @@ -283,20 +295,22 @@ func (l *Limiter) GetOverrideDefaultResponseWriter() bool { return l.overrideDefaultResponseWriter } -// SetIPLookups is thread-safe way of setting list of places to look up IP address. -func (l *Limiter) SetIPLookups(ipLookups []string) *Limiter { +// SetIPLookup is thread-safe way of setting an explicit way to look up IP address. +// This method is intended to replace SetIPLookups (version 6 or older). +func (l *Limiter) SetIPLookup(lookup IPLookup) *Limiter { l.Lock() - l.ipLookups = ipLookups + l.explicitIPLookup = lookup l.Unlock() return l } -// GetIPLookups is thread-safe way of getting list of places to look up IP address. -func (l *Limiter) GetIPLookups() []string { +// GetIPLookup is thread-safe way of getting an explicit way to look up IP address. +// This method is intended to replace the old GetIPLookups (version 6 or older). +func (l *Limiter) GetIPLookup() IPLookup { l.RLock() defer l.RUnlock() - return l.ipLookups + return l.explicitIPLookup } // SetIgnoreURL is thread-safe way of setting whenever ignore the URL on rate limit keys diff --git a/limiter/limiter_setter_getter_test.go b/limiter/limiter_setter_getter_test.go index a4bd41f..55276c1 100644 --- a/limiter/limiter_setter_getter_test.go +++ b/limiter/limiter_setter_getter_test.go @@ -43,19 +43,6 @@ func TestSetGetStatusCode(t *testing.T) { } } -func TestSetGetIPLookups(t *testing.T) { - lmt := New(nil).SetMax(1) - - // Check default - if len(lmt.GetIPLookups()) != 3 { - t.Errorf("IPLookups field is incorrect. Value: %v", lmt.GetIPLookups()) - } - - if lmt.SetIPLookups([]string{"X-Real-IP"}).GetIPLookups()[0] != "X-Real-IP" { - t.Errorf("IPLookups field is incorrect. Value: %v", lmt.GetIPLookups()) - } -} - func TestSetGetMethods(t *testing.T) { lmt := New(nil).SetMax(1) diff --git a/tollbooth.go b/tollbooth.go index 0dcf82e..684a24c 100644 --- a/tollbooth.go +++ b/tollbooth.go @@ -37,8 +37,7 @@ func setRateLimitResponseHeaders(lmt *limiter.Limiter, w http.ResponseWriter, to func NewLimiter(max float64, tbOptions *limiter.ExpirableOptions) *limiter.Limiter { return limiter.New(tbOptions). SetMax(max). - SetBurst(int(math.Max(1, max))). - SetIPLookups([]string{"X-Forwarded-For", "X-Real-IP", "RemoteAddr"}) + SetBurst(int(math.Max(1, max))) } // LimitByKeys keeps track number of request made by keys separated by pipe. @@ -63,7 +62,7 @@ func ShouldSkipLimiter(lmt *limiter.Limiter, r *http.Request) bool { // --------------------------------- // Filter by remote ip // If we are unable to find remoteIP, skip limiter - remoteIP := libstring.RemoteIP(lmt.GetIPLookups(), lmt.GetForwardedForIndexFromBehind(), r) + remoteIP := libstring.RemoteIPFromIPLookup(lmt.GetIPLookup(), r) remoteIP = libstring.CanonicalizeIP(remoteIP) if remoteIP == "" { return true @@ -195,7 +194,7 @@ func ShouldSkipLimiter(lmt *limiter.Limiter, r *http.Request) bool { // BuildKeys generates a slice of keys to rate-limit by given limiter and request structs. func BuildKeys(lmt *limiter.Limiter, r *http.Request) [][]string { - remoteIP := libstring.RemoteIP(lmt.GetIPLookups(), lmt.GetForwardedForIndexFromBehind(), r) + remoteIP := libstring.RemoteIPFromIPLookup(lmt.GetIPLookup(), r) remoteIP = libstring.CanonicalizeIP(remoteIP) path := r.URL.Path sliceKeys := make([][]string, 0) diff --git a/tollbooth_benchmark_test.go b/tollbooth_benchmark_test.go index a035d83..338156f 100644 --- a/tollbooth_benchmark_test.go +++ b/tollbooth_benchmark_test.go @@ -30,9 +30,12 @@ func BenchmarkLimitByKeysWithExpiringBuckets(b *testing.B) { func BenchmarkBuildKeys(b *testing.B) { lmt := limiter.New(nil).SetMax(1) // Only 1 request per second is allowed. - lmt.SetIPLookups([]string{"X-Real-IP", "RemoteAddr", "X-Forwarded-For"}) - lmt.SetHeaders(make(map[string][]string)) - lmt.SetHeader("X-Real-IP", []string{"2601:7:1c82:4097:59a0:a80b:2841:b8c8"}) + lmt.SetIPLookup(limiter.IPLookup{ + Name: "X-Real-IP", + IndexFromRight: 0, + }). + SetHeaders(make(map[string][]string)). + SetHeader("X-Real-IP", []string{"2601:7:1c82:4097:59a0:a80b:2841:b8c8"}) request, err := http.NewRequest("GET", "/", strings.NewReader("Hello, world!")) if err != nil { diff --git a/tollbooth_bug_report_test.go b/tollbooth_bug_report_test.go index 9c323d8..b8b9fbb 100644 --- a/tollbooth_bug_report_test.go +++ b/tollbooth_bug_report_test.go @@ -59,7 +59,10 @@ Top: var issue66HeaderKey = "X-Customer-ID" func issue66RateLimiter(h http.HandlerFunc, customerIDs []string) (http.HandlerFunc, *limiter.Limiter) { - allocationLimiter := NewLimiter(1, nil).SetMethods([]string{"POST"}) + allocationLimiter := NewLimiter(1, nil).SetMethods([]string{"POST"}). + SetIPLookup(limiter.IPLookup{ + Name: "RemoteAddr", + }) handler := func(w http.ResponseWriter, r *http.Request) { allocationLimiter.SetHeader(issue66HeaderKey, customerIDs) @@ -135,8 +138,7 @@ Expected to receive: %v status code. Got: %v`, func Test_Issue91_BrokenSetMethod_DontBlockGet(t *testing.T) { requestsPerSecond := float64(1) - lmt := NewLimiter(requestsPerSecond, nil) - lmt.SetMethods([]string{"POST"}) + lmt := NewLimiter(requestsPerSecond, nil).SetMethods([]string{"POST"}) methods := lmt.GetMethods() if methods[0] != "POST" { @@ -170,8 +172,10 @@ func Test_Issue91_BrokenSetMethod_DontBlockGet(t *testing.T) { func Test_Issue91_BrokenSetMethod_BlockPost(t *testing.T) { requestsPerSecond := float64(1) - lmt := NewLimiter(requestsPerSecond, nil) - lmt.SetMethods([]string{"POST"}) + lmt := NewLimiter(requestsPerSecond, nil).SetMethods([]string{"POST"}). + SetIPLookup(limiter.IPLookup{ + Name: "RemoteAddr", + }) limitReachedCounter := 0 lmt.SetOnLimitReached(func(http.ResponseWriter, *http.Request) { diff --git a/tollbooth_test.go b/tollbooth_test.go index 9d3f69c..e5865df 100644 --- a/tollbooth_test.go +++ b/tollbooth_test.go @@ -33,7 +33,10 @@ func TestLimitByKeys(t *testing.T) { } func TestDefaultBuildKeys(t *testing.T) { - lmt := NewLimiter(1, nil) + lmt := NewLimiter(1, nil).SetIPLookup(limiter.IPLookup{ + Name: "X-Real-IP", + IndexFromRight: 0, + }) request, err := http.NewRequest("GET", "/", strings.NewReader("Hello, world!")) if err != nil { @@ -59,8 +62,12 @@ func TestDefaultBuildKeys(t *testing.T) { } func TestIgnoreURLBuildKeys(t *testing.T) { - lmt := NewLimiter(1, nil) - lmt.SetIgnoreURL(true) + lmt := NewLimiter(1, nil). + SetIPLookup(limiter.IPLookup{ + Name: "X-Real-IP", + IndexFromRight: 0, + }). + SetIgnoreURL(true) request, err := http.NewRequest("GET", "/", strings.NewReader("Hello, world!")) if err != nil { @@ -79,8 +86,12 @@ func TestIgnoreURLBuildKeys(t *testing.T) { } func TestBasicAuthBuildKeys(t *testing.T) { - lmt := NewLimiter(1, nil) - lmt.SetBasicAuthUsers([]string{"bro"}) + lmt := NewLimiter(1, nil). + SetIPLookup(limiter.IPLookup{ + Name: "X-Real-IP", + IndexFromRight: 0, + }). + SetBasicAuthUsers([]string{"bro"}) request, err := http.NewRequest("GET", "/", strings.NewReader("Hello, world!")) if err != nil { @@ -109,8 +120,12 @@ func TestBasicAuthBuildKeys(t *testing.T) { } func TestCustomHeadersBuildKeys(t *testing.T) { - lmt := NewLimiter(1, nil) - lmt.SetHeader("X-Auth-Token", []string{"totally-top-secret", "another-secret"}) + lmt := NewLimiter(1, nil). + SetIPLookup(limiter.IPLookup{ + Name: "X-Real-IP", + IndexFromRight: 0, + }). + SetHeader("X-Auth-Token", []string{"totally-top-secret", "another-secret"}) request, err := http.NewRequest("GET", "/", strings.NewReader("Hello, world!")) if err != nil { @@ -138,8 +153,12 @@ func TestCustomHeadersBuildKeys(t *testing.T) { } func TestRequestMethodBuildKeys(t *testing.T) { - lmt := NewLimiter(1, nil) - lmt.SetMethods([]string{"GET"}) + lmt := NewLimiter(1, nil). + SetIPLookup(limiter.IPLookup{ + Name: "X-Real-IP", + IndexFromRight: 0, + }). + SetMethods([]string{"GET"}) request, err := http.NewRequest("GET", "/", strings.NewReader("Hello, world!")) if err != nil { @@ -166,8 +185,12 @@ func TestRequestMethodBuildKeys(t *testing.T) { } func TestContextValueBuildKeys(t *testing.T) { - lmt := NewLimiter(1, nil) - lmt.SetContextValue("API-access-level", []string{"basic"}) + lmt := NewLimiter(1, nil). + SetIPLookup(limiter.IPLookup{ + Name: "X-Real-IP", + IndexFromRight: 0, + }). + SetContextValue("API-access-level", []string{"basic"}) request, err := http.NewRequest("GET", "/", strings.NewReader("Hello, world!")) if err != nil { @@ -196,9 +219,13 @@ func TestContextValueBuildKeys(t *testing.T) { } func TestRequestMethodAndCustomHeadersBuildKeys(t *testing.T) { - lmt := NewLimiter(1, nil) - lmt.SetMethods([]string{"GET"}) - lmt.SetHeader("X-Auth-Token", []string{"totally-top-secret", "another-secret"}) + lmt := NewLimiter(1, nil). + SetIPLookup(limiter.IPLookup{ + Name: "X-Real-IP", + IndexFromRight: 0, + }). + SetMethods([]string{"GET"}). + SetHeader("X-Auth-Token", []string{"totally-top-secret", "another-secret"}) request, err := http.NewRequest("GET", "/", strings.NewReader("Hello, world!")) if err != nil { @@ -228,9 +255,13 @@ func TestRequestMethodAndCustomHeadersBuildKeys(t *testing.T) { } func TestRequestMethodAndBasicAuthUsersBuildKeys(t *testing.T) { - lmt := NewLimiter(1, nil) - lmt.SetMethods([]string{"GET"}) - lmt.SetBasicAuthUsers([]string{"bro"}) + lmt := NewLimiter(1, nil). + SetIPLookup(limiter.IPLookup{ + Name: "X-Real-IP", + IndexFromRight: 0, + }). + SetMethods([]string{"GET"}). + SetBasicAuthUsers([]string{"bro"}) request, err := http.NewRequest("GET", "/", strings.NewReader("Hello, world!")) if err != nil { @@ -258,10 +289,14 @@ func TestRequestMethodAndBasicAuthUsersBuildKeys(t *testing.T) { } func TestRequestMethodCustomHeadersAndBasicAuthUsersBuildKeys(t *testing.T) { - lmt := NewLimiter(1, nil) - lmt.SetMethods([]string{"GET"}) - lmt.SetHeader("X-Auth-Token", []string{"totally-top-secret", "another-secret"}) - lmt.SetBasicAuthUsers([]string{"bro"}) + lmt := NewLimiter(1, nil). + SetIPLookup(limiter.IPLookup{ + Name: "X-Real-IP", + IndexFromRight: 0, + }). + SetMethods([]string{"GET"}). + SetHeader("X-Auth-Token", []string{"totally-top-secret", "another-secret"}). + SetBasicAuthUsers([]string{"bro"}) request, err := http.NewRequest("GET", "/", strings.NewReader("Hello, world!")) if err != nil { @@ -293,11 +328,15 @@ func TestRequestMethodCustomHeadersAndBasicAuthUsersBuildKeys(t *testing.T) { } func TestRequestMethodCustomHeadersAndBasicAuthUsersAndContextValuesBuildKeys(t *testing.T) { - lmt := NewLimiter(1, nil) - lmt.SetMethods([]string{"GET"}) - lmt.SetHeader("X-Auth-Token", []string{"totally-top-secret", "another-secret"}) - lmt.SetContextValue("API-access-level", []string{"basic"}) - lmt.SetBasicAuthUsers([]string{"bro"}) + lmt := NewLimiter(1, nil). + SetIPLookup(limiter.IPLookup{ + Name: "X-Real-IP", + IndexFromRight: 0, + }). + SetMethods([]string{"GET"}). + SetHeader("X-Auth-Token", []string{"totally-top-secret", "another-secret"}). + SetContextValue("API-access-level", []string{"basic"}). + SetBasicAuthUsers([]string{"bro"}) request, err := http.NewRequest("GET", "/", strings.NewReader("Hello, world!")) if err != nil { @@ -332,9 +371,12 @@ func TestRequestMethodCustomHeadersAndBasicAuthUsersAndContextValuesBuildKeys(t } func TestLimitHandler(t *testing.T) { - lmt := limiter.New(nil).SetMax(1).SetBurst(1) - lmt.SetIPLookups([]string{"X-Real-IP", "RemoteAddr", "X-Forwarded-For"}) - lmt.SetMethods([]string{"POST"}) + lmt := limiter.New(nil).SetMax(1).SetBurst(1). + SetIPLookup(limiter.IPLookup{ + Name: "X-Real-IP", + IndexFromRight: 0, + }). + SetMethods([]string{"POST"}) counter := 0 lmt.SetOnLimitReached(func(http.ResponseWriter, *http.Request) { counter++ }) @@ -405,10 +447,13 @@ func TestLimitHandler(t *testing.T) { } func TestOverrideForResponseWriter(t *testing.T) { - lmt := limiter.New(nil).SetMax(1).SetBurst(1) - lmt.SetIPLookups([]string{"X-Real-IP", "RemoteAddr", "X-Forwarded-For"}) - lmt.SetMethods([]string{"POST"}) - lmt.SetOverrideDefaultResponseWriter(true) + lmt := limiter.New(nil).SetMax(1).SetBurst(1). + SetIPLookup(limiter.IPLookup{ + Name: "X-Real-IP", + IndexFromRight: 0, + }). + SetMethods([]string{"POST"}). + SetOverrideDefaultResponseWriter(true) counter := 0 lmt.SetOnLimitReached(func(w http.ResponseWriter, _ *http.Request) { @@ -519,7 +564,10 @@ func (lm *LockMap) Add(key string, incr int64) { func TestLimitHandlerEmptyHeader(t *testing.T) { lmt := limiter.New(nil).SetMax(1).SetBurst(1) - lmt.SetIPLookups([]string{"X-Real-IP", "RemoteAddr", "X-Forwarded-For"}) + lmt.SetIPLookup(limiter.IPLookup{ + Name: "X-Real-IP", + IndexFromRight: 0, + }) lmt.SetMethods([]string{"POST"}) lmt.SetHeader("user_id", []string{})