HttpClient now has a built-in max response size check
This commit is contained in:
parent
00f927886d
commit
1181b9e1fe
|
@ -103,12 +103,13 @@ func TestSendMentionIntegrationStressTest(t *testing.T) {
|
|||
runs := 100
|
||||
responses := make(chan bool, runs)
|
||||
|
||||
http.HandleFunc("/pingback", func(writer http.ResponseWriter, request *http.Request) {
|
||||
mux := http.NewServeMux()
|
||||
mux.HandleFunc("/pingback", func(writer http.ResponseWriter, request *http.Request) {
|
||||
writer.WriteHeader(200)
|
||||
writer.Write([]byte("pingbacked stuff."))
|
||||
responses <- true
|
||||
})
|
||||
http.HandleFunc("/target", func(writer http.ResponseWriter, request *http.Request) {
|
||||
mux.HandleFunc("/target", func(writer http.ResponseWriter, request *http.Request) {
|
||||
target := `<html>
|
||||
<head>
|
||||
<link rel="pingback" href="http://localhost:6666/pingback" />
|
||||
|
@ -119,10 +120,12 @@ func TestSendMentionIntegrationStressTest(t *testing.T) {
|
|||
writer.WriteHeader(200)
|
||||
writer.Write([]byte(target))
|
||||
})
|
||||
srv := &http.Server{Addr: ":6666", Handler: mux}
|
||||
defer srv.Close()
|
||||
|
||||
go func() {
|
||||
fmt.Println("Serving stub at 6666...")
|
||||
http.ListenAndServe(":6666", nil)
|
||||
srv.ListenAndServe()
|
||||
fmt.Println("Stub stopped?")
|
||||
}()
|
||||
|
||||
|
|
|
@ -1,10 +1,11 @@
|
|||
package rest
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"github.com/hashicorp/go-cleanhttp"
|
||||
"github.com/hashicorp/go-retryablehttp"
|
||||
"io/ioutil"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
|
@ -21,6 +22,10 @@ type Client interface {
|
|||
type HttpClient struct {
|
||||
}
|
||||
|
||||
const (
|
||||
MaxBytes = 5000000 // 5 MiB
|
||||
)
|
||||
|
||||
var (
|
||||
// do not use retryablehttp default impl - inject own logger and retry policies
|
||||
jammingHttp = &retryablehttp.Client{
|
||||
|
@ -32,6 +37,8 @@ var (
|
|||
CheckRetry: retryablehttp.DefaultRetryPolicy,
|
||||
Backoff: retryablehttp.DefaultBackoff,
|
||||
}
|
||||
|
||||
ResponseAboveLimit = errors.New("response bigger than limit")
|
||||
)
|
||||
|
||||
func (client *HttpClient) PostForm(url string, formData url.Values) error {
|
||||
|
@ -56,21 +63,22 @@ func (client *HttpClient) Post(url string, contenType string, body string) error
|
|||
return nil
|
||||
}
|
||||
|
||||
// something like this? https://freshman.tech/snippets/go/http-response-to-string/
|
||||
// GetBody issues a retryable GET request and returns the header, body string, and a possible error.
|
||||
// It limits response sizes to MaxBytes and returns an error if status not between [200, 299].
|
||||
func (client *HttpClient) GetBody(url string) (http.Header, string, error) {
|
||||
resp, geterr := client.Get(url)
|
||||
if geterr != nil {
|
||||
return nil, "", fmt.Errorf("GET from %s: %v", url, geterr)
|
||||
return nil, "", fmt.Errorf("GET from %s: %w", url, geterr)
|
||||
}
|
||||
|
||||
if !isStatusOk(resp) {
|
||||
return nil, "", fmt.Errorf("GET from %s: Status code is not OK (%d)", url, resp.StatusCode)
|
||||
}
|
||||
|
||||
body, readerr := ioutil.ReadAll(resp.Body)
|
||||
body, readerr := readUntilMax(resp.Body, MaxBytes)
|
||||
defer resp.Body.Close()
|
||||
if readerr != nil {
|
||||
return nil, "", fmt.Errorf("GET from %s: unable to read body: %v", url, readerr)
|
||||
return nil, "", fmt.Errorf("GET from %s: unable to read body: %w", url, readerr)
|
||||
}
|
||||
return resp.Header, string(body), nil
|
||||
}
|
||||
|
@ -82,3 +90,27 @@ func isStatusOk(resp *http.Response) bool {
|
|||
func (client *HttpClient) Get(url string) (*http.Response, error) {
|
||||
return jammingHttp.Get(url)
|
||||
}
|
||||
|
||||
// readUntilMax is a duplicate of io.Read(). It behaves exactly the same.
|
||||
// However, it will only read maxBytes bytes, exponentially chunked (as per append).
|
||||
// Returns an error if it exceeds the limit.
|
||||
func readUntilMax(r io.Reader, maxBytes int) ([]byte, error) {
|
||||
b := make([]byte, 0, 512)
|
||||
for {
|
||||
if len(b) == cap(b) {
|
||||
// Add more capacity (let append pick how much).
|
||||
b = append(b, 0)[:len(b)]
|
||||
}
|
||||
n, err := r.Read(b[len(b):cap(b)])
|
||||
b = b[:len(b)+n]
|
||||
if err != nil {
|
||||
if err == io.EOF {
|
||||
err = nil
|
||||
}
|
||||
return b, err
|
||||
}
|
||||
if len(b) > maxBytes {
|
||||
return nil, ResponseAboveLimit
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -0,0 +1,72 @@
|
|||
package rest
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"io/ioutil"
|
||||
"net/http"
|
||||
"testing"
|
||||
)
|
||||
|
||||
var client = HttpClient{}
|
||||
|
||||
func TestGetBodyWithinLimitsReturnsHeadersAndBodyString(t *testing.T) {
|
||||
mux := http.NewServeMux()
|
||||
mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("amaigat", "jup.")
|
||||
w.WriteHeader(200)
|
||||
data, err := ioutil.ReadFile("../mocks/samplerss.xml") // is about 1.6 MB
|
||||
assert.NoError(t, err)
|
||||
w.Write(data)
|
||||
})
|
||||
srv := &http.Server{Addr: ":6666", Handler: mux}
|
||||
defer srv.Close()
|
||||
|
||||
go func() {
|
||||
srv.ListenAndServe()
|
||||
}()
|
||||
headers, body, err := client.GetBody("http://localhost:6666/")
|
||||
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "jup.", headers.Get("amaigat"))
|
||||
assert.Contains(t, body, "<rss")
|
||||
}
|
||||
|
||||
func TestGetBodyOf404ReturnsError(t *testing.T) {
|
||||
mux := http.NewServeMux()
|
||||
mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(404)
|
||||
})
|
||||
srv := &http.Server{Addr: ":6666", Handler: mux}
|
||||
defer srv.Close()
|
||||
|
||||
go func() {
|
||||
srv.ListenAndServe()
|
||||
}()
|
||||
_, body, err := client.GetBody("http://localhost:6666/")
|
||||
|
||||
assert.Contains(t, err.Error(), "404")
|
||||
assert.Equal(t, "", body)
|
||||
}
|
||||
|
||||
func TestGetBodyOfTooLargeContentReturnsError(t *testing.T) {
|
||||
mux := http.NewServeMux()
|
||||
mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(200)
|
||||
garbage := make([]byte, MaxBytes*2)
|
||||
for i := 0; i < len(garbage); i++ {
|
||||
garbage[i] = 'A'
|
||||
}
|
||||
w.Write(garbage)
|
||||
})
|
||||
srv := &http.Server{Addr: ":6666", Handler: mux}
|
||||
defer srv.Close()
|
||||
|
||||
go func() {
|
||||
srv.ListenAndServe()
|
||||
}()
|
||||
_, body, err := client.GetBody("http://localhost:6666/")
|
||||
|
||||
assert.Equal(t, ResponseAboveLimit, errors.Unwrap(err))
|
||||
assert.Equal(t, "", body)
|
||||
}
|
Loading…
Reference in New Issue