From 5339ff2c0dd64a014f0c9199bf2783f436b86ffe Mon Sep 17 00:00:00 2001 From: LTLA Date: Tue, 17 Sep 2024 09:34:05 -0700 Subject: [PATCH] Only process request files if they are no less than a minute old. This improves the protection against replay attacks by ruling out old request files. Only new request files are potentially dangerous, and these are handled by locks on the active request registry. Also moved all of the request-related functionality into their own file for easier development and testing. --- main.go | 107 ++++----------------------- request.go | 123 ++++++++++++++++++++++++++++++++ main_test.go => request_test.go | 66 ++++++++++++++--- 3 files changed, 192 insertions(+), 104 deletions(-) create mode 100644 request.go rename main_test.go => request_test.go (58%) diff --git a/main.go b/main.go index cd4b951..796ee49 100644 --- a/main.go +++ b/main.go @@ -8,13 +8,9 @@ import ( "os" "errors" "strings" - "fmt" "encoding/json" "net/http" "strconv" - "io/fs" - "syscall" - "sync" ) func dumpJsonResponse(w http.ResponseWriter, status int, v interface{}, path string) { @@ -59,94 +55,6 @@ func configureCors(w http.ResponseWriter, r *http.Request) bool { /***************************************************/ -func checkRequestFile(path, staging string) (string, error) { - if !strings.HasPrefix(path, "request-") { - return "", newHttpError(http.StatusBadRequest, errors.New("file name should start with \"request-\"")) - } - - if !filepath.IsLocal(path) { - return "", newHttpError(http.StatusBadRequest, errors.New("path should be local to the staging directory")) - } - reqpath := filepath.Join(staging, path) - - info, err := os.Lstat(reqpath) - if err != nil { - return "", newHttpError(http.StatusBadRequest, fmt.Errorf("failed to access path; %v", err)) - } - - if info.IsDir() { - return "", newHttpError(http.StatusBadRequest, errors.New("path is a directory")) - } - - if info.Mode() & fs.ModeSymlink != 0 { - return "", newHttpError(http.StatusBadRequest, errors.New("path is a symbolic link")) - } - - s, ok := info.Sys().(*syscall.Stat_t) - if !ok { - return "", fmt.Errorf("failed to convert to a syscall.Stat_t; %w", err) - } - if uint32(s.Nlink) > 1 { - return "", newHttpError(http.StatusBadRequest, errors.New("path seems to have multiple hard links")) - } - - return reqpath, nil -} - -/***************************************************/ - -// This tracks the requests that are currently being processed, to prevent the -// same request being processed multiple times at the same time. We use a -// multi-pool approach to improve parallelism across requests. -type activeRegistry struct { - NumPools int - Locks []sync.Mutex - Active []map[string]bool -} - -func newActiveRegistry(num_pools int) *activeRegistry { - return &activeRegistry { - NumPools: num_pools, - Locks: make([]sync.Mutex, num_pools), - Active: make([]map[string]bool, num_pools), - } -} - -func (a *activeRegistry) choosePool(path string) int { - sum := 0 - for _, r := range path { - sum += int(r) - } - return sum % a.NumPools -} - -func (a *activeRegistry) Add(path string) bool { - i := a.choosePool(path) - a.Locks[i].Lock() - defer a.Locks[i].Unlock() - - if a.Active[i] == nil { - a.Active[i] = map[string]bool{} - } else { - _, ok := a.Active[i][path] - if ok { - return false - } - } - - a.Active[i][path] = true - return true -} - -func (a *activeRegistry) Remove(path string) { - i := a.choosePool(path) - a.Locks[i].Lock() - defer a.Locks[i].Unlock() - delete(a.Active[i], path) -} - -/***************************************************/ - func main() { spath := flag.String("staging", "", "Path to the staging directory") rpath := flag.String("registry", "", "Path to the registry") @@ -174,7 +82,12 @@ func main() { } } - actreg := newActiveRegistry(11) + actreg := newActiveRequestRegistry(11) + const request_expiry = time.Minute + err := prefillActiveRequestRegistry(actreg, staging, request_expiry) + if err != nil { + log.Fatalf("failed to prefill active request registry; %v", err) + } endpt_prefix := *prefix if endpt_prefix != "" { @@ -186,7 +99,7 @@ func main() { path := r.PathValue("path") log.Println("processing " + path) - reqpath, err := checkRequestFile(path, staging) + reqpath, err := checkRequestFile(path, staging, request_expiry) if err != nil { dumpHttpErrorResponse(w, err, path) return @@ -244,10 +157,14 @@ func main() { // Purge the request file once it's processed, to reduce the potential // for replay attacks. For safety's sake, we only remove it from the - // registry if the request file was properly deleted. + // registry if the request file was properly deleted or it expired. err = os.Remove(reqpath) if err != nil { log.Printf("failed to purge the request file at %q; %v", path, err) + go func() { + time.Sleep(request_expiry) + actreg.Remove(path) + }() } else { actreg.Remove(path) } diff --git a/request.go b/request.go new file mode 100644 index 0000000..77c37af --- /dev/null +++ b/request.go @@ -0,0 +1,123 @@ +package main + +import ( + "sync" + "strings" + "path/filepath" + "net/http" + "time" + "os" + "fmt" + "errors" + "io/fs" + "syscall" +) + +func chooseLockPool(path string, num_pools int) int { + sum := 0 + for _, r := range path { + sum += int(r) + } + return sum % num_pools +} + +// This tracks the requests that are currently being processed, to prevent the +// same request being processed multiple times at the same time. We use a +// multi-pool approach to improve parallelism across requests. +type activeRequestRegistry struct { + NumPools int + Locks []sync.Mutex + Active []map[string]bool +} + +func newActiveRequestRegistry(num_pools int) *activeRequestRegistry { + return &activeRequestRegistry { + NumPools: num_pools, + Locks: make([]sync.Mutex, num_pools), + Active: make([]map[string]bool, num_pools), + } +} + +func prefillActiveRequestRegistry(a *activeRequestRegistry, staging string, expiry time.Duration) error { + // Prefilling the registry ensures that a user can't replay requests after a restart of the service. + entries, err := os.ReadDir(staging) + if err != nil { + return fmt.Errorf("failed to list existing request files in '%s'", staging) + } + + // This is only necessary until the expiry time is exceeded, after which we can evict those entries. + // Technically we only need to do this for files that weren't already expired, but this doesn't hurt. + for _, e := range entries { + path := e.Name() + a.Add(path) + go func(p string) { + time.Sleep(expiry) + a.Remove(p) + }(path) + } + return nil +} + +func (a *activeRequestRegistry) Add(path string) bool { + i := chooseLockPool(path, a.NumPools) + a.Locks[i].Lock() + defer a.Locks[i].Unlock() + + if a.Active[i] == nil { + a.Active[i] = map[string]bool{} + } else { + _, ok := a.Active[i][path] + if ok { + return false + } + } + + a.Active[i][path] = true + return true +} + +func (a *activeRequestRegistry) Remove(path string) { + i := chooseLockPool(path, a.NumPools) + a.Locks[i].Lock() + defer a.Locks[i].Unlock() + delete(a.Active[i], path) +} + +func checkRequestFile(path, staging string, expiry time.Duration) (string, error) { + if !strings.HasPrefix(path, "request-") { + return "", newHttpError(http.StatusBadRequest, errors.New("file name should start with \"request-\"")) + } + + if !filepath.IsLocal(path) { + return "", newHttpError(http.StatusBadRequest, errors.New("path should be local to the staging directory")) + } + reqpath := filepath.Join(staging, path) + + info, err := os.Lstat(reqpath) + if err != nil { + return "", newHttpError(http.StatusBadRequest, fmt.Errorf("failed to access path; %v", err)) + } + + if info.IsDir() { + return "", newHttpError(http.StatusBadRequest, errors.New("path is a directory")) + } + + if info.Mode() & fs.ModeSymlink != 0 { + return "", newHttpError(http.StatusBadRequest, errors.New("path is a symbolic link")) + } + + s, ok := info.Sys().(*syscall.Stat_t) + if !ok { + return "", fmt.Errorf("failed to convert to a syscall.Stat_t; %w", err) + } + if uint32(s.Nlink) > 1 { + return "", newHttpError(http.StatusBadRequest, errors.New("path seems to have multiple hard links")) + } + + current := time.Now() + if current.Sub(info.ModTime()) >= expiry { + return "", newHttpError(http.StatusBadRequest, errors.New("request file is expired")) + } + + return reqpath, nil +} diff --git a/main_test.go b/request_test.go similarity index 58% rename from main_test.go rename to request_test.go index bb2c8b6..e99cecc 100644 --- a/main_test.go +++ b/request_test.go @@ -5,6 +5,7 @@ import ( "strings" "os" "path/filepath" + "time" ) func TestCheckRequestFile(t *testing.T) { @@ -19,7 +20,7 @@ func TestCheckRequestFile(t *testing.T) { } t.Run("success", func(t *testing.T) { - out, err := checkRequestFile("request-foo", staging) + out, err := checkRequestFile("request-foo", staging, time.Minute) if err != nil { t.Fatal(err) } @@ -29,21 +30,21 @@ func TestCheckRequestFile(t *testing.T) { }) t.Run("name failure", func(t *testing.T) { - _, err := checkRequestFile("foo", staging) + _, err := checkRequestFile("foo", staging, time.Minute) if err == nil || !strings.Contains(err.Error(), "request-") { t.Fatal("should have failed") } }) t.Run("locality failure", func(t *testing.T) { - _, err := checkRequestFile("request-blah/../../foo", staging) + _, err := checkRequestFile("request-blah/../../foo", staging, time.Minute) if err == nil || !strings.Contains(err.Error(), "local") { t.Fatal("should have failed") } }) t.Run("not present", func(t *testing.T) { - _, err := checkRequestFile("request-blah", staging) + _, err := checkRequestFile("request-blah", staging, time.Minute) if err == nil || !strings.Contains(err.Error(), "failed to access") { t.Fatal("should have failed") } @@ -55,7 +56,7 @@ func TestCheckRequestFile(t *testing.T) { } t.Run("directory", func(t *testing.T) { - _, err := checkRequestFile("request-blah", staging) + _, err := checkRequestFile("request-blah", staging, time.Minute) if err == nil || !strings.Contains(err.Error(), "directory") { t.Fatal("should have failed") } @@ -67,7 +68,7 @@ func TestCheckRequestFile(t *testing.T) { } t.Run("symlink", func(t *testing.T) { - _, err := checkRequestFile("request-symlink", staging) + _, err := checkRequestFile("request-symlink", staging, time.Minute) if err == nil || !strings.Contains(err.Error(), "symbolic link") { t.Fatal("should have failed") } @@ -79,15 +80,28 @@ func TestCheckRequestFile(t *testing.T) { } t.Run("hard link", func(t *testing.T) { - _, err := checkRequestFile("request-hardlink", staging) + _, err := checkRequestFile("request-hardlink", staging, time.Minute) if err == nil || !strings.Contains(err.Error(), "hard link") { t.Fatal("should have failed") } }) + + err = os.Remove(filepath.Join(staging, "request-hardlink")) // removing the hardlink to test the rest. + if err != nil { + t.Fatal(err) + } + + t.Run("expired", func(t *testing.T) { + time.Sleep(time.Millisecond) + _, err := checkRequestFile("request-foo", staging, 0) + if err == nil || !strings.Contains(err.Error(), "expired") { + t.Fatal("should have failed") + } + }) } -func TestActiveRegistry(t *testing.T) { - a := newActiveRegistry(3) +func TestActiveRequestRegistry(t *testing.T) { + a := newActiveRequestRegistry(3) path := "adasdasdasd" ok := a.Add(path) @@ -111,3 +125,37 @@ func TestActiveRegistry(t *testing.T) { t.Fatal("expected a successful addition again") } } + +func TestPrefillActiveRequestRegistry(t *testing.T) { + staging, err := os.MkdirTemp("", "") + if err != nil { + t.Fatal(err) + } + + names := []string{ "foo", "bar", "whee" } + for _, f := range names { + err = os.WriteFile(filepath.Join(staging, f), []byte{}, 0644) + if err != nil { + t.Fatal(err) + } + } + + a := newActiveRequestRegistry(3) + err = prefillActiveRequestRegistry(a, staging, time.Millisecond * 100) + if err != nil { + t.Fatal(err) + } + + for _, f := range names { + if a.Add(f) { + t.Fatalf("%s should already be present in the registry", f) + } + } + + time.Sleep(time.Millisecond * 200) + for _, f := range names { + if !a.Add(f) { + t.Fatalf("%s should have been removed from the registry", f) + } + } +}