From d654ce7c8b0faa888e402b3aee9af9ac1358c4cb Mon Sep 17 00:00:00 2001 From: Egor Aristov Date: Mon, 10 Feb 2025 11:41:27 +0300 Subject: [PATCH] rate limits --- cmd/webserver/webserver.go | 8 ++++++- internal/api/http/handler.go | 46 ++++++++++++++++++++++++++++++++---- internal/config/config.go | 9 ++++++- 3 files changed, 56 insertions(+), 7 deletions(-) diff --git a/cmd/webserver/webserver.go b/cmd/webserver/webserver.go index cc5a449..56f23fd 100644 --- a/cmd/webserver/webserver.go +++ b/cmd/webserver/webserver.go @@ -10,6 +10,7 @@ import ( "github.com/labstack/echo/v4/middleware" "github.com/labstack/gommon/log" "github.com/nats-io/nats.go" + "golang.org/x/time/rate" "net/http" "os" "os/signal" @@ -51,7 +52,12 @@ func main() { e.StaticFS("/", echo.MustSubFS(wizard_vue.EmbedFS, wizard_vue.FSPrefix)) - apiHandler := httpApi.New(na, na) + apiHandler := httpApi.New( + na, + na, + rate.Every(time.Duration(float64(time.Second)*cfg.RateLimitEvery)), + cfg.RateLimitBurst, + ) apiHandler.SetupRoutes(e.Group("/api/v1")) go func() { diff --git a/internal/api/http/handler.go b/internal/api/http/handler.go index 0fff5a6..dd34ff3 100644 --- a/internal/api/http/handler.go +++ b/internal/api/http/handler.go @@ -15,11 +15,13 @@ import ( "github.com/gorilla/feeds" "github.com/labstack/echo/v4" "github.com/labstack/gommon/log" + "golang.org/x/time/rate" "html" "io" "net/url" "strconv" "strings" + "sync" "time" ) @@ -30,16 +32,26 @@ const ( ) type Handler struct { - validate *validator.Validate - workQueue adapters.WorkQueue - cache adapters.Cache + validate *validator.Validate + workQueue adapters.WorkQueue + cache adapters.Cache + rateLimit rate.Limit + rateLimitBurst int + limits map[string]*rate.Limiter + limitsMu sync.RWMutex } -func New(wq adapters.WorkQueue, cache adapters.Cache) *Handler { +func New(wq adapters.WorkQueue, cache adapters.Cache, rateLimit rate.Limit, rateLimitBurst int) *Handler { if wq == nil || cache == nil { panic("you fckd up with di again") } - h := Handler{workQueue: wq, cache: cache} + h := Handler{ + workQueue: wq, + cache: cache, + rateLimit: rateLimit, + rateLimitBurst: rateLimitBurst, + limits: make(map[string]*rate.Limiter), + } h.validate = validator.New(validator.WithRequiredStructEnabled()) if err := h.validate.RegisterValidation("selector", validators.ValidateSelector); err != nil { log.Panicf("register validation: %v", err) @@ -110,6 +122,9 @@ func (h *Handler) handleRender(c echo.Context) error { return echo.NewHTTPError(500, fmt.Errorf("cache failed: %v", err)) } if errors.Is(err, adapters.ErrKeyNotFound) || time.Since(cachedTS) > cacheLifetime { + if !h.checkRateLimit(c) { + return echo.ErrTooManyRequests + } taskResultBytes, err = h.workQueue.Enqueue(timeoutCtx, task.CacheKey(), encodedTask) if err != nil { return echo.NewHTTPError(500, fmt.Errorf("task enqueue failed: %v", err)) @@ -151,6 +166,10 @@ func (h *Handler) handlePageScreenshot(c echo.Context) error { return echo.NewHTTPError(500, fmt.Errorf("task marshal error: %v", err)) } + if !h.checkRateLimit(c) { + return echo.ErrTooManyRequests + } + taskResultBytes, err := h.workQueue.Enqueue(timeoutCtx, task.CacheKey(), encodedTask) if err != nil { return echo.NewHTTPError(500, fmt.Errorf("queued cache failed: %v", err)) @@ -163,6 +182,23 @@ func (h *Handler) handlePageScreenshot(c echo.Context) error { return c.Blob(200, "image/png", result.Image) } +func (h *Handler) checkRateLimit(c echo.Context) bool { + h.limitsMu.RLock() + limiter, ok := h.limits[c.RealIP()] + h.limitsMu.RUnlock() + if !ok { + h.limitsMu.Lock() + limiter, ok = h.limits[c.RealIP()] + if !ok { + limiter = rate.NewLimiter(h.rateLimit, h.rateLimitBurst) + h.limits[c.RealIP()] = limiter + } + h.limitsMu.Unlock() + } + log.Debugf("Rate limiter for ip=%s tokens=%f", c.RealIP(), limiter.Tokens()) + return limiter.Allow() +} + func (h *Handler) decodeSpecs(specsParam string) (Specs, error) { var err error version := 0 diff --git a/internal/config/config.go b/internal/config/config.go index f5bf732..d14e3db 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -10,10 +10,17 @@ import ( ) type Config struct { + // Format: host:port WebserverAddress string `env:"WEBSERVER_ADDRESS" env-default:"0.0.0.0:5000" validate:"hostname_port"` NatsUrl string `env:"NATS_URL" env-default:"nats://localhost:4222" validate:"url"` Debug bool `env:"DEBUG"` - Proxy string `env:"PROXY" env-default:"" validate:"omitempty,proxy"` + // Format: scheme://user:pass@host:port (supported schemes: http, https, socks) + Proxy string `env:"PROXY" env-default:"" validate:"omitempty,proxy"` + // RateLimitEvery and RateLimitBurst are parameters for Token Bucket algorithm. + // A token is added to the bucket every RateLimitEvery seconds. + // Rate limits don't apply to cache + RateLimitEvery float64 `env:"RATE_LIMIT_EVERY" env-default:"60" validate:"number,gt=0"` + RateLimitBurst int `env:"RATE_LIMIT_BURST" env-default:"10" validate:"number,gte=0"` } func Read() (Config, error) {