Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

We received a vulnerability disclosure due to how we pick a remote IP address. #99

Merged
merged 6 commits into from
Oct 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 38 additions & 7 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,9 @@ This is a generic middleware to rate-limit HTTP requests.

**v7.x.x:** Replaced `time/rate` with `embedded time/rate` so that we can support more rate limit headers.

**v8.x.x:** Address `RemoteIP` vulnerability concern by replacing it with `RemoteIPFromIPLookup`, an explicit way to pick the IP address.


## Five Minute Tutorial

```go
Expand All @@ -34,6 +37,7 @@ import (
"net/http"

"github.com/didip/tollbooth/v7"
"github.com/didip/tollbooth/v7/limiter"
)

func HelloHandler(w http.ResponseWriter, req *http.Request) {
Expand All @@ -42,7 +46,15 @@ func HelloHandler(w http.ResponseWriter, req *http.Request) {

func main() {
// Create a request limiter per handler.
http.Handle("/", tollbooth.LimitFuncHandler(tollbooth.NewLimiter(1, nil), HelloHandler))
lmt := tollbooth.NewLimiter(1, nil)

// New in version >= 8, you must explicitly define how to pick the IP address.
lmt.SetIPLookup(limiter.IPLookup{
Name: "X-Real-IP",
IndexFromRight: 0,
})

http.Handle("/", tollbooth.LimitFuncHandler(lmt, HelloHandler))
http.ListenAndServe(":12345", nil)
}
```
Expand All @@ -66,10 +78,24 @@ func main() {
// every token bucket in it will expire 1 hour after it was initially set.
lmt = tollbooth.NewLimiter(1, &limiter.ExpirableOptions{DefaultExpirationTTL: time.Hour})

// Configure list of places to look for IP address.
// By default it's: "RemoteAddr", "X-Forwarded-For", "X-Real-IP"
// If your application is behind a proxy, set "X-Forwarded-For" first.
lmt.SetIPLookups([]string{"RemoteAddr", "X-Forwarded-For", "X-Real-IP"})
// New in version >= 8, you must explicitly define how to pick the IP address.
// If IP address cannot be found, rate limiter will not be activated.
lmt.SetIPLookup(limiter.IPLookup{
// The name of lookup method.
// Possible options are: RemoteAddr, X-Forwarded-For, X-Real-IP, CF-Connecting-IP
// All other headers are considered unknown and will be ignored.
Name: "X-Real-IP",

// The index position to pick the ip address from a comma separated list.
// The index goes from right to left.
//
// When there are multiple of the same headers,
// we will concat them together in the order of first to last seen.
// And then we pick the IP using this index position.
IndexFromRight: 0,
})

// In version >= 8, lmt.SetIPLookups and lmt.GetIPLookups are removed.

// Limit only GET and POST requests.
lmt.SetMethods([]string{"GET", "POST"})
Expand All @@ -89,8 +115,7 @@ func main() {
lmt.RemoveHeaderEntries("X-Access-Token", []string{"limitless-token"})

// By the way, the setters are chainable. Example:
lmt.SetIPLookups([]string{"RemoteAddr", "X-Forwarded-For", "X-Real-IP"}).
SetMethods([]string{"GET", "POST"}).
lmt.SetMethods([]string{"GET", "POST"}).
SetBasicAuthUsers([]string{"sansa"}).
SetBasicAuthUsers([]string{"tyrion"})
```
Expand Down Expand Up @@ -137,6 +162,12 @@ func main() {
```go
lmt := tollbooth.NewLimiter(1, nil)

// New in version >= 8, you must explicitly define how to pick the IP address.
lmt.SetIPLookup(limiter.IPLookup{
Name: "X-Forwarded-For",
IndexFromRight: 0,
})

// Set a custom message.
lmt.SetMessage("You have reached maximum request limit.")

Expand Down
53 changes: 26 additions & 27 deletions libstring/libstring.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@ import (
"net"
"net/http"
"strings"

"github.com/didip/tollbooth/v7/limiter"
)

// StringInSlice finds needle in a slice of strings.
Expand All @@ -17,38 +19,35 @@ func StringInSlice(sliceString []string, needle string) bool {
return false
}

// RemoteIP finds IP Address given http.Request struct.
func RemoteIP(ipLookups []string, forwardedForIndexFromBehind int, r *http.Request) string {
realIP := r.Header.Get("X-Real-IP")
forwardedFor := r.Header.Get("X-Forwarded-For")

for _, lookup := range ipLookups {
if lookup == "RemoteAddr" {
// 1. Cover the basic use cases for both ipv4 and ipv6
ip, _, err := net.SplitHostPort(r.RemoteAddr)
if err != nil {
// 2. Upon error, just return the remote addr.
return r.RemoteAddr
}
return ip
// RemoteIPFromIPLookup picks an ip address explicitly from limiter.IPLookup criteria.
// This function is intended to replace RemoteIP function.
func RemoteIPFromIPLookup(ipLookup limiter.IPLookup, r *http.Request) string {
switch ipLookup.Name {
case "RemoteAddr":
// 1. Cover the basic use cases for both ipv4 and ipv6
ip, _, err := net.SplitHostPort(r.RemoteAddr)
if err != nil {
// 2. Upon error, just return the remote addr.
return r.RemoteAddr
}
if lookup == "X-Forwarded-For" && forwardedFor != "" {
// X-Forwarded-For is potentially a list of addresses separated with ","
parts := strings.Split(forwardedFor, ",")
for i, p := range parts {
parts[i] = strings.TrimSpace(p)
}
return ip

partIndex := len(parts) - 1 - forwardedForIndexFromBehind
if partIndex < 0 {
partIndex = 0
}
case "X-Forwarded-For", "X-Real-IP", "CF-Connecting-IP":
ipAddrListCommaSeparated := r.Header.Values(ipLookup.Name)

return parts[partIndex]
ipAddrCommaSeparated := strings.Join(ipAddrListCommaSeparated, ",")

ips := strings.Split(ipAddrCommaSeparated, ",")
for i, p := range ips {
ips[i] = strings.TrimSpace(p)
}
if lookup == "X-Real-IP" && realIP != "" {
return realIP

ipIndex := len(ips) - 1 - ipLookup.IndexFromRight
if ipIndex < 0 {
ipIndex = 0
}

return ips[ipIndex]
}

return ""
Expand Down
103 changes: 52 additions & 51 deletions libstring/libstring_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@ import (
"net/http"
"strings"
"testing"

"github.com/didip/tollbooth/v7/limiter"
)

func TestStringInSlice(t *testing.T) {
Expand All @@ -12,28 +14,7 @@ func TestStringInSlice(t *testing.T) {
}
}

func TestRemoteIPDefault(t *testing.T) {
ipLookups := []string{"RemoteAddr", "X-Real-IP"}
ipv6 := "2601:7:1c82:4097:59a0:a80b:2841:b8c8"

request, err := http.NewRequest("GET", "/", strings.NewReader("Hello, world!"))
if err != nil {
t.Errorf("Unable to create new HTTP request. Error: %v", err)
}

request.Header.Set("X-Real-IP", ipv6)

ip := RemoteIP(ipLookups, 0, request)
if ip != request.RemoteAddr {
t.Errorf("Did not get the right IP. IP: %v", ip)
}
if ip == ipv6 {
t.Errorf("X-Real-IP should have been skipped. IP: %v", ip)
}
}

func TestRemoteIPForwardedFor(t *testing.T) {
ipLookups := []string{"X-Forwarded-For", "X-Real-IP", "RemoteAddr"}
ipv6 := "2601:7:1c82:4097:59a0:a80b:2841:b8c8"

request, err := http.NewRequest("GET", "/", strings.NewReader("Hello, world!"))
Expand All @@ -44,7 +25,11 @@ func TestRemoteIPForwardedFor(t *testing.T) {
request.Header.Set("X-Forwarded-For", "10.10.10.10")
request.Header.Set("X-Real-IP", ipv6)

ip := RemoteIP(ipLookups, 0, request)
ip := RemoteIPFromIPLookup(limiter.IPLookup{
Name: "X-Forwarded-For",
IndexFromRight: 0,
}, request)

if ip != "10.10.10.10" {
t.Errorf("Did not get the right IP. IP: %v", ip)
}
Expand All @@ -54,7 +39,6 @@ func TestRemoteIPForwardedFor(t *testing.T) {
}

func TestRemoteIPRealIP(t *testing.T) {
ipLookups := []string{"X-Real-IP", "X-Forwarded-For", "RemoteAddr"}
ipv6 := "2601:7:1c82:4097:59a0:a80b:2841:b8c8"

request, err := http.NewRequest("GET", "/", strings.NewReader("Hello, world!"))
Expand All @@ -65,7 +49,11 @@ func TestRemoteIPRealIP(t *testing.T) {
request.Header.Set("X-Forwarded-For", "10.10.10.10")
request.Header.Set("X-Real-IP", ipv6)

ip := RemoteIP(ipLookups, 0, request)
ip := RemoteIPFromIPLookup(limiter.IPLookup{
Name: "X-Real-IP",
IndexFromRight: 0,
}, request)

if ip != ipv6 {
t.Errorf("Did not get the right IP. IP: %v", ip)
}
Expand All @@ -74,53 +62,64 @@ func TestRemoteIPRealIP(t *testing.T) {
}
}

func TestRemoteIPMultipleForwardedFor(t *testing.T) {
ipLookups := []string{"X-Forwarded-For", "X-Real-IP", "RemoteAddr"}
ipv6 := "2601:7:1c82:4097:59a0:a80b:2841:b8c8"

func TestRemoteIPMultipleForwardedForIPAddresses(t *testing.T) {
request, err := http.NewRequest("GET", "/", strings.NewReader("Hello, world!"))
if err != nil {
t.Errorf("Unable to create new HTTP request. Error: %v", err)
}

request.Header.Set("X-Real-IP", ipv6)

// Missing X-Forwarded-For should not break things
ip := RemoteIP(ipLookups, 0, request)
if ip != ipv6 {
t.Errorf("X-Real-IP should have been chosen because X-Forwarded-For is missing. IP: %v", ip)
}

request.Header.Set("X-Forwarded-For", "10.10.10.10,10.10.10.11")

ip := RemoteIPFromIPLookup(limiter.IPLookup{
Name: "X-Forwarded-For",
IndexFromRight: 0,
}, request)

// Should get the last one
ip = RemoteIP(ipLookups, 0, request)
if ip != "10.10.10.11" {
t.Errorf("Did not get the right IP. IP: %v", ip)
}
if ip == ipv6 {
t.Errorf("X-Real-IP should have been skipped. IP: %v", ip)
}

ip = RemoteIPFromIPLookup(limiter.IPLookup{
Name: "X-Forwarded-For",
IndexFromRight: 1,
}, request)

// Should get the 2nd from last
ip = RemoteIP(ipLookups, 1, request)
if ip != "10.10.10.10" {
t.Errorf("Did not get the right IP. IP: %v", ip)
}
if ip == ipv6 {
t.Errorf("X-Real-IP should have been skipped. IP: %v", ip)
}

// What about index out of bound? RemoteIP should simply choose index 0.
ip = RemoteIP(ipLookups, 2, request)
ip = RemoteIPFromIPLookup(limiter.IPLookup{
Name: "X-Forwarded-For",
IndexFromRight: 2,
}, request)

if ip != "10.10.10.10" {
t.Errorf("Did not get the right IP. IP: %v", ip)
}
if ip == ipv6 {
t.Errorf("X-Real-IP should have been skipped. IP: %v", ip)
}
}

func TestRemoteIPMultipleForwardedForHeaders(t *testing.T) {
request, err := http.NewRequest("GET", "/", strings.NewReader("Hello, world!"))
if err != nil {
t.Errorf("Unable to create new HTTP request. Error: %v", err)
}

request.Header.Add("X-Forwarded-For", "8.8.8.8,8.8.4.4")
request.Header.Add("X-Forwarded-For", "10.10.10.10,10.10.10.11")

ip := RemoteIPFromIPLookup(limiter.IPLookup{
Name: "X-Forwarded-For",
IndexFromRight: 0,
}, request)

// Should get the last header and the last IP
if ip != "10.10.10.11" {
t.Errorf("Did not get the right IP. IP: %v", ip)
}
}
func TestCanonicalizeIP(t *testing.T) {
tests := []struct {
name string
Expand Down Expand Up @@ -169,10 +168,12 @@ func TestCanonicalizeIP(t *testing.T) {
},
}
for _, tt := range tests {
tt := tt
ip := tt.ip
want := tt.want

t.Run(tt.name, func(t *testing.T) {
if got := CanonicalizeIP(tt.ip); got != tt.want {
t.Errorf("CanonicalizeIP() = %v, want %v", got, tt.want)
if got := CanonicalizeIP(ip); got != want {
t.Errorf("CanonicalizeIP() = %v, want %v", got, want)
}
})
}
Expand Down
Loading
Loading