diff --git a/README.md b/README.md index a0fab65..dd17d06 100644 --- a/README.md +++ b/README.md @@ -4,7 +4,7 @@ Go module `brainbaking.com/go-jamming`: > A minimalistic Go-powered jamstack-augmented microservice for webmentions etc -✅️ **This is a fork of [https://github.com/wgroeneveld/serve-my-jams](serve-my-jams)**, the Node-powered original microservice, which is no longer being maintained. +✅️ **This is a fork of [serve-my-jams](https://github.com/wgroeneveld/serve-my-jams)**, the Node-powered original microservice, which is no longer being maintained. **Are you looking for a way to DO something with this?** See https://github.com/wgroeneveld/jam-my-stack ! diff --git a/app/limiter.go b/app/limiter.go new file mode 100644 index 0000000..4238b42 --- /dev/null +++ b/app/limiter.go @@ -0,0 +1,90 @@ +package app + +import ( + "brainbaking.com/go-jamming/common" + "brainbaking.com/go-jamming/rest" + "github.com/rs/zerolog/log" + "golang.org/x/time/rate" + "net/http" + "sync" + "time" +) + +type visitor struct { + limiter *rate.Limiter + lastSeen time.Time +} + +type RateLimiter struct { + visitors map[string]*visitor + mu sync.RWMutex + rateLimitPerSec int + rateBurst int + Middleware func(next http.Handler) http.Handler +} + +func NewRateLimiter(rateLimitPerSec int, rateBurst int) *RateLimiter { + rl := &RateLimiter{ + visitors: make(map[string]*visitor), + mu: sync.RWMutex{}, + rateBurst: rateBurst, + rateLimitPerSec: rateLimitPerSec, + } + rl.Middleware = func(next http.Handler) http.Handler { + return rl.limiterMiddleware(next) + } + + go rl.cleanupVisitors() + return rl +} + +const ( + ttl = 5 * time.Minute + cleanupCron = 2 * time.Minute +) + +func (rl *RateLimiter) getVisitor(ip string) *rate.Limiter { + rl.mu.Lock() + defer rl.mu.Unlock() + + v, exists := rl.visitors[ip] + if !exists { + limiter := rate.NewLimiter(rate.Limit(rl.rateLimitPerSec), rl.rateBurst) + rl.visitors[ip] = &visitor{limiter, common.Now()} + return limiter + } + + v.lastSeen = common.Now() + return v.limiter +} + +func (rl *RateLimiter) cleanupVisitors() { + for { + time.Sleep(cleanupCron) + + rl.mu.Lock() + for ip, v := range rl.visitors { + if time.Since(v.lastSeen) > ttl { + log.Debug().Str("ip", ip).Msg("Cleaning up rate limiter visitor") + delete(rl.visitors, ip) + } + } + rl.mu.Unlock() + } +} + +// with the help of https://www.alexedwards.net/blog/how-to-rate-limit-http-requests, TY! +func (rl *RateLimiter) limiterMiddleware(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + ip := r.RemoteAddr // also contains port, but don't care + limiter := rl.getVisitor(ip) + + if limiter.Allow() == false { + log.Error().Str("ip", ip).Msg("Someone spamming? Rate limit hit!") + rest.TooManyRequests(w) + return + } + + next.ServeHTTP(w, r) + }) +} diff --git a/app/logging.go b/app/logging.go index 85f0d55..793fd9b 100644 --- a/app/logging.go +++ b/app/logging.go @@ -17,7 +17,7 @@ func (lrw *loggingResponseWriter) WriteHeader(code int) { lrw.ResponseWriter.WriteHeader(code) } -func loggingMiddleware(next http.Handler) http.Handler { +func LoggingMiddleware(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { logWriter := &loggingResponseWriter{w, http.StatusOK} next.ServeHTTP(logWriter, r) diff --git a/app/server.go b/app/server.go index 5218f57..e4c79ae 100644 --- a/app/server.go +++ b/app/server.go @@ -35,7 +35,8 @@ func Start() { server.routes() http.Handle("/", r) - r.Use(loggingMiddleware) + r.Use(LoggingMiddleware) + r.Use(NewRateLimiter(5, 10).Middleware) log.Info().Int("port", server.conf.Port).Msg("Serving...") http.ListenAndServe(":"+strconv.Itoa(server.conf.Port), nil) diff --git a/go.mod b/go.mod index 6962d02..56b3c83 100644 --- a/go.mod +++ b/go.mod @@ -8,5 +8,6 @@ require ( github.com/hashicorp/go-retryablehttp v0.6.8 github.com/rs/zerolog v1.21.0 github.com/stretchr/testify v1.7.0 + golang.org/x/time v0.0.0-20210220033141-f8bda1e9f3ba willnorris.com/go/microformats v1.1.1 ) diff --git a/go.sum b/go.sum index 3ee2b39..113509a 100644 --- a/go.sum +++ b/go.sum @@ -46,6 +46,8 @@ golang.org/x/sys v0.0.0-20200930185726-fdedc70b468f/go.mod h1:h1NjWce9XRLGQEsW7w golang.org/x/sys v0.0.0-20210119212857-b64e53b001e4/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= +golang.org/x/time v0.0.0-20210220033141-f8bda1e9f3ba h1:O8mE0/t419eoIwhTFpKVkHiTs/Igowgfkj25AcZrtiE= +golang.org/x/time v0.0.0-20210220033141-f8bda1e9f3ba/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= golang.org/x/tools v0.1.0/go.mod h1:xkSsbof2nBLbhDlRMhhhyNLN/zl3eTqcnHD5viDpcZ0= diff --git a/playground.go b/playground.go index 1c74b44..a34932b 100644 --- a/playground.go +++ b/playground.go @@ -1,24 +1,5 @@ package main -import ( - "fmt" - "io/ioutil" - "log" - "net/http" -) - func mainz() { - fmt.Println("Hello, playground") - resp, err := http.Get("https://brainbaking.com/notes") - if err != nil { - log.Fatalln(err) - } - - body, err2 := ioutil.ReadAll(resp.Body) - if err2 != nil { - log.Fatalln(err) - } - - fmt.Printf("tis ditte") - fmt.Printf("%s", body) + //time.Tick() } diff --git a/rest/utils.go b/rest/utils.go index 46ed3e4..99011f5 100644 --- a/rest/utils.go +++ b/rest/utils.go @@ -8,11 +8,15 @@ import ( // mimicing NotFound: https://golang.org/src/net/http/server.go?s=64787:64830#L2076 func BadRequest(w http.ResponseWriter) { - http.Error(w, "400 bad request", http.StatusBadRequest) + http.Error(w, http.StatusText(http.StatusBadRequest), http.StatusBadRequest) +} + +func TooManyRequests(w http.ResponseWriter) { + http.Error(w, http.StatusText(http.StatusTooManyRequests), http.StatusTooManyRequests) } func Unauthorized(w http.ResponseWriter) { - http.Error(w, "401 unauthorized", http.StatusUnauthorized) + http.Error(w, http.StatusText(http.StatusUnauthorized), http.StatusUnauthorized) } func Json(w http.ResponseWriter, data interface{}) {