loosen up ratelimiter to take requestUri into account
This commit is contained in:
parent
8d415d022c
commit
56b84125c4
|
@ -3,6 +3,7 @@ package app
|
||||||
import (
|
import (
|
||||||
"brainbaking.com/go-jamming/common"
|
"brainbaking.com/go-jamming/common"
|
||||||
"brainbaking.com/go-jamming/rest"
|
"brainbaking.com/go-jamming/rest"
|
||||||
|
"fmt"
|
||||||
"github.com/rs/zerolog/log"
|
"github.com/rs/zerolog/log"
|
||||||
"golang.org/x/time/rate"
|
"golang.org/x/time/rate"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
@ -43,14 +44,15 @@ const (
|
||||||
cleanupCron = 2 * time.Minute
|
cleanupCron = 2 * time.Minute
|
||||||
)
|
)
|
||||||
|
|
||||||
func (rl *RateLimiter) getVisitor(ip string) *rate.Limiter {
|
func (rl *RateLimiter) limiterFor(ip string, uri string) *rate.Limiter {
|
||||||
rl.mu.Lock()
|
rl.mu.Lock()
|
||||||
defer rl.mu.Unlock()
|
defer rl.mu.Unlock()
|
||||||
|
|
||||||
v, exists := rl.visitors[ip]
|
key := fmt.Sprintf("%s-%s", ip, uri)
|
||||||
|
v, exists := rl.visitors[key]
|
||||||
if !exists {
|
if !exists {
|
||||||
limiter := rate.NewLimiter(rate.Limit(rl.rateLimitPerSec), rl.rateBurst)
|
limiter := rate.NewLimiter(rate.Limit(rl.rateLimitPerSec), rl.rateBurst)
|
||||||
rl.visitors[ip] = &visitor{limiter, common.Now()}
|
rl.visitors[key] = &visitor{limiter, common.Now()}
|
||||||
return limiter
|
return limiter
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -63,10 +65,10 @@ func (rl *RateLimiter) cleanupVisitors() {
|
||||||
time.Sleep(cleanupCron)
|
time.Sleep(cleanupCron)
|
||||||
|
|
||||||
rl.mu.Lock()
|
rl.mu.Lock()
|
||||||
for ip, v := range rl.visitors {
|
for key, v := range rl.visitors {
|
||||||
if time.Since(v.lastSeen) > ttl {
|
if time.Since(v.lastSeen) > ttl {
|
||||||
log.Debug().Str("ip", ip).Msg("Cleaning up rate limiter visitor")
|
log.Debug().Str("key", key).Msg("Cleaning up rate limiter visitor")
|
||||||
delete(rl.visitors, ip)
|
delete(rl.visitors, key)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
rl.mu.Unlock()
|
rl.mu.Unlock()
|
||||||
|
@ -77,10 +79,10 @@ func (rl *RateLimiter) cleanupVisitors() {
|
||||||
func (rl *RateLimiter) limiterMiddleware(next http.Handler) http.Handler {
|
func (rl *RateLimiter) limiterMiddleware(next http.Handler) http.Handler {
|
||||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
ip := ipFrom(r)
|
ip := ipFrom(r)
|
||||||
limiter := rl.getVisitor(ip)
|
limiter := rl.limiterFor(ip, r.RequestURI)
|
||||||
|
|
||||||
if !limiter.Allow() {
|
if !limiter.Allow() {
|
||||||
log.Error().Str("ip", ip).Msg("Someone spamming? Rate limit hit!")
|
log.Error().Str("ip", ip).Str("uri", r.RequestURI).Msg("Someone spamming? Rate limit hit!")
|
||||||
rest.TooManyRequests(w)
|
rest.TooManyRequests(w)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
|
@ -0,0 +1,57 @@
|
||||||
|
package app
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"github.com/gorilla/mux"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestHitsRateLimitAfterSlammingRequests(t *testing.T) {
|
||||||
|
r := mux.NewRouter()
|
||||||
|
r.HandleFunc("/endpoint", testFn).Methods("GET")
|
||||||
|
r.Use(NewRateLimiter(5, 10).Middleware)
|
||||||
|
ts := httptest.NewServer(r)
|
||||||
|
|
||||||
|
t.Cleanup(ts.Close)
|
||||||
|
statusCodes := []int{}
|
||||||
|
|
||||||
|
for i := 0; i <= 10; i++ {
|
||||||
|
client := &http.Client{}
|
||||||
|
req, _ := http.NewRequest("GET", fmt.Sprintf("%s/endpoint", ts.URL), nil)
|
||||||
|
|
||||||
|
resp, err := client.Do(req)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
statusCodes = append(statusCodes, resp.StatusCode)
|
||||||
|
}
|
||||||
|
assert.Contains(t, statusCodes, 429)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDoesNotHitRateLimitOfSecondEndpointAfterSlammingFirst(t *testing.T) {
|
||||||
|
r := mux.NewRouter()
|
||||||
|
r.HandleFunc("/endpoint1", testFn).Methods("GET")
|
||||||
|
r.HandleFunc("/endpoint2", testFn).Methods("GET")
|
||||||
|
r.Use(NewRateLimiter(5, 10).Middleware)
|
||||||
|
ts := httptest.NewServer(r)
|
||||||
|
|
||||||
|
t.Cleanup(ts.Close)
|
||||||
|
|
||||||
|
for i := 0; i <= 10; i++ {
|
||||||
|
client := &http.Client{}
|
||||||
|
req, _ := http.NewRequest("GET", fmt.Sprintf("%s/endpoint1", ts.URL), nil)
|
||||||
|
client.Do(req)
|
||||||
|
}
|
||||||
|
for i := 0; i <= 5; i++ {
|
||||||
|
client := &http.Client{}
|
||||||
|
req, _ := http.NewRequest("GET", fmt.Sprintf("%s/endpoint2", ts.URL), nil)
|
||||||
|
resp, err := client.Do(req)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Equal(t, 200, resp.StatusCode)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func testFn(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.Write([]byte("ok"))
|
||||||
|
}
|
|
@ -55,7 +55,10 @@ func ipFrom(r *http.Request) string {
|
||||||
if forwardedFor != "" { // in case of proxy. Could be: clientip, proxy1, proxy2, ...
|
if forwardedFor != "" { // in case of proxy. Could be: clientip, proxy1, proxy2, ...
|
||||||
return strings.Split(forwardedFor, ",")[0]
|
return strings.Split(forwardedFor, ",")[0]
|
||||||
}
|
}
|
||||||
return r.RemoteAddr // also contains port, but don't care
|
if strings.Contains(r.RemoteAddr, ":") { // in case of 127.0.0.1:12345
|
||||||
|
return strings.Split(r.RemoteAddr, ":")[0]
|
||||||
|
}
|
||||||
|
return r.RemoteAddr
|
||||||
}
|
}
|
||||||
|
|
||||||
func Start() {
|
func Start() {
|
||||||
|
|
Loading…
Reference in New Issue