3395 lines
74 KiB
Go
3395 lines
74 KiB
Go
package fasthttp
|
|
|
|
import (
|
|
"bufio"
|
|
"bytes"
|
|
"crypto/tls"
|
|
"fmt"
|
|
"io"
|
|
"net"
|
|
"net/url"
|
|
"os"
|
|
"regexp"
|
|
"runtime"
|
|
"strings"
|
|
"sync"
|
|
"sync/atomic"
|
|
"testing"
|
|
"time"
|
|
|
|
"github.com/valyala/fasthttp/fasthttputil"
|
|
)
|
|
|
|
func TestCloseIdleConnections(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
ln := fasthttputil.NewInmemoryListener()
|
|
|
|
s := &Server{
|
|
Handler: func(ctx *RequestCtx) {
|
|
},
|
|
}
|
|
go func() {
|
|
if err := s.Serve(ln); err != nil {
|
|
t.Error(err)
|
|
}
|
|
}()
|
|
|
|
c := &Client{
|
|
Dial: func(addr string) (net.Conn, error) {
|
|
return ln.Dial()
|
|
},
|
|
}
|
|
|
|
if _, _, err := c.Get(nil, "http://google.com"); err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
connsLen := func() int {
|
|
c.mLock.Lock()
|
|
defer c.mLock.Unlock()
|
|
|
|
if _, ok := c.m["google.com"]; !ok {
|
|
return 0
|
|
}
|
|
|
|
c.m["google.com"].connsLock.Lock()
|
|
defer c.m["google.com"].connsLock.Unlock()
|
|
|
|
return len(c.m["google.com"].conns)
|
|
}
|
|
|
|
if conns := connsLen(); conns > 1 {
|
|
t.Errorf("expected 1 conns got %d", conns)
|
|
}
|
|
|
|
c.CloseIdleConnections()
|
|
|
|
if conns := connsLen(); conns > 0 {
|
|
t.Errorf("expected 0 conns got %d", conns)
|
|
}
|
|
}
|
|
|
|
func TestPipelineClientSetUserAgent(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
testPipelineClientSetUserAgent(t, 0)
|
|
}
|
|
|
|
func TestPipelineClientSetUserAgentTimeout(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
testPipelineClientSetUserAgent(t, time.Second)
|
|
}
|
|
|
|
func testPipelineClientSetUserAgent(t *testing.T, timeout time.Duration) {
|
|
ln := fasthttputil.NewInmemoryListener()
|
|
|
|
userAgentSeen := ""
|
|
s := &Server{
|
|
Handler: func(ctx *RequestCtx) {
|
|
userAgentSeen = string(ctx.UserAgent())
|
|
},
|
|
}
|
|
go s.Serve(ln) //nolint:errcheck
|
|
|
|
userAgent := "I'm not fasthttp"
|
|
c := &HostClient{
|
|
Name: userAgent,
|
|
Dial: func(addr string) (net.Conn, error) {
|
|
return ln.Dial()
|
|
},
|
|
}
|
|
req := AcquireRequest()
|
|
res := AcquireResponse()
|
|
|
|
req.SetRequestURI("http://example.com")
|
|
|
|
var err error
|
|
if timeout <= 0 {
|
|
err = c.Do(req, res)
|
|
} else {
|
|
err = c.DoTimeout(req, res, timeout)
|
|
}
|
|
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
if userAgentSeen != userAgent {
|
|
t.Fatalf("User-Agent defers %q != %q", userAgentSeen, userAgent)
|
|
}
|
|
}
|
|
|
|
func TestHostClientNegativeTimeout(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
ln := fasthttputil.NewInmemoryListener()
|
|
s := &Server{
|
|
Handler: func(ctx *RequestCtx) {
|
|
},
|
|
}
|
|
go s.Serve(ln) //nolint:errcheck
|
|
c := &HostClient{
|
|
Dial: func(addr string) (net.Conn, error) {
|
|
return ln.Dial()
|
|
},
|
|
}
|
|
req := AcquireRequest()
|
|
req.Header.SetMethod(MethodGet)
|
|
req.SetRequestURI("http://example.com")
|
|
if err := c.DoTimeout(req, nil, -time.Second); err != ErrTimeout {
|
|
t.Fatalf("expected ErrTimeout error got: %+v", err)
|
|
}
|
|
if err := c.DoDeadline(req, nil, time.Now().Add(-time.Second)); err != ErrTimeout {
|
|
t.Fatalf("expected ErrTimeout error got: %+v", err)
|
|
}
|
|
ln.Close()
|
|
}
|
|
|
|
func TestDoDeadlineRetry(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
tries := 0
|
|
done := make(chan struct{})
|
|
|
|
ln := fasthttputil.NewInmemoryListener()
|
|
go func() {
|
|
for {
|
|
c, err := ln.Accept()
|
|
if err != nil {
|
|
close(done)
|
|
break
|
|
}
|
|
tries++
|
|
br := bufio.NewReader(c)
|
|
(&RequestHeader{}).Read(br) //nolint:errcheck
|
|
(&Request{}).readBodyStream(br, 0, false, false) //nolint:errcheck
|
|
time.Sleep(time.Millisecond * 60)
|
|
c.Close()
|
|
}
|
|
}()
|
|
c := &HostClient{
|
|
Dial: func(addr string) (net.Conn, error) {
|
|
return ln.Dial()
|
|
},
|
|
}
|
|
req := AcquireRequest()
|
|
req.Header.SetMethod(MethodGet)
|
|
req.SetRequestURI("http://example.com")
|
|
if err := c.DoDeadline(req, nil, time.Now().Add(time.Millisecond*100)); err != ErrTimeout {
|
|
t.Fatalf("expected ErrTimeout error got: %+v", err)
|
|
}
|
|
ln.Close()
|
|
<-done
|
|
if tries != 2 {
|
|
t.Fatalf("expected 2 tries got %d", tries)
|
|
}
|
|
}
|
|
|
|
func TestPipelineClientIssue832(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
ln := fasthttputil.NewInmemoryListener()
|
|
|
|
req := AcquireRequest()
|
|
// Don't defer ReleaseRequest as we use it in a goroutine that might not be done at the end.
|
|
|
|
req.SetHost("example.com")
|
|
|
|
res := AcquireResponse()
|
|
// Don't defer ReleaseResponse as we use it in a goroutine that might not be done at the end.
|
|
|
|
client := PipelineClient{
|
|
Dial: func(addr string) (net.Conn, error) {
|
|
return ln.Dial()
|
|
},
|
|
ReadTimeout: time.Millisecond * 10,
|
|
Logger: &testLogger{}, // Ignore log output.
|
|
}
|
|
|
|
attempts := 10
|
|
go func() {
|
|
for i := 0; i < attempts; i++ {
|
|
c, err := ln.Accept()
|
|
if err != nil {
|
|
t.Error(err)
|
|
}
|
|
if c != nil {
|
|
go func() {
|
|
time.Sleep(time.Millisecond * 50)
|
|
c.Close()
|
|
}()
|
|
}
|
|
}
|
|
}()
|
|
|
|
done := make(chan int)
|
|
go func() {
|
|
defer close(done)
|
|
|
|
for i := 0; i < attempts; i++ {
|
|
if err := client.Do(req, res); err == nil {
|
|
t.Error("error expected")
|
|
}
|
|
}
|
|
}()
|
|
|
|
select {
|
|
case <-time.After(time.Second * 2):
|
|
t.Fatal("PipelineClient did not restart worker")
|
|
case <-done:
|
|
}
|
|
}
|
|
|
|
func TestClientInvalidURI(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
ln := fasthttputil.NewInmemoryListener()
|
|
requests := int64(0)
|
|
s := &Server{
|
|
Handler: func(_ *RequestCtx) {
|
|
atomic.AddInt64(&requests, 1)
|
|
},
|
|
}
|
|
go s.Serve(ln) //nolint:errcheck
|
|
c := &Client{
|
|
Dial: func(addr string) (net.Conn, error) {
|
|
return ln.Dial()
|
|
},
|
|
}
|
|
req, res := AcquireRequest(), AcquireResponse()
|
|
defer func() {
|
|
ReleaseRequest(req)
|
|
ReleaseResponse(res)
|
|
}()
|
|
req.Header.SetMethod(MethodGet)
|
|
req.SetRequestURI("http://example.com\r\n\r\nGET /\r\n\r\n")
|
|
err := c.Do(req, res)
|
|
if err == nil {
|
|
t.Fatal("expected error (missing required Host header in request)")
|
|
}
|
|
if n := atomic.LoadInt64(&requests); n != 0 {
|
|
t.Fatalf("0 requests expected, got %d", n)
|
|
}
|
|
}
|
|
|
|
func TestClientGetWithBody(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
ln := fasthttputil.NewInmemoryListener()
|
|
s := &Server{
|
|
Handler: func(ctx *RequestCtx) {
|
|
body := ctx.Request.Body()
|
|
ctx.Write(body) //nolint:errcheck
|
|
},
|
|
}
|
|
go s.Serve(ln) //nolint:errcheck
|
|
c := &Client{
|
|
Dial: func(addr string) (net.Conn, error) {
|
|
return ln.Dial()
|
|
},
|
|
}
|
|
req, res := AcquireRequest(), AcquireResponse()
|
|
defer func() {
|
|
ReleaseRequest(req)
|
|
ReleaseResponse(res)
|
|
}()
|
|
req.Header.SetMethod(MethodGet)
|
|
req.SetRequestURI("http://example.com")
|
|
req.SetBodyString("test")
|
|
err := c.Do(req, res)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
if len(res.Body()) == 0 {
|
|
t.Fatal("missing request body")
|
|
}
|
|
}
|
|
|
|
func TestClientURLAuth(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
cases := map[string]string{
|
|
"user:pass@": "Basic dXNlcjpwYXNz",
|
|
"foo:@": "Basic Zm9vOg==",
|
|
":@": "",
|
|
"@": "",
|
|
"": "",
|
|
}
|
|
|
|
ch := make(chan string, 1)
|
|
ln := fasthttputil.NewInmemoryListener()
|
|
s := &Server{
|
|
Handler: func(ctx *RequestCtx) {
|
|
ch <- string(ctx.Request.Header.Peek(HeaderAuthorization))
|
|
},
|
|
}
|
|
go s.Serve(ln) //nolint:errcheck
|
|
c := &Client{
|
|
Dial: func(addr string) (net.Conn, error) {
|
|
return ln.Dial()
|
|
},
|
|
}
|
|
for up, expected := range cases {
|
|
req := AcquireRequest()
|
|
req.Header.SetMethod(MethodGet)
|
|
req.SetRequestURI("http://" + up + "example.com/foo/bar")
|
|
if err := c.Do(req, nil); err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
val := <-ch
|
|
|
|
if val != expected {
|
|
t.Fatalf("wrong %q header: %q expected %q", HeaderAuthorization, val, expected)
|
|
}
|
|
}
|
|
}
|
|
|
|
func TestClientNilResp(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
ln := fasthttputil.NewInmemoryListener()
|
|
s := &Server{
|
|
Handler: func(ctx *RequestCtx) {
|
|
},
|
|
}
|
|
go s.Serve(ln) //nolint:errcheck
|
|
c := &Client{
|
|
Dial: func(addr string) (net.Conn, error) {
|
|
return ln.Dial()
|
|
},
|
|
}
|
|
req := AcquireRequest()
|
|
req.Header.SetMethod(MethodGet)
|
|
req.SetRequestURI("http://example.com")
|
|
if err := c.Do(req, nil); err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
if err := c.DoTimeout(req, nil, time.Second); err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
ln.Close()
|
|
}
|
|
|
|
func TestClientNegativeTimeout(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
ln := fasthttputil.NewInmemoryListener()
|
|
s := &Server{
|
|
Handler: func(ctx *RequestCtx) {
|
|
},
|
|
}
|
|
go s.Serve(ln) //nolint:errcheck
|
|
c := &Client{
|
|
Dial: func(addr string) (net.Conn, error) {
|
|
return ln.Dial()
|
|
},
|
|
}
|
|
req := AcquireRequest()
|
|
req.Header.SetMethod(MethodGet)
|
|
req.SetRequestURI("http://example.com")
|
|
if err := c.DoTimeout(req, nil, -time.Second); err != ErrTimeout {
|
|
t.Fatalf("expected ErrTimeout error got: %+v", err)
|
|
}
|
|
if err := c.DoDeadline(req, nil, time.Now().Add(-time.Second)); err != ErrTimeout {
|
|
t.Fatalf("expected ErrTimeout error got: %+v", err)
|
|
}
|
|
ln.Close()
|
|
}
|
|
|
|
func TestPipelineClientNilResp(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
ln := fasthttputil.NewInmemoryListener()
|
|
s := &Server{
|
|
Handler: func(ctx *RequestCtx) {
|
|
},
|
|
}
|
|
go s.Serve(ln) //nolint:errcheck
|
|
c := &PipelineClient{
|
|
Dial: func(addr string) (net.Conn, error) {
|
|
return ln.Dial()
|
|
},
|
|
}
|
|
req := AcquireRequest()
|
|
req.Header.SetMethod(MethodGet)
|
|
req.SetRequestURI("http://example.com")
|
|
if err := c.Do(req, nil); err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
if err := c.DoTimeout(req, nil, time.Second); err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
if err := c.DoDeadline(req, nil, time.Now().Add(time.Second)); err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
}
|
|
|
|
func TestClientParseConn(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
network := "tcp"
|
|
ln, _ := net.Listen(network, "127.0.0.1:0")
|
|
s := &Server{
|
|
Handler: func(ctx *RequestCtx) {
|
|
},
|
|
}
|
|
go s.Serve(ln) //nolint:errcheck
|
|
host := ln.Addr().String()
|
|
c := &Client{}
|
|
req, res := AcquireRequest(), AcquireResponse()
|
|
defer func() {
|
|
ReleaseRequest(req)
|
|
ReleaseResponse(res)
|
|
}()
|
|
req.SetRequestURI("http://" + host + "")
|
|
if err := c.Do(req, res); err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
if res.RemoteAddr().Network() != network {
|
|
t.Fatalf("req RemoteAddr parse network fail: %q, hope: %q", res.RemoteAddr().Network(), network)
|
|
}
|
|
if host != res.RemoteAddr().String() {
|
|
t.Fatalf("req RemoteAddr parse addr fail: %q, hope: %q", res.RemoteAddr().String(), host)
|
|
}
|
|
|
|
if !regexp.MustCompile(`^127\.0\.0\.1:[0-9]{4,5}$`).MatchString(res.LocalAddr().String()) {
|
|
t.Fatalf("res LocalAddr addr match fail: %q, hope match: %q", res.LocalAddr().String(), "^127.0.0.1:[0-9]{4,5}$")
|
|
}
|
|
}
|
|
|
|
func TestClientPostArgs(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
ln := fasthttputil.NewInmemoryListener()
|
|
s := &Server{
|
|
Handler: func(ctx *RequestCtx) {
|
|
body := ctx.Request.Body()
|
|
if len(body) == 0 {
|
|
return
|
|
}
|
|
ctx.Write(body) //nolint:errcheck
|
|
},
|
|
}
|
|
go s.Serve(ln) //nolint:errcheck
|
|
c := &Client{
|
|
Dial: func(addr string) (net.Conn, error) {
|
|
return ln.Dial()
|
|
},
|
|
}
|
|
req, res := AcquireRequest(), AcquireResponse()
|
|
defer func() {
|
|
ReleaseRequest(req)
|
|
ReleaseResponse(res)
|
|
}()
|
|
args := req.PostArgs()
|
|
args.Add("addhttp2", "support")
|
|
args.Add("fast", "http")
|
|
req.Header.SetMethod(MethodPost)
|
|
req.SetRequestURI("http://make.fasthttp.great?again")
|
|
err := c.Do(req, res)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
if len(res.Body()) == 0 {
|
|
t.Fatal("cannot set args as body")
|
|
}
|
|
}
|
|
|
|
func TestClientRedirectSameSchema(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
listenHTTPS1 := testClientRedirectListener(t, true)
|
|
defer listenHTTPS1.Close()
|
|
|
|
listenHTTPS2 := testClientRedirectListener(t, true)
|
|
defer listenHTTPS2.Close()
|
|
|
|
sHTTPS1 := testClientRedirectChangingSchemaServer(t, listenHTTPS1, listenHTTPS1, true)
|
|
defer sHTTPS1.Stop()
|
|
|
|
sHTTPS2 := testClientRedirectChangingSchemaServer(t, listenHTTPS2, listenHTTPS2, false)
|
|
defer sHTTPS2.Stop()
|
|
|
|
destURL := fmt.Sprintf("https://%s/baz", listenHTTPS1.Addr().String())
|
|
|
|
urlParsed, err := url.Parse(destURL)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
return
|
|
}
|
|
|
|
reqClient := &HostClient{
|
|
IsTLS: true,
|
|
Addr: urlParsed.Host,
|
|
TLSConfig: &tls.Config{
|
|
InsecureSkipVerify: true,
|
|
},
|
|
}
|
|
|
|
statusCode, _, err := reqClient.GetTimeout(nil, destURL, 4000*time.Millisecond)
|
|
if err != nil {
|
|
t.Fatalf("HostClient error: %v", err)
|
|
return
|
|
}
|
|
|
|
if statusCode != 200 {
|
|
t.Fatalf("HostClient error code response %d", statusCode)
|
|
return
|
|
}
|
|
}
|
|
|
|
func TestClientRedirectClientChangingSchemaHttp2Https(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
listenHTTPS := testClientRedirectListener(t, true)
|
|
defer listenHTTPS.Close()
|
|
|
|
listenHTTP := testClientRedirectListener(t, false)
|
|
defer listenHTTP.Close()
|
|
|
|
sHTTPS := testClientRedirectChangingSchemaServer(t, listenHTTPS, listenHTTP, true)
|
|
defer sHTTPS.Stop()
|
|
|
|
sHTTP := testClientRedirectChangingSchemaServer(t, listenHTTPS, listenHTTP, false)
|
|
defer sHTTP.Stop()
|
|
|
|
destURL := fmt.Sprintf("http://%s/baz", listenHTTP.Addr().String())
|
|
|
|
reqClient := &Client{
|
|
TLSConfig: &tls.Config{
|
|
InsecureSkipVerify: true,
|
|
},
|
|
}
|
|
|
|
statusCode, _, err := reqClient.GetTimeout(nil, destURL, 4000*time.Millisecond)
|
|
if err != nil {
|
|
t.Fatalf("HostClient error: %v", err)
|
|
return
|
|
}
|
|
|
|
if statusCode != 200 {
|
|
t.Fatalf("HostClient error code response %d", statusCode)
|
|
return
|
|
}
|
|
}
|
|
|
|
func TestClientRedirectHostClientChangingSchemaHttp2Https(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
listenHTTPS := testClientRedirectListener(t, true)
|
|
defer listenHTTPS.Close()
|
|
|
|
listenHTTP := testClientRedirectListener(t, false)
|
|
defer listenHTTP.Close()
|
|
|
|
sHTTPS := testClientRedirectChangingSchemaServer(t, listenHTTPS, listenHTTP, true)
|
|
defer sHTTPS.Stop()
|
|
|
|
sHTTP := testClientRedirectChangingSchemaServer(t, listenHTTPS, listenHTTP, false)
|
|
defer sHTTP.Stop()
|
|
|
|
destURL := fmt.Sprintf("http://%s/baz", listenHTTP.Addr().String())
|
|
|
|
urlParsed, err := url.Parse(destURL)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
return
|
|
}
|
|
|
|
reqClient := &HostClient{
|
|
Addr: urlParsed.Host,
|
|
TLSConfig: &tls.Config{
|
|
InsecureSkipVerify: true,
|
|
},
|
|
}
|
|
|
|
_, _, err = reqClient.GetTimeout(nil, destURL, 4000*time.Millisecond)
|
|
if err != ErrHostClientRedirectToDifferentScheme {
|
|
t.Fatal("expected HostClient error")
|
|
}
|
|
}
|
|
|
|
func testClientRedirectListener(t *testing.T, isTLS bool) net.Listener {
|
|
var ln net.Listener
|
|
var err error
|
|
var tlsConfig *tls.Config
|
|
|
|
if isTLS {
|
|
certData, keyData, kerr := GenerateTestCertificate("localhost")
|
|
if kerr != nil {
|
|
t.Fatal(kerr)
|
|
}
|
|
|
|
cert, kerr := tls.X509KeyPair(certData, keyData)
|
|
if kerr != nil {
|
|
t.Fatal(kerr)
|
|
}
|
|
|
|
tlsConfig = &tls.Config{
|
|
Certificates: []tls.Certificate{cert},
|
|
}
|
|
ln, err = tls.Listen("tcp", "localhost:0", tlsConfig)
|
|
} else {
|
|
ln, err = net.Listen("tcp", "localhost:0")
|
|
}
|
|
|
|
if err != nil {
|
|
t.Fatalf("cannot listen isTLS %v: %v", isTLS, err)
|
|
}
|
|
|
|
return ln
|
|
}
|
|
|
|
func testClientRedirectChangingSchemaServer(t *testing.T, https, http net.Listener, isTLS bool) *testEchoServer {
|
|
s := &Server{
|
|
Handler: func(ctx *RequestCtx) {
|
|
if ctx.IsTLS() {
|
|
ctx.SetStatusCode(200)
|
|
} else {
|
|
ctx.Redirect(fmt.Sprintf("https://%s/baz", https.Addr().String()), 301)
|
|
}
|
|
},
|
|
}
|
|
|
|
var ln net.Listener
|
|
if isTLS {
|
|
ln = https
|
|
} else {
|
|
ln = http
|
|
}
|
|
|
|
ch := make(chan struct{})
|
|
go func() {
|
|
err := s.Serve(ln)
|
|
if err != nil {
|
|
t.Errorf("unexpected error returned from Serve(): %v", err)
|
|
}
|
|
close(ch)
|
|
}()
|
|
return &testEchoServer{
|
|
s: s,
|
|
ln: ln,
|
|
ch: ch,
|
|
t: t,
|
|
}
|
|
}
|
|
|
|
func TestClientHeaderCase(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
ln := fasthttputil.NewInmemoryListener()
|
|
defer ln.Close()
|
|
|
|
go func() {
|
|
c, err := ln.Accept()
|
|
if err != nil {
|
|
t.Error(err)
|
|
}
|
|
c.Write([]byte("HTTP/1.1 200 OK\r\n" + //nolint:errcheck
|
|
"content-type: text/plain\r\n" +
|
|
"transfer-encoding: chunked\r\n\r\n" +
|
|
"24\r\nThis is the data in the first chunk \r\n" +
|
|
"1B\r\nand this is the second one \r\n" +
|
|
"0\r\n\r\n",
|
|
))
|
|
}()
|
|
|
|
c := &Client{
|
|
Dial: func(addr string) (net.Conn, error) {
|
|
return ln.Dial()
|
|
},
|
|
ReadTimeout: time.Millisecond * 10,
|
|
|
|
// Even without name normalizing we should parse headers correctly.
|
|
DisableHeaderNamesNormalizing: true,
|
|
}
|
|
|
|
code, body, err := c.Get(nil, "http://example.com")
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
if code != 200 {
|
|
t.Errorf("expected status code 200 got %d", code)
|
|
}
|
|
if string(body) != "This is the data in the first chunk and this is the second one " {
|
|
t.Errorf("wrong body: %q", body)
|
|
}
|
|
}
|
|
|
|
func TestClientReadTimeout(t *testing.T) {
|
|
if runtime.GOOS == "windows" {
|
|
t.SkipNow()
|
|
}
|
|
|
|
t.Parallel()
|
|
|
|
ln := fasthttputil.NewInmemoryListener()
|
|
|
|
timeout := false
|
|
s := &Server{
|
|
Handler: func(_ *RequestCtx) {
|
|
if timeout {
|
|
time.Sleep(time.Second)
|
|
} else {
|
|
timeout = true
|
|
}
|
|
},
|
|
Logger: &testLogger{}, // Don't print closed pipe errors.
|
|
}
|
|
go s.Serve(ln) //nolint:errcheck
|
|
|
|
c := &HostClient{
|
|
ReadTimeout: time.Millisecond * 400,
|
|
MaxIdemponentCallAttempts: 1,
|
|
Dial: func(addr string) (net.Conn, error) {
|
|
return ln.Dial()
|
|
},
|
|
}
|
|
|
|
req := AcquireRequest()
|
|
res := AcquireResponse()
|
|
|
|
req.SetRequestURI("http://localhost")
|
|
|
|
// Setting Connection: Close will make the connection be
|
|
// returned to the pool.
|
|
req.SetConnectionClose()
|
|
|
|
if err := c.Do(req, res); err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
ReleaseRequest(req)
|
|
ReleaseResponse(res)
|
|
|
|
done := make(chan struct{})
|
|
go func() {
|
|
req := AcquireRequest()
|
|
res := AcquireResponse()
|
|
|
|
req.SetRequestURI("http://localhost")
|
|
req.SetConnectionClose()
|
|
|
|
if err := c.Do(req, res); err != ErrTimeout {
|
|
t.Errorf("expected ErrTimeout got %#v", err)
|
|
}
|
|
|
|
ReleaseRequest(req)
|
|
ReleaseResponse(res)
|
|
close(done)
|
|
}()
|
|
|
|
select {
|
|
case <-done:
|
|
// This shouldn't take longer than the timeout times the number of requests it is going to try to do.
|
|
// Give it an extra second just to be sure.
|
|
case <-time.After(c.ReadTimeout*time.Duration(c.MaxIdemponentCallAttempts) + time.Second):
|
|
t.Fatal("Client.ReadTimeout didn't work")
|
|
}
|
|
}
|
|
|
|
func TestClientDefaultUserAgent(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
ln := fasthttputil.NewInmemoryListener()
|
|
|
|
userAgentSeen := ""
|
|
s := &Server{
|
|
Handler: func(ctx *RequestCtx) {
|
|
userAgentSeen = string(ctx.UserAgent())
|
|
},
|
|
}
|
|
go s.Serve(ln) //nolint:errcheck
|
|
|
|
c := &Client{
|
|
Dial: func(addr string) (net.Conn, error) {
|
|
return ln.Dial()
|
|
},
|
|
}
|
|
req := AcquireRequest()
|
|
res := AcquireResponse()
|
|
|
|
req.SetRequestURI("http://example.com")
|
|
|
|
err := c.Do(req, res)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
if userAgentSeen != defaultUserAgent {
|
|
t.Fatalf("User-Agent defers %q != %q", userAgentSeen, defaultUserAgent)
|
|
}
|
|
}
|
|
|
|
func TestClientSetUserAgent(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
ln := fasthttputil.NewInmemoryListener()
|
|
|
|
userAgentSeen := ""
|
|
s := &Server{
|
|
Handler: func(ctx *RequestCtx) {
|
|
userAgentSeen = string(ctx.UserAgent())
|
|
},
|
|
}
|
|
go s.Serve(ln) //nolint:errcheck
|
|
|
|
userAgent := "I'm not fasthttp"
|
|
c := &Client{
|
|
Name: userAgent,
|
|
Dial: func(addr string) (net.Conn, error) {
|
|
return ln.Dial()
|
|
},
|
|
}
|
|
req := AcquireRequest()
|
|
res := AcquireResponse()
|
|
|
|
req.SetRequestURI("http://example.com")
|
|
|
|
err := c.Do(req, res)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
if userAgentSeen != userAgent {
|
|
t.Fatalf("User-Agent defers %q != %q", userAgentSeen, userAgent)
|
|
}
|
|
}
|
|
|
|
func TestClientNoUserAgent(t *testing.T) {
|
|
ln := fasthttputil.NewInmemoryListener()
|
|
|
|
userAgentSeen := ""
|
|
s := &Server{
|
|
Handler: func(ctx *RequestCtx) {
|
|
userAgentSeen = string(ctx.UserAgent())
|
|
},
|
|
}
|
|
go s.Serve(ln) //nolint:errcheck
|
|
|
|
c := &Client{
|
|
NoDefaultUserAgentHeader: true,
|
|
Dial: func(addr string) (net.Conn, error) {
|
|
return ln.Dial()
|
|
},
|
|
}
|
|
req := AcquireRequest()
|
|
res := AcquireResponse()
|
|
|
|
req.SetRequestURI("http://example.com")
|
|
|
|
err := c.Do(req, res)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
if userAgentSeen != "" {
|
|
t.Fatalf("User-Agent wrong %q != %q", userAgentSeen, "")
|
|
}
|
|
}
|
|
|
|
func TestClientDoWithCustomHeaders(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
// make sure that the client sends all the request headers and body.
|
|
ln := fasthttputil.NewInmemoryListener()
|
|
c := &Client{
|
|
Dial: func(addr string) (net.Conn, error) {
|
|
return ln.Dial()
|
|
},
|
|
}
|
|
|
|
uri := "/foo/bar/baz?a=b&cd=12"
|
|
headers := map[string]string{
|
|
"Foo": "bar",
|
|
"Host": "example.com",
|
|
"Content-Type": "asdfsdf",
|
|
"a-b-c-d-f": "",
|
|
}
|
|
body := "request body"
|
|
|
|
ch := make(chan error)
|
|
go func() {
|
|
conn, err := ln.Accept()
|
|
if err != nil {
|
|
ch <- fmt.Errorf("cannot accept client connection: %w", err)
|
|
return
|
|
}
|
|
br := bufio.NewReader(conn)
|
|
|
|
var req Request
|
|
if err = req.Read(br); err != nil {
|
|
ch <- fmt.Errorf("cannot read client request: %w", err)
|
|
return
|
|
}
|
|
if string(req.Header.Method()) != MethodPost {
|
|
ch <- fmt.Errorf("unexpected request method: %q. Expecting %q", req.Header.Method(), MethodPost)
|
|
return
|
|
}
|
|
reqURI := req.RequestURI()
|
|
if string(reqURI) != uri {
|
|
ch <- fmt.Errorf("unexpected request uri: %q. Expecting %q", reqURI, uri)
|
|
return
|
|
}
|
|
for k, v := range headers {
|
|
hv := req.Header.Peek(k)
|
|
if string(hv) != v {
|
|
ch <- fmt.Errorf("unexpected value for header %q: %q. Expecting %q", k, hv, v)
|
|
return
|
|
}
|
|
}
|
|
cl := req.Header.ContentLength()
|
|
if cl != len(body) {
|
|
ch <- fmt.Errorf("unexpected content-length %d. Expecting %d", cl, len(body))
|
|
return
|
|
}
|
|
reqBody := req.Body()
|
|
if string(reqBody) != body {
|
|
ch <- fmt.Errorf("unexpected request body: %q. Expecting %q", reqBody, body)
|
|
return
|
|
}
|
|
|
|
var resp Response
|
|
bw := bufio.NewWriter(conn)
|
|
if err = resp.Write(bw); err != nil {
|
|
ch <- fmt.Errorf("cannot send response: %w", err)
|
|
return
|
|
}
|
|
if err = bw.Flush(); err != nil {
|
|
ch <- fmt.Errorf("cannot flush response: %w", err)
|
|
return
|
|
}
|
|
|
|
ch <- nil
|
|
}()
|
|
|
|
var req Request
|
|
req.Header.SetMethod(MethodPost)
|
|
req.SetRequestURI(uri)
|
|
for k, v := range headers {
|
|
req.Header.Set(k, v)
|
|
}
|
|
req.SetBodyString(body)
|
|
|
|
var resp Response
|
|
|
|
err := c.DoTimeout(&req, &resp, time.Second)
|
|
if err != nil {
|
|
t.Fatalf("error when doing request: %v", err)
|
|
}
|
|
|
|
select {
|
|
case <-ch:
|
|
case <-time.After(5 * time.Second):
|
|
t.Fatalf("timeout")
|
|
}
|
|
}
|
|
|
|
func TestPipelineClientDoSerial(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
testPipelineClientDoConcurrent(t, 1, 0, 0)
|
|
}
|
|
|
|
func TestPipelineClientDoConcurrent(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
testPipelineClientDoConcurrent(t, 10, 0, 1)
|
|
}
|
|
|
|
func TestPipelineClientDoBatchDelayConcurrent(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
testPipelineClientDoConcurrent(t, 10, 5*time.Millisecond, 1)
|
|
}
|
|
|
|
func TestPipelineClientDoBatchDelayConcurrentMultiConn(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
testPipelineClientDoConcurrent(t, 10, 5*time.Millisecond, 3)
|
|
}
|
|
|
|
func testPipelineClientDoConcurrent(t *testing.T, concurrency int, maxBatchDelay time.Duration, maxConns int) {
|
|
ln := fasthttputil.NewInmemoryListener()
|
|
|
|
s := &Server{
|
|
Handler: func(ctx *RequestCtx) {
|
|
ctx.WriteString("OK") //nolint:errcheck
|
|
},
|
|
}
|
|
|
|
serverStopCh := make(chan struct{})
|
|
go func() {
|
|
if err := s.Serve(ln); err != nil {
|
|
t.Errorf("unexpected error: %v", err)
|
|
}
|
|
close(serverStopCh)
|
|
}()
|
|
|
|
c := &PipelineClient{
|
|
Dial: func(addr string) (net.Conn, error) {
|
|
return ln.Dial()
|
|
},
|
|
MaxConns: maxConns,
|
|
MaxPendingRequests: concurrency,
|
|
MaxBatchDelay: maxBatchDelay,
|
|
Logger: &testLogger{},
|
|
}
|
|
|
|
clientStopCh := make(chan struct{}, concurrency)
|
|
for i := 0; i < concurrency; i++ {
|
|
go func() {
|
|
testPipelineClientDo(t, c)
|
|
clientStopCh <- struct{}{}
|
|
}()
|
|
}
|
|
|
|
for i := 0; i < concurrency; i++ {
|
|
select {
|
|
case <-clientStopCh:
|
|
case <-time.After(3 * time.Second):
|
|
t.Fatalf("timeout")
|
|
}
|
|
}
|
|
|
|
if c.PendingRequests() != 0 {
|
|
t.Fatalf("unexpected number of pending requests: %d. Expecting zero", c.PendingRequests())
|
|
}
|
|
|
|
if err := ln.Close(); err != nil {
|
|
t.Fatalf("unexpected error: %v", err)
|
|
}
|
|
select {
|
|
case <-serverStopCh:
|
|
case <-time.After(time.Second):
|
|
t.Fatalf("timeout")
|
|
}
|
|
}
|
|
|
|
func testPipelineClientDo(t *testing.T, c *PipelineClient) {
|
|
var err error
|
|
req := AcquireRequest()
|
|
req.SetRequestURI("http://foobar/baz")
|
|
resp := AcquireResponse()
|
|
for i := 0; i < 10; i++ {
|
|
if i&1 == 0 {
|
|
err = c.DoTimeout(req, resp, time.Second)
|
|
} else {
|
|
err = c.Do(req, resp)
|
|
}
|
|
if err != nil {
|
|
if err == ErrPipelineOverflow {
|
|
time.Sleep(10 * time.Millisecond)
|
|
continue
|
|
}
|
|
t.Errorf("unexpected error on iteration %d: %v", i, err)
|
|
}
|
|
if resp.StatusCode() != StatusOK {
|
|
t.Errorf("unexpected status code: %d. Expecting %d", resp.StatusCode(), StatusOK)
|
|
}
|
|
body := string(resp.Body())
|
|
if body != "OK" {
|
|
t.Errorf("unexpected body: %q. Expecting %q", body, "OK")
|
|
}
|
|
|
|
// sleep for a while, so the connection to the host may expire.
|
|
if i%5 == 0 {
|
|
time.Sleep(30 * time.Millisecond)
|
|
}
|
|
}
|
|
ReleaseRequest(req)
|
|
ReleaseResponse(resp)
|
|
}
|
|
|
|
func TestPipelineClientDoDisableHeaderNamesNormalizing(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
testPipelineClientDisableHeaderNamesNormalizing(t, 0)
|
|
}
|
|
|
|
func TestPipelineClientDoTimeoutDisableHeaderNamesNormalizing(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
testPipelineClientDisableHeaderNamesNormalizing(t, time.Second)
|
|
}
|
|
|
|
func testPipelineClientDisableHeaderNamesNormalizing(t *testing.T, timeout time.Duration) {
|
|
ln := fasthttputil.NewInmemoryListener()
|
|
|
|
s := &Server{
|
|
Handler: func(ctx *RequestCtx) {
|
|
ctx.Response.Header.Set("foo-BAR", "baz")
|
|
},
|
|
DisableHeaderNamesNormalizing: true,
|
|
}
|
|
|
|
serverStopCh := make(chan struct{})
|
|
go func() {
|
|
if err := s.Serve(ln); err != nil {
|
|
t.Errorf("unexpected error: %v", err)
|
|
}
|
|
close(serverStopCh)
|
|
}()
|
|
|
|
c := &PipelineClient{
|
|
Dial: func(addr string) (net.Conn, error) {
|
|
return ln.Dial()
|
|
},
|
|
DisableHeaderNamesNormalizing: true,
|
|
}
|
|
|
|
var req Request
|
|
req.SetRequestURI("http://aaaai.com/bsdf?sddfsd")
|
|
var resp Response
|
|
for i := 0; i < 5; i++ {
|
|
if timeout > 0 {
|
|
if err := c.DoTimeout(&req, &resp, timeout); err != nil {
|
|
t.Fatalf("unexpected error: %v", err)
|
|
}
|
|
} else {
|
|
if err := c.Do(&req, &resp); err != nil {
|
|
t.Fatalf("unexpected error: %v", err)
|
|
}
|
|
}
|
|
hv := resp.Header.Peek("foo-BAR")
|
|
if string(hv) != "baz" {
|
|
t.Fatalf("unexpected header value: %q. Expecting %q", hv, "baz")
|
|
}
|
|
hv = resp.Header.Peek("Foo-Bar")
|
|
if len(hv) > 0 {
|
|
t.Fatalf("unexpected non-empty header value %q", hv)
|
|
}
|
|
}
|
|
|
|
if err := ln.Close(); err != nil {
|
|
t.Fatalf("unexpected error: %v", err)
|
|
}
|
|
select {
|
|
case <-serverStopCh:
|
|
case <-time.After(time.Second):
|
|
t.Fatalf("timeout")
|
|
}
|
|
}
|
|
|
|
func TestClientDoTimeoutDisableHeaderNamesNormalizing(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
ln := fasthttputil.NewInmemoryListener()
|
|
|
|
s := &Server{
|
|
Handler: func(ctx *RequestCtx) {
|
|
ctx.Response.Header.Set("foo-BAR", "baz")
|
|
},
|
|
DisableHeaderNamesNormalizing: true,
|
|
}
|
|
|
|
serverStopCh := make(chan struct{})
|
|
go func() {
|
|
if err := s.Serve(ln); err != nil {
|
|
t.Errorf("unexpected error: %v", err)
|
|
}
|
|
close(serverStopCh)
|
|
}()
|
|
|
|
c := &Client{
|
|
Dial: func(addr string) (net.Conn, error) {
|
|
return ln.Dial()
|
|
},
|
|
DisableHeaderNamesNormalizing: true,
|
|
}
|
|
|
|
var req Request
|
|
req.SetRequestURI("http://aaaai.com/bsdf?sddfsd")
|
|
var resp Response
|
|
for i := 0; i < 5; i++ {
|
|
if err := c.DoTimeout(&req, &resp, time.Second); err != nil {
|
|
t.Fatalf("unexpected error: %v", err)
|
|
}
|
|
hv := resp.Header.Peek("foo-BAR")
|
|
if string(hv) != "baz" {
|
|
t.Fatalf("unexpected header value: %q. Expecting %q", hv, "baz")
|
|
}
|
|
hv = resp.Header.Peek("Foo-Bar")
|
|
if len(hv) > 0 {
|
|
t.Fatalf("unexpected non-empty header value %q", hv)
|
|
}
|
|
}
|
|
|
|
if err := ln.Close(); err != nil {
|
|
t.Fatalf("unexpected error: %v", err)
|
|
}
|
|
select {
|
|
case <-serverStopCh:
|
|
case <-time.After(time.Second):
|
|
t.Fatalf("timeout")
|
|
}
|
|
}
|
|
|
|
func TestClientDoTimeoutDisablePathNormalizing(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
ln := fasthttputil.NewInmemoryListener()
|
|
|
|
s := &Server{
|
|
Handler: func(ctx *RequestCtx) {
|
|
uri := ctx.URI()
|
|
uri.DisablePathNormalizing = true
|
|
ctx.Response.Header.Set("received-uri", string(uri.FullURI()))
|
|
},
|
|
}
|
|
|
|
serverStopCh := make(chan struct{})
|
|
go func() {
|
|
if err := s.Serve(ln); err != nil {
|
|
t.Errorf("unexpected error: %v", err)
|
|
}
|
|
close(serverStopCh)
|
|
}()
|
|
|
|
c := &Client{
|
|
Dial: func(addr string) (net.Conn, error) {
|
|
return ln.Dial()
|
|
},
|
|
DisablePathNormalizing: true,
|
|
}
|
|
|
|
urlWithEncodedPath := "http://example.com/encoded/Y%2BY%2FY%3D/stuff"
|
|
|
|
var req Request
|
|
req.SetRequestURI(urlWithEncodedPath)
|
|
var resp Response
|
|
for i := 0; i < 5; i++ {
|
|
if err := c.DoTimeout(&req, &resp, time.Second); err != nil {
|
|
t.Fatalf("unexpected error: %v", err)
|
|
}
|
|
hv := resp.Header.Peek("received-uri")
|
|
if string(hv) != urlWithEncodedPath {
|
|
t.Fatalf("request uri was normalized: %q. Expecting %q", hv, urlWithEncodedPath)
|
|
}
|
|
}
|
|
|
|
if err := ln.Close(); err != nil {
|
|
t.Fatalf("unexpected error: %v", err)
|
|
}
|
|
select {
|
|
case <-serverStopCh:
|
|
case <-time.After(time.Second):
|
|
t.Fatalf("timeout")
|
|
}
|
|
}
|
|
|
|
func TestHostClientPendingRequests(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
const concurrency = 10
|
|
doneCh := make(chan struct{})
|
|
readyCh := make(chan struct{}, concurrency)
|
|
s := &Server{
|
|
Handler: func(_ *RequestCtx) {
|
|
readyCh <- struct{}{}
|
|
<-doneCh
|
|
},
|
|
}
|
|
ln := fasthttputil.NewInmemoryListener()
|
|
serverStopCh := make(chan struct{})
|
|
go func() {
|
|
if err := s.Serve(ln); err != nil {
|
|
t.Errorf("unexpected error: %v", err)
|
|
}
|
|
close(serverStopCh)
|
|
}()
|
|
|
|
c := &HostClient{
|
|
Addr: "foobar",
|
|
Dial: func(addr string) (net.Conn, error) {
|
|
return ln.Dial()
|
|
},
|
|
}
|
|
|
|
pendingRequests := c.PendingRequests()
|
|
if pendingRequests != 0 {
|
|
t.Fatalf("non-zero pendingRequests: %d", pendingRequests)
|
|
}
|
|
|
|
resultCh := make(chan error, concurrency)
|
|
for i := 0; i < concurrency; i++ {
|
|
go func() {
|
|
req := AcquireRequest()
|
|
req.SetRequestURI("http://foobar/baz")
|
|
resp := AcquireResponse()
|
|
|
|
if err := c.DoTimeout(req, resp, 10*time.Second); err != nil {
|
|
resultCh <- fmt.Errorf("unexpected error: %w", err)
|
|
return
|
|
}
|
|
|
|
if resp.StatusCode() != StatusOK {
|
|
resultCh <- fmt.Errorf("unexpected status code %d. Expecting %d", resp.StatusCode(), StatusOK)
|
|
return
|
|
}
|
|
resultCh <- nil
|
|
}()
|
|
}
|
|
|
|
// wait while all the requests reach server
|
|
for i := 0; i < concurrency; i++ {
|
|
select {
|
|
case <-readyCh:
|
|
case <-time.After(time.Second):
|
|
t.Fatalf("timeout")
|
|
}
|
|
}
|
|
|
|
pendingRequests = c.PendingRequests()
|
|
if pendingRequests != concurrency {
|
|
t.Fatalf("unexpected pendingRequests: %d. Expecting %d", pendingRequests, concurrency)
|
|
}
|
|
|
|
// unblock request handlers on the server and wait until all the requests are finished.
|
|
close(doneCh)
|
|
for i := 0; i < concurrency; i++ {
|
|
select {
|
|
case err := <-resultCh:
|
|
if err != nil {
|
|
t.Fatalf("unexpected error: %v", err)
|
|
}
|
|
case <-time.After(time.Second):
|
|
t.Fatalf("timeout")
|
|
}
|
|
}
|
|
|
|
pendingRequests = c.PendingRequests()
|
|
if pendingRequests != 0 {
|
|
t.Fatalf("non-zero pendingRequests: %d", pendingRequests)
|
|
}
|
|
|
|
// stop the server
|
|
if err := ln.Close(); err != nil {
|
|
t.Fatalf("unexpected error: %v", err)
|
|
}
|
|
select {
|
|
case <-serverStopCh:
|
|
case <-time.After(time.Second):
|
|
t.Fatalf("timeout")
|
|
}
|
|
}
|
|
|
|
func TestHostClientMaxConnsWithDeadline(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
var (
|
|
emptyBodyCount uint8
|
|
ln = fasthttputil.NewInmemoryListener()
|
|
timeout = 200 * time.Millisecond
|
|
wg sync.WaitGroup
|
|
)
|
|
|
|
s := &Server{
|
|
Handler: func(ctx *RequestCtx) {
|
|
if len(ctx.PostBody()) == 0 {
|
|
emptyBodyCount++
|
|
}
|
|
|
|
ctx.WriteString("foo") //nolint:errcheck
|
|
},
|
|
}
|
|
serverStopCh := make(chan struct{})
|
|
go func() {
|
|
if err := s.Serve(ln); err != nil {
|
|
t.Errorf("unexpected error: %v", err)
|
|
}
|
|
close(serverStopCh)
|
|
}()
|
|
|
|
c := &HostClient{
|
|
Addr: "foobar",
|
|
Dial: func(addr string) (net.Conn, error) {
|
|
return ln.Dial()
|
|
},
|
|
MaxConns: 1,
|
|
}
|
|
|
|
for i := 0; i < 5; i++ {
|
|
wg.Add(1)
|
|
go func() {
|
|
defer wg.Done()
|
|
|
|
req := AcquireRequest()
|
|
req.SetRequestURI("http://foobar/baz")
|
|
req.Header.SetMethod(MethodPost)
|
|
req.SetBodyString("bar")
|
|
resp := AcquireResponse()
|
|
|
|
for {
|
|
if err := c.DoDeadline(req, resp, time.Now().Add(timeout)); err != nil {
|
|
if err == ErrNoFreeConns {
|
|
time.Sleep(time.Millisecond)
|
|
continue
|
|
}
|
|
t.Errorf("unexpected error: %v", err)
|
|
}
|
|
break
|
|
}
|
|
|
|
if resp.StatusCode() != StatusOK {
|
|
t.Errorf("unexpected status code %d. Expecting %d", resp.StatusCode(), StatusOK)
|
|
}
|
|
|
|
body := resp.Body()
|
|
if string(body) != "foo" {
|
|
t.Errorf("unexpected body %q. Expecting %q", body, "abcd")
|
|
}
|
|
}()
|
|
}
|
|
wg.Wait()
|
|
|
|
if err := ln.Close(); err != nil {
|
|
t.Fatalf("unexpected error: %v", err)
|
|
}
|
|
select {
|
|
case <-serverStopCh:
|
|
case <-time.After(time.Second):
|
|
t.Fatalf("timeout")
|
|
}
|
|
|
|
if emptyBodyCount > 0 {
|
|
t.Fatalf("at least one request body was empty")
|
|
}
|
|
}
|
|
|
|
func TestHostClientMaxConnDuration(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
ln := fasthttputil.NewInmemoryListener()
|
|
|
|
connectionCloseCount := uint32(0)
|
|
s := &Server{
|
|
Handler: func(ctx *RequestCtx) {
|
|
ctx.WriteString("abcd") //nolint:errcheck
|
|
if ctx.Request.ConnectionClose() {
|
|
atomic.AddUint32(&connectionCloseCount, 1)
|
|
}
|
|
},
|
|
}
|
|
serverStopCh := make(chan struct{})
|
|
go func() {
|
|
if err := s.Serve(ln); err != nil {
|
|
t.Errorf("unexpected error: %v", err)
|
|
}
|
|
close(serverStopCh)
|
|
}()
|
|
|
|
c := &HostClient{
|
|
Addr: "foobar",
|
|
Dial: func(addr string) (net.Conn, error) {
|
|
return ln.Dial()
|
|
},
|
|
MaxConnDuration: 10 * time.Millisecond,
|
|
}
|
|
|
|
for i := 0; i < 5; i++ {
|
|
statusCode, body, err := c.Get(nil, "http://aaaa.com/bbb/cc")
|
|
if err != nil {
|
|
t.Fatalf("unexpected error: %v", err)
|
|
}
|
|
if statusCode != StatusOK {
|
|
t.Fatalf("unexpected status code %d. Expecting %d", statusCode, StatusOK)
|
|
}
|
|
if string(body) != "abcd" {
|
|
t.Fatalf("unexpected body %q. Expecting %q", body, "abcd")
|
|
}
|
|
time.Sleep(c.MaxConnDuration)
|
|
}
|
|
|
|
if err := ln.Close(); err != nil {
|
|
t.Fatalf("unexpected error: %v", err)
|
|
}
|
|
select {
|
|
case <-serverStopCh:
|
|
case <-time.After(time.Second):
|
|
t.Fatalf("timeout")
|
|
}
|
|
|
|
if connectionCloseCount == 0 {
|
|
t.Fatalf("expecting at least one 'Connection: close' request header")
|
|
}
|
|
}
|
|
|
|
func TestHostClientMultipleAddrs(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
ln := fasthttputil.NewInmemoryListener()
|
|
|
|
s := &Server{
|
|
Handler: func(ctx *RequestCtx) {
|
|
ctx.Write(ctx.Host()) //nolint:errcheck
|
|
ctx.SetConnectionClose()
|
|
},
|
|
}
|
|
serverStopCh := make(chan struct{})
|
|
go func() {
|
|
if err := s.Serve(ln); err != nil {
|
|
t.Errorf("unexpected error: %v", err)
|
|
}
|
|
close(serverStopCh)
|
|
}()
|
|
|
|
dialsCount := make(map[string]int)
|
|
c := &HostClient{
|
|
Addr: "foo,bar,baz",
|
|
Dial: func(addr string) (net.Conn, error) {
|
|
dialsCount[addr]++
|
|
return ln.Dial()
|
|
},
|
|
}
|
|
|
|
for i := 0; i < 9; i++ {
|
|
statusCode, body, err := c.Get(nil, "http://foobar/baz/aaa?bbb=ddd")
|
|
if err != nil {
|
|
t.Fatalf("unexpected error: %v", err)
|
|
}
|
|
if statusCode != StatusOK {
|
|
t.Fatalf("unexpected status code %d. Expecting %d", statusCode, StatusOK)
|
|
}
|
|
if string(body) != "foobar" {
|
|
t.Fatalf("unexpected body %q. Expecting %q", body, "foobar")
|
|
}
|
|
}
|
|
|
|
if err := ln.Close(); err != nil {
|
|
t.Fatalf("unexpected error: %v", err)
|
|
}
|
|
select {
|
|
case <-serverStopCh:
|
|
case <-time.After(time.Second):
|
|
t.Fatalf("timeout")
|
|
}
|
|
|
|
if len(dialsCount) != 3 {
|
|
t.Fatalf("unexpected dialsCount size %d. Expecting 3", len(dialsCount))
|
|
}
|
|
for _, k := range []string{"foo", "bar", "baz"} {
|
|
if dialsCount[k] != 3 {
|
|
t.Fatalf("unexpected dialsCount for %q. Expecting 3", k)
|
|
}
|
|
}
|
|
}
|
|
|
|
func TestClientFollowRedirects(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
s := &Server{
|
|
Handler: func(ctx *RequestCtx) {
|
|
switch string(ctx.Path()) {
|
|
case "/foo":
|
|
u := ctx.URI()
|
|
u.Update("/xy?z=wer")
|
|
ctx.Redirect(u.String(), StatusFound)
|
|
case "/xy":
|
|
u := ctx.URI()
|
|
u.Update("/bar")
|
|
ctx.Redirect(u.String(), StatusFound)
|
|
case "/abc/*/123":
|
|
u := ctx.URI()
|
|
u.Update("/xyz/*/456")
|
|
ctx.Redirect(u.String(), StatusFound)
|
|
default:
|
|
ctx.Success("text/plain", ctx.Path())
|
|
}
|
|
},
|
|
}
|
|
ln := fasthttputil.NewInmemoryListener()
|
|
|
|
serverStopCh := make(chan struct{})
|
|
go func() {
|
|
if err := s.Serve(ln); err != nil {
|
|
t.Errorf("unexpected error: %v", err)
|
|
}
|
|
close(serverStopCh)
|
|
}()
|
|
|
|
c := &HostClient{
|
|
Addr: "xxx",
|
|
Dial: func(addr string) (net.Conn, error) {
|
|
return ln.Dial()
|
|
},
|
|
}
|
|
|
|
for i := 0; i < 10; i++ {
|
|
statusCode, body, err := c.GetTimeout(nil, "http://xxx/foo", time.Second)
|
|
if err != nil {
|
|
t.Fatalf("unexpected error: %v", err)
|
|
}
|
|
if statusCode != StatusOK {
|
|
t.Fatalf("unexpected status code: %d", statusCode)
|
|
}
|
|
if string(body) != "/bar" {
|
|
t.Fatalf("unexpected response %q. Expecting %q", body, "/bar")
|
|
}
|
|
}
|
|
|
|
for i := 0; i < 10; i++ {
|
|
statusCode, body, err := c.Get(nil, "http://xxx/aaab/sss")
|
|
if err != nil {
|
|
t.Fatalf("unexpected error: %v", err)
|
|
}
|
|
if statusCode != StatusOK {
|
|
t.Fatalf("unexpected status code: %d", statusCode)
|
|
}
|
|
if string(body) != "/aaab/sss" {
|
|
t.Fatalf("unexpected response %q. Expecting %q", body, "/aaab/sss")
|
|
}
|
|
}
|
|
|
|
for i := 0; i < 10; i++ {
|
|
req := AcquireRequest()
|
|
resp := AcquireResponse()
|
|
|
|
req.SetRequestURI("http://xxx/foo")
|
|
|
|
err := c.DoRedirects(req, resp, 16)
|
|
if err != nil {
|
|
t.Fatalf("unexpected error: %v", err)
|
|
}
|
|
|
|
if statusCode := resp.StatusCode(); statusCode != StatusOK {
|
|
t.Fatalf("unexpected status code: %d", statusCode)
|
|
}
|
|
|
|
if body := string(resp.Body()); body != "/bar" {
|
|
t.Fatalf("unexpected response %q. Expecting %q", body, "/bar")
|
|
}
|
|
|
|
ReleaseRequest(req)
|
|
ReleaseResponse(resp)
|
|
}
|
|
|
|
for i := 0; i < 10; i++ {
|
|
req := AcquireRequest()
|
|
resp := AcquireResponse()
|
|
|
|
req.SetRequestURI("http://xxx/foo")
|
|
|
|
req.SetTimeout(time.Second)
|
|
err := c.DoRedirects(req, resp, 16)
|
|
if err != nil {
|
|
t.Fatalf("unexpected error: %v", err)
|
|
}
|
|
|
|
if statusCode := resp.StatusCode(); statusCode != StatusOK {
|
|
t.Fatalf("unexpected status code: %d", statusCode)
|
|
}
|
|
|
|
if body := string(resp.Body()); body != "/bar" {
|
|
t.Fatalf("unexpected response %q. Expecting %q", body, "/bar")
|
|
}
|
|
|
|
ReleaseRequest(req)
|
|
ReleaseResponse(resp)
|
|
}
|
|
|
|
for i := 0; i < 10; i++ {
|
|
req := AcquireRequest()
|
|
resp := AcquireResponse()
|
|
|
|
req.SetRequestURI("http://xxx/foo")
|
|
|
|
testConn, _ := net.Dial("tcp", ln.Addr().String())
|
|
timeoutConn := &Client{
|
|
Dial: func(addr string) (net.Conn, error) {
|
|
return &readTimeoutConn{Conn: testConn, t: time.Second}, nil
|
|
},
|
|
}
|
|
|
|
req.SetTimeout(time.Millisecond)
|
|
err := timeoutConn.DoRedirects(req, resp, 16)
|
|
if err == nil {
|
|
t.Errorf("expecting error")
|
|
}
|
|
if err != ErrTimeout {
|
|
t.Errorf("unexpected error: %v. Expecting %v", err, ErrTimeout)
|
|
}
|
|
|
|
ReleaseRequest(req)
|
|
ReleaseResponse(resp)
|
|
}
|
|
|
|
for i := 0; i < 10; i++ {
|
|
req := AcquireRequest()
|
|
resp := AcquireResponse()
|
|
|
|
req.SetRequestURI("http://xxx/abc/*/123")
|
|
req.URI().DisablePathNormalizing = true
|
|
req.DisableRedirectPathNormalizing = true
|
|
|
|
err := c.DoRedirects(req, resp, 16)
|
|
if err != nil {
|
|
t.Fatalf("unexpected error: %v", err)
|
|
}
|
|
|
|
if statusCode := resp.StatusCode(); statusCode != StatusOK {
|
|
t.Fatalf("unexpected status code: %d", statusCode)
|
|
}
|
|
|
|
if body := string(resp.Body()); body != "/xyz/*/456" {
|
|
t.Fatalf("unexpected response %q. Expecting %q", body, "/xyz/*/456")
|
|
}
|
|
|
|
ReleaseRequest(req)
|
|
ReleaseResponse(resp)
|
|
}
|
|
|
|
req := AcquireRequest()
|
|
resp := AcquireResponse()
|
|
|
|
req.SetRequestURI("http://xxx/foo")
|
|
|
|
err := c.DoRedirects(req, resp, 0)
|
|
if have, want := err, ErrTooManyRedirects; have != want {
|
|
t.Fatalf("want error: %v, have %v", want, have)
|
|
}
|
|
|
|
ReleaseRequest(req)
|
|
ReleaseResponse(resp)
|
|
}
|
|
|
|
func TestClientGetTimeoutSuccess(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
s := startEchoServer(t, "tcp", "127.0.0.1:")
|
|
defer s.Stop()
|
|
|
|
testClientGetTimeoutSuccess(t, &defaultClient, "http://"+s.Addr(), 100)
|
|
}
|
|
|
|
func TestClientGetTimeoutSuccessConcurrent(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
s := startEchoServer(t, "tcp", "127.0.0.1:")
|
|
defer s.Stop()
|
|
|
|
var wg sync.WaitGroup
|
|
for i := 0; i < 10; i++ {
|
|
wg.Add(1)
|
|
go func() {
|
|
defer wg.Done()
|
|
testClientGetTimeoutSuccess(t, &defaultClient, "http://"+s.Addr(), 100)
|
|
}()
|
|
}
|
|
wg.Wait()
|
|
}
|
|
|
|
func TestClientDoTimeoutSuccess(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
s := startEchoServer(t, "tcp", "127.0.0.1:")
|
|
defer s.Stop()
|
|
|
|
testClientDoTimeoutSuccess(t, &defaultClient, "http://"+s.Addr(), 100)
|
|
testClientRequestSetTimeoutSuccess(t, &defaultClient, "http://"+s.Addr(), 100)
|
|
}
|
|
|
|
func TestClientDoTimeoutSuccessConcurrent(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
s := startEchoServer(t, "tcp", "127.0.0.1:")
|
|
defer s.Stop()
|
|
|
|
var wg sync.WaitGroup
|
|
for i := 0; i < 10; i++ {
|
|
wg.Add(1)
|
|
go func() {
|
|
defer wg.Done()
|
|
testClientDoTimeoutSuccess(t, &defaultClient, "http://"+s.Addr(), 100)
|
|
testClientRequestSetTimeoutSuccess(t, &defaultClient, "http://"+s.Addr(), 100)
|
|
}()
|
|
}
|
|
wg.Wait()
|
|
}
|
|
|
|
func TestClientGetTimeoutError(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
s := startEchoServer(t, "tcp", "127.0.0.1:")
|
|
defer s.Stop()
|
|
|
|
testConn, _ := net.Dial("tcp", s.ln.Addr().String())
|
|
c := &Client{
|
|
Dial: func(addr string) (net.Conn, error) {
|
|
return &readTimeoutConn{Conn: testConn, t: time.Second}, nil
|
|
},
|
|
}
|
|
|
|
testClientGetTimeoutError(t, c, 100)
|
|
}
|
|
|
|
func TestClientGetTimeoutErrorConcurrent(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
s := startEchoServer(t, "tcp", "127.0.0.1:")
|
|
defer s.Stop()
|
|
|
|
testConn, _ := net.Dial("tcp", s.ln.Addr().String())
|
|
c := &Client{
|
|
Dial: func(addr string) (net.Conn, error) {
|
|
return &readTimeoutConn{Conn: testConn, t: time.Second}, nil
|
|
},
|
|
MaxConnsPerHost: 1000,
|
|
}
|
|
|
|
var wg sync.WaitGroup
|
|
for i := 0; i < 10; i++ {
|
|
wg.Add(1)
|
|
go func() {
|
|
defer wg.Done()
|
|
testClientGetTimeoutError(t, c, 100)
|
|
}()
|
|
}
|
|
wg.Wait()
|
|
}
|
|
|
|
func TestClientDoTimeoutError(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
s := startEchoServer(t, "tcp", "127.0.0.1:")
|
|
defer s.Stop()
|
|
|
|
testConn, _ := net.Dial("tcp", s.ln.Addr().String())
|
|
c := &Client{
|
|
Dial: func(addr string) (net.Conn, error) {
|
|
return &readTimeoutConn{Conn: testConn, t: time.Second}, nil
|
|
},
|
|
}
|
|
|
|
testClientDoTimeoutError(t, c, 100)
|
|
testClientRequestSetTimeoutError(t, c, 100)
|
|
}
|
|
|
|
func TestClientDoTimeoutErrorConcurrent(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
s := startEchoServer(t, "tcp", "127.0.0.1:")
|
|
defer s.Stop()
|
|
|
|
testConn, _ := net.Dial("tcp", s.ln.Addr().String())
|
|
c := &Client{
|
|
Dial: func(addr string) (net.Conn, error) {
|
|
return &readTimeoutConn{Conn: testConn, t: time.Second}, nil
|
|
},
|
|
MaxConnsPerHost: 1000,
|
|
}
|
|
|
|
var wg sync.WaitGroup
|
|
for i := 0; i < 10; i++ {
|
|
wg.Add(1)
|
|
go func() {
|
|
defer wg.Done()
|
|
testClientDoTimeoutError(t, c, 100)
|
|
}()
|
|
}
|
|
wg.Wait()
|
|
}
|
|
|
|
func testClientDoTimeoutError(t *testing.T, c *Client, n int) {
|
|
var req Request
|
|
var resp Response
|
|
req.SetRequestURI("http://foobar.com/baz")
|
|
for i := 0; i < n; i++ {
|
|
err := c.DoTimeout(&req, &resp, time.Millisecond)
|
|
if err == nil {
|
|
t.Errorf("expecting error")
|
|
}
|
|
if err != ErrTimeout {
|
|
t.Errorf("unexpected error: %v. Expecting %v", err, ErrTimeout)
|
|
}
|
|
}
|
|
}
|
|
|
|
func testClientGetTimeoutError(t *testing.T, c *Client, n int) {
|
|
buf := make([]byte, 10)
|
|
for i := 0; i < n; i++ {
|
|
statusCode, body, err := c.GetTimeout(buf, "http://foobar.com/baz", time.Millisecond)
|
|
if err == nil {
|
|
t.Errorf("expecting error")
|
|
}
|
|
if err != ErrTimeout {
|
|
t.Errorf("unexpected error: %v. Expecting %v", err, ErrTimeout)
|
|
}
|
|
if statusCode != 0 {
|
|
t.Errorf("unexpected statusCode=%d. Expecting %d", statusCode, 0)
|
|
}
|
|
if body == nil {
|
|
t.Errorf("body must be non-nil")
|
|
}
|
|
}
|
|
}
|
|
|
|
func testClientRequestSetTimeoutError(t *testing.T, c *Client, n int) {
|
|
var req Request
|
|
var resp Response
|
|
req.SetRequestURI("http://foobar.com/baz")
|
|
for i := 0; i < n; i++ {
|
|
req.SetTimeout(time.Millisecond)
|
|
err := c.Do(&req, &resp)
|
|
if err == nil {
|
|
t.Errorf("expecting error")
|
|
}
|
|
if err != ErrTimeout {
|
|
t.Errorf("unexpected error: %v. Expecting %v", err, ErrTimeout)
|
|
}
|
|
}
|
|
}
|
|
|
|
type readTimeoutConn struct {
|
|
net.Conn
|
|
t time.Duration
|
|
wc chan struct{}
|
|
rc chan struct{}
|
|
}
|
|
|
|
func (r *readTimeoutConn) Read(p []byte) (int, error) {
|
|
<-r.rc
|
|
return 0, os.ErrDeadlineExceeded
|
|
}
|
|
|
|
func (r *readTimeoutConn) Write(p []byte) (int, error) {
|
|
<-r.wc
|
|
return 0, os.ErrDeadlineExceeded
|
|
}
|
|
|
|
func (r *readTimeoutConn) Close() error {
|
|
return nil
|
|
}
|
|
|
|
func (r *readTimeoutConn) LocalAddr() net.Addr {
|
|
return nil
|
|
}
|
|
|
|
func (r *readTimeoutConn) RemoteAddr() net.Addr {
|
|
return nil
|
|
}
|
|
|
|
func (r *readTimeoutConn) SetReadDeadline(d time.Time) error {
|
|
r.rc = make(chan struct{}, 1)
|
|
go func() {
|
|
time.Sleep(time.Until(d))
|
|
r.rc <- struct{}{}
|
|
}()
|
|
return nil
|
|
}
|
|
|
|
func (r *readTimeoutConn) SetWriteDeadline(d time.Time) error {
|
|
r.wc = make(chan struct{}, 1)
|
|
go func() {
|
|
time.Sleep(time.Until(d))
|
|
r.wc <- struct{}{}
|
|
}()
|
|
return nil
|
|
}
|
|
|
|
func TestClientNonIdempotentRetry(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
dialsCount := 0
|
|
c := &Client{
|
|
Dial: func(_ string) (net.Conn, error) {
|
|
dialsCount++
|
|
switch dialsCount {
|
|
case 1, 2:
|
|
return &readErrorConn{}, nil
|
|
case 3:
|
|
return &singleReadConn{
|
|
s: "HTTP/1.1 345 OK\r\nContent-Type: foobar\r\nContent-Length: 7\r\n\r\n0123456",
|
|
}, nil
|
|
default:
|
|
t.Fatalf("unexpected number of dials: %d", dialsCount)
|
|
}
|
|
panic("unreachable")
|
|
},
|
|
}
|
|
|
|
// This POST must succeed, since the readErrorConn closes
|
|
// the connection before sending any response.
|
|
// So the client must retry non-idempotent request.
|
|
dialsCount = 0
|
|
statusCode, body, err := c.Post(nil, "http://foobar/a/b", nil)
|
|
if err != nil {
|
|
t.Fatalf("unexpected error: %v", err)
|
|
}
|
|
if statusCode != 345 {
|
|
t.Fatalf("unexpected status code: %d. Expecting 345", statusCode)
|
|
}
|
|
if string(body) != "0123456" {
|
|
t.Fatalf("unexpected body: %q. Expecting %q", body, "0123456")
|
|
}
|
|
|
|
// Verify that idempotent GET succeeds.
|
|
dialsCount = 0
|
|
statusCode, body, err = c.Get(nil, "http://foobar/a/b")
|
|
if err != nil {
|
|
t.Fatalf("unexpected error: %v", err)
|
|
}
|
|
if statusCode != 345 {
|
|
t.Fatalf("unexpected status code: %d. Expecting 345", statusCode)
|
|
}
|
|
if string(body) != "0123456" {
|
|
t.Fatalf("unexpected body: %q. Expecting %q", body, "0123456")
|
|
}
|
|
}
|
|
|
|
func TestClientNonIdempotentRetry_BodyStream(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
dialsCount := 0
|
|
c := &Client{
|
|
Dial: func(_ string) (net.Conn, error) {
|
|
dialsCount++
|
|
switch dialsCount {
|
|
case 1, 2:
|
|
return &readErrorConn{}, nil
|
|
case 3:
|
|
return &singleEchoConn{
|
|
b: []byte("HTTP/1.1 345 OK\r\nContent-Type: foobar\r\n\r\n"),
|
|
}, nil
|
|
default:
|
|
t.Fatalf("unexpected number of dials: %d", dialsCount)
|
|
}
|
|
panic("unreachable")
|
|
},
|
|
}
|
|
|
|
dialsCount = 0
|
|
|
|
req := Request{}
|
|
res := Response{}
|
|
|
|
req.SetRequestURI("http://foobar/a/b")
|
|
req.Header.SetMethod("POST")
|
|
body := bytes.NewBufferString("test")
|
|
req.SetBodyStream(body, body.Len())
|
|
|
|
err := c.Do(&req, &res)
|
|
if err == nil {
|
|
t.Fatal("expected error from being unable to retry a bodyStream")
|
|
}
|
|
}
|
|
|
|
func TestClientIdempotentRequest(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
dialsCount := 0
|
|
c := &Client{
|
|
Dial: func(_ string) (net.Conn, error) {
|
|
dialsCount++
|
|
switch dialsCount {
|
|
case 1:
|
|
return &singleReadConn{
|
|
s: "invalid response",
|
|
}, nil
|
|
case 2:
|
|
return &writeErrorConn{}, nil
|
|
case 3:
|
|
return &readErrorConn{}, nil
|
|
case 4:
|
|
return &singleReadConn{
|
|
s: "HTTP/1.1 345 OK\r\nContent-Type: foobar\r\nContent-Length: 7\r\n\r\n0123456",
|
|
}, nil
|
|
default:
|
|
t.Fatalf("unexpected number of dials: %d", dialsCount)
|
|
}
|
|
panic("unreachable")
|
|
},
|
|
}
|
|
|
|
// idempotent GET must succeed.
|
|
statusCode, body, err := c.Get(nil, "http://foobar/a/b")
|
|
if err != nil {
|
|
t.Fatalf("unexpected error: %v", err)
|
|
}
|
|
if statusCode != 345 {
|
|
t.Fatalf("unexpected status code: %d. Expecting 345", statusCode)
|
|
}
|
|
if string(body) != "0123456" {
|
|
t.Fatalf("unexpected body: %q. Expecting %q", body, "0123456")
|
|
}
|
|
|
|
var args Args
|
|
|
|
// non-idempotent POST must fail on incorrect singleReadConn
|
|
dialsCount = 0
|
|
_, _, err = c.Post(nil, "http://foobar/a/b", &args)
|
|
if err == nil {
|
|
t.Fatalf("expecting error")
|
|
}
|
|
|
|
// non-idempotent POST must fail on incorrect singleReadConn
|
|
dialsCount = 0
|
|
_, _, err = c.Post(nil, "http://foobar/a/b", nil)
|
|
if err == nil {
|
|
t.Fatalf("expecting error")
|
|
}
|
|
}
|
|
|
|
func TestClientRetryRequestWithCustomDecider(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
dialsCount := 0
|
|
c := &Client{
|
|
Dial: func(_ string) (net.Conn, error) {
|
|
dialsCount++
|
|
switch dialsCount {
|
|
case 1:
|
|
return &singleReadConn{
|
|
s: "invalid response",
|
|
}, nil
|
|
case 2:
|
|
return &writeErrorConn{}, nil
|
|
case 3:
|
|
return &readErrorConn{}, nil
|
|
case 4:
|
|
return &singleReadConn{
|
|
s: "HTTP/1.1 345 OK\r\nContent-Type: foobar\r\nContent-Length: 7\r\n\r\n0123456",
|
|
}, nil
|
|
default:
|
|
t.Fatalf("unexpected number of dials: %d", dialsCount)
|
|
}
|
|
panic("unreachable")
|
|
},
|
|
RetryIf: func(req *Request) bool {
|
|
return req.URI().String() == "http://foobar/a/b"
|
|
},
|
|
}
|
|
|
|
var args Args
|
|
|
|
// Post must succeed for http://foobar/a/b uri.
|
|
statusCode, body, err := c.Post(nil, "http://foobar/a/b", &args)
|
|
if err != nil {
|
|
t.Fatalf("unexpected error: %v", err)
|
|
}
|
|
if statusCode != 345 {
|
|
t.Fatalf("unexpected status code: %d. Expecting 345", statusCode)
|
|
}
|
|
if string(body) != "0123456" {
|
|
t.Fatalf("unexpected body: %q. Expecting %q", body, "0123456")
|
|
}
|
|
|
|
// POST must fail for http://foobar/a/b/c uri.
|
|
dialsCount = 0
|
|
_, _, err = c.Post(nil, "http://foobar/a/b/c", &args)
|
|
if err == nil {
|
|
t.Fatalf("expecting error")
|
|
}
|
|
}
|
|
|
|
type TransportDemo struct {
|
|
br *bufio.Reader
|
|
bw *bufio.Writer
|
|
}
|
|
|
|
func (t TransportDemo) RoundTrip(hc *HostClient, req *Request, res *Response) (retry bool, err error) {
|
|
if err = req.Write(t.bw); err != nil {
|
|
return false, err
|
|
}
|
|
if err = t.bw.Flush(); err != nil {
|
|
return false, err
|
|
}
|
|
err = res.Read(t.br)
|
|
return err != nil, err
|
|
}
|
|
|
|
func TestHostClientTransport(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
ln := fasthttputil.NewInmemoryListener()
|
|
|
|
s := &Server{
|
|
Handler: func(ctx *RequestCtx) {
|
|
ctx.WriteString("abcd") //nolint:errcheck
|
|
},
|
|
}
|
|
serverStopCh := make(chan struct{})
|
|
go func() {
|
|
if err := s.Serve(ln); err != nil {
|
|
t.Errorf("unexpected error: %v", err)
|
|
}
|
|
close(serverStopCh)
|
|
}()
|
|
|
|
c := &HostClient{
|
|
Addr: "foobar",
|
|
Transport: func() RoundTripper {
|
|
c, _ := ln.Dial()
|
|
|
|
br := bufio.NewReader(c)
|
|
bw := bufio.NewWriter(c)
|
|
|
|
return TransportDemo{br: br, bw: bw}
|
|
}(),
|
|
}
|
|
|
|
for i := 0; i < 5; i++ {
|
|
statusCode, body, err := c.Get(nil, "http://aaaa.com/bbb/cc")
|
|
if err != nil {
|
|
t.Fatalf("unexpected error: %v", err)
|
|
}
|
|
if statusCode != StatusOK {
|
|
t.Fatalf("unexpected status code %d. Expecting %d", statusCode, StatusOK)
|
|
}
|
|
if string(body) != "abcd" {
|
|
t.Fatalf("unexpected body %q. Expecting %q", body, "abcd")
|
|
}
|
|
}
|
|
|
|
if err := ln.Close(); err != nil {
|
|
t.Fatalf("unexpected error: %v", err)
|
|
}
|
|
|
|
select {
|
|
case <-serverStopCh:
|
|
case <-time.After(time.Second):
|
|
t.Fatalf("timeout")
|
|
}
|
|
}
|
|
|
|
type writeErrorConn struct {
|
|
net.Conn
|
|
}
|
|
|
|
func (w *writeErrorConn) Write(p []byte) (int, error) {
|
|
return 1, fmt.Errorf("error")
|
|
}
|
|
|
|
func (w *writeErrorConn) Close() error {
|
|
return nil
|
|
}
|
|
|
|
func (w *writeErrorConn) LocalAddr() net.Addr {
|
|
return nil
|
|
}
|
|
|
|
func (w *writeErrorConn) RemoteAddr() net.Addr {
|
|
return nil
|
|
}
|
|
|
|
func (r *writeErrorConn) SetReadDeadline(_ time.Time) error {
|
|
return nil
|
|
}
|
|
|
|
func (r *writeErrorConn) SetWriteDeadline(_ time.Time) error {
|
|
return nil
|
|
}
|
|
|
|
type readErrorConn struct {
|
|
net.Conn
|
|
}
|
|
|
|
func (r *readErrorConn) Read(p []byte) (int, error) {
|
|
return 0, fmt.Errorf("error")
|
|
}
|
|
|
|
func (r *readErrorConn) Write(p []byte) (int, error) {
|
|
return len(p), nil
|
|
}
|
|
|
|
func (r *readErrorConn) Close() error {
|
|
return nil
|
|
}
|
|
|
|
func (r *readErrorConn) LocalAddr() net.Addr {
|
|
return nil
|
|
}
|
|
|
|
func (r *readErrorConn) RemoteAddr() net.Addr {
|
|
return nil
|
|
}
|
|
|
|
func (r *readErrorConn) SetReadDeadline(_ time.Time) error {
|
|
return nil
|
|
}
|
|
|
|
func (r *readErrorConn) SetWriteDeadline(_ time.Time) error {
|
|
return nil
|
|
}
|
|
|
|
type singleReadConn struct {
|
|
net.Conn
|
|
s string
|
|
n int
|
|
}
|
|
|
|
func (r *singleReadConn) Read(p []byte) (int, error) {
|
|
if len(r.s) == r.n {
|
|
return 0, io.EOF
|
|
}
|
|
n := copy(p, []byte(r.s[r.n:]))
|
|
r.n += n
|
|
return n, nil
|
|
}
|
|
|
|
func (r *singleReadConn) Write(p []byte) (int, error) {
|
|
return len(p), nil
|
|
}
|
|
|
|
func (r *singleReadConn) Close() error {
|
|
return nil
|
|
}
|
|
|
|
func (r *singleReadConn) LocalAddr() net.Addr {
|
|
return nil
|
|
}
|
|
|
|
func (r *singleReadConn) RemoteAddr() net.Addr {
|
|
return nil
|
|
}
|
|
|
|
func (r *singleReadConn) SetReadDeadline(_ time.Time) error {
|
|
return nil
|
|
}
|
|
|
|
func (r *singleReadConn) SetWriteDeadline(_ time.Time) error {
|
|
return nil
|
|
}
|
|
|
|
type singleEchoConn struct {
|
|
net.Conn
|
|
b []byte
|
|
n int
|
|
}
|
|
|
|
func (r *singleEchoConn) Read(p []byte) (int, error) {
|
|
if len(r.b) == r.n {
|
|
return 0, io.EOF
|
|
}
|
|
n := copy(p, r.b[r.n:])
|
|
r.n += n
|
|
return n, nil
|
|
}
|
|
|
|
func (r *singleEchoConn) Write(p []byte) (int, error) {
|
|
r.b = append(r.b, p...)
|
|
return len(p), nil
|
|
}
|
|
|
|
func (r *singleEchoConn) Close() error {
|
|
return nil
|
|
}
|
|
|
|
func (r *singleEchoConn) LocalAddr() net.Addr {
|
|
return nil
|
|
}
|
|
|
|
func (r *singleEchoConn) RemoteAddr() net.Addr {
|
|
return nil
|
|
}
|
|
|
|
func (r *singleEchoConn) SetReadDeadline(_ time.Time) error {
|
|
return nil
|
|
}
|
|
|
|
func (r *singleEchoConn) SetWriteDeadline(_ time.Time) error {
|
|
return nil
|
|
}
|
|
|
|
func TestSingleEchoConn(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
c := &Client{
|
|
Dial: func(addr string) (net.Conn, error) {
|
|
return &singleEchoConn{
|
|
b: []byte("HTTP/1.1 345 OK\r\nContent-Type: foobar\r\n\r\n"),
|
|
}, nil
|
|
},
|
|
}
|
|
|
|
req := Request{}
|
|
res := Response{}
|
|
|
|
req.SetRequestURI("http://foobar/a/b")
|
|
req.Header.SetMethod("POST")
|
|
req.Header.Set("Content-Type", "text/plain")
|
|
body := bytes.NewBufferString("test")
|
|
req.SetBodyStream(body, body.Len())
|
|
|
|
err := c.Do(&req, &res)
|
|
if err != nil {
|
|
t.Fatalf("unexpected error: %v", err)
|
|
}
|
|
if res.StatusCode() != 345 {
|
|
t.Fatalf("unexpected status code: %d. Expecting 345", res.StatusCode())
|
|
}
|
|
expected := "POST /a/b HTTP/1.1\r\nUser-Agent: fasthttp\r\nHost: foobar\r\nContent-Type: text/plain\r\nContent-Length: 4\r\n\r\ntest"
|
|
if string(res.Body()) != expected {
|
|
t.Fatalf("unexpected body: %q. Expecting %q", res.Body(), expected)
|
|
}
|
|
}
|
|
|
|
func TestClientHTTPSInvalidServerName(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
sHTTPS := startEchoServerTLS(t, "tcp", "127.0.0.1:")
|
|
defer sHTTPS.Stop()
|
|
|
|
var c Client
|
|
|
|
for i := 0; i < 10; i++ {
|
|
_, _, err := c.GetTimeout(nil, "https://"+sHTTPS.Addr(), time.Second)
|
|
if err == nil {
|
|
t.Fatalf("expecting TLS error")
|
|
}
|
|
}
|
|
}
|
|
|
|
func TestClientHTTPSConcurrent(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
sHTTP := startEchoServer(t, "tcp", "127.0.0.1:")
|
|
defer sHTTP.Stop()
|
|
|
|
sHTTPS := startEchoServerTLS(t, "tcp", "127.0.0.1:")
|
|
defer sHTTPS.Stop()
|
|
|
|
c := &Client{
|
|
TLSConfig: &tls.Config{
|
|
InsecureSkipVerify: true,
|
|
},
|
|
}
|
|
|
|
var wg sync.WaitGroup
|
|
for i := 0; i < 4; i++ {
|
|
wg.Add(1)
|
|
addr := "http://" + sHTTP.Addr()
|
|
if i&1 != 0 {
|
|
addr = "https://" + sHTTPS.Addr()
|
|
}
|
|
go func() {
|
|
defer wg.Done()
|
|
testClientGet(t, c, addr, 20)
|
|
testClientPost(t, c, addr, 10)
|
|
}()
|
|
}
|
|
wg.Wait()
|
|
}
|
|
|
|
func TestClientManyServers(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
var addrs []string
|
|
for i := 0; i < 10; i++ {
|
|
s := startEchoServer(t, "tcp", "127.0.0.1:")
|
|
defer s.Stop()
|
|
addrs = append(addrs, s.Addr())
|
|
}
|
|
|
|
var wg sync.WaitGroup
|
|
for i := 0; i < 4; i++ {
|
|
wg.Add(1)
|
|
addr := "http://" + addrs[i]
|
|
go func() {
|
|
defer wg.Done()
|
|
testClientGet(t, &defaultClient, addr, 20)
|
|
testClientPost(t, &defaultClient, addr, 10)
|
|
}()
|
|
}
|
|
wg.Wait()
|
|
}
|
|
|
|
func TestClientGet(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
s := startEchoServer(t, "tcp", "127.0.0.1:")
|
|
defer s.Stop()
|
|
|
|
testClientGet(t, &defaultClient, "http://"+s.Addr(), 100)
|
|
}
|
|
|
|
func TestClientPost(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
s := startEchoServer(t, "tcp", "127.0.0.1:")
|
|
defer s.Stop()
|
|
|
|
testClientPost(t, &defaultClient, "http://"+s.Addr(), 100)
|
|
}
|
|
|
|
func TestClientConcurrent(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
s := startEchoServer(t, "tcp", "127.0.0.1:")
|
|
defer s.Stop()
|
|
|
|
addr := "http://" + s.Addr()
|
|
var wg sync.WaitGroup
|
|
for i := 0; i < 10; i++ {
|
|
wg.Add(1)
|
|
go func() {
|
|
defer wg.Done()
|
|
testClientGet(t, &defaultClient, addr, 30)
|
|
testClientPost(t, &defaultClient, addr, 10)
|
|
}()
|
|
}
|
|
wg.Wait()
|
|
}
|
|
|
|
func skipIfNotUnix(tb testing.TB) {
|
|
switch runtime.GOOS {
|
|
case "android", "nacl", "plan9", "windows":
|
|
tb.Skipf("%s does not support unix sockets", runtime.GOOS)
|
|
}
|
|
if runtime.GOOS == "darwin" && (runtime.GOARCH == "arm" || runtime.GOARCH == "arm64") {
|
|
tb.Skip("iOS does not support unix, unixgram")
|
|
}
|
|
}
|
|
|
|
func TestHostClientGet(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
skipIfNotUnix(t)
|
|
addr := "TestHostClientGet.unix"
|
|
s := startEchoServer(t, "unix", addr)
|
|
defer s.Stop()
|
|
c := createEchoClient(t, "unix", addr)
|
|
|
|
testHostClientGet(t, c, 100)
|
|
}
|
|
|
|
func TestHostClientPost(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
skipIfNotUnix(t)
|
|
addr := "./TestHostClientPost.unix"
|
|
s := startEchoServer(t, "unix", addr)
|
|
defer s.Stop()
|
|
c := createEchoClient(t, "unix", addr)
|
|
|
|
testHostClientPost(t, c, 100)
|
|
}
|
|
|
|
func TestHostClientConcurrent(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
skipIfNotUnix(t)
|
|
addr := "./TestHostClientConcurrent.unix"
|
|
s := startEchoServer(t, "unix", addr)
|
|
defer s.Stop()
|
|
c := createEchoClient(t, "unix", addr)
|
|
|
|
var wg sync.WaitGroup
|
|
for i := 0; i < 10; i++ {
|
|
wg.Add(1)
|
|
go func() {
|
|
defer wg.Done()
|
|
testHostClientGet(t, c, 30)
|
|
testHostClientPost(t, c, 10)
|
|
}()
|
|
}
|
|
wg.Wait()
|
|
}
|
|
|
|
func testClientGet(t *testing.T, c clientGetter, addr string, n int) {
|
|
var buf []byte
|
|
for i := 0; i < n; i++ {
|
|
uri := fmt.Sprintf("%s/foo/%d?bar=baz", addr, i)
|
|
statusCode, body, err := c.Get(buf, uri)
|
|
buf = body
|
|
if err != nil {
|
|
t.Errorf("unexpected error when doing http request: %v", err)
|
|
}
|
|
if statusCode != StatusOK {
|
|
t.Errorf("unexpected status code: %d. Expecting %d", statusCode, StatusOK)
|
|
}
|
|
resultURI := string(body)
|
|
if resultURI != uri {
|
|
t.Errorf("unexpected uri %q. Expecting %q", resultURI, uri)
|
|
}
|
|
}
|
|
}
|
|
|
|
func testClientDoTimeoutSuccess(t *testing.T, c *Client, addr string, n int) {
|
|
var req Request
|
|
var resp Response
|
|
|
|
for i := 0; i < n; i++ {
|
|
uri := fmt.Sprintf("%s/foo/%d?bar=baz", addr, i)
|
|
req.SetRequestURI(uri)
|
|
if err := c.DoTimeout(&req, &resp, time.Second); err != nil {
|
|
t.Errorf("unexpected error: %v", err)
|
|
}
|
|
if resp.StatusCode() != StatusOK {
|
|
t.Errorf("unexpected status code: %d. Expecting %d", resp.StatusCode(), StatusOK)
|
|
}
|
|
resultURI := string(resp.Body())
|
|
if strings.HasPrefix(uri, "https") {
|
|
resultURI = uri[:5] + resultURI[4:]
|
|
}
|
|
if resultURI != uri {
|
|
t.Errorf("unexpected uri %q. Expecting %q", resultURI, uri)
|
|
}
|
|
}
|
|
}
|
|
|
|
func testClientRequestSetTimeoutSuccess(t *testing.T, c *Client, addr string, n int) {
|
|
var req Request
|
|
var resp Response
|
|
|
|
for i := 0; i < n; i++ {
|
|
uri := fmt.Sprintf("%s/foo/%d?bar=baz", addr, i)
|
|
req.SetRequestURI(uri)
|
|
req.SetTimeout(time.Second)
|
|
if err := c.Do(&req, &resp); err != nil {
|
|
t.Errorf("unexpected error: %v", err)
|
|
}
|
|
if resp.StatusCode() != StatusOK {
|
|
t.Errorf("unexpected status code: %d. Expecting %d", resp.StatusCode(), StatusOK)
|
|
}
|
|
resultURI := string(resp.Body())
|
|
if strings.HasPrefix(uri, "https") {
|
|
resultURI = uri[:5] + resultURI[4:]
|
|
}
|
|
if resultURI != uri {
|
|
t.Errorf("unexpected uri %q. Expecting %q", resultURI, uri)
|
|
}
|
|
}
|
|
}
|
|
|
|
func testClientGetTimeoutSuccess(t *testing.T, c *Client, addr string, n int) {
|
|
var buf []byte
|
|
for i := 0; i < n; i++ {
|
|
uri := fmt.Sprintf("%s/foo/%d?bar=baz", addr, i)
|
|
statusCode, body, err := c.GetTimeout(buf, uri, time.Second)
|
|
buf = body
|
|
if err != nil {
|
|
t.Errorf("unexpected error when doing http request: %v", err)
|
|
}
|
|
if statusCode != StatusOK {
|
|
t.Errorf("unexpected status code: %d. Expecting %d", statusCode, StatusOK)
|
|
}
|
|
resultURI := string(body)
|
|
if strings.HasPrefix(uri, "https") {
|
|
resultURI = uri[:5] + resultURI[4:]
|
|
}
|
|
if resultURI != uri {
|
|
t.Errorf("unexpected uri %q. Expecting %q", resultURI, uri)
|
|
}
|
|
}
|
|
}
|
|
|
|
func testClientPost(t *testing.T, c clientPoster, addr string, n int) {
|
|
var buf []byte
|
|
var args Args
|
|
for i := 0; i < n; i++ {
|
|
uri := fmt.Sprintf("%s/foo/%d?bar=baz", addr, i)
|
|
args.Set("xx", fmt.Sprintf("yy%d", i))
|
|
args.Set("zzz", fmt.Sprintf("qwe_%d", i))
|
|
argsS := args.String()
|
|
statusCode, body, err := c.Post(buf, uri, &args)
|
|
buf = body
|
|
if err != nil {
|
|
t.Errorf("unexpected error when doing http request: %v", err)
|
|
}
|
|
if statusCode != StatusOK {
|
|
t.Errorf("unexpected status code: %d. Expecting %d", statusCode, StatusOK)
|
|
}
|
|
s := string(body)
|
|
if s != argsS {
|
|
t.Errorf("unexpected response %q. Expecting %q", s, argsS)
|
|
}
|
|
}
|
|
}
|
|
|
|
func testHostClientGet(t *testing.T, c *HostClient, n int) {
|
|
testClientGet(t, c, "http://google.com", n)
|
|
}
|
|
|
|
func testHostClientPost(t *testing.T, c *HostClient, n int) {
|
|
testClientPost(t, c, "http://post-host.com", n)
|
|
}
|
|
|
|
type clientPoster interface {
|
|
Post(dst []byte, uri string, postArgs *Args) (int, []byte, error)
|
|
}
|
|
|
|
type clientGetter interface {
|
|
Get(dst []byte, uri string) (int, []byte, error)
|
|
}
|
|
|
|
func createEchoClient(t *testing.T, network, addr string) *HostClient {
|
|
return &HostClient{
|
|
Addr: addr,
|
|
Dial: func(addr string) (net.Conn, error) {
|
|
return net.Dial(network, addr)
|
|
},
|
|
}
|
|
}
|
|
|
|
type testEchoServer struct {
|
|
s *Server
|
|
ln net.Listener
|
|
ch chan struct{}
|
|
t *testing.T
|
|
}
|
|
|
|
func (s *testEchoServer) Stop() {
|
|
s.ln.Close()
|
|
select {
|
|
case <-s.ch:
|
|
case <-time.After(time.Second):
|
|
s.t.Fatalf("timeout when waiting for server close")
|
|
}
|
|
}
|
|
|
|
func (s *testEchoServer) Addr() string {
|
|
return s.ln.Addr().String()
|
|
}
|
|
|
|
func startEchoServerTLS(t *testing.T, network, addr string) *testEchoServer {
|
|
return startEchoServerExt(t, network, addr, true)
|
|
}
|
|
|
|
func startEchoServer(t *testing.T, network, addr string) *testEchoServer {
|
|
return startEchoServerExt(t, network, addr, false)
|
|
}
|
|
|
|
func startEchoServerExt(t *testing.T, network, addr string, isTLS bool) *testEchoServer {
|
|
if network == "unix" {
|
|
os.Remove(addr)
|
|
}
|
|
var ln net.Listener
|
|
var err error
|
|
if isTLS {
|
|
certData, keyData, kerr := GenerateTestCertificate("localhost")
|
|
if kerr != nil {
|
|
t.Fatal(kerr)
|
|
}
|
|
|
|
cert, kerr := tls.X509KeyPair(certData, keyData)
|
|
if kerr != nil {
|
|
t.Fatal(kerr)
|
|
}
|
|
|
|
tlsConfig := &tls.Config{
|
|
Certificates: []tls.Certificate{cert},
|
|
}
|
|
ln, err = tls.Listen(network, addr, tlsConfig)
|
|
} else {
|
|
ln, err = net.Listen(network, addr)
|
|
}
|
|
if err != nil {
|
|
t.Fatalf("cannot listen %q: %v", addr, err)
|
|
}
|
|
|
|
s := &Server{
|
|
Handler: func(ctx *RequestCtx) {
|
|
if ctx.IsGet() {
|
|
ctx.Success("text/plain", ctx.URI().FullURI())
|
|
} else if ctx.IsPost() {
|
|
ctx.PostArgs().WriteTo(ctx) //nolint:errcheck
|
|
}
|
|
},
|
|
Logger: &testLogger{}, // Ignore log output.
|
|
}
|
|
ch := make(chan struct{})
|
|
go func() {
|
|
err := s.Serve(ln)
|
|
if err != nil {
|
|
t.Errorf("unexpected error returned from Serve(): %v", err)
|
|
}
|
|
close(ch)
|
|
}()
|
|
return &testEchoServer{
|
|
s: s,
|
|
ln: ln,
|
|
ch: ch,
|
|
t: t,
|
|
}
|
|
}
|
|
|
|
func TestClientTLSHandshakeTimeout(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
listener, err := net.Listen("tcp", "127.0.0.1:0")
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
addr := listener.Addr().String()
|
|
defer listener.Close()
|
|
|
|
complete := make(chan bool)
|
|
defer close(complete)
|
|
|
|
go func() {
|
|
conn, err := listener.Accept()
|
|
if err != nil {
|
|
t.Error(err)
|
|
return
|
|
}
|
|
<-complete
|
|
conn.Close()
|
|
}()
|
|
|
|
client := Client{
|
|
WriteTimeout: 100 * time.Millisecond,
|
|
ReadTimeout: 100 * time.Millisecond,
|
|
}
|
|
|
|
_, _, err = client.Get(nil, "https://"+addr)
|
|
if err == nil {
|
|
t.Fatal("tlsClientHandshake completed successfully")
|
|
}
|
|
|
|
if err != ErrTLSHandshakeTimeout {
|
|
t.Errorf("resulting error not a timeout: %v\nType %T: %#v", err, err, err)
|
|
}
|
|
}
|
|
|
|
func TestClientConfigureClientFailed(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
c := &Client{
|
|
ConfigureClient: func(hc *HostClient) error {
|
|
return fmt.Errorf("failed to configure")
|
|
},
|
|
}
|
|
|
|
req := Request{}
|
|
req.SetRequestURI("http://example.com")
|
|
|
|
err := c.Do(&req, &Response{})
|
|
if err == nil {
|
|
t.Fatal("expected error (failed to configure)")
|
|
}
|
|
|
|
c.ConfigureClient = nil
|
|
err = c.Do(&req, &Response{})
|
|
if err != nil {
|
|
t.Fatalf("unexpected error: %v", err)
|
|
}
|
|
}
|
|
|
|
func TestHostClientMaxConnWaitTimeoutSuccess(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
var (
|
|
emptyBodyCount uint8
|
|
ln = fasthttputil.NewInmemoryListener()
|
|
wg sync.WaitGroup
|
|
)
|
|
|
|
s := &Server{
|
|
Handler: func(ctx *RequestCtx) {
|
|
if len(ctx.PostBody()) == 0 {
|
|
emptyBodyCount++
|
|
}
|
|
time.Sleep(5 * time.Millisecond)
|
|
ctx.WriteString("foo") //nolint:errcheck
|
|
},
|
|
}
|
|
serverStopCh := make(chan struct{})
|
|
go func() {
|
|
if err := s.Serve(ln); err != nil {
|
|
t.Errorf("unexpected error: %v", err)
|
|
}
|
|
close(serverStopCh)
|
|
}()
|
|
|
|
c := &HostClient{
|
|
Addr: "foobar",
|
|
Dial: func(addr string) (net.Conn, error) {
|
|
return ln.Dial()
|
|
},
|
|
MaxConns: 1,
|
|
MaxConnWaitTimeout: time.Second * 2,
|
|
}
|
|
|
|
for i := 0; i < 5; i++ {
|
|
wg.Add(1)
|
|
go func() {
|
|
defer wg.Done()
|
|
|
|
req := AcquireRequest()
|
|
req.SetRequestURI("http://foobar/baz")
|
|
req.Header.SetMethod(MethodPost)
|
|
req.SetBodyString("bar")
|
|
resp := AcquireResponse()
|
|
|
|
if err := c.Do(req, resp); err != nil {
|
|
t.Errorf("unexpected error: %v", err)
|
|
}
|
|
|
|
if resp.StatusCode() != StatusOK {
|
|
t.Errorf("unexpected status code %d. Expecting %d", resp.StatusCode(), StatusOK)
|
|
}
|
|
|
|
body := resp.Body()
|
|
if string(body) != "foo" {
|
|
t.Errorf("unexpected body %q. Expecting %q", body, "abcd")
|
|
}
|
|
}()
|
|
}
|
|
wg.Wait()
|
|
|
|
if c.connsWait.len() > 0 {
|
|
t.Errorf("connsWait has %v items remaining", c.connsWait.len())
|
|
}
|
|
if err := ln.Close(); err != nil {
|
|
t.Fatalf("unexpected error: %v", err)
|
|
}
|
|
select {
|
|
case <-serverStopCh:
|
|
case <-time.After(time.Second * 5):
|
|
t.Fatalf("timeout")
|
|
}
|
|
|
|
if emptyBodyCount > 0 {
|
|
t.Fatalf("at least one request body was empty")
|
|
}
|
|
}
|
|
|
|
func TestHostClientMaxConnWaitTimeoutError(t *testing.T) {
|
|
var (
|
|
emptyBodyCount uint8
|
|
ln = fasthttputil.NewInmemoryListener()
|
|
wg sync.WaitGroup
|
|
)
|
|
|
|
s := &Server{
|
|
Handler: func(ctx *RequestCtx) {
|
|
if len(ctx.PostBody()) == 0 {
|
|
emptyBodyCount++
|
|
}
|
|
time.Sleep(5 * time.Millisecond)
|
|
ctx.WriteString("foo") //nolint:errcheck
|
|
},
|
|
}
|
|
serverStopCh := make(chan struct{})
|
|
go func() {
|
|
if err := s.Serve(ln); err != nil {
|
|
t.Errorf("unexpected error: %v", err)
|
|
}
|
|
close(serverStopCh)
|
|
}()
|
|
|
|
c := &HostClient{
|
|
Addr: "foobar",
|
|
Dial: func(addr string) (net.Conn, error) {
|
|
return ln.Dial()
|
|
},
|
|
MaxConns: 1,
|
|
MaxConnWaitTimeout: 10 * time.Millisecond,
|
|
}
|
|
|
|
var errNoFreeConnsCount uint32
|
|
for i := 0; i < 5; i++ {
|
|
wg.Add(1)
|
|
go func() {
|
|
defer wg.Done()
|
|
|
|
req := AcquireRequest()
|
|
req.SetRequestURI("http://foobar/baz")
|
|
req.Header.SetMethod(MethodPost)
|
|
req.SetBodyString("bar")
|
|
resp := AcquireResponse()
|
|
|
|
if err := c.Do(req, resp); err != nil {
|
|
if err != ErrNoFreeConns {
|
|
t.Errorf("unexpected error: %v. Expecting %v", err, ErrNoFreeConns)
|
|
}
|
|
atomic.AddUint32(&errNoFreeConnsCount, 1)
|
|
} else {
|
|
if resp.StatusCode() != StatusOK {
|
|
t.Errorf("unexpected status code %d. Expecting %d", resp.StatusCode(), StatusOK)
|
|
}
|
|
|
|
body := resp.Body()
|
|
if string(body) != "foo" {
|
|
t.Errorf("unexpected body %q. Expecting %q", body, "abcd")
|
|
}
|
|
}
|
|
}()
|
|
}
|
|
wg.Wait()
|
|
|
|
time.Sleep(time.Millisecond * 100)
|
|
|
|
// Prevent a race condition with the conns cleaner that might still be running.
|
|
c.connsLock.Lock()
|
|
defer c.connsLock.Unlock()
|
|
|
|
if c.connsWait.len() > 0 {
|
|
t.Errorf("connsWait has %v items remaining", c.connsWait.len())
|
|
}
|
|
if errNoFreeConnsCount == 0 {
|
|
t.Errorf("unexpected errorCount: %d. Expecting > 0", errNoFreeConnsCount)
|
|
}
|
|
if err := ln.Close(); err != nil {
|
|
t.Fatalf("unexpected error: %v", err)
|
|
}
|
|
select {
|
|
case <-serverStopCh:
|
|
case <-time.After(time.Second):
|
|
t.Fatalf("timeout")
|
|
}
|
|
|
|
if emptyBodyCount > 0 {
|
|
t.Fatalf("at least one request body was empty")
|
|
}
|
|
}
|
|
|
|
func TestHostClientMaxConnWaitTimeoutWithEarlierDeadline(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
var (
|
|
emptyBodyCount uint8
|
|
ln = fasthttputil.NewInmemoryListener()
|
|
wg sync.WaitGroup
|
|
// make deadline reach earlier than conns wait timeout
|
|
sleep = 100 * time.Millisecond
|
|
timeout = 10 * time.Millisecond
|
|
maxConnWaitTimeout = 50 * time.Millisecond
|
|
)
|
|
|
|
s := &Server{
|
|
Handler: func(ctx *RequestCtx) {
|
|
if len(ctx.PostBody()) == 0 {
|
|
emptyBodyCount++
|
|
}
|
|
time.Sleep(sleep)
|
|
ctx.WriteString("foo") //nolint:errcheck
|
|
},
|
|
Logger: &testLogger{}, // Don't print connection closed errors.
|
|
}
|
|
serverStopCh := make(chan struct{})
|
|
go func() {
|
|
if err := s.Serve(ln); err != nil {
|
|
t.Errorf("unexpected error: %v", err)
|
|
}
|
|
close(serverStopCh)
|
|
}()
|
|
|
|
c := &HostClient{
|
|
Addr: "foobar",
|
|
Dial: func(addr string) (net.Conn, error) {
|
|
return ln.Dial()
|
|
},
|
|
MaxConns: 1,
|
|
MaxConnWaitTimeout: maxConnWaitTimeout,
|
|
}
|
|
|
|
var errTimeoutCount uint32
|
|
for i := 0; i < 5; i++ {
|
|
wg.Add(1)
|
|
go func() {
|
|
defer wg.Done()
|
|
|
|
req := AcquireRequest()
|
|
req.SetRequestURI("http://foobar/baz")
|
|
req.Header.SetMethod(MethodPost)
|
|
req.SetBodyString("bar")
|
|
resp := AcquireResponse()
|
|
|
|
if err := c.DoDeadline(req, resp, time.Now().Add(timeout)); err != nil {
|
|
if err != ErrTimeout {
|
|
t.Errorf("unexpected error: %v. Expecting %v", err, ErrTimeout)
|
|
}
|
|
atomic.AddUint32(&errTimeoutCount, 1)
|
|
} else {
|
|
if resp.StatusCode() != StatusOK {
|
|
t.Errorf("unexpected status code %d. Expecting %d", resp.StatusCode(), StatusOK)
|
|
}
|
|
|
|
body := resp.Body()
|
|
if string(body) != "foo" {
|
|
t.Errorf("unexpected body %q. Expecting %q", body, "abcd")
|
|
}
|
|
}
|
|
}()
|
|
}
|
|
wg.Wait()
|
|
|
|
c.connsLock.Lock()
|
|
for {
|
|
w := c.connsWait.popFront()
|
|
if w == nil {
|
|
break
|
|
}
|
|
w.mu.Lock()
|
|
if w.err != nil && w.err != ErrTimeout {
|
|
t.Errorf("unexpected error: %v. Expecting %v", w.err, ErrTimeout)
|
|
}
|
|
w.mu.Unlock()
|
|
}
|
|
c.connsLock.Unlock()
|
|
if errTimeoutCount == 0 {
|
|
t.Errorf("unexpected errTimeoutCount: %d. Expecting > 0", errTimeoutCount)
|
|
}
|
|
if err := ln.Close(); err != nil {
|
|
t.Fatalf("unexpected error: %v", err)
|
|
}
|
|
select {
|
|
case <-serverStopCh:
|
|
case <-time.After(time.Second):
|
|
t.Fatalf("timeout")
|
|
}
|
|
|
|
if emptyBodyCount > 0 {
|
|
t.Fatalf("at least one request body was empty")
|
|
}
|
|
}
|
|
|
|
type TransportEmpty struct{}
|
|
|
|
func (t TransportEmpty) RoundTrip(hc *HostClient, req *Request, res *Response) (retry bool, err error) {
|
|
return false, nil
|
|
}
|
|
|
|
func TestHttpsRequestWithoutParsedURL(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
client := HostClient{
|
|
IsTLS: true,
|
|
Transport: TransportEmpty{},
|
|
}
|
|
|
|
req := &Request{}
|
|
|
|
req.SetRequestURI("https://foo.com/bar")
|
|
|
|
_, err := client.doNonNilReqResp(req, &Response{})
|
|
if err != nil {
|
|
t.Fatal("https requests with IsTLS client must succeed")
|
|
}
|
|
}
|
|
|
|
func TestHostClientErrConnPoolStrategyNotImpl(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
ln := fasthttputil.NewInmemoryListener()
|
|
server := &Server{
|
|
Handler: func(ctx *RequestCtx) {},
|
|
}
|
|
serverStopCh := make(chan struct{})
|
|
go func() {
|
|
if err := server.Serve(ln); err != nil {
|
|
t.Errorf("unexpected error: %v", err)
|
|
}
|
|
close(serverStopCh)
|
|
}()
|
|
|
|
client := &HostClient{
|
|
Addr: "foobar",
|
|
Dial: func(addr string) (net.Conn, error) {
|
|
return ln.Dial()
|
|
},
|
|
ConnPoolStrategy: ConnPoolStrategyType(100),
|
|
}
|
|
|
|
req := AcquireRequest()
|
|
req.SetRequestURI("http://foobar/baz")
|
|
|
|
if err := client.Do(req, AcquireResponse()); err != nil {
|
|
t.Fatalf("unexpected error: %v", err)
|
|
}
|
|
if err := client.Do(req, &Response{}); err != ErrConnPoolStrategyNotImpl {
|
|
t.Errorf("expected ErrConnPoolStrategyNotImpl error, got %v", err)
|
|
}
|
|
if err := client.Do(req, &Response{}); err != ErrConnPoolStrategyNotImpl {
|
|
t.Errorf("expected ErrConnPoolStrategyNotImpl error, got %v", err)
|
|
}
|
|
|
|
if err := ln.Close(); err != nil {
|
|
t.Fatalf("unexpected error: %v", err)
|
|
}
|
|
}
|
|
|
|
func Test_AddMissingPort(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
type args struct {
|
|
addr string
|
|
isTLS bool
|
|
}
|
|
tests := []struct {
|
|
name string
|
|
args args
|
|
want string
|
|
}{
|
|
{
|
|
args: args{"127.1", false}, // 127.1 is a short form of 127.0.0.1
|
|
want: "127.1:80",
|
|
},
|
|
{
|
|
args: args{"127.0.0.1", false},
|
|
want: "127.0.0.1:80",
|
|
},
|
|
{
|
|
args: args{"127.0.0.1", true},
|
|
want: "127.0.0.1:443",
|
|
},
|
|
{
|
|
args: args{"[::1]", false},
|
|
want: "[::1]:80",
|
|
},
|
|
{
|
|
args: args{"::1", false},
|
|
want: "::1", // keep as is
|
|
},
|
|
{
|
|
args: args{"[::1]", true},
|
|
want: "[::1]:443",
|
|
},
|
|
{
|
|
args: args{"127.0.0.1:8080", false},
|
|
want: "127.0.0.1:8080",
|
|
},
|
|
{
|
|
args: args{"127.0.0.1:8443", true},
|
|
want: "127.0.0.1:8443",
|
|
},
|
|
{
|
|
args: args{"[::1]:8080", false},
|
|
want: "[::1]:8080",
|
|
},
|
|
{
|
|
args: args{"[::1]:8443", true},
|
|
want: "[::1]:8443",
|
|
},
|
|
}
|
|
for _, tt := range tests {
|
|
t.Run(tt.want, func(t *testing.T) {
|
|
if got := AddMissingPort(tt.args.addr, tt.args.isTLS); got != tt.want {
|
|
t.Errorf("AddMissingPort() = %v, want %v", got, tt.want)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
type TransportWrapper struct {
|
|
base RoundTripper
|
|
count *int
|
|
t *testing.T
|
|
}
|
|
|
|
func (tw *TransportWrapper) RoundTrip(hc *HostClient, req *Request, resp *Response) (bool, error) {
|
|
req.Header.Set("trace-id", "123")
|
|
tw.assertRequestLog(req.String())
|
|
retry, err := tw.transport().RoundTrip(hc, req, resp)
|
|
resp.Header.Set("trace-id", "124")
|
|
tw.assertResponseLog(resp.String())
|
|
*tw.count++
|
|
return retry, err
|
|
}
|
|
|
|
func (tw *TransportWrapper) transport() RoundTripper {
|
|
if tw.base == nil {
|
|
return DefaultTransport
|
|
}
|
|
return tw.base
|
|
}
|
|
|
|
func (tw *TransportWrapper) assertRequestLog(reqLog string) {
|
|
if !strings.Contains(reqLog, "Trace-Id: 123") {
|
|
tw.t.Errorf("request log should contains: %v", "Trace-Id: 123")
|
|
}
|
|
}
|
|
|
|
func (tw *TransportWrapper) assertResponseLog(respLog string) {
|
|
if !strings.Contains(respLog, "Trace-Id: 124") {
|
|
tw.t.Errorf("response log should contains: %v", "Trace-Id: 124")
|
|
}
|
|
}
|
|
|
|
func TestClientTransportEx(t *testing.T) {
|
|
sHTTP := startEchoServer(t, "tcp", "127.0.0.1:")
|
|
defer sHTTP.Stop()
|
|
|
|
sHTTPS := startEchoServerTLS(t, "tcp", "127.0.0.1:")
|
|
defer sHTTPS.Stop()
|
|
|
|
count := 0
|
|
c := &Client{
|
|
TLSConfig: &tls.Config{
|
|
InsecureSkipVerify: true,
|
|
},
|
|
ConfigureClient: func(hc *HostClient) error {
|
|
hc.Transport = &TransportWrapper{base: hc.Transport, count: &count, t: t}
|
|
return nil
|
|
},
|
|
}
|
|
// test transport
|
|
const loopCount = 4
|
|
const getCount = 20
|
|
const postCount = 10
|
|
for i := 0; i < loopCount; i++ {
|
|
addr := "http://" + sHTTP.Addr()
|
|
if i&1 != 0 {
|
|
addr = "https://" + sHTTPS.Addr()
|
|
}
|
|
// test get
|
|
testClientGet(t, c, addr, getCount)
|
|
// test post
|
|
testClientPost(t, c, addr, postCount)
|
|
}
|
|
roundTripCount := loopCount * (getCount + postCount)
|
|
if count != roundTripCount {
|
|
t.Errorf("round trip count should be: %v", roundTripCount)
|
|
}
|
|
}
|
|
|
|
func Test_getRedirectURL(t *testing.T) {
|
|
type args struct {
|
|
baseURL string
|
|
location []byte
|
|
disablePathNormalizing bool
|
|
}
|
|
tests := []struct {
|
|
name string
|
|
args args
|
|
want string
|
|
}{
|
|
{
|
|
name: "Path normalizing enabled, no special characters in path",
|
|
args: args{
|
|
baseURL: "http://foo.example.com/abc",
|
|
location: []byte("http://bar.example.com/def"),
|
|
disablePathNormalizing: false,
|
|
},
|
|
want: "http://bar.example.com/def",
|
|
},
|
|
{
|
|
name: "Path normalizing enabled, special characters in path",
|
|
args: args{
|
|
baseURL: "http://foo.example.com/abc/*/def",
|
|
location: []byte("http://bar.example.com/123/*/456"),
|
|
disablePathNormalizing: false,
|
|
},
|
|
want: "http://bar.example.com/123/%2A/456",
|
|
},
|
|
{
|
|
name: "Path normalizing disabled, no special characters in path",
|
|
args: args{
|
|
baseURL: "http://foo.example.com/abc",
|
|
location: []byte("http://bar.example.com/def"),
|
|
disablePathNormalizing: true,
|
|
},
|
|
want: "http://bar.example.com/def",
|
|
},
|
|
{
|
|
name: "Path normalizing disabled, special characters in path",
|
|
args: args{
|
|
baseURL: "http://foo.example.com/abc/*/def",
|
|
location: []byte("http://bar.example.com/123/*/456"),
|
|
disablePathNormalizing: true,
|
|
},
|
|
want: "http://bar.example.com/123/*/456",
|
|
},
|
|
}
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
if got := getRedirectURL(tt.args.baseURL, tt.args.location, tt.args.disablePathNormalizing); got != tt.want {
|
|
t.Errorf("getRedirectURL() = %v, want %v", got, tt.want)
|
|
}
|
|
})
|
|
}
|
|
}
|