From f5ced3b7175e2e8dc5272ebfc9e256b1f471a346 Mon Sep 17 00:00:00 2001 From: Stavros Date: Sat, 13 Jun 2026 16:41:00 +0300 Subject: [PATCH] refactor: use cache store from tinyauth --- cache.go | 85 --------------------- cache_store.go | 197 ++++++++++++++++++++++++++++++++++++++++++++++++ main.go | 3 +- rate_limiter.go | 82 ++++++++++---------- 4 files changed, 241 insertions(+), 126 deletions(-) delete mode 100644 cache.go create mode 100644 cache_store.go diff --git a/cache.go b/cache.go deleted file mode 100644 index 8e1864e..0000000 --- a/cache.go +++ /dev/null @@ -1,85 +0,0 @@ -package main - -import ( - "sync" - "time" -) - -type cacheField struct { - value any - expire int64 -} - -type Cache struct { - cache map[string]cacheField - mutex sync.RWMutex -} - -func NewCache() *Cache { - cache := &Cache{ - cache: make(map[string]cacheField), - } - cache.cleanup() - return cache -} - -func (c *Cache) Set(key string, value any, ttl int64) { - c.mutex.Lock() - defer c.mutex.Unlock() - - expire := time.Now().Add(time.Duration(ttl) * time.Second).Unix() - - c.cache[key] = cacheField{ - value: value, - expire: expire, - } -} - -func (c *Cache) Get(key string) (any, bool) { - c.mutex.RLock() - - field, ok := c.cache[key] - - if !ok { - c.mutex.RUnlock() - return nil, false - } - - if time.Now().Unix() > field.expire { - c.mutex.RUnlock() - c.Delete(key) - return nil, false - } - - c.mutex.RUnlock() - return field.value, true -} - -func (c *Cache) Delete(key string) { - c.mutex.Lock() - defer c.mutex.Unlock() - delete(c.cache, key) -} - -func (c *Cache) Flush() { - c.mutex.Lock() - defer c.mutex.Unlock() - c.cache = make(map[string]cacheField, 0) -} - -func (c *Cache) cleanup() { - go func() { - ticker := time.NewTicker(24 * time.Hour) - defer ticker.Stop() - - for range ticker.C { - c.mutex.Lock() - for key, field := range c.cache { - if time.Now().Unix() > field.expire { - delete(c.cache, key) - } - } - c.mutex.Unlock() - } - }() -} diff --git a/cache_store.go b/cache_store.go new file mode 100644 index 0000000..bf16239 --- /dev/null +++ b/cache_store.go @@ -0,0 +1,197 @@ +package main + +import ( + "slices" + "sync" + "time" +) + +type CacheStoreActions[T any] struct { + Set func(key string, value T, ttl time.Duration) + Get func(key string) (T, bool) + Delete func(key string) + Update func(key string, value T, ttl time.Duration) bool +} + +type cacheEntry[T any] struct { + value T + expiresAt *time.Time +} + +type CacheStore[T any] struct { + cache map[string]cacheEntry[T] + order []string + mu sync.RWMutex + maxSize int +} + +func NewCacheStore[T any](maxSize int) *CacheStore[T] { + return &CacheStore[T]{ + cache: make(map[string]cacheEntry[T]), + order: make([]string, 0), + maxSize: maxSize, + } +} + +// With lock allows performing multiple operations on the cache store atomically. +// The provided mutate function receives a set of actions (Set, Get, Delete) that +// can be used to manipulate the cache store within the locked context. +func (cs *CacheStore[T]) WithLock(mutate func(actions CacheStoreActions[T])) { + cs.mu.Lock() + defer cs.mu.Unlock() + actions := CacheStoreActions[T]{ + Set: cs.setCallback, + Get: cs.getCallback, + Delete: cs.deleteCallback, + Update: cs.updateCallback, + } + mutate(actions) +} + +func (cs *CacheStore[T]) updateCallback(key string, value T, ttl time.Duration) bool { + if currentEntry, exists := cs.cache[key]; exists { + if currentEntry.expiresAt != nil && time.Now().After(*currentEntry.expiresAt) { + return false + } + + entry := cacheEntry[T]{ + value: value, + expiresAt: currentEntry.expiresAt, + } + + if ttl > 0 { + expiration := time.Now().Add(ttl) + entry.expiresAt = &expiration + } + + cs.cache[key] = entry + + return true + } + + return false +} + +func (cs *CacheStore[T]) Update(key string, value T, ttl time.Duration) bool { + cs.mu.Lock() + defer cs.mu.Unlock() + return cs.updateCallback(key, value, ttl) +} + +func (cs *CacheStore[T]) setCallback(key string, value T, ttl time.Duration) { + if cs.maxSize > 0 { + if _, exists := cs.cache[key]; !exists && len(cs.cache) >= cs.maxSize { + cs.evictOne() + } + } + + var expiresAt *time.Time + + if ttl > 0 { + expiration := time.Now().Add(ttl) + expiresAt = &expiration + } + + cs.cache[key] = cacheEntry[T]{ + value: value, + expiresAt: expiresAt, + } + + if !slices.Contains(cs.order, key) { + cs.order = append(cs.order, key) + } +} + +func (cs *CacheStore[T]) Set(key string, value T, ttl time.Duration) { + cs.mu.Lock() + defer cs.mu.Unlock() + cs.setCallback(key, value, ttl) +} + +func (cs *CacheStore[T]) getCallback(key string) (T, bool) { + entry, exists := cs.cache[key] + + if !exists { + var zero T + return zero, false + } + + if entry.expiresAt != nil && time.Now().After(*entry.expiresAt) { + var zero T + return zero, false + } + + return entry.value, true +} + +func (cs *CacheStore[T]) Get(key string) (T, bool) { + cs.mu.RLock() + defer cs.mu.RUnlock() + return cs.getCallback(key) +} + +func (cs *CacheStore[T]) deleteCallback(key string) { + delete(cs.cache, key) + keyIdx := slices.Index(cs.order, key) + if keyIdx != -1 { + cs.order = append(cs.order[:keyIdx], cs.order[keyIdx+1:]...) + } +} + +func (cs *CacheStore[T]) Delete(key string) { + cs.mu.Lock() + defer cs.mu.Unlock() + cs.deleteCallback(key) +} + +func (cs *CacheStore[T]) Sweep() { + cs.mu.Lock() + for key, entry := range cs.cache { + if entry.expiresAt != nil && time.Now().After(*entry.expiresAt) { + cs.deleteCallback(key) + } + } + cs.mu.Unlock() +} + +func (cs *CacheStore[T]) evictOne() bool { + now := time.Now() + var oldestKey string + var oldestExp *time.Time + + for k, e := range cs.cache { + if e.expiresAt != nil && now.After(*e.expiresAt) { + cs.deleteCallback(k) + return true + } + if e.expiresAt != nil && (oldestExp == nil || e.expiresAt.Before(*oldestExp)) { + oldestKey, oldestExp = k, e.expiresAt + } + } + + // If we found an oldest key, evict it else we delete the first key in the order list + if oldestKey != "" { + cs.deleteCallback(oldestKey) + return true + } else { + if len(cs.order) > 0 { + cs.deleteCallback(cs.order[0]) + return true + } + } + + return false +} + +func (cs *CacheStore[T]) Size() int { + cs.mu.RLock() + defer cs.mu.RUnlock() + return len(cs.cache) +} + +func (cs *CacheStore[T]) Clear() { + cs.mu.Lock() + defer cs.mu.Unlock() + cs.cache = make(map[string]cacheEntry[T]) + cs.order = make([]string, 0) +} diff --git a/main.go b/main.go index edee0d4..04785eb 100644 --- a/main.go +++ b/main.go @@ -79,7 +79,6 @@ func main() { );`) queries := queries.New(sqlDb) - cache := NewCache() router := chi.NewRouter() router.Use(middleware.Logger) router.Use(middleware.Recoverer) @@ -87,7 +86,7 @@ func main() { rateLimiter := NewRateLimiter(RateLimitConfig{ RateLimitCount: config.RateLimitCount, TrustedProxies: config.TrustedProxies, - }, cache) + }) instancesHandler := NewInstancesHandler(queries) healthHandler := NewHealthHandler() diff --git a/rate_limiter.go b/rate_limiter.go index 68a3e00..aac6efa 100644 --- a/rate_limiter.go +++ b/rate_limiter.go @@ -2,11 +2,10 @@ package main import ( "fmt" - "log/slog" "net" "net/http" "slices" - "sync" + "strings" "time" ) @@ -17,63 +16,65 @@ type RateLimitConfig struct { type RateLimiter struct { config RateLimitConfig - cache *Cache - mutex sync.RWMutex + caches struct { + ratelimit *CacheStore[int] + } } -func NewRateLimiter(config RateLimitConfig, cache *Cache) *RateLimiter { - return &RateLimiter{ +func NewRateLimiter(config RateLimitConfig) *RateLimiter { + rl := &RateLimiter{ config: config, - cache: cache, } + + ratelimitCache := NewCacheStore[int](0) + rl.caches.ratelimit = ratelimitCache + + go func() { + ticker := time.NewTicker(1 * time.Minute) + defer ticker.Stop() + + for range ticker.C { + rl.caches.ratelimit.Sweep() + } + }() + + return rl } func (rl *RateLimiter) limit(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - rl.mutex.Lock() - defer rl.mutex.Unlock() - clientIP := rl.getClientIP(r) - if clientIP == "" { http.Error(w, "failed to determine client ip", http.StatusInternalServerError) return } - value, exists := rl.cache.Get(clientIP) + var used int + rl.caches.ratelimit.WithLock(func(actions CacheStoreActions[int]) { + current, exists := actions.Get(clientIP) + if !exists { + actions.Set(clientIP, 1, 12*time.Hour) + used = 1 + return + } + current++ + used = current + if current > rl.config.RateLimitCount { + return + } + actions.Update(clientIP, current, 0) + }) w.Header().Set("x-ratelimit-limit", fmt.Sprint(rl.config.RateLimitCount)) - w.Header().Set("x-ratelimit-reset", fmt.Sprint(time.Now().Add(12*time.Hour).Unix())) - - if !exists { - rl.cache.Set(clientIP, 1, 43200) // 12 hours TTL - w.Header().Set("x-ratelimit-remaining", fmt.Sprint(rl.config.RateLimitCount-1)) - w.Header().Set("x-ratelimit-used", fmt.Sprint(1)) - next.ServeHTTP(w, r) - return - } - - used, ok := value.(int) - - if !ok { - slog.Error("failed to assert rate limit cache value type") - http.Error(w, "internal server error", http.StatusInternalServerError) - return - } - - used++ + w.Header().Set("x-ratelimit-used", fmt.Sprint(used)) if used > rl.config.RateLimitCount { - w.Header().Set("x-ratelimit-remaining", fmt.Sprint(0)) - w.Header().Set("x-ratelimit-used", fmt.Sprint(used)) + w.Header().Set("x-ratelimit-remaining", "0") http.Error(w, "rate limit exceeded", http.StatusTooManyRequests) return } - rl.cache.Set(clientIP, used, 43200) // 12 hours TTL - w.Header().Set("x-ratelimit-remaining", fmt.Sprint(rl.config.RateLimitCount-used)) - w.Header().Set("x-ratelimit-used", fmt.Sprint(used)) next.ServeHTTP(w, r) }) } @@ -92,10 +93,13 @@ func (rl *RateLimiter) getClientIP(r *http.Request) string { } if slices.Contains(rl.config.TrustedProxies, ip) { - xForwardedFor := r.Header.Values("x-forwarded-for") + xForwardedFor := r.Header.Get("x-forwarded-for") - if len(xForwardedFor) > 0 { - return xForwardedFor[0] + if xForwardedFor != "" { + firstIp := strings.SplitN(xForwardedFor, ",", 2)[0] + if firstIp != "" { + return firstIp + } } }