From 1181b9e1fea6e31532d696c294d81d2e275afca4 Mon Sep 17 00:00:00 2001 From: wgroeneveld Date: Tue, 20 Apr 2021 09:53:17 +0200 Subject: [PATCH] HttpClient now has a built-in max response size check --- app/webmention/send/send_test.go | 9 ++-- rest/client.go | 42 ++++++++++++++++--- rest/client_test.go | 72 ++++++++++++++++++++++++++++++++ 3 files changed, 115 insertions(+), 8 deletions(-) create mode 100644 rest/client_test.go diff --git a/app/webmention/send/send_test.go b/app/webmention/send/send_test.go index aa1c157..a2a4594 100644 --- a/app/webmention/send/send_test.go +++ b/app/webmention/send/send_test.go @@ -103,12 +103,13 @@ func TestSendMentionIntegrationStressTest(t *testing.T) { runs := 100 responses := make(chan bool, runs) - http.HandleFunc("/pingback", func(writer http.ResponseWriter, request *http.Request) { + mux := http.NewServeMux() + mux.HandleFunc("/pingback", func(writer http.ResponseWriter, request *http.Request) { writer.WriteHeader(200) writer.Write([]byte("pingbacked stuff.")) responses <- true }) - http.HandleFunc("/target", func(writer http.ResponseWriter, request *http.Request) { + mux.HandleFunc("/target", func(writer http.ResponseWriter, request *http.Request) { target := ` @@ -119,10 +120,12 @@ func TestSendMentionIntegrationStressTest(t *testing.T) { writer.WriteHeader(200) writer.Write([]byte(target)) }) + srv := &http.Server{Addr: ":6666", Handler: mux} + defer srv.Close() go func() { fmt.Println("Serving stub at 6666...") - http.ListenAndServe(":6666", nil) + srv.ListenAndServe() fmt.Println("Stub stopped?") }() diff --git a/rest/client.go b/rest/client.go index b682b24..4a45957 100644 --- a/rest/client.go +++ b/rest/client.go @@ -1,10 +1,11 @@ package rest import ( + "errors" "fmt" "github.com/hashicorp/go-cleanhttp" "github.com/hashicorp/go-retryablehttp" - "io/ioutil" + "io" "net/http" "net/url" "strings" @@ -21,6 +22,10 @@ type Client interface { type HttpClient struct { } +const ( + MaxBytes = 5000000 // 5 MiB +) + var ( // do not use retryablehttp default impl - inject own logger and retry policies jammingHttp = &retryablehttp.Client{ @@ -32,6 +37,8 @@ var ( CheckRetry: retryablehttp.DefaultRetryPolicy, Backoff: retryablehttp.DefaultBackoff, } + + ResponseAboveLimit = errors.New("response bigger than limit") ) func (client *HttpClient) PostForm(url string, formData url.Values) error { @@ -56,21 +63,22 @@ func (client *HttpClient) Post(url string, contenType string, body string) error return nil } -// something like this? https://freshman.tech/snippets/go/http-response-to-string/ +// GetBody issues a retryable GET request and returns the header, body string, and a possible error. +// It limits response sizes to MaxBytes and returns an error if status not between [200, 299]. func (client *HttpClient) GetBody(url string) (http.Header, string, error) { resp, geterr := client.Get(url) if geterr != nil { - return nil, "", fmt.Errorf("GET from %s: %v", url, geterr) + return nil, "", fmt.Errorf("GET from %s: %w", url, geterr) } if !isStatusOk(resp) { return nil, "", fmt.Errorf("GET from %s: Status code is not OK (%d)", url, resp.StatusCode) } - body, readerr := ioutil.ReadAll(resp.Body) + body, readerr := readUntilMax(resp.Body, MaxBytes) defer resp.Body.Close() if readerr != nil { - return nil, "", fmt.Errorf("GET from %s: unable to read body: %v", url, readerr) + return nil, "", fmt.Errorf("GET from %s: unable to read body: %w", url, readerr) } return resp.Header, string(body), nil } @@ -82,3 +90,27 @@ func isStatusOk(resp *http.Response) bool { func (client *HttpClient) Get(url string) (*http.Response, error) { return jammingHttp.Get(url) } + +// readUntilMax is a duplicate of io.Read(). It behaves exactly the same. +// However, it will only read maxBytes bytes, exponentially chunked (as per append). +// Returns an error if it exceeds the limit. +func readUntilMax(r io.Reader, maxBytes int) ([]byte, error) { + b := make([]byte, 0, 512) + for { + if len(b) == cap(b) { + // Add more capacity (let append pick how much). + b = append(b, 0)[:len(b)] + } + n, err := r.Read(b[len(b):cap(b)]) + b = b[:len(b)+n] + if err != nil { + if err == io.EOF { + err = nil + } + return b, err + } + if len(b) > maxBytes { + return nil, ResponseAboveLimit + } + } +} diff --git a/rest/client_test.go b/rest/client_test.go new file mode 100644 index 0000000..771587f --- /dev/null +++ b/rest/client_test.go @@ -0,0 +1,72 @@ +package rest + +import ( + "errors" + "github.com/stretchr/testify/assert" + "io/ioutil" + "net/http" + "testing" +) + +var client = HttpClient{} + +func TestGetBodyWithinLimitsReturnsHeadersAndBodyString(t *testing.T) { + mux := http.NewServeMux() + mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("amaigat", "jup.") + w.WriteHeader(200) + data, err := ioutil.ReadFile("../mocks/samplerss.xml") // is about 1.6 MB + assert.NoError(t, err) + w.Write(data) + }) + srv := &http.Server{Addr: ":6666", Handler: mux} + defer srv.Close() + + go func() { + srv.ListenAndServe() + }() + headers, body, err := client.GetBody("http://localhost:6666/") + + assert.NoError(t, err) + assert.Equal(t, "jup.", headers.Get("amaigat")) + assert.Contains(t, body, "