From 56b84125c4756eb57efdc24eb1af9fb289c7bbad Mon Sep 17 00:00:00 2001 From: wgroeneveld Date: Sat, 7 May 2022 14:49:19 +0200 Subject: [PATCH] loosen up ratelimiter to take requestUri into account --- app/limiter.go | 18 +++++++------- app/limiter_test.go | 57 +++++++++++++++++++++++++++++++++++++++++++++ app/server.go | 5 +++- 3 files changed, 71 insertions(+), 9 deletions(-) create mode 100644 app/limiter_test.go diff --git a/app/limiter.go b/app/limiter.go index cbea588..76e8031 100644 --- a/app/limiter.go +++ b/app/limiter.go @@ -3,6 +3,7 @@ package app import ( "brainbaking.com/go-jamming/common" "brainbaking.com/go-jamming/rest" + "fmt" "github.com/rs/zerolog/log" "golang.org/x/time/rate" "net/http" @@ -43,14 +44,15 @@ const ( cleanupCron = 2 * time.Minute ) -func (rl *RateLimiter) getVisitor(ip string) *rate.Limiter { +func (rl *RateLimiter) limiterFor(ip string, uri string) *rate.Limiter { rl.mu.Lock() defer rl.mu.Unlock() - v, exists := rl.visitors[ip] + key := fmt.Sprintf("%s-%s", ip, uri) + v, exists := rl.visitors[key] if !exists { limiter := rate.NewLimiter(rate.Limit(rl.rateLimitPerSec), rl.rateBurst) - rl.visitors[ip] = &visitor{limiter, common.Now()} + rl.visitors[key] = &visitor{limiter, common.Now()} return limiter } @@ -63,10 +65,10 @@ func (rl *RateLimiter) cleanupVisitors() { time.Sleep(cleanupCron) rl.mu.Lock() - for ip, v := range rl.visitors { + for key, v := range rl.visitors { if time.Since(v.lastSeen) > ttl { - log.Debug().Str("ip", ip).Msg("Cleaning up rate limiter visitor") - delete(rl.visitors, ip) + log.Debug().Str("key", key).Msg("Cleaning up rate limiter visitor") + delete(rl.visitors, key) } } rl.mu.Unlock() @@ -77,10 +79,10 @@ func (rl *RateLimiter) cleanupVisitors() { func (rl *RateLimiter) limiterMiddleware(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { ip := ipFrom(r) - limiter := rl.getVisitor(ip) + limiter := rl.limiterFor(ip, r.RequestURI) if !limiter.Allow() { - log.Error().Str("ip", ip).Msg("Someone spamming? Rate limit hit!") + log.Error().Str("ip", ip).Str("uri", r.RequestURI).Msg("Someone spamming? Rate limit hit!") rest.TooManyRequests(w) return } diff --git a/app/limiter_test.go b/app/limiter_test.go new file mode 100644 index 0000000..25f4794 --- /dev/null +++ b/app/limiter_test.go @@ -0,0 +1,57 @@ +package app + +import ( + "fmt" + "github.com/gorilla/mux" + "github.com/stretchr/testify/assert" + "net/http" + "net/http/httptest" + "testing" +) + +func TestHitsRateLimitAfterSlammingRequests(t *testing.T) { + r := mux.NewRouter() + r.HandleFunc("/endpoint", testFn).Methods("GET") + r.Use(NewRateLimiter(5, 10).Middleware) + ts := httptest.NewServer(r) + + t.Cleanup(ts.Close) + statusCodes := []int{} + + for i := 0; i <= 10; i++ { + client := &http.Client{} + req, _ := http.NewRequest("GET", fmt.Sprintf("%s/endpoint", ts.URL), nil) + + resp, err := client.Do(req) + assert.NoError(t, err) + statusCodes = append(statusCodes, resp.StatusCode) + } + assert.Contains(t, statusCodes, 429) +} + +func TestDoesNotHitRateLimitOfSecondEndpointAfterSlammingFirst(t *testing.T) { + r := mux.NewRouter() + r.HandleFunc("/endpoint1", testFn).Methods("GET") + r.HandleFunc("/endpoint2", testFn).Methods("GET") + r.Use(NewRateLimiter(5, 10).Middleware) + ts := httptest.NewServer(r) + + t.Cleanup(ts.Close) + + for i := 0; i <= 10; i++ { + client := &http.Client{} + req, _ := http.NewRequest("GET", fmt.Sprintf("%s/endpoint1", ts.URL), nil) + client.Do(req) + } + for i := 0; i <= 5; i++ { + client := &http.Client{} + req, _ := http.NewRequest("GET", fmt.Sprintf("%s/endpoint2", ts.URL), nil) + resp, err := client.Do(req) + assert.NoError(t, err) + assert.Equal(t, 200, resp.StatusCode) + } +} + +func testFn(w http.ResponseWriter, r *http.Request) { + w.Write([]byte("ok")) +} diff --git a/app/server.go b/app/server.go index 12c9d71..daaf60c 100644 --- a/app/server.go +++ b/app/server.go @@ -55,7 +55,10 @@ func ipFrom(r *http.Request) string { if forwardedFor != "" { // in case of proxy. Could be: clientip, proxy1, proxy2, ... return strings.Split(forwardedFor, ",")[0] } - return r.RemoteAddr // also contains port, but don't care + if strings.Contains(r.RemoteAddr, ":") { // in case of 127.0.0.1:12345 + return strings.Split(r.RemoteAddr, ":")[0] + } + return r.RemoteAddr } func Start() {