some refactoring (sol_I_d)

This commit is contained in:
Egor Aristov 2025-02-08 23:01:16 +03:00
parent fcf4bc2231
commit af4c94ac24
4 changed files with 81 additions and 57 deletions

View File

@ -40,7 +40,7 @@ func main() {
} }
}() }()
cq, err := natsadapter.New(natsc, "RENDER_TASKS") na, err := natsadapter.New(natsc, "RENDER_TASKS")
if err != nil { if err != nil {
log.Panicf("create nats adapter: %v", err) log.Panicf("create nats adapter: %v", err)
} }
@ -51,7 +51,7 @@ func main() {
e.StaticFS("/", echo.MustSubFS(wizard_vue.EmbedFS, wizard_vue.FSPrefix)) e.StaticFS("/", echo.MustSubFS(wizard_vue.EmbedFS, wizard_vue.FSPrefix))
apiHandler := httpApi.New(cq) apiHandler := httpApi.New(na, na)
apiHandler.SetupRoutes(e.Group("/api/v1")) apiHandler.SetupRoutes(e.Group("/api/v1"))
go func() { go func() {

View File

@ -2,16 +2,19 @@ package adapters
import ( import (
"context" "context"
"fmt"
"time" "time"
) )
type CachedWorkQueue interface { type WorkQueue interface {
ProcessWorkCached( Enqueue(ctx context.Context, key string, payload []byte) (result []byte, err error)
ctx context.Context, }
cacheLifetime time.Duration,
cacheKey string, var ErrKeyNotFound = fmt.Errorf("key not found")
taskPayload []byte,
) (result []byte, err error) type Cache interface {
Get(key string) (result []byte, ts time.Time, err error)
Set(key string, payload []byte) (err error)
} }
type QueueConsumer interface { type QueueConsumer interface {

View File

@ -2,7 +2,9 @@ package natsadapter
import ( import (
"context" "context"
"errors"
"fmt" "fmt"
"github.com/egor3f/rssalchemy/internal/adapters"
"github.com/labstack/gommon/log" "github.com/labstack/gommon/log"
"github.com/nats-io/nats.go" "github.com/nats-io/nats.go"
"github.com/nats-io/nats.go/jetstream" "github.com/nats-io/nats.go/jetstream"
@ -56,65 +58,74 @@ func New(natsc *nats.Conn, streamName string) (*NatsAdapter, error) {
return &na, nil return &na, nil
} }
func (na *NatsAdapter) ProcessWorkCached( func (na *NatsAdapter) Enqueue(ctx context.Context, key string, payload []byte) ([]byte, error) {
ctx context.Context,
cacheLifetime time.Duration,
cacheKey string,
taskPayload []byte,
) (result []byte, err error) {
// prevent resubmitting already running task // prevent resubmitting already running task
na.runningMu.Lock() na.runningMu.Lock()
_, alreadyRunning := na.running[cacheKey] _, alreadyRunning := na.running[key]
na.running[cacheKey] = struct{}{} na.running[key] = struct{}{}
na.runningMu.Unlock() na.runningMu.Unlock()
defer func() { defer func() {
na.runningMu.Lock() na.runningMu.Lock()
delete(na.running, cacheKey) delete(na.running, key)
na.runningMu.Unlock() na.runningMu.Unlock()
}() }()
watcher, err := na.kv.Watch(ctx, cacheKey) watcher, err := na.kv.Watch(ctx, key)
if err != nil { if err != nil {
return nil, fmt.Errorf("cache watch failed: %w", err) return nil, fmt.Errorf("nats watch failed: %w", err)
} }
defer watcher.Stop() defer watcher.Stop()
var lastUpdate jetstream.KeyValueEntry var taskEnqueued bool
for { for {
select { select {
case upd := <-watcher.Updates(): case upd := <-watcher.Updates():
if upd != nil { if upd != nil {
lastUpdate = upd if !taskEnqueued {
if time.Since(upd.Created()) <= cacheLifetime { // old value from cache, skipping
log.Infof("using cached value for task: %s, payload=%.100s", cacheKey, lastUpdate.Value()) continue
return lastUpdate.Value(), nil
} }
} else { log.Infof("got value for task: %s, payload=%.100s", key, upd.Value())
return upd.Value(), nil
}
taskEnqueued = true
if alreadyRunning { if alreadyRunning {
log.Infof("already running: %s", cacheKey) log.Infof("already running: %s", key)
} else { continue
log.Infof("sending task to queue: %s", cacheKey) }
log.Infof("sending task to queue: %s", key)
_, err = na.jets.Publish( _, err = na.jets.Publish(
ctx, ctx,
fmt.Sprintf("%s.%s", na.streamName, cacheKey), fmt.Sprintf("%s.%s", na.streamName, key),
taskPayload, payload,
) )
if err != nil { if err != nil {
return nil, fmt.Errorf("nats publish error: %v", err) return nil, fmt.Errorf("nats publish error: %v", err)
} }
}
}
case <-ctx.Done(): case <-ctx.Done():
log.Warnf("task cancelled by context: %s", cacheKey) log.Warnf("task cancelled by context: %s", key)
// anyway, using cached lastUpdate
if lastUpdate != nil {
return lastUpdate.Value(), ctx.Err()
} else {
return nil, ctx.Err() return nil, ctx.Err()
} }
} }
} }
func (na *NatsAdapter) Get(key string) (result []byte, ts time.Time, err error) {
entry, err := na.kv.Get(context.TODO(), key)
if err != nil {
if errors.Is(err, jetstream.ErrKeyNotFound) {
return nil, time.Time{}, adapters.ErrKeyNotFound
}
return nil, time.Time{}, fmt.Errorf("nats: %w", err)
}
return entry.Value(), entry.Created(), nil
}
func (na *NatsAdapter) Set(key string, payload []byte) error {
_, err := na.kv.Put(context.TODO(), key, payload)
if err != nil {
return fmt.Errorf("nats: %w", err)
}
return nil
} }
func (na *NatsAdapter) ConsumeQueue( func (na *NatsAdapter) ConsumeQueue(

View File

@ -6,6 +6,7 @@ import (
"context" "context"
"encoding/base64" "encoding/base64"
"encoding/json" "encoding/json"
"errors"
"fmt" "fmt"
"github.com/egor3f/rssalchemy/internal/adapters" "github.com/egor3f/rssalchemy/internal/adapters"
"github.com/egor3f/rssalchemy/internal/models" "github.com/egor3f/rssalchemy/internal/models"
@ -23,18 +24,22 @@ import (
) )
const ( const (
taskTimeout = 45 * time.Second taskTimeout = 1 * time.Minute
minLifetime = taskTimeout minLifetime = taskTimeout
maxLifetime = 24 * time.Hour maxLifetime = 24 * time.Hour
) )
type Handler struct { type Handler struct {
validate *validator.Validate validate *validator.Validate
CachedQueue adapters.CachedWorkQueue workQueue adapters.WorkQueue
cache adapters.Cache
} }
func New(cq adapters.CachedWorkQueue) *Handler { func New(wq adapters.WorkQueue, cache adapters.Cache) *Handler {
h := Handler{CachedQueue: cq} if wq == nil || cache == nil {
panic("you fckd up with di again")
}
h := Handler{workQueue: wq, cache: cache}
h.validate = validator.New(validator.WithRequiredStructEnabled()) h.validate = validator.New(validator.WithRequiredStructEnabled())
if err := h.validate.RegisterValidation("selector", validators.ValidateSelector); err != nil { if err := h.validate.RegisterValidation("selector", validators.ValidateSelector); err != nil {
log.Panicf("register validation: %v", err) log.Panicf("register validation: %v", err)
@ -100,9 +105,15 @@ func (h *Handler) handleRender(c echo.Context) error {
return echo.NewHTTPError(500, fmt.Errorf("task marshal error: %v", err)) return echo.NewHTTPError(500, fmt.Errorf("task marshal error: %v", err))
} }
taskResultBytes, err := h.CachedQueue.ProcessWorkCached(timeoutCtx, cacheLifetime, task.CacheKey(), encodedTask) taskResultBytes, cachedTS, err := h.cache.Get(task.CacheKey())
if err != nil && !errors.Is(err, adapters.ErrKeyNotFound) {
return echo.NewHTTPError(500, fmt.Errorf("cache failed: %v", err))
}
if errors.Is(err, adapters.ErrKeyNotFound) || time.Since(cachedTS) > cacheLifetime {
taskResultBytes, err = h.workQueue.Enqueue(timeoutCtx, task.CacheKey(), encodedTask)
if err != nil { if err != nil {
return echo.NewHTTPError(500, fmt.Errorf("queued cache failed: %v", err)) return echo.NewHTTPError(500, fmt.Errorf("task enqueue failed: %v", err))
}
} }
var result models.TaskResult var result models.TaskResult
@ -140,8 +151,7 @@ func (h *Handler) handlePageScreenshot(c echo.Context) error {
return echo.NewHTTPError(500, fmt.Errorf("task marshal error: %v", err)) return echo.NewHTTPError(500, fmt.Errorf("task marshal error: %v", err))
} }
cacheLifetime := minLifetime taskResultBytes, err := h.workQueue.Enqueue(timeoutCtx, task.CacheKey(), encodedTask)
taskResultBytes, err := h.CachedQueue.ProcessWorkCached(timeoutCtx, cacheLifetime, task.CacheKey(), encodedTask)
if err != nil { if err != nil {
return echo.NewHTTPError(500, fmt.Errorf("queued cache failed: %v", err)) return echo.NewHTTPError(500, fmt.Errorf("queued cache failed: %v", err))
} }