diff --git a/api/v1/api.go b/api/v1/api.go index 5a63a567..c444147f 100644 --- a/api/v1/api.go +++ b/api/v1/api.go @@ -92,7 +92,7 @@ func NewAPIV1( // @securityDefinitions.Bearer.name Authorization func (s *apiV1) RegisterRoutes(e *echo.Echo) { - e.Use(middleware.RateLimiterWithConfig(configureRateLimiter(s.cfg.RateLimit))) + e.Use(middleware.RateLimiterWithConfig(util.ConfigureRateLimiter(s.cfg.RateLimit))) e.POST("/register", s.handleRegisterUser) e.POST("/login", s.handleLoginUser) e.GET("/health", s.handleHealth) diff --git a/api/v1/handlers.go b/api/v1/handlers.go index a560dfe7..5e1e5740 100644 --- a/api/v1/handlers.go +++ b/api/v1/handlers.go @@ -7,8 +7,6 @@ import ( "encoding/hex" "encoding/json" "fmt" - "github.com/labstack/echo/v4/middleware" - "golang.org/x/time/rate" "io" "io/ioutil" "math/rand" @@ -121,25 +119,6 @@ func withUser(f func(echo.Context, *util.User) error) func(echo.Context) error { } } -func configureRateLimiter(rateLimit rate.Limit) middleware.RateLimiterConfig { - config := middleware.DefaultRateLimiterConfig - - config.IdentifierExtractor = func(ctx echo.Context) (string, error) { - var id string - user, ok := ctx.Get("user").(*util.User) - if !ok { - id = ctx.RealIP() - } else { - id = user.UUID - } - return id, nil - } - - config.Store = middleware.NewRateLimiterMemoryStore(rateLimit) - - return config -} - // handleStats godoc // @Summary Get content statistics // @Description This endpoint is used to get content statistics. Every content stored in the network (estuary) is tracked by a unique ID which can be used to get information about the content. This endpoint will allow the consumer to get the collected stats of a content diff --git a/cmd/estuary-shuttle/main.go b/cmd/estuary-shuttle/main.go index 949493c4..d8aae378 100644 --- a/cmd/estuary-shuttle/main.go +++ b/cmd/estuary-shuttle/main.go @@ -6,6 +6,7 @@ import ( "encoding/json" "flag" "fmt" + "golang.org/x/time/rate" "io" "io/ioutil" "net/http" @@ -86,7 +87,7 @@ const ( ColDir = "dir" ) -//#nosec G104 - it's not common to treat SetLogLevel error return +// #nosec G104 - it's not common to treat SetLogLevel error return func before(cctx *cli.Context) error { level := util.LogLevel @@ -196,6 +197,8 @@ func overrideSetOptions(flags []cli.Flag, cctx *cli.Context, cfg *config.Shuttle cfg.RpcEngine.Queue.Enabled = cctx.Bool("queue-eng-enabled") case "queue-eng-consumers": cfg.RpcEngine.Queue.Consumers = cctx.Int("queue-eng-consumers") + case "rate-limit": + cfg.RateLimit = rate.Limit(cctx.Float64("rate-limit")) default: } } @@ -325,7 +328,7 @@ func main() { Value: cfg.Dev, }, &cli.StringSliceFlag{ - Name: "announce-addr", + Name: "announce-addr", Usage: "specify multiaddrs that this node can be connected to ", Value: cli.NewStringSlice(cfg.Node.AnnounceAddrs...), }, @@ -1154,6 +1157,8 @@ func (s *Shuttle) ServeAPI() error { e.Use(middleware.Logger()) } + e.Use(middleware.RateLimiterWithConfig(util.ConfigureRateLimiter(s.shuttleConfig.RateLimit))) + e.Use(s.tracingMiddleware) e.Use(util.AppVersionMiddleware(s.shuttleConfig.AppVersion)) e.HTTPErrorHandler = util.ErrorHandler diff --git a/config/shuttle.go b/config/shuttle.go index 509b508f..3da866e8 100644 --- a/config/shuttle.go +++ b/config/shuttle.go @@ -2,6 +2,7 @@ package config import ( "errors" + "golang.org/x/time/rate" "path/filepath" rcmgr "github.com/libp2p/go-libp2p/p2p/host/resource-manager" @@ -28,6 +29,7 @@ type Shuttle struct { Private bool `json:"private"` Dev bool `json:"dev"` NoReloadPinQueue bool `json:"no_reload_pin_queue"` + RateLimit rate.Limit `json:"rate_limit"` Node Node `json:"node"` Jaeger Jaeger `json:"jaeger"` Content Content `json:"content"` @@ -82,6 +84,7 @@ func NewShuttle(appVersion string) *Shuttle { Private: false, Dev: false, NoReloadPinQueue: false, + RateLimit: rate.Limit(20), Content: Content{ DisableLocalAdding: false, diff --git a/util/rate_limiter.go b/util/rate_limiter.go new file mode 100644 index 00000000..dcfbbba3 --- /dev/null +++ b/util/rate_limiter.go @@ -0,0 +1,26 @@ +package util + +import ( + "github.com/labstack/echo/v4" + "github.com/labstack/echo/v4/middleware" + "golang.org/x/time/rate" +) + +func ConfigureRateLimiter(rateLimit rate.Limit) middleware.RateLimiterConfig { + config := middleware.DefaultRateLimiterConfig + + config.IdentifierExtractor = func(ctx echo.Context) (string, error) { + var id string + user, ok := ctx.Get("user").(User) + if !ok { + id = ctx.RealIP() + } else { + id = user.UUID + } + return id, nil + } + + config.Store = middleware.NewRateLimiterMemoryStore(rateLimit) + + return config +}