rate limits
This commit is contained in:
parent
caaf410a70
commit
d654ce7c8b
@ -10,6 +10,7 @@ import (
|
|||||||
"github.com/labstack/echo/v4/middleware"
|
"github.com/labstack/echo/v4/middleware"
|
||||||
"github.com/labstack/gommon/log"
|
"github.com/labstack/gommon/log"
|
||||||
"github.com/nats-io/nats.go"
|
"github.com/nats-io/nats.go"
|
||||||
|
"golang.org/x/time/rate"
|
||||||
"net/http"
|
"net/http"
|
||||||
"os"
|
"os"
|
||||||
"os/signal"
|
"os/signal"
|
||||||
@ -51,7 +52,12 @@ func main() {
|
|||||||
|
|
||||||
e.StaticFS("/", echo.MustSubFS(wizard_vue.EmbedFS, wizard_vue.FSPrefix))
|
e.StaticFS("/", echo.MustSubFS(wizard_vue.EmbedFS, wizard_vue.FSPrefix))
|
||||||
|
|
||||||
apiHandler := httpApi.New(na, na)
|
apiHandler := httpApi.New(
|
||||||
|
na,
|
||||||
|
na,
|
||||||
|
rate.Every(time.Duration(float64(time.Second)*cfg.RateLimitEvery)),
|
||||||
|
cfg.RateLimitBurst,
|
||||||
|
)
|
||||||
apiHandler.SetupRoutes(e.Group("/api/v1"))
|
apiHandler.SetupRoutes(e.Group("/api/v1"))
|
||||||
|
|
||||||
go func() {
|
go func() {
|
||||||
|
|||||||
@ -15,11 +15,13 @@ import (
|
|||||||
"github.com/gorilla/feeds"
|
"github.com/gorilla/feeds"
|
||||||
"github.com/labstack/echo/v4"
|
"github.com/labstack/echo/v4"
|
||||||
"github.com/labstack/gommon/log"
|
"github.com/labstack/gommon/log"
|
||||||
|
"golang.org/x/time/rate"
|
||||||
"html"
|
"html"
|
||||||
"io"
|
"io"
|
||||||
"net/url"
|
"net/url"
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -30,16 +32,26 @@ const (
|
|||||||
)
|
)
|
||||||
|
|
||||||
type Handler struct {
|
type Handler struct {
|
||||||
validate *validator.Validate
|
validate *validator.Validate
|
||||||
workQueue adapters.WorkQueue
|
workQueue adapters.WorkQueue
|
||||||
cache adapters.Cache
|
cache adapters.Cache
|
||||||
|
rateLimit rate.Limit
|
||||||
|
rateLimitBurst int
|
||||||
|
limits map[string]*rate.Limiter
|
||||||
|
limitsMu sync.RWMutex
|
||||||
}
|
}
|
||||||
|
|
||||||
func New(wq adapters.WorkQueue, cache adapters.Cache) *Handler {
|
func New(wq adapters.WorkQueue, cache adapters.Cache, rateLimit rate.Limit, rateLimitBurst int) *Handler {
|
||||||
if wq == nil || cache == nil {
|
if wq == nil || cache == nil {
|
||||||
panic("you fckd up with di again")
|
panic("you fckd up with di again")
|
||||||
}
|
}
|
||||||
h := Handler{workQueue: wq, cache: cache}
|
h := Handler{
|
||||||
|
workQueue: wq,
|
||||||
|
cache: cache,
|
||||||
|
rateLimit: rateLimit,
|
||||||
|
rateLimitBurst: rateLimitBurst,
|
||||||
|
limits: make(map[string]*rate.Limiter),
|
||||||
|
}
|
||||||
h.validate = validator.New(validator.WithRequiredStructEnabled())
|
h.validate = validator.New(validator.WithRequiredStructEnabled())
|
||||||
if err := h.validate.RegisterValidation("selector", validators.ValidateSelector); err != nil {
|
if err := h.validate.RegisterValidation("selector", validators.ValidateSelector); err != nil {
|
||||||
log.Panicf("register validation: %v", err)
|
log.Panicf("register validation: %v", err)
|
||||||
@ -110,6 +122,9 @@ func (h *Handler) handleRender(c echo.Context) error {
|
|||||||
return echo.NewHTTPError(500, fmt.Errorf("cache failed: %v", err))
|
return echo.NewHTTPError(500, fmt.Errorf("cache failed: %v", err))
|
||||||
}
|
}
|
||||||
if errors.Is(err, adapters.ErrKeyNotFound) || time.Since(cachedTS) > cacheLifetime {
|
if errors.Is(err, adapters.ErrKeyNotFound) || time.Since(cachedTS) > cacheLifetime {
|
||||||
|
if !h.checkRateLimit(c) {
|
||||||
|
return echo.ErrTooManyRequests
|
||||||
|
}
|
||||||
taskResultBytes, err = h.workQueue.Enqueue(timeoutCtx, task.CacheKey(), encodedTask)
|
taskResultBytes, err = h.workQueue.Enqueue(timeoutCtx, task.CacheKey(), encodedTask)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return echo.NewHTTPError(500, fmt.Errorf("task enqueue failed: %v", err))
|
return echo.NewHTTPError(500, fmt.Errorf("task enqueue failed: %v", err))
|
||||||
@ -151,6 +166,10 @@ func (h *Handler) handlePageScreenshot(c echo.Context) error {
|
|||||||
return echo.NewHTTPError(500, fmt.Errorf("task marshal error: %v", err))
|
return echo.NewHTTPError(500, fmt.Errorf("task marshal error: %v", err))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if !h.checkRateLimit(c) {
|
||||||
|
return echo.ErrTooManyRequests
|
||||||
|
}
|
||||||
|
|
||||||
taskResultBytes, err := h.workQueue.Enqueue(timeoutCtx, task.CacheKey(), encodedTask)
|
taskResultBytes, err := h.workQueue.Enqueue(timeoutCtx, task.CacheKey(), encodedTask)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return echo.NewHTTPError(500, fmt.Errorf("queued cache failed: %v", err))
|
return echo.NewHTTPError(500, fmt.Errorf("queued cache failed: %v", err))
|
||||||
@ -163,6 +182,23 @@ func (h *Handler) handlePageScreenshot(c echo.Context) error {
|
|||||||
return c.Blob(200, "image/png", result.Image)
|
return c.Blob(200, "image/png", result.Image)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (h *Handler) checkRateLimit(c echo.Context) bool {
|
||||||
|
h.limitsMu.RLock()
|
||||||
|
limiter, ok := h.limits[c.RealIP()]
|
||||||
|
h.limitsMu.RUnlock()
|
||||||
|
if !ok {
|
||||||
|
h.limitsMu.Lock()
|
||||||
|
limiter, ok = h.limits[c.RealIP()]
|
||||||
|
if !ok {
|
||||||
|
limiter = rate.NewLimiter(h.rateLimit, h.rateLimitBurst)
|
||||||
|
h.limits[c.RealIP()] = limiter
|
||||||
|
}
|
||||||
|
h.limitsMu.Unlock()
|
||||||
|
}
|
||||||
|
log.Debugf("Rate limiter for ip=%s tokens=%f", c.RealIP(), limiter.Tokens())
|
||||||
|
return limiter.Allow()
|
||||||
|
}
|
||||||
|
|
||||||
func (h *Handler) decodeSpecs(specsParam string) (Specs, error) {
|
func (h *Handler) decodeSpecs(specsParam string) (Specs, error) {
|
||||||
var err error
|
var err error
|
||||||
version := 0
|
version := 0
|
||||||
|
|||||||
@ -10,10 +10,17 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
type Config struct {
|
type Config struct {
|
||||||
|
// Format: host:port
|
||||||
WebserverAddress string `env:"WEBSERVER_ADDRESS" env-default:"0.0.0.0:5000" validate:"hostname_port"`
|
WebserverAddress string `env:"WEBSERVER_ADDRESS" env-default:"0.0.0.0:5000" validate:"hostname_port"`
|
||||||
NatsUrl string `env:"NATS_URL" env-default:"nats://localhost:4222" validate:"url"`
|
NatsUrl string `env:"NATS_URL" env-default:"nats://localhost:4222" validate:"url"`
|
||||||
Debug bool `env:"DEBUG"`
|
Debug bool `env:"DEBUG"`
|
||||||
Proxy string `env:"PROXY" env-default:"" validate:"omitempty,proxy"`
|
// Format: scheme://user:pass@host:port (supported schemes: http, https, socks)
|
||||||
|
Proxy string `env:"PROXY" env-default:"" validate:"omitempty,proxy"`
|
||||||
|
// RateLimitEvery and RateLimitBurst are parameters for Token Bucket algorithm.
|
||||||
|
// A token is added to the bucket every RateLimitEvery seconds.
|
||||||
|
// Rate limits don't apply to cache
|
||||||
|
RateLimitEvery float64 `env:"RATE_LIMIT_EVERY" env-default:"60" validate:"number,gt=0"`
|
||||||
|
RateLimitBurst int `env:"RATE_LIMIT_BURST" env-default:"10" validate:"number,gte=0"`
|
||||||
}
|
}
|
||||||
|
|
||||||
func Read() (Config, error) {
|
func Read() (Config, error) {
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user