package wscat import ( "bufio" "bytes" "compress/flate" "crypto/x509" "encoding/json" "errors" "fmt" "io" "math/rand" "net" "net/http" "net/http/httptest" "net/url" "os" "reflect" "strings" "sync" "sync/atomic" "testing" "testing/internal/testdeps" "testing/iotest" "time" g "gobang" ) var cstUpgrader = Upgrader{ Subprotocols: []string{"p0", "p1"}, ReadBufferSize: 1024, WriteBufferSize: 1024, EnableCompression: true, Error: func(w http.ResponseWriter, r *http.Request, status int, reason error) { http.Error(w, reason.Error(), status) }, } type cstHandler struct { *testing.T s *cstServer } type cstServer struct { URL string Server *httptest.Server wg sync.WaitGroup } const ( cstPath = "/a/b" cstRawQuery = "x=y" cstRequestURI = cstPath + "?" + cstRawQuery ) func (s *cstServer) Close() { s.Server.Close() // Wait for handler functions to complete. s.wg.Wait() } func newServer(t *testing.T) *cstServer { var s cstServer s.Server = httptest.NewServer(cstHandler{T: t, s: &s}) s.Server.URL += cstRequestURI s.URL = makeWsProto(s.Server.URL) return &s } func newTLSServer(t *testing.T) *cstServer { var s cstServer s.Server = httptest.NewTLSServer(cstHandler{T: t, s: &s}) s.Server.URL += cstRequestURI s.URL = makeWsProto(s.Server.URL) return &s } func (t cstHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { // Because tests wait for a response from a server, we are guaranteed that // the wait group count is incremented before the test waits on the group // in the call to (*cstServer).Close(). t.s.wg.Add(1) defer t.s.wg.Done() if r.URL.Path != cstPath { t.Logf("path=%v, want %v", r.URL.Path, cstPath) http.Error(w, "bad path", http.StatusBadRequest) return } if r.URL.RawQuery != cstRawQuery { t.Logf("query=%v, want %v", r.URL.RawQuery, cstRawQuery) http.Error(w, "bad path", http.StatusBadRequest) return } ws, err := cstUpgrader.Upgrade(w, r, http.Header{"Set-Cookie": {"sessionID=1234"}}) if err != nil { t.Logf("Upgrade: %v", err) return } defer ws.Close() if ws.Subprotocol() != "p1" { t.Logf("Subprotocol() = %s, want p1", ws.Subprotocol()) ws.Close() return } op, rd, err := ws.NextReader() if err != nil { t.Logf("NextReader: %v", err) return } wr, err := ws.NextWriter(op) if err != nil { t.Logf("NextWriter: %v", err) return } if _, err = io.Copy(wr, rd); err != nil { t.Logf("NextWriter: %v", err) return } if err := wr.Close(); err != nil { t.Logf("Close: %v", err) return } } func makeWsProto(s string) string { return "ws" + strings.TrimPrefix(s, "http") } func sendRecv(t *testing.T, ws *Conn) { const message = "Hello World!" if err := ws.SetWriteDeadline(time.Now().Add(time.Second)); err != nil { t.Fatalf("SetWriteDeadline: %v", err) } if err := ws.WriteMessage(TextMessage, []byte(message)); err != nil { t.Fatalf("WriteMessage: %v", err) } if err := ws.SetReadDeadline(time.Now().Add(time.Second)); err != nil { t.Fatalf("SetReadDeadline: %v", err) } _, p, err := ws.ReadMessage() if err != nil { t.Fatalf("ReadMessage: %v", err) } if string(p) != message { t.Fatalf("message=%s, want %s", p, message) } } func rootCAs(t *testing.T, s *httptest.Server) *x509.CertPool { certs := x509.NewCertPool() for _, c := range s.TLS.Certificates { roots, err := x509.ParseCertificates(c.Certificate[len(c.Certificate)-1]) if err != nil { t.Fatalf("error parsing server's root cert: %v", err) } for _, root := range roots { certs.AddCert(root) } } return certs } // requireDeadlineNetConn fails the current test when Read or Write are called // with no deadline. type requireDeadlineNetConn struct { t *testing.T c net.Conn readDeadlineIsSet bool writeDeadlineIsSet bool } func (c *requireDeadlineNetConn) SetDeadline(t time.Time) error { c.writeDeadlineIsSet = !t.Equal(time.Time{}) c.readDeadlineIsSet = c.writeDeadlineIsSet return c.c.SetDeadline(t) } func (c *requireDeadlineNetConn) SetReadDeadline(t time.Time) error { c.readDeadlineIsSet = !t.Equal(time.Time{}) return c.c.SetDeadline(t) } func (c *requireDeadlineNetConn) SetWriteDeadline(t time.Time) error { c.writeDeadlineIsSet = !t.Equal(time.Time{}) return c.c.SetDeadline(t) } func (c *requireDeadlineNetConn) Write(p []byte) (int, error) { if !c.writeDeadlineIsSet { c.t.Fatalf("write with no deadline") } return c.c.Write(p) } func (c *requireDeadlineNetConn) Read(p []byte) (int, error) { if !c.readDeadlineIsSet { c.t.Fatalf("read with no deadline") } return c.c.Read(p) } func (c *requireDeadlineNetConn) Close() error { return c.c.Close() } func (c *requireDeadlineNetConn) LocalAddr() net.Addr { return c.c.LocalAddr() } func (c *requireDeadlineNetConn) RemoteAddr() net.Addr { return c.c.RemoteAddr() } func TestBadMethod(t *testing.T) { s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { ws, err := cstUpgrader.Upgrade(w, r, nil) if err == nil { t.Errorf("handshake succeeded, expect fail") ws.Close() } })) defer s.Close() req, err := http.NewRequest(http.MethodPost, s.URL, strings.NewReader("")) if err != nil { t.Fatalf("NewRequest returned error %v", err) } req.Header.Set("Connection", "upgrade") req.Header.Set("Upgrade", "websocket") req.Header.Set("Sec-Websocket-Version", "13") resp, err := http.DefaultClient.Do(req) if err != nil { t.Fatalf("Do returned error %v", err) } resp.Body.Close() if resp.StatusCode != http.StatusMethodNotAllowed { t.Errorf("Status = %d, want %d", resp.StatusCode, http.StatusMethodNotAllowed) } } func TestNoUpgrade(t *testing.T) { t.Parallel() s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { ws, err := cstUpgrader.Upgrade(w, r, nil) if err == nil { t.Errorf("handshake succeeded, expect fail") ws.Close() } })) defer s.Close() req, err := http.NewRequest(http.MethodGet, s.URL, strings.NewReader("")) if err != nil { t.Fatalf("NewRequest returned error %v", err) } req.Header.Set("Connection", "upgrade") req.Header.Set("Sec-Websocket-Version", "13") resp, err := http.DefaultClient.Do(req) if err != nil { t.Fatalf("Do returned error %v", err) } resp.Body.Close() if u := resp.Header.Get("Upgrade"); u != "websocket" { t.Errorf("Upgrade response header is %q, want %q", u, "websocket") } if resp.StatusCode != http.StatusUpgradeRequired { t.Errorf("Status = %d, want %d", resp.StatusCode, http.StatusUpgradeRequired) } } type testLogWriter struct { t *testing.T } func (w testLogWriter) Write(p []byte) (int, error) { w.t.Logf("%s", p) return len(p), nil } type dataBeforeHandshakeResponseWriter struct { http.ResponseWriter } type dataBeforeHandshakeConnection struct { net.Conn io.Reader } func (c *dataBeforeHandshakeConnection) Read(p []byte) (int, error) { return c.Reader.Read(p) } func (w dataBeforeHandshakeResponseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) { // Example single-frame masked text message from section 5.7 of the RFC. message := []byte{0x81, 0x85, 0x37, 0xfa, 0x21, 0x3d, 0x7f, 0x9f, 0x4d, 0x51, 0x58} n := len(message) / 2 c, rw, err := http.NewResponseController(w.ResponseWriter).Hijack() if rw != nil { // Load first part of message into bufio.Reader. If the websocket // connection reads more than n bytes from the bufio.Reader, then the // test will fail with an unexpected EOF error. rw.Reader.Reset(bytes.NewReader(message[:n])) rw.Reader.Peek(n) } if c != nil { // Inject second part of message before data read from the network connection. c = &dataBeforeHandshakeConnection{ Conn: c, Reader: io.MultiReader(bytes.NewReader(message[n:]), c), } } return c, rw, err } var hostPortNoPortTests = []struct { u *url.URL hostPort, hostNoPort string }{ {&url.URL{Scheme: "ws", Host: "example.com"}, "example.com:80", "example.com"}, {&url.URL{Scheme: "wss", Host: "example.com"}, "example.com:443", "example.com"}, {&url.URL{Scheme: "ws", Host: "example.com:7777"}, "example.com:7777", "example.com"}, {&url.URL{Scheme: "wss", Host: "example.com:7777"}, "example.com:7777", "example.com"}, } func TestHostPortNoPort(t *testing.T) { for _, tt := range hostPortNoPortTests { hostPort, hostNoPort := hostPortNoPort(tt.u) if hostPort != tt.hostPort { t.Errorf("hostPortNoPort(%v) returned hostPort %q, want %q", tt.u, hostPort, tt.hostPort) } if hostNoPort != tt.hostNoPort { t.Errorf("hostPortNoPort(%v) returned hostNoPort %q, want %q", tt.u, hostNoPort, tt.hostNoPort) } } } type nopCloser struct{ io.Writer } func (nopCloser) Close() error { return nil } func TestTruncWriter(t *testing.T) { const data = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijlkmnopqrstuvwxyz987654321" for n := 1; n <= 10; n++ { var b bytes.Buffer w := &truncWriter{w: nopCloser{&b}} p := []byte(data) for len(p) > 0 { m := len(p) if m > n { m = n } _, _ = w.Write(p[:m]) p = p[m:] } if b.String() != data[:len(data)-len(w.p)] { t.Errorf("%d: %q", n, b.String()) } } } func textMessages(num int) [][]byte { messages := make([][]byte, num) for i := 0; i < num; i++ { msg := fmt.Sprintf("planet: %d, country: %d, city: %d, street: %d", i, i, i, i) messages[i] = []byte(msg) } return messages } func BenchmarkWriteNoCompression(b *testing.B) { w := io.Discard c := newTestConn(nil, w, false) messages := textMessages(100) b.ResetTimer() for i := 0; i < b.N; i++ { _ = c.WriteMessage(TextMessage, messages[i%len(messages)]) } b.ReportAllocs() } func BenchmarkWriteWithCompression(b *testing.B) { w := io.Discard c := newTestConn(nil, w, false) messages := textMessages(100) c.enableWriteCompression = true c.newCompressionWriter = compressNoContextTakeover b.ResetTimer() for i := 0; i < b.N; i++ { _ = c.WriteMessage(TextMessage, messages[i%len(messages)]) } b.ReportAllocs() } func TestValidCompressionLevel(t *testing.T) { c := newTestConn(nil, nil, false) for _, level := range []int{minCompressionLevel - 1, maxCompressionLevel + 1} { if err := c.SetCompressionLevel(level); err == nil { t.Errorf("no error for level %d", level) } } for _, level := range []int{minCompressionLevel, maxCompressionLevel} { if err := c.SetCompressionLevel(level); err != nil { t.Errorf("error for level %d", level) } } } // broadcastBench allows to run broadcast benchmarks. // In every broadcast benchmark we create many connections, then send the same // message into every connection and wait for all writes complete. This emulates // an application where many connections listen to the same data - i.e. PUB/SUB // scenarios with many subscribers in one channel. type broadcastBench struct { w io.Writer closeCh chan struct{} doneCh chan struct{} count int32 conns []*broadcastConn compression bool usePrepared bool } type broadcastMessage struct { payload []byte prepared *PreparedMessage } type broadcastConn struct { conn *Conn msgCh chan *broadcastMessage } func newBroadcastConn(c *Conn) *broadcastConn { return &broadcastConn{ conn: c, msgCh: make(chan *broadcastMessage, 1), } } func newBroadcastBench(usePrepared, compression bool) *broadcastBench { bench := &broadcastBench{ w: io.Discard, doneCh: make(chan struct{}), closeCh: make(chan struct{}), usePrepared: usePrepared, compression: compression, } bench.makeConns(10000) return bench } func (b *broadcastBench) makeConns(numConns int) { conns := make([]*broadcastConn, numConns) for i := 0; i < numConns; i++ { c := newTestConn(nil, b.w, true) if b.compression { c.enableWriteCompression = true c.newCompressionWriter = compressNoContextTakeover } conns[i] = newBroadcastConn(c) go func(c *broadcastConn) { for { select { case msg := <-c.msgCh: if msg.prepared != nil { _ = c.conn.WritePreparedMessage(msg.prepared) } else { _ = c.conn.WriteMessage(TextMessage, msg.payload) } val := atomic.AddInt32(&b.count, 1) if val%int32(numConns) == 0 { b.doneCh <- struct{}{} } case <-b.closeCh: return } } }(conns[i]) } b.conns = conns } func (b *broadcastBench) close() { close(b.closeCh) } func (b *broadcastBench) broadcastOnce(msg *broadcastMessage) { for _, c := range b.conns { c.msgCh <- msg } <-b.doneCh } func BenchmarkBroadcast(b *testing.B) { benchmarks := []struct { name string usePrepared bool compression bool }{ {"NoCompression", false, false}, {"Compression", false, true}, {"NoCompressionPrepared", true, false}, {"CompressionPrepared", true, true}, } payload := textMessages(1)[0] for _, bm := range benchmarks { b.Run(bm.name, func(b *testing.B) { bench := newBroadcastBench(bm.usePrepared, bm.compression) defer bench.close() b.ResetTimer() for i := 0; i < b.N; i++ { message := &broadcastMessage{ payload: payload, } if bench.usePrepared { pm, _ := NewPreparedMessage(TextMessage, message.payload) message.prepared = pm } bench.broadcastOnce(message) } b.ReportAllocs() }) } } var _ net.Error = errWriteTimeout type fakeNetConn struct { io.Reader io.Writer } func (c fakeNetConn) Close() error { return nil } func (c fakeNetConn) LocalAddr() net.Addr { return localAddr } func (c fakeNetConn) RemoteAddr() net.Addr { return remoteAddr } func (c fakeNetConn) SetDeadline(t time.Time) error { return nil } func (c fakeNetConn) SetReadDeadline(t time.Time) error { return nil } func (c fakeNetConn) SetWriteDeadline(t time.Time) error { return nil } type fakeAddr int var ( localAddr = fakeAddr(1) remoteAddr = fakeAddr(2) ) func (a fakeAddr) Network() string { return "net" } func (a fakeAddr) String() string { return "str" } // newTestConn creates a connection backed by a fake network connection using // default values for buffering. func newTestConn(r io.Reader, w io.Writer, isServer bool) *Conn { return newConn(fakeNetConn{Reader: r, Writer: w}, isServer, 1024, 1024, nil, nil, nil) } func TestFraming(t *testing.T) { frameSizes := []int{ 0, 1, 2, 124, 125, 126, 127, 128, 129, 65534, 65535, // 65536, 65537 } var readChunkers = []struct { name string f func(io.Reader) io.Reader }{ {"half", iotest.HalfReader}, {"one", iotest.OneByteReader}, {"asis", func(r io.Reader) io.Reader { return r }}, } writeBuf := make([]byte, 65537) for i := range writeBuf { writeBuf[i] = byte(i) } var writers = []struct { name string f func(w io.Writer, n int) (int, error) }{ {"iocopy", func(w io.Writer, n int) (int, error) { nn, err := io.Copy(w, bytes.NewReader(writeBuf[:n])) return int(nn), err }}, {"write", func(w io.Writer, n int) (int, error) { return w.Write(writeBuf[:n]) }}, {"string", func(w io.Writer, n int) (int, error) { return io.WriteString(w, string(writeBuf[:n])) }}, } for _, compress := range []bool{false, true} { for _, isServer := range []bool{true, false} { for _, chunker := range readChunkers { var connBuf bytes.Buffer wc := newTestConn(nil, &connBuf, isServer) rc := newTestConn(chunker.f(&connBuf), nil, !isServer) if compress { wc.newCompressionWriter = compressNoContextTakeover rc.newDecompressionReader = decompressNoContextTakeover } for _, n := range frameSizes { for _, writer := range writers { name := fmt.Sprintf("z:%v, s:%v, r:%s, n:%d w:%s", compress, isServer, chunker.name, n, writer.name) w, err := wc.NextWriter(TextMessage) if err != nil { t.Errorf("%s: wc.NextWriter() returned %v", name, err) continue } nn, err := writer.f(w, n) if err != nil || nn != n { t.Errorf("%s: w.Write(writeBuf[:n]) returned %d, %v", name, nn, err) continue } err = w.Close() if err != nil { t.Errorf("%s: w.Close() returned %v", name, err) continue } opCode, r, err := rc.NextReader() if err != nil || opCode != TextMessage { t.Errorf("%s: NextReader() returned %d, r, %v", name, opCode, err) continue } t.Logf("frame size: %d", n) rbuf, err := io.ReadAll(r) if err != nil { t.Errorf("%s: ReadFull() returned rbuf, %v", name, err) continue } if len(rbuf) != n { t.Errorf("%s: len(rbuf) is %d, want %d", name, len(rbuf), n) continue } for i, b := range rbuf { if byte(i) != b { t.Errorf("%s: bad byte at offset %d", name, i) break } } } } } } } } func TestWriteControlDeadline(t *testing.T) { t.Parallel() message := []byte("hello") var connBuf bytes.Buffer c := newTestConn(nil, &connBuf, true) if err := c.WriteControl(PongMessage, message, time.Time{}); err != nil { t.Errorf("WriteControl(..., zero deadline) = %v, want nil", err) } if err := c.WriteControl(PongMessage, message, time.Now().Add(time.Second)); err != nil { t.Errorf("WriteControl(..., future deadline) = %v, want nil", err) } if err := c.WriteControl(PongMessage, message, time.Now().Add(-time.Second)); err == nil { t.Errorf("WriteControl(..., past deadline) = nil, want timeout error") } } func TestConcurrencyWriteControl(t *testing.T) { const message = "this is a ping/pong messsage" loop := 10 workers := 10 for i := 0; i < loop; i++ { var connBuf bytes.Buffer wg := sync.WaitGroup{} wc := newTestConn(nil, &connBuf, true) for i := 0; i < workers; i++ { wg.Add(1) go func() { defer wg.Done() if err := wc.WriteControl(PongMessage, []byte(message), time.Now().Add(time.Second)); err != nil { t.Errorf("concurrently wc.WriteControl() returned %v", err) } }() } wg.Wait() wc.Close() } } func TestControl(t *testing.T) { const message = "this is a ping/pong message" for _, isServer := range []bool{true, false} { for _, isWriteControl := range []bool{true, false} { name := fmt.Sprintf("s:%v, wc:%v", isServer, isWriteControl) var connBuf bytes.Buffer wc := newTestConn(nil, &connBuf, isServer) rc := newTestConn(&connBuf, nil, !isServer) if isWriteControl { _ = wc.WriteControl(PongMessage, []byte(message), time.Now().Add(time.Second)) } else { w, err := wc.NextWriter(PongMessage) if err != nil { t.Errorf("%s: wc.NextWriter() returned %v", name, err) continue } if _, err := w.Write([]byte(message)); err != nil { t.Errorf("%s: w.Write() returned %v", name, err) continue } if err := w.Close(); err != nil { t.Errorf("%s: w.Close() returned %v", name, err) continue } var actualMessage string rc.SetPongHandler(func(s string) error { actualMessage = s; return nil }) _, _, _ = rc.NextReader() if actualMessage != message { t.Errorf("%s: pong=%q, want %q", name, actualMessage, message) continue } } } } } // simpleBufferPool is an implementation of BufferPool for TestWriteBufferPool. type simpleBufferPool struct { v interface{} } func (p *simpleBufferPool) Get() interface{} { v := p.v p.v = nil return v } func (p *simpleBufferPool) Put(v interface{}) { p.v = v } func TestWriteBufferPool(t *testing.T) { const message = "Now is the time for all good people to come to the aid of the party." var buf bytes.Buffer var pool simpleBufferPool rc := newTestConn(&buf, nil, false) // Specify writeBufferSize smaller than message size to ensure that pooling // works with fragmented messages. wc := newConn(fakeNetConn{Writer: &buf}, true, 1024, len(message)-1, &pool, nil, nil) if wc.writeBuf != nil { t.Fatal("writeBuf not nil after create") } // Part 1: test NextWriter/Write/Close w, err := wc.NextWriter(TextMessage) if err != nil { t.Fatalf("wc.NextWriter() returned %v", err) } if wc.writeBuf == nil { t.Fatal("writeBuf is nil after NextWriter") } writeBufAddr := &wc.writeBuf[0] if _, err := io.WriteString(w, message); err != nil { t.Fatalf("io.WriteString(w, message) returned %v", err) } if err := w.Close(); err != nil { t.Fatalf("w.Close() returned %v", err) } if wc.writeBuf != nil { t.Fatal("writeBuf not nil after w.Close()") } if wpd, ok := pool.v.(writePoolData); !ok || len(wpd.buf) == 0 || &wpd.buf[0] != writeBufAddr { t.Fatal("writeBuf not returned to pool") } opCode, p, err := rc.ReadMessage() if opCode != TextMessage || err != nil { t.Fatalf("ReadMessage() returned %d, p, %v", opCode, err) } if s := string(p); s != message { t.Fatalf("message is %s, want %s", s, message) } // Part 2: Test WriteMessage. if err := wc.WriteMessage(TextMessage, []byte(message)); err != nil { t.Fatalf("wc.WriteMessage() returned %v", err) } if wc.writeBuf != nil { t.Fatal("writeBuf not nil after wc.WriteMessage()") } if wpd, ok := pool.v.(writePoolData); !ok || len(wpd.buf) == 0 || &wpd.buf[0] != writeBufAddr { t.Fatal("writeBuf not returned to pool after WriteMessage") } opCode, p, err = rc.ReadMessage() if opCode != TextMessage || err != nil { t.Fatalf("ReadMessage() returned %d, p, %v", opCode, err) } if s := string(p); s != message { t.Fatalf("message is %s, want %s", s, message) } } // TestWriteBufferPoolSync ensures that *sync.Pool works as a buffer pool. func TestWriteBufferPoolSync(t *testing.T) { var buf bytes.Buffer var pool sync.Pool wc := newConn(fakeNetConn{Writer: &buf}, true, 1024, 1024, &pool, nil, nil) rc := newTestConn(&buf, nil, false) const message = "Hello World!" for i := 0; i < 3; i++ { if err := wc.WriteMessage(TextMessage, []byte(message)); err != nil { t.Fatalf("wc.WriteMessage() returned %v", err) } opCode, p, err := rc.ReadMessage() if opCode != TextMessage || err != nil { t.Fatalf("ReadMessage() returned %d, p, %v", opCode, err) } if s := string(p); s != message { t.Fatalf("message is %s, want %s", s, message) } } } // errorWriter is an io.Writer than returns an error on all writes. type errorWriter struct{} func (ew errorWriter) Write(p []byte) (int, error) { return 0, errors.New("error") } // TestWriteBufferPoolError ensures that buffer is returned to pool after error // on write. func TestWriteBufferPoolError(t *testing.T) { // Part 1: Test NextWriter/Write/Close var pool simpleBufferPool wc := newConn(fakeNetConn{Writer: errorWriter{}}, true, 1024, 1024, &pool, nil, nil) w, err := wc.NextWriter(TextMessage) if err != nil { t.Fatalf("wc.NextWriter() returned %v", err) } if wc.writeBuf == nil { t.Fatal("writeBuf is nil after NextWriter") } writeBufAddr := &wc.writeBuf[0] if _, err := io.WriteString(w, "Hello"); err != nil { t.Fatalf("io.WriteString(w, message) returned %v", err) } if err := w.Close(); err == nil { t.Fatalf("w.Close() did not return error") } if wpd, ok := pool.v.(writePoolData); !ok || len(wpd.buf) == 0 || &wpd.buf[0] != writeBufAddr { t.Fatal("writeBuf not returned to pool") } // Part 2: Test WriteMessage wc = newConn(fakeNetConn{Writer: errorWriter{}}, true, 1024, 1024, &pool, nil, nil) if err := wc.WriteMessage(TextMessage, []byte("Hello")); err == nil { t.Fatalf("wc.WriteMessage did not return error") } if wpd, ok := pool.v.(writePoolData); !ok || len(wpd.buf) == 0 || &wpd.buf[0] != writeBufAddr { t.Fatal("writeBuf not returned to pool") } } func TestCloseFrameBeforeFinalMessageFrame(t *testing.T) { const bufSize = 512 expectedErr := &CloseError{Code: CloseNormalClosure, Text: "hello"} var b1, b2 bytes.Buffer wc := newConn(&fakeNetConn{Reader: nil, Writer: &b1}, false, 1024, bufSize, nil, nil, nil) rc := newTestConn(&b1, &b2, true) w, _ := wc.NextWriter(BinaryMessage) _, _ = w.Write(make([]byte, bufSize+bufSize/2)) _ = wc.WriteControl(CloseMessage, FormatCloseMessage(expectedErr.Code, expectedErr.Text), time.Now().Add(10*time.Second)) w.Close() op, r, err := rc.NextReader() if op != BinaryMessage || err != nil { t.Fatalf("NextReader() returned %d, %v", op, err) } _, err = io.Copy(io.Discard, r) if !reflect.DeepEqual(err, expectedErr) { t.Fatalf("io.Copy() returned %v, want %v", err, expectedErr) } _, _, err = rc.NextReader() if !reflect.DeepEqual(err, expectedErr) { t.Fatalf("NextReader() returned %v, want %v", err, expectedErr) } } func TestEOFWithinFrame(t *testing.T) { const bufSize = 64 for n := 0; ; n++ { var b bytes.Buffer wc := newTestConn(nil, &b, false) rc := newTestConn(&b, nil, true) w, _ := wc.NextWriter(BinaryMessage) _, _ = w.Write(make([]byte, bufSize)) w.Close() if n >= b.Len() { break } b.Truncate(n) op, r, err := rc.NextReader() if err == errUnexpectedEOF { continue } if op != BinaryMessage || err != nil { t.Fatalf("%d: NextReader() returned %d, %v", n, op, err) } _, err = io.Copy(io.Discard, r) if err != errUnexpectedEOF { t.Fatalf("%d: io.Copy() returned %v, want %v", n, err, errUnexpectedEOF) } _, _, err = rc.NextReader() if err != errUnexpectedEOF { t.Fatalf("%d: NextReader() returned %v, want %v", n, err, errUnexpectedEOF) } } } func TestEOFBeforeFinalFrame(t *testing.T) { const bufSize = 512 var b1, b2 bytes.Buffer wc := newConn(&fakeNetConn{Writer: &b1}, false, 1024, bufSize, nil, nil, nil) rc := newTestConn(&b1, &b2, true) w, _ := wc.NextWriter(BinaryMessage) _, _ = w.Write(make([]byte, bufSize+bufSize/2)) op, r, err := rc.NextReader() if op != BinaryMessage || err != nil { t.Fatalf("NextReader() returned %d, %v", op, err) } _, err = io.Copy(io.Discard, r) if err != errUnexpectedEOF { t.Fatalf("io.Copy() returned %v, want %v", err, errUnexpectedEOF) } _, _, err = rc.NextReader() if err != errUnexpectedEOF { t.Fatalf("NextReader() returned %v, want %v", err, errUnexpectedEOF) } } func TestWriteAfterMessageWriterClose(t *testing.T) { wc := newTestConn(nil, &bytes.Buffer{}, false) w, _ := wc.NextWriter(BinaryMessage) _, _ = io.WriteString(w, "hello") if err := w.Close(); err != nil { t.Fatalf("unexpected error closing message writer, %v", err) } if _, err := io.WriteString(w, "world"); err == nil { t.Fatalf("no error writing after close") } w, _ = wc.NextWriter(BinaryMessage) _, _ = io.WriteString(w, "hello") // close w by getting next writer _, err := wc.NextWriter(BinaryMessage) if err != nil { t.Fatalf("unexpected error getting next writer, %v", err) } if _, err := io.WriteString(w, "world"); err == nil { t.Fatalf("no error writing after close") } } func TestReadLimit(t *testing.T) { t.Run("Test ReadLimit is enforced", func(t *testing.T) { const readLimit = 512 message := make([]byte, readLimit+1) var b1, b2 bytes.Buffer wc := newConn(&fakeNetConn{Writer: &b1}, false, 1024, readLimit-2, nil, nil, nil) rc := newTestConn(&b1, &b2, true) rc.SetReadLimit(readLimit) // Send message at the limit with interleaved pong. w, _ := wc.NextWriter(BinaryMessage) _, _ = w.Write(message[:readLimit-1]) _ = wc.WriteControl(PongMessage, []byte("this is a pong"), time.Now().Add(10*time.Second)) _, _ = w.Write(message[:1]) w.Close() // Send message larger than the limit. _ = wc.WriteMessage(BinaryMessage, message[:readLimit+1]) op, _, err := rc.NextReader() if op != BinaryMessage || err != nil { t.Fatalf("1: NextReader() returned %d, %v", op, err) } op, r, err := rc.NextReader() if op != BinaryMessage || err != nil { t.Fatalf("2: NextReader() returned %d, %v", op, err) } _, err = io.Copy(io.Discard, r) if err != ErrReadLimit { t.Fatalf("io.Copy() returned %v", err) } }) t.Run("Test that ReadLimit cannot be overflowed", func(t *testing.T) { const readLimit = 1 var b1, b2 bytes.Buffer rc := newTestConn(&b1, &b2, true) rc.SetReadLimit(readLimit) // First, send a non-final binary message b1.Write([]byte("\x02\x81")) // Mask key b1.Write([]byte("\x00\x00\x00\x00")) // First payload b1.Write([]byte("A")) // Next, send a negative-length, non-final continuation frame b1.Write([]byte("\x00\xFF\x80\x00\x00\x00\x00\x00\x00\x00")) // Mask key b1.Write([]byte("\x00\x00\x00\x00")) // Next, send a too long, final continuation frame b1.Write([]byte("\x80\xFF\x00\x00\x00\x00\x00\x00\x00\x05")) // Mask key b1.Write([]byte("\x00\x00\x00\x00")) // Too-long payload b1.Write([]byte("BCDEF")) op, r, err := rc.NextReader() if op != BinaryMessage || err != nil { t.Fatalf("1: NextReader() returned %d, %v", op, err) } var buf [10]byte var read int n, err := r.Read(buf[:]) if err != nil && err != ErrReadLimit { t.Fatalf("unexpected error testing read limit: %v", err) } read += n n, err = r.Read(buf[:]) if err != nil && err != ErrReadLimit { t.Fatalf("unexpected error testing read limit: %v", err) } read += n if err == nil && read > readLimit { t.Fatalf("read limit exceeded: limit %d, read %d", readLimit, read) } }) } func TestAddrs(t *testing.T) { c := newTestConn(nil, nil, true) if c.LocalAddr() != localAddr { t.Errorf("LocalAddr = %v, want %v", c.LocalAddr(), localAddr) } if c.RemoteAddr() != remoteAddr { t.Errorf("RemoteAddr = %v, want %v", c.RemoteAddr(), remoteAddr) } } func TestNetConn(t *testing.T) { var b1, b2 bytes.Buffer fc := fakeNetConn{Reader: &b1, Writer: &b2} c := newConn(fc, true, 1024, 1024, nil, nil, nil) ul := c.NetConn() if ul != fc { t.Fatalf("Underlying conn is not what it should be.") } } func TestBufioReadBytes(t *testing.T) { // Test calling bufio.ReadBytes for value longer than read buffer size. m := make([]byte, 512) m[len(m)-1] = '\n' var b1, b2 bytes.Buffer wc := newConn(fakeNetConn{Writer: &b1}, false, len(m)+64, len(m)+64, nil, nil, nil) rc := newConn(fakeNetConn{Reader: &b1, Writer: &b2}, true, len(m)-64, len(m)-64, nil, nil, nil) w, _ := wc.NextWriter(BinaryMessage) _, _ = w.Write(m) w.Close() op, r, err := rc.NextReader() if op != BinaryMessage || err != nil { t.Fatalf("NextReader() returned %d, %v", op, err) } br := bufio.NewReader(r) p, err := br.ReadBytes('\n') if err != nil { t.Fatalf("ReadBytes() returned %v", err) } if len(p) != len(m) { t.Fatalf("read returned %d bytes, want %d bytes", len(p), len(m)) } } var closeErrorTests = []struct { err error codes []int ok bool }{ {&CloseError{Code: CloseNormalClosure}, []int{CloseNormalClosure}, true}, {&CloseError{Code: CloseNormalClosure}, []int{CloseNoStatusReceived}, false}, {&CloseError{Code: CloseNormalClosure}, []int{CloseNoStatusReceived, CloseNormalClosure}, true}, {errors.New("hello"), []int{CloseNormalClosure}, false}, } func TestCloseError(t *testing.T) { for _, tt := range closeErrorTests { ok := IsCloseError(tt.err, tt.codes...) if ok != tt.ok { t.Errorf("IsCloseError(%#v, %#v) returned %v, want %v", tt.err, tt.codes, ok, tt.ok) } } } var unexpectedCloseErrorTests = []struct { err error codes []int ok bool }{ {&CloseError{Code: CloseNormalClosure}, []int{CloseNormalClosure}, false}, {&CloseError{Code: CloseNormalClosure}, []int{CloseNoStatusReceived}, true}, {&CloseError{Code: CloseNormalClosure}, []int{CloseNoStatusReceived, CloseNormalClosure}, false}, {errors.New("hello"), []int{CloseNormalClosure}, false}, } func TestUnexpectedCloseErrors(t *testing.T) { for _, tt := range unexpectedCloseErrorTests { ok := IsUnexpectedCloseError(tt.err, tt.codes...) if ok != tt.ok { t.Errorf("IsUnexpectedCloseError(%#v, %#v) returned %v, want %v", tt.err, tt.codes, ok, tt.ok) } } } type blockingWriter struct { c1, c2 chan struct{} } func (w blockingWriter) Write(p []byte) (int, error) { // Allow main to continue close(w.c1) // Wait for panic in main <-w.c2 return len(p), nil } func TestConcurrentWritePanic(t *testing.T) { w := blockingWriter{make(chan struct{}), make(chan struct{})} c := newTestConn(nil, w, false) go func() { _ = c.WriteMessage(TextMessage, []byte{}) }() // wait for goroutine to block in write. <-w.c1 defer func() { close(w.c2) if v := recover(); v != nil { return } }() _ = c.WriteMessage(TextMessage, []byte{}) t.Fatal("should not get here") } type failingReader struct{} func (r failingReader) Read(p []byte) (int, error) { return 0, io.EOF } func TestFailedConnectionReadPanic(t *testing.T) { c := newTestConn(failingReader{}, nil, false) defer func() { if v := recover(); v != nil { return } }() for i := 0; i < 20000; i++ { _, _, _ = c.ReadMessage() } t.Fatal("should not get here") } func TestJoinMessages(t *testing.T) { messages := []string{"a", "bc", "def", "ghij", "klmno", "0", "12", "345", "6789"} for _, readChunk := range []int{1, 2, 3, 4, 5, 6, 7} { for _, term := range []string{"", ","} { var connBuf bytes.Buffer wc := newTestConn(nil, &connBuf, true) rc := newTestConn(&connBuf, nil, false) for _, m := range messages { _ = wc.WriteMessage(BinaryMessage, []byte(m)) } var result bytes.Buffer _, err := io.CopyBuffer(&result, JoinMessages(rc, term), make([]byte, readChunk)) if IsUnexpectedCloseError(err, CloseAbnormalClosure) { t.Errorf("readChunk=%d, term=%q: unexpected error %v", readChunk, term, err) } want := strings.Join(messages, term) + term if result.String() != want { t.Errorf("readChunk=%d, term=%q, got %q, want %q", readChunk, term, result.String(), want) } } } } func TestJSON(t *testing.T) { var buf bytes.Buffer wc := newTestConn(nil, &buf, true) rc := newTestConn(&buf, nil, false) var actual, expect struct { A int B string } expect.A = 1 expect.B = "hello" if err := wc.WriteJSON(&expect); err != nil { t.Fatal("write", err) } if err := rc.ReadJSON(&actual); err != nil { t.Fatal("read", err) } if !reflect.DeepEqual(&actual, &expect) { t.Fatal("equal", actual, expect) } } func TestPartialJSONRead(t *testing.T) { var buf0, buf1 bytes.Buffer wc := newTestConn(nil, &buf0, true) rc := newTestConn(&buf0, &buf1, false) var v struct { A int B string } v.A = 1 v.B = "hello" messageCount := 0 // Partial JSON values. data, err := json.Marshal(v) if err != nil { t.Fatal(err) } for i := len(data) - 1; i >= 0; i-- { if err := wc.WriteMessage(TextMessage, data[:i]); err != nil { t.Fatal(err) } messageCount++ } // Whitespace. if err := wc.WriteMessage(TextMessage, []byte(" ")); err != nil { t.Fatal(err) } messageCount++ // Close. if err := wc.WriteMessage(CloseMessage, FormatCloseMessage(CloseNormalClosure, "")); err != nil { t.Fatal(err) } for i := 0; i < messageCount; i++ { err := rc.ReadJSON(&v) if err != io.ErrUnexpectedEOF { t.Error("read", i, err) } } err = rc.ReadJSON(&v) if _, ok := err.(*CloseError); !ok { t.Error("final", err) } } // !appengine func maskBytesByByte(key [4]byte, pos int, b []byte) int { for i := range b { b[i] ^= key[pos&3] pos++ } return pos & 3 } func notzero(b []byte) int { for i := range b { if b[i] != 0 { return i } } return -1 } func TestMaskBytes(t *testing.T) { key := [4]byte{1, 2, 3, 4} for size := 1; size <= 1024; size++ { for align := 0; align < wordSize; align++ { for pos := 0; pos < 4; pos++ { b := make([]byte, size+align)[align:] maskBytes(key, pos, b) maskBytesByByte(key, pos, b) if i := notzero(b); i >= 0 { t.Errorf("size:%d, align:%d, pos:%d, offset:%d", size, align, pos, i) } } } } } func BenchmarkMaskBytes(b *testing.B) { for _, size := range []int{2, 4, 8, 16, 32, 512, 1024} { b.Run(fmt.Sprintf("size-%d", size), func(b *testing.B) { for _, align := range []int{wordSize / 2} { b.Run(fmt.Sprintf("align-%d", align), func(b *testing.B) { for _, fn := range []struct { name string fn func(key [4]byte, pos int, b []byte) int }{ {"byte", maskBytesByByte}, {"word", maskBytes}, } { b.Run(fn.name, func(b *testing.B) { key := newMaskKey() data := make([]byte, size+align)[align:] for i := 0; i < b.N; i++ { fn.fn(key, 0, data) } b.SetBytes(int64(len(data))) }) } }) } }) } } var preparedMessageTests = []struct { messageType int isServer bool enableWriteCompression bool compressionLevel int }{ // Server {TextMessage, true, false, flate.BestSpeed}, {TextMessage, true, true, flate.BestSpeed}, {TextMessage, true, true, flate.BestCompression}, {PingMessage, true, false, flate.BestSpeed}, {PingMessage, true, true, flate.BestSpeed}, // Client {TextMessage, false, false, flate.BestSpeed}, {TextMessage, false, true, flate.BestSpeed}, {TextMessage, false, true, flate.BestCompression}, {PingMessage, false, false, flate.BestSpeed}, {PingMessage, false, true, flate.BestSpeed}, } func TestPreparedMessage(t *testing.T) { testRand := rand.New(rand.NewSource(99)) prevMaskRand := maskRand maskRand = testRand defer func() { maskRand = prevMaskRand }() for _, tt := range preparedMessageTests { var data = []byte("this is a test") var buf bytes.Buffer c := newTestConn(nil, &buf, tt.isServer) if tt.enableWriteCompression { c.newCompressionWriter = compressNoContextTakeover } if err := c.SetCompressionLevel(tt.compressionLevel); err != nil { t.Fatal(err) } // Seed random number generator for consistent frame mask. testRand.Seed(1234) if err := c.WriteMessage(tt.messageType, data); err != nil { t.Fatal(err) } want := buf.String() pm, err := NewPreparedMessage(tt.messageType, data) if err != nil { t.Fatal(err) } // Scribble on data to ensure that NewPreparedMessage takes a snapshot. copy(data, "hello world") // Seed random number generator for consistent frame mask. testRand.Seed(1234) buf.Reset() if err := c.WritePreparedMessage(pm); err != nil { t.Fatal(err) } got := buf.String() if got != want { t.Errorf("write message != prepared message for %+v", tt) } } } var subprotocolTests = []struct { h string protocols []string }{ {"", nil}, {"foo", []string{"foo"}}, {"foo,bar", []string{"foo", "bar"}}, {"foo, bar", []string{"foo", "bar"}}, {" foo, bar", []string{"foo", "bar"}}, {" foo, bar ", []string{"foo", "bar"}}, } func TestSubprotocols(t *testing.T) { for _, st := range subprotocolTests { r := http.Request{Header: http.Header{"Sec-Websocket-Protocol": {st.h}}} protocols := Subprotocols(&r) if !reflect.DeepEqual(st.protocols, protocols) { t.Errorf("SubProtocols(%q) returned %#v, want %#v", st.h, protocols, st.protocols) } } } var isWebSocketUpgradeTests = []struct { ok bool h http.Header }{ {false, http.Header{"Upgrade": {"websocket"}}}, {false, http.Header{"Connection": {"upgrade"}}}, {true, http.Header{"Connection": {"upgRade"}, "Upgrade": {"WebSocket"}}}, } func TestIsWebSocketUpgrade(t *testing.T) { for _, tt := range isWebSocketUpgradeTests { ok := IsWebSocketUpgrade(&http.Request{Header: tt.h}) if tt.ok != ok { t.Errorf("IsWebSocketUpgrade(%v) returned %v, want %v", tt.h, ok, tt.ok) } } } func TestSubProtocolSelection(t *testing.T) { upgrader := Upgrader{ Subprotocols: []string{"foo", "bar", "baz"}, } r := http.Request{Header: http.Header{"Sec-Websocket-Protocol": {"foo", "bar"}}} s := upgrader.selectSubprotocol(&r, nil) if s != "foo" { t.Errorf("Upgrader.selectSubprotocol returned %v, want %v", s, "foo") } r = http.Request{Header: http.Header{"Sec-Websocket-Protocol": {"bar", "foo"}}} s = upgrader.selectSubprotocol(&r, nil) if s != "bar" { t.Errorf("Upgrader.selectSubprotocol returned %v, want %v", s, "bar") } r = http.Request{Header: http.Header{"Sec-Websocket-Protocol": {"baz"}}} s = upgrader.selectSubprotocol(&r, nil) if s != "baz" { t.Errorf("Upgrader.selectSubprotocol returned %v, want %v", s, "baz") } r = http.Request{Header: http.Header{"Sec-Websocket-Protocol": {"quux"}}} s = upgrader.selectSubprotocol(&r, nil) if s != "" { t.Errorf("Upgrader.selectSubprotocol returned %v, want %v", s, "empty string") } } var checkSameOriginTests = []struct { ok bool r *http.Request }{ {false, &http.Request{Host: "example.org", Header: map[string][]string{"Origin": {"https://other.org"}}}}, {true, &http.Request{Host: "example.org", Header: map[string][]string{"Origin": {"https://example.org"}}}}, {true, &http.Request{Host: "Example.org", Header: map[string][]string{"Origin": {"https://example.org"}}}}, } func TestCheckSameOrigin(t *testing.T) { for _, tt := range checkSameOriginTests { ok := checkSameOrigin(tt.r) if tt.ok != ok { t.Errorf("checkSameOrigin(%+v) returned %v, want %v", tt.r, ok, tt.ok) } } } type reuseTestResponseWriter struct { brw *bufio.ReadWriter http.ResponseWriter } func (resp *reuseTestResponseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) { return fakeNetConn{strings.NewReader(""), &bytes.Buffer{}}, resp.brw, nil } var bufioReuseTests = []struct { n int reuse bool }{ {4096, true}, {128, false}, } func xTestBufioReuse(t *testing.T) { for i, tt := range bufioReuseTests { br := bufio.NewReaderSize(strings.NewReader(""), tt.n) bw := bufio.NewWriterSize(&bytes.Buffer{}, tt.n) resp := &reuseTestResponseWriter{ brw: bufio.NewReadWriter(br, bw), } upgrader := Upgrader{} c, err := upgrader.Upgrade(resp, &http.Request{ Method: http.MethodGet, Header: http.Header{ "Upgrade": []string{"websocket"}, "Connection": []string{"upgrade"}, "Sec-Websocket-Key": []string{"dGhlIHNhbXBsZSBub25jZQ=="}, "Sec-Websocket-Version": []string{"13"}, }}, nil) if err != nil { t.Fatal(err) } if reuse := c.br == br; reuse != tt.reuse { t.Errorf("%d: buffered reader reuse=%v, want %v", i, reuse, tt.reuse) } writeBuf := bw.AvailableBuffer() if reuse := &c.writeBuf[0] == &writeBuf[0]; reuse != tt.reuse { t.Errorf("%d: write buffer reuse=%v, want %v", i, reuse, tt.reuse) } } } func TestHijack_NotSupported(t *testing.T) { t.Parallel() req := httptest.NewRequest(http.MethodGet, "http://example.com", nil) req.Header.Set("Upgrade", "websocket") req.Header.Set("Connection", "upgrade") req.Header.Set("Sec-Websocket-Key", "dGhlIHNhbXBsZSBub25jZQ==") req.Header.Set("Sec-Websocket-Version", "13") recorder := httptest.NewRecorder() upgrader := Upgrader{} _, err := upgrader.Upgrade(recorder, req, nil) if want := (HandshakeError{}); !errors.As(err, &want) || recorder.Code != http.StatusInternalServerError { t.Errorf("want %T and status_code=%d", want, http.StatusInternalServerError) t.Fatalf("got err=%T and status_code=%d", err, recorder.Code) } } var equalASCIIFoldTests = []struct { t, s string eq bool }{ {"WebSocket", "websocket", true}, {"websocket", "WebSocket", true}, {"Öyster", "öyster", false}, {"WebSocket", "WetSocket", false}, } func TestEqualASCIIFold(t *testing.T) { for _, tt := range equalASCIIFoldTests { eq := equalASCIIFold(tt.s, tt.t) if eq != tt.eq { t.Errorf("equalASCIIFold(%q, %q) = %v, want %v", tt.s, tt.t, eq, tt.eq) } } } var tokenListContainsValueTests = []struct { value string ok bool }{ {"WebSocket", true}, {"WEBSOCKET", true}, {"websocket", true}, {"websockets", false}, {"x websocket", false}, {"websocket x", false}, {"other,websocket,more", true}, {"other, websocket, more", true}, } func TestTokenListContainsValue(t *testing.T) { for _, tt := range tokenListContainsValueTests { h := http.Header{"Upgrade": {tt.value}} ok := tokenListContainsValue(h, "Upgrade", "websocket") if ok != tt.ok { t.Errorf("tokenListContainsValue(h, n, %q) = %v, want %v", tt.value, ok, tt.ok) } } } var isValidChallengeKeyTests = []struct { key string ok bool }{ {"dGhlIHNhbXBsZSBub25jZQ==", true}, {"", false}, {"InvalidKey", false}, {"WHQ4eXhscUtKYjBvOGN3WEdtOEQ=", false}, } func TestIsValidChallengeKey(t *testing.T) { for _, tt := range isValidChallengeKeyTests { ok := isValidChallengeKey(tt.key) if ok != tt.ok { t.Errorf("isValidChallengeKey returns %v, want %v", ok, tt.ok) } } } var parseExtensionTests = []struct { value string extensions []map[string]string }{ {`foo`, []map[string]string{{"": "foo"}}}, {`foo, bar; baz=2`, []map[string]string{ {"": "foo"}, {"": "bar", "baz": "2"}}}, {`foo; bar="b,a;z"`, []map[string]string{ {"": "foo", "bar": "b,a;z"}}}, {`foo , bar; baz = 2`, []map[string]string{ {"": "foo"}, {"": "bar", "baz": "2"}}}, {`foo, bar; baz=2 junk`, []map[string]string{ {"": "foo"}}}, {`foo junk, bar; baz=2 junk`, nil}, {`mux; max-channels=4; flow-control, deflate-stream`, []map[string]string{ {"": "mux", "max-channels": "4", "flow-control": ""}, {"": "deflate-stream"}}}, {`permessage-foo; x="10"`, []map[string]string{ {"": "permessage-foo", "x": "10"}}}, {`permessage-foo; use_y, permessage-foo`, []map[string]string{ {"": "permessage-foo", "use_y": ""}, {"": "permessage-foo"}}}, {`permessage-deflate; client_max_window_bits; server_max_window_bits=10 , permessage-deflate; client_max_window_bits`, []map[string]string{ {"": "permessage-deflate", "client_max_window_bits": "", "server_max_window_bits": "10"}, {"": "permessage-deflate", "client_max_window_bits": ""}}}, {"permessage-deflate; server_no_context_takeover; client_max_window_bits=15", []map[string]string{ {"": "permessage-deflate", "server_no_context_takeover": "", "client_max_window_bits": "15"}, }}, } func TestParseExtensions(t *testing.T) { for _, tt := range parseExtensionTests { h := http.Header{http.CanonicalHeaderKey("Sec-WebSocket-Extensions"): {tt.value}} extensions := parseExtensions(h) if !reflect.DeepEqual(extensions, tt.extensions) { t.Errorf("parseExtensions(%q)\n = %v,\nwant %v", tt.value, extensions, tt.extensions) } } } func test_parseArgs() { given := parseArgs([]string { "x", "y", "z" }) expected := _CLIArgs { fromAddr: "y", toAddr: "z", } g.AssertEqual(given, expected) } func MainTest() { test_parseArgs() tests := []testing.InternalTest { { "TestBadMethod", TestBadMethod }, { "TestNoUpgrade", TestNoUpgrade }, { "TestHostPortNoPort", TestHostPortNoPort }, { "TestTruncWriter", TestTruncWriter }, { "TestValidCompressionLevel", TestValidCompressionLevel }, { "TestFraming", TestFraming }, { "TestWriteControlDeadline", TestWriteControlDeadline }, { "TestConcurrencyWriteControl", TestConcurrencyWriteControl }, { "TestControl", TestControl }, { "TestWriteBufferPool", TestWriteBufferPool }, { "TestWriteBufferPoolSync", TestWriteBufferPoolSync }, { "TestWriteBufferPoolError", TestWriteBufferPoolError }, { "TestCloseFrameBeforeFinalMessageFrame", TestCloseFrameBeforeFinalMessageFrame }, { "TestEOFWithinFrame", TestEOFWithinFrame }, { "TestEOFBeforeFinalFrame", TestEOFBeforeFinalFrame }, { "TestWriteAfterMessageWriterClose", TestWriteAfterMessageWriterClose }, { "TestReadLimit", TestReadLimit }, { "TestAddrs", TestAddrs }, { "TestNetConn", TestNetConn }, { "TestBufioReadBytes", TestBufioReadBytes }, { "TestCloseError", TestCloseError }, { "TestUnexpectedCloseErrors", TestUnexpectedCloseErrors }, { "TestConcurrentWritePanic", TestConcurrentWritePanic }, { "TestFailedConnectionReadPanic", TestFailedConnectionReadPanic }, { "TestJoinMessages", TestJoinMessages }, { "TestJSON", TestJSON }, { "TestPartialJSONRead", TestPartialJSONRead }, { "TestMaskBytes", TestMaskBytes }, { "TestPreparedMessage", TestPreparedMessage }, { "TestSubprotocols", TestSubprotocols }, { "TestIsWebSocketUpgrade", TestIsWebSocketUpgrade }, { "TestSubProtocolSelection", TestSubProtocolSelection }, { "TestCheckSameOrigin", TestCheckSameOrigin }, { "TestHijack_NotSupported", TestHijack_NotSupported }, { "TestEqualASCIIFold", TestEqualASCIIFold }, { "TestTokenListContainsValue", TestTokenListContainsValue }, { "TestIsValidChallengeKey", TestIsValidChallengeKey }, { "TestParseExtensions", TestParseExtensions }, } benchmarks := []testing.InternalBenchmark { { "BenchmarkWriteNoCompression", BenchmarkWriteNoCompression }, { "BenchmarkWriteWithCompression", BenchmarkWriteWithCompression }, { "BenchmarkBroadcast", BenchmarkBroadcast }, { "BenchmarkMaskBytes", BenchmarkMaskBytes }, } fuzzTargets := []testing.InternalFuzzTarget {} examples := []testing.InternalExample {} m := testing.MainStart( testdeps.TestDeps {}, tests, benchmarks, fuzzTargets, examples, ) os.Exit(m.Run()) }