loosen up ratelimiter to take requestUri into account

This commit is contained in:
Wouter Groeneveld 2022-05-07 14:49:19 +02:00
parent 8d415d022c
commit 56b84125c4
3 changed files with 71 additions and 9 deletions

View File

@ -3,6 +3,7 @@ package app
import (
"brainbaking.com/go-jamming/common"
"brainbaking.com/go-jamming/rest"
"fmt"
"github.com/rs/zerolog/log"
"golang.org/x/time/rate"
"net/http"
@ -43,14 +44,15 @@ const (
cleanupCron = 2 * time.Minute
)
func (rl *RateLimiter) getVisitor(ip string) *rate.Limiter {
func (rl *RateLimiter) limiterFor(ip string, uri string) *rate.Limiter {
rl.mu.Lock()
defer rl.mu.Unlock()
v, exists := rl.visitors[ip]
key := fmt.Sprintf("%s-%s", ip, uri)
v, exists := rl.visitors[key]
if !exists {
limiter := rate.NewLimiter(rate.Limit(rl.rateLimitPerSec), rl.rateBurst)
rl.visitors[ip] = &visitor{limiter, common.Now()}
rl.visitors[key] = &visitor{limiter, common.Now()}
return limiter
}
@ -63,10 +65,10 @@ func (rl *RateLimiter) cleanupVisitors() {
time.Sleep(cleanupCron)
rl.mu.Lock()
for ip, v := range rl.visitors {
for key, v := range rl.visitors {
if time.Since(v.lastSeen) > ttl {
log.Debug().Str("ip", ip).Msg("Cleaning up rate limiter visitor")
delete(rl.visitors, ip)
log.Debug().Str("key", key).Msg("Cleaning up rate limiter visitor")
delete(rl.visitors, key)
}
}
rl.mu.Unlock()
@ -77,10 +79,10 @@ func (rl *RateLimiter) cleanupVisitors() {
func (rl *RateLimiter) limiterMiddleware(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
ip := ipFrom(r)
limiter := rl.getVisitor(ip)
limiter := rl.limiterFor(ip, r.RequestURI)
if !limiter.Allow() {
log.Error().Str("ip", ip).Msg("Someone spamming? Rate limit hit!")
log.Error().Str("ip", ip).Str("uri", r.RequestURI).Msg("Someone spamming? Rate limit hit!")
rest.TooManyRequests(w)
return
}

57
app/limiter_test.go Normal file
View File

@ -0,0 +1,57 @@
package app
import (
"fmt"
"github.com/gorilla/mux"
"github.com/stretchr/testify/assert"
"net/http"
"net/http/httptest"
"testing"
)
func TestHitsRateLimitAfterSlammingRequests(t *testing.T) {
r := mux.NewRouter()
r.HandleFunc("/endpoint", testFn).Methods("GET")
r.Use(NewRateLimiter(5, 10).Middleware)
ts := httptest.NewServer(r)
t.Cleanup(ts.Close)
statusCodes := []int{}
for i := 0; i <= 10; i++ {
client := &http.Client{}
req, _ := http.NewRequest("GET", fmt.Sprintf("%s/endpoint", ts.URL), nil)
resp, err := client.Do(req)
assert.NoError(t, err)
statusCodes = append(statusCodes, resp.StatusCode)
}
assert.Contains(t, statusCodes, 429)
}
func TestDoesNotHitRateLimitOfSecondEndpointAfterSlammingFirst(t *testing.T) {
r := mux.NewRouter()
r.HandleFunc("/endpoint1", testFn).Methods("GET")
r.HandleFunc("/endpoint2", testFn).Methods("GET")
r.Use(NewRateLimiter(5, 10).Middleware)
ts := httptest.NewServer(r)
t.Cleanup(ts.Close)
for i := 0; i <= 10; i++ {
client := &http.Client{}
req, _ := http.NewRequest("GET", fmt.Sprintf("%s/endpoint1", ts.URL), nil)
client.Do(req)
}
for i := 0; i <= 5; i++ {
client := &http.Client{}
req, _ := http.NewRequest("GET", fmt.Sprintf("%s/endpoint2", ts.URL), nil)
resp, err := client.Do(req)
assert.NoError(t, err)
assert.Equal(t, 200, resp.StatusCode)
}
}
func testFn(w http.ResponseWriter, r *http.Request) {
w.Write([]byte("ok"))
}

View File

@ -55,7 +55,10 @@ func ipFrom(r *http.Request) string {
if forwardedFor != "" { // in case of proxy. Could be: clientip, proxy1, proxy2, ...
return strings.Split(forwardedFor, ",")[0]
}
return r.RemoteAddr // also contains port, but don't care
if strings.Contains(r.RemoteAddr, ":") { // in case of 127.0.0.1:12345
return strings.Split(r.RemoteAddr, ":")[0]
}
return r.RemoteAddr
}
func Start() {