HttpClient now has a built-in max response size check

This commit is contained in:
Wouter Groeneveld 2021-04-20 09:53:17 +02:00
parent 00f927886d
commit 1181b9e1fe
3 changed files with 115 additions and 8 deletions

View File

@ -103,12 +103,13 @@ func TestSendMentionIntegrationStressTest(t *testing.T) {
runs := 100
responses := make(chan bool, runs)
http.HandleFunc("/pingback", func(writer http.ResponseWriter, request *http.Request) {
mux := http.NewServeMux()
mux.HandleFunc("/pingback", func(writer http.ResponseWriter, request *http.Request) {
writer.WriteHeader(200)
writer.Write([]byte("pingbacked stuff."))
responses <- true
})
http.HandleFunc("/target", func(writer http.ResponseWriter, request *http.Request) {
mux.HandleFunc("/target", func(writer http.ResponseWriter, request *http.Request) {
target := `<html>
<head>
<link rel="pingback" href="http://localhost:6666/pingback" />
@ -119,10 +120,12 @@ func TestSendMentionIntegrationStressTest(t *testing.T) {
writer.WriteHeader(200)
writer.Write([]byte(target))
})
srv := &http.Server{Addr: ":6666", Handler: mux}
defer srv.Close()
go func() {
fmt.Println("Serving stub at 6666...")
http.ListenAndServe(":6666", nil)
srv.ListenAndServe()
fmt.Println("Stub stopped?")
}()

View File

@ -1,10 +1,11 @@
package rest
import (
"errors"
"fmt"
"github.com/hashicorp/go-cleanhttp"
"github.com/hashicorp/go-retryablehttp"
"io/ioutil"
"io"
"net/http"
"net/url"
"strings"
@ -21,6 +22,10 @@ type Client interface {
type HttpClient struct {
}
const (
MaxBytes = 5000000 // 5 MiB
)
var (
// do not use retryablehttp default impl - inject own logger and retry policies
jammingHttp = &retryablehttp.Client{
@ -32,6 +37,8 @@ var (
CheckRetry: retryablehttp.DefaultRetryPolicy,
Backoff: retryablehttp.DefaultBackoff,
}
ResponseAboveLimit = errors.New("response bigger than limit")
)
func (client *HttpClient) PostForm(url string, formData url.Values) error {
@ -56,21 +63,22 @@ func (client *HttpClient) Post(url string, contenType string, body string) error
return nil
}
// something like this? https://freshman.tech/snippets/go/http-response-to-string/
// GetBody issues a retryable GET request and returns the header, body string, and a possible error.
// It limits response sizes to MaxBytes and returns an error if status not between [200, 299].
func (client *HttpClient) GetBody(url string) (http.Header, string, error) {
resp, geterr := client.Get(url)
if geterr != nil {
return nil, "", fmt.Errorf("GET from %s: %v", url, geterr)
return nil, "", fmt.Errorf("GET from %s: %w", url, geterr)
}
if !isStatusOk(resp) {
return nil, "", fmt.Errorf("GET from %s: Status code is not OK (%d)", url, resp.StatusCode)
}
body, readerr := ioutil.ReadAll(resp.Body)
body, readerr := readUntilMax(resp.Body, MaxBytes)
defer resp.Body.Close()
if readerr != nil {
return nil, "", fmt.Errorf("GET from %s: unable to read body: %v", url, readerr)
return nil, "", fmt.Errorf("GET from %s: unable to read body: %w", url, readerr)
}
return resp.Header, string(body), nil
}
@ -82,3 +90,27 @@ func isStatusOk(resp *http.Response) bool {
func (client *HttpClient) Get(url string) (*http.Response, error) {
return jammingHttp.Get(url)
}
// readUntilMax is a duplicate of io.Read(). It behaves exactly the same.
// However, it will only read maxBytes bytes, exponentially chunked (as per append).
// Returns an error if it exceeds the limit.
func readUntilMax(r io.Reader, maxBytes int) ([]byte, error) {
b := make([]byte, 0, 512)
for {
if len(b) == cap(b) {
// Add more capacity (let append pick how much).
b = append(b, 0)[:len(b)]
}
n, err := r.Read(b[len(b):cap(b)])
b = b[:len(b)+n]
if err != nil {
if err == io.EOF {
err = nil
}
return b, err
}
if len(b) > maxBytes {
return nil, ResponseAboveLimit
}
}
}

72
rest/client_test.go Normal file
View File

@ -0,0 +1,72 @@
package rest
import (
"errors"
"github.com/stretchr/testify/assert"
"io/ioutil"
"net/http"
"testing"
)
var client = HttpClient{}
func TestGetBodyWithinLimitsReturnsHeadersAndBodyString(t *testing.T) {
mux := http.NewServeMux()
mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("amaigat", "jup.")
w.WriteHeader(200)
data, err := ioutil.ReadFile("../mocks/samplerss.xml") // is about 1.6 MB
assert.NoError(t, err)
w.Write(data)
})
srv := &http.Server{Addr: ":6666", Handler: mux}
defer srv.Close()
go func() {
srv.ListenAndServe()
}()
headers, body, err := client.GetBody("http://localhost:6666/")
assert.NoError(t, err)
assert.Equal(t, "jup.", headers.Get("amaigat"))
assert.Contains(t, body, "<rss")
}
func TestGetBodyOf404ReturnsError(t *testing.T) {
mux := http.NewServeMux()
mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(404)
})
srv := &http.Server{Addr: ":6666", Handler: mux}
defer srv.Close()
go func() {
srv.ListenAndServe()
}()
_, body, err := client.GetBody("http://localhost:6666/")
assert.Contains(t, err.Error(), "404")
assert.Equal(t, "", body)
}
func TestGetBodyOfTooLargeContentReturnsError(t *testing.T) {
mux := http.NewServeMux()
mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(200)
garbage := make([]byte, MaxBytes*2)
for i := 0; i < len(garbage); i++ {
garbage[i] = 'A'
}
w.Write(garbage)
})
srv := &http.Server{Addr: ":6666", Handler: mux}
defer srv.Close()
go func() {
srv.ListenAndServe()
}()
_, body, err := client.GetBody("http://localhost:6666/")
assert.Equal(t, ResponseAboveLimit, errors.Unwrap(err))
assert.Equal(t, "", body)
}