rate limits

This commit is contained in:
Egor Aristov 2025-02-10 11:41:27 +03:00
parent caaf410a70
commit d654ce7c8b
Signed by: egor3f
GPG Key ID: 40482A264AAEC85F
3 changed files with 56 additions and 7 deletions

View File

@ -10,6 +10,7 @@ import (
"github.com/labstack/echo/v4/middleware" "github.com/labstack/echo/v4/middleware"
"github.com/labstack/gommon/log" "github.com/labstack/gommon/log"
"github.com/nats-io/nats.go" "github.com/nats-io/nats.go"
"golang.org/x/time/rate"
"net/http" "net/http"
"os" "os"
"os/signal" "os/signal"
@ -51,7 +52,12 @@ func main() {
e.StaticFS("/", echo.MustSubFS(wizard_vue.EmbedFS, wizard_vue.FSPrefix)) e.StaticFS("/", echo.MustSubFS(wizard_vue.EmbedFS, wizard_vue.FSPrefix))
apiHandler := httpApi.New(na, na) apiHandler := httpApi.New(
na,
na,
rate.Every(time.Duration(float64(time.Second)*cfg.RateLimitEvery)),
cfg.RateLimitBurst,
)
apiHandler.SetupRoutes(e.Group("/api/v1")) apiHandler.SetupRoutes(e.Group("/api/v1"))
go func() { go func() {

View File

@ -15,11 +15,13 @@ import (
"github.com/gorilla/feeds" "github.com/gorilla/feeds"
"github.com/labstack/echo/v4" "github.com/labstack/echo/v4"
"github.com/labstack/gommon/log" "github.com/labstack/gommon/log"
"golang.org/x/time/rate"
"html" "html"
"io" "io"
"net/url" "net/url"
"strconv" "strconv"
"strings" "strings"
"sync"
"time" "time"
) )
@ -30,16 +32,26 @@ const (
) )
type Handler struct { type Handler struct {
validate *validator.Validate validate *validator.Validate
workQueue adapters.WorkQueue workQueue adapters.WorkQueue
cache adapters.Cache cache adapters.Cache
rateLimit rate.Limit
rateLimitBurst int
limits map[string]*rate.Limiter
limitsMu sync.RWMutex
} }
func New(wq adapters.WorkQueue, cache adapters.Cache) *Handler { func New(wq adapters.WorkQueue, cache adapters.Cache, rateLimit rate.Limit, rateLimitBurst int) *Handler {
if wq == nil || cache == nil { if wq == nil || cache == nil {
panic("you fckd up with di again") panic("you fckd up with di again")
} }
h := Handler{workQueue: wq, cache: cache} h := Handler{
workQueue: wq,
cache: cache,
rateLimit: rateLimit,
rateLimitBurst: rateLimitBurst,
limits: make(map[string]*rate.Limiter),
}
h.validate = validator.New(validator.WithRequiredStructEnabled()) h.validate = validator.New(validator.WithRequiredStructEnabled())
if err := h.validate.RegisterValidation("selector", validators.ValidateSelector); err != nil { if err := h.validate.RegisterValidation("selector", validators.ValidateSelector); err != nil {
log.Panicf("register validation: %v", err) log.Panicf("register validation: %v", err)
@ -110,6 +122,9 @@ func (h *Handler) handleRender(c echo.Context) error {
return echo.NewHTTPError(500, fmt.Errorf("cache failed: %v", err)) return echo.NewHTTPError(500, fmt.Errorf("cache failed: %v", err))
} }
if errors.Is(err, adapters.ErrKeyNotFound) || time.Since(cachedTS) > cacheLifetime { if errors.Is(err, adapters.ErrKeyNotFound) || time.Since(cachedTS) > cacheLifetime {
if !h.checkRateLimit(c) {
return echo.ErrTooManyRequests
}
taskResultBytes, err = h.workQueue.Enqueue(timeoutCtx, task.CacheKey(), encodedTask) taskResultBytes, err = h.workQueue.Enqueue(timeoutCtx, task.CacheKey(), encodedTask)
if err != nil { if err != nil {
return echo.NewHTTPError(500, fmt.Errorf("task enqueue failed: %v", err)) return echo.NewHTTPError(500, fmt.Errorf("task enqueue failed: %v", err))
@ -151,6 +166,10 @@ func (h *Handler) handlePageScreenshot(c echo.Context) error {
return echo.NewHTTPError(500, fmt.Errorf("task marshal error: %v", err)) return echo.NewHTTPError(500, fmt.Errorf("task marshal error: %v", err))
} }
if !h.checkRateLimit(c) {
return echo.ErrTooManyRequests
}
taskResultBytes, err := h.workQueue.Enqueue(timeoutCtx, task.CacheKey(), encodedTask) taskResultBytes, err := h.workQueue.Enqueue(timeoutCtx, task.CacheKey(), encodedTask)
if err != nil { if err != nil {
return echo.NewHTTPError(500, fmt.Errorf("queued cache failed: %v", err)) return echo.NewHTTPError(500, fmt.Errorf("queued cache failed: %v", err))
@ -163,6 +182,23 @@ func (h *Handler) handlePageScreenshot(c echo.Context) error {
return c.Blob(200, "image/png", result.Image) return c.Blob(200, "image/png", result.Image)
} }
func (h *Handler) checkRateLimit(c echo.Context) bool {
h.limitsMu.RLock()
limiter, ok := h.limits[c.RealIP()]
h.limitsMu.RUnlock()
if !ok {
h.limitsMu.Lock()
limiter, ok = h.limits[c.RealIP()]
if !ok {
limiter = rate.NewLimiter(h.rateLimit, h.rateLimitBurst)
h.limits[c.RealIP()] = limiter
}
h.limitsMu.Unlock()
}
log.Debugf("Rate limiter for ip=%s tokens=%f", c.RealIP(), limiter.Tokens())
return limiter.Allow()
}
func (h *Handler) decodeSpecs(specsParam string) (Specs, error) { func (h *Handler) decodeSpecs(specsParam string) (Specs, error) {
var err error var err error
version := 0 version := 0

View File

@ -10,10 +10,17 @@ import (
) )
type Config struct { type Config struct {
// Format: host:port
WebserverAddress string `env:"WEBSERVER_ADDRESS" env-default:"0.0.0.0:5000" validate:"hostname_port"` WebserverAddress string `env:"WEBSERVER_ADDRESS" env-default:"0.0.0.0:5000" validate:"hostname_port"`
NatsUrl string `env:"NATS_URL" env-default:"nats://localhost:4222" validate:"url"` NatsUrl string `env:"NATS_URL" env-default:"nats://localhost:4222" validate:"url"`
Debug bool `env:"DEBUG"` Debug bool `env:"DEBUG"`
Proxy string `env:"PROXY" env-default:"" validate:"omitempty,proxy"` // Format: scheme://user:pass@host:port (supported schemes: http, https, socks)
Proxy string `env:"PROXY" env-default:"" validate:"omitempty,proxy"`
// RateLimitEvery and RateLimitBurst are parameters for Token Bucket algorithm.
// A token is added to the bucket every RateLimitEvery seconds.
// Rate limits don't apply to cache
RateLimitEvery float64 `env:"RATE_LIMIT_EVERY" env-default:"60" validate:"number,gt=0"`
RateLimitBurst int `env:"RATE_LIMIT_BURST" env-default:"10" validate:"number,gte=0"`
} }
func Read() (Config, error) { func Read() (Config, error) {