421 lines
9.7 KiB
Go
421 lines
9.7 KiB
Go
//go:build !js
|
|
// +build !js
|
|
|
|
package websocket_test
|
|
|
|
import (
|
|
"bytes"
|
|
"context"
|
|
"crypto/rand"
|
|
"io"
|
|
"net/http"
|
|
"net/http/httptest"
|
|
"net/url"
|
|
"strings"
|
|
"testing"
|
|
"time"
|
|
|
|
"github.com/coder/websocket"
|
|
"github.com/coder/websocket/internal/test/assert"
|
|
"github.com/coder/websocket/internal/util"
|
|
"github.com/coder/websocket/internal/xsync"
|
|
)
|
|
|
|
func TestBadDials(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
t.Run("badReq", func(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
testCases := []struct {
|
|
name string
|
|
url string
|
|
opts *websocket.DialOptions
|
|
rand util.ReaderFunc
|
|
nilCtx bool
|
|
}{
|
|
{
|
|
name: "badURL",
|
|
url: "://noscheme",
|
|
},
|
|
{
|
|
name: "badURLScheme",
|
|
url: "ftp://nhooyr.io",
|
|
},
|
|
{
|
|
name: "badTLS",
|
|
url: "wss://totallyfake.nhooyr.io",
|
|
},
|
|
{
|
|
name: "badReader",
|
|
rand: func(p []byte) (int, error) {
|
|
return 0, io.EOF
|
|
},
|
|
},
|
|
{
|
|
name: "nilContext",
|
|
url: "http://localhost",
|
|
nilCtx: true,
|
|
},
|
|
}
|
|
|
|
for _, tc := range testCases {
|
|
tc := tc
|
|
t.Run(tc.name, func(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
var ctx context.Context
|
|
var cancel func()
|
|
if !tc.nilCtx {
|
|
ctx, cancel = context.WithTimeout(context.Background(), time.Second*5)
|
|
defer cancel()
|
|
}
|
|
|
|
if tc.rand == nil {
|
|
tc.rand = rand.Reader.Read
|
|
}
|
|
|
|
_, _, err := websocket.ExportedDial(ctx, tc.url, tc.opts, tc.rand)
|
|
assert.Error(t, err)
|
|
})
|
|
}
|
|
})
|
|
|
|
t.Run("badResponse", func(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
ctx, cancel := context.WithTimeout(context.Background(), time.Second*5)
|
|
defer cancel()
|
|
|
|
_, _, err := websocket.Dial(ctx, "ws://example.com", &websocket.DialOptions{
|
|
HTTPClient: mockHTTPClient(func(*http.Request) (*http.Response, error) {
|
|
return &http.Response{
|
|
Body: io.NopCloser(strings.NewReader("hi")),
|
|
}, nil
|
|
}),
|
|
})
|
|
assert.Contains(t, err, "failed to WebSocket dial: expected handshake response status code 101 but got 0")
|
|
})
|
|
|
|
t.Run("badBody", func(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
ctx, cancel := context.WithTimeout(context.Background(), time.Second*5)
|
|
defer cancel()
|
|
|
|
rt := func(r *http.Request) (*http.Response, error) {
|
|
h := http.Header{}
|
|
h.Set("Connection", "Upgrade")
|
|
h.Set("Upgrade", "websocket")
|
|
h.Set("Sec-WebSocket-Accept", websocket.SecWebSocketAccept(r.Header.Get("Sec-WebSocket-Key")))
|
|
|
|
return &http.Response{
|
|
StatusCode: http.StatusSwitchingProtocols,
|
|
Header: h,
|
|
Body: io.NopCloser(strings.NewReader("hi")),
|
|
}, nil
|
|
}
|
|
|
|
_, _, err := websocket.Dial(ctx, "ws://example.com", &websocket.DialOptions{
|
|
HTTPClient: mockHTTPClient(rt),
|
|
})
|
|
assert.Contains(t, err, "response body is not a io.ReadWriteCloser")
|
|
})
|
|
}
|
|
|
|
func Test_verifyHostOverride(t *testing.T) {
|
|
testCases := []struct {
|
|
name string
|
|
host string
|
|
exp string
|
|
}{
|
|
{
|
|
name: "noOverride",
|
|
host: "",
|
|
exp: "example.com",
|
|
},
|
|
{
|
|
name: "hostOverride",
|
|
host: "example.net",
|
|
exp: "example.net",
|
|
},
|
|
}
|
|
|
|
for _, tc := range testCases {
|
|
tc := tc
|
|
t.Run(tc.name, func(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
ctx, cancel := context.WithTimeout(context.Background(), time.Second*5)
|
|
defer cancel()
|
|
|
|
rt := func(r *http.Request) (*http.Response, error) {
|
|
assert.Equal(t, "Host", tc.exp, r.Host)
|
|
|
|
h := http.Header{}
|
|
h.Set("Connection", "Upgrade")
|
|
h.Set("Upgrade", "websocket")
|
|
h.Set("Sec-WebSocket-Accept", websocket.SecWebSocketAccept(r.Header.Get("Sec-WebSocket-Key")))
|
|
|
|
return &http.Response{
|
|
StatusCode: http.StatusSwitchingProtocols,
|
|
Header: h,
|
|
Body: mockBody{bytes.NewBufferString("hi")},
|
|
}, nil
|
|
}
|
|
|
|
c, _, err := websocket.Dial(ctx, "ws://example.com", &websocket.DialOptions{
|
|
HTTPClient: mockHTTPClient(rt),
|
|
Host: tc.host,
|
|
})
|
|
assert.Success(t, err)
|
|
c.CloseNow()
|
|
})
|
|
}
|
|
|
|
}
|
|
|
|
type mockBody struct {
|
|
*bytes.Buffer
|
|
}
|
|
|
|
func (mb mockBody) Close() error {
|
|
return nil
|
|
}
|
|
|
|
func Test_verifyServerHandshake(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
testCases := []struct {
|
|
name string
|
|
response func(w http.ResponseWriter)
|
|
success bool
|
|
}{
|
|
{
|
|
name: "badStatus",
|
|
response: func(w http.ResponseWriter) {
|
|
w.WriteHeader(http.StatusOK)
|
|
},
|
|
success: false,
|
|
},
|
|
{
|
|
name: "badConnection",
|
|
response: func(w http.ResponseWriter) {
|
|
w.Header().Set("Connection", "???")
|
|
w.WriteHeader(http.StatusSwitchingProtocols)
|
|
},
|
|
success: false,
|
|
},
|
|
{
|
|
name: "badUpgrade",
|
|
response: func(w http.ResponseWriter) {
|
|
w.Header().Set("Connection", "Upgrade")
|
|
w.Header().Set("Upgrade", "???")
|
|
w.WriteHeader(http.StatusSwitchingProtocols)
|
|
},
|
|
success: false,
|
|
},
|
|
{
|
|
name: "badSecWebSocketAccept",
|
|
response: func(w http.ResponseWriter) {
|
|
w.Header().Set("Connection", "Upgrade")
|
|
w.Header().Set("Upgrade", "websocket")
|
|
w.Header().Set("Sec-WebSocket-Accept", "xd")
|
|
w.WriteHeader(http.StatusSwitchingProtocols)
|
|
},
|
|
success: false,
|
|
},
|
|
{
|
|
name: "badSecWebSocketProtocol",
|
|
response: func(w http.ResponseWriter) {
|
|
w.Header().Set("Connection", "Upgrade")
|
|
w.Header().Set("Upgrade", "websocket")
|
|
w.Header().Set("Sec-WebSocket-Protocol", "xd")
|
|
w.WriteHeader(http.StatusSwitchingProtocols)
|
|
},
|
|
success: false,
|
|
},
|
|
{
|
|
name: "unsupportedExtension",
|
|
response: func(w http.ResponseWriter) {
|
|
w.Header().Set("Connection", "Upgrade")
|
|
w.Header().Set("Upgrade", "websocket")
|
|
w.Header().Set("Sec-WebSocket-Extensions", "meow")
|
|
w.WriteHeader(http.StatusSwitchingProtocols)
|
|
},
|
|
success: false,
|
|
},
|
|
{
|
|
name: "unsupportedDeflateParam",
|
|
response: func(w http.ResponseWriter) {
|
|
w.Header().Set("Connection", "Upgrade")
|
|
w.Header().Set("Upgrade", "websocket")
|
|
w.Header().Set("Sec-WebSocket-Extensions", "permessage-deflate; meow")
|
|
w.WriteHeader(http.StatusSwitchingProtocols)
|
|
},
|
|
success: false,
|
|
},
|
|
{
|
|
name: "success",
|
|
response: func(w http.ResponseWriter) {
|
|
w.Header().Set("Connection", "Upgrade")
|
|
w.Header().Set("Upgrade", "websocket")
|
|
w.WriteHeader(http.StatusSwitchingProtocols)
|
|
},
|
|
success: true,
|
|
},
|
|
}
|
|
|
|
for _, tc := range testCases {
|
|
tc := tc
|
|
t.Run(tc.name, func(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
w := httptest.NewRecorder()
|
|
tc.response(w)
|
|
resp := w.Result()
|
|
|
|
r := httptest.NewRequest("GET", "/", nil)
|
|
key, err := websocket.SecWebSocketKey(rand.Reader)
|
|
assert.Success(t, err)
|
|
r.Header.Set("Sec-WebSocket-Key", key)
|
|
|
|
if resp.Header.Get("Sec-WebSocket-Accept") == "" {
|
|
resp.Header.Set("Sec-WebSocket-Accept", websocket.SecWebSocketAccept(key))
|
|
}
|
|
|
|
opts := &websocket.DialOptions{
|
|
Subprotocols: strings.Split(r.Header.Get("Sec-WebSocket-Protocol"), ","),
|
|
}
|
|
_, err = websocket.VerifyServerResponse(opts, websocket.CompressionModeOpts(opts.CompressionMode), key, resp)
|
|
if tc.success {
|
|
assert.Success(t, err)
|
|
} else {
|
|
assert.Error(t, err)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func mockHTTPClient(fn roundTripperFunc) *http.Client {
|
|
return &http.Client{
|
|
Transport: fn,
|
|
}
|
|
}
|
|
|
|
type roundTripperFunc func(*http.Request) (*http.Response, error)
|
|
|
|
func (f roundTripperFunc) RoundTrip(r *http.Request) (*http.Response, error) {
|
|
return f(r)
|
|
}
|
|
|
|
func TestDialRedirect(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
ctx, cancel := context.WithTimeout(context.Background(), time.Second*5)
|
|
defer cancel()
|
|
|
|
_, _, err := websocket.Dial(ctx, "ws://example.com", &websocket.DialOptions{
|
|
HTTPClient: mockHTTPClient(func(r *http.Request) (*http.Response, error) {
|
|
resp := &http.Response{
|
|
Header: http.Header{},
|
|
}
|
|
if r.URL.Scheme != "https" {
|
|
resp.Header.Set("Location", "wss://example.com")
|
|
resp.StatusCode = http.StatusFound
|
|
return resp, nil
|
|
}
|
|
resp.Header.Set("Connection", "Upgrade")
|
|
resp.Header.Set("Upgrade", "meow")
|
|
resp.StatusCode = http.StatusSwitchingProtocols
|
|
return resp, nil
|
|
}),
|
|
})
|
|
assert.Contains(t, err, "failed to WebSocket dial: WebSocket protocol violation: Upgrade header \"meow\" does not contain websocket")
|
|
}
|
|
|
|
type forwardProxy struct {
|
|
hc *http.Client
|
|
}
|
|
|
|
func newForwardProxy() *forwardProxy {
|
|
return &forwardProxy{
|
|
hc: &http.Client{},
|
|
}
|
|
}
|
|
|
|
func (fc *forwardProxy) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
|
ctx, cancel := context.WithTimeout(r.Context(), time.Second*10)
|
|
defer cancel()
|
|
|
|
r = r.WithContext(ctx)
|
|
r.RequestURI = ""
|
|
resp, err := fc.hc.Do(r)
|
|
if err != nil {
|
|
http.Error(w, err.Error(), http.StatusBadRequest)
|
|
return
|
|
}
|
|
defer resp.Body.Close()
|
|
|
|
for k, v := range resp.Header {
|
|
w.Header()[k] = v
|
|
}
|
|
w.Header().Set("PROXIED", "true")
|
|
w.WriteHeader(resp.StatusCode)
|
|
if resprw, ok := resp.Body.(io.ReadWriter); ok {
|
|
c, brw, err := w.(http.Hijacker).Hijack()
|
|
if err != nil {
|
|
http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError)
|
|
return
|
|
}
|
|
brw.Flush()
|
|
|
|
errc1 := xsync.Go(func() error {
|
|
_, err := io.Copy(c, resprw)
|
|
return err
|
|
})
|
|
errc2 := xsync.Go(func() error {
|
|
_, err := io.Copy(resprw, c)
|
|
return err
|
|
})
|
|
select {
|
|
case <-errc1:
|
|
case <-errc2:
|
|
case <-r.Context().Done():
|
|
}
|
|
} else {
|
|
io.Copy(w, resp.Body)
|
|
}
|
|
}
|
|
|
|
func TestDialViaProxy(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
ps := httptest.NewServer(newForwardProxy())
|
|
defer ps.Close()
|
|
|
|
s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
err := echoServer(w, r, nil)
|
|
assert.Success(t, err)
|
|
}))
|
|
defer s.Close()
|
|
|
|
psu, err := url.Parse(ps.URL)
|
|
assert.Success(t, err)
|
|
proxyTransport := http.DefaultTransport.(*http.Transport).Clone()
|
|
proxyTransport.Proxy = http.ProxyURL(psu)
|
|
|
|
ctx, cancel := context.WithTimeout(context.Background(), time.Second*10)
|
|
defer cancel()
|
|
c, resp, err := websocket.Dial(ctx, s.URL, &websocket.DialOptions{
|
|
HTTPClient: &http.Client{
|
|
Transport: proxyTransport,
|
|
},
|
|
})
|
|
assert.Success(t, err)
|
|
assert.Equal(t, "", "true", resp.Header.Get("PROXIED"))
|
|
|
|
assertEcho(t, ctx, c)
|
|
assertClose(t, c)
|
|
}
|