Files
2026-03-31 20:02:01 +00:00

3395 lines
74 KiB
Go

package fasthttp
import (
"bufio"
"bytes"
"crypto/tls"
"fmt"
"io"
"net"
"net/url"
"os"
"regexp"
"runtime"
"strings"
"sync"
"sync/atomic"
"testing"
"time"
"github.com/valyala/fasthttp/fasthttputil"
)
func TestCloseIdleConnections(t *testing.T) {
t.Parallel()
ln := fasthttputil.NewInmemoryListener()
s := &Server{
Handler: func(ctx *RequestCtx) {
},
}
go func() {
if err := s.Serve(ln); err != nil {
t.Error(err)
}
}()
c := &Client{
Dial: func(addr string) (net.Conn, error) {
return ln.Dial()
},
}
if _, _, err := c.Get(nil, "http://google.com"); err != nil {
t.Fatal(err)
}
connsLen := func() int {
c.mLock.Lock()
defer c.mLock.Unlock()
if _, ok := c.m["google.com"]; !ok {
return 0
}
c.m["google.com"].connsLock.Lock()
defer c.m["google.com"].connsLock.Unlock()
return len(c.m["google.com"].conns)
}
if conns := connsLen(); conns > 1 {
t.Errorf("expected 1 conns got %d", conns)
}
c.CloseIdleConnections()
if conns := connsLen(); conns > 0 {
t.Errorf("expected 0 conns got %d", conns)
}
}
func TestPipelineClientSetUserAgent(t *testing.T) {
t.Parallel()
testPipelineClientSetUserAgent(t, 0)
}
func TestPipelineClientSetUserAgentTimeout(t *testing.T) {
t.Parallel()
testPipelineClientSetUserAgent(t, time.Second)
}
func testPipelineClientSetUserAgent(t *testing.T, timeout time.Duration) {
ln := fasthttputil.NewInmemoryListener()
userAgentSeen := ""
s := &Server{
Handler: func(ctx *RequestCtx) {
userAgentSeen = string(ctx.UserAgent())
},
}
go s.Serve(ln) //nolint:errcheck
userAgent := "I'm not fasthttp"
c := &HostClient{
Name: userAgent,
Dial: func(addr string) (net.Conn, error) {
return ln.Dial()
},
}
req := AcquireRequest()
res := AcquireResponse()
req.SetRequestURI("http://example.com")
var err error
if timeout <= 0 {
err = c.Do(req, res)
} else {
err = c.DoTimeout(req, res, timeout)
}
if err != nil {
t.Fatal(err)
}
if userAgentSeen != userAgent {
t.Fatalf("User-Agent defers %q != %q", userAgentSeen, userAgent)
}
}
func TestHostClientNegativeTimeout(t *testing.T) {
t.Parallel()
ln := fasthttputil.NewInmemoryListener()
s := &Server{
Handler: func(ctx *RequestCtx) {
},
}
go s.Serve(ln) //nolint:errcheck
c := &HostClient{
Dial: func(addr string) (net.Conn, error) {
return ln.Dial()
},
}
req := AcquireRequest()
req.Header.SetMethod(MethodGet)
req.SetRequestURI("http://example.com")
if err := c.DoTimeout(req, nil, -time.Second); err != ErrTimeout {
t.Fatalf("expected ErrTimeout error got: %+v", err)
}
if err := c.DoDeadline(req, nil, time.Now().Add(-time.Second)); err != ErrTimeout {
t.Fatalf("expected ErrTimeout error got: %+v", err)
}
ln.Close()
}
func TestDoDeadlineRetry(t *testing.T) {
t.Parallel()
tries := 0
done := make(chan struct{})
ln := fasthttputil.NewInmemoryListener()
go func() {
for {
c, err := ln.Accept()
if err != nil {
close(done)
break
}
tries++
br := bufio.NewReader(c)
(&RequestHeader{}).Read(br) //nolint:errcheck
(&Request{}).readBodyStream(br, 0, false, false) //nolint:errcheck
time.Sleep(time.Millisecond * 60)
c.Close()
}
}()
c := &HostClient{
Dial: func(addr string) (net.Conn, error) {
return ln.Dial()
},
}
req := AcquireRequest()
req.Header.SetMethod(MethodGet)
req.SetRequestURI("http://example.com")
if err := c.DoDeadline(req, nil, time.Now().Add(time.Millisecond*100)); err != ErrTimeout {
t.Fatalf("expected ErrTimeout error got: %+v", err)
}
ln.Close()
<-done
if tries != 2 {
t.Fatalf("expected 2 tries got %d", tries)
}
}
func TestPipelineClientIssue832(t *testing.T) {
t.Parallel()
ln := fasthttputil.NewInmemoryListener()
req := AcquireRequest()
// Don't defer ReleaseRequest as we use it in a goroutine that might not be done at the end.
req.SetHost("example.com")
res := AcquireResponse()
// Don't defer ReleaseResponse as we use it in a goroutine that might not be done at the end.
client := PipelineClient{
Dial: func(addr string) (net.Conn, error) {
return ln.Dial()
},
ReadTimeout: time.Millisecond * 10,
Logger: &testLogger{}, // Ignore log output.
}
attempts := 10
go func() {
for i := 0; i < attempts; i++ {
c, err := ln.Accept()
if err != nil {
t.Error(err)
}
if c != nil {
go func() {
time.Sleep(time.Millisecond * 50)
c.Close()
}()
}
}
}()
done := make(chan int)
go func() {
defer close(done)
for i := 0; i < attempts; i++ {
if err := client.Do(req, res); err == nil {
t.Error("error expected")
}
}
}()
select {
case <-time.After(time.Second * 2):
t.Fatal("PipelineClient did not restart worker")
case <-done:
}
}
func TestClientInvalidURI(t *testing.T) {
t.Parallel()
ln := fasthttputil.NewInmemoryListener()
requests := int64(0)
s := &Server{
Handler: func(_ *RequestCtx) {
atomic.AddInt64(&requests, 1)
},
}
go s.Serve(ln) //nolint:errcheck
c := &Client{
Dial: func(addr string) (net.Conn, error) {
return ln.Dial()
},
}
req, res := AcquireRequest(), AcquireResponse()
defer func() {
ReleaseRequest(req)
ReleaseResponse(res)
}()
req.Header.SetMethod(MethodGet)
req.SetRequestURI("http://example.com\r\n\r\nGET /\r\n\r\n")
err := c.Do(req, res)
if err == nil {
t.Fatal("expected error (missing required Host header in request)")
}
if n := atomic.LoadInt64(&requests); n != 0 {
t.Fatalf("0 requests expected, got %d", n)
}
}
func TestClientGetWithBody(t *testing.T) {
t.Parallel()
ln := fasthttputil.NewInmemoryListener()
s := &Server{
Handler: func(ctx *RequestCtx) {
body := ctx.Request.Body()
ctx.Write(body) //nolint:errcheck
},
}
go s.Serve(ln) //nolint:errcheck
c := &Client{
Dial: func(addr string) (net.Conn, error) {
return ln.Dial()
},
}
req, res := AcquireRequest(), AcquireResponse()
defer func() {
ReleaseRequest(req)
ReleaseResponse(res)
}()
req.Header.SetMethod(MethodGet)
req.SetRequestURI("http://example.com")
req.SetBodyString("test")
err := c.Do(req, res)
if err != nil {
t.Fatal(err)
}
if len(res.Body()) == 0 {
t.Fatal("missing request body")
}
}
func TestClientURLAuth(t *testing.T) {
t.Parallel()
cases := map[string]string{
"user:pass@": "Basic dXNlcjpwYXNz",
"foo:@": "Basic Zm9vOg==",
":@": "",
"@": "",
"": "",
}
ch := make(chan string, 1)
ln := fasthttputil.NewInmemoryListener()
s := &Server{
Handler: func(ctx *RequestCtx) {
ch <- string(ctx.Request.Header.Peek(HeaderAuthorization))
},
}
go s.Serve(ln) //nolint:errcheck
c := &Client{
Dial: func(addr string) (net.Conn, error) {
return ln.Dial()
},
}
for up, expected := range cases {
req := AcquireRequest()
req.Header.SetMethod(MethodGet)
req.SetRequestURI("http://" + up + "example.com/foo/bar")
if err := c.Do(req, nil); err != nil {
t.Fatal(err)
}
val := <-ch
if val != expected {
t.Fatalf("wrong %q header: %q expected %q", HeaderAuthorization, val, expected)
}
}
}
func TestClientNilResp(t *testing.T) {
t.Parallel()
ln := fasthttputil.NewInmemoryListener()
s := &Server{
Handler: func(ctx *RequestCtx) {
},
}
go s.Serve(ln) //nolint:errcheck
c := &Client{
Dial: func(addr string) (net.Conn, error) {
return ln.Dial()
},
}
req := AcquireRequest()
req.Header.SetMethod(MethodGet)
req.SetRequestURI("http://example.com")
if err := c.Do(req, nil); err != nil {
t.Fatal(err)
}
if err := c.DoTimeout(req, nil, time.Second); err != nil {
t.Fatal(err)
}
ln.Close()
}
func TestClientNegativeTimeout(t *testing.T) {
t.Parallel()
ln := fasthttputil.NewInmemoryListener()
s := &Server{
Handler: func(ctx *RequestCtx) {
},
}
go s.Serve(ln) //nolint:errcheck
c := &Client{
Dial: func(addr string) (net.Conn, error) {
return ln.Dial()
},
}
req := AcquireRequest()
req.Header.SetMethod(MethodGet)
req.SetRequestURI("http://example.com")
if err := c.DoTimeout(req, nil, -time.Second); err != ErrTimeout {
t.Fatalf("expected ErrTimeout error got: %+v", err)
}
if err := c.DoDeadline(req, nil, time.Now().Add(-time.Second)); err != ErrTimeout {
t.Fatalf("expected ErrTimeout error got: %+v", err)
}
ln.Close()
}
func TestPipelineClientNilResp(t *testing.T) {
t.Parallel()
ln := fasthttputil.NewInmemoryListener()
s := &Server{
Handler: func(ctx *RequestCtx) {
},
}
go s.Serve(ln) //nolint:errcheck
c := &PipelineClient{
Dial: func(addr string) (net.Conn, error) {
return ln.Dial()
},
}
req := AcquireRequest()
req.Header.SetMethod(MethodGet)
req.SetRequestURI("http://example.com")
if err := c.Do(req, nil); err != nil {
t.Fatal(err)
}
if err := c.DoTimeout(req, nil, time.Second); err != nil {
t.Fatal(err)
}
if err := c.DoDeadline(req, nil, time.Now().Add(time.Second)); err != nil {
t.Fatal(err)
}
}
func TestClientParseConn(t *testing.T) {
t.Parallel()
network := "tcp"
ln, _ := net.Listen(network, "127.0.0.1:0")
s := &Server{
Handler: func(ctx *RequestCtx) {
},
}
go s.Serve(ln) //nolint:errcheck
host := ln.Addr().String()
c := &Client{}
req, res := AcquireRequest(), AcquireResponse()
defer func() {
ReleaseRequest(req)
ReleaseResponse(res)
}()
req.SetRequestURI("http://" + host + "")
if err := c.Do(req, res); err != nil {
t.Fatal(err)
}
if res.RemoteAddr().Network() != network {
t.Fatalf("req RemoteAddr parse network fail: %q, hope: %q", res.RemoteAddr().Network(), network)
}
if host != res.RemoteAddr().String() {
t.Fatalf("req RemoteAddr parse addr fail: %q, hope: %q", res.RemoteAddr().String(), host)
}
if !regexp.MustCompile(`^127\.0\.0\.1:[0-9]{4,5}$`).MatchString(res.LocalAddr().String()) {
t.Fatalf("res LocalAddr addr match fail: %q, hope match: %q", res.LocalAddr().String(), "^127.0.0.1:[0-9]{4,5}$")
}
}
func TestClientPostArgs(t *testing.T) {
t.Parallel()
ln := fasthttputil.NewInmemoryListener()
s := &Server{
Handler: func(ctx *RequestCtx) {
body := ctx.Request.Body()
if len(body) == 0 {
return
}
ctx.Write(body) //nolint:errcheck
},
}
go s.Serve(ln) //nolint:errcheck
c := &Client{
Dial: func(addr string) (net.Conn, error) {
return ln.Dial()
},
}
req, res := AcquireRequest(), AcquireResponse()
defer func() {
ReleaseRequest(req)
ReleaseResponse(res)
}()
args := req.PostArgs()
args.Add("addhttp2", "support")
args.Add("fast", "http")
req.Header.SetMethod(MethodPost)
req.SetRequestURI("http://make.fasthttp.great?again")
err := c.Do(req, res)
if err != nil {
t.Fatal(err)
}
if len(res.Body()) == 0 {
t.Fatal("cannot set args as body")
}
}
func TestClientRedirectSameSchema(t *testing.T) {
t.Parallel()
listenHTTPS1 := testClientRedirectListener(t, true)
defer listenHTTPS1.Close()
listenHTTPS2 := testClientRedirectListener(t, true)
defer listenHTTPS2.Close()
sHTTPS1 := testClientRedirectChangingSchemaServer(t, listenHTTPS1, listenHTTPS1, true)
defer sHTTPS1.Stop()
sHTTPS2 := testClientRedirectChangingSchemaServer(t, listenHTTPS2, listenHTTPS2, false)
defer sHTTPS2.Stop()
destURL := fmt.Sprintf("https://%s/baz", listenHTTPS1.Addr().String())
urlParsed, err := url.Parse(destURL)
if err != nil {
t.Fatal(err)
return
}
reqClient := &HostClient{
IsTLS: true,
Addr: urlParsed.Host,
TLSConfig: &tls.Config{
InsecureSkipVerify: true,
},
}
statusCode, _, err := reqClient.GetTimeout(nil, destURL, 4000*time.Millisecond)
if err != nil {
t.Fatalf("HostClient error: %v", err)
return
}
if statusCode != 200 {
t.Fatalf("HostClient error code response %d", statusCode)
return
}
}
func TestClientRedirectClientChangingSchemaHttp2Https(t *testing.T) {
t.Parallel()
listenHTTPS := testClientRedirectListener(t, true)
defer listenHTTPS.Close()
listenHTTP := testClientRedirectListener(t, false)
defer listenHTTP.Close()
sHTTPS := testClientRedirectChangingSchemaServer(t, listenHTTPS, listenHTTP, true)
defer sHTTPS.Stop()
sHTTP := testClientRedirectChangingSchemaServer(t, listenHTTPS, listenHTTP, false)
defer sHTTP.Stop()
destURL := fmt.Sprintf("http://%s/baz", listenHTTP.Addr().String())
reqClient := &Client{
TLSConfig: &tls.Config{
InsecureSkipVerify: true,
},
}
statusCode, _, err := reqClient.GetTimeout(nil, destURL, 4000*time.Millisecond)
if err != nil {
t.Fatalf("HostClient error: %v", err)
return
}
if statusCode != 200 {
t.Fatalf("HostClient error code response %d", statusCode)
return
}
}
func TestClientRedirectHostClientChangingSchemaHttp2Https(t *testing.T) {
t.Parallel()
listenHTTPS := testClientRedirectListener(t, true)
defer listenHTTPS.Close()
listenHTTP := testClientRedirectListener(t, false)
defer listenHTTP.Close()
sHTTPS := testClientRedirectChangingSchemaServer(t, listenHTTPS, listenHTTP, true)
defer sHTTPS.Stop()
sHTTP := testClientRedirectChangingSchemaServer(t, listenHTTPS, listenHTTP, false)
defer sHTTP.Stop()
destURL := fmt.Sprintf("http://%s/baz", listenHTTP.Addr().String())
urlParsed, err := url.Parse(destURL)
if err != nil {
t.Fatal(err)
return
}
reqClient := &HostClient{
Addr: urlParsed.Host,
TLSConfig: &tls.Config{
InsecureSkipVerify: true,
},
}
_, _, err = reqClient.GetTimeout(nil, destURL, 4000*time.Millisecond)
if err != ErrHostClientRedirectToDifferentScheme {
t.Fatal("expected HostClient error")
}
}
func testClientRedirectListener(t *testing.T, isTLS bool) net.Listener {
var ln net.Listener
var err error
var tlsConfig *tls.Config
if isTLS {
certData, keyData, kerr := GenerateTestCertificate("localhost")
if kerr != nil {
t.Fatal(kerr)
}
cert, kerr := tls.X509KeyPair(certData, keyData)
if kerr != nil {
t.Fatal(kerr)
}
tlsConfig = &tls.Config{
Certificates: []tls.Certificate{cert},
}
ln, err = tls.Listen("tcp", "localhost:0", tlsConfig)
} else {
ln, err = net.Listen("tcp", "localhost:0")
}
if err != nil {
t.Fatalf("cannot listen isTLS %v: %v", isTLS, err)
}
return ln
}
func testClientRedirectChangingSchemaServer(t *testing.T, https, http net.Listener, isTLS bool) *testEchoServer {
s := &Server{
Handler: func(ctx *RequestCtx) {
if ctx.IsTLS() {
ctx.SetStatusCode(200)
} else {
ctx.Redirect(fmt.Sprintf("https://%s/baz", https.Addr().String()), 301)
}
},
}
var ln net.Listener
if isTLS {
ln = https
} else {
ln = http
}
ch := make(chan struct{})
go func() {
err := s.Serve(ln)
if err != nil {
t.Errorf("unexpected error returned from Serve(): %v", err)
}
close(ch)
}()
return &testEchoServer{
s: s,
ln: ln,
ch: ch,
t: t,
}
}
func TestClientHeaderCase(t *testing.T) {
t.Parallel()
ln := fasthttputil.NewInmemoryListener()
defer ln.Close()
go func() {
c, err := ln.Accept()
if err != nil {
t.Error(err)
}
c.Write([]byte("HTTP/1.1 200 OK\r\n" + //nolint:errcheck
"content-type: text/plain\r\n" +
"transfer-encoding: chunked\r\n\r\n" +
"24\r\nThis is the data in the first chunk \r\n" +
"1B\r\nand this is the second one \r\n" +
"0\r\n\r\n",
))
}()
c := &Client{
Dial: func(addr string) (net.Conn, error) {
return ln.Dial()
},
ReadTimeout: time.Millisecond * 10,
// Even without name normalizing we should parse headers correctly.
DisableHeaderNamesNormalizing: true,
}
code, body, err := c.Get(nil, "http://example.com")
if err != nil {
t.Fatal(err)
}
if code != 200 {
t.Errorf("expected status code 200 got %d", code)
}
if string(body) != "This is the data in the first chunk and this is the second one " {
t.Errorf("wrong body: %q", body)
}
}
func TestClientReadTimeout(t *testing.T) {
if runtime.GOOS == "windows" {
t.SkipNow()
}
t.Parallel()
ln := fasthttputil.NewInmemoryListener()
timeout := false
s := &Server{
Handler: func(_ *RequestCtx) {
if timeout {
time.Sleep(time.Second)
} else {
timeout = true
}
},
Logger: &testLogger{}, // Don't print closed pipe errors.
}
go s.Serve(ln) //nolint:errcheck
c := &HostClient{
ReadTimeout: time.Millisecond * 400,
MaxIdemponentCallAttempts: 1,
Dial: func(addr string) (net.Conn, error) {
return ln.Dial()
},
}
req := AcquireRequest()
res := AcquireResponse()
req.SetRequestURI("http://localhost")
// Setting Connection: Close will make the connection be
// returned to the pool.
req.SetConnectionClose()
if err := c.Do(req, res); err != nil {
t.Fatal(err)
}
ReleaseRequest(req)
ReleaseResponse(res)
done := make(chan struct{})
go func() {
req := AcquireRequest()
res := AcquireResponse()
req.SetRequestURI("http://localhost")
req.SetConnectionClose()
if err := c.Do(req, res); err != ErrTimeout {
t.Errorf("expected ErrTimeout got %#v", err)
}
ReleaseRequest(req)
ReleaseResponse(res)
close(done)
}()
select {
case <-done:
// This shouldn't take longer than the timeout times the number of requests it is going to try to do.
// Give it an extra second just to be sure.
case <-time.After(c.ReadTimeout*time.Duration(c.MaxIdemponentCallAttempts) + time.Second):
t.Fatal("Client.ReadTimeout didn't work")
}
}
func TestClientDefaultUserAgent(t *testing.T) {
t.Parallel()
ln := fasthttputil.NewInmemoryListener()
userAgentSeen := ""
s := &Server{
Handler: func(ctx *RequestCtx) {
userAgentSeen = string(ctx.UserAgent())
},
}
go s.Serve(ln) //nolint:errcheck
c := &Client{
Dial: func(addr string) (net.Conn, error) {
return ln.Dial()
},
}
req := AcquireRequest()
res := AcquireResponse()
req.SetRequestURI("http://example.com")
err := c.Do(req, res)
if err != nil {
t.Fatal(err)
}
if userAgentSeen != defaultUserAgent {
t.Fatalf("User-Agent defers %q != %q", userAgentSeen, defaultUserAgent)
}
}
func TestClientSetUserAgent(t *testing.T) {
t.Parallel()
ln := fasthttputil.NewInmemoryListener()
userAgentSeen := ""
s := &Server{
Handler: func(ctx *RequestCtx) {
userAgentSeen = string(ctx.UserAgent())
},
}
go s.Serve(ln) //nolint:errcheck
userAgent := "I'm not fasthttp"
c := &Client{
Name: userAgent,
Dial: func(addr string) (net.Conn, error) {
return ln.Dial()
},
}
req := AcquireRequest()
res := AcquireResponse()
req.SetRequestURI("http://example.com")
err := c.Do(req, res)
if err != nil {
t.Fatal(err)
}
if userAgentSeen != userAgent {
t.Fatalf("User-Agent defers %q != %q", userAgentSeen, userAgent)
}
}
func TestClientNoUserAgent(t *testing.T) {
ln := fasthttputil.NewInmemoryListener()
userAgentSeen := ""
s := &Server{
Handler: func(ctx *RequestCtx) {
userAgentSeen = string(ctx.UserAgent())
},
}
go s.Serve(ln) //nolint:errcheck
c := &Client{
NoDefaultUserAgentHeader: true,
Dial: func(addr string) (net.Conn, error) {
return ln.Dial()
},
}
req := AcquireRequest()
res := AcquireResponse()
req.SetRequestURI("http://example.com")
err := c.Do(req, res)
if err != nil {
t.Fatal(err)
}
if userAgentSeen != "" {
t.Fatalf("User-Agent wrong %q != %q", userAgentSeen, "")
}
}
func TestClientDoWithCustomHeaders(t *testing.T) {
t.Parallel()
// make sure that the client sends all the request headers and body.
ln := fasthttputil.NewInmemoryListener()
c := &Client{
Dial: func(addr string) (net.Conn, error) {
return ln.Dial()
},
}
uri := "/foo/bar/baz?a=b&cd=12"
headers := map[string]string{
"Foo": "bar",
"Host": "example.com",
"Content-Type": "asdfsdf",
"a-b-c-d-f": "",
}
body := "request body"
ch := make(chan error)
go func() {
conn, err := ln.Accept()
if err != nil {
ch <- fmt.Errorf("cannot accept client connection: %w", err)
return
}
br := bufio.NewReader(conn)
var req Request
if err = req.Read(br); err != nil {
ch <- fmt.Errorf("cannot read client request: %w", err)
return
}
if string(req.Header.Method()) != MethodPost {
ch <- fmt.Errorf("unexpected request method: %q. Expecting %q", req.Header.Method(), MethodPost)
return
}
reqURI := req.RequestURI()
if string(reqURI) != uri {
ch <- fmt.Errorf("unexpected request uri: %q. Expecting %q", reqURI, uri)
return
}
for k, v := range headers {
hv := req.Header.Peek(k)
if string(hv) != v {
ch <- fmt.Errorf("unexpected value for header %q: %q. Expecting %q", k, hv, v)
return
}
}
cl := req.Header.ContentLength()
if cl != len(body) {
ch <- fmt.Errorf("unexpected content-length %d. Expecting %d", cl, len(body))
return
}
reqBody := req.Body()
if string(reqBody) != body {
ch <- fmt.Errorf("unexpected request body: %q. Expecting %q", reqBody, body)
return
}
var resp Response
bw := bufio.NewWriter(conn)
if err = resp.Write(bw); err != nil {
ch <- fmt.Errorf("cannot send response: %w", err)
return
}
if err = bw.Flush(); err != nil {
ch <- fmt.Errorf("cannot flush response: %w", err)
return
}
ch <- nil
}()
var req Request
req.Header.SetMethod(MethodPost)
req.SetRequestURI(uri)
for k, v := range headers {
req.Header.Set(k, v)
}
req.SetBodyString(body)
var resp Response
err := c.DoTimeout(&req, &resp, time.Second)
if err != nil {
t.Fatalf("error when doing request: %v", err)
}
select {
case <-ch:
case <-time.After(5 * time.Second):
t.Fatalf("timeout")
}
}
func TestPipelineClientDoSerial(t *testing.T) {
t.Parallel()
testPipelineClientDoConcurrent(t, 1, 0, 0)
}
func TestPipelineClientDoConcurrent(t *testing.T) {
t.Parallel()
testPipelineClientDoConcurrent(t, 10, 0, 1)
}
func TestPipelineClientDoBatchDelayConcurrent(t *testing.T) {
t.Parallel()
testPipelineClientDoConcurrent(t, 10, 5*time.Millisecond, 1)
}
func TestPipelineClientDoBatchDelayConcurrentMultiConn(t *testing.T) {
t.Parallel()
testPipelineClientDoConcurrent(t, 10, 5*time.Millisecond, 3)
}
func testPipelineClientDoConcurrent(t *testing.T, concurrency int, maxBatchDelay time.Duration, maxConns int) {
ln := fasthttputil.NewInmemoryListener()
s := &Server{
Handler: func(ctx *RequestCtx) {
ctx.WriteString("OK") //nolint:errcheck
},
}
serverStopCh := make(chan struct{})
go func() {
if err := s.Serve(ln); err != nil {
t.Errorf("unexpected error: %v", err)
}
close(serverStopCh)
}()
c := &PipelineClient{
Dial: func(addr string) (net.Conn, error) {
return ln.Dial()
},
MaxConns: maxConns,
MaxPendingRequests: concurrency,
MaxBatchDelay: maxBatchDelay,
Logger: &testLogger{},
}
clientStopCh := make(chan struct{}, concurrency)
for i := 0; i < concurrency; i++ {
go func() {
testPipelineClientDo(t, c)
clientStopCh <- struct{}{}
}()
}
for i := 0; i < concurrency; i++ {
select {
case <-clientStopCh:
case <-time.After(3 * time.Second):
t.Fatalf("timeout")
}
}
if c.PendingRequests() != 0 {
t.Fatalf("unexpected number of pending requests: %d. Expecting zero", c.PendingRequests())
}
if err := ln.Close(); err != nil {
t.Fatalf("unexpected error: %v", err)
}
select {
case <-serverStopCh:
case <-time.After(time.Second):
t.Fatalf("timeout")
}
}
func testPipelineClientDo(t *testing.T, c *PipelineClient) {
var err error
req := AcquireRequest()
req.SetRequestURI("http://foobar/baz")
resp := AcquireResponse()
for i := 0; i < 10; i++ {
if i&1 == 0 {
err = c.DoTimeout(req, resp, time.Second)
} else {
err = c.Do(req, resp)
}
if err != nil {
if err == ErrPipelineOverflow {
time.Sleep(10 * time.Millisecond)
continue
}
t.Errorf("unexpected error on iteration %d: %v", i, err)
}
if resp.StatusCode() != StatusOK {
t.Errorf("unexpected status code: %d. Expecting %d", resp.StatusCode(), StatusOK)
}
body := string(resp.Body())
if body != "OK" {
t.Errorf("unexpected body: %q. Expecting %q", body, "OK")
}
// sleep for a while, so the connection to the host may expire.
if i%5 == 0 {
time.Sleep(30 * time.Millisecond)
}
}
ReleaseRequest(req)
ReleaseResponse(resp)
}
func TestPipelineClientDoDisableHeaderNamesNormalizing(t *testing.T) {
t.Parallel()
testPipelineClientDisableHeaderNamesNormalizing(t, 0)
}
func TestPipelineClientDoTimeoutDisableHeaderNamesNormalizing(t *testing.T) {
t.Parallel()
testPipelineClientDisableHeaderNamesNormalizing(t, time.Second)
}
func testPipelineClientDisableHeaderNamesNormalizing(t *testing.T, timeout time.Duration) {
ln := fasthttputil.NewInmemoryListener()
s := &Server{
Handler: func(ctx *RequestCtx) {
ctx.Response.Header.Set("foo-BAR", "baz")
},
DisableHeaderNamesNormalizing: true,
}
serverStopCh := make(chan struct{})
go func() {
if err := s.Serve(ln); err != nil {
t.Errorf("unexpected error: %v", err)
}
close(serverStopCh)
}()
c := &PipelineClient{
Dial: func(addr string) (net.Conn, error) {
return ln.Dial()
},
DisableHeaderNamesNormalizing: true,
}
var req Request
req.SetRequestURI("http://aaaai.com/bsdf?sddfsd")
var resp Response
for i := 0; i < 5; i++ {
if timeout > 0 {
if err := c.DoTimeout(&req, &resp, timeout); err != nil {
t.Fatalf("unexpected error: %v", err)
}
} else {
if err := c.Do(&req, &resp); err != nil {
t.Fatalf("unexpected error: %v", err)
}
}
hv := resp.Header.Peek("foo-BAR")
if string(hv) != "baz" {
t.Fatalf("unexpected header value: %q. Expecting %q", hv, "baz")
}
hv = resp.Header.Peek("Foo-Bar")
if len(hv) > 0 {
t.Fatalf("unexpected non-empty header value %q", hv)
}
}
if err := ln.Close(); err != nil {
t.Fatalf("unexpected error: %v", err)
}
select {
case <-serverStopCh:
case <-time.After(time.Second):
t.Fatalf("timeout")
}
}
func TestClientDoTimeoutDisableHeaderNamesNormalizing(t *testing.T) {
t.Parallel()
ln := fasthttputil.NewInmemoryListener()
s := &Server{
Handler: func(ctx *RequestCtx) {
ctx.Response.Header.Set("foo-BAR", "baz")
},
DisableHeaderNamesNormalizing: true,
}
serverStopCh := make(chan struct{})
go func() {
if err := s.Serve(ln); err != nil {
t.Errorf("unexpected error: %v", err)
}
close(serverStopCh)
}()
c := &Client{
Dial: func(addr string) (net.Conn, error) {
return ln.Dial()
},
DisableHeaderNamesNormalizing: true,
}
var req Request
req.SetRequestURI("http://aaaai.com/bsdf?sddfsd")
var resp Response
for i := 0; i < 5; i++ {
if err := c.DoTimeout(&req, &resp, time.Second); err != nil {
t.Fatalf("unexpected error: %v", err)
}
hv := resp.Header.Peek("foo-BAR")
if string(hv) != "baz" {
t.Fatalf("unexpected header value: %q. Expecting %q", hv, "baz")
}
hv = resp.Header.Peek("Foo-Bar")
if len(hv) > 0 {
t.Fatalf("unexpected non-empty header value %q", hv)
}
}
if err := ln.Close(); err != nil {
t.Fatalf("unexpected error: %v", err)
}
select {
case <-serverStopCh:
case <-time.After(time.Second):
t.Fatalf("timeout")
}
}
func TestClientDoTimeoutDisablePathNormalizing(t *testing.T) {
t.Parallel()
ln := fasthttputil.NewInmemoryListener()
s := &Server{
Handler: func(ctx *RequestCtx) {
uri := ctx.URI()
uri.DisablePathNormalizing = true
ctx.Response.Header.Set("received-uri", string(uri.FullURI()))
},
}
serverStopCh := make(chan struct{})
go func() {
if err := s.Serve(ln); err != nil {
t.Errorf("unexpected error: %v", err)
}
close(serverStopCh)
}()
c := &Client{
Dial: func(addr string) (net.Conn, error) {
return ln.Dial()
},
DisablePathNormalizing: true,
}
urlWithEncodedPath := "http://example.com/encoded/Y%2BY%2FY%3D/stuff"
var req Request
req.SetRequestURI(urlWithEncodedPath)
var resp Response
for i := 0; i < 5; i++ {
if err := c.DoTimeout(&req, &resp, time.Second); err != nil {
t.Fatalf("unexpected error: %v", err)
}
hv := resp.Header.Peek("received-uri")
if string(hv) != urlWithEncodedPath {
t.Fatalf("request uri was normalized: %q. Expecting %q", hv, urlWithEncodedPath)
}
}
if err := ln.Close(); err != nil {
t.Fatalf("unexpected error: %v", err)
}
select {
case <-serverStopCh:
case <-time.After(time.Second):
t.Fatalf("timeout")
}
}
func TestHostClientPendingRequests(t *testing.T) {
t.Parallel()
const concurrency = 10
doneCh := make(chan struct{})
readyCh := make(chan struct{}, concurrency)
s := &Server{
Handler: func(_ *RequestCtx) {
readyCh <- struct{}{}
<-doneCh
},
}
ln := fasthttputil.NewInmemoryListener()
serverStopCh := make(chan struct{})
go func() {
if err := s.Serve(ln); err != nil {
t.Errorf("unexpected error: %v", err)
}
close(serverStopCh)
}()
c := &HostClient{
Addr: "foobar",
Dial: func(addr string) (net.Conn, error) {
return ln.Dial()
},
}
pendingRequests := c.PendingRequests()
if pendingRequests != 0 {
t.Fatalf("non-zero pendingRequests: %d", pendingRequests)
}
resultCh := make(chan error, concurrency)
for i := 0; i < concurrency; i++ {
go func() {
req := AcquireRequest()
req.SetRequestURI("http://foobar/baz")
resp := AcquireResponse()
if err := c.DoTimeout(req, resp, 10*time.Second); err != nil {
resultCh <- fmt.Errorf("unexpected error: %w", err)
return
}
if resp.StatusCode() != StatusOK {
resultCh <- fmt.Errorf("unexpected status code %d. Expecting %d", resp.StatusCode(), StatusOK)
return
}
resultCh <- nil
}()
}
// wait while all the requests reach server
for i := 0; i < concurrency; i++ {
select {
case <-readyCh:
case <-time.After(time.Second):
t.Fatalf("timeout")
}
}
pendingRequests = c.PendingRequests()
if pendingRequests != concurrency {
t.Fatalf("unexpected pendingRequests: %d. Expecting %d", pendingRequests, concurrency)
}
// unblock request handlers on the server and wait until all the requests are finished.
close(doneCh)
for i := 0; i < concurrency; i++ {
select {
case err := <-resultCh:
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
case <-time.After(time.Second):
t.Fatalf("timeout")
}
}
pendingRequests = c.PendingRequests()
if pendingRequests != 0 {
t.Fatalf("non-zero pendingRequests: %d", pendingRequests)
}
// stop the server
if err := ln.Close(); err != nil {
t.Fatalf("unexpected error: %v", err)
}
select {
case <-serverStopCh:
case <-time.After(time.Second):
t.Fatalf("timeout")
}
}
func TestHostClientMaxConnsWithDeadline(t *testing.T) {
t.Parallel()
var (
emptyBodyCount uint8
ln = fasthttputil.NewInmemoryListener()
timeout = 200 * time.Millisecond
wg sync.WaitGroup
)
s := &Server{
Handler: func(ctx *RequestCtx) {
if len(ctx.PostBody()) == 0 {
emptyBodyCount++
}
ctx.WriteString("foo") //nolint:errcheck
},
}
serverStopCh := make(chan struct{})
go func() {
if err := s.Serve(ln); err != nil {
t.Errorf("unexpected error: %v", err)
}
close(serverStopCh)
}()
c := &HostClient{
Addr: "foobar",
Dial: func(addr string) (net.Conn, error) {
return ln.Dial()
},
MaxConns: 1,
}
for i := 0; i < 5; i++ {
wg.Add(1)
go func() {
defer wg.Done()
req := AcquireRequest()
req.SetRequestURI("http://foobar/baz")
req.Header.SetMethod(MethodPost)
req.SetBodyString("bar")
resp := AcquireResponse()
for {
if err := c.DoDeadline(req, resp, time.Now().Add(timeout)); err != nil {
if err == ErrNoFreeConns {
time.Sleep(time.Millisecond)
continue
}
t.Errorf("unexpected error: %v", err)
}
break
}
if resp.StatusCode() != StatusOK {
t.Errorf("unexpected status code %d. Expecting %d", resp.StatusCode(), StatusOK)
}
body := resp.Body()
if string(body) != "foo" {
t.Errorf("unexpected body %q. Expecting %q", body, "abcd")
}
}()
}
wg.Wait()
if err := ln.Close(); err != nil {
t.Fatalf("unexpected error: %v", err)
}
select {
case <-serverStopCh:
case <-time.After(time.Second):
t.Fatalf("timeout")
}
if emptyBodyCount > 0 {
t.Fatalf("at least one request body was empty")
}
}
func TestHostClientMaxConnDuration(t *testing.T) {
t.Parallel()
ln := fasthttputil.NewInmemoryListener()
connectionCloseCount := uint32(0)
s := &Server{
Handler: func(ctx *RequestCtx) {
ctx.WriteString("abcd") //nolint:errcheck
if ctx.Request.ConnectionClose() {
atomic.AddUint32(&connectionCloseCount, 1)
}
},
}
serverStopCh := make(chan struct{})
go func() {
if err := s.Serve(ln); err != nil {
t.Errorf("unexpected error: %v", err)
}
close(serverStopCh)
}()
c := &HostClient{
Addr: "foobar",
Dial: func(addr string) (net.Conn, error) {
return ln.Dial()
},
MaxConnDuration: 10 * time.Millisecond,
}
for i := 0; i < 5; i++ {
statusCode, body, err := c.Get(nil, "http://aaaa.com/bbb/cc")
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if statusCode != StatusOK {
t.Fatalf("unexpected status code %d. Expecting %d", statusCode, StatusOK)
}
if string(body) != "abcd" {
t.Fatalf("unexpected body %q. Expecting %q", body, "abcd")
}
time.Sleep(c.MaxConnDuration)
}
if err := ln.Close(); err != nil {
t.Fatalf("unexpected error: %v", err)
}
select {
case <-serverStopCh:
case <-time.After(time.Second):
t.Fatalf("timeout")
}
if connectionCloseCount == 0 {
t.Fatalf("expecting at least one 'Connection: close' request header")
}
}
func TestHostClientMultipleAddrs(t *testing.T) {
t.Parallel()
ln := fasthttputil.NewInmemoryListener()
s := &Server{
Handler: func(ctx *RequestCtx) {
ctx.Write(ctx.Host()) //nolint:errcheck
ctx.SetConnectionClose()
},
}
serverStopCh := make(chan struct{})
go func() {
if err := s.Serve(ln); err != nil {
t.Errorf("unexpected error: %v", err)
}
close(serverStopCh)
}()
dialsCount := make(map[string]int)
c := &HostClient{
Addr: "foo,bar,baz",
Dial: func(addr string) (net.Conn, error) {
dialsCount[addr]++
return ln.Dial()
},
}
for i := 0; i < 9; i++ {
statusCode, body, err := c.Get(nil, "http://foobar/baz/aaa?bbb=ddd")
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if statusCode != StatusOK {
t.Fatalf("unexpected status code %d. Expecting %d", statusCode, StatusOK)
}
if string(body) != "foobar" {
t.Fatalf("unexpected body %q. Expecting %q", body, "foobar")
}
}
if err := ln.Close(); err != nil {
t.Fatalf("unexpected error: %v", err)
}
select {
case <-serverStopCh:
case <-time.After(time.Second):
t.Fatalf("timeout")
}
if len(dialsCount) != 3 {
t.Fatalf("unexpected dialsCount size %d. Expecting 3", len(dialsCount))
}
for _, k := range []string{"foo", "bar", "baz"} {
if dialsCount[k] != 3 {
t.Fatalf("unexpected dialsCount for %q. Expecting 3", k)
}
}
}
func TestClientFollowRedirects(t *testing.T) {
t.Parallel()
s := &Server{
Handler: func(ctx *RequestCtx) {
switch string(ctx.Path()) {
case "/foo":
u := ctx.URI()
u.Update("/xy?z=wer")
ctx.Redirect(u.String(), StatusFound)
case "/xy":
u := ctx.URI()
u.Update("/bar")
ctx.Redirect(u.String(), StatusFound)
case "/abc/*/123":
u := ctx.URI()
u.Update("/xyz/*/456")
ctx.Redirect(u.String(), StatusFound)
default:
ctx.Success("text/plain", ctx.Path())
}
},
}
ln := fasthttputil.NewInmemoryListener()
serverStopCh := make(chan struct{})
go func() {
if err := s.Serve(ln); err != nil {
t.Errorf("unexpected error: %v", err)
}
close(serverStopCh)
}()
c := &HostClient{
Addr: "xxx",
Dial: func(addr string) (net.Conn, error) {
return ln.Dial()
},
}
for i := 0; i < 10; i++ {
statusCode, body, err := c.GetTimeout(nil, "http://xxx/foo", time.Second)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if statusCode != StatusOK {
t.Fatalf("unexpected status code: %d", statusCode)
}
if string(body) != "/bar" {
t.Fatalf("unexpected response %q. Expecting %q", body, "/bar")
}
}
for i := 0; i < 10; i++ {
statusCode, body, err := c.Get(nil, "http://xxx/aaab/sss")
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if statusCode != StatusOK {
t.Fatalf("unexpected status code: %d", statusCode)
}
if string(body) != "/aaab/sss" {
t.Fatalf("unexpected response %q. Expecting %q", body, "/aaab/sss")
}
}
for i := 0; i < 10; i++ {
req := AcquireRequest()
resp := AcquireResponse()
req.SetRequestURI("http://xxx/foo")
err := c.DoRedirects(req, resp, 16)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if statusCode := resp.StatusCode(); statusCode != StatusOK {
t.Fatalf("unexpected status code: %d", statusCode)
}
if body := string(resp.Body()); body != "/bar" {
t.Fatalf("unexpected response %q. Expecting %q", body, "/bar")
}
ReleaseRequest(req)
ReleaseResponse(resp)
}
for i := 0; i < 10; i++ {
req := AcquireRequest()
resp := AcquireResponse()
req.SetRequestURI("http://xxx/foo")
req.SetTimeout(time.Second)
err := c.DoRedirects(req, resp, 16)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if statusCode := resp.StatusCode(); statusCode != StatusOK {
t.Fatalf("unexpected status code: %d", statusCode)
}
if body := string(resp.Body()); body != "/bar" {
t.Fatalf("unexpected response %q. Expecting %q", body, "/bar")
}
ReleaseRequest(req)
ReleaseResponse(resp)
}
for i := 0; i < 10; i++ {
req := AcquireRequest()
resp := AcquireResponse()
req.SetRequestURI("http://xxx/foo")
testConn, _ := net.Dial("tcp", ln.Addr().String())
timeoutConn := &Client{
Dial: func(addr string) (net.Conn, error) {
return &readTimeoutConn{Conn: testConn, t: time.Second}, nil
},
}
req.SetTimeout(time.Millisecond)
err := timeoutConn.DoRedirects(req, resp, 16)
if err == nil {
t.Errorf("expecting error")
}
if err != ErrTimeout {
t.Errorf("unexpected error: %v. Expecting %v", err, ErrTimeout)
}
ReleaseRequest(req)
ReleaseResponse(resp)
}
for i := 0; i < 10; i++ {
req := AcquireRequest()
resp := AcquireResponse()
req.SetRequestURI("http://xxx/abc/*/123")
req.URI().DisablePathNormalizing = true
req.DisableRedirectPathNormalizing = true
err := c.DoRedirects(req, resp, 16)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if statusCode := resp.StatusCode(); statusCode != StatusOK {
t.Fatalf("unexpected status code: %d", statusCode)
}
if body := string(resp.Body()); body != "/xyz/*/456" {
t.Fatalf("unexpected response %q. Expecting %q", body, "/xyz/*/456")
}
ReleaseRequest(req)
ReleaseResponse(resp)
}
req := AcquireRequest()
resp := AcquireResponse()
req.SetRequestURI("http://xxx/foo")
err := c.DoRedirects(req, resp, 0)
if have, want := err, ErrTooManyRedirects; have != want {
t.Fatalf("want error: %v, have %v", want, have)
}
ReleaseRequest(req)
ReleaseResponse(resp)
}
func TestClientGetTimeoutSuccess(t *testing.T) {
t.Parallel()
s := startEchoServer(t, "tcp", "127.0.0.1:")
defer s.Stop()
testClientGetTimeoutSuccess(t, &defaultClient, "http://"+s.Addr(), 100)
}
func TestClientGetTimeoutSuccessConcurrent(t *testing.T) {
t.Parallel()
s := startEchoServer(t, "tcp", "127.0.0.1:")
defer s.Stop()
var wg sync.WaitGroup
for i := 0; i < 10; i++ {
wg.Add(1)
go func() {
defer wg.Done()
testClientGetTimeoutSuccess(t, &defaultClient, "http://"+s.Addr(), 100)
}()
}
wg.Wait()
}
func TestClientDoTimeoutSuccess(t *testing.T) {
t.Parallel()
s := startEchoServer(t, "tcp", "127.0.0.1:")
defer s.Stop()
testClientDoTimeoutSuccess(t, &defaultClient, "http://"+s.Addr(), 100)
testClientRequestSetTimeoutSuccess(t, &defaultClient, "http://"+s.Addr(), 100)
}
func TestClientDoTimeoutSuccessConcurrent(t *testing.T) {
t.Parallel()
s := startEchoServer(t, "tcp", "127.0.0.1:")
defer s.Stop()
var wg sync.WaitGroup
for i := 0; i < 10; i++ {
wg.Add(1)
go func() {
defer wg.Done()
testClientDoTimeoutSuccess(t, &defaultClient, "http://"+s.Addr(), 100)
testClientRequestSetTimeoutSuccess(t, &defaultClient, "http://"+s.Addr(), 100)
}()
}
wg.Wait()
}
func TestClientGetTimeoutError(t *testing.T) {
t.Parallel()
s := startEchoServer(t, "tcp", "127.0.0.1:")
defer s.Stop()
testConn, _ := net.Dial("tcp", s.ln.Addr().String())
c := &Client{
Dial: func(addr string) (net.Conn, error) {
return &readTimeoutConn{Conn: testConn, t: time.Second}, nil
},
}
testClientGetTimeoutError(t, c, 100)
}
func TestClientGetTimeoutErrorConcurrent(t *testing.T) {
t.Parallel()
s := startEchoServer(t, "tcp", "127.0.0.1:")
defer s.Stop()
testConn, _ := net.Dial("tcp", s.ln.Addr().String())
c := &Client{
Dial: func(addr string) (net.Conn, error) {
return &readTimeoutConn{Conn: testConn, t: time.Second}, nil
},
MaxConnsPerHost: 1000,
}
var wg sync.WaitGroup
for i := 0; i < 10; i++ {
wg.Add(1)
go func() {
defer wg.Done()
testClientGetTimeoutError(t, c, 100)
}()
}
wg.Wait()
}
func TestClientDoTimeoutError(t *testing.T) {
t.Parallel()
s := startEchoServer(t, "tcp", "127.0.0.1:")
defer s.Stop()
testConn, _ := net.Dial("tcp", s.ln.Addr().String())
c := &Client{
Dial: func(addr string) (net.Conn, error) {
return &readTimeoutConn{Conn: testConn, t: time.Second}, nil
},
}
testClientDoTimeoutError(t, c, 100)
testClientRequestSetTimeoutError(t, c, 100)
}
func TestClientDoTimeoutErrorConcurrent(t *testing.T) {
t.Parallel()
s := startEchoServer(t, "tcp", "127.0.0.1:")
defer s.Stop()
testConn, _ := net.Dial("tcp", s.ln.Addr().String())
c := &Client{
Dial: func(addr string) (net.Conn, error) {
return &readTimeoutConn{Conn: testConn, t: time.Second}, nil
},
MaxConnsPerHost: 1000,
}
var wg sync.WaitGroup
for i := 0; i < 10; i++ {
wg.Add(1)
go func() {
defer wg.Done()
testClientDoTimeoutError(t, c, 100)
}()
}
wg.Wait()
}
func testClientDoTimeoutError(t *testing.T, c *Client, n int) {
var req Request
var resp Response
req.SetRequestURI("http://foobar.com/baz")
for i := 0; i < n; i++ {
err := c.DoTimeout(&req, &resp, time.Millisecond)
if err == nil {
t.Errorf("expecting error")
}
if err != ErrTimeout {
t.Errorf("unexpected error: %v. Expecting %v", err, ErrTimeout)
}
}
}
func testClientGetTimeoutError(t *testing.T, c *Client, n int) {
buf := make([]byte, 10)
for i := 0; i < n; i++ {
statusCode, body, err := c.GetTimeout(buf, "http://foobar.com/baz", time.Millisecond)
if err == nil {
t.Errorf("expecting error")
}
if err != ErrTimeout {
t.Errorf("unexpected error: %v. Expecting %v", err, ErrTimeout)
}
if statusCode != 0 {
t.Errorf("unexpected statusCode=%d. Expecting %d", statusCode, 0)
}
if body == nil {
t.Errorf("body must be non-nil")
}
}
}
func testClientRequestSetTimeoutError(t *testing.T, c *Client, n int) {
var req Request
var resp Response
req.SetRequestURI("http://foobar.com/baz")
for i := 0; i < n; i++ {
req.SetTimeout(time.Millisecond)
err := c.Do(&req, &resp)
if err == nil {
t.Errorf("expecting error")
}
if err != ErrTimeout {
t.Errorf("unexpected error: %v. Expecting %v", err, ErrTimeout)
}
}
}
type readTimeoutConn struct {
net.Conn
t time.Duration
wc chan struct{}
rc chan struct{}
}
func (r *readTimeoutConn) Read(p []byte) (int, error) {
<-r.rc
return 0, os.ErrDeadlineExceeded
}
func (r *readTimeoutConn) Write(p []byte) (int, error) {
<-r.wc
return 0, os.ErrDeadlineExceeded
}
func (r *readTimeoutConn) Close() error {
return nil
}
func (r *readTimeoutConn) LocalAddr() net.Addr {
return nil
}
func (r *readTimeoutConn) RemoteAddr() net.Addr {
return nil
}
func (r *readTimeoutConn) SetReadDeadline(d time.Time) error {
r.rc = make(chan struct{}, 1)
go func() {
time.Sleep(time.Until(d))
r.rc <- struct{}{}
}()
return nil
}
func (r *readTimeoutConn) SetWriteDeadline(d time.Time) error {
r.wc = make(chan struct{}, 1)
go func() {
time.Sleep(time.Until(d))
r.wc <- struct{}{}
}()
return nil
}
func TestClientNonIdempotentRetry(t *testing.T) {
t.Parallel()
dialsCount := 0
c := &Client{
Dial: func(_ string) (net.Conn, error) {
dialsCount++
switch dialsCount {
case 1, 2:
return &readErrorConn{}, nil
case 3:
return &singleReadConn{
s: "HTTP/1.1 345 OK\r\nContent-Type: foobar\r\nContent-Length: 7\r\n\r\n0123456",
}, nil
default:
t.Fatalf("unexpected number of dials: %d", dialsCount)
}
panic("unreachable")
},
}
// This POST must succeed, since the readErrorConn closes
// the connection before sending any response.
// So the client must retry non-idempotent request.
dialsCount = 0
statusCode, body, err := c.Post(nil, "http://foobar/a/b", nil)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if statusCode != 345 {
t.Fatalf("unexpected status code: %d. Expecting 345", statusCode)
}
if string(body) != "0123456" {
t.Fatalf("unexpected body: %q. Expecting %q", body, "0123456")
}
// Verify that idempotent GET succeeds.
dialsCount = 0
statusCode, body, err = c.Get(nil, "http://foobar/a/b")
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if statusCode != 345 {
t.Fatalf("unexpected status code: %d. Expecting 345", statusCode)
}
if string(body) != "0123456" {
t.Fatalf("unexpected body: %q. Expecting %q", body, "0123456")
}
}
func TestClientNonIdempotentRetry_BodyStream(t *testing.T) {
t.Parallel()
dialsCount := 0
c := &Client{
Dial: func(_ string) (net.Conn, error) {
dialsCount++
switch dialsCount {
case 1, 2:
return &readErrorConn{}, nil
case 3:
return &singleEchoConn{
b: []byte("HTTP/1.1 345 OK\r\nContent-Type: foobar\r\n\r\n"),
}, nil
default:
t.Fatalf("unexpected number of dials: %d", dialsCount)
}
panic("unreachable")
},
}
dialsCount = 0
req := Request{}
res := Response{}
req.SetRequestURI("http://foobar/a/b")
req.Header.SetMethod("POST")
body := bytes.NewBufferString("test")
req.SetBodyStream(body, body.Len())
err := c.Do(&req, &res)
if err == nil {
t.Fatal("expected error from being unable to retry a bodyStream")
}
}
func TestClientIdempotentRequest(t *testing.T) {
t.Parallel()
dialsCount := 0
c := &Client{
Dial: func(_ string) (net.Conn, error) {
dialsCount++
switch dialsCount {
case 1:
return &singleReadConn{
s: "invalid response",
}, nil
case 2:
return &writeErrorConn{}, nil
case 3:
return &readErrorConn{}, nil
case 4:
return &singleReadConn{
s: "HTTP/1.1 345 OK\r\nContent-Type: foobar\r\nContent-Length: 7\r\n\r\n0123456",
}, nil
default:
t.Fatalf("unexpected number of dials: %d", dialsCount)
}
panic("unreachable")
},
}
// idempotent GET must succeed.
statusCode, body, err := c.Get(nil, "http://foobar/a/b")
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if statusCode != 345 {
t.Fatalf("unexpected status code: %d. Expecting 345", statusCode)
}
if string(body) != "0123456" {
t.Fatalf("unexpected body: %q. Expecting %q", body, "0123456")
}
var args Args
// non-idempotent POST must fail on incorrect singleReadConn
dialsCount = 0
_, _, err = c.Post(nil, "http://foobar/a/b", &args)
if err == nil {
t.Fatalf("expecting error")
}
// non-idempotent POST must fail on incorrect singleReadConn
dialsCount = 0
_, _, err = c.Post(nil, "http://foobar/a/b", nil)
if err == nil {
t.Fatalf("expecting error")
}
}
func TestClientRetryRequestWithCustomDecider(t *testing.T) {
t.Parallel()
dialsCount := 0
c := &Client{
Dial: func(_ string) (net.Conn, error) {
dialsCount++
switch dialsCount {
case 1:
return &singleReadConn{
s: "invalid response",
}, nil
case 2:
return &writeErrorConn{}, nil
case 3:
return &readErrorConn{}, nil
case 4:
return &singleReadConn{
s: "HTTP/1.1 345 OK\r\nContent-Type: foobar\r\nContent-Length: 7\r\n\r\n0123456",
}, nil
default:
t.Fatalf("unexpected number of dials: %d", dialsCount)
}
panic("unreachable")
},
RetryIf: func(req *Request) bool {
return req.URI().String() == "http://foobar/a/b"
},
}
var args Args
// Post must succeed for http://foobar/a/b uri.
statusCode, body, err := c.Post(nil, "http://foobar/a/b", &args)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if statusCode != 345 {
t.Fatalf("unexpected status code: %d. Expecting 345", statusCode)
}
if string(body) != "0123456" {
t.Fatalf("unexpected body: %q. Expecting %q", body, "0123456")
}
// POST must fail for http://foobar/a/b/c uri.
dialsCount = 0
_, _, err = c.Post(nil, "http://foobar/a/b/c", &args)
if err == nil {
t.Fatalf("expecting error")
}
}
type TransportDemo struct {
br *bufio.Reader
bw *bufio.Writer
}
func (t TransportDemo) RoundTrip(hc *HostClient, req *Request, res *Response) (retry bool, err error) {
if err = req.Write(t.bw); err != nil {
return false, err
}
if err = t.bw.Flush(); err != nil {
return false, err
}
err = res.Read(t.br)
return err != nil, err
}
func TestHostClientTransport(t *testing.T) {
t.Parallel()
ln := fasthttputil.NewInmemoryListener()
s := &Server{
Handler: func(ctx *RequestCtx) {
ctx.WriteString("abcd") //nolint:errcheck
},
}
serverStopCh := make(chan struct{})
go func() {
if err := s.Serve(ln); err != nil {
t.Errorf("unexpected error: %v", err)
}
close(serverStopCh)
}()
c := &HostClient{
Addr: "foobar",
Transport: func() RoundTripper {
c, _ := ln.Dial()
br := bufio.NewReader(c)
bw := bufio.NewWriter(c)
return TransportDemo{br: br, bw: bw}
}(),
}
for i := 0; i < 5; i++ {
statusCode, body, err := c.Get(nil, "http://aaaa.com/bbb/cc")
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if statusCode != StatusOK {
t.Fatalf("unexpected status code %d. Expecting %d", statusCode, StatusOK)
}
if string(body) != "abcd" {
t.Fatalf("unexpected body %q. Expecting %q", body, "abcd")
}
}
if err := ln.Close(); err != nil {
t.Fatalf("unexpected error: %v", err)
}
select {
case <-serverStopCh:
case <-time.After(time.Second):
t.Fatalf("timeout")
}
}
type writeErrorConn struct {
net.Conn
}
func (w *writeErrorConn) Write(p []byte) (int, error) {
return 1, fmt.Errorf("error")
}
func (w *writeErrorConn) Close() error {
return nil
}
func (w *writeErrorConn) LocalAddr() net.Addr {
return nil
}
func (w *writeErrorConn) RemoteAddr() net.Addr {
return nil
}
func (r *writeErrorConn) SetReadDeadline(_ time.Time) error {
return nil
}
func (r *writeErrorConn) SetWriteDeadline(_ time.Time) error {
return nil
}
type readErrorConn struct {
net.Conn
}
func (r *readErrorConn) Read(p []byte) (int, error) {
return 0, fmt.Errorf("error")
}
func (r *readErrorConn) Write(p []byte) (int, error) {
return len(p), nil
}
func (r *readErrorConn) Close() error {
return nil
}
func (r *readErrorConn) LocalAddr() net.Addr {
return nil
}
func (r *readErrorConn) RemoteAddr() net.Addr {
return nil
}
func (r *readErrorConn) SetReadDeadline(_ time.Time) error {
return nil
}
func (r *readErrorConn) SetWriteDeadline(_ time.Time) error {
return nil
}
type singleReadConn struct {
net.Conn
s string
n int
}
func (r *singleReadConn) Read(p []byte) (int, error) {
if len(r.s) == r.n {
return 0, io.EOF
}
n := copy(p, []byte(r.s[r.n:]))
r.n += n
return n, nil
}
func (r *singleReadConn) Write(p []byte) (int, error) {
return len(p), nil
}
func (r *singleReadConn) Close() error {
return nil
}
func (r *singleReadConn) LocalAddr() net.Addr {
return nil
}
func (r *singleReadConn) RemoteAddr() net.Addr {
return nil
}
func (r *singleReadConn) SetReadDeadline(_ time.Time) error {
return nil
}
func (r *singleReadConn) SetWriteDeadline(_ time.Time) error {
return nil
}
type singleEchoConn struct {
net.Conn
b []byte
n int
}
func (r *singleEchoConn) Read(p []byte) (int, error) {
if len(r.b) == r.n {
return 0, io.EOF
}
n := copy(p, r.b[r.n:])
r.n += n
return n, nil
}
func (r *singleEchoConn) Write(p []byte) (int, error) {
r.b = append(r.b, p...)
return len(p), nil
}
func (r *singleEchoConn) Close() error {
return nil
}
func (r *singleEchoConn) LocalAddr() net.Addr {
return nil
}
func (r *singleEchoConn) RemoteAddr() net.Addr {
return nil
}
func (r *singleEchoConn) SetReadDeadline(_ time.Time) error {
return nil
}
func (r *singleEchoConn) SetWriteDeadline(_ time.Time) error {
return nil
}
func TestSingleEchoConn(t *testing.T) {
t.Parallel()
c := &Client{
Dial: func(addr string) (net.Conn, error) {
return &singleEchoConn{
b: []byte("HTTP/1.1 345 OK\r\nContent-Type: foobar\r\n\r\n"),
}, nil
},
}
req := Request{}
res := Response{}
req.SetRequestURI("http://foobar/a/b")
req.Header.SetMethod("POST")
req.Header.Set("Content-Type", "text/plain")
body := bytes.NewBufferString("test")
req.SetBodyStream(body, body.Len())
err := c.Do(&req, &res)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if res.StatusCode() != 345 {
t.Fatalf("unexpected status code: %d. Expecting 345", res.StatusCode())
}
expected := "POST /a/b HTTP/1.1\r\nUser-Agent: fasthttp\r\nHost: foobar\r\nContent-Type: text/plain\r\nContent-Length: 4\r\n\r\ntest"
if string(res.Body()) != expected {
t.Fatalf("unexpected body: %q. Expecting %q", res.Body(), expected)
}
}
func TestClientHTTPSInvalidServerName(t *testing.T) {
t.Parallel()
sHTTPS := startEchoServerTLS(t, "tcp", "127.0.0.1:")
defer sHTTPS.Stop()
var c Client
for i := 0; i < 10; i++ {
_, _, err := c.GetTimeout(nil, "https://"+sHTTPS.Addr(), time.Second)
if err == nil {
t.Fatalf("expecting TLS error")
}
}
}
func TestClientHTTPSConcurrent(t *testing.T) {
t.Parallel()
sHTTP := startEchoServer(t, "tcp", "127.0.0.1:")
defer sHTTP.Stop()
sHTTPS := startEchoServerTLS(t, "tcp", "127.0.0.1:")
defer sHTTPS.Stop()
c := &Client{
TLSConfig: &tls.Config{
InsecureSkipVerify: true,
},
}
var wg sync.WaitGroup
for i := 0; i < 4; i++ {
wg.Add(1)
addr := "http://" + sHTTP.Addr()
if i&1 != 0 {
addr = "https://" + sHTTPS.Addr()
}
go func() {
defer wg.Done()
testClientGet(t, c, addr, 20)
testClientPost(t, c, addr, 10)
}()
}
wg.Wait()
}
func TestClientManyServers(t *testing.T) {
t.Parallel()
var addrs []string
for i := 0; i < 10; i++ {
s := startEchoServer(t, "tcp", "127.0.0.1:")
defer s.Stop()
addrs = append(addrs, s.Addr())
}
var wg sync.WaitGroup
for i := 0; i < 4; i++ {
wg.Add(1)
addr := "http://" + addrs[i]
go func() {
defer wg.Done()
testClientGet(t, &defaultClient, addr, 20)
testClientPost(t, &defaultClient, addr, 10)
}()
}
wg.Wait()
}
func TestClientGet(t *testing.T) {
t.Parallel()
s := startEchoServer(t, "tcp", "127.0.0.1:")
defer s.Stop()
testClientGet(t, &defaultClient, "http://"+s.Addr(), 100)
}
func TestClientPost(t *testing.T) {
t.Parallel()
s := startEchoServer(t, "tcp", "127.0.0.1:")
defer s.Stop()
testClientPost(t, &defaultClient, "http://"+s.Addr(), 100)
}
func TestClientConcurrent(t *testing.T) {
t.Parallel()
s := startEchoServer(t, "tcp", "127.0.0.1:")
defer s.Stop()
addr := "http://" + s.Addr()
var wg sync.WaitGroup
for i := 0; i < 10; i++ {
wg.Add(1)
go func() {
defer wg.Done()
testClientGet(t, &defaultClient, addr, 30)
testClientPost(t, &defaultClient, addr, 10)
}()
}
wg.Wait()
}
func skipIfNotUnix(tb testing.TB) {
switch runtime.GOOS {
case "android", "nacl", "plan9", "windows":
tb.Skipf("%s does not support unix sockets", runtime.GOOS)
}
if runtime.GOOS == "darwin" && (runtime.GOARCH == "arm" || runtime.GOARCH == "arm64") {
tb.Skip("iOS does not support unix, unixgram")
}
}
func TestHostClientGet(t *testing.T) {
t.Parallel()
skipIfNotUnix(t)
addr := "TestHostClientGet.unix"
s := startEchoServer(t, "unix", addr)
defer s.Stop()
c := createEchoClient(t, "unix", addr)
testHostClientGet(t, c, 100)
}
func TestHostClientPost(t *testing.T) {
t.Parallel()
skipIfNotUnix(t)
addr := "./TestHostClientPost.unix"
s := startEchoServer(t, "unix", addr)
defer s.Stop()
c := createEchoClient(t, "unix", addr)
testHostClientPost(t, c, 100)
}
func TestHostClientConcurrent(t *testing.T) {
t.Parallel()
skipIfNotUnix(t)
addr := "./TestHostClientConcurrent.unix"
s := startEchoServer(t, "unix", addr)
defer s.Stop()
c := createEchoClient(t, "unix", addr)
var wg sync.WaitGroup
for i := 0; i < 10; i++ {
wg.Add(1)
go func() {
defer wg.Done()
testHostClientGet(t, c, 30)
testHostClientPost(t, c, 10)
}()
}
wg.Wait()
}
func testClientGet(t *testing.T, c clientGetter, addr string, n int) {
var buf []byte
for i := 0; i < n; i++ {
uri := fmt.Sprintf("%s/foo/%d?bar=baz", addr, i)
statusCode, body, err := c.Get(buf, uri)
buf = body
if err != nil {
t.Errorf("unexpected error when doing http request: %v", err)
}
if statusCode != StatusOK {
t.Errorf("unexpected status code: %d. Expecting %d", statusCode, StatusOK)
}
resultURI := string(body)
if resultURI != uri {
t.Errorf("unexpected uri %q. Expecting %q", resultURI, uri)
}
}
}
func testClientDoTimeoutSuccess(t *testing.T, c *Client, addr string, n int) {
var req Request
var resp Response
for i := 0; i < n; i++ {
uri := fmt.Sprintf("%s/foo/%d?bar=baz", addr, i)
req.SetRequestURI(uri)
if err := c.DoTimeout(&req, &resp, time.Second); err != nil {
t.Errorf("unexpected error: %v", err)
}
if resp.StatusCode() != StatusOK {
t.Errorf("unexpected status code: %d. Expecting %d", resp.StatusCode(), StatusOK)
}
resultURI := string(resp.Body())
if strings.HasPrefix(uri, "https") {
resultURI = uri[:5] + resultURI[4:]
}
if resultURI != uri {
t.Errorf("unexpected uri %q. Expecting %q", resultURI, uri)
}
}
}
func testClientRequestSetTimeoutSuccess(t *testing.T, c *Client, addr string, n int) {
var req Request
var resp Response
for i := 0; i < n; i++ {
uri := fmt.Sprintf("%s/foo/%d?bar=baz", addr, i)
req.SetRequestURI(uri)
req.SetTimeout(time.Second)
if err := c.Do(&req, &resp); err != nil {
t.Errorf("unexpected error: %v", err)
}
if resp.StatusCode() != StatusOK {
t.Errorf("unexpected status code: %d. Expecting %d", resp.StatusCode(), StatusOK)
}
resultURI := string(resp.Body())
if strings.HasPrefix(uri, "https") {
resultURI = uri[:5] + resultURI[4:]
}
if resultURI != uri {
t.Errorf("unexpected uri %q. Expecting %q", resultURI, uri)
}
}
}
func testClientGetTimeoutSuccess(t *testing.T, c *Client, addr string, n int) {
var buf []byte
for i := 0; i < n; i++ {
uri := fmt.Sprintf("%s/foo/%d?bar=baz", addr, i)
statusCode, body, err := c.GetTimeout(buf, uri, time.Second)
buf = body
if err != nil {
t.Errorf("unexpected error when doing http request: %v", err)
}
if statusCode != StatusOK {
t.Errorf("unexpected status code: %d. Expecting %d", statusCode, StatusOK)
}
resultURI := string(body)
if strings.HasPrefix(uri, "https") {
resultURI = uri[:5] + resultURI[4:]
}
if resultURI != uri {
t.Errorf("unexpected uri %q. Expecting %q", resultURI, uri)
}
}
}
func testClientPost(t *testing.T, c clientPoster, addr string, n int) {
var buf []byte
var args Args
for i := 0; i < n; i++ {
uri := fmt.Sprintf("%s/foo/%d?bar=baz", addr, i)
args.Set("xx", fmt.Sprintf("yy%d", i))
args.Set("zzz", fmt.Sprintf("qwe_%d", i))
argsS := args.String()
statusCode, body, err := c.Post(buf, uri, &args)
buf = body
if err != nil {
t.Errorf("unexpected error when doing http request: %v", err)
}
if statusCode != StatusOK {
t.Errorf("unexpected status code: %d. Expecting %d", statusCode, StatusOK)
}
s := string(body)
if s != argsS {
t.Errorf("unexpected response %q. Expecting %q", s, argsS)
}
}
}
func testHostClientGet(t *testing.T, c *HostClient, n int) {
testClientGet(t, c, "http://google.com", n)
}
func testHostClientPost(t *testing.T, c *HostClient, n int) {
testClientPost(t, c, "http://post-host.com", n)
}
type clientPoster interface {
Post(dst []byte, uri string, postArgs *Args) (int, []byte, error)
}
type clientGetter interface {
Get(dst []byte, uri string) (int, []byte, error)
}
func createEchoClient(t *testing.T, network, addr string) *HostClient {
return &HostClient{
Addr: addr,
Dial: func(addr string) (net.Conn, error) {
return net.Dial(network, addr)
},
}
}
type testEchoServer struct {
s *Server
ln net.Listener
ch chan struct{}
t *testing.T
}
func (s *testEchoServer) Stop() {
s.ln.Close()
select {
case <-s.ch:
case <-time.After(time.Second):
s.t.Fatalf("timeout when waiting for server close")
}
}
func (s *testEchoServer) Addr() string {
return s.ln.Addr().String()
}
func startEchoServerTLS(t *testing.T, network, addr string) *testEchoServer {
return startEchoServerExt(t, network, addr, true)
}
func startEchoServer(t *testing.T, network, addr string) *testEchoServer {
return startEchoServerExt(t, network, addr, false)
}
func startEchoServerExt(t *testing.T, network, addr string, isTLS bool) *testEchoServer {
if network == "unix" {
os.Remove(addr)
}
var ln net.Listener
var err error
if isTLS {
certData, keyData, kerr := GenerateTestCertificate("localhost")
if kerr != nil {
t.Fatal(kerr)
}
cert, kerr := tls.X509KeyPair(certData, keyData)
if kerr != nil {
t.Fatal(kerr)
}
tlsConfig := &tls.Config{
Certificates: []tls.Certificate{cert},
}
ln, err = tls.Listen(network, addr, tlsConfig)
} else {
ln, err = net.Listen(network, addr)
}
if err != nil {
t.Fatalf("cannot listen %q: %v", addr, err)
}
s := &Server{
Handler: func(ctx *RequestCtx) {
if ctx.IsGet() {
ctx.Success("text/plain", ctx.URI().FullURI())
} else if ctx.IsPost() {
ctx.PostArgs().WriteTo(ctx) //nolint:errcheck
}
},
Logger: &testLogger{}, // Ignore log output.
}
ch := make(chan struct{})
go func() {
err := s.Serve(ln)
if err != nil {
t.Errorf("unexpected error returned from Serve(): %v", err)
}
close(ch)
}()
return &testEchoServer{
s: s,
ln: ln,
ch: ch,
t: t,
}
}
func TestClientTLSHandshakeTimeout(t *testing.T) {
t.Parallel()
listener, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatal(err)
}
addr := listener.Addr().String()
defer listener.Close()
complete := make(chan bool)
defer close(complete)
go func() {
conn, err := listener.Accept()
if err != nil {
t.Error(err)
return
}
<-complete
conn.Close()
}()
client := Client{
WriteTimeout: 100 * time.Millisecond,
ReadTimeout: 100 * time.Millisecond,
}
_, _, err = client.Get(nil, "https://"+addr)
if err == nil {
t.Fatal("tlsClientHandshake completed successfully")
}
if err != ErrTLSHandshakeTimeout {
t.Errorf("resulting error not a timeout: %v\nType %T: %#v", err, err, err)
}
}
func TestClientConfigureClientFailed(t *testing.T) {
t.Parallel()
c := &Client{
ConfigureClient: func(hc *HostClient) error {
return fmt.Errorf("failed to configure")
},
}
req := Request{}
req.SetRequestURI("http://example.com")
err := c.Do(&req, &Response{})
if err == nil {
t.Fatal("expected error (failed to configure)")
}
c.ConfigureClient = nil
err = c.Do(&req, &Response{})
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
}
func TestHostClientMaxConnWaitTimeoutSuccess(t *testing.T) {
t.Parallel()
var (
emptyBodyCount uint8
ln = fasthttputil.NewInmemoryListener()
wg sync.WaitGroup
)
s := &Server{
Handler: func(ctx *RequestCtx) {
if len(ctx.PostBody()) == 0 {
emptyBodyCount++
}
time.Sleep(5 * time.Millisecond)
ctx.WriteString("foo") //nolint:errcheck
},
}
serverStopCh := make(chan struct{})
go func() {
if err := s.Serve(ln); err != nil {
t.Errorf("unexpected error: %v", err)
}
close(serverStopCh)
}()
c := &HostClient{
Addr: "foobar",
Dial: func(addr string) (net.Conn, error) {
return ln.Dial()
},
MaxConns: 1,
MaxConnWaitTimeout: time.Second * 2,
}
for i := 0; i < 5; i++ {
wg.Add(1)
go func() {
defer wg.Done()
req := AcquireRequest()
req.SetRequestURI("http://foobar/baz")
req.Header.SetMethod(MethodPost)
req.SetBodyString("bar")
resp := AcquireResponse()
if err := c.Do(req, resp); err != nil {
t.Errorf("unexpected error: %v", err)
}
if resp.StatusCode() != StatusOK {
t.Errorf("unexpected status code %d. Expecting %d", resp.StatusCode(), StatusOK)
}
body := resp.Body()
if string(body) != "foo" {
t.Errorf("unexpected body %q. Expecting %q", body, "abcd")
}
}()
}
wg.Wait()
if c.connsWait.len() > 0 {
t.Errorf("connsWait has %v items remaining", c.connsWait.len())
}
if err := ln.Close(); err != nil {
t.Fatalf("unexpected error: %v", err)
}
select {
case <-serverStopCh:
case <-time.After(time.Second * 5):
t.Fatalf("timeout")
}
if emptyBodyCount > 0 {
t.Fatalf("at least one request body was empty")
}
}
func TestHostClientMaxConnWaitTimeoutError(t *testing.T) {
var (
emptyBodyCount uint8
ln = fasthttputil.NewInmemoryListener()
wg sync.WaitGroup
)
s := &Server{
Handler: func(ctx *RequestCtx) {
if len(ctx.PostBody()) == 0 {
emptyBodyCount++
}
time.Sleep(5 * time.Millisecond)
ctx.WriteString("foo") //nolint:errcheck
},
}
serverStopCh := make(chan struct{})
go func() {
if err := s.Serve(ln); err != nil {
t.Errorf("unexpected error: %v", err)
}
close(serverStopCh)
}()
c := &HostClient{
Addr: "foobar",
Dial: func(addr string) (net.Conn, error) {
return ln.Dial()
},
MaxConns: 1,
MaxConnWaitTimeout: 10 * time.Millisecond,
}
var errNoFreeConnsCount uint32
for i := 0; i < 5; i++ {
wg.Add(1)
go func() {
defer wg.Done()
req := AcquireRequest()
req.SetRequestURI("http://foobar/baz")
req.Header.SetMethod(MethodPost)
req.SetBodyString("bar")
resp := AcquireResponse()
if err := c.Do(req, resp); err != nil {
if err != ErrNoFreeConns {
t.Errorf("unexpected error: %v. Expecting %v", err, ErrNoFreeConns)
}
atomic.AddUint32(&errNoFreeConnsCount, 1)
} else {
if resp.StatusCode() != StatusOK {
t.Errorf("unexpected status code %d. Expecting %d", resp.StatusCode(), StatusOK)
}
body := resp.Body()
if string(body) != "foo" {
t.Errorf("unexpected body %q. Expecting %q", body, "abcd")
}
}
}()
}
wg.Wait()
time.Sleep(time.Millisecond * 100)
// Prevent a race condition with the conns cleaner that might still be running.
c.connsLock.Lock()
defer c.connsLock.Unlock()
if c.connsWait.len() > 0 {
t.Errorf("connsWait has %v items remaining", c.connsWait.len())
}
if errNoFreeConnsCount == 0 {
t.Errorf("unexpected errorCount: %d. Expecting > 0", errNoFreeConnsCount)
}
if err := ln.Close(); err != nil {
t.Fatalf("unexpected error: %v", err)
}
select {
case <-serverStopCh:
case <-time.After(time.Second):
t.Fatalf("timeout")
}
if emptyBodyCount > 0 {
t.Fatalf("at least one request body was empty")
}
}
func TestHostClientMaxConnWaitTimeoutWithEarlierDeadline(t *testing.T) {
t.Parallel()
var (
emptyBodyCount uint8
ln = fasthttputil.NewInmemoryListener()
wg sync.WaitGroup
// make deadline reach earlier than conns wait timeout
sleep = 100 * time.Millisecond
timeout = 10 * time.Millisecond
maxConnWaitTimeout = 50 * time.Millisecond
)
s := &Server{
Handler: func(ctx *RequestCtx) {
if len(ctx.PostBody()) == 0 {
emptyBodyCount++
}
time.Sleep(sleep)
ctx.WriteString("foo") //nolint:errcheck
},
Logger: &testLogger{}, // Don't print connection closed errors.
}
serverStopCh := make(chan struct{})
go func() {
if err := s.Serve(ln); err != nil {
t.Errorf("unexpected error: %v", err)
}
close(serverStopCh)
}()
c := &HostClient{
Addr: "foobar",
Dial: func(addr string) (net.Conn, error) {
return ln.Dial()
},
MaxConns: 1,
MaxConnWaitTimeout: maxConnWaitTimeout,
}
var errTimeoutCount uint32
for i := 0; i < 5; i++ {
wg.Add(1)
go func() {
defer wg.Done()
req := AcquireRequest()
req.SetRequestURI("http://foobar/baz")
req.Header.SetMethod(MethodPost)
req.SetBodyString("bar")
resp := AcquireResponse()
if err := c.DoDeadline(req, resp, time.Now().Add(timeout)); err != nil {
if err != ErrTimeout {
t.Errorf("unexpected error: %v. Expecting %v", err, ErrTimeout)
}
atomic.AddUint32(&errTimeoutCount, 1)
} else {
if resp.StatusCode() != StatusOK {
t.Errorf("unexpected status code %d. Expecting %d", resp.StatusCode(), StatusOK)
}
body := resp.Body()
if string(body) != "foo" {
t.Errorf("unexpected body %q. Expecting %q", body, "abcd")
}
}
}()
}
wg.Wait()
c.connsLock.Lock()
for {
w := c.connsWait.popFront()
if w == nil {
break
}
w.mu.Lock()
if w.err != nil && w.err != ErrTimeout {
t.Errorf("unexpected error: %v. Expecting %v", w.err, ErrTimeout)
}
w.mu.Unlock()
}
c.connsLock.Unlock()
if errTimeoutCount == 0 {
t.Errorf("unexpected errTimeoutCount: %d. Expecting > 0", errTimeoutCount)
}
if err := ln.Close(); err != nil {
t.Fatalf("unexpected error: %v", err)
}
select {
case <-serverStopCh:
case <-time.After(time.Second):
t.Fatalf("timeout")
}
if emptyBodyCount > 0 {
t.Fatalf("at least one request body was empty")
}
}
type TransportEmpty struct{}
func (t TransportEmpty) RoundTrip(hc *HostClient, req *Request, res *Response) (retry bool, err error) {
return false, nil
}
func TestHttpsRequestWithoutParsedURL(t *testing.T) {
t.Parallel()
client := HostClient{
IsTLS: true,
Transport: TransportEmpty{},
}
req := &Request{}
req.SetRequestURI("https://foo.com/bar")
_, err := client.doNonNilReqResp(req, &Response{})
if err != nil {
t.Fatal("https requests with IsTLS client must succeed")
}
}
func TestHostClientErrConnPoolStrategyNotImpl(t *testing.T) {
t.Parallel()
ln := fasthttputil.NewInmemoryListener()
server := &Server{
Handler: func(ctx *RequestCtx) {},
}
serverStopCh := make(chan struct{})
go func() {
if err := server.Serve(ln); err != nil {
t.Errorf("unexpected error: %v", err)
}
close(serverStopCh)
}()
client := &HostClient{
Addr: "foobar",
Dial: func(addr string) (net.Conn, error) {
return ln.Dial()
},
ConnPoolStrategy: ConnPoolStrategyType(100),
}
req := AcquireRequest()
req.SetRequestURI("http://foobar/baz")
if err := client.Do(req, AcquireResponse()); err != nil {
t.Fatalf("unexpected error: %v", err)
}
if err := client.Do(req, &Response{}); err != ErrConnPoolStrategyNotImpl {
t.Errorf("expected ErrConnPoolStrategyNotImpl error, got %v", err)
}
if err := client.Do(req, &Response{}); err != ErrConnPoolStrategyNotImpl {
t.Errorf("expected ErrConnPoolStrategyNotImpl error, got %v", err)
}
if err := ln.Close(); err != nil {
t.Fatalf("unexpected error: %v", err)
}
}
func Test_AddMissingPort(t *testing.T) {
t.Parallel()
type args struct {
addr string
isTLS bool
}
tests := []struct {
name string
args args
want string
}{
{
args: args{"127.1", false}, // 127.1 is a short form of 127.0.0.1
want: "127.1:80",
},
{
args: args{"127.0.0.1", false},
want: "127.0.0.1:80",
},
{
args: args{"127.0.0.1", true},
want: "127.0.0.1:443",
},
{
args: args{"[::1]", false},
want: "[::1]:80",
},
{
args: args{"::1", false},
want: "::1", // keep as is
},
{
args: args{"[::1]", true},
want: "[::1]:443",
},
{
args: args{"127.0.0.1:8080", false},
want: "127.0.0.1:8080",
},
{
args: args{"127.0.0.1:8443", true},
want: "127.0.0.1:8443",
},
{
args: args{"[::1]:8080", false},
want: "[::1]:8080",
},
{
args: args{"[::1]:8443", true},
want: "[::1]:8443",
},
}
for _, tt := range tests {
t.Run(tt.want, func(t *testing.T) {
if got := AddMissingPort(tt.args.addr, tt.args.isTLS); got != tt.want {
t.Errorf("AddMissingPort() = %v, want %v", got, tt.want)
}
})
}
}
type TransportWrapper struct {
base RoundTripper
count *int
t *testing.T
}
func (tw *TransportWrapper) RoundTrip(hc *HostClient, req *Request, resp *Response) (bool, error) {
req.Header.Set("trace-id", "123")
tw.assertRequestLog(req.String())
retry, err := tw.transport().RoundTrip(hc, req, resp)
resp.Header.Set("trace-id", "124")
tw.assertResponseLog(resp.String())
*tw.count++
return retry, err
}
func (tw *TransportWrapper) transport() RoundTripper {
if tw.base == nil {
return DefaultTransport
}
return tw.base
}
func (tw *TransportWrapper) assertRequestLog(reqLog string) {
if !strings.Contains(reqLog, "Trace-Id: 123") {
tw.t.Errorf("request log should contains: %v", "Trace-Id: 123")
}
}
func (tw *TransportWrapper) assertResponseLog(respLog string) {
if !strings.Contains(respLog, "Trace-Id: 124") {
tw.t.Errorf("response log should contains: %v", "Trace-Id: 124")
}
}
func TestClientTransportEx(t *testing.T) {
sHTTP := startEchoServer(t, "tcp", "127.0.0.1:")
defer sHTTP.Stop()
sHTTPS := startEchoServerTLS(t, "tcp", "127.0.0.1:")
defer sHTTPS.Stop()
count := 0
c := &Client{
TLSConfig: &tls.Config{
InsecureSkipVerify: true,
},
ConfigureClient: func(hc *HostClient) error {
hc.Transport = &TransportWrapper{base: hc.Transport, count: &count, t: t}
return nil
},
}
// test transport
const loopCount = 4
const getCount = 20
const postCount = 10
for i := 0; i < loopCount; i++ {
addr := "http://" + sHTTP.Addr()
if i&1 != 0 {
addr = "https://" + sHTTPS.Addr()
}
// test get
testClientGet(t, c, addr, getCount)
// test post
testClientPost(t, c, addr, postCount)
}
roundTripCount := loopCount * (getCount + postCount)
if count != roundTripCount {
t.Errorf("round trip count should be: %v", roundTripCount)
}
}
func Test_getRedirectURL(t *testing.T) {
type args struct {
baseURL string
location []byte
disablePathNormalizing bool
}
tests := []struct {
name string
args args
want string
}{
{
name: "Path normalizing enabled, no special characters in path",
args: args{
baseURL: "http://foo.example.com/abc",
location: []byte("http://bar.example.com/def"),
disablePathNormalizing: false,
},
want: "http://bar.example.com/def",
},
{
name: "Path normalizing enabled, special characters in path",
args: args{
baseURL: "http://foo.example.com/abc/*/def",
location: []byte("http://bar.example.com/123/*/456"),
disablePathNormalizing: false,
},
want: "http://bar.example.com/123/%2A/456",
},
{
name: "Path normalizing disabled, no special characters in path",
args: args{
baseURL: "http://foo.example.com/abc",
location: []byte("http://bar.example.com/def"),
disablePathNormalizing: true,
},
want: "http://bar.example.com/def",
},
{
name: "Path normalizing disabled, special characters in path",
args: args{
baseURL: "http://foo.example.com/abc/*/def",
location: []byte("http://bar.example.com/123/*/456"),
disablePathNormalizing: true,
},
want: "http://bar.example.com/123/*/456",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := getRedirectURL(tt.args.baseURL, tt.args.location, tt.args.disablePathNormalizing); got != tt.want {
t.Errorf("getRedirectURL() = %v, want %v", got, tt.want)
}
})
}
}