summaryrefslogtreecommitdiff
path: root/tests/wscat.go
diff options
context:
space:
mode:
Diffstat (limited to 'tests/wscat.go')
-rw-r--r--tests/wscat.go2842
1 files changed, 2842 insertions, 0 deletions
diff --git a/tests/wscat.go b/tests/wscat.go
new file mode 100644
index 0000000..da16f66
--- /dev/null
+++ b/tests/wscat.go
@@ -0,0 +1,2842 @@
+package wscat
+
+import (
+ "bufio"
+ "bytes"
+ "compress/flate"
+ "context"
+ "crypto/tls"
+ "crypto/x509"
+ "encoding/base64"
+ "encoding/binary"
+ "encoding/json"
+ "errors"
+ "fmt"
+ "io"
+ "log"
+ "math/rand"
+ "net"
+ "net/http"
+ "net/http/cookiejar"
+ "net/http/httptest"
+ "net/http/httptrace"
+ "net/url"
+ "os"
+ "reflect"
+ "strings"
+ "sync"
+ "sync/atomic"
+ "testing"
+ "testing/internal/testdeps"
+ "testing/iotest"
+ "time"
+
+ g "gobang"
+)
+
+
+
+var cstUpgrader = Upgrader{
+ Subprotocols: []string{"p0", "p1"},
+ ReadBufferSize: 1024,
+ WriteBufferSize: 1024,
+ EnableCompression: true,
+ Error: func(w http.ResponseWriter, r *http.Request, status int, reason error) {
+ http.Error(w, reason.Error(), status)
+ },
+}
+
+var cstDialer = Dialer{
+ Subprotocols: []string{"p1", "p2"},
+ ReadBufferSize: 1024,
+ WriteBufferSize: 1024,
+ HandshakeTimeout: 30 * time.Second,
+}
+
+type cstHandler struct {
+ *testing.T
+ s *cstServer
+}
+
+type cstServer struct {
+ URL string
+ Server *httptest.Server
+ wg sync.WaitGroup
+}
+
+const (
+ cstPath = "/a/b"
+ cstRawQuery = "x=y"
+ cstRequestURI = cstPath + "?" + cstRawQuery
+)
+
+func (s *cstServer) Close() {
+ s.Server.Close()
+ // Wait for handler functions to complete.
+ s.wg.Wait()
+}
+
+func newServer(t *testing.T) *cstServer {
+ var s cstServer
+ s.Server = httptest.NewServer(cstHandler{T: t, s: &s})
+ s.Server.URL += cstRequestURI
+ s.URL = makeWsProto(s.Server.URL)
+ return &s
+}
+
+func newTLSServer(t *testing.T) *cstServer {
+ var s cstServer
+ s.Server = httptest.NewTLSServer(cstHandler{T: t, s: &s})
+ s.Server.URL += cstRequestURI
+ s.URL = makeWsProto(s.Server.URL)
+ return &s
+}
+
+func (t cstHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
+ // Because tests wait for a response from a server, we are guaranteed that
+ // the wait group count is incremented before the test waits on the group
+ // in the call to (*cstServer).Close().
+ t.s.wg.Add(1)
+ defer t.s.wg.Done()
+
+ if r.URL.Path != cstPath {
+ t.Logf("path=%v, want %v", r.URL.Path, cstPath)
+ http.Error(w, "bad path", http.StatusBadRequest)
+ return
+ }
+ if r.URL.RawQuery != cstRawQuery {
+ t.Logf("query=%v, want %v", r.URL.RawQuery, cstRawQuery)
+ http.Error(w, "bad path", http.StatusBadRequest)
+ return
+ }
+ subprotos := Subprotocols(r)
+ if !reflect.DeepEqual(subprotos, cstDialer.Subprotocols) {
+ t.Logf("subprotols=%v, want %v", subprotos, cstDialer.Subprotocols)
+ http.Error(w, "bad protocol", http.StatusBadRequest)
+ return
+ }
+ ws, err := cstUpgrader.Upgrade(w, r, http.Header{"Set-Cookie": {"sessionID=1234"}})
+ if err != nil {
+ t.Logf("Upgrade: %v", err)
+ return
+ }
+ defer ws.Close()
+
+ if ws.Subprotocol() != "p1" {
+ t.Logf("Subprotocol() = %s, want p1", ws.Subprotocol())
+ ws.Close()
+ return
+ }
+ op, rd, err := ws.NextReader()
+ if err != nil {
+ t.Logf("NextReader: %v", err)
+ return
+ }
+ wr, err := ws.NextWriter(op)
+ if err != nil {
+ t.Logf("NextWriter: %v", err)
+ return
+ }
+ if _, err = io.Copy(wr, rd); err != nil {
+ t.Logf("NextWriter: %v", err)
+ return
+ }
+ if err := wr.Close(); err != nil {
+ t.Logf("Close: %v", err)
+ return
+ }
+}
+
+func makeWsProto(s string) string {
+ return "ws" + strings.TrimPrefix(s, "http")
+}
+
+func sendRecv(t *testing.T, ws *Conn) {
+ const message = "Hello World!"
+ if err := ws.SetWriteDeadline(time.Now().Add(time.Second)); err != nil {
+ t.Fatalf("SetWriteDeadline: %v", err)
+ }
+ if err := ws.WriteMessage(TextMessage, []byte(message)); err != nil {
+ t.Fatalf("WriteMessage: %v", err)
+ }
+ if err := ws.SetReadDeadline(time.Now().Add(time.Second)); err != nil {
+ t.Fatalf("SetReadDeadline: %v", err)
+ }
+ _, p, err := ws.ReadMessage()
+ if err != nil {
+ t.Fatalf("ReadMessage: %v", err)
+ }
+ if string(p) != message {
+ t.Fatalf("message=%s, want %s", p, message)
+ }
+}
+
+func TestProxyDial(t *testing.T) {
+
+ s := newServer(t)
+ defer s.Close()
+
+ surl, _ := url.Parse(s.Server.URL)
+
+ cstDialer := cstDialer // make local copy for modification on next line.
+ cstDialer.Proxy = http.ProxyURL(surl)
+
+ connect := false
+ origHandler := s.Server.Config.Handler
+
+ // Capture the request Host header.
+ s.Server.Config.Handler = http.HandlerFunc(
+ func(w http.ResponseWriter, r *http.Request) {
+ if r.Method == http.MethodConnect {
+ connect = true
+ w.WriteHeader(http.StatusOK)
+ return
+ }
+
+ if !connect {
+ t.Log("connect not received")
+ http.Error(w, "connect not received", http.StatusMethodNotAllowed)
+ return
+ }
+ origHandler.ServeHTTP(w, r)
+ })
+
+ ws, _, err := cstDialer.Dial(s.URL, nil)
+ if err != nil {
+ t.Fatalf("Dial: %v", err)
+ }
+ defer ws.Close()
+ sendRecv(t, ws)
+}
+
+func TestProxyAuthorizationDial(t *testing.T) {
+ s := newServer(t)
+ defer s.Close()
+
+ surl, _ := url.Parse(s.Server.URL)
+ surl.User = url.UserPassword("username", "password")
+
+ cstDialer := cstDialer // make local copy for modification on next line.
+ cstDialer.Proxy = http.ProxyURL(surl)
+
+ connect := false
+ origHandler := s.Server.Config.Handler
+
+ // Capture the request Host header.
+ s.Server.Config.Handler = http.HandlerFunc(
+ func(w http.ResponseWriter, r *http.Request) {
+ proxyAuth := r.Header.Get("Proxy-Authorization")
+ expectedProxyAuth := "Basic " + base64.StdEncoding.EncodeToString([]byte("username:password"))
+ if r.Method == http.MethodConnect && proxyAuth == expectedProxyAuth {
+ connect = true
+ w.WriteHeader(http.StatusOK)
+ return
+ }
+
+ if !connect {
+ t.Log("connect with proxy authorization not received")
+ http.Error(w, "connect with proxy authorization not received", http.StatusMethodNotAllowed)
+ return
+ }
+ origHandler.ServeHTTP(w, r)
+ })
+
+ ws, _, err := cstDialer.Dial(s.URL, nil)
+ if err != nil {
+ t.Fatalf("Dial: %v", err)
+ }
+ defer ws.Close()
+ sendRecv(t, ws)
+}
+
+func TestDial(t *testing.T) {
+ s := newServer(t)
+ defer s.Close()
+
+ ws, _, err := cstDialer.Dial(s.URL, nil)
+ if err != nil {
+ t.Fatalf("Dial: %v", err)
+ }
+ defer ws.Close()
+ sendRecv(t, ws)
+}
+
+func TestDialCookieJar(t *testing.T) {
+ s := newServer(t)
+ defer s.Close()
+
+ jar, _ := cookiejar.New(nil)
+ d := cstDialer
+ d.Jar = jar
+
+ u, _ := url.Parse(s.URL)
+
+ switch u.Scheme {
+ case "ws":
+ u.Scheme = "http"
+ case "wss":
+ u.Scheme = "https"
+ }
+
+ cookies := []*http.Cookie{{Name: "gorilla", Value: "ws", Path: "/"}}
+ d.Jar.SetCookies(u, cookies)
+
+ ws, _, err := d.Dial(s.URL, nil)
+ if err != nil {
+ t.Fatalf("Dial: %v", err)
+ }
+ defer ws.Close()
+
+ var gorilla string
+ var sessionID string
+ for _, c := range d.Jar.Cookies(u) {
+ if c.Name == "gorilla" {
+ gorilla = c.Value
+ }
+
+ if c.Name == "sessionID" {
+ sessionID = c.Value
+ }
+ }
+ if gorilla != "ws" {
+ t.Error("Cookie not present in jar.")
+ }
+
+ if sessionID != "1234" {
+ t.Error("Set-Cookie not received from the server.")
+ }
+
+ sendRecv(t, ws)
+}
+
+func rootCAs(t *testing.T, s *httptest.Server) *x509.CertPool {
+ certs := x509.NewCertPool()
+ for _, c := range s.TLS.Certificates {
+ roots, err := x509.ParseCertificates(c.Certificate[len(c.Certificate)-1])
+ if err != nil {
+ t.Fatalf("error parsing server's root cert: %v", err)
+ }
+ for _, root := range roots {
+ certs.AddCert(root)
+ }
+ }
+ return certs
+}
+
+func TestDialTLS(t *testing.T) {
+ s := newTLSServer(t)
+ defer s.Close()
+
+ d := cstDialer
+ d.TLSClientConfig = &tls.Config{RootCAs: rootCAs(t, s.Server)}
+ ws, _, err := d.Dial(s.URL, nil)
+ if err != nil {
+ t.Fatalf("Dial: %v", err)
+ }
+ defer ws.Close()
+ sendRecv(t, ws)
+}
+
+func TestDialTimeout(t *testing.T) {
+ s := newServer(t)
+ defer s.Close()
+
+ d := cstDialer
+ d.HandshakeTimeout = -1
+ ws, _, err := d.Dial(s.URL, nil)
+ if err == nil {
+ ws.Close()
+ t.Fatalf("Dial: nil")
+ }
+}
+
+// requireDeadlineNetConn fails the current test when Read or Write are called
+// with no deadline.
+type requireDeadlineNetConn struct {
+ t *testing.T
+ c net.Conn
+ readDeadlineIsSet bool
+ writeDeadlineIsSet bool
+}
+
+func (c *requireDeadlineNetConn) SetDeadline(t time.Time) error {
+ c.writeDeadlineIsSet = !t.Equal(time.Time{})
+ c.readDeadlineIsSet = c.writeDeadlineIsSet
+ return c.c.SetDeadline(t)
+}
+
+func (c *requireDeadlineNetConn) SetReadDeadline(t time.Time) error {
+ c.readDeadlineIsSet = !t.Equal(time.Time{})
+ return c.c.SetDeadline(t)
+}
+
+func (c *requireDeadlineNetConn) SetWriteDeadline(t time.Time) error {
+ c.writeDeadlineIsSet = !t.Equal(time.Time{})
+ return c.c.SetDeadline(t)
+}
+
+func (c *requireDeadlineNetConn) Write(p []byte) (int, error) {
+ if !c.writeDeadlineIsSet {
+ c.t.Fatalf("write with no deadline")
+ }
+ return c.c.Write(p)
+}
+
+func (c *requireDeadlineNetConn) Read(p []byte) (int, error) {
+ if !c.readDeadlineIsSet {
+ c.t.Fatalf("read with no deadline")
+ }
+ return c.c.Read(p)
+}
+
+func (c *requireDeadlineNetConn) Close() error { return c.c.Close() }
+func (c *requireDeadlineNetConn) LocalAddr() net.Addr { return c.c.LocalAddr() }
+func (c *requireDeadlineNetConn) RemoteAddr() net.Addr { return c.c.RemoteAddr() }
+
+func TestHandshakeTimeout(t *testing.T) {
+ s := newServer(t)
+ defer s.Close()
+
+ d := cstDialer
+ d.NetDial = func(n, a string) (net.Conn, error) {
+ c, err := net.Dial(n, a)
+ return &requireDeadlineNetConn{c: c, t: t}, err
+ }
+ ws, _, err := d.Dial(s.URL, nil)
+ if err != nil {
+ t.Fatal("Dial:", err)
+ }
+ ws.Close()
+}
+
+func TestHandshakeTimeoutInContext(t *testing.T) {
+ s := newServer(t)
+ defer s.Close()
+
+ d := cstDialer
+ d.HandshakeTimeout = 0
+ d.NetDialContext = func(ctx context.Context, n, a string) (net.Conn, error) {
+ netDialer := &net.Dialer{}
+ c, err := netDialer.DialContext(ctx, n, a)
+ return &requireDeadlineNetConn{c: c, t: t}, err
+ }
+
+ ctx, cancel := context.WithDeadline(context.Background(), time.Now().Add(30*time.Second))
+ defer cancel()
+ ws, _, err := d.DialContext(ctx, s.URL, nil)
+ if err != nil {
+ t.Fatal("Dial:", err)
+ }
+ ws.Close()
+}
+
+func TestDialBadScheme(t *testing.T) {
+ s := newServer(t)
+ defer s.Close()
+
+ ws, _, err := cstDialer.Dial(s.Server.URL, nil)
+ if err == nil {
+ ws.Close()
+ t.Fatalf("Dial: nil")
+ }
+}
+
+func TestDialBadOrigin(t *testing.T) {
+ s := newServer(t)
+ defer s.Close()
+
+ ws, resp, err := cstDialer.Dial(s.URL, http.Header{"Origin": {"bad"}})
+ if err == nil {
+ ws.Close()
+ t.Fatalf("Dial: nil")
+ }
+ if resp == nil {
+ t.Fatalf("resp=nil, err=%v", err)
+ }
+ if resp.StatusCode != http.StatusForbidden {
+ t.Fatalf("status=%d, want %d", resp.StatusCode, http.StatusForbidden)
+ }
+}
+
+func TestDialBadHeader(t *testing.T) {
+ s := newServer(t)
+ defer s.Close()
+
+ for _, k := range []string{"Upgrade",
+ "Connection",
+ "Sec-Websocket-Key",
+ "Sec-Websocket-Version",
+ "Sec-Websocket-Protocol"} {
+ h := http.Header{}
+ h.Set(k, "bad")
+ ws, _, err := cstDialer.Dial(s.URL, http.Header{"Origin": {"bad"}})
+ if err == nil {
+ ws.Close()
+ t.Errorf("Dial with header %s returned nil", k)
+ }
+ }
+}
+
+func TestBadMethod(t *testing.T) {
+ s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ ws, err := cstUpgrader.Upgrade(w, r, nil)
+ if err == nil {
+ t.Errorf("handshake succeeded, expect fail")
+ ws.Close()
+ }
+ }))
+ defer s.Close()
+
+ req, err := http.NewRequest(http.MethodPost, s.URL, strings.NewReader(""))
+ if err != nil {
+ t.Fatalf("NewRequest returned error %v", err)
+ }
+ req.Header.Set("Connection", "upgrade")
+ req.Header.Set("Upgrade", "websocket")
+ req.Header.Set("Sec-Websocket-Version", "13")
+
+ resp, err := http.DefaultClient.Do(req)
+ if err != nil {
+ t.Fatalf("Do returned error %v", err)
+ }
+ resp.Body.Close()
+ if resp.StatusCode != http.StatusMethodNotAllowed {
+ t.Errorf("Status = %d, want %d", resp.StatusCode, http.StatusMethodNotAllowed)
+ }
+}
+
+func TestNoUpgrade(t *testing.T) {
+ t.Parallel()
+ s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ ws, err := cstUpgrader.Upgrade(w, r, nil)
+ if err == nil {
+ t.Errorf("handshake succeeded, expect fail")
+ ws.Close()
+ }
+ }))
+ defer s.Close()
+
+ req, err := http.NewRequest(http.MethodGet, s.URL, strings.NewReader(""))
+ if err != nil {
+ t.Fatalf("NewRequest returned error %v", err)
+ }
+ req.Header.Set("Connection", "upgrade")
+ req.Header.Set("Sec-Websocket-Version", "13")
+
+ resp, err := http.DefaultClient.Do(req)
+ if err != nil {
+ t.Fatalf("Do returned error %v", err)
+ }
+ resp.Body.Close()
+ if u := resp.Header.Get("Upgrade"); u != "websocket" {
+ t.Errorf("Upgrade response header is %q, want %q", u, "websocket")
+ }
+ if resp.StatusCode != http.StatusUpgradeRequired {
+ t.Errorf("Status = %d, want %d", resp.StatusCode, http.StatusUpgradeRequired)
+ }
+}
+
+func TestDialExtraTokensInRespHeaders(t *testing.T) {
+ s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ challengeKey := r.Header.Get("Sec-Websocket-Key")
+ w.Header().Set("Upgrade", "foo, websocket")
+ w.Header().Set("Connection", "upgrade, keep-alive")
+ w.Header().Set("Sec-Websocket-Accept", computeAcceptKey(challengeKey))
+ w.WriteHeader(101)
+ }))
+ defer s.Close()
+
+ ws, _, err := cstDialer.Dial(makeWsProto(s.URL), nil)
+ if err != nil {
+ t.Fatalf("Dial: %v", err)
+ }
+ defer ws.Close()
+}
+
+func TestHandshake(t *testing.T) {
+ s := newServer(t)
+ defer s.Close()
+
+ ws, resp, err := cstDialer.Dial(s.URL, http.Header{"Origin": {s.URL}})
+ if err != nil {
+ t.Fatalf("Dial: %v", err)
+ }
+ defer ws.Close()
+
+ var sessionID string
+ for _, c := range resp.Cookies() {
+ if c.Name == "sessionID" {
+ sessionID = c.Value
+ }
+ }
+ if sessionID != "1234" {
+ t.Error("Set-Cookie not received from the server.")
+ }
+
+ if ws.Subprotocol() != "p1" {
+ t.Errorf("ws.Subprotocol() = %s, want p1", ws.Subprotocol())
+ }
+ sendRecv(t, ws)
+}
+
+func TestRespOnBadHandshake(t *testing.T) {
+ const expectedStatus = http.StatusGone
+ const expectedBody = "This is the response body."
+
+ s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ w.WriteHeader(expectedStatus)
+ _, _ = io.WriteString(w, expectedBody)
+ }))
+ defer s.Close()
+
+ ws, resp, err := cstDialer.Dial(makeWsProto(s.URL), nil)
+ if err == nil {
+ ws.Close()
+ t.Fatalf("Dial: nil")
+ }
+
+ if resp == nil {
+ t.Fatalf("resp=nil, err=%v", err)
+ }
+
+ if resp.StatusCode != expectedStatus {
+ t.Errorf("resp.StatusCode=%d, want %d", resp.StatusCode, expectedStatus)
+ }
+
+ p, err := io.ReadAll(resp.Body)
+ if err != nil {
+ t.Fatalf("ReadFull(resp.Body) returned error %v", err)
+ }
+
+ if string(p) != expectedBody {
+ t.Errorf("resp.Body=%s, want %s", p, expectedBody)
+ }
+}
+
+type testLogWriter struct {
+ t *testing.T
+}
+
+func (w testLogWriter) Write(p []byte) (int, error) {
+ w.t.Logf("%s", p)
+ return len(p), nil
+}
+
+// TestHost tests handling of host names and confirms that it matches net/http.
+func TestHost(t *testing.T) {
+
+ upgrader := Upgrader{}
+ handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ if IsWebSocketUpgrade(r) {
+ c, err := upgrader.Upgrade(w, r, http.Header{"X-Test-Host": {r.Host}})
+ if err != nil {
+ t.Fatal(err)
+ }
+ c.Close()
+ } else {
+ w.Header().Set("X-Test-Host", r.Host)
+ }
+ })
+
+ server := httptest.NewServer(handler)
+ defer server.Close()
+
+ tlsServer := httptest.NewTLSServer(handler)
+ defer tlsServer.Close()
+
+ addrs := map[*httptest.Server]string{server: server.Listener.Addr().String(), tlsServer: tlsServer.Listener.Addr().String()}
+ wsProtos := map[*httptest.Server]string{server: "ws://", tlsServer: "wss://"}
+ httpProtos := map[*httptest.Server]string{server: "http://", tlsServer: "https://"}
+
+ // Avoid log noise from net/http server by logging to testing.T
+ server.Config.ErrorLog = log.New(testLogWriter{t}, "", 0)
+ tlsServer.Config.ErrorLog = server.Config.ErrorLog
+
+ cas := rootCAs(t, tlsServer)
+
+ tests := []struct {
+ fail bool // true if dial / get should fail
+ server *httptest.Server // server to use
+ url string // host for request URI
+ header string // optional request host header
+ tls string // optional host for tls ServerName
+ wantAddr string // expected host for dial
+ wantHeader string // expected request header on server
+ insecureSkipVerify bool
+ }{
+ {
+ server: server,
+ url: addrs[server],
+ wantAddr: addrs[server],
+ wantHeader: addrs[server],
+ },
+ {
+ server: tlsServer,
+ url: addrs[tlsServer],
+ wantAddr: addrs[tlsServer],
+ wantHeader: addrs[tlsServer],
+ },
+
+ {
+ server: server,
+ url: addrs[server],
+ header: "badhost.com",
+ wantAddr: addrs[server],
+ wantHeader: "badhost.com",
+ },
+ {
+ server: tlsServer,
+ url: addrs[tlsServer],
+ header: "badhost.com",
+ wantAddr: addrs[tlsServer],
+ wantHeader: "badhost.com",
+ },
+
+ {
+ server: server,
+ url: "example.com",
+ header: "badhost.com",
+ wantAddr: "example.com:80",
+ wantHeader: "badhost.com",
+ },
+ {
+ server: tlsServer,
+ url: "example.com",
+ header: "badhost.com",
+ wantAddr: "example.com:443",
+ wantHeader: "badhost.com",
+ },
+
+ {
+ server: server,
+ url: "badhost.com",
+ header: "example.com",
+ wantAddr: "badhost.com:80",
+ wantHeader: "example.com",
+ },
+ {
+ fail: true,
+ server: tlsServer,
+ url: "badhost.com",
+ header: "example.com",
+ wantAddr: "badhost.com:443",
+ },
+ {
+ server: tlsServer,
+ url: "badhost.com",
+ insecureSkipVerify: true,
+ wantAddr: "badhost.com:443",
+ wantHeader: "badhost.com",
+ },
+ {
+ server: tlsServer,
+ url: "badhost.com",
+ tls: "example.com",
+ wantAddr: "badhost.com:443",
+ wantHeader: "badhost.com",
+ },
+ }
+
+ for i, tt := range tests {
+
+ tls := &tls.Config{
+ RootCAs: cas,
+ ServerName: tt.tls,
+ InsecureSkipVerify: tt.insecureSkipVerify,
+ }
+
+ var gotAddr string
+ dialer := Dialer{
+ NetDial: func(network, addr string) (net.Conn, error) {
+ gotAddr = addr
+ return net.Dial(network, addrs[tt.server])
+ },
+ TLSClientConfig: tls,
+ }
+
+ // Test websocket dial
+
+ h := http.Header{}
+ if tt.header != "" {
+ h.Set("Host", tt.header)
+ }
+ c, resp, err := dialer.Dial(wsProtos[tt.server]+tt.url+"/", h)
+ if err == nil {
+ c.Close()
+ }
+
+ check := func(protos map[*httptest.Server]string) {
+ name := fmt.Sprintf("%d: %s%s/ header[Host]=%q, tls.ServerName=%q", i+1, protos[tt.server], tt.url, tt.header, tt.tls)
+ if gotAddr != tt.wantAddr {
+ t.Errorf("%s: got addr %s, want %s", name, gotAddr, tt.wantAddr)
+ }
+ switch {
+ case tt.fail && err == nil:
+ t.Errorf("%s: unexpected success", name)
+ case !tt.fail && err != nil:
+ t.Errorf("%s: unexpected error %v", name, err)
+ case !tt.fail && err == nil:
+ if gotHost := resp.Header.Get("X-Test-Host"); gotHost != tt.wantHeader {
+ t.Errorf("%s: got host %s, want %s", name, gotHost, tt.wantHeader)
+ }
+ }
+ }
+
+ check(wsProtos)
+
+ // Confirm that net/http has same result
+
+ transport := &http.Transport{
+ Dial: dialer.NetDial,
+ TLSClientConfig: dialer.TLSClientConfig,
+ }
+ req, _ := http.NewRequest(http.MethodGet, httpProtos[tt.server]+tt.url+"/", nil)
+ if tt.header != "" {
+ req.Host = tt.header
+ }
+ client := &http.Client{Transport: transport}
+ resp, err = client.Do(req)
+ if err == nil {
+ resp.Body.Close()
+ }
+ transport.CloseIdleConnections()
+ check(httpProtos)
+ }
+}
+
+func TestDialCompression(t *testing.T) {
+ s := newServer(t)
+ defer s.Close()
+
+ dialer := cstDialer
+ dialer.EnableCompression = true
+ ws, _, err := dialer.Dial(s.URL, nil)
+ if err != nil {
+ t.Fatalf("Dial: %v", err)
+ }
+ defer ws.Close()
+ sendRecv(t, ws)
+}
+
+func TestSocksProxyDial(t *testing.T) {
+ s := newServer(t)
+ defer s.Close()
+
+ proxyListener, err := net.Listen("tcp", "127.0.0.1:0")
+ if err != nil {
+ t.Fatalf("listen failed: %v", err)
+ }
+ defer proxyListener.Close()
+ go func() {
+ c1, err := proxyListener.Accept()
+ if err != nil {
+ t.Errorf("proxy accept failed: %v", err)
+ return
+ }
+ defer c1.Close()
+
+ _ = c1.SetDeadline(time.Now().Add(30 * time.Second))
+
+ buf := make([]byte, 32)
+ if _, err := io.ReadFull(c1, buf[:3]); err != nil {
+ t.Errorf("read failed: %v", err)
+ return
+ }
+ if want := []byte{5, 1, 0}; !bytes.Equal(want, buf[:len(want)]) {
+ t.Errorf("read %x, want %x", buf[:len(want)], want)
+ }
+ if _, err := c1.Write([]byte{5, 0}); err != nil {
+ t.Errorf("write failed: %v", err)
+ return
+ }
+ if _, err := io.ReadFull(c1, buf[:10]); err != nil {
+ t.Errorf("read failed: %v", err)
+ return
+ }
+ if want := []byte{5, 1, 0, 1}; !bytes.Equal(want, buf[:len(want)]) {
+ t.Errorf("read %x, want %x", buf[:len(want)], want)
+ return
+ }
+ buf[1] = 0
+ if _, err := c1.Write(buf[:10]); err != nil {
+ t.Errorf("write failed: %v", err)
+ return
+ }
+
+ ip := net.IP(buf[4:8])
+ port := binary.BigEndian.Uint16(buf[8:10])
+
+ c2, err := net.DialTCP("tcp", nil, &net.TCPAddr{IP: ip, Port: int(port)})
+ if err != nil {
+ t.Errorf("dial failed; %v", err)
+ return
+ }
+ defer c2.Close()
+ done := make(chan struct{})
+ go func() {
+ _, _ = io.Copy(c1, c2)
+ close(done)
+ }()
+ _, _ = io.Copy(c2, c1)
+ <-done
+ }()
+
+ purl, err := url.Parse("socks5://" + proxyListener.Addr().String())
+ if err != nil {
+ t.Fatalf("parse failed: %v", err)
+ }
+
+ cstDialer := cstDialer // make local copy for modification on next line.
+ cstDialer.Proxy = http.ProxyURL(purl)
+
+ ws, _, err := cstDialer.Dial(s.URL, nil)
+ if err != nil {
+ t.Fatalf("Dial: %v", err)
+ }
+ defer ws.Close()
+ sendRecv(t, ws)
+}
+
+func TestTracingDialWithContext(t *testing.T) {
+
+ var headersWrote, requestWrote, getConn, gotConn, connectDone, gotFirstResponseByte bool
+ trace := &httptrace.ClientTrace{
+ WroteHeaders: func() {
+ headersWrote = true
+ },
+ WroteRequest: func(httptrace.WroteRequestInfo) {
+ requestWrote = true
+ },
+ GetConn: func(hostPort string) {
+ getConn = true
+ },
+ GotConn: func(info httptrace.GotConnInfo) {
+ gotConn = true
+ },
+ ConnectDone: func(network, addr string, err error) {
+ connectDone = true
+ },
+ GotFirstResponseByte: func() {
+ gotFirstResponseByte = true
+ },
+ }
+ ctx := httptrace.WithClientTrace(context.Background(), trace)
+
+ s := newTLSServer(t)
+ defer s.Close()
+
+ d := cstDialer
+ d.TLSClientConfig = &tls.Config{RootCAs: rootCAs(t, s.Server)}
+
+ ws, _, err := d.DialContext(ctx, s.URL, nil)
+ if err != nil {
+ t.Fatalf("Dial: %v", err)
+ }
+
+ if !headersWrote {
+ t.Fatal("Headers was not written")
+ }
+ if !requestWrote {
+ t.Fatal("Request was not written")
+ }
+ if !getConn {
+ t.Fatal("getConn was not called")
+ }
+ if !gotConn {
+ t.Fatal("gotConn was not called")
+ }
+ if !connectDone {
+ t.Fatal("connectDone was not called")
+ }
+ if !gotFirstResponseByte {
+ t.Fatal("GotFirstResponseByte was not called")
+ }
+
+ defer ws.Close()
+ sendRecv(t, ws)
+}
+
+func TestEmptyTracingDialWithContext(t *testing.T) {
+
+ trace := &httptrace.ClientTrace{}
+ ctx := httptrace.WithClientTrace(context.Background(), trace)
+
+ s := newTLSServer(t)
+ defer s.Close()
+
+ d := cstDialer
+ d.TLSClientConfig = &tls.Config{RootCAs: rootCAs(t, s.Server)}
+
+ ws, _, err := d.DialContext(ctx, s.URL, nil)
+ if err != nil {
+ t.Fatalf("Dial: %v", err)
+ }
+
+ defer ws.Close()
+ sendRecv(t, ws)
+}
+
+// TestNetDialConnect tests selection of dial method between NetDial, NetDialContext, NetDialTLS or NetDialTLSContext
+func TestNetDialConnect(t *testing.T) {
+
+ upgrader := Upgrader{}
+ handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ if IsWebSocketUpgrade(r) {
+ c, err := upgrader.Upgrade(w, r, http.Header{"X-Test-Host": {r.Host}})
+ if err != nil {
+ t.Fatal(err)
+ }
+ c.Close()
+ } else {
+ w.Header().Set("X-Test-Host", r.Host)
+ }
+ })
+
+ server := httptest.NewServer(handler)
+ defer server.Close()
+
+ tlsServer := httptest.NewTLSServer(handler)
+ defer tlsServer.Close()
+
+ testUrls := map[*httptest.Server]string{
+ server: "ws://" + server.Listener.Addr().String() + "/",
+ tlsServer: "wss://" + tlsServer.Listener.Addr().String() + "/",
+ }
+
+ cas := rootCAs(t, tlsServer)
+ tlsConfig := &tls.Config{
+ RootCAs: cas,
+ ServerName: "example.com",
+ InsecureSkipVerify: false,
+ }
+
+ tests := []struct {
+ name string
+ server *httptest.Server // server to use
+ netDial func(network, addr string) (net.Conn, error)
+ netDialContext func(ctx context.Context, network, addr string) (net.Conn, error)
+ netDialTLSContext func(ctx context.Context, network, addr string) (net.Conn, error)
+ tlsClientConfig *tls.Config
+ }{
+
+ {
+ name: "HTTP server, all NetDial* defined, shall use NetDialContext",
+ server: server,
+ netDial: func(network, addr string) (net.Conn, error) {
+ return nil, errors.New("NetDial should not be called")
+ },
+ netDialContext: func(_ context.Context, network, addr string) (net.Conn, error) {
+ return net.Dial(network, addr)
+ },
+ netDialTLSContext: func(_ context.Context, network, addr string) (net.Conn, error) {
+ return nil, errors.New("NetDialTLSContext should not be called")
+ },
+ tlsClientConfig: nil,
+ },
+ {
+ name: "HTTP server, all NetDial* undefined",
+ server: server,
+ netDial: nil,
+ netDialContext: nil,
+ netDialTLSContext: nil,
+ tlsClientConfig: nil,
+ },
+ {
+ name: "HTTP server, NetDialContext undefined, shall fallback to NetDial",
+ server: server,
+ netDial: func(network, addr string) (net.Conn, error) {
+ return net.Dial(network, addr)
+ },
+ netDialContext: nil,
+ netDialTLSContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
+ return nil, errors.New("NetDialTLSContext should not be called")
+ },
+ tlsClientConfig: nil,
+ },
+ {
+ name: "HTTPS server, all NetDial* defined, shall use NetDialTLSContext",
+ server: tlsServer,
+ netDial: func(network, addr string) (net.Conn, error) {
+ return nil, errors.New("NetDial should not be called")
+ },
+ netDialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
+ return nil, errors.New("NetDialContext should not be called")
+ },
+ netDialTLSContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
+ netConn, err := net.Dial(network, addr)
+ if err != nil {
+ return nil, err
+ }
+ tlsConn := tls.Client(netConn, tlsConfig)
+ err = tlsConn.Handshake()
+ if err != nil {
+ return nil, err
+ }
+ return tlsConn, nil
+ },
+ tlsClientConfig: nil,
+ },
+ {
+ name: "HTTPS server, NetDialTLSContext undefined, shall fallback to NetDialContext and do handshake",
+ server: tlsServer,
+ netDial: func(network, addr string) (net.Conn, error) {
+ return nil, errors.New("NetDial should not be called")
+ },
+ netDialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
+ return net.Dial(network, addr)
+ },
+ netDialTLSContext: nil,
+ tlsClientConfig: tlsConfig,
+ },
+ {
+ name: "HTTPS server, NetDialTLSContext and NetDialContext undefined, shall fallback to NetDial and do handshake",
+ server: tlsServer,
+ netDial: func(network, addr string) (net.Conn, error) {
+ return net.Dial(network, addr)
+ },
+ netDialContext: nil,
+ netDialTLSContext: nil,
+ tlsClientConfig: tlsConfig,
+ },
+ {
+ name: "HTTPS server, all NetDial* undefined",
+ server: tlsServer,
+ netDial: nil,
+ netDialContext: nil,
+ netDialTLSContext: nil,
+ tlsClientConfig: tlsConfig,
+ },
+ {
+ name: "HTTPS server, all NetDialTLSContext defined, dummy TlsClientConfig defined, shall not do handshake",
+ server: tlsServer,
+ netDial: func(network, addr string) (net.Conn, error) {
+ return nil, errors.New("NetDial should not be called")
+ },
+ netDialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
+ return nil, errors.New("NetDialContext should not be called")
+ },
+ netDialTLSContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
+ netConn, err := net.Dial(network, addr)
+ if err != nil {
+ return nil, err
+ }
+ tlsConn := tls.Client(netConn, tlsConfig)
+ err = tlsConn.Handshake()
+ if err != nil {
+ return nil, err
+ }
+ return tlsConn, nil
+ },
+ tlsClientConfig: &tls.Config{
+ RootCAs: nil,
+ ServerName: "badserver.com",
+ InsecureSkipVerify: false,
+ },
+ },
+ }
+
+ for _, tc := range tests {
+ dialer := Dialer{
+ NetDial: tc.netDial,
+ NetDialContext: tc.netDialContext,
+ NetDialTLSContext: tc.netDialTLSContext,
+ TLSClientConfig: tc.tlsClientConfig,
+ }
+
+ // Test websocket dial
+ c, _, err := dialer.Dial(testUrls[tc.server], nil)
+ if err != nil {
+ t.Errorf("FAILED %s, err: %s", tc.name, err.Error())
+ } else {
+ c.Close()
+ }
+ }
+}
+func TestNextProtos(t *testing.T) {
+ ts := httptest.NewUnstartedServer(
+ http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}),
+ )
+ ts.EnableHTTP2 = true
+ ts.StartTLS()
+ defer ts.Close()
+
+ d := Dialer{
+ TLSClientConfig: ts.Client().Transport.(*http.Transport).TLSClientConfig,
+ }
+
+ r, err := ts.Client().Get(ts.URL)
+ if err != nil {
+ t.Fatalf("Get: %v", err)
+ }
+ r.Body.Close()
+
+ // Asserts that Dialer.TLSClientConfig.NextProtos contains "h2"
+ // after the Client.Get call from net/http above.
+ var containsHTTP2 bool = false
+ for _, proto := range d.TLSClientConfig.NextProtos {
+ if proto == "h2" {
+ containsHTTP2 = true
+ }
+ }
+ if !containsHTTP2 {
+ t.Fatalf("Dialer.TLSClientConfig.NextProtos does not contain \"h2\"")
+ }
+
+ _, _, err = d.Dial(makeWsProto(ts.URL), nil)
+ if err == nil {
+ t.Fatalf("Dial succeeded, expect fail ")
+ }
+}
+
+type dataBeforeHandshakeResponseWriter struct {
+ http.ResponseWriter
+}
+
+type dataBeforeHandshakeConnection struct {
+ net.Conn
+ io.Reader
+}
+
+func (c *dataBeforeHandshakeConnection) Read(p []byte) (int, error) {
+ return c.Reader.Read(p)
+}
+
+func (w dataBeforeHandshakeResponseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) {
+ // Example single-frame masked text message from section 5.7 of the RFC.
+ message := []byte{0x81, 0x85, 0x37, 0xfa, 0x21, 0x3d, 0x7f, 0x9f, 0x4d, 0x51, 0x58}
+ n := len(message) / 2
+
+ c, rw, err := http.NewResponseController(w.ResponseWriter).Hijack()
+ if rw != nil {
+ // Load first part of message into bufio.Reader. If the websocket
+ // connection reads more than n bytes from the bufio.Reader, then the
+ // test will fail with an unexpected EOF error.
+ rw.Reader.Reset(bytes.NewReader(message[:n]))
+ rw.Reader.Peek(n)
+ }
+ if c != nil {
+ // Inject second part of message before data read from the network connection.
+ c = &dataBeforeHandshakeConnection{
+ Conn: c,
+ Reader: io.MultiReader(bytes.NewReader(message[n:]), c),
+ }
+ }
+ return c, rw, err
+}
+
+func TestDataReceivedBeforeHandshake(t *testing.T) {
+ s := newServer(t)
+ defer s.Close()
+
+ origHandler := s.Server.Config.Handler
+ s.Server.Config.Handler = http.HandlerFunc(
+ func(w http.ResponseWriter, r *http.Request) {
+ origHandler.ServeHTTP(dataBeforeHandshakeResponseWriter{w}, r)
+ })
+
+ for _, readBufferSize := range []int{0, 1024} {
+ t.Run(fmt.Sprintf("ReadBufferSize=%d", readBufferSize), func(t *testing.T) {
+ dialer := cstDialer
+ dialer.ReadBufferSize = readBufferSize
+ ws, _, err := cstDialer.Dial(s.URL, nil)
+ if err != nil {
+ t.Fatalf("Dial: %v", err)
+ }
+ defer ws.Close()
+ _, m, err := ws.ReadMessage()
+ if err != nil || string(m) != "Hello" {
+ t.Fatalf("ReadMessage() = %q, %v, want \"Hello\", nil", m, err)
+ }
+ })
+ }
+}
+// Copyright 2014 The Gorilla WebSocket Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+
+var hostPortNoPortTests = []struct {
+ u *url.URL
+ hostPort, hostNoPort string
+}{
+ {&url.URL{Scheme: "ws", Host: "example.com"}, "example.com:80", "example.com"},
+ {&url.URL{Scheme: "wss", Host: "example.com"}, "example.com:443", "example.com"},
+ {&url.URL{Scheme: "ws", Host: "example.com:7777"}, "example.com:7777", "example.com"},
+ {&url.URL{Scheme: "wss", Host: "example.com:7777"}, "example.com:7777", "example.com"},
+}
+
+func TestHostPortNoPort(t *testing.T) {
+ for _, tt := range hostPortNoPortTests {
+ hostPort, hostNoPort := hostPortNoPort(tt.u)
+ if hostPort != tt.hostPort {
+ t.Errorf("hostPortNoPort(%v) returned hostPort %q, want %q", tt.u, hostPort, tt.hostPort)
+ }
+ if hostNoPort != tt.hostNoPort {
+ t.Errorf("hostPortNoPort(%v) returned hostNoPort %q, want %q", tt.u, hostNoPort, tt.hostNoPort)
+ }
+ }
+}
+
+type nopCloser struct{ io.Writer }
+
+func (nopCloser) Close() error { return nil }
+
+func TestTruncWriter(t *testing.T) {
+ const data = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijlkmnopqrstuvwxyz987654321"
+ for n := 1; n <= 10; n++ {
+ var b bytes.Buffer
+ w := &truncWriter{w: nopCloser{&b}}
+ p := []byte(data)
+ for len(p) > 0 {
+ m := len(p)
+ if m > n {
+ m = n
+ }
+ _, _ = w.Write(p[:m])
+ p = p[m:]
+ }
+ if b.String() != data[:len(data)-len(w.p)] {
+ t.Errorf("%d: %q", n, b.String())
+ }
+ }
+}
+
+func textMessages(num int) [][]byte {
+ messages := make([][]byte, num)
+ for i := 0; i < num; i++ {
+ msg := fmt.Sprintf("planet: %d, country: %d, city: %d, street: %d", i, i, i, i)
+ messages[i] = []byte(msg)
+ }
+ return messages
+}
+
+func BenchmarkWriteNoCompression(b *testing.B) {
+ w := io.Discard
+ c := newTestConn(nil, w, false)
+ messages := textMessages(100)
+ b.ResetTimer()
+ for i := 0; i < b.N; i++ {
+ _ = c.WriteMessage(TextMessage, messages[i%len(messages)])
+ }
+ b.ReportAllocs()
+}
+
+func BenchmarkWriteWithCompression(b *testing.B) {
+ w := io.Discard
+ c := newTestConn(nil, w, false)
+ messages := textMessages(100)
+ c.enableWriteCompression = true
+ c.newCompressionWriter = compressNoContextTakeover
+ b.ResetTimer()
+ for i := 0; i < b.N; i++ {
+ _ = c.WriteMessage(TextMessage, messages[i%len(messages)])
+ }
+ b.ReportAllocs()
+}
+
+func TestValidCompressionLevel(t *testing.T) {
+ c := newTestConn(nil, nil, false)
+ for _, level := range []int{minCompressionLevel - 1, maxCompressionLevel + 1} {
+ if err := c.SetCompressionLevel(level); err == nil {
+ t.Errorf("no error for level %d", level)
+ }
+ }
+ for _, level := range []int{minCompressionLevel, maxCompressionLevel} {
+ if err := c.SetCompressionLevel(level); err != nil {
+ t.Errorf("error for level %d", level)
+ }
+ }
+}
+// Copyright 2017 The Gorilla WebSocket Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+
+// broadcastBench allows to run broadcast benchmarks.
+// In every broadcast benchmark we create many connections, then send the same
+// message into every connection and wait for all writes complete. This emulates
+// an application where many connections listen to the same data - i.e. PUB/SUB
+// scenarios with many subscribers in one channel.
+type broadcastBench struct {
+ w io.Writer
+ closeCh chan struct{}
+ doneCh chan struct{}
+ count int32
+ conns []*broadcastConn
+ compression bool
+ usePrepared bool
+}
+
+type broadcastMessage struct {
+ payload []byte
+ prepared *PreparedMessage
+}
+
+type broadcastConn struct {
+ conn *Conn
+ msgCh chan *broadcastMessage
+}
+
+func newBroadcastConn(c *Conn) *broadcastConn {
+ return &broadcastConn{
+ conn: c,
+ msgCh: make(chan *broadcastMessage, 1),
+ }
+}
+
+func newBroadcastBench(usePrepared, compression bool) *broadcastBench {
+ bench := &broadcastBench{
+ w: io.Discard,
+ doneCh: make(chan struct{}),
+ closeCh: make(chan struct{}),
+ usePrepared: usePrepared,
+ compression: compression,
+ }
+ bench.makeConns(10000)
+ return bench
+}
+
+func (b *broadcastBench) makeConns(numConns int) {
+ conns := make([]*broadcastConn, numConns)
+
+ for i := 0; i < numConns; i++ {
+ c := newTestConn(nil, b.w, true)
+ if b.compression {
+ c.enableWriteCompression = true
+ c.newCompressionWriter = compressNoContextTakeover
+ }
+ conns[i] = newBroadcastConn(c)
+ go func(c *broadcastConn) {
+ for {
+ select {
+ case msg := <-c.msgCh:
+ if msg.prepared != nil {
+ _ = c.conn.WritePreparedMessage(msg.prepared)
+ } else {
+ _ = c.conn.WriteMessage(TextMessage, msg.payload)
+ }
+ val := atomic.AddInt32(&b.count, 1)
+ if val%int32(numConns) == 0 {
+ b.doneCh <- struct{}{}
+ }
+ case <-b.closeCh:
+ return
+ }
+ }
+ }(conns[i])
+ }
+ b.conns = conns
+}
+
+func (b *broadcastBench) close() {
+ close(b.closeCh)
+}
+
+func (b *broadcastBench) broadcastOnce(msg *broadcastMessage) {
+ for _, c := range b.conns {
+ c.msgCh <- msg
+ }
+ <-b.doneCh
+}
+
+func BenchmarkBroadcast(b *testing.B) {
+ benchmarks := []struct {
+ name string
+ usePrepared bool
+ compression bool
+ }{
+ {"NoCompression", false, false},
+ {"Compression", false, true},
+ {"NoCompressionPrepared", true, false},
+ {"CompressionPrepared", true, true},
+ }
+ payload := textMessages(1)[0]
+ for _, bm := range benchmarks {
+ b.Run(bm.name, func(b *testing.B) {
+ bench := newBroadcastBench(bm.usePrepared, bm.compression)
+ defer bench.close()
+ b.ResetTimer()
+ for i := 0; i < b.N; i++ {
+ message := &broadcastMessage{
+ payload: payload,
+ }
+ if bench.usePrepared {
+ pm, _ := NewPreparedMessage(TextMessage, message.payload)
+ message.prepared = pm
+ }
+ bench.broadcastOnce(message)
+ }
+ b.ReportAllocs()
+ })
+ }
+}
+// Copyright 2013 The Gorilla WebSocket Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+
+var _ net.Error = errWriteTimeout
+
+type fakeNetConn struct {
+ io.Reader
+ io.Writer
+}
+
+func (c fakeNetConn) Close() error { return nil }
+func (c fakeNetConn) LocalAddr() net.Addr { return localAddr }
+func (c fakeNetConn) RemoteAddr() net.Addr { return remoteAddr }
+func (c fakeNetConn) SetDeadline(t time.Time) error { return nil }
+func (c fakeNetConn) SetReadDeadline(t time.Time) error { return nil }
+func (c fakeNetConn) SetWriteDeadline(t time.Time) error { return nil }
+
+type fakeAddr int
+
+var (
+ localAddr = fakeAddr(1)
+ remoteAddr = fakeAddr(2)
+)
+
+func (a fakeAddr) Network() string {
+ return "net"
+}
+
+func (a fakeAddr) String() string {
+ return "str"
+}
+
+// newTestConn creates a connection backed by a fake network connection using
+// default values for buffering.
+func newTestConn(r io.Reader, w io.Writer, isServer bool) *Conn {
+ return newConn(fakeNetConn{Reader: r, Writer: w}, isServer, 1024, 1024, nil, nil, nil)
+}
+
+func TestFraming(t *testing.T) {
+ frameSizes := []int{
+ 0, 1, 2, 124, 125, 126, 127, 128, 129, 65534, 65535,
+ // 65536, 65537
+ }
+ var readChunkers = []struct {
+ name string
+ f func(io.Reader) io.Reader
+ }{
+ {"half", iotest.HalfReader},
+ {"one", iotest.OneByteReader},
+ {"asis", func(r io.Reader) io.Reader { return r }},
+ }
+ writeBuf := make([]byte, 65537)
+ for i := range writeBuf {
+ writeBuf[i] = byte(i)
+ }
+ var writers = []struct {
+ name string
+ f func(w io.Writer, n int) (int, error)
+ }{
+ {"iocopy", func(w io.Writer, n int) (int, error) {
+ nn, err := io.Copy(w, bytes.NewReader(writeBuf[:n]))
+ return int(nn), err
+ }},
+ {"write", func(w io.Writer, n int) (int, error) {
+ return w.Write(writeBuf[:n])
+ }},
+ {"string", func(w io.Writer, n int) (int, error) {
+ return io.WriteString(w, string(writeBuf[:n]))
+ }},
+ }
+
+ for _, compress := range []bool{false, true} {
+ for _, isServer := range []bool{true, false} {
+ for _, chunker := range readChunkers {
+
+ var connBuf bytes.Buffer
+ wc := newTestConn(nil, &connBuf, isServer)
+ rc := newTestConn(chunker.f(&connBuf), nil, !isServer)
+ if compress {
+ wc.newCompressionWriter = compressNoContextTakeover
+ rc.newDecompressionReader = decompressNoContextTakeover
+ }
+ for _, n := range frameSizes {
+ for _, writer := range writers {
+ name := fmt.Sprintf("z:%v, s:%v, r:%s, n:%d w:%s", compress, isServer, chunker.name, n, writer.name)
+
+ w, err := wc.NextWriter(TextMessage)
+ if err != nil {
+ t.Errorf("%s: wc.NextWriter() returned %v", name, err)
+ continue
+ }
+ nn, err := writer.f(w, n)
+ if err != nil || nn != n {
+ t.Errorf("%s: w.Write(writeBuf[:n]) returned %d, %v", name, nn, err)
+ continue
+ }
+ err = w.Close()
+ if err != nil {
+ t.Errorf("%s: w.Close() returned %v", name, err)
+ continue
+ }
+
+ opCode, r, err := rc.NextReader()
+ if err != nil || opCode != TextMessage {
+ t.Errorf("%s: NextReader() returned %d, r, %v", name, opCode, err)
+ continue
+ }
+
+ t.Logf("frame size: %d", n)
+ rbuf, err := io.ReadAll(r)
+ if err != nil {
+ t.Errorf("%s: ReadFull() returned rbuf, %v", name, err)
+ continue
+ }
+
+ if len(rbuf) != n {
+ t.Errorf("%s: len(rbuf) is %d, want %d", name, len(rbuf), n)
+ continue
+ }
+
+ for i, b := range rbuf {
+ if byte(i) != b {
+ t.Errorf("%s: bad byte at offset %d", name, i)
+ break
+ }
+ }
+ }
+ }
+ }
+ }
+ }
+}
+
+func TestWriteControlDeadline(t *testing.T) {
+ t.Parallel()
+ message := []byte("hello")
+ var connBuf bytes.Buffer
+ c := newTestConn(nil, &connBuf, true)
+ if err := c.WriteControl(PongMessage, message, time.Time{}); err != nil {
+ t.Errorf("WriteControl(..., zero deadline) = %v, want nil", err)
+ }
+ if err := c.WriteControl(PongMessage, message, time.Now().Add(time.Second)); err != nil {
+ t.Errorf("WriteControl(..., future deadline) = %v, want nil", err)
+ }
+ if err := c.WriteControl(PongMessage, message, time.Now().Add(-time.Second)); err == nil {
+ t.Errorf("WriteControl(..., past deadline) = nil, want timeout error")
+ }
+}
+
+func TestConcurrencyWriteControl(t *testing.T) {
+ const message = "this is a ping/pong messsage"
+ loop := 10
+ workers := 10
+ for i := 0; i < loop; i++ {
+ var connBuf bytes.Buffer
+
+ wg := sync.WaitGroup{}
+ wc := newTestConn(nil, &connBuf, true)
+
+ for i := 0; i < workers; i++ {
+ wg.Add(1)
+ go func() {
+ defer wg.Done()
+ if err := wc.WriteControl(PongMessage, []byte(message), time.Now().Add(time.Second)); err != nil {
+ t.Errorf("concurrently wc.WriteControl() returned %v", err)
+ }
+ }()
+ }
+
+ wg.Wait()
+ wc.Close()
+ }
+}
+
+func TestControl(t *testing.T) {
+ const message = "this is a ping/pong message"
+ for _, isServer := range []bool{true, false} {
+ for _, isWriteControl := range []bool{true, false} {
+ name := fmt.Sprintf("s:%v, wc:%v", isServer, isWriteControl)
+ var connBuf bytes.Buffer
+ wc := newTestConn(nil, &connBuf, isServer)
+ rc := newTestConn(&connBuf, nil, !isServer)
+ if isWriteControl {
+ _ = wc.WriteControl(PongMessage, []byte(message), time.Now().Add(time.Second))
+ } else {
+ w, err := wc.NextWriter(PongMessage)
+ if err != nil {
+ t.Errorf("%s: wc.NextWriter() returned %v", name, err)
+ continue
+ }
+ if _, err := w.Write([]byte(message)); err != nil {
+ t.Errorf("%s: w.Write() returned %v", name, err)
+ continue
+ }
+ if err := w.Close(); err != nil {
+ t.Errorf("%s: w.Close() returned %v", name, err)
+ continue
+ }
+ var actualMessage string
+ rc.SetPongHandler(func(s string) error { actualMessage = s; return nil })
+ _, _, _ = rc.NextReader()
+ if actualMessage != message {
+ t.Errorf("%s: pong=%q, want %q", name, actualMessage, message)
+ continue
+ }
+ }
+ }
+ }
+}
+
+// simpleBufferPool is an implementation of BufferPool for TestWriteBufferPool.
+type simpleBufferPool struct {
+ v interface{}
+}
+
+func (p *simpleBufferPool) Get() interface{} {
+ v := p.v
+ p.v = nil
+ return v
+}
+
+func (p *simpleBufferPool) Put(v interface{}) {
+ p.v = v
+}
+
+func TestWriteBufferPool(t *testing.T) {
+ const message = "Now is the time for all good people to come to the aid of the party."
+
+ var buf bytes.Buffer
+ var pool simpleBufferPool
+ rc := newTestConn(&buf, nil, false)
+
+ // Specify writeBufferSize smaller than message size to ensure that pooling
+ // works with fragmented messages.
+ wc := newConn(fakeNetConn{Writer: &buf}, true, 1024, len(message)-1, &pool, nil, nil)
+
+ if wc.writeBuf != nil {
+ t.Fatal("writeBuf not nil after create")
+ }
+
+ // Part 1: test NextWriter/Write/Close
+
+ w, err := wc.NextWriter(TextMessage)
+ if err != nil {
+ t.Fatalf("wc.NextWriter() returned %v", err)
+ }
+
+ if wc.writeBuf == nil {
+ t.Fatal("writeBuf is nil after NextWriter")
+ }
+
+ writeBufAddr := &wc.writeBuf[0]
+
+ if _, err := io.WriteString(w, message); err != nil {
+ t.Fatalf("io.WriteString(w, message) returned %v", err)
+ }
+
+ if err := w.Close(); err != nil {
+ t.Fatalf("w.Close() returned %v", err)
+ }
+
+ if wc.writeBuf != nil {
+ t.Fatal("writeBuf not nil after w.Close()")
+ }
+
+ if wpd, ok := pool.v.(writePoolData); !ok || len(wpd.buf) == 0 || &wpd.buf[0] != writeBufAddr {
+ t.Fatal("writeBuf not returned to pool")
+ }
+
+ opCode, p, err := rc.ReadMessage()
+ if opCode != TextMessage || err != nil {
+ t.Fatalf("ReadMessage() returned %d, p, %v", opCode, err)
+ }
+
+ if s := string(p); s != message {
+ t.Fatalf("message is %s, want %s", s, message)
+ }
+
+ // Part 2: Test WriteMessage.
+
+ if err := wc.WriteMessage(TextMessage, []byte(message)); err != nil {
+ t.Fatalf("wc.WriteMessage() returned %v", err)
+ }
+
+ if wc.writeBuf != nil {
+ t.Fatal("writeBuf not nil after wc.WriteMessage()")
+ }
+
+ if wpd, ok := pool.v.(writePoolData); !ok || len(wpd.buf) == 0 || &wpd.buf[0] != writeBufAddr {
+ t.Fatal("writeBuf not returned to pool after WriteMessage")
+ }
+
+ opCode, p, err = rc.ReadMessage()
+ if opCode != TextMessage || err != nil {
+ t.Fatalf("ReadMessage() returned %d, p, %v", opCode, err)
+ }
+
+ if s := string(p); s != message {
+ t.Fatalf("message is %s, want %s", s, message)
+ }
+}
+
+// TestWriteBufferPoolSync ensures that *sync.Pool works as a buffer pool.
+func TestWriteBufferPoolSync(t *testing.T) {
+ var buf bytes.Buffer
+ var pool sync.Pool
+ wc := newConn(fakeNetConn{Writer: &buf}, true, 1024, 1024, &pool, nil, nil)
+ rc := newTestConn(&buf, nil, false)
+
+ const message = "Hello World!"
+ for i := 0; i < 3; i++ {
+ if err := wc.WriteMessage(TextMessage, []byte(message)); err != nil {
+ t.Fatalf("wc.WriteMessage() returned %v", err)
+ }
+ opCode, p, err := rc.ReadMessage()
+ if opCode != TextMessage || err != nil {
+ t.Fatalf("ReadMessage() returned %d, p, %v", opCode, err)
+ }
+ if s := string(p); s != message {
+ t.Fatalf("message is %s, want %s", s, message)
+ }
+ }
+}
+
+// errorWriter is an io.Writer than returns an error on all writes.
+type errorWriter struct{}
+
+func (ew errorWriter) Write(p []byte) (int, error) { return 0, errors.New("error") }
+
+// TestWriteBufferPoolError ensures that buffer is returned to pool after error
+// on write.
+func TestWriteBufferPoolError(t *testing.T) {
+
+ // Part 1: Test NextWriter/Write/Close
+
+ var pool simpleBufferPool
+ wc := newConn(fakeNetConn{Writer: errorWriter{}}, true, 1024, 1024, &pool, nil, nil)
+
+ w, err := wc.NextWriter(TextMessage)
+ if err != nil {
+ t.Fatalf("wc.NextWriter() returned %v", err)
+ }
+
+ if wc.writeBuf == nil {
+ t.Fatal("writeBuf is nil after NextWriter")
+ }
+
+ writeBufAddr := &wc.writeBuf[0]
+
+ if _, err := io.WriteString(w, "Hello"); err != nil {
+ t.Fatalf("io.WriteString(w, message) returned %v", err)
+ }
+
+ if err := w.Close(); err == nil {
+ t.Fatalf("w.Close() did not return error")
+ }
+
+ if wpd, ok := pool.v.(writePoolData); !ok || len(wpd.buf) == 0 || &wpd.buf[0] != writeBufAddr {
+ t.Fatal("writeBuf not returned to pool")
+ }
+
+ // Part 2: Test WriteMessage
+
+ wc = newConn(fakeNetConn{Writer: errorWriter{}}, true, 1024, 1024, &pool, nil, nil)
+
+ if err := wc.WriteMessage(TextMessage, []byte("Hello")); err == nil {
+ t.Fatalf("wc.WriteMessage did not return error")
+ }
+
+ if wpd, ok := pool.v.(writePoolData); !ok || len(wpd.buf) == 0 || &wpd.buf[0] != writeBufAddr {
+ t.Fatal("writeBuf not returned to pool")
+ }
+}
+
+func TestCloseFrameBeforeFinalMessageFrame(t *testing.T) {
+ const bufSize = 512
+
+ expectedErr := &CloseError{Code: CloseNormalClosure, Text: "hello"}
+
+ var b1, b2 bytes.Buffer
+ wc := newConn(&fakeNetConn{Reader: nil, Writer: &b1}, false, 1024, bufSize, nil, nil, nil)
+ rc := newTestConn(&b1, &b2, true)
+
+ w, _ := wc.NextWriter(BinaryMessage)
+ _, _ = w.Write(make([]byte, bufSize+bufSize/2))
+ _ = wc.WriteControl(CloseMessage, FormatCloseMessage(expectedErr.Code, expectedErr.Text), time.Now().Add(10*time.Second))
+ w.Close()
+
+ op, r, err := rc.NextReader()
+ if op != BinaryMessage || err != nil {
+ t.Fatalf("NextReader() returned %d, %v", op, err)
+ }
+ _, err = io.Copy(io.Discard, r)
+ if !reflect.DeepEqual(err, expectedErr) {
+ t.Fatalf("io.Copy() returned %v, want %v", err, expectedErr)
+ }
+ _, _, err = rc.NextReader()
+ if !reflect.DeepEqual(err, expectedErr) {
+ t.Fatalf("NextReader() returned %v, want %v", err, expectedErr)
+ }
+}
+
+func TestEOFWithinFrame(t *testing.T) {
+ const bufSize = 64
+
+ for n := 0; ; n++ {
+ var b bytes.Buffer
+ wc := newTestConn(nil, &b, false)
+ rc := newTestConn(&b, nil, true)
+
+ w, _ := wc.NextWriter(BinaryMessage)
+ _, _ = w.Write(make([]byte, bufSize))
+ w.Close()
+
+ if n >= b.Len() {
+ break
+ }
+ b.Truncate(n)
+
+ op, r, err := rc.NextReader()
+ if err == errUnexpectedEOF {
+ continue
+ }
+ if op != BinaryMessage || err != nil {
+ t.Fatalf("%d: NextReader() returned %d, %v", n, op, err)
+ }
+ _, err = io.Copy(io.Discard, r)
+ if err != errUnexpectedEOF {
+ t.Fatalf("%d: io.Copy() returned %v, want %v", n, err, errUnexpectedEOF)
+ }
+ _, _, err = rc.NextReader()
+ if err != errUnexpectedEOF {
+ t.Fatalf("%d: NextReader() returned %v, want %v", n, err, errUnexpectedEOF)
+ }
+ }
+}
+
+func TestEOFBeforeFinalFrame(t *testing.T) {
+ const bufSize = 512
+
+ var b1, b2 bytes.Buffer
+ wc := newConn(&fakeNetConn{Writer: &b1}, false, 1024, bufSize, nil, nil, nil)
+ rc := newTestConn(&b1, &b2, true)
+
+ w, _ := wc.NextWriter(BinaryMessage)
+ _, _ = w.Write(make([]byte, bufSize+bufSize/2))
+
+ op, r, err := rc.NextReader()
+ if op != BinaryMessage || err != nil {
+ t.Fatalf("NextReader() returned %d, %v", op, err)
+ }
+ _, err = io.Copy(io.Discard, r)
+ if err != errUnexpectedEOF {
+ t.Fatalf("io.Copy() returned %v, want %v", err, errUnexpectedEOF)
+ }
+ _, _, err = rc.NextReader()
+ if err != errUnexpectedEOF {
+ t.Fatalf("NextReader() returned %v, want %v", err, errUnexpectedEOF)
+ }
+}
+
+func TestWriteAfterMessageWriterClose(t *testing.T) {
+ wc := newTestConn(nil, &bytes.Buffer{}, false)
+ w, _ := wc.NextWriter(BinaryMessage)
+ _, _ = io.WriteString(w, "hello")
+ if err := w.Close(); err != nil {
+ t.Fatalf("unexpected error closing message writer, %v", err)
+ }
+
+ if _, err := io.WriteString(w, "world"); err == nil {
+ t.Fatalf("no error writing after close")
+ }
+
+ w, _ = wc.NextWriter(BinaryMessage)
+ _, _ = io.WriteString(w, "hello")
+
+ // close w by getting next writer
+ _, err := wc.NextWriter(BinaryMessage)
+ if err != nil {
+ t.Fatalf("unexpected error getting next writer, %v", err)
+ }
+
+ if _, err := io.WriteString(w, "world"); err == nil {
+ t.Fatalf("no error writing after close")
+ }
+}
+
+func TestReadLimit(t *testing.T) {
+ t.Run("Test ReadLimit is enforced", func(t *testing.T) {
+ const readLimit = 512
+ message := make([]byte, readLimit+1)
+
+ var b1, b2 bytes.Buffer
+ wc := newConn(&fakeNetConn{Writer: &b1}, false, 1024, readLimit-2, nil, nil, nil)
+ rc := newTestConn(&b1, &b2, true)
+ rc.SetReadLimit(readLimit)
+
+ // Send message at the limit with interleaved pong.
+ w, _ := wc.NextWriter(BinaryMessage)
+ _, _ = w.Write(message[:readLimit-1])
+ _ = wc.WriteControl(PongMessage, []byte("this is a pong"), time.Now().Add(10*time.Second))
+ _, _ = w.Write(message[:1])
+ w.Close()
+
+ // Send message larger than the limit.
+ _ = wc.WriteMessage(BinaryMessage, message[:readLimit+1])
+
+ op, _, err := rc.NextReader()
+ if op != BinaryMessage || err != nil {
+ t.Fatalf("1: NextReader() returned %d, %v", op, err)
+ }
+ op, r, err := rc.NextReader()
+ if op != BinaryMessage || err != nil {
+ t.Fatalf("2: NextReader() returned %d, %v", op, err)
+ }
+ _, err = io.Copy(io.Discard, r)
+ if err != ErrReadLimit {
+ t.Fatalf("io.Copy() returned %v", err)
+ }
+ })
+
+ t.Run("Test that ReadLimit cannot be overflowed", func(t *testing.T) {
+ const readLimit = 1
+
+ var b1, b2 bytes.Buffer
+ rc := newTestConn(&b1, &b2, true)
+ rc.SetReadLimit(readLimit)
+
+ // First, send a non-final binary message
+ b1.Write([]byte("\x02\x81"))
+
+ // Mask key
+ b1.Write([]byte("\x00\x00\x00\x00"))
+
+ // First payload
+ b1.Write([]byte("A"))
+
+ // Next, send a negative-length, non-final continuation frame
+ b1.Write([]byte("\x00\xFF\x80\x00\x00\x00\x00\x00\x00\x00"))
+
+ // Mask key
+ b1.Write([]byte("\x00\x00\x00\x00"))
+
+ // Next, send a too long, final continuation frame
+ b1.Write([]byte("\x80\xFF\x00\x00\x00\x00\x00\x00\x00\x05"))
+
+ // Mask key
+ b1.Write([]byte("\x00\x00\x00\x00"))
+
+ // Too-long payload
+ b1.Write([]byte("BCDEF"))
+
+ op, r, err := rc.NextReader()
+ if op != BinaryMessage || err != nil {
+ t.Fatalf("1: NextReader() returned %d, %v", op, err)
+ }
+
+ var buf [10]byte
+ var read int
+ n, err := r.Read(buf[:])
+ if err != nil && err != ErrReadLimit {
+ t.Fatalf("unexpected error testing read limit: %v", err)
+ }
+ read += n
+
+ n, err = r.Read(buf[:])
+ if err != nil && err != ErrReadLimit {
+ t.Fatalf("unexpected error testing read limit: %v", err)
+ }
+ read += n
+
+ if err == nil && read > readLimit {
+ t.Fatalf("read limit exceeded: limit %d, read %d", readLimit, read)
+ }
+ })
+}
+
+func TestAddrs(t *testing.T) {
+ c := newTestConn(nil, nil, true)
+ if c.LocalAddr() != localAddr {
+ t.Errorf("LocalAddr = %v, want %v", c.LocalAddr(), localAddr)
+ }
+ if c.RemoteAddr() != remoteAddr {
+ t.Errorf("RemoteAddr = %v, want %v", c.RemoteAddr(), remoteAddr)
+ }
+}
+
+func TestDeprecatedUnderlyingConn(t *testing.T) {
+ var b1, b2 bytes.Buffer
+ fc := fakeNetConn{Reader: &b1, Writer: &b2}
+ c := newConn(fc, true, 1024, 1024, nil, nil, nil)
+ ul := c.UnderlyingConn()
+ if ul != fc {
+ t.Fatalf("Underlying conn is not what it should be.")
+ }
+}
+
+func TestNetConn(t *testing.T) {
+ var b1, b2 bytes.Buffer
+ fc := fakeNetConn{Reader: &b1, Writer: &b2}
+ c := newConn(fc, true, 1024, 1024, nil, nil, nil)
+ ul := c.NetConn()
+ if ul != fc {
+ t.Fatalf("Underlying conn is not what it should be.")
+ }
+}
+
+func TestBufioReadBytes(t *testing.T) {
+ // Test calling bufio.ReadBytes for value longer than read buffer size.
+
+ m := make([]byte, 512)
+ m[len(m)-1] = '\n'
+
+ var b1, b2 bytes.Buffer
+ wc := newConn(fakeNetConn{Writer: &b1}, false, len(m)+64, len(m)+64, nil, nil, nil)
+ rc := newConn(fakeNetConn{Reader: &b1, Writer: &b2}, true, len(m)-64, len(m)-64, nil, nil, nil)
+
+ w, _ := wc.NextWriter(BinaryMessage)
+ _, _ = w.Write(m)
+ w.Close()
+
+ op, r, err := rc.NextReader()
+ if op != BinaryMessage || err != nil {
+ t.Fatalf("NextReader() returned %d, %v", op, err)
+ }
+
+ br := bufio.NewReader(r)
+ p, err := br.ReadBytes('\n')
+ if err != nil {
+ t.Fatalf("ReadBytes() returned %v", err)
+ }
+ if len(p) != len(m) {
+ t.Fatalf("read returned %d bytes, want %d bytes", len(p), len(m))
+ }
+}
+
+var closeErrorTests = []struct {
+ err error
+ codes []int
+ ok bool
+}{
+ {&CloseError{Code: CloseNormalClosure}, []int{CloseNormalClosure}, true},
+ {&CloseError{Code: CloseNormalClosure}, []int{CloseNoStatusReceived}, false},
+ {&CloseError{Code: CloseNormalClosure}, []int{CloseNoStatusReceived, CloseNormalClosure}, true},
+ {errors.New("hello"), []int{CloseNormalClosure}, false},
+}
+
+func TestCloseError(t *testing.T) {
+ for _, tt := range closeErrorTests {
+ ok := IsCloseError(tt.err, tt.codes...)
+ if ok != tt.ok {
+ t.Errorf("IsCloseError(%#v, %#v) returned %v, want %v", tt.err, tt.codes, ok, tt.ok)
+ }
+ }
+}
+
+var unexpectedCloseErrorTests = []struct {
+ err error
+ codes []int
+ ok bool
+}{
+ {&CloseError{Code: CloseNormalClosure}, []int{CloseNormalClosure}, false},
+ {&CloseError{Code: CloseNormalClosure}, []int{CloseNoStatusReceived}, true},
+ {&CloseError{Code: CloseNormalClosure}, []int{CloseNoStatusReceived, CloseNormalClosure}, false},
+ {errors.New("hello"), []int{CloseNormalClosure}, false},
+}
+
+func TestUnexpectedCloseErrors(t *testing.T) {
+ for _, tt := range unexpectedCloseErrorTests {
+ ok := IsUnexpectedCloseError(tt.err, tt.codes...)
+ if ok != tt.ok {
+ t.Errorf("IsUnexpectedCloseError(%#v, %#v) returned %v, want %v", tt.err, tt.codes, ok, tt.ok)
+ }
+ }
+}
+
+type blockingWriter struct {
+ c1, c2 chan struct{}
+}
+
+func (w blockingWriter) Write(p []byte) (int, error) {
+ // Allow main to continue
+ close(w.c1)
+ // Wait for panic in main
+ <-w.c2
+ return len(p), nil
+}
+
+func TestConcurrentWritePanic(t *testing.T) {
+ w := blockingWriter{make(chan struct{}), make(chan struct{})}
+ c := newTestConn(nil, w, false)
+ go func() {
+ _ = c.WriteMessage(TextMessage, []byte{})
+ }()
+
+ // wait for goroutine to block in write.
+ <-w.c1
+
+ defer func() {
+ close(w.c2)
+ if v := recover(); v != nil {
+ return
+ }
+ }()
+
+ _ = c.WriteMessage(TextMessage, []byte{})
+ t.Fatal("should not get here")
+}
+
+type failingReader struct{}
+
+func (r failingReader) Read(p []byte) (int, error) {
+ return 0, io.EOF
+}
+
+func TestFailedConnectionReadPanic(t *testing.T) {
+ c := newTestConn(failingReader{}, nil, false)
+
+ defer func() {
+ if v := recover(); v != nil {
+ return
+ }
+ }()
+
+ for i := 0; i < 20000; i++ {
+ _, _, _ = c.ReadMessage()
+ }
+ t.Fatal("should not get here")
+}
+// Copyright 2019 The Gorilla WebSocket Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+
+func TestJoinMessages(t *testing.T) {
+ messages := []string{"a", "bc", "def", "ghij", "klmno", "0", "12", "345", "6789"}
+ for _, readChunk := range []int{1, 2, 3, 4, 5, 6, 7} {
+ for _, term := range []string{"", ","} {
+ var connBuf bytes.Buffer
+ wc := newTestConn(nil, &connBuf, true)
+ rc := newTestConn(&connBuf, nil, false)
+ for _, m := range messages {
+ _ = wc.WriteMessage(BinaryMessage, []byte(m))
+ }
+
+ var result bytes.Buffer
+ _, err := io.CopyBuffer(&result, JoinMessages(rc, term), make([]byte, readChunk))
+ if IsUnexpectedCloseError(err, CloseAbnormalClosure) {
+ t.Errorf("readChunk=%d, term=%q: unexpected error %v", readChunk, term, err)
+ }
+ want := strings.Join(messages, term) + term
+ if result.String() != want {
+ t.Errorf("readChunk=%d, term=%q, got %q, want %q", readChunk, term, result.String(), want)
+ }
+ }
+ }
+}
+// Copyright 2013 The Gorilla WebSocket Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+
+func TestJSON(t *testing.T) {
+ var buf bytes.Buffer
+ wc := newTestConn(nil, &buf, true)
+ rc := newTestConn(&buf, nil, false)
+
+ var actual, expect struct {
+ A int
+ B string
+ }
+ expect.A = 1
+ expect.B = "hello"
+
+ if err := wc.WriteJSON(&expect); err != nil {
+ t.Fatal("write", err)
+ }
+
+ if err := rc.ReadJSON(&actual); err != nil {
+ t.Fatal("read", err)
+ }
+
+ if !reflect.DeepEqual(&actual, &expect) {
+ t.Fatal("equal", actual, expect)
+ }
+}
+
+func TestPartialJSONRead(t *testing.T) {
+ var buf0, buf1 bytes.Buffer
+ wc := newTestConn(nil, &buf0, true)
+ rc := newTestConn(&buf0, &buf1, false)
+
+ var v struct {
+ A int
+ B string
+ }
+ v.A = 1
+ v.B = "hello"
+
+ messageCount := 0
+
+ // Partial JSON values.
+
+ data, err := json.Marshal(v)
+ if err != nil {
+ t.Fatal(err)
+ }
+ for i := len(data) - 1; i >= 0; i-- {
+ if err := wc.WriteMessage(TextMessage, data[:i]); err != nil {
+ t.Fatal(err)
+ }
+ messageCount++
+ }
+
+ // Whitespace.
+
+ if err := wc.WriteMessage(TextMessage, []byte(" ")); err != nil {
+ t.Fatal(err)
+ }
+ messageCount++
+
+ // Close.
+
+ if err := wc.WriteMessage(CloseMessage, FormatCloseMessage(CloseNormalClosure, "")); err != nil {
+ t.Fatal(err)
+ }
+
+ for i := 0; i < messageCount; i++ {
+ err := rc.ReadJSON(&v)
+ if err != io.ErrUnexpectedEOF {
+ t.Error("read", i, err)
+ }
+ }
+
+ err = rc.ReadJSON(&v)
+ if _, ok := err.(*CloseError); !ok {
+ t.Error("final", err)
+ }
+}
+
+func TestDeprecatedJSON(t *testing.T) {
+ var buf bytes.Buffer
+ wc := newTestConn(nil, &buf, true)
+ rc := newTestConn(&buf, nil, false)
+
+ var actual, expect struct {
+ A int
+ B string
+ }
+ expect.A = 1
+ expect.B = "hello"
+
+ if err := WriteJSON(wc, &expect); err != nil {
+ t.Fatal("write", err)
+ }
+
+ if err := ReadJSON(rc, &actual); err != nil {
+ t.Fatal("read", err)
+ }
+
+ if !reflect.DeepEqual(&actual, &expect) {
+ t.Fatal("equal", actual, expect)
+ }
+}
+// Copyright 2016 The Gorilla WebSocket Authors. All rights reserved. Use of
+// this source code is governed by a BSD-style license that can be found in the
+// LICENSE file.
+
+// !appengine
+
+
+func maskBytesByByte(key [4]byte, pos int, b []byte) int {
+ for i := range b {
+ b[i] ^= key[pos&3]
+ pos++
+ }
+ return pos & 3
+}
+
+func notzero(b []byte) int {
+ for i := range b {
+ if b[i] != 0 {
+ return i
+ }
+ }
+ return -1
+}
+
+func TestMaskBytes(t *testing.T) {
+ key := [4]byte{1, 2, 3, 4}
+ for size := 1; size <= 1024; size++ {
+ for align := 0; align < wordSize; align++ {
+ for pos := 0; pos < 4; pos++ {
+ b := make([]byte, size+align)[align:]
+ maskBytes(key, pos, b)
+ maskBytesByByte(key, pos, b)
+ if i := notzero(b); i >= 0 {
+ t.Errorf("size:%d, align:%d, pos:%d, offset:%d", size, align, pos, i)
+ }
+ }
+ }
+ }
+}
+
+func BenchmarkMaskBytes(b *testing.B) {
+ for _, size := range []int{2, 4, 8, 16, 32, 512, 1024} {
+ b.Run(fmt.Sprintf("size-%d", size), func(b *testing.B) {
+ for _, align := range []int{wordSize / 2} {
+ b.Run(fmt.Sprintf("align-%d", align), func(b *testing.B) {
+ for _, fn := range []struct {
+ name string
+ fn func(key [4]byte, pos int, b []byte) int
+ }{
+ {"byte", maskBytesByByte},
+ {"word", maskBytes},
+ } {
+ b.Run(fn.name, func(b *testing.B) {
+ key := newMaskKey()
+ data := make([]byte, size+align)[align:]
+ for i := 0; i < b.N; i++ {
+ fn.fn(key, 0, data)
+ }
+ b.SetBytes(int64(len(data)))
+ })
+ }
+ })
+ }
+ })
+ }
+}
+// Copyright 2017 The Gorilla WebSocket Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+
+var preparedMessageTests = []struct {
+ messageType int
+ isServer bool
+ enableWriteCompression bool
+ compressionLevel int
+}{
+ // Server
+ {TextMessage, true, false, flate.BestSpeed},
+ {TextMessage, true, true, flate.BestSpeed},
+ {TextMessage, true, true, flate.BestCompression},
+ {PingMessage, true, false, flate.BestSpeed},
+ {PingMessage, true, true, flate.BestSpeed},
+
+ // Client
+ {TextMessage, false, false, flate.BestSpeed},
+ {TextMessage, false, true, flate.BestSpeed},
+ {TextMessage, false, true, flate.BestCompression},
+ {PingMessage, false, false, flate.BestSpeed},
+ {PingMessage, false, true, flate.BestSpeed},
+}
+
+func TestPreparedMessage(t *testing.T) {
+ testRand := rand.New(rand.NewSource(99))
+ prevMaskRand := maskRand
+ maskRand = testRand
+ defer func() { maskRand = prevMaskRand }()
+
+ for _, tt := range preparedMessageTests {
+ var data = []byte("this is a test")
+ var buf bytes.Buffer
+ c := newTestConn(nil, &buf, tt.isServer)
+ if tt.enableWriteCompression {
+ c.newCompressionWriter = compressNoContextTakeover
+ }
+ if err := c.SetCompressionLevel(tt.compressionLevel); err != nil {
+ t.Fatal(err)
+ }
+
+ // Seed random number generator for consistent frame mask.
+ testRand.Seed(1234)
+
+ if err := c.WriteMessage(tt.messageType, data); err != nil {
+ t.Fatal(err)
+ }
+ want := buf.String()
+
+ pm, err := NewPreparedMessage(tt.messageType, data)
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ // Scribble on data to ensure that NewPreparedMessage takes a snapshot.
+ copy(data, "hello world")
+
+ // Seed random number generator for consistent frame mask.
+ testRand.Seed(1234)
+
+ buf.Reset()
+ if err := c.WritePreparedMessage(pm); err != nil {
+ t.Fatal(err)
+ }
+ got := buf.String()
+
+ if got != want {
+ t.Errorf("write message != prepared message for %+v", tt)
+ }
+ }
+}
+// Copyright 2013 The Gorilla WebSocket Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+
+var subprotocolTests = []struct {
+ h string
+ protocols []string
+}{
+ {"", nil},
+ {"foo", []string{"foo"}},
+ {"foo,bar", []string{"foo", "bar"}},
+ {"foo, bar", []string{"foo", "bar"}},
+ {" foo, bar", []string{"foo", "bar"}},
+ {" foo, bar ", []string{"foo", "bar"}},
+}
+
+func TestSubprotocols(t *testing.T) {
+ for _, st := range subprotocolTests {
+ r := http.Request{Header: http.Header{"Sec-Websocket-Protocol": {st.h}}}
+ protocols := Subprotocols(&r)
+ if !reflect.DeepEqual(st.protocols, protocols) {
+ t.Errorf("SubProtocols(%q) returned %#v, want %#v", st.h, protocols, st.protocols)
+ }
+ }
+}
+
+var isWebSocketUpgradeTests = []struct {
+ ok bool
+ h http.Header
+}{
+ {false, http.Header{"Upgrade": {"websocket"}}},
+ {false, http.Header{"Connection": {"upgrade"}}},
+ {true, http.Header{"Connection": {"upgRade"}, "Upgrade": {"WebSocket"}}},
+}
+
+func TestIsWebSocketUpgrade(t *testing.T) {
+ for _, tt := range isWebSocketUpgradeTests {
+ ok := IsWebSocketUpgrade(&http.Request{Header: tt.h})
+ if tt.ok != ok {
+ t.Errorf("IsWebSocketUpgrade(%v) returned %v, want %v", tt.h, ok, tt.ok)
+ }
+ }
+}
+
+func TestSubProtocolSelection(t *testing.T) {
+ upgrader := Upgrader{
+ Subprotocols: []string{"foo", "bar", "baz"},
+ }
+
+ r := http.Request{Header: http.Header{"Sec-Websocket-Protocol": {"foo", "bar"}}}
+ s := upgrader.selectSubprotocol(&r, nil)
+ if s != "foo" {
+ t.Errorf("Upgrader.selectSubprotocol returned %v, want %v", s, "foo")
+ }
+
+ r = http.Request{Header: http.Header{"Sec-Websocket-Protocol": {"bar", "foo"}}}
+ s = upgrader.selectSubprotocol(&r, nil)
+ if s != "bar" {
+ t.Errorf("Upgrader.selectSubprotocol returned %v, want %v", s, "bar")
+ }
+
+ r = http.Request{Header: http.Header{"Sec-Websocket-Protocol": {"baz"}}}
+ s = upgrader.selectSubprotocol(&r, nil)
+ if s != "baz" {
+ t.Errorf("Upgrader.selectSubprotocol returned %v, want %v", s, "baz")
+ }
+
+ r = http.Request{Header: http.Header{"Sec-Websocket-Protocol": {"quux"}}}
+ s = upgrader.selectSubprotocol(&r, nil)
+ if s != "" {
+ t.Errorf("Upgrader.selectSubprotocol returned %v, want %v", s, "empty string")
+ }
+}
+
+var checkSameOriginTests = []struct {
+ ok bool
+ r *http.Request
+}{
+ {false, &http.Request{Host: "example.org", Header: map[string][]string{"Origin": {"https://other.org"}}}},
+ {true, &http.Request{Host: "example.org", Header: map[string][]string{"Origin": {"https://example.org"}}}},
+ {true, &http.Request{Host: "Example.org", Header: map[string][]string{"Origin": {"https://example.org"}}}},
+}
+
+func TestCheckSameOrigin(t *testing.T) {
+ for _, tt := range checkSameOriginTests {
+ ok := checkSameOrigin(tt.r)
+ if tt.ok != ok {
+ t.Errorf("checkSameOrigin(%+v) returned %v, want %v", tt.r, ok, tt.ok)
+ }
+ }
+}
+
+type reuseTestResponseWriter struct {
+ brw *bufio.ReadWriter
+ http.ResponseWriter
+}
+
+func (resp *reuseTestResponseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) {
+ return fakeNetConn{strings.NewReader(""), &bytes.Buffer{}}, resp.brw, nil
+}
+
+var bufioReuseTests = []struct {
+ n int
+ reuse bool
+}{
+ {4096, true},
+ {128, false},
+}
+
+func xTestBufioReuse(t *testing.T) {
+ for i, tt := range bufioReuseTests {
+ br := bufio.NewReaderSize(strings.NewReader(""), tt.n)
+ bw := bufio.NewWriterSize(&bytes.Buffer{}, tt.n)
+ resp := &reuseTestResponseWriter{
+ brw: bufio.NewReadWriter(br, bw),
+ }
+ upgrader := Upgrader{}
+ c, err := upgrader.Upgrade(resp, &http.Request{
+ Method: http.MethodGet,
+ Header: http.Header{
+ "Upgrade": []string{"websocket"},
+ "Connection": []string{"upgrade"},
+ "Sec-Websocket-Key": []string{"dGhlIHNhbXBsZSBub25jZQ=="},
+ "Sec-Websocket-Version": []string{"13"},
+ }}, nil)
+ if err != nil {
+ t.Fatal(err)
+ }
+ if reuse := c.br == br; reuse != tt.reuse {
+ t.Errorf("%d: buffered reader reuse=%v, want %v", i, reuse, tt.reuse)
+ }
+ writeBuf := bw.AvailableBuffer()
+ if reuse := &c.writeBuf[0] == &writeBuf[0]; reuse != tt.reuse {
+ t.Errorf("%d: write buffer reuse=%v, want %v", i, reuse, tt.reuse)
+ }
+ }
+}
+
+func TestHijack_NotSupported(t *testing.T) {
+ t.Parallel()
+
+ req := httptest.NewRequest(http.MethodGet, "http://example.com", nil)
+ req.Header.Set("Upgrade", "websocket")
+ req.Header.Set("Connection", "upgrade")
+ req.Header.Set("Sec-Websocket-Key", "dGhlIHNhbXBsZSBub25jZQ==")
+ req.Header.Set("Sec-Websocket-Version", "13")
+
+ recorder := httptest.NewRecorder()
+
+ upgrader := Upgrader{}
+ _, err := upgrader.Upgrade(recorder, req, nil)
+
+ if want := (HandshakeError{}); !errors.As(err, &want) || recorder.Code != http.StatusInternalServerError {
+ t.Errorf("want %T and status_code=%d", want, http.StatusInternalServerError)
+ t.Fatalf("got err=%T and status_code=%d", err, recorder.Code)
+ }
+}
+// Copyright 2014 The Gorilla WebSocket Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+
+var equalASCIIFoldTests = []struct {
+ t, s string
+ eq bool
+}{
+ {"WebSocket", "websocket", true},
+ {"websocket", "WebSocket", true},
+ {"Öyster", "öyster", false},
+ {"WebSocket", "WetSocket", false},
+}
+
+func TestEqualASCIIFold(t *testing.T) {
+ for _, tt := range equalASCIIFoldTests {
+ eq := equalASCIIFold(tt.s, tt.t)
+ if eq != tt.eq {
+ t.Errorf("equalASCIIFold(%q, %q) = %v, want %v", tt.s, tt.t, eq, tt.eq)
+ }
+ }
+}
+
+var tokenListContainsValueTests = []struct {
+ value string
+ ok bool
+}{
+ {"WebSocket", true},
+ {"WEBSOCKET", true},
+ {"websocket", true},
+ {"websockets", false},
+ {"x websocket", false},
+ {"websocket x", false},
+ {"other,websocket,more", true},
+ {"other, websocket, more", true},
+}
+
+func TestTokenListContainsValue(t *testing.T) {
+ for _, tt := range tokenListContainsValueTests {
+ h := http.Header{"Upgrade": {tt.value}}
+ ok := tokenListContainsValue(h, "Upgrade", "websocket")
+ if ok != tt.ok {
+ t.Errorf("tokenListContainsValue(h, n, %q) = %v, want %v", tt.value, ok, tt.ok)
+ }
+ }
+}
+
+var isValidChallengeKeyTests = []struct {
+ key string
+ ok bool
+}{
+ {"dGhlIHNhbXBsZSBub25jZQ==", true},
+ {"", false},
+ {"InvalidKey", false},
+ {"WHQ4eXhscUtKYjBvOGN3WEdtOEQ=", false},
+}
+
+func TestIsValidChallengeKey(t *testing.T) {
+ for _, tt := range isValidChallengeKeyTests {
+ ok := isValidChallengeKey(tt.key)
+ if ok != tt.ok {
+ t.Errorf("isValidChallengeKey returns %v, want %v", ok, tt.ok)
+ }
+ }
+}
+
+var parseExtensionTests = []struct {
+ value string
+ extensions []map[string]string
+}{
+ {`foo`, []map[string]string{{"": "foo"}}},
+ {`foo, bar; baz=2`, []map[string]string{
+ {"": "foo"},
+ {"": "bar", "baz": "2"}}},
+ {`foo; bar="b,a;z"`, []map[string]string{
+ {"": "foo", "bar": "b,a;z"}}},
+ {`foo , bar; baz = 2`, []map[string]string{
+ {"": "foo"},
+ {"": "bar", "baz": "2"}}},
+ {`foo, bar; baz=2 junk`, []map[string]string{
+ {"": "foo"}}},
+ {`foo junk, bar; baz=2 junk`, nil},
+ {`mux; max-channels=4; flow-control, deflate-stream`, []map[string]string{
+ {"": "mux", "max-channels": "4", "flow-control": ""},
+ {"": "deflate-stream"}}},
+ {`permessage-foo; x="10"`, []map[string]string{
+ {"": "permessage-foo", "x": "10"}}},
+ {`permessage-foo; use_y, permessage-foo`, []map[string]string{
+ {"": "permessage-foo", "use_y": ""},
+ {"": "permessage-foo"}}},
+ {`permessage-deflate; client_max_window_bits; server_max_window_bits=10 , permessage-deflate; client_max_window_bits`, []map[string]string{
+ {"": "permessage-deflate", "client_max_window_bits": "", "server_max_window_bits": "10"},
+ {"": "permessage-deflate", "client_max_window_bits": ""}}},
+ {"permessage-deflate; server_no_context_takeover; client_max_window_bits=15", []map[string]string{
+ {"": "permessage-deflate", "server_no_context_takeover": "", "client_max_window_bits": "15"},
+ }},
+}
+
+func TestParseExtensions(t *testing.T) {
+ for _, tt := range parseExtensionTests {
+ h := http.Header{http.CanonicalHeaderKey("Sec-WebSocket-Extensions"): {tt.value}}
+ extensions := parseExtensions(h)
+ if !reflect.DeepEqual(extensions, tt.extensions) {
+ t.Errorf("parseExtensions(%q)\n = %v,\nwant %v", tt.value, extensions, tt.extensions)
+ }
+ }
+}
+
+func TestParseArgs(t *testing.T) {
+ given := ParseArgs([]string { "x", "y", "z" })
+ expected := CLIArgs {
+ FromAddr: "y",
+ ToAddr: "z",
+ }
+
+ g.AssertEqual(given, expected)
+}
+
+
+
+func MainTest() {
+ tests := []testing.InternalTest {
+ { "TestProxyDial", TestProxyDial },
+ { "TestProxyAuthorizationDial", TestProxyAuthorizationDial },
+ { "TestDial", TestDial },
+ { "TestDialCookieJar", TestDialCookieJar },
+ { "TestDialTLS", TestDialTLS },
+ { "TestDialTimeout", TestDialTimeout },
+ { "TestHandshakeTimeout", TestHandshakeTimeout },
+ { "TestHandshakeTimeoutInContext", TestHandshakeTimeoutInContext },
+ { "TestDialBadScheme", TestDialBadScheme },
+ { "TestDialBadOrigin", TestDialBadOrigin },
+ { "TestDialBadHeader", TestDialBadHeader },
+ { "TestBadMethod", TestBadMethod },
+ { "TestNoUpgrade", TestNoUpgrade },
+ { "TestDialExtraTokensInRespHeaders", TestDialExtraTokensInRespHeaders },
+ { "TestHandshake", TestHandshake },
+ { "TestRespOnBadHandshake", TestRespOnBadHandshake },
+ { "TestHost", TestHost },
+ { "TestDialCompression", TestDialCompression },
+ { "TestSocksProxyDial", TestSocksProxyDial },
+ { "TestTracingDialWithContext", TestTracingDialWithContext },
+ { "TestEmptyTracingDialWithContext", TestEmptyTracingDialWithContext },
+ { "TestNetDialConnect", TestNetDialConnect },
+ { "TestNextProtos", TestNextProtos },
+ { "TestDataReceivedBeforeHandshake", TestDataReceivedBeforeHandshake },
+ { "TestHostPortNoPort", TestHostPortNoPort },
+ { "TestTruncWriter", TestTruncWriter },
+ { "TestValidCompressionLevel", TestValidCompressionLevel },
+ { "TestFraming", TestFraming },
+ { "TestWriteControlDeadline", TestWriteControlDeadline },
+ { "TestConcurrencyWriteControl", TestConcurrencyWriteControl },
+ { "TestControl", TestControl },
+ { "TestWriteBufferPool", TestWriteBufferPool },
+ { "TestWriteBufferPoolSync", TestWriteBufferPoolSync },
+ { "TestWriteBufferPoolError", TestWriteBufferPoolError },
+ { "TestCloseFrameBeforeFinalMessageFrame", TestCloseFrameBeforeFinalMessageFrame },
+ { "TestEOFWithinFrame", TestEOFWithinFrame },
+ { "TestEOFBeforeFinalFrame", TestEOFBeforeFinalFrame },
+ { "TestWriteAfterMessageWriterClose", TestWriteAfterMessageWriterClose },
+ { "TestReadLimit", TestReadLimit },
+ { "TestAddrs", TestAddrs },
+ { "TestDeprecatedUnderlyingConn", TestDeprecatedUnderlyingConn },
+ { "TestNetConn", TestNetConn },
+ { "TestBufioReadBytes", TestBufioReadBytes },
+ { "TestCloseError", TestCloseError },
+ { "TestUnexpectedCloseErrors", TestUnexpectedCloseErrors },
+ { "TestConcurrentWritePanic", TestConcurrentWritePanic },
+ { "TestFailedConnectionReadPanic", TestFailedConnectionReadPanic },
+ { "TestJoinMessages", TestJoinMessages },
+ { "TestJSON", TestJSON },
+ { "TestPartialJSONRead", TestPartialJSONRead },
+ { "TestDeprecatedJSON", TestDeprecatedJSON },
+ { "TestMaskBytes", TestMaskBytes },
+ { "TestPreparedMessage", TestPreparedMessage },
+ { "TestSubprotocols", TestSubprotocols },
+ { "TestIsWebSocketUpgrade", TestIsWebSocketUpgrade },
+ { "TestSubProtocolSelection", TestSubProtocolSelection },
+ { "TestCheckSameOrigin", TestCheckSameOrigin },
+ { "TestHijack_NotSupported", TestHijack_NotSupported },
+ { "TestEqualASCIIFold", TestEqualASCIIFold },
+ { "TestTokenListContainsValue", TestTokenListContainsValue },
+ { "TestIsValidChallengeKey", TestIsValidChallengeKey },
+ { "TestParseExtensions", TestParseExtensions },
+ { "TestParseArgs", TestParseArgs },
+ }
+
+ // FIXME: run benchmarks
+ benchmarks := []testing.InternalBenchmark {
+ { "BenchmarkWriteNoCompression", BenchmarkWriteNoCompression },
+ { "BenchmarkWriteWithCompression", BenchmarkWriteWithCompression },
+ { "BenchmarkBroadcast", BenchmarkBroadcast },
+ { "BenchmarkMaskBytes", BenchmarkMaskBytes },
+ }
+
+ fuzzTargets := []testing.InternalFuzzTarget {}
+ examples := []testing.InternalExample {}
+ m := testing.MainStart(
+ testdeps.TestDeps {},
+ tests,
+ benchmarks,
+ fuzzTargets,
+ examples,
+ )
+ os.Exit(m.Run())
+}