loosen up ratelimiter to take requestUri into account
parent
8d415d022c
commit
56b84125c4
|
@ -3,6 +3,7 @@ package app
|
|||
import (
|
||||
"brainbaking.com/go-jamming/common"
|
||||
"brainbaking.com/go-jamming/rest"
|
||||
"fmt"
|
||||
"github.com/rs/zerolog/log"
|
||||
"golang.org/x/time/rate"
|
||||
"net/http"
|
||||
|
@ -43,14 +44,15 @@ const (
|
|||
cleanupCron = 2 * time.Minute
|
||||
)
|
||||
|
||||
func (rl *RateLimiter) getVisitor(ip string) *rate.Limiter {
|
||||
func (rl *RateLimiter) limiterFor(ip string, uri string) *rate.Limiter {
|
||||
rl.mu.Lock()
|
||||
defer rl.mu.Unlock()
|
||||
|
||||
v, exists := rl.visitors[ip]
|
||||
key := fmt.Sprintf("%s-%s", ip, uri)
|
||||
v, exists := rl.visitors[key]
|
||||
if !exists {
|
||||
limiter := rate.NewLimiter(rate.Limit(rl.rateLimitPerSec), rl.rateBurst)
|
||||
rl.visitors[ip] = &visitor{limiter, common.Now()}
|
||||
rl.visitors[key] = &visitor{limiter, common.Now()}
|
||||
return limiter
|
||||
}
|
||||
|
||||
|
@ -63,10 +65,10 @@ func (rl *RateLimiter) cleanupVisitors() {
|
|||
time.Sleep(cleanupCron)
|
||||
|
||||
rl.mu.Lock()
|
||||
for ip, v := range rl.visitors {
|
||||
for key, v := range rl.visitors {
|
||||
if time.Since(v.lastSeen) > ttl {
|
||||
log.Debug().Str("ip", ip).Msg("Cleaning up rate limiter visitor")
|
||||
delete(rl.visitors, ip)
|
||||
log.Debug().Str("key", key).Msg("Cleaning up rate limiter visitor")
|
||||
delete(rl.visitors, key)
|
||||
}
|
||||
}
|
||||
rl.mu.Unlock()
|
||||
|
@ -77,10 +79,10 @@ func (rl *RateLimiter) cleanupVisitors() {
|
|||
func (rl *RateLimiter) limiterMiddleware(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
ip := ipFrom(r)
|
||||
limiter := rl.getVisitor(ip)
|
||||
limiter := rl.limiterFor(ip, r.RequestURI)
|
||||
|
||||
if !limiter.Allow() {
|
||||
log.Error().Str("ip", ip).Msg("Someone spamming? Rate limit hit!")
|
||||
log.Error().Str("ip", ip).Str("uri", r.RequestURI).Msg("Someone spamming? Rate limit hit!")
|
||||
rest.TooManyRequests(w)
|
||||
return
|
||||
}
|
||||
|
|
|
@ -0,0 +1,57 @@
|
|||
package app
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"github.com/gorilla/mux"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestHitsRateLimitAfterSlammingRequests(t *testing.T) {
|
||||
r := mux.NewRouter()
|
||||
r.HandleFunc("/endpoint", testFn).Methods("GET")
|
||||
r.Use(NewRateLimiter(5, 10).Middleware)
|
||||
ts := httptest.NewServer(r)
|
||||
|
||||
t.Cleanup(ts.Close)
|
||||
statusCodes := []int{}
|
||||
|
||||
for i := 0; i <= 10; i++ {
|
||||
client := &http.Client{}
|
||||
req, _ := http.NewRequest("GET", fmt.Sprintf("%s/endpoint", ts.URL), nil)
|
||||
|
||||
resp, err := client.Do(req)
|
||||
assert.NoError(t, err)
|
||||
statusCodes = append(statusCodes, resp.StatusCode)
|
||||
}
|
||||
assert.Contains(t, statusCodes, 429)
|
||||
}
|
||||
|
||||
func TestDoesNotHitRateLimitOfSecondEndpointAfterSlammingFirst(t *testing.T) {
|
||||
r := mux.NewRouter()
|
||||
r.HandleFunc("/endpoint1", testFn).Methods("GET")
|
||||
r.HandleFunc("/endpoint2", testFn).Methods("GET")
|
||||
r.Use(NewRateLimiter(5, 10).Middleware)
|
||||
ts := httptest.NewServer(r)
|
||||
|
||||
t.Cleanup(ts.Close)
|
||||
|
||||
for i := 0; i <= 10; i++ {
|
||||
client := &http.Client{}
|
||||
req, _ := http.NewRequest("GET", fmt.Sprintf("%s/endpoint1", ts.URL), nil)
|
||||
client.Do(req)
|
||||
}
|
||||
for i := 0; i <= 5; i++ {
|
||||
client := &http.Client{}
|
||||
req, _ := http.NewRequest("GET", fmt.Sprintf("%s/endpoint2", ts.URL), nil)
|
||||
resp, err := client.Do(req)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, 200, resp.StatusCode)
|
||||
}
|
||||
}
|
||||
|
||||
func testFn(w http.ResponseWriter, r *http.Request) {
|
||||
w.Write([]byte("ok"))
|
||||
}
|
|
@ -55,7 +55,10 @@ func ipFrom(r *http.Request) string {
|
|||
if forwardedFor != "" { // in case of proxy. Could be: clientip, proxy1, proxy2, ...
|
||||
return strings.Split(forwardedFor, ",")[0]
|
||||
}
|
||||
return r.RemoteAddr // also contains port, but don't care
|
||||
if strings.Contains(r.RemoteAddr, ":") { // in case of 127.0.0.1:12345
|
||||
return strings.Split(r.RemoteAddr, ":")[0]
|
||||
}
|
||||
return r.RemoteAddr
|
||||
}
|
||||
|
||||
func Start() {
|
||||
|
|
Loading…
Reference in New Issue