forked from wgroeneveld/go-jamming
HttpClient now has a built-in max response size check
This commit is contained in:
parent
00f927886d
commit
1181b9e1fe
|
@ -103,12 +103,13 @@ func TestSendMentionIntegrationStressTest(t *testing.T) {
|
||||||
runs := 100
|
runs := 100
|
||||||
responses := make(chan bool, runs)
|
responses := make(chan bool, runs)
|
||||||
|
|
||||||
http.HandleFunc("/pingback", func(writer http.ResponseWriter, request *http.Request) {
|
mux := http.NewServeMux()
|
||||||
|
mux.HandleFunc("/pingback", func(writer http.ResponseWriter, request *http.Request) {
|
||||||
writer.WriteHeader(200)
|
writer.WriteHeader(200)
|
||||||
writer.Write([]byte("pingbacked stuff."))
|
writer.Write([]byte("pingbacked stuff."))
|
||||||
responses <- true
|
responses <- true
|
||||||
})
|
})
|
||||||
http.HandleFunc("/target", func(writer http.ResponseWriter, request *http.Request) {
|
mux.HandleFunc("/target", func(writer http.ResponseWriter, request *http.Request) {
|
||||||
target := `<html>
|
target := `<html>
|
||||||
<head>
|
<head>
|
||||||
<link rel="pingback" href="http://localhost:6666/pingback" />
|
<link rel="pingback" href="http://localhost:6666/pingback" />
|
||||||
|
@ -119,10 +120,12 @@ func TestSendMentionIntegrationStressTest(t *testing.T) {
|
||||||
writer.WriteHeader(200)
|
writer.WriteHeader(200)
|
||||||
writer.Write([]byte(target))
|
writer.Write([]byte(target))
|
||||||
})
|
})
|
||||||
|
srv := &http.Server{Addr: ":6666", Handler: mux}
|
||||||
|
defer srv.Close()
|
||||||
|
|
||||||
go func() {
|
go func() {
|
||||||
fmt.Println("Serving stub at 6666...")
|
fmt.Println("Serving stub at 6666...")
|
||||||
http.ListenAndServe(":6666", nil)
|
srv.ListenAndServe()
|
||||||
fmt.Println("Stub stopped?")
|
fmt.Println("Stub stopped?")
|
||||||
}()
|
}()
|
||||||
|
|
||||||
|
|
|
@ -1,10 +1,11 @@
|
||||||
package rest
|
package rest
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"github.com/hashicorp/go-cleanhttp"
|
"github.com/hashicorp/go-cleanhttp"
|
||||||
"github.com/hashicorp/go-retryablehttp"
|
"github.com/hashicorp/go-retryablehttp"
|
||||||
"io/ioutil"
|
"io"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/url"
|
"net/url"
|
||||||
"strings"
|
"strings"
|
||||||
|
@ -21,6 +22,10 @@ type Client interface {
|
||||||
type HttpClient struct {
|
type HttpClient struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const (
|
||||||
|
MaxBytes = 5000000 // 5 MiB
|
||||||
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
// do not use retryablehttp default impl - inject own logger and retry policies
|
// do not use retryablehttp default impl - inject own logger and retry policies
|
||||||
jammingHttp = &retryablehttp.Client{
|
jammingHttp = &retryablehttp.Client{
|
||||||
|
@ -32,6 +37,8 @@ var (
|
||||||
CheckRetry: retryablehttp.DefaultRetryPolicy,
|
CheckRetry: retryablehttp.DefaultRetryPolicy,
|
||||||
Backoff: retryablehttp.DefaultBackoff,
|
Backoff: retryablehttp.DefaultBackoff,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
ResponseAboveLimit = errors.New("response bigger than limit")
|
||||||
)
|
)
|
||||||
|
|
||||||
func (client *HttpClient) PostForm(url string, formData url.Values) error {
|
func (client *HttpClient) PostForm(url string, formData url.Values) error {
|
||||||
|
@ -56,21 +63,22 @@ func (client *HttpClient) Post(url string, contenType string, body string) error
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// something like this? https://freshman.tech/snippets/go/http-response-to-string/
|
// GetBody issues a retryable GET request and returns the header, body string, and a possible error.
|
||||||
|
// It limits response sizes to MaxBytes and returns an error if status not between [200, 299].
|
||||||
func (client *HttpClient) GetBody(url string) (http.Header, string, error) {
|
func (client *HttpClient) GetBody(url string) (http.Header, string, error) {
|
||||||
resp, geterr := client.Get(url)
|
resp, geterr := client.Get(url)
|
||||||
if geterr != nil {
|
if geterr != nil {
|
||||||
return nil, "", fmt.Errorf("GET from %s: %v", url, geterr)
|
return nil, "", fmt.Errorf("GET from %s: %w", url, geterr)
|
||||||
}
|
}
|
||||||
|
|
||||||
if !isStatusOk(resp) {
|
if !isStatusOk(resp) {
|
||||||
return nil, "", fmt.Errorf("GET from %s: Status code is not OK (%d)", url, resp.StatusCode)
|
return nil, "", fmt.Errorf("GET from %s: Status code is not OK (%d)", url, resp.StatusCode)
|
||||||
}
|
}
|
||||||
|
|
||||||
body, readerr := ioutil.ReadAll(resp.Body)
|
body, readerr := readUntilMax(resp.Body, MaxBytes)
|
||||||
defer resp.Body.Close()
|
defer resp.Body.Close()
|
||||||
if readerr != nil {
|
if readerr != nil {
|
||||||
return nil, "", fmt.Errorf("GET from %s: unable to read body: %v", url, readerr)
|
return nil, "", fmt.Errorf("GET from %s: unable to read body: %w", url, readerr)
|
||||||
}
|
}
|
||||||
return resp.Header, string(body), nil
|
return resp.Header, string(body), nil
|
||||||
}
|
}
|
||||||
|
@ -82,3 +90,27 @@ func isStatusOk(resp *http.Response) bool {
|
||||||
func (client *HttpClient) Get(url string) (*http.Response, error) {
|
func (client *HttpClient) Get(url string) (*http.Response, error) {
|
||||||
return jammingHttp.Get(url)
|
return jammingHttp.Get(url)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// readUntilMax is a duplicate of io.Read(). It behaves exactly the same.
|
||||||
|
// However, it will only read maxBytes bytes, exponentially chunked (as per append).
|
||||||
|
// Returns an error if it exceeds the limit.
|
||||||
|
func readUntilMax(r io.Reader, maxBytes int) ([]byte, error) {
|
||||||
|
b := make([]byte, 0, 512)
|
||||||
|
for {
|
||||||
|
if len(b) == cap(b) {
|
||||||
|
// Add more capacity (let append pick how much).
|
||||||
|
b = append(b, 0)[:len(b)]
|
||||||
|
}
|
||||||
|
n, err := r.Read(b[len(b):cap(b)])
|
||||||
|
b = b[:len(b)+n]
|
||||||
|
if err != nil {
|
||||||
|
if err == io.EOF {
|
||||||
|
err = nil
|
||||||
|
}
|
||||||
|
return b, err
|
||||||
|
}
|
||||||
|
if len(b) > maxBytes {
|
||||||
|
return nil, ResponseAboveLimit
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
|
@ -0,0 +1,72 @@
|
||||||
|
package rest
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"io/ioutil"
|
||||||
|
"net/http"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
var client = HttpClient{}
|
||||||
|
|
||||||
|
func TestGetBodyWithinLimitsReturnsHeadersAndBodyString(t *testing.T) {
|
||||||
|
mux := http.NewServeMux()
|
||||||
|
mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.Header().Set("amaigat", "jup.")
|
||||||
|
w.WriteHeader(200)
|
||||||
|
data, err := ioutil.ReadFile("../mocks/samplerss.xml") // is about 1.6 MB
|
||||||
|
assert.NoError(t, err)
|
||||||
|
w.Write(data)
|
||||||
|
})
|
||||||
|
srv := &http.Server{Addr: ":6666", Handler: mux}
|
||||||
|
defer srv.Close()
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
srv.ListenAndServe()
|
||||||
|
}()
|
||||||
|
headers, body, err := client.GetBody("http://localhost:6666/")
|
||||||
|
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Equal(t, "jup.", headers.Get("amaigat"))
|
||||||
|
assert.Contains(t, body, "<rss")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetBodyOf404ReturnsError(t *testing.T) {
|
||||||
|
mux := http.NewServeMux()
|
||||||
|
mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.WriteHeader(404)
|
||||||
|
})
|
||||||
|
srv := &http.Server{Addr: ":6666", Handler: mux}
|
||||||
|
defer srv.Close()
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
srv.ListenAndServe()
|
||||||
|
}()
|
||||||
|
_, body, err := client.GetBody("http://localhost:6666/")
|
||||||
|
|
||||||
|
assert.Contains(t, err.Error(), "404")
|
||||||
|
assert.Equal(t, "", body)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetBodyOfTooLargeContentReturnsError(t *testing.T) {
|
||||||
|
mux := http.NewServeMux()
|
||||||
|
mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.WriteHeader(200)
|
||||||
|
garbage := make([]byte, MaxBytes*2)
|
||||||
|
for i := 0; i < len(garbage); i++ {
|
||||||
|
garbage[i] = 'A'
|
||||||
|
}
|
||||||
|
w.Write(garbage)
|
||||||
|
})
|
||||||
|
srv := &http.Server{Addr: ":6666", Handler: mux}
|
||||||
|
defer srv.Close()
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
srv.ListenAndServe()
|
||||||
|
}()
|
||||||
|
_, body, err := client.GetBody("http://localhost:6666/")
|
||||||
|
|
||||||
|
assert.Equal(t, ResponseAboveLimit, errors.Unwrap(err))
|
||||||
|
assert.Equal(t, "", body)
|
||||||
|
}
|
Loading…
Reference in New Issue