From a1e62aa7472066775dd54ee8596d29ea16bece14 Mon Sep 17 00:00:00 2001 From: Didip Kerabat Date: Thu, 3 Mar 2022 21:28:31 -0800 Subject: [PATCH 1/5] We received a vulnerability disclosure due to how we pick a remote IP address. Disclosure URL: https://gist.github.com/adam-p/4b777de4bda0027f4c3daa45618adcdc This is an attempt to address the situation. 1. We no longer configure SetIPLookups on default. 2. We address the two different SetIPLookups confusion in two different place by removing both of them. 3. We add a new, explicit way, for user to define how IP address should be picked up. Tests are all updated to use the new method of picking IP address. This will be a backward incompatible change so version number has to be bumped to 7. --- libstring/libstring.go | 62 +++++++------ libstring/libstring_test.go | 123 ++++++++++++++++---------- limiter/limiter.go | 38 ++++++-- limiter/limiter_setter_getter_test.go | 13 --- tollbooth.go | 7 +- tollbooth_benchmark_test.go | 10 ++- tollbooth_bug_report_test.go | 14 +-- tollbooth_test.go | 123 +++++++++++++++++++------- 8 files changed, 247 insertions(+), 143 deletions(-) diff --git a/libstring/libstring.go b/libstring/libstring.go index c9355e2..776f1cc 100644 --- a/libstring/libstring.go +++ b/libstring/libstring.go @@ -5,6 +5,8 @@ import ( "net" "net/http" "strings" + + "github.com/didip/tollbooth/v6/limiter" ) // StringInSlice finds needle in a slice of strings. @@ -17,38 +19,40 @@ func StringInSlice(sliceString []string, needle string) bool { return false } -// RemoteIP finds IP Address given http.Request struct. -func RemoteIP(ipLookups []string, forwardedForIndexFromBehind int, r *http.Request) string { - realIP := r.Header.Get("X-Real-IP") - forwardedFor := r.Header.Get("X-Forwarded-For") - - for _, lookup := range ipLookups { - if lookup == "RemoteAddr" { - // 1. Cover the basic use cases for both ipv4 and ipv6 - ip, _, err := net.SplitHostPort(r.RemoteAddr) - if err != nil { - // 2. Upon error, just return the remote addr. - return r.RemoteAddr - } - return ip +// RemoteIPFromIPLookup picks an ip address explicitly from limiter.IPLookup criteria. +// This function is intended to replace RemoteIP function. +func RemoteIPFromIPLookup(ipLookup limiter.IPLookup, r *http.Request) string { + switch ipLookup.Name { + case "RemoteAddr": + // 1. Cover the basic use cases for both ipv4 and ipv6 + ip, _, err := net.SplitHostPort(r.RemoteAddr) + if err != nil { + // 2. Upon error, just return the remote addr. + return r.RemoteAddr + } + return ip + + case "X-Forwarded-For", "X-Real-IP", "CF-Connecting-IP": + ipAddrListCommaSeparated := r.Header.Values(ipLookup.Name) + + headerIndex := len(ipAddrListCommaSeparated) - 1 - ipLookup.HeaderIndexFromRight + if headerIndex < 0 { + headerIndex = 0 } - if lookup == "X-Forwarded-For" && forwardedFor != "" { - // X-Forwarded-For is potentially a list of addresses separated with "," - parts := strings.Split(forwardedFor, ",") - for i, p := range parts { - parts[i] = strings.TrimSpace(p) - } - - partIndex := len(parts) - 1 - forwardedForIndexFromBehind - if partIndex < 0 { - partIndex = 0 - } - - return parts[partIndex] + + ipAddrCommaSeparated := ipAddrListCommaSeparated[headerIndex] + + ips := strings.Split(ipAddrCommaSeparated, ",") + for i, p := range ips { + ips[i] = strings.TrimSpace(p) } - if lookup == "X-Real-IP" && realIP != "" { - return realIP + + ipIndex := len(ips) - 1 - ipLookup.IndexFromRight + if ipIndex < 0 { + ipIndex = 0 } + + return ips[ipIndex] } return "" diff --git a/libstring/libstring_test.go b/libstring/libstring_test.go index 0bc8713..a078d27 100644 --- a/libstring/libstring_test.go +++ b/libstring/libstring_test.go @@ -4,6 +4,8 @@ import ( "net/http" "strings" "testing" + + "github.com/didip/tollbooth/v6/limiter" ) func TestStringInSlice(t *testing.T) { @@ -12,28 +14,7 @@ func TestStringInSlice(t *testing.T) { } } -func TestRemoteIPDefault(t *testing.T) { - ipLookups := []string{"RemoteAddr", "X-Real-IP"} - ipv6 := "2601:7:1c82:4097:59a0:a80b:2841:b8c8" - - request, err := http.NewRequest("GET", "/", strings.NewReader("Hello, world!")) - if err != nil { - t.Errorf("Unable to create new HTTP request. Error: %v", err) - } - - request.Header.Set("X-Real-IP", ipv6) - - ip := RemoteIP(ipLookups, 0, request) - if ip != request.RemoteAddr { - t.Errorf("Did not get the right IP. IP: %v", ip) - } - if ip == ipv6 { - t.Errorf("X-Real-IP should have been skipped. IP: %v", ip) - } -} - func TestRemoteIPForwardedFor(t *testing.T) { - ipLookups := []string{"X-Forwarded-For", "X-Real-IP", "RemoteAddr"} ipv6 := "2601:7:1c82:4097:59a0:a80b:2841:b8c8" request, err := http.NewRequest("GET", "/", strings.NewReader("Hello, world!")) @@ -44,7 +25,12 @@ func TestRemoteIPForwardedFor(t *testing.T) { request.Header.Set("X-Forwarded-For", "10.10.10.10") request.Header.Set("X-Real-IP", ipv6) - ip := RemoteIP(ipLookups, 0, request) + ip := RemoteIPFromIPLookup(limiter.IPLookup{ + Name: "X-Forwarded-For", + HeaderIndexFromRight: 0, + IndexFromRight: 0, + }, request) + if ip != "10.10.10.10" { t.Errorf("Did not get the right IP. IP: %v", ip) } @@ -54,7 +40,6 @@ func TestRemoteIPForwardedFor(t *testing.T) { } func TestRemoteIPRealIP(t *testing.T) { - ipLookups := []string{"X-Real-IP", "X-Forwarded-For", "RemoteAddr"} ipv6 := "2601:7:1c82:4097:59a0:a80b:2841:b8c8" request, err := http.NewRequest("GET", "/", strings.NewReader("Hello, world!")) @@ -65,7 +50,12 @@ func TestRemoteIPRealIP(t *testing.T) { request.Header.Set("X-Forwarded-For", "10.10.10.10") request.Header.Set("X-Real-IP", ipv6) - ip := RemoteIP(ipLookups, 0, request) + ip := RemoteIPFromIPLookup(limiter.IPLookup{ + Name: "X-Real-IP", + HeaderIndexFromRight: 0, + IndexFromRight: 0, + }, request) + if ip != ipv6 { t.Errorf("Did not get the right IP. IP: %v", ip) } @@ -74,53 +64,90 @@ func TestRemoteIPRealIP(t *testing.T) { } } -func TestRemoteIPMultipleForwardedFor(t *testing.T) { - ipLookups := []string{"X-Forwarded-For", "X-Real-IP", "RemoteAddr"} - ipv6 := "2601:7:1c82:4097:59a0:a80b:2841:b8c8" - +func TestRemoteIPMultipleForwardedForIPAddresses(t *testing.T) { request, err := http.NewRequest("GET", "/", strings.NewReader("Hello, world!")) if err != nil { t.Errorf("Unable to create new HTTP request. Error: %v", err) } - request.Header.Set("X-Real-IP", ipv6) - - // Missing X-Forwarded-For should not break things - ip := RemoteIP(ipLookups, 0, request) - if ip != ipv6 { - t.Errorf("X-Real-IP should have been chosen because X-Forwarded-For is missing. IP: %v", ip) - } - request.Header.Set("X-Forwarded-For", "10.10.10.10,10.10.10.11") + ip := RemoteIPFromIPLookup(limiter.IPLookup{ + Name: "X-Forwarded-For", + HeaderIndexFromRight: 0, + IndexFromRight: 0, + }, request) + // Should get the last one - ip = RemoteIP(ipLookups, 0, request) if ip != "10.10.10.11" { t.Errorf("Did not get the right IP. IP: %v", ip) } - if ip == ipv6 { - t.Errorf("X-Real-IP should have been skipped. IP: %v", ip) - } + + ip = RemoteIPFromIPLookup(limiter.IPLookup{ + Name: "X-Forwarded-For", + HeaderIndexFromRight: 0, + IndexFromRight: 1, + }, request) // Should get the 2nd from last - ip = RemoteIP(ipLookups, 1, request) if ip != "10.10.10.10" { t.Errorf("Did not get the right IP. IP: %v", ip) } - if ip == ipv6 { - t.Errorf("X-Real-IP should have been skipped. IP: %v", ip) - } // What about index out of bound? RemoteIP should simply choose index 0. - ip = RemoteIP(ipLookups, 2, request) + ip = RemoteIPFromIPLookup(limiter.IPLookup{ + Name: "X-Forwarded-For", + HeaderIndexFromRight: 0, + IndexFromRight: 2, + }, request) + if ip != "10.10.10.10" { t.Errorf("Did not get the right IP. IP: %v", ip) } - if ip == ipv6 { - t.Errorf("X-Real-IP should have been skipped. IP: %v", ip) - } } +func TestRemoteIPMultipleForwardedForHeaders(t *testing.T) { + request, err := http.NewRequest("GET", "/", strings.NewReader("Hello, world!")) + if err != nil { + t.Errorf("Unable to create new HTTP request. Error: %v", err) + } + + request.Header.Add("X-Forwarded-For", "8.8.8.8,8.8.4.4") + request.Header.Add("X-Forwarded-For", "10.10.10.10,10.10.10.11") + + ip := RemoteIPFromIPLookup(limiter.IPLookup{ + Name: "X-Forwarded-For", + HeaderIndexFromRight: 0, + IndexFromRight: 0, + }, request) + + // Should get the last header and the last IP + if ip != "10.10.10.11" { + t.Errorf("Did not get the right IP. IP: %v", ip) + } + + ip = RemoteIPFromIPLookup(limiter.IPLookup{ + Name: "X-Forwarded-For", + HeaderIndexFromRight: 1, + IndexFromRight: 0, + }, request) + + // Should get the last IP from the first header + if ip != "8.8.4.4" { + t.Errorf("Did not get the right IP. IP: %v", ip) + } + + ip = RemoteIPFromIPLookup(limiter.IPLookup{ + Name: "X-Forwarded-For", + HeaderIndexFromRight: 1, + IndexFromRight: 1, + }, request) + + // Should get the first IP from the first header + if ip != "8.8.8.8" { + t.Errorf("Did not get the right IP. IP: %v", ip) + } +} func TestCanonicalizeIP(t *testing.T) { tests := []struct { name string diff --git a/limiter/limiter.go b/limiter/limiter.go index 4a7a08a..1b6ae93 100644 --- a/limiter/limiter.go +++ b/limiter/limiter.go @@ -18,7 +18,6 @@ func New(generalExpirableOptions *ExpirableOptions) *Limiter { SetMessage("You have reached maximum request limit."). SetStatusCode(429). SetOnLimitReached(nil). - SetIPLookups([]string{"RemoteAddr", "X-Forwarded-For", "X-Real-IP"}). SetForwardedForIndexFromBehind(0). SetHeaders(make(map[string][]string)). SetContextValues(make(map[string][]string)). @@ -42,6 +41,21 @@ func New(generalExpirableOptions *ExpirableOptions) *Limiter { return lmt } +// IPLookup is a config struct to define how users want to pick the remote IP address. +type IPLookup struct { + // The name of lookup method. + // Possible options are: RemoteAddr, X-Forwarded-For, X-Real-IP, CF-Connecting-IP + Name string + + // If there are multiple of the same header, this index determines which one to use. + // The index goes from right to left. + HeaderIndexFromRight int + + // The index position to pick the ip address from a comma separated list. + // The index goes from right to left. + IndexFromRight int +} + // Limiter is a config struct to limit a particular request handler. type Limiter struct { // Maximum number of requests to limit per second. @@ -70,6 +84,10 @@ type Limiter struct { // You can rearrange the order as you like. ipLookups []string + // Explicitly define how to look up IP address. + // This is intended to replace ipLookups + explicitIPLookup IPLookup + forwardedForIndex int // List of HTTP Methods to limit (GET, POST, PUT, etc.). @@ -269,10 +287,12 @@ func (l *Limiter) ExecOnLimitReached(w http.ResponseWriter, r *http.Request) { } // SetOverrideDefaultResponseWriter is a thread-safe way of setting the response writer override variable. -func (l *Limiter) SetOverrideDefaultResponseWriter(override bool) { +func (l *Limiter) SetOverrideDefaultResponseWriter(override bool) *Limiter { l.Lock() l.overrideDefaultResponseWriter = override l.Unlock() + + return l } // GetOverrideDefaultResponseWriter is a thread-safe way of getting the response writer override variable. @@ -282,20 +302,22 @@ func (l *Limiter) GetOverrideDefaultResponseWriter() bool { return l.overrideDefaultResponseWriter } -// SetIPLookups is thread-safe way of setting list of places to look up IP address. -func (l *Limiter) SetIPLookups(ipLookups []string) *Limiter { +// SetIPLookup is thread-safe way of setting an explicit way to look up IP address. +// This method is intended to replace SetIPLookups (version 6 or older). +func (l *Limiter) SetIPLookup(lookup IPLookup) *Limiter { l.Lock() - l.ipLookups = ipLookups + l.explicitIPLookup = lookup l.Unlock() return l } -// GetIPLookups is thread-safe way of getting list of places to look up IP address. -func (l *Limiter) GetIPLookups() []string { +// GetIPLookup is thread-safe way of getting an explicit way to look up IP address. +// This method is intended to replace the old GetIPLookups (version 6 or older). +func (l *Limiter) GetIPLookup() IPLookup { l.RLock() defer l.RUnlock() - return l.ipLookups + return l.explicitIPLookup } // SetIgnoreURL is thread-safe way of setting whenever ignore the URL on rate limit keys diff --git a/limiter/limiter_setter_getter_test.go b/limiter/limiter_setter_getter_test.go index a4bd41f..55276c1 100644 --- a/limiter/limiter_setter_getter_test.go +++ b/limiter/limiter_setter_getter_test.go @@ -43,19 +43,6 @@ func TestSetGetStatusCode(t *testing.T) { } } -func TestSetGetIPLookups(t *testing.T) { - lmt := New(nil).SetMax(1) - - // Check default - if len(lmt.GetIPLookups()) != 3 { - t.Errorf("IPLookups field is incorrect. Value: %v", lmt.GetIPLookups()) - } - - if lmt.SetIPLookups([]string{"X-Real-IP"}).GetIPLookups()[0] != "X-Real-IP" { - t.Errorf("IPLookups field is incorrect. Value: %v", lmt.GetIPLookups()) - } -} - func TestSetGetMethods(t *testing.T) { lmt := New(nil).SetMax(1) diff --git a/tollbooth.go b/tollbooth.go index d3fa240..c59ce60 100644 --- a/tollbooth.go +++ b/tollbooth.go @@ -29,8 +29,7 @@ func setResponseHeaders(lmt *limiter.Limiter, w http.ResponseWriter, r *http.Req func NewLimiter(max float64, tbOptions *limiter.ExpirableOptions) *limiter.Limiter { return limiter.New(tbOptions). SetMax(max). - SetBurst(int(math.Max(1, max))). - SetIPLookups([]string{"X-Forwarded-For", "X-Real-IP", "RemoteAddr"}) + SetBurst(int(math.Max(1, max))) } // LimitByKeys keeps track number of request made by keys separated by pipe. @@ -48,7 +47,7 @@ func ShouldSkipLimiter(lmt *limiter.Limiter, r *http.Request) bool { // --------------------------------- // Filter by remote ip // If we are unable to find remoteIP, skip limiter - remoteIP := libstring.RemoteIP(lmt.GetIPLookups(), lmt.GetForwardedForIndexFromBehind(), r) + remoteIP := libstring.RemoteIPFromIPLookup(lmt.GetIPLookup(), r) remoteIP = libstring.CanonicalizeIP(remoteIP) if remoteIP == "" { return true @@ -176,7 +175,7 @@ func ShouldSkipLimiter(lmt *limiter.Limiter, r *http.Request) bool { // BuildKeys generates a slice of keys to rate-limit by given limiter and request structs. func BuildKeys(lmt *limiter.Limiter, r *http.Request) [][]string { - remoteIP := libstring.RemoteIP(lmt.GetIPLookups(), lmt.GetForwardedForIndexFromBehind(), r) + remoteIP := libstring.RemoteIPFromIPLookup(lmt.GetIPLookup(), r) remoteIP = libstring.CanonicalizeIP(remoteIP) path := r.URL.Path sliceKeys := make([][]string, 0) diff --git a/tollbooth_benchmark_test.go b/tollbooth_benchmark_test.go index 488e113..c8ce42c 100644 --- a/tollbooth_benchmark_test.go +++ b/tollbooth_benchmark_test.go @@ -30,9 +30,13 @@ func BenchmarkLimitByKeysWithExpiringBuckets(b *testing.B) { func BenchmarkBuildKeys(b *testing.B) { lmt := limiter.New(nil).SetMax(1) // Only 1 request per second is allowed. - lmt.SetIPLookups([]string{"X-Real-IP", "RemoteAddr", "X-Forwarded-For"}) - lmt.SetHeaders(make(map[string][]string)) - lmt.SetHeader("X-Real-IP", []string{"2601:7:1c82:4097:59a0:a80b:2841:b8c8"}) + lmt.SetIPLookup(limiter.IPLookup{ + Name: "X-Real-IP", + HeaderIndexFromRight: 0, + IndexFromRight: 0, + }). + SetHeaders(make(map[string][]string)). + SetHeader("X-Real-IP", []string{"2601:7:1c82:4097:59a0:a80b:2841:b8c8"}) request, err := http.NewRequest("GET", "/", strings.NewReader("Hello, world!")) if err != nil { diff --git a/tollbooth_bug_report_test.go b/tollbooth_bug_report_test.go index 5b03bc1..3874840 100644 --- a/tollbooth_bug_report_test.go +++ b/tollbooth_bug_report_test.go @@ -59,7 +59,10 @@ Top: var issue66HeaderKey = "X-Customer-ID" func issue66RateLimiter(h http.HandlerFunc, customerIDs []string) (http.HandlerFunc, *limiter.Limiter) { - allocationLimiter := NewLimiter(1, nil).SetMethods([]string{"POST"}) + allocationLimiter := NewLimiter(1, nil).SetMethods([]string{"POST"}). + SetIPLookup(limiter.IPLookup{ + Name: "RemoteAddr", + }) handler := func(w http.ResponseWriter, r *http.Request) { allocationLimiter.SetHeader(issue66HeaderKey, customerIDs) @@ -135,8 +138,7 @@ Expected to receive: %v status code. Got: %v`, func Test_Issue91_BrokenSetMethod_DontBlockGet(t *testing.T) { requestsPerSecond := float64(1) - lmt := NewLimiter(requestsPerSecond, nil) - lmt.SetMethods([]string{"POST"}) + lmt := NewLimiter(requestsPerSecond, nil).SetMethods([]string{"POST"}) methods := lmt.GetMethods() if methods[0] != "POST" { @@ -170,8 +172,10 @@ func Test_Issue91_BrokenSetMethod_DontBlockGet(t *testing.T) { func Test_Issue91_BrokenSetMethod_BlockPost(t *testing.T) { requestsPerSecond := float64(1) - lmt := NewLimiter(requestsPerSecond, nil) - lmt.SetMethods([]string{"POST"}) + lmt := NewLimiter(requestsPerSecond, nil).SetMethods([]string{"POST"}). + SetIPLookup(limiter.IPLookup{ + Name: "RemoteAddr", + }) limitReachedCounter := 0 lmt.SetOnLimitReached(func(w http.ResponseWriter, r *http.Request) { diff --git a/tollbooth_test.go b/tollbooth_test.go index 2bbf020..4fd11ec 100644 --- a/tollbooth_test.go +++ b/tollbooth_test.go @@ -32,7 +32,11 @@ func TestLimitByKeys(t *testing.T) { } func TestDefaultBuildKeys(t *testing.T) { - lmt := NewLimiter(1, nil) + lmt := NewLimiter(1, nil).SetIPLookup(limiter.IPLookup{ + Name: "X-Real-IP", + HeaderIndexFromRight: 0, + IndexFromRight: 0, + }) request, err := http.NewRequest("GET", "/", strings.NewReader("Hello, world!")) if err != nil { @@ -58,8 +62,13 @@ func TestDefaultBuildKeys(t *testing.T) { } func TestIgnoreURLBuildKeys(t *testing.T) { - lmt := NewLimiter(1, nil) - lmt.SetIgnoreURL(true) + lmt := NewLimiter(1, nil). + SetIPLookup(limiter.IPLookup{ + Name: "X-Real-IP", + HeaderIndexFromRight: 0, + IndexFromRight: 0, + }). + SetIgnoreURL(true) request, err := http.NewRequest("GET", "/", strings.NewReader("Hello, world!")) if err != nil { @@ -78,8 +87,13 @@ func TestIgnoreURLBuildKeys(t *testing.T) { } func TestBasicAuthBuildKeys(t *testing.T) { - lmt := NewLimiter(1, nil) - lmt.SetBasicAuthUsers([]string{"bro"}) + lmt := NewLimiter(1, nil). + SetIPLookup(limiter.IPLookup{ + Name: "X-Real-IP", + HeaderIndexFromRight: 0, + IndexFromRight: 0, + }). + SetBasicAuthUsers([]string{"bro"}) request, err := http.NewRequest("GET", "/", strings.NewReader("Hello, world!")) if err != nil { @@ -108,8 +122,13 @@ func TestBasicAuthBuildKeys(t *testing.T) { } func TestCustomHeadersBuildKeys(t *testing.T) { - lmt := NewLimiter(1, nil) - lmt.SetHeader("X-Auth-Token", []string{"totally-top-secret", "another-secret"}) + lmt := NewLimiter(1, nil). + SetIPLookup(limiter.IPLookup{ + Name: "X-Real-IP", + HeaderIndexFromRight: 0, + IndexFromRight: 0, + }). + SetHeader("X-Auth-Token", []string{"totally-top-secret", "another-secret"}) request, err := http.NewRequest("GET", "/", strings.NewReader("Hello, world!")) if err != nil { @@ -137,8 +156,13 @@ func TestCustomHeadersBuildKeys(t *testing.T) { } func TestRequestMethodBuildKeys(t *testing.T) { - lmt := NewLimiter(1, nil) - lmt.SetMethods([]string{"GET"}) + lmt := NewLimiter(1, nil). + SetIPLookup(limiter.IPLookup{ + Name: "X-Real-IP", + HeaderIndexFromRight: 0, + IndexFromRight: 0, + }). + SetMethods([]string{"GET"}) request, err := http.NewRequest("GET", "/", strings.NewReader("Hello, world!")) if err != nil { @@ -165,8 +189,13 @@ func TestRequestMethodBuildKeys(t *testing.T) { } func TestContextValueBuildKeys(t *testing.T) { - lmt := NewLimiter(1, nil) - lmt.SetContextValue("API-access-level", []string{"basic"}) + lmt := NewLimiter(1, nil). + SetIPLookup(limiter.IPLookup{ + Name: "X-Real-IP", + HeaderIndexFromRight: 0, + IndexFromRight: 0, + }). + SetContextValue("API-access-level", []string{"basic"}) request, err := http.NewRequest("GET", "/", strings.NewReader("Hello, world!")) if err != nil { @@ -195,9 +224,14 @@ func TestContextValueBuildKeys(t *testing.T) { } func TestRequestMethodAndCustomHeadersBuildKeys(t *testing.T) { - lmt := NewLimiter(1, nil) - lmt.SetMethods([]string{"GET"}) - lmt.SetHeader("X-Auth-Token", []string{"totally-top-secret", "another-secret"}) + lmt := NewLimiter(1, nil). + SetIPLookup(limiter.IPLookup{ + Name: "X-Real-IP", + HeaderIndexFromRight: 0, + IndexFromRight: 0, + }). + SetMethods([]string{"GET"}). + SetHeader("X-Auth-Token", []string{"totally-top-secret", "another-secret"}) request, err := http.NewRequest("GET", "/", strings.NewReader("Hello, world!")) if err != nil { @@ -227,9 +261,14 @@ func TestRequestMethodAndCustomHeadersBuildKeys(t *testing.T) { } func TestRequestMethodAndBasicAuthUsersBuildKeys(t *testing.T) { - lmt := NewLimiter(1, nil) - lmt.SetMethods([]string{"GET"}) - lmt.SetBasicAuthUsers([]string{"bro"}) + lmt := NewLimiter(1, nil). + SetIPLookup(limiter.IPLookup{ + Name: "X-Real-IP", + HeaderIndexFromRight: 0, + IndexFromRight: 0, + }). + SetMethods([]string{"GET"}). + SetBasicAuthUsers([]string{"bro"}) request, err := http.NewRequest("GET", "/", strings.NewReader("Hello, world!")) if err != nil { @@ -257,10 +296,15 @@ func TestRequestMethodAndBasicAuthUsersBuildKeys(t *testing.T) { } func TestRequestMethodCustomHeadersAndBasicAuthUsersBuildKeys(t *testing.T) { - lmt := NewLimiter(1, nil) - lmt.SetMethods([]string{"GET"}) - lmt.SetHeader("X-Auth-Token", []string{"totally-top-secret", "another-secret"}) - lmt.SetBasicAuthUsers([]string{"bro"}) + lmt := NewLimiter(1, nil). + SetIPLookup(limiter.IPLookup{ + Name: "X-Real-IP", + HeaderIndexFromRight: 0, + IndexFromRight: 0, + }). + SetMethods([]string{"GET"}). + SetHeader("X-Auth-Token", []string{"totally-top-secret", "another-secret"}). + SetBasicAuthUsers([]string{"bro"}) request, err := http.NewRequest("GET", "/", strings.NewReader("Hello, world!")) if err != nil { @@ -292,11 +336,16 @@ func TestRequestMethodCustomHeadersAndBasicAuthUsersBuildKeys(t *testing.T) { } func TestRequestMethodCustomHeadersAndBasicAuthUsersAndContextValuesBuildKeys(t *testing.T) { - lmt := NewLimiter(1, nil) - lmt.SetMethods([]string{"GET"}) - lmt.SetHeader("X-Auth-Token", []string{"totally-top-secret", "another-secret"}) - lmt.SetContextValue("API-access-level", []string{"basic"}) - lmt.SetBasicAuthUsers([]string{"bro"}) + lmt := NewLimiter(1, nil). + SetIPLookup(limiter.IPLookup{ + Name: "X-Real-IP", + HeaderIndexFromRight: 0, + IndexFromRight: 0, + }). + SetMethods([]string{"GET"}). + SetHeader("X-Auth-Token", []string{"totally-top-secret", "another-secret"}). + SetContextValue("API-access-level", []string{"basic"}). + SetBasicAuthUsers([]string{"bro"}) request, err := http.NewRequest("GET", "/", strings.NewReader("Hello, world!")) if err != nil { @@ -331,9 +380,13 @@ func TestRequestMethodCustomHeadersAndBasicAuthUsersAndContextValuesBuildKeys(t } func TestLimitHandler(t *testing.T) { - lmt := limiter.New(nil).SetMax(1).SetBurst(1) - lmt.SetIPLookups([]string{"X-Real-IP", "RemoteAddr", "X-Forwarded-For"}) - lmt.SetMethods([]string{"POST"}) + lmt := limiter.New(nil).SetMax(1).SetBurst(1). + SetIPLookup(limiter.IPLookup{ + Name: "X-Real-IP", + HeaderIndexFromRight: 0, + IndexFromRight: 0, + }). + SetMethods([]string{"POST"}) counter := 0 lmt.SetOnLimitReached(func(w http.ResponseWriter, r *http.Request) { counter++ }) @@ -377,10 +430,14 @@ func TestLimitHandler(t *testing.T) { } func TestOverrideForResponseWriter(t *testing.T) { - lmt := limiter.New(nil).SetMax(1).SetBurst(1) - lmt.SetIPLookups([]string{"X-Real-IP", "RemoteAddr", "X-Forwarded-For"}) - lmt.SetMethods([]string{"POST"}) - lmt.SetOverrideDefaultResponseWriter(true) + lmt := limiter.New(nil).SetMax(1).SetBurst(1). + SetIPLookup(limiter.IPLookup{ + Name: "X-Real-IP", + HeaderIndexFromRight: 0, + IndexFromRight: 0, + }). + SetMethods([]string{"POST"}). + SetOverrideDefaultResponseWriter(true) counter := 0 lmt.SetOnLimitReached(func(w http.ResponseWriter, r *http.Request) { From 0e274403c12d8b8f1edd447ffbda677c8c8a2fd3 Mon Sep 17 00:00:00 2001 From: Didip Kerabat Date: Thu, 3 Mar 2022 21:35:25 -0800 Subject: [PATCH 2/5] Make golint happy. --- libstring/libstring.go | 1 - libstring/libstring_test.go | 7 +++++-- limiter/limiter.go | 5 ----- 3 files changed, 5 insertions(+), 8 deletions(-) diff --git a/libstring/libstring.go b/libstring/libstring.go index 776f1cc..2928481 100644 --- a/libstring/libstring.go +++ b/libstring/libstring.go @@ -73,7 +73,6 @@ func CanonicalizeIP(ip string) string { case ':': // IPv6 isIPv6 = true - break } } if !isIPv6 { diff --git a/libstring/libstring_test.go b/libstring/libstring_test.go index a078d27..7f6d4b8 100644 --- a/libstring/libstring_test.go +++ b/libstring/libstring_test.go @@ -196,9 +196,12 @@ func TestCanonicalizeIP(t *testing.T) { }, } for _, tt := range tests { + ip := tt.ip + want := tt.want + t.Run(tt.name, func(t *testing.T) { - if got := CanonicalizeIP(tt.ip); got != tt.want { - t.Errorf("CanonicalizeIP() = %v, want %v", got, tt.want) + if got := CanonicalizeIP(ip); got != want { + t.Errorf("CanonicalizeIP() = %v, want %v", got, want) } }) } diff --git a/limiter/limiter.go b/limiter/limiter.go index 1b6ae93..1c48a17 100644 --- a/limiter/limiter.go +++ b/limiter/limiter.go @@ -79,11 +79,6 @@ type Limiter struct { // An option to write back what you want upon reaching a limit. overrideDefaultResponseWriter bool - // List of places to look up IP address. - // Default is "RemoteAddr", "X-Forwarded-For", "X-Real-IP". - // You can rearrange the order as you like. - ipLookups []string - // Explicitly define how to look up IP address. // This is intended to replace ipLookups explicitIPLookup IPLookup From 983f008bf665fdadc180f319e71e72675147274b Mon Sep 17 00:00:00 2001 From: Didip Kerabat Date: Thu, 3 Mar 2022 21:52:09 -0800 Subject: [PATCH 3/5] Update documentation. --- README.md | 52 +++++++++++++++++++++++++++++++++++++--------- limiter/limiter.go | 1 + 2 files changed, 43 insertions(+), 10 deletions(-) diff --git a/README.md b/README.md index 31f37b5..9affaa6 100644 --- a/README.md +++ b/README.md @@ -23,6 +23,8 @@ This is a generic middleware to rate-limit HTTP requests. **v6.x.x:** Replaced `go-cache` with `github.com/go-pkgz/expirable-cache` because `go-cache` leaks goroutines. +**v7.x.x:** Address `RemoteIP` vulnerability concern by replacing it with `RemoteIPFromIPLookup`, an explicit way to pick the IP address. + ## Five Minute Tutorial ```go @@ -31,7 +33,8 @@ package main import ( "net/http" - "github.com/didip/tollbooth/v6" + "github.com/didip/tollbooth/v7" + "github.com/didip/tollbooth/v7/limiter" ) func HelloHandler(w http.ResponseWriter, req *http.Request) { @@ -40,7 +43,16 @@ func HelloHandler(w http.ResponseWriter, req *http.Request) { func main() { // Create a request limiter per handler. - http.Handle("/", tollbooth.LimitFuncHandler(tollbooth.NewLimiter(1, nil), HelloHandler)) + lmt := tollbooth.NewLimiter(1, nil) + + // New in version >= 7, you must explicitly define how to pick the IP address. + lmt.SetIPLookup(limiter.IPLookup{ + Name: "X-Real-IP", + HeaderIndexFromRight: 0, + IndexFromRight: 0, + }) + + http.Handle("/", tollbooth.LimitFuncHandler(lmt, HelloHandler)) http.ListenAndServe(":12345", nil) } ``` @@ -52,8 +64,8 @@ func main() { import ( "time" - "github.com/didip/tollbooth/v6" - "github.com/didip/tollbooth/v6/limiter" + "github.com/didip/tollbooth/v7" + "github.com/didip/tollbooth/v7/limiter" ) lmt := tollbooth.NewLimiter(1, nil) @@ -64,10 +76,24 @@ func main() { // every token bucket in it will expire 1 hour after it was initially set. lmt = tollbooth.NewLimiter(1, &limiter.ExpirableOptions{DefaultExpirationTTL: time.Hour}) - // Configure list of places to look for IP address. - // By default it's: "RemoteAddr", "X-Forwarded-For", "X-Real-IP" - // If your application is behind a proxy, set "X-Forwarded-For" first. - lmt.SetIPLookups([]string{"RemoteAddr", "X-Forwarded-For", "X-Real-IP"}) + // New in version >= 7, you must explicitly define how to pick the IP address. + // If IP address cannot be found, rate limiter will not be activated. + lmt.SetIPLookup(limiter.IPLookup{ + // The name of lookup method. + // Possible options are: RemoteAddr, X-Forwarded-For, X-Real-IP, CF-Connecting-IP + // All other headers are considered unknown and will be ignored. + Name: "X-Real-IP", + + // If there are multiple of the same header, this index determines which one to use. + // The index goes from right to left. + HeaderIndexFromRight: 0, + + // The index position to pick the ip address from a comma separated list. + // The index goes from right to left. + IndexFromRight: 0, + }) + + // In version >= 7, lmt.SetIPLookups and lmt.GetIPLookups are removed. // Limit only GET and POST requests. lmt.SetMethods([]string{"GET", "POST"}) @@ -87,8 +113,7 @@ func main() { lmt.RemoveHeaderEntries("X-Access-Token", []string{"limitless-token"}) // By the way, the setters are chainable. Example: - lmt.SetIPLookups([]string{"RemoteAddr", "X-Forwarded-For", "X-Real-IP"}). - SetMethods([]string{"GET", "POST"}). + lmt.SetMethods([]string{"GET", "POST"}). SetBasicAuthUsers([]string{"sansa"}). SetBasicAuthUsers([]string{"tyrion"}) ``` @@ -128,6 +153,13 @@ func main() { ```go lmt := tollbooth.NewLimiter(1, nil) + // New in version >= 7, you must explicitly define how to pick the IP address. + lmt.SetIPLookup(limiter.IPLookup{ + Name: "X-Forwarded-For", + HeaderIndexFromRight: 0, + IndexFromRight: 0, + }) + // Set a custom message. lmt.SetMessage("You have reached maximum request limit.") diff --git a/limiter/limiter.go b/limiter/limiter.go index 1c48a17..45cfc1c 100644 --- a/limiter/limiter.go +++ b/limiter/limiter.go @@ -45,6 +45,7 @@ func New(generalExpirableOptions *ExpirableOptions) *Limiter { type IPLookup struct { // The name of lookup method. // Possible options are: RemoteAddr, X-Forwarded-For, X-Real-IP, CF-Connecting-IP + // All other headers are considered unknown and will be ignored. Name string // If there are multiple of the same header, this index determines which one to use. From 4a4cde1f25587b0d0afa480b2e36e8e1273d5064 Mon Sep 17 00:00:00 2001 From: Didip Kerabat Date: Fri, 4 Mar 2022 22:04:09 -0800 Subject: [PATCH 4/5] =?UTF-8?q?We=20don=E2=80=99t=20need=20the=20ability?= =?UTF-8?q?=20to=20pick=20which=20header=20to=20use.?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- README.md | 22 +++++++------- libstring/libstring.go | 7 +---- libstring/libstring_test.go | 52 ++++++++------------------------ limiter/limiter.go | 4 --- tollbooth_benchmark_test.go | 5 ++-- tollbooth_test.go | 60 +++++++++++++++---------------------- 6 files changed, 49 insertions(+), 101 deletions(-) diff --git a/README.md b/README.md index 9affaa6..c6c5fb5 100644 --- a/README.md +++ b/README.md @@ -47,9 +47,8 @@ func main() { // New in version >= 7, you must explicitly define how to pick the IP address. lmt.SetIPLookup(limiter.IPLookup{ - Name: "X-Real-IP", - HeaderIndexFromRight: 0, - IndexFromRight: 0, + Name: "X-Real-IP", + IndexFromRight: 0, }) http.Handle("/", tollbooth.LimitFuncHandler(lmt, HelloHandler)) @@ -82,15 +81,15 @@ func main() { // The name of lookup method. // Possible options are: RemoteAddr, X-Forwarded-For, X-Real-IP, CF-Connecting-IP // All other headers are considered unknown and will be ignored. - Name: "X-Real-IP", - - // If there are multiple of the same header, this index determines which one to use. - // The index goes from right to left. - HeaderIndexFromRight: 0, + Name: "X-Real-IP", // The index position to pick the ip address from a comma separated list. // The index goes from right to left. - IndexFromRight: 0, + // + // When there are multiple of the same headers, + // we will concat them together in the order of first to last seen. + // And then we pick the IP using this index position. + IndexFromRight: 0, }) // In version >= 7, lmt.SetIPLookups and lmt.GetIPLookups are removed. @@ -155,9 +154,8 @@ func main() { // New in version >= 7, you must explicitly define how to pick the IP address. lmt.SetIPLookup(limiter.IPLookup{ - Name: "X-Forwarded-For", - HeaderIndexFromRight: 0, - IndexFromRight: 0, + Name: "X-Forwarded-For", + IndexFromRight: 0, }) // Set a custom message. diff --git a/libstring/libstring.go b/libstring/libstring.go index 2928481..453f254 100644 --- a/libstring/libstring.go +++ b/libstring/libstring.go @@ -35,12 +35,7 @@ func RemoteIPFromIPLookup(ipLookup limiter.IPLookup, r *http.Request) string { case "X-Forwarded-For", "X-Real-IP", "CF-Connecting-IP": ipAddrListCommaSeparated := r.Header.Values(ipLookup.Name) - headerIndex := len(ipAddrListCommaSeparated) - 1 - ipLookup.HeaderIndexFromRight - if headerIndex < 0 { - headerIndex = 0 - } - - ipAddrCommaSeparated := ipAddrListCommaSeparated[headerIndex] + ipAddrCommaSeparated := strings.Join(ipAddrListCommaSeparated, ",") ips := strings.Split(ipAddrCommaSeparated, ",") for i, p := range ips { diff --git a/libstring/libstring_test.go b/libstring/libstring_test.go index 7f6d4b8..70f9060 100644 --- a/libstring/libstring_test.go +++ b/libstring/libstring_test.go @@ -26,9 +26,8 @@ func TestRemoteIPForwardedFor(t *testing.T) { request.Header.Set("X-Real-IP", ipv6) ip := RemoteIPFromIPLookup(limiter.IPLookup{ - Name: "X-Forwarded-For", - HeaderIndexFromRight: 0, - IndexFromRight: 0, + Name: "X-Forwarded-For", + IndexFromRight: 0, }, request) if ip != "10.10.10.10" { @@ -51,9 +50,8 @@ func TestRemoteIPRealIP(t *testing.T) { request.Header.Set("X-Real-IP", ipv6) ip := RemoteIPFromIPLookup(limiter.IPLookup{ - Name: "X-Real-IP", - HeaderIndexFromRight: 0, - IndexFromRight: 0, + Name: "X-Real-IP", + IndexFromRight: 0, }, request) if ip != ipv6 { @@ -73,9 +71,8 @@ func TestRemoteIPMultipleForwardedForIPAddresses(t *testing.T) { request.Header.Set("X-Forwarded-For", "10.10.10.10,10.10.10.11") ip := RemoteIPFromIPLookup(limiter.IPLookup{ - Name: "X-Forwarded-For", - HeaderIndexFromRight: 0, - IndexFromRight: 0, + Name: "X-Forwarded-For", + IndexFromRight: 0, }, request) // Should get the last one @@ -84,9 +81,8 @@ func TestRemoteIPMultipleForwardedForIPAddresses(t *testing.T) { } ip = RemoteIPFromIPLookup(limiter.IPLookup{ - Name: "X-Forwarded-For", - HeaderIndexFromRight: 0, - IndexFromRight: 1, + Name: "X-Forwarded-For", + IndexFromRight: 1, }, request) // Should get the 2nd from last @@ -96,9 +92,8 @@ func TestRemoteIPMultipleForwardedForIPAddresses(t *testing.T) { // What about index out of bound? RemoteIP should simply choose index 0. ip = RemoteIPFromIPLookup(limiter.IPLookup{ - Name: "X-Forwarded-For", - HeaderIndexFromRight: 0, - IndexFromRight: 2, + Name: "X-Forwarded-For", + IndexFromRight: 2, }, request) if ip != "10.10.10.10" { @@ -116,37 +111,14 @@ func TestRemoteIPMultipleForwardedForHeaders(t *testing.T) { request.Header.Add("X-Forwarded-For", "10.10.10.10,10.10.10.11") ip := RemoteIPFromIPLookup(limiter.IPLookup{ - Name: "X-Forwarded-For", - HeaderIndexFromRight: 0, - IndexFromRight: 0, + Name: "X-Forwarded-For", + IndexFromRight: 0, }, request) // Should get the last header and the last IP if ip != "10.10.10.11" { t.Errorf("Did not get the right IP. IP: %v", ip) } - - ip = RemoteIPFromIPLookup(limiter.IPLookup{ - Name: "X-Forwarded-For", - HeaderIndexFromRight: 1, - IndexFromRight: 0, - }, request) - - // Should get the last IP from the first header - if ip != "8.8.4.4" { - t.Errorf("Did not get the right IP. IP: %v", ip) - } - - ip = RemoteIPFromIPLookup(limiter.IPLookup{ - Name: "X-Forwarded-For", - HeaderIndexFromRight: 1, - IndexFromRight: 1, - }, request) - - // Should get the first IP from the first header - if ip != "8.8.8.8" { - t.Errorf("Did not get the right IP. IP: %v", ip) - } } func TestCanonicalizeIP(t *testing.T) { tests := []struct { diff --git a/limiter/limiter.go b/limiter/limiter.go index 45cfc1c..fbf2e6e 100644 --- a/limiter/limiter.go +++ b/limiter/limiter.go @@ -48,10 +48,6 @@ type IPLookup struct { // All other headers are considered unknown and will be ignored. Name string - // If there are multiple of the same header, this index determines which one to use. - // The index goes from right to left. - HeaderIndexFromRight int - // The index position to pick the ip address from a comma separated list. // The index goes from right to left. IndexFromRight int diff --git a/tollbooth_benchmark_test.go b/tollbooth_benchmark_test.go index c8ce42c..16ef3d9 100644 --- a/tollbooth_benchmark_test.go +++ b/tollbooth_benchmark_test.go @@ -31,9 +31,8 @@ func BenchmarkLimitByKeysWithExpiringBuckets(b *testing.B) { func BenchmarkBuildKeys(b *testing.B) { lmt := limiter.New(nil).SetMax(1) // Only 1 request per second is allowed. lmt.SetIPLookup(limiter.IPLookup{ - Name: "X-Real-IP", - HeaderIndexFromRight: 0, - IndexFromRight: 0, + Name: "X-Real-IP", + IndexFromRight: 0, }). SetHeaders(make(map[string][]string)). SetHeader("X-Real-IP", []string{"2601:7:1c82:4097:59a0:a80b:2841:b8c8"}) diff --git a/tollbooth_test.go b/tollbooth_test.go index 4fd11ec..b85a7e5 100644 --- a/tollbooth_test.go +++ b/tollbooth_test.go @@ -33,9 +33,8 @@ func TestLimitByKeys(t *testing.T) { func TestDefaultBuildKeys(t *testing.T) { lmt := NewLimiter(1, nil).SetIPLookup(limiter.IPLookup{ - Name: "X-Real-IP", - HeaderIndexFromRight: 0, - IndexFromRight: 0, + Name: "X-Real-IP", + IndexFromRight: 0, }) request, err := http.NewRequest("GET", "/", strings.NewReader("Hello, world!")) @@ -64,9 +63,8 @@ func TestDefaultBuildKeys(t *testing.T) { func TestIgnoreURLBuildKeys(t *testing.T) { lmt := NewLimiter(1, nil). SetIPLookup(limiter.IPLookup{ - Name: "X-Real-IP", - HeaderIndexFromRight: 0, - IndexFromRight: 0, + Name: "X-Real-IP", + IndexFromRight: 0, }). SetIgnoreURL(true) @@ -89,9 +87,8 @@ func TestIgnoreURLBuildKeys(t *testing.T) { func TestBasicAuthBuildKeys(t *testing.T) { lmt := NewLimiter(1, nil). SetIPLookup(limiter.IPLookup{ - Name: "X-Real-IP", - HeaderIndexFromRight: 0, - IndexFromRight: 0, + Name: "X-Real-IP", + IndexFromRight: 0, }). SetBasicAuthUsers([]string{"bro"}) @@ -124,9 +121,8 @@ func TestBasicAuthBuildKeys(t *testing.T) { func TestCustomHeadersBuildKeys(t *testing.T) { lmt := NewLimiter(1, nil). SetIPLookup(limiter.IPLookup{ - Name: "X-Real-IP", - HeaderIndexFromRight: 0, - IndexFromRight: 0, + Name: "X-Real-IP", + IndexFromRight: 0, }). SetHeader("X-Auth-Token", []string{"totally-top-secret", "another-secret"}) @@ -158,9 +154,8 @@ func TestCustomHeadersBuildKeys(t *testing.T) { func TestRequestMethodBuildKeys(t *testing.T) { lmt := NewLimiter(1, nil). SetIPLookup(limiter.IPLookup{ - Name: "X-Real-IP", - HeaderIndexFromRight: 0, - IndexFromRight: 0, + Name: "X-Real-IP", + IndexFromRight: 0, }). SetMethods([]string{"GET"}) @@ -191,9 +186,8 @@ func TestRequestMethodBuildKeys(t *testing.T) { func TestContextValueBuildKeys(t *testing.T) { lmt := NewLimiter(1, nil). SetIPLookup(limiter.IPLookup{ - Name: "X-Real-IP", - HeaderIndexFromRight: 0, - IndexFromRight: 0, + Name: "X-Real-IP", + IndexFromRight: 0, }). SetContextValue("API-access-level", []string{"basic"}) @@ -226,9 +220,8 @@ func TestContextValueBuildKeys(t *testing.T) { func TestRequestMethodAndCustomHeadersBuildKeys(t *testing.T) { lmt := NewLimiter(1, nil). SetIPLookup(limiter.IPLookup{ - Name: "X-Real-IP", - HeaderIndexFromRight: 0, - IndexFromRight: 0, + Name: "X-Real-IP", + IndexFromRight: 0, }). SetMethods([]string{"GET"}). SetHeader("X-Auth-Token", []string{"totally-top-secret", "another-secret"}) @@ -263,9 +256,8 @@ func TestRequestMethodAndCustomHeadersBuildKeys(t *testing.T) { func TestRequestMethodAndBasicAuthUsersBuildKeys(t *testing.T) { lmt := NewLimiter(1, nil). SetIPLookup(limiter.IPLookup{ - Name: "X-Real-IP", - HeaderIndexFromRight: 0, - IndexFromRight: 0, + Name: "X-Real-IP", + IndexFromRight: 0, }). SetMethods([]string{"GET"}). SetBasicAuthUsers([]string{"bro"}) @@ -298,9 +290,8 @@ func TestRequestMethodAndBasicAuthUsersBuildKeys(t *testing.T) { func TestRequestMethodCustomHeadersAndBasicAuthUsersBuildKeys(t *testing.T) { lmt := NewLimiter(1, nil). SetIPLookup(limiter.IPLookup{ - Name: "X-Real-IP", - HeaderIndexFromRight: 0, - IndexFromRight: 0, + Name: "X-Real-IP", + IndexFromRight: 0, }). SetMethods([]string{"GET"}). SetHeader("X-Auth-Token", []string{"totally-top-secret", "another-secret"}). @@ -338,9 +329,8 @@ func TestRequestMethodCustomHeadersAndBasicAuthUsersBuildKeys(t *testing.T) { func TestRequestMethodCustomHeadersAndBasicAuthUsersAndContextValuesBuildKeys(t *testing.T) { lmt := NewLimiter(1, nil). SetIPLookup(limiter.IPLookup{ - Name: "X-Real-IP", - HeaderIndexFromRight: 0, - IndexFromRight: 0, + Name: "X-Real-IP", + IndexFromRight: 0, }). SetMethods([]string{"GET"}). SetHeader("X-Auth-Token", []string{"totally-top-secret", "another-secret"}). @@ -382,9 +372,8 @@ func TestRequestMethodCustomHeadersAndBasicAuthUsersAndContextValuesBuildKeys(t func TestLimitHandler(t *testing.T) { lmt := limiter.New(nil).SetMax(1).SetBurst(1). SetIPLookup(limiter.IPLookup{ - Name: "X-Real-IP", - HeaderIndexFromRight: 0, - IndexFromRight: 0, + Name: "X-Real-IP", + IndexFromRight: 0, }). SetMethods([]string{"POST"}) @@ -432,9 +421,8 @@ func TestLimitHandler(t *testing.T) { func TestOverrideForResponseWriter(t *testing.T) { lmt := limiter.New(nil).SetMax(1).SetBurst(1). SetIPLookup(limiter.IPLookup{ - Name: "X-Real-IP", - HeaderIndexFromRight: 0, - IndexFromRight: 0, + Name: "X-Real-IP", + IndexFromRight: 0, }). SetMethods([]string{"POST"}). SetOverrideDefaultResponseWriter(true) From 5a782cc5c3fd873abf9334356f34a86bced4ab13 Mon Sep 17 00:00:00 2001 From: Didip Kerabat Date: Wed, 9 Oct 2024 12:43:26 -0700 Subject: [PATCH 5/5] Fix tests. --- README.md | 8 ++++---- libstring/libstring.go | 2 +- tollbooth_test.go | 5 ++++- 3 files changed, 9 insertions(+), 6 deletions(-) diff --git a/README.md b/README.md index d9d0c70..e13a0d5 100644 --- a/README.md +++ b/README.md @@ -48,7 +48,7 @@ func main() { // Create a request limiter per handler. lmt := tollbooth.NewLimiter(1, nil) - // New in version >= 7, you must explicitly define how to pick the IP address. + // New in version >= 8, you must explicitly define how to pick the IP address. lmt.SetIPLookup(limiter.IPLookup{ Name: "X-Real-IP", IndexFromRight: 0, @@ -78,7 +78,7 @@ func main() { // every token bucket in it will expire 1 hour after it was initially set. lmt = tollbooth.NewLimiter(1, &limiter.ExpirableOptions{DefaultExpirationTTL: time.Hour}) - // New in version >= 7, you must explicitly define how to pick the IP address. + // New in version >= 8, you must explicitly define how to pick the IP address. // If IP address cannot be found, rate limiter will not be activated. lmt.SetIPLookup(limiter.IPLookup{ // The name of lookup method. @@ -95,7 +95,7 @@ func main() { IndexFromRight: 0, }) - // In version >= 7, lmt.SetIPLookups and lmt.GetIPLookups are removed. + // In version >= 8, lmt.SetIPLookups and lmt.GetIPLookups are removed. // Limit only GET and POST requests. lmt.SetMethods([]string{"GET", "POST"}) @@ -162,7 +162,7 @@ func main() { ```go lmt := tollbooth.NewLimiter(1, nil) - // New in version >= 7, you must explicitly define how to pick the IP address. + // New in version >= 8, you must explicitly define how to pick the IP address. lmt.SetIPLookup(limiter.IPLookup{ Name: "X-Forwarded-For", IndexFromRight: 0, diff --git a/libstring/libstring.go b/libstring/libstring.go index 453f254..0e6b334 100644 --- a/libstring/libstring.go +++ b/libstring/libstring.go @@ -6,7 +6,7 @@ import ( "net/http" "strings" - "github.com/didip/tollbooth/v6/limiter" + "github.com/didip/tollbooth/v7/limiter" ) // StringInSlice finds needle in a slice of strings. diff --git a/tollbooth_test.go b/tollbooth_test.go index b023f4b..e5865df 100644 --- a/tollbooth_test.go +++ b/tollbooth_test.go @@ -564,7 +564,10 @@ func (lm *LockMap) Add(key string, incr int64) { func TestLimitHandlerEmptyHeader(t *testing.T) { lmt := limiter.New(nil).SetMax(1).SetBurst(1) - lmt.SetIPLookups([]string{"X-Real-IP", "RemoteAddr", "X-Forwarded-For"}) + lmt.SetIPLookup(limiter.IPLookup{ + Name: "X-Real-IP", + IndexFromRight: 0, + }) lmt.SetMethods([]string{"POST"}) lmt.SetHeader("user_id", []string{})