diff options
Diffstat (limited to 'tests/wscat.go')
-rw-r--r-- | tests/wscat.go | 2842 |
1 files changed, 2842 insertions, 0 deletions
diff --git a/tests/wscat.go b/tests/wscat.go new file mode 100644 index 0000000..da16f66 --- /dev/null +++ b/tests/wscat.go @@ -0,0 +1,2842 @@ +package wscat + +import ( + "bufio" + "bytes" + "compress/flate" + "context" + "crypto/tls" + "crypto/x509" + "encoding/base64" + "encoding/binary" + "encoding/json" + "errors" + "fmt" + "io" + "log" + "math/rand" + "net" + "net/http" + "net/http/cookiejar" + "net/http/httptest" + "net/http/httptrace" + "net/url" + "os" + "reflect" + "strings" + "sync" + "sync/atomic" + "testing" + "testing/internal/testdeps" + "testing/iotest" + "time" + + g "gobang" +) + + + +var cstUpgrader = Upgrader{ + Subprotocols: []string{"p0", "p1"}, + ReadBufferSize: 1024, + WriteBufferSize: 1024, + EnableCompression: true, + Error: func(w http.ResponseWriter, r *http.Request, status int, reason error) { + http.Error(w, reason.Error(), status) + }, +} + +var cstDialer = Dialer{ + Subprotocols: []string{"p1", "p2"}, + ReadBufferSize: 1024, + WriteBufferSize: 1024, + HandshakeTimeout: 30 * time.Second, +} + +type cstHandler struct { + *testing.T + s *cstServer +} + +type cstServer struct { + URL string + Server *httptest.Server + wg sync.WaitGroup +} + +const ( + cstPath = "/a/b" + cstRawQuery = "x=y" + cstRequestURI = cstPath + "?" + cstRawQuery +) + +func (s *cstServer) Close() { + s.Server.Close() + // Wait for handler functions to complete. + s.wg.Wait() +} + +func newServer(t *testing.T) *cstServer { + var s cstServer + s.Server = httptest.NewServer(cstHandler{T: t, s: &s}) + s.Server.URL += cstRequestURI + s.URL = makeWsProto(s.Server.URL) + return &s +} + +func newTLSServer(t *testing.T) *cstServer { + var s cstServer + s.Server = httptest.NewTLSServer(cstHandler{T: t, s: &s}) + s.Server.URL += cstRequestURI + s.URL = makeWsProto(s.Server.URL) + return &s +} + +func (t cstHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { + // Because tests wait for a response from a server, we are guaranteed that + // the wait group count is incremented before the test waits on the group + // in the call to (*cstServer).Close(). + t.s.wg.Add(1) + defer t.s.wg.Done() + + if r.URL.Path != cstPath { + t.Logf("path=%v, want %v", r.URL.Path, cstPath) + http.Error(w, "bad path", http.StatusBadRequest) + return + } + if r.URL.RawQuery != cstRawQuery { + t.Logf("query=%v, want %v", r.URL.RawQuery, cstRawQuery) + http.Error(w, "bad path", http.StatusBadRequest) + return + } + subprotos := Subprotocols(r) + if !reflect.DeepEqual(subprotos, cstDialer.Subprotocols) { + t.Logf("subprotols=%v, want %v", subprotos, cstDialer.Subprotocols) + http.Error(w, "bad protocol", http.StatusBadRequest) + return + } + ws, err := cstUpgrader.Upgrade(w, r, http.Header{"Set-Cookie": {"sessionID=1234"}}) + if err != nil { + t.Logf("Upgrade: %v", err) + return + } + defer ws.Close() + + if ws.Subprotocol() != "p1" { + t.Logf("Subprotocol() = %s, want p1", ws.Subprotocol()) + ws.Close() + return + } + op, rd, err := ws.NextReader() + if err != nil { + t.Logf("NextReader: %v", err) + return + } + wr, err := ws.NextWriter(op) + if err != nil { + t.Logf("NextWriter: %v", err) + return + } + if _, err = io.Copy(wr, rd); err != nil { + t.Logf("NextWriter: %v", err) + return + } + if err := wr.Close(); err != nil { + t.Logf("Close: %v", err) + return + } +} + +func makeWsProto(s string) string { + return "ws" + strings.TrimPrefix(s, "http") +} + +func sendRecv(t *testing.T, ws *Conn) { + const message = "Hello World!" + if err := ws.SetWriteDeadline(time.Now().Add(time.Second)); err != nil { + t.Fatalf("SetWriteDeadline: %v", err) + } + if err := ws.WriteMessage(TextMessage, []byte(message)); err != nil { + t.Fatalf("WriteMessage: %v", err) + } + if err := ws.SetReadDeadline(time.Now().Add(time.Second)); err != nil { + t.Fatalf("SetReadDeadline: %v", err) + } + _, p, err := ws.ReadMessage() + if err != nil { + t.Fatalf("ReadMessage: %v", err) + } + if string(p) != message { + t.Fatalf("message=%s, want %s", p, message) + } +} + +func TestProxyDial(t *testing.T) { + + s := newServer(t) + defer s.Close() + + surl, _ := url.Parse(s.Server.URL) + + cstDialer := cstDialer // make local copy for modification on next line. + cstDialer.Proxy = http.ProxyURL(surl) + + connect := false + origHandler := s.Server.Config.Handler + + // Capture the request Host header. + s.Server.Config.Handler = http.HandlerFunc( + func(w http.ResponseWriter, r *http.Request) { + if r.Method == http.MethodConnect { + connect = true + w.WriteHeader(http.StatusOK) + return + } + + if !connect { + t.Log("connect not received") + http.Error(w, "connect not received", http.StatusMethodNotAllowed) + return + } + origHandler.ServeHTTP(w, r) + }) + + ws, _, err := cstDialer.Dial(s.URL, nil) + if err != nil { + t.Fatalf("Dial: %v", err) + } + defer ws.Close() + sendRecv(t, ws) +} + +func TestProxyAuthorizationDial(t *testing.T) { + s := newServer(t) + defer s.Close() + + surl, _ := url.Parse(s.Server.URL) + surl.User = url.UserPassword("username", "password") + + cstDialer := cstDialer // make local copy for modification on next line. + cstDialer.Proxy = http.ProxyURL(surl) + + connect := false + origHandler := s.Server.Config.Handler + + // Capture the request Host header. + s.Server.Config.Handler = http.HandlerFunc( + func(w http.ResponseWriter, r *http.Request) { + proxyAuth := r.Header.Get("Proxy-Authorization") + expectedProxyAuth := "Basic " + base64.StdEncoding.EncodeToString([]byte("username:password")) + if r.Method == http.MethodConnect && proxyAuth == expectedProxyAuth { + connect = true + w.WriteHeader(http.StatusOK) + return + } + + if !connect { + t.Log("connect with proxy authorization not received") + http.Error(w, "connect with proxy authorization not received", http.StatusMethodNotAllowed) + return + } + origHandler.ServeHTTP(w, r) + }) + + ws, _, err := cstDialer.Dial(s.URL, nil) + if err != nil { + t.Fatalf("Dial: %v", err) + } + defer ws.Close() + sendRecv(t, ws) +} + +func TestDial(t *testing.T) { + s := newServer(t) + defer s.Close() + + ws, _, err := cstDialer.Dial(s.URL, nil) + if err != nil { + t.Fatalf("Dial: %v", err) + } + defer ws.Close() + sendRecv(t, ws) +} + +func TestDialCookieJar(t *testing.T) { + s := newServer(t) + defer s.Close() + + jar, _ := cookiejar.New(nil) + d := cstDialer + d.Jar = jar + + u, _ := url.Parse(s.URL) + + switch u.Scheme { + case "ws": + u.Scheme = "http" + case "wss": + u.Scheme = "https" + } + + cookies := []*http.Cookie{{Name: "gorilla", Value: "ws", Path: "/"}} + d.Jar.SetCookies(u, cookies) + + ws, _, err := d.Dial(s.URL, nil) + if err != nil { + t.Fatalf("Dial: %v", err) + } + defer ws.Close() + + var gorilla string + var sessionID string + for _, c := range d.Jar.Cookies(u) { + if c.Name == "gorilla" { + gorilla = c.Value + } + + if c.Name == "sessionID" { + sessionID = c.Value + } + } + if gorilla != "ws" { + t.Error("Cookie not present in jar.") + } + + if sessionID != "1234" { + t.Error("Set-Cookie not received from the server.") + } + + sendRecv(t, ws) +} + +func rootCAs(t *testing.T, s *httptest.Server) *x509.CertPool { + certs := x509.NewCertPool() + for _, c := range s.TLS.Certificates { + roots, err := x509.ParseCertificates(c.Certificate[len(c.Certificate)-1]) + if err != nil { + t.Fatalf("error parsing server's root cert: %v", err) + } + for _, root := range roots { + certs.AddCert(root) + } + } + return certs +} + +func TestDialTLS(t *testing.T) { + s := newTLSServer(t) + defer s.Close() + + d := cstDialer + d.TLSClientConfig = &tls.Config{RootCAs: rootCAs(t, s.Server)} + ws, _, err := d.Dial(s.URL, nil) + if err != nil { + t.Fatalf("Dial: %v", err) + } + defer ws.Close() + sendRecv(t, ws) +} + +func TestDialTimeout(t *testing.T) { + s := newServer(t) + defer s.Close() + + d := cstDialer + d.HandshakeTimeout = -1 + ws, _, err := d.Dial(s.URL, nil) + if err == nil { + ws.Close() + t.Fatalf("Dial: nil") + } +} + +// requireDeadlineNetConn fails the current test when Read or Write are called +// with no deadline. +type requireDeadlineNetConn struct { + t *testing.T + c net.Conn + readDeadlineIsSet bool + writeDeadlineIsSet bool +} + +func (c *requireDeadlineNetConn) SetDeadline(t time.Time) error { + c.writeDeadlineIsSet = !t.Equal(time.Time{}) + c.readDeadlineIsSet = c.writeDeadlineIsSet + return c.c.SetDeadline(t) +} + +func (c *requireDeadlineNetConn) SetReadDeadline(t time.Time) error { + c.readDeadlineIsSet = !t.Equal(time.Time{}) + return c.c.SetDeadline(t) +} + +func (c *requireDeadlineNetConn) SetWriteDeadline(t time.Time) error { + c.writeDeadlineIsSet = !t.Equal(time.Time{}) + return c.c.SetDeadline(t) +} + +func (c *requireDeadlineNetConn) Write(p []byte) (int, error) { + if !c.writeDeadlineIsSet { + c.t.Fatalf("write with no deadline") + } + return c.c.Write(p) +} + +func (c *requireDeadlineNetConn) Read(p []byte) (int, error) { + if !c.readDeadlineIsSet { + c.t.Fatalf("read with no deadline") + } + return c.c.Read(p) +} + +func (c *requireDeadlineNetConn) Close() error { return c.c.Close() } +func (c *requireDeadlineNetConn) LocalAddr() net.Addr { return c.c.LocalAddr() } +func (c *requireDeadlineNetConn) RemoteAddr() net.Addr { return c.c.RemoteAddr() } + +func TestHandshakeTimeout(t *testing.T) { + s := newServer(t) + defer s.Close() + + d := cstDialer + d.NetDial = func(n, a string) (net.Conn, error) { + c, err := net.Dial(n, a) + return &requireDeadlineNetConn{c: c, t: t}, err + } + ws, _, err := d.Dial(s.URL, nil) + if err != nil { + t.Fatal("Dial:", err) + } + ws.Close() +} + +func TestHandshakeTimeoutInContext(t *testing.T) { + s := newServer(t) + defer s.Close() + + d := cstDialer + d.HandshakeTimeout = 0 + d.NetDialContext = func(ctx context.Context, n, a string) (net.Conn, error) { + netDialer := &net.Dialer{} + c, err := netDialer.DialContext(ctx, n, a) + return &requireDeadlineNetConn{c: c, t: t}, err + } + + ctx, cancel := context.WithDeadline(context.Background(), time.Now().Add(30*time.Second)) + defer cancel() + ws, _, err := d.DialContext(ctx, s.URL, nil) + if err != nil { + t.Fatal("Dial:", err) + } + ws.Close() +} + +func TestDialBadScheme(t *testing.T) { + s := newServer(t) + defer s.Close() + + ws, _, err := cstDialer.Dial(s.Server.URL, nil) + if err == nil { + ws.Close() + t.Fatalf("Dial: nil") + } +} + +func TestDialBadOrigin(t *testing.T) { + s := newServer(t) + defer s.Close() + + ws, resp, err := cstDialer.Dial(s.URL, http.Header{"Origin": {"bad"}}) + if err == nil { + ws.Close() + t.Fatalf("Dial: nil") + } + if resp == nil { + t.Fatalf("resp=nil, err=%v", err) + } + if resp.StatusCode != http.StatusForbidden { + t.Fatalf("status=%d, want %d", resp.StatusCode, http.StatusForbidden) + } +} + +func TestDialBadHeader(t *testing.T) { + s := newServer(t) + defer s.Close() + + for _, k := range []string{"Upgrade", + "Connection", + "Sec-Websocket-Key", + "Sec-Websocket-Version", + "Sec-Websocket-Protocol"} { + h := http.Header{} + h.Set(k, "bad") + ws, _, err := cstDialer.Dial(s.URL, http.Header{"Origin": {"bad"}}) + if err == nil { + ws.Close() + t.Errorf("Dial with header %s returned nil", k) + } + } +} + +func TestBadMethod(t *testing.T) { + s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + ws, err := cstUpgrader.Upgrade(w, r, nil) + if err == nil { + t.Errorf("handshake succeeded, expect fail") + ws.Close() + } + })) + defer s.Close() + + req, err := http.NewRequest(http.MethodPost, s.URL, strings.NewReader("")) + if err != nil { + t.Fatalf("NewRequest returned error %v", err) + } + req.Header.Set("Connection", "upgrade") + req.Header.Set("Upgrade", "websocket") + req.Header.Set("Sec-Websocket-Version", "13") + + resp, err := http.DefaultClient.Do(req) + if err != nil { + t.Fatalf("Do returned error %v", err) + } + resp.Body.Close() + if resp.StatusCode != http.StatusMethodNotAllowed { + t.Errorf("Status = %d, want %d", resp.StatusCode, http.StatusMethodNotAllowed) + } +} + +func TestNoUpgrade(t *testing.T) { + t.Parallel() + s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + ws, err := cstUpgrader.Upgrade(w, r, nil) + if err == nil { + t.Errorf("handshake succeeded, expect fail") + ws.Close() + } + })) + defer s.Close() + + req, err := http.NewRequest(http.MethodGet, s.URL, strings.NewReader("")) + if err != nil { + t.Fatalf("NewRequest returned error %v", err) + } + req.Header.Set("Connection", "upgrade") + req.Header.Set("Sec-Websocket-Version", "13") + + resp, err := http.DefaultClient.Do(req) + if err != nil { + t.Fatalf("Do returned error %v", err) + } + resp.Body.Close() + if u := resp.Header.Get("Upgrade"); u != "websocket" { + t.Errorf("Upgrade response header is %q, want %q", u, "websocket") + } + if resp.StatusCode != http.StatusUpgradeRequired { + t.Errorf("Status = %d, want %d", resp.StatusCode, http.StatusUpgradeRequired) + } +} + +func TestDialExtraTokensInRespHeaders(t *testing.T) { + s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + challengeKey := r.Header.Get("Sec-Websocket-Key") + w.Header().Set("Upgrade", "foo, websocket") + w.Header().Set("Connection", "upgrade, keep-alive") + w.Header().Set("Sec-Websocket-Accept", computeAcceptKey(challengeKey)) + w.WriteHeader(101) + })) + defer s.Close() + + ws, _, err := cstDialer.Dial(makeWsProto(s.URL), nil) + if err != nil { + t.Fatalf("Dial: %v", err) + } + defer ws.Close() +} + +func TestHandshake(t *testing.T) { + s := newServer(t) + defer s.Close() + + ws, resp, err := cstDialer.Dial(s.URL, http.Header{"Origin": {s.URL}}) + if err != nil { + t.Fatalf("Dial: %v", err) + } + defer ws.Close() + + var sessionID string + for _, c := range resp.Cookies() { + if c.Name == "sessionID" { + sessionID = c.Value + } + } + if sessionID != "1234" { + t.Error("Set-Cookie not received from the server.") + } + + if ws.Subprotocol() != "p1" { + t.Errorf("ws.Subprotocol() = %s, want p1", ws.Subprotocol()) + } + sendRecv(t, ws) +} + +func TestRespOnBadHandshake(t *testing.T) { + const expectedStatus = http.StatusGone + const expectedBody = "This is the response body." + + s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(expectedStatus) + _, _ = io.WriteString(w, expectedBody) + })) + defer s.Close() + + ws, resp, err := cstDialer.Dial(makeWsProto(s.URL), nil) + if err == nil { + ws.Close() + t.Fatalf("Dial: nil") + } + + if resp == nil { + t.Fatalf("resp=nil, err=%v", err) + } + + if resp.StatusCode != expectedStatus { + t.Errorf("resp.StatusCode=%d, want %d", resp.StatusCode, expectedStatus) + } + + p, err := io.ReadAll(resp.Body) + if err != nil { + t.Fatalf("ReadFull(resp.Body) returned error %v", err) + } + + if string(p) != expectedBody { + t.Errorf("resp.Body=%s, want %s", p, expectedBody) + } +} + +type testLogWriter struct { + t *testing.T +} + +func (w testLogWriter) Write(p []byte) (int, error) { + w.t.Logf("%s", p) + return len(p), nil +} + +// TestHost tests handling of host names and confirms that it matches net/http. +func TestHost(t *testing.T) { + + upgrader := Upgrader{} + handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if IsWebSocketUpgrade(r) { + c, err := upgrader.Upgrade(w, r, http.Header{"X-Test-Host": {r.Host}}) + if err != nil { + t.Fatal(err) + } + c.Close() + } else { + w.Header().Set("X-Test-Host", r.Host) + } + }) + + server := httptest.NewServer(handler) + defer server.Close() + + tlsServer := httptest.NewTLSServer(handler) + defer tlsServer.Close() + + addrs := map[*httptest.Server]string{server: server.Listener.Addr().String(), tlsServer: tlsServer.Listener.Addr().String()} + wsProtos := map[*httptest.Server]string{server: "ws://", tlsServer: "wss://"} + httpProtos := map[*httptest.Server]string{server: "http://", tlsServer: "https://"} + + // Avoid log noise from net/http server by logging to testing.T + server.Config.ErrorLog = log.New(testLogWriter{t}, "", 0) + tlsServer.Config.ErrorLog = server.Config.ErrorLog + + cas := rootCAs(t, tlsServer) + + tests := []struct { + fail bool // true if dial / get should fail + server *httptest.Server // server to use + url string // host for request URI + header string // optional request host header + tls string // optional host for tls ServerName + wantAddr string // expected host for dial + wantHeader string // expected request header on server + insecureSkipVerify bool + }{ + { + server: server, + url: addrs[server], + wantAddr: addrs[server], + wantHeader: addrs[server], + }, + { + server: tlsServer, + url: addrs[tlsServer], + wantAddr: addrs[tlsServer], + wantHeader: addrs[tlsServer], + }, + + { + server: server, + url: addrs[server], + header: "badhost.com", + wantAddr: addrs[server], + wantHeader: "badhost.com", + }, + { + server: tlsServer, + url: addrs[tlsServer], + header: "badhost.com", + wantAddr: addrs[tlsServer], + wantHeader: "badhost.com", + }, + + { + server: server, + url: "example.com", + header: "badhost.com", + wantAddr: "example.com:80", + wantHeader: "badhost.com", + }, + { + server: tlsServer, + url: "example.com", + header: "badhost.com", + wantAddr: "example.com:443", + wantHeader: "badhost.com", + }, + + { + server: server, + url: "badhost.com", + header: "example.com", + wantAddr: "badhost.com:80", + wantHeader: "example.com", + }, + { + fail: true, + server: tlsServer, + url: "badhost.com", + header: "example.com", + wantAddr: "badhost.com:443", + }, + { + server: tlsServer, + url: "badhost.com", + insecureSkipVerify: true, + wantAddr: "badhost.com:443", + wantHeader: "badhost.com", + }, + { + server: tlsServer, + url: "badhost.com", + tls: "example.com", + wantAddr: "badhost.com:443", + wantHeader: "badhost.com", + }, + } + + for i, tt := range tests { + + tls := &tls.Config{ + RootCAs: cas, + ServerName: tt.tls, + InsecureSkipVerify: tt.insecureSkipVerify, + } + + var gotAddr string + dialer := Dialer{ + NetDial: func(network, addr string) (net.Conn, error) { + gotAddr = addr + return net.Dial(network, addrs[tt.server]) + }, + TLSClientConfig: tls, + } + + // Test websocket dial + + h := http.Header{} + if tt.header != "" { + h.Set("Host", tt.header) + } + c, resp, err := dialer.Dial(wsProtos[tt.server]+tt.url+"/", h) + if err == nil { + c.Close() + } + + check := func(protos map[*httptest.Server]string) { + name := fmt.Sprintf("%d: %s%s/ header[Host]=%q, tls.ServerName=%q", i+1, protos[tt.server], tt.url, tt.header, tt.tls) + if gotAddr != tt.wantAddr { + t.Errorf("%s: got addr %s, want %s", name, gotAddr, tt.wantAddr) + } + switch { + case tt.fail && err == nil: + t.Errorf("%s: unexpected success", name) + case !tt.fail && err != nil: + t.Errorf("%s: unexpected error %v", name, err) + case !tt.fail && err == nil: + if gotHost := resp.Header.Get("X-Test-Host"); gotHost != tt.wantHeader { + t.Errorf("%s: got host %s, want %s", name, gotHost, tt.wantHeader) + } + } + } + + check(wsProtos) + + // Confirm that net/http has same result + + transport := &http.Transport{ + Dial: dialer.NetDial, + TLSClientConfig: dialer.TLSClientConfig, + } + req, _ := http.NewRequest(http.MethodGet, httpProtos[tt.server]+tt.url+"/", nil) + if tt.header != "" { + req.Host = tt.header + } + client := &http.Client{Transport: transport} + resp, err = client.Do(req) + if err == nil { + resp.Body.Close() + } + transport.CloseIdleConnections() + check(httpProtos) + } +} + +func TestDialCompression(t *testing.T) { + s := newServer(t) + defer s.Close() + + dialer := cstDialer + dialer.EnableCompression = true + ws, _, err := dialer.Dial(s.URL, nil) + if err != nil { + t.Fatalf("Dial: %v", err) + } + defer ws.Close() + sendRecv(t, ws) +} + +func TestSocksProxyDial(t *testing.T) { + s := newServer(t) + defer s.Close() + + proxyListener, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatalf("listen failed: %v", err) + } + defer proxyListener.Close() + go func() { + c1, err := proxyListener.Accept() + if err != nil { + t.Errorf("proxy accept failed: %v", err) + return + } + defer c1.Close() + + _ = c1.SetDeadline(time.Now().Add(30 * time.Second)) + + buf := make([]byte, 32) + if _, err := io.ReadFull(c1, buf[:3]); err != nil { + t.Errorf("read failed: %v", err) + return + } + if want := []byte{5, 1, 0}; !bytes.Equal(want, buf[:len(want)]) { + t.Errorf("read %x, want %x", buf[:len(want)], want) + } + if _, err := c1.Write([]byte{5, 0}); err != nil { + t.Errorf("write failed: %v", err) + return + } + if _, err := io.ReadFull(c1, buf[:10]); err != nil { + t.Errorf("read failed: %v", err) + return + } + if want := []byte{5, 1, 0, 1}; !bytes.Equal(want, buf[:len(want)]) { + t.Errorf("read %x, want %x", buf[:len(want)], want) + return + } + buf[1] = 0 + if _, err := c1.Write(buf[:10]); err != nil { + t.Errorf("write failed: %v", err) + return + } + + ip := net.IP(buf[4:8]) + port := binary.BigEndian.Uint16(buf[8:10]) + + c2, err := net.DialTCP("tcp", nil, &net.TCPAddr{IP: ip, Port: int(port)}) + if err != nil { + t.Errorf("dial failed; %v", err) + return + } + defer c2.Close() + done := make(chan struct{}) + go func() { + _, _ = io.Copy(c1, c2) + close(done) + }() + _, _ = io.Copy(c2, c1) + <-done + }() + + purl, err := url.Parse("socks5://" + proxyListener.Addr().String()) + if err != nil { + t.Fatalf("parse failed: %v", err) + } + + cstDialer := cstDialer // make local copy for modification on next line. + cstDialer.Proxy = http.ProxyURL(purl) + + ws, _, err := cstDialer.Dial(s.URL, nil) + if err != nil { + t.Fatalf("Dial: %v", err) + } + defer ws.Close() + sendRecv(t, ws) +} + +func TestTracingDialWithContext(t *testing.T) { + + var headersWrote, requestWrote, getConn, gotConn, connectDone, gotFirstResponseByte bool + trace := &httptrace.ClientTrace{ + WroteHeaders: func() { + headersWrote = true + }, + WroteRequest: func(httptrace.WroteRequestInfo) { + requestWrote = true + }, + GetConn: func(hostPort string) { + getConn = true + }, + GotConn: func(info httptrace.GotConnInfo) { + gotConn = true + }, + ConnectDone: func(network, addr string, err error) { + connectDone = true + }, + GotFirstResponseByte: func() { + gotFirstResponseByte = true + }, + } + ctx := httptrace.WithClientTrace(context.Background(), trace) + + s := newTLSServer(t) + defer s.Close() + + d := cstDialer + d.TLSClientConfig = &tls.Config{RootCAs: rootCAs(t, s.Server)} + + ws, _, err := d.DialContext(ctx, s.URL, nil) + if err != nil { + t.Fatalf("Dial: %v", err) + } + + if !headersWrote { + t.Fatal("Headers was not written") + } + if !requestWrote { + t.Fatal("Request was not written") + } + if !getConn { + t.Fatal("getConn was not called") + } + if !gotConn { + t.Fatal("gotConn was not called") + } + if !connectDone { + t.Fatal("connectDone was not called") + } + if !gotFirstResponseByte { + t.Fatal("GotFirstResponseByte was not called") + } + + defer ws.Close() + sendRecv(t, ws) +} + +func TestEmptyTracingDialWithContext(t *testing.T) { + + trace := &httptrace.ClientTrace{} + ctx := httptrace.WithClientTrace(context.Background(), trace) + + s := newTLSServer(t) + defer s.Close() + + d := cstDialer + d.TLSClientConfig = &tls.Config{RootCAs: rootCAs(t, s.Server)} + + ws, _, err := d.DialContext(ctx, s.URL, nil) + if err != nil { + t.Fatalf("Dial: %v", err) + } + + defer ws.Close() + sendRecv(t, ws) +} + +// TestNetDialConnect tests selection of dial method between NetDial, NetDialContext, NetDialTLS or NetDialTLSContext +func TestNetDialConnect(t *testing.T) { + + upgrader := Upgrader{} + handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if IsWebSocketUpgrade(r) { + c, err := upgrader.Upgrade(w, r, http.Header{"X-Test-Host": {r.Host}}) + if err != nil { + t.Fatal(err) + } + c.Close() + } else { + w.Header().Set("X-Test-Host", r.Host) + } + }) + + server := httptest.NewServer(handler) + defer server.Close() + + tlsServer := httptest.NewTLSServer(handler) + defer tlsServer.Close() + + testUrls := map[*httptest.Server]string{ + server: "ws://" + server.Listener.Addr().String() + "/", + tlsServer: "wss://" + tlsServer.Listener.Addr().String() + "/", + } + + cas := rootCAs(t, tlsServer) + tlsConfig := &tls.Config{ + RootCAs: cas, + ServerName: "example.com", + InsecureSkipVerify: false, + } + + tests := []struct { + name string + server *httptest.Server // server to use + netDial func(network, addr string) (net.Conn, error) + netDialContext func(ctx context.Context, network, addr string) (net.Conn, error) + netDialTLSContext func(ctx context.Context, network, addr string) (net.Conn, error) + tlsClientConfig *tls.Config + }{ + + { + name: "HTTP server, all NetDial* defined, shall use NetDialContext", + server: server, + netDial: func(network, addr string) (net.Conn, error) { + return nil, errors.New("NetDial should not be called") + }, + netDialContext: func(_ context.Context, network, addr string) (net.Conn, error) { + return net.Dial(network, addr) + }, + netDialTLSContext: func(_ context.Context, network, addr string) (net.Conn, error) { + return nil, errors.New("NetDialTLSContext should not be called") + }, + tlsClientConfig: nil, + }, + { + name: "HTTP server, all NetDial* undefined", + server: server, + netDial: nil, + netDialContext: nil, + netDialTLSContext: nil, + tlsClientConfig: nil, + }, + { + name: "HTTP server, NetDialContext undefined, shall fallback to NetDial", + server: server, + netDial: func(network, addr string) (net.Conn, error) { + return net.Dial(network, addr) + }, + netDialContext: nil, + netDialTLSContext: func(ctx context.Context, network, addr string) (net.Conn, error) { + return nil, errors.New("NetDialTLSContext should not be called") + }, + tlsClientConfig: nil, + }, + { + name: "HTTPS server, all NetDial* defined, shall use NetDialTLSContext", + server: tlsServer, + netDial: func(network, addr string) (net.Conn, error) { + return nil, errors.New("NetDial should not be called") + }, + netDialContext: func(ctx context.Context, network, addr string) (net.Conn, error) { + return nil, errors.New("NetDialContext should not be called") + }, + netDialTLSContext: func(ctx context.Context, network, addr string) (net.Conn, error) { + netConn, err := net.Dial(network, addr) + if err != nil { + return nil, err + } + tlsConn := tls.Client(netConn, tlsConfig) + err = tlsConn.Handshake() + if err != nil { + return nil, err + } + return tlsConn, nil + }, + tlsClientConfig: nil, + }, + { + name: "HTTPS server, NetDialTLSContext undefined, shall fallback to NetDialContext and do handshake", + server: tlsServer, + netDial: func(network, addr string) (net.Conn, error) { + return nil, errors.New("NetDial should not be called") + }, + netDialContext: func(ctx context.Context, network, addr string) (net.Conn, error) { + return net.Dial(network, addr) + }, + netDialTLSContext: nil, + tlsClientConfig: tlsConfig, + }, + { + name: "HTTPS server, NetDialTLSContext and NetDialContext undefined, shall fallback to NetDial and do handshake", + server: tlsServer, + netDial: func(network, addr string) (net.Conn, error) { + return net.Dial(network, addr) + }, + netDialContext: nil, + netDialTLSContext: nil, + tlsClientConfig: tlsConfig, + }, + { + name: "HTTPS server, all NetDial* undefined", + server: tlsServer, + netDial: nil, + netDialContext: nil, + netDialTLSContext: nil, + tlsClientConfig: tlsConfig, + }, + { + name: "HTTPS server, all NetDialTLSContext defined, dummy TlsClientConfig defined, shall not do handshake", + server: tlsServer, + netDial: func(network, addr string) (net.Conn, error) { + return nil, errors.New("NetDial should not be called") + }, + netDialContext: func(ctx context.Context, network, addr string) (net.Conn, error) { + return nil, errors.New("NetDialContext should not be called") + }, + netDialTLSContext: func(ctx context.Context, network, addr string) (net.Conn, error) { + netConn, err := net.Dial(network, addr) + if err != nil { + return nil, err + } + tlsConn := tls.Client(netConn, tlsConfig) + err = tlsConn.Handshake() + if err != nil { + return nil, err + } + return tlsConn, nil + }, + tlsClientConfig: &tls.Config{ + RootCAs: nil, + ServerName: "badserver.com", + InsecureSkipVerify: false, + }, + }, + } + + for _, tc := range tests { + dialer := Dialer{ + NetDial: tc.netDial, + NetDialContext: tc.netDialContext, + NetDialTLSContext: tc.netDialTLSContext, + TLSClientConfig: tc.tlsClientConfig, + } + + // Test websocket dial + c, _, err := dialer.Dial(testUrls[tc.server], nil) + if err != nil { + t.Errorf("FAILED %s, err: %s", tc.name, err.Error()) + } else { + c.Close() + } + } +} +func TestNextProtos(t *testing.T) { + ts := httptest.NewUnstartedServer( + http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}), + ) + ts.EnableHTTP2 = true + ts.StartTLS() + defer ts.Close() + + d := Dialer{ + TLSClientConfig: ts.Client().Transport.(*http.Transport).TLSClientConfig, + } + + r, err := ts.Client().Get(ts.URL) + if err != nil { + t.Fatalf("Get: %v", err) + } + r.Body.Close() + + // Asserts that Dialer.TLSClientConfig.NextProtos contains "h2" + // after the Client.Get call from net/http above. + var containsHTTP2 bool = false + for _, proto := range d.TLSClientConfig.NextProtos { + if proto == "h2" { + containsHTTP2 = true + } + } + if !containsHTTP2 { + t.Fatalf("Dialer.TLSClientConfig.NextProtos does not contain \"h2\"") + } + + _, _, err = d.Dial(makeWsProto(ts.URL), nil) + if err == nil { + t.Fatalf("Dial succeeded, expect fail ") + } +} + +type dataBeforeHandshakeResponseWriter struct { + http.ResponseWriter +} + +type dataBeforeHandshakeConnection struct { + net.Conn + io.Reader +} + +func (c *dataBeforeHandshakeConnection) Read(p []byte) (int, error) { + return c.Reader.Read(p) +} + +func (w dataBeforeHandshakeResponseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) { + // Example single-frame masked text message from section 5.7 of the RFC. + message := []byte{0x81, 0x85, 0x37, 0xfa, 0x21, 0x3d, 0x7f, 0x9f, 0x4d, 0x51, 0x58} + n := len(message) / 2 + + c, rw, err := http.NewResponseController(w.ResponseWriter).Hijack() + if rw != nil { + // Load first part of message into bufio.Reader. If the websocket + // connection reads more than n bytes from the bufio.Reader, then the + // test will fail with an unexpected EOF error. + rw.Reader.Reset(bytes.NewReader(message[:n])) + rw.Reader.Peek(n) + } + if c != nil { + // Inject second part of message before data read from the network connection. + c = &dataBeforeHandshakeConnection{ + Conn: c, + Reader: io.MultiReader(bytes.NewReader(message[n:]), c), + } + } + return c, rw, err +} + +func TestDataReceivedBeforeHandshake(t *testing.T) { + s := newServer(t) + defer s.Close() + + origHandler := s.Server.Config.Handler + s.Server.Config.Handler = http.HandlerFunc( + func(w http.ResponseWriter, r *http.Request) { + origHandler.ServeHTTP(dataBeforeHandshakeResponseWriter{w}, r) + }) + + for _, readBufferSize := range []int{0, 1024} { + t.Run(fmt.Sprintf("ReadBufferSize=%d", readBufferSize), func(t *testing.T) { + dialer := cstDialer + dialer.ReadBufferSize = readBufferSize + ws, _, err := cstDialer.Dial(s.URL, nil) + if err != nil { + t.Fatalf("Dial: %v", err) + } + defer ws.Close() + _, m, err := ws.ReadMessage() + if err != nil || string(m) != "Hello" { + t.Fatalf("ReadMessage() = %q, %v, want \"Hello\", nil", m, err) + } + }) + } +} +// Copyright 2014 The Gorilla WebSocket Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + + +var hostPortNoPortTests = []struct { + u *url.URL + hostPort, hostNoPort string +}{ + {&url.URL{Scheme: "ws", Host: "example.com"}, "example.com:80", "example.com"}, + {&url.URL{Scheme: "wss", Host: "example.com"}, "example.com:443", "example.com"}, + {&url.URL{Scheme: "ws", Host: "example.com:7777"}, "example.com:7777", "example.com"}, + {&url.URL{Scheme: "wss", Host: "example.com:7777"}, "example.com:7777", "example.com"}, +} + +func TestHostPortNoPort(t *testing.T) { + for _, tt := range hostPortNoPortTests { + hostPort, hostNoPort := hostPortNoPort(tt.u) + if hostPort != tt.hostPort { + t.Errorf("hostPortNoPort(%v) returned hostPort %q, want %q", tt.u, hostPort, tt.hostPort) + } + if hostNoPort != tt.hostNoPort { + t.Errorf("hostPortNoPort(%v) returned hostNoPort %q, want %q", tt.u, hostNoPort, tt.hostNoPort) + } + } +} + +type nopCloser struct{ io.Writer } + +func (nopCloser) Close() error { return nil } + +func TestTruncWriter(t *testing.T) { + const data = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijlkmnopqrstuvwxyz987654321" + for n := 1; n <= 10; n++ { + var b bytes.Buffer + w := &truncWriter{w: nopCloser{&b}} + p := []byte(data) + for len(p) > 0 { + m := len(p) + if m > n { + m = n + } + _, _ = w.Write(p[:m]) + p = p[m:] + } + if b.String() != data[:len(data)-len(w.p)] { + t.Errorf("%d: %q", n, b.String()) + } + } +} + +func textMessages(num int) [][]byte { + messages := make([][]byte, num) + for i := 0; i < num; i++ { + msg := fmt.Sprintf("planet: %d, country: %d, city: %d, street: %d", i, i, i, i) + messages[i] = []byte(msg) + } + return messages +} + +func BenchmarkWriteNoCompression(b *testing.B) { + w := io.Discard + c := newTestConn(nil, w, false) + messages := textMessages(100) + b.ResetTimer() + for i := 0; i < b.N; i++ { + _ = c.WriteMessage(TextMessage, messages[i%len(messages)]) + } + b.ReportAllocs() +} + +func BenchmarkWriteWithCompression(b *testing.B) { + w := io.Discard + c := newTestConn(nil, w, false) + messages := textMessages(100) + c.enableWriteCompression = true + c.newCompressionWriter = compressNoContextTakeover + b.ResetTimer() + for i := 0; i < b.N; i++ { + _ = c.WriteMessage(TextMessage, messages[i%len(messages)]) + } + b.ReportAllocs() +} + +func TestValidCompressionLevel(t *testing.T) { + c := newTestConn(nil, nil, false) + for _, level := range []int{minCompressionLevel - 1, maxCompressionLevel + 1} { + if err := c.SetCompressionLevel(level); err == nil { + t.Errorf("no error for level %d", level) + } + } + for _, level := range []int{minCompressionLevel, maxCompressionLevel} { + if err := c.SetCompressionLevel(level); err != nil { + t.Errorf("error for level %d", level) + } + } +} +// Copyright 2017 The Gorilla WebSocket Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + + +// broadcastBench allows to run broadcast benchmarks. +// In every broadcast benchmark we create many connections, then send the same +// message into every connection and wait for all writes complete. This emulates +// an application where many connections listen to the same data - i.e. PUB/SUB +// scenarios with many subscribers in one channel. +type broadcastBench struct { + w io.Writer + closeCh chan struct{} + doneCh chan struct{} + count int32 + conns []*broadcastConn + compression bool + usePrepared bool +} + +type broadcastMessage struct { + payload []byte + prepared *PreparedMessage +} + +type broadcastConn struct { + conn *Conn + msgCh chan *broadcastMessage +} + +func newBroadcastConn(c *Conn) *broadcastConn { + return &broadcastConn{ + conn: c, + msgCh: make(chan *broadcastMessage, 1), + } +} + +func newBroadcastBench(usePrepared, compression bool) *broadcastBench { + bench := &broadcastBench{ + w: io.Discard, + doneCh: make(chan struct{}), + closeCh: make(chan struct{}), + usePrepared: usePrepared, + compression: compression, + } + bench.makeConns(10000) + return bench +} + +func (b *broadcastBench) makeConns(numConns int) { + conns := make([]*broadcastConn, numConns) + + for i := 0; i < numConns; i++ { + c := newTestConn(nil, b.w, true) + if b.compression { + c.enableWriteCompression = true + c.newCompressionWriter = compressNoContextTakeover + } + conns[i] = newBroadcastConn(c) + go func(c *broadcastConn) { + for { + select { + case msg := <-c.msgCh: + if msg.prepared != nil { + _ = c.conn.WritePreparedMessage(msg.prepared) + } else { + _ = c.conn.WriteMessage(TextMessage, msg.payload) + } + val := atomic.AddInt32(&b.count, 1) + if val%int32(numConns) == 0 { + b.doneCh <- struct{}{} + } + case <-b.closeCh: + return + } + } + }(conns[i]) + } + b.conns = conns +} + +func (b *broadcastBench) close() { + close(b.closeCh) +} + +func (b *broadcastBench) broadcastOnce(msg *broadcastMessage) { + for _, c := range b.conns { + c.msgCh <- msg + } + <-b.doneCh +} + +func BenchmarkBroadcast(b *testing.B) { + benchmarks := []struct { + name string + usePrepared bool + compression bool + }{ + {"NoCompression", false, false}, + {"Compression", false, true}, + {"NoCompressionPrepared", true, false}, + {"CompressionPrepared", true, true}, + } + payload := textMessages(1)[0] + for _, bm := range benchmarks { + b.Run(bm.name, func(b *testing.B) { + bench := newBroadcastBench(bm.usePrepared, bm.compression) + defer bench.close() + b.ResetTimer() + for i := 0; i < b.N; i++ { + message := &broadcastMessage{ + payload: payload, + } + if bench.usePrepared { + pm, _ := NewPreparedMessage(TextMessage, message.payload) + message.prepared = pm + } + bench.broadcastOnce(message) + } + b.ReportAllocs() + }) + } +} +// Copyright 2013 The Gorilla WebSocket Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + + +var _ net.Error = errWriteTimeout + +type fakeNetConn struct { + io.Reader + io.Writer +} + +func (c fakeNetConn) Close() error { return nil } +func (c fakeNetConn) LocalAddr() net.Addr { return localAddr } +func (c fakeNetConn) RemoteAddr() net.Addr { return remoteAddr } +func (c fakeNetConn) SetDeadline(t time.Time) error { return nil } +func (c fakeNetConn) SetReadDeadline(t time.Time) error { return nil } +func (c fakeNetConn) SetWriteDeadline(t time.Time) error { return nil } + +type fakeAddr int + +var ( + localAddr = fakeAddr(1) + remoteAddr = fakeAddr(2) +) + +func (a fakeAddr) Network() string { + return "net" +} + +func (a fakeAddr) String() string { + return "str" +} + +// newTestConn creates a connection backed by a fake network connection using +// default values for buffering. +func newTestConn(r io.Reader, w io.Writer, isServer bool) *Conn { + return newConn(fakeNetConn{Reader: r, Writer: w}, isServer, 1024, 1024, nil, nil, nil) +} + +func TestFraming(t *testing.T) { + frameSizes := []int{ + 0, 1, 2, 124, 125, 126, 127, 128, 129, 65534, 65535, + // 65536, 65537 + } + var readChunkers = []struct { + name string + f func(io.Reader) io.Reader + }{ + {"half", iotest.HalfReader}, + {"one", iotest.OneByteReader}, + {"asis", func(r io.Reader) io.Reader { return r }}, + } + writeBuf := make([]byte, 65537) + for i := range writeBuf { + writeBuf[i] = byte(i) + } + var writers = []struct { + name string + f func(w io.Writer, n int) (int, error) + }{ + {"iocopy", func(w io.Writer, n int) (int, error) { + nn, err := io.Copy(w, bytes.NewReader(writeBuf[:n])) + return int(nn), err + }}, + {"write", func(w io.Writer, n int) (int, error) { + return w.Write(writeBuf[:n]) + }}, + {"string", func(w io.Writer, n int) (int, error) { + return io.WriteString(w, string(writeBuf[:n])) + }}, + } + + for _, compress := range []bool{false, true} { + for _, isServer := range []bool{true, false} { + for _, chunker := range readChunkers { + + var connBuf bytes.Buffer + wc := newTestConn(nil, &connBuf, isServer) + rc := newTestConn(chunker.f(&connBuf), nil, !isServer) + if compress { + wc.newCompressionWriter = compressNoContextTakeover + rc.newDecompressionReader = decompressNoContextTakeover + } + for _, n := range frameSizes { + for _, writer := range writers { + name := fmt.Sprintf("z:%v, s:%v, r:%s, n:%d w:%s", compress, isServer, chunker.name, n, writer.name) + + w, err := wc.NextWriter(TextMessage) + if err != nil { + t.Errorf("%s: wc.NextWriter() returned %v", name, err) + continue + } + nn, err := writer.f(w, n) + if err != nil || nn != n { + t.Errorf("%s: w.Write(writeBuf[:n]) returned %d, %v", name, nn, err) + continue + } + err = w.Close() + if err != nil { + t.Errorf("%s: w.Close() returned %v", name, err) + continue + } + + opCode, r, err := rc.NextReader() + if err != nil || opCode != TextMessage { + t.Errorf("%s: NextReader() returned %d, r, %v", name, opCode, err) + continue + } + + t.Logf("frame size: %d", n) + rbuf, err := io.ReadAll(r) + if err != nil { + t.Errorf("%s: ReadFull() returned rbuf, %v", name, err) + continue + } + + if len(rbuf) != n { + t.Errorf("%s: len(rbuf) is %d, want %d", name, len(rbuf), n) + continue + } + + for i, b := range rbuf { + if byte(i) != b { + t.Errorf("%s: bad byte at offset %d", name, i) + break + } + } + } + } + } + } + } +} + +func TestWriteControlDeadline(t *testing.T) { + t.Parallel() + message := []byte("hello") + var connBuf bytes.Buffer + c := newTestConn(nil, &connBuf, true) + if err := c.WriteControl(PongMessage, message, time.Time{}); err != nil { + t.Errorf("WriteControl(..., zero deadline) = %v, want nil", err) + } + if err := c.WriteControl(PongMessage, message, time.Now().Add(time.Second)); err != nil { + t.Errorf("WriteControl(..., future deadline) = %v, want nil", err) + } + if err := c.WriteControl(PongMessage, message, time.Now().Add(-time.Second)); err == nil { + t.Errorf("WriteControl(..., past deadline) = nil, want timeout error") + } +} + +func TestConcurrencyWriteControl(t *testing.T) { + const message = "this is a ping/pong messsage" + loop := 10 + workers := 10 + for i := 0; i < loop; i++ { + var connBuf bytes.Buffer + + wg := sync.WaitGroup{} + wc := newTestConn(nil, &connBuf, true) + + for i := 0; i < workers; i++ { + wg.Add(1) + go func() { + defer wg.Done() + if err := wc.WriteControl(PongMessage, []byte(message), time.Now().Add(time.Second)); err != nil { + t.Errorf("concurrently wc.WriteControl() returned %v", err) + } + }() + } + + wg.Wait() + wc.Close() + } +} + +func TestControl(t *testing.T) { + const message = "this is a ping/pong message" + for _, isServer := range []bool{true, false} { + for _, isWriteControl := range []bool{true, false} { + name := fmt.Sprintf("s:%v, wc:%v", isServer, isWriteControl) + var connBuf bytes.Buffer + wc := newTestConn(nil, &connBuf, isServer) + rc := newTestConn(&connBuf, nil, !isServer) + if isWriteControl { + _ = wc.WriteControl(PongMessage, []byte(message), time.Now().Add(time.Second)) + } else { + w, err := wc.NextWriter(PongMessage) + if err != nil { + t.Errorf("%s: wc.NextWriter() returned %v", name, err) + continue + } + if _, err := w.Write([]byte(message)); err != nil { + t.Errorf("%s: w.Write() returned %v", name, err) + continue + } + if err := w.Close(); err != nil { + t.Errorf("%s: w.Close() returned %v", name, err) + continue + } + var actualMessage string + rc.SetPongHandler(func(s string) error { actualMessage = s; return nil }) + _, _, _ = rc.NextReader() + if actualMessage != message { + t.Errorf("%s: pong=%q, want %q", name, actualMessage, message) + continue + } + } + } + } +} + +// simpleBufferPool is an implementation of BufferPool for TestWriteBufferPool. +type simpleBufferPool struct { + v interface{} +} + +func (p *simpleBufferPool) Get() interface{} { + v := p.v + p.v = nil + return v +} + +func (p *simpleBufferPool) Put(v interface{}) { + p.v = v +} + +func TestWriteBufferPool(t *testing.T) { + const message = "Now is the time for all good people to come to the aid of the party." + + var buf bytes.Buffer + var pool simpleBufferPool + rc := newTestConn(&buf, nil, false) + + // Specify writeBufferSize smaller than message size to ensure that pooling + // works with fragmented messages. + wc := newConn(fakeNetConn{Writer: &buf}, true, 1024, len(message)-1, &pool, nil, nil) + + if wc.writeBuf != nil { + t.Fatal("writeBuf not nil after create") + } + + // Part 1: test NextWriter/Write/Close + + w, err := wc.NextWriter(TextMessage) + if err != nil { + t.Fatalf("wc.NextWriter() returned %v", err) + } + + if wc.writeBuf == nil { + t.Fatal("writeBuf is nil after NextWriter") + } + + writeBufAddr := &wc.writeBuf[0] + + if _, err := io.WriteString(w, message); err != nil { + t.Fatalf("io.WriteString(w, message) returned %v", err) + } + + if err := w.Close(); err != nil { + t.Fatalf("w.Close() returned %v", err) + } + + if wc.writeBuf != nil { + t.Fatal("writeBuf not nil after w.Close()") + } + + if wpd, ok := pool.v.(writePoolData); !ok || len(wpd.buf) == 0 || &wpd.buf[0] != writeBufAddr { + t.Fatal("writeBuf not returned to pool") + } + + opCode, p, err := rc.ReadMessage() + if opCode != TextMessage || err != nil { + t.Fatalf("ReadMessage() returned %d, p, %v", opCode, err) + } + + if s := string(p); s != message { + t.Fatalf("message is %s, want %s", s, message) + } + + // Part 2: Test WriteMessage. + + if err := wc.WriteMessage(TextMessage, []byte(message)); err != nil { + t.Fatalf("wc.WriteMessage() returned %v", err) + } + + if wc.writeBuf != nil { + t.Fatal("writeBuf not nil after wc.WriteMessage()") + } + + if wpd, ok := pool.v.(writePoolData); !ok || len(wpd.buf) == 0 || &wpd.buf[0] != writeBufAddr { + t.Fatal("writeBuf not returned to pool after WriteMessage") + } + + opCode, p, err = rc.ReadMessage() + if opCode != TextMessage || err != nil { + t.Fatalf("ReadMessage() returned %d, p, %v", opCode, err) + } + + if s := string(p); s != message { + t.Fatalf("message is %s, want %s", s, message) + } +} + +// TestWriteBufferPoolSync ensures that *sync.Pool works as a buffer pool. +func TestWriteBufferPoolSync(t *testing.T) { + var buf bytes.Buffer + var pool sync.Pool + wc := newConn(fakeNetConn{Writer: &buf}, true, 1024, 1024, &pool, nil, nil) + rc := newTestConn(&buf, nil, false) + + const message = "Hello World!" + for i := 0; i < 3; i++ { + if err := wc.WriteMessage(TextMessage, []byte(message)); err != nil { + t.Fatalf("wc.WriteMessage() returned %v", err) + } + opCode, p, err := rc.ReadMessage() + if opCode != TextMessage || err != nil { + t.Fatalf("ReadMessage() returned %d, p, %v", opCode, err) + } + if s := string(p); s != message { + t.Fatalf("message is %s, want %s", s, message) + } + } +} + +// errorWriter is an io.Writer than returns an error on all writes. +type errorWriter struct{} + +func (ew errorWriter) Write(p []byte) (int, error) { return 0, errors.New("error") } + +// TestWriteBufferPoolError ensures that buffer is returned to pool after error +// on write. +func TestWriteBufferPoolError(t *testing.T) { + + // Part 1: Test NextWriter/Write/Close + + var pool simpleBufferPool + wc := newConn(fakeNetConn{Writer: errorWriter{}}, true, 1024, 1024, &pool, nil, nil) + + w, err := wc.NextWriter(TextMessage) + if err != nil { + t.Fatalf("wc.NextWriter() returned %v", err) + } + + if wc.writeBuf == nil { + t.Fatal("writeBuf is nil after NextWriter") + } + + writeBufAddr := &wc.writeBuf[0] + + if _, err := io.WriteString(w, "Hello"); err != nil { + t.Fatalf("io.WriteString(w, message) returned %v", err) + } + + if err := w.Close(); err == nil { + t.Fatalf("w.Close() did not return error") + } + + if wpd, ok := pool.v.(writePoolData); !ok || len(wpd.buf) == 0 || &wpd.buf[0] != writeBufAddr { + t.Fatal("writeBuf not returned to pool") + } + + // Part 2: Test WriteMessage + + wc = newConn(fakeNetConn{Writer: errorWriter{}}, true, 1024, 1024, &pool, nil, nil) + + if err := wc.WriteMessage(TextMessage, []byte("Hello")); err == nil { + t.Fatalf("wc.WriteMessage did not return error") + } + + if wpd, ok := pool.v.(writePoolData); !ok || len(wpd.buf) == 0 || &wpd.buf[0] != writeBufAddr { + t.Fatal("writeBuf not returned to pool") + } +} + +func TestCloseFrameBeforeFinalMessageFrame(t *testing.T) { + const bufSize = 512 + + expectedErr := &CloseError{Code: CloseNormalClosure, Text: "hello"} + + var b1, b2 bytes.Buffer + wc := newConn(&fakeNetConn{Reader: nil, Writer: &b1}, false, 1024, bufSize, nil, nil, nil) + rc := newTestConn(&b1, &b2, true) + + w, _ := wc.NextWriter(BinaryMessage) + _, _ = w.Write(make([]byte, bufSize+bufSize/2)) + _ = wc.WriteControl(CloseMessage, FormatCloseMessage(expectedErr.Code, expectedErr.Text), time.Now().Add(10*time.Second)) + w.Close() + + op, r, err := rc.NextReader() + if op != BinaryMessage || err != nil { + t.Fatalf("NextReader() returned %d, %v", op, err) + } + _, err = io.Copy(io.Discard, r) + if !reflect.DeepEqual(err, expectedErr) { + t.Fatalf("io.Copy() returned %v, want %v", err, expectedErr) + } + _, _, err = rc.NextReader() + if !reflect.DeepEqual(err, expectedErr) { + t.Fatalf("NextReader() returned %v, want %v", err, expectedErr) + } +} + +func TestEOFWithinFrame(t *testing.T) { + const bufSize = 64 + + for n := 0; ; n++ { + var b bytes.Buffer + wc := newTestConn(nil, &b, false) + rc := newTestConn(&b, nil, true) + + w, _ := wc.NextWriter(BinaryMessage) + _, _ = w.Write(make([]byte, bufSize)) + w.Close() + + if n >= b.Len() { + break + } + b.Truncate(n) + + op, r, err := rc.NextReader() + if err == errUnexpectedEOF { + continue + } + if op != BinaryMessage || err != nil { + t.Fatalf("%d: NextReader() returned %d, %v", n, op, err) + } + _, err = io.Copy(io.Discard, r) + if err != errUnexpectedEOF { + t.Fatalf("%d: io.Copy() returned %v, want %v", n, err, errUnexpectedEOF) + } + _, _, err = rc.NextReader() + if err != errUnexpectedEOF { + t.Fatalf("%d: NextReader() returned %v, want %v", n, err, errUnexpectedEOF) + } + } +} + +func TestEOFBeforeFinalFrame(t *testing.T) { + const bufSize = 512 + + var b1, b2 bytes.Buffer + wc := newConn(&fakeNetConn{Writer: &b1}, false, 1024, bufSize, nil, nil, nil) + rc := newTestConn(&b1, &b2, true) + + w, _ := wc.NextWriter(BinaryMessage) + _, _ = w.Write(make([]byte, bufSize+bufSize/2)) + + op, r, err := rc.NextReader() + if op != BinaryMessage || err != nil { + t.Fatalf("NextReader() returned %d, %v", op, err) + } + _, err = io.Copy(io.Discard, r) + if err != errUnexpectedEOF { + t.Fatalf("io.Copy() returned %v, want %v", err, errUnexpectedEOF) + } + _, _, err = rc.NextReader() + if err != errUnexpectedEOF { + t.Fatalf("NextReader() returned %v, want %v", err, errUnexpectedEOF) + } +} + +func TestWriteAfterMessageWriterClose(t *testing.T) { + wc := newTestConn(nil, &bytes.Buffer{}, false) + w, _ := wc.NextWriter(BinaryMessage) + _, _ = io.WriteString(w, "hello") + if err := w.Close(); err != nil { + t.Fatalf("unexpected error closing message writer, %v", err) + } + + if _, err := io.WriteString(w, "world"); err == nil { + t.Fatalf("no error writing after close") + } + + w, _ = wc.NextWriter(BinaryMessage) + _, _ = io.WriteString(w, "hello") + + // close w by getting next writer + _, err := wc.NextWriter(BinaryMessage) + if err != nil { + t.Fatalf("unexpected error getting next writer, %v", err) + } + + if _, err := io.WriteString(w, "world"); err == nil { + t.Fatalf("no error writing after close") + } +} + +func TestReadLimit(t *testing.T) { + t.Run("Test ReadLimit is enforced", func(t *testing.T) { + const readLimit = 512 + message := make([]byte, readLimit+1) + + var b1, b2 bytes.Buffer + wc := newConn(&fakeNetConn{Writer: &b1}, false, 1024, readLimit-2, nil, nil, nil) + rc := newTestConn(&b1, &b2, true) + rc.SetReadLimit(readLimit) + + // Send message at the limit with interleaved pong. + w, _ := wc.NextWriter(BinaryMessage) + _, _ = w.Write(message[:readLimit-1]) + _ = wc.WriteControl(PongMessage, []byte("this is a pong"), time.Now().Add(10*time.Second)) + _, _ = w.Write(message[:1]) + w.Close() + + // Send message larger than the limit. + _ = wc.WriteMessage(BinaryMessage, message[:readLimit+1]) + + op, _, err := rc.NextReader() + if op != BinaryMessage || err != nil { + t.Fatalf("1: NextReader() returned %d, %v", op, err) + } + op, r, err := rc.NextReader() + if op != BinaryMessage || err != nil { + t.Fatalf("2: NextReader() returned %d, %v", op, err) + } + _, err = io.Copy(io.Discard, r) + if err != ErrReadLimit { + t.Fatalf("io.Copy() returned %v", err) + } + }) + + t.Run("Test that ReadLimit cannot be overflowed", func(t *testing.T) { + const readLimit = 1 + + var b1, b2 bytes.Buffer + rc := newTestConn(&b1, &b2, true) + rc.SetReadLimit(readLimit) + + // First, send a non-final binary message + b1.Write([]byte("\x02\x81")) + + // Mask key + b1.Write([]byte("\x00\x00\x00\x00")) + + // First payload + b1.Write([]byte("A")) + + // Next, send a negative-length, non-final continuation frame + b1.Write([]byte("\x00\xFF\x80\x00\x00\x00\x00\x00\x00\x00")) + + // Mask key + b1.Write([]byte("\x00\x00\x00\x00")) + + // Next, send a too long, final continuation frame + b1.Write([]byte("\x80\xFF\x00\x00\x00\x00\x00\x00\x00\x05")) + + // Mask key + b1.Write([]byte("\x00\x00\x00\x00")) + + // Too-long payload + b1.Write([]byte("BCDEF")) + + op, r, err := rc.NextReader() + if op != BinaryMessage || err != nil { + t.Fatalf("1: NextReader() returned %d, %v", op, err) + } + + var buf [10]byte + var read int + n, err := r.Read(buf[:]) + if err != nil && err != ErrReadLimit { + t.Fatalf("unexpected error testing read limit: %v", err) + } + read += n + + n, err = r.Read(buf[:]) + if err != nil && err != ErrReadLimit { + t.Fatalf("unexpected error testing read limit: %v", err) + } + read += n + + if err == nil && read > readLimit { + t.Fatalf("read limit exceeded: limit %d, read %d", readLimit, read) + } + }) +} + +func TestAddrs(t *testing.T) { + c := newTestConn(nil, nil, true) + if c.LocalAddr() != localAddr { + t.Errorf("LocalAddr = %v, want %v", c.LocalAddr(), localAddr) + } + if c.RemoteAddr() != remoteAddr { + t.Errorf("RemoteAddr = %v, want %v", c.RemoteAddr(), remoteAddr) + } +} + +func TestDeprecatedUnderlyingConn(t *testing.T) { + var b1, b2 bytes.Buffer + fc := fakeNetConn{Reader: &b1, Writer: &b2} + c := newConn(fc, true, 1024, 1024, nil, nil, nil) + ul := c.UnderlyingConn() + if ul != fc { + t.Fatalf("Underlying conn is not what it should be.") + } +} + +func TestNetConn(t *testing.T) { + var b1, b2 bytes.Buffer + fc := fakeNetConn{Reader: &b1, Writer: &b2} + c := newConn(fc, true, 1024, 1024, nil, nil, nil) + ul := c.NetConn() + if ul != fc { + t.Fatalf("Underlying conn is not what it should be.") + } +} + +func TestBufioReadBytes(t *testing.T) { + // Test calling bufio.ReadBytes for value longer than read buffer size. + + m := make([]byte, 512) + m[len(m)-1] = '\n' + + var b1, b2 bytes.Buffer + wc := newConn(fakeNetConn{Writer: &b1}, false, len(m)+64, len(m)+64, nil, nil, nil) + rc := newConn(fakeNetConn{Reader: &b1, Writer: &b2}, true, len(m)-64, len(m)-64, nil, nil, nil) + + w, _ := wc.NextWriter(BinaryMessage) + _, _ = w.Write(m) + w.Close() + + op, r, err := rc.NextReader() + if op != BinaryMessage || err != nil { + t.Fatalf("NextReader() returned %d, %v", op, err) + } + + br := bufio.NewReader(r) + p, err := br.ReadBytes('\n') + if err != nil { + t.Fatalf("ReadBytes() returned %v", err) + } + if len(p) != len(m) { + t.Fatalf("read returned %d bytes, want %d bytes", len(p), len(m)) + } +} + +var closeErrorTests = []struct { + err error + codes []int + ok bool +}{ + {&CloseError{Code: CloseNormalClosure}, []int{CloseNormalClosure}, true}, + {&CloseError{Code: CloseNormalClosure}, []int{CloseNoStatusReceived}, false}, + {&CloseError{Code: CloseNormalClosure}, []int{CloseNoStatusReceived, CloseNormalClosure}, true}, + {errors.New("hello"), []int{CloseNormalClosure}, false}, +} + +func TestCloseError(t *testing.T) { + for _, tt := range closeErrorTests { + ok := IsCloseError(tt.err, tt.codes...) + if ok != tt.ok { + t.Errorf("IsCloseError(%#v, %#v) returned %v, want %v", tt.err, tt.codes, ok, tt.ok) + } + } +} + +var unexpectedCloseErrorTests = []struct { + err error + codes []int + ok bool +}{ + {&CloseError{Code: CloseNormalClosure}, []int{CloseNormalClosure}, false}, + {&CloseError{Code: CloseNormalClosure}, []int{CloseNoStatusReceived}, true}, + {&CloseError{Code: CloseNormalClosure}, []int{CloseNoStatusReceived, CloseNormalClosure}, false}, + {errors.New("hello"), []int{CloseNormalClosure}, false}, +} + +func TestUnexpectedCloseErrors(t *testing.T) { + for _, tt := range unexpectedCloseErrorTests { + ok := IsUnexpectedCloseError(tt.err, tt.codes...) + if ok != tt.ok { + t.Errorf("IsUnexpectedCloseError(%#v, %#v) returned %v, want %v", tt.err, tt.codes, ok, tt.ok) + } + } +} + +type blockingWriter struct { + c1, c2 chan struct{} +} + +func (w blockingWriter) Write(p []byte) (int, error) { + // Allow main to continue + close(w.c1) + // Wait for panic in main + <-w.c2 + return len(p), nil +} + +func TestConcurrentWritePanic(t *testing.T) { + w := blockingWriter{make(chan struct{}), make(chan struct{})} + c := newTestConn(nil, w, false) + go func() { + _ = c.WriteMessage(TextMessage, []byte{}) + }() + + // wait for goroutine to block in write. + <-w.c1 + + defer func() { + close(w.c2) + if v := recover(); v != nil { + return + } + }() + + _ = c.WriteMessage(TextMessage, []byte{}) + t.Fatal("should not get here") +} + +type failingReader struct{} + +func (r failingReader) Read(p []byte) (int, error) { + return 0, io.EOF +} + +func TestFailedConnectionReadPanic(t *testing.T) { + c := newTestConn(failingReader{}, nil, false) + + defer func() { + if v := recover(); v != nil { + return + } + }() + + for i := 0; i < 20000; i++ { + _, _, _ = c.ReadMessage() + } + t.Fatal("should not get here") +} +// Copyright 2019 The Gorilla WebSocket Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + + +func TestJoinMessages(t *testing.T) { + messages := []string{"a", "bc", "def", "ghij", "klmno", "0", "12", "345", "6789"} + for _, readChunk := range []int{1, 2, 3, 4, 5, 6, 7} { + for _, term := range []string{"", ","} { + var connBuf bytes.Buffer + wc := newTestConn(nil, &connBuf, true) + rc := newTestConn(&connBuf, nil, false) + for _, m := range messages { + _ = wc.WriteMessage(BinaryMessage, []byte(m)) + } + + var result bytes.Buffer + _, err := io.CopyBuffer(&result, JoinMessages(rc, term), make([]byte, readChunk)) + if IsUnexpectedCloseError(err, CloseAbnormalClosure) { + t.Errorf("readChunk=%d, term=%q: unexpected error %v", readChunk, term, err) + } + want := strings.Join(messages, term) + term + if result.String() != want { + t.Errorf("readChunk=%d, term=%q, got %q, want %q", readChunk, term, result.String(), want) + } + } + } +} +// Copyright 2013 The Gorilla WebSocket Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + + +func TestJSON(t *testing.T) { + var buf bytes.Buffer + wc := newTestConn(nil, &buf, true) + rc := newTestConn(&buf, nil, false) + + var actual, expect struct { + A int + B string + } + expect.A = 1 + expect.B = "hello" + + if err := wc.WriteJSON(&expect); err != nil { + t.Fatal("write", err) + } + + if err := rc.ReadJSON(&actual); err != nil { + t.Fatal("read", err) + } + + if !reflect.DeepEqual(&actual, &expect) { + t.Fatal("equal", actual, expect) + } +} + +func TestPartialJSONRead(t *testing.T) { + var buf0, buf1 bytes.Buffer + wc := newTestConn(nil, &buf0, true) + rc := newTestConn(&buf0, &buf1, false) + + var v struct { + A int + B string + } + v.A = 1 + v.B = "hello" + + messageCount := 0 + + // Partial JSON values. + + data, err := json.Marshal(v) + if err != nil { + t.Fatal(err) + } + for i := len(data) - 1; i >= 0; i-- { + if err := wc.WriteMessage(TextMessage, data[:i]); err != nil { + t.Fatal(err) + } + messageCount++ + } + + // Whitespace. + + if err := wc.WriteMessage(TextMessage, []byte(" ")); err != nil { + t.Fatal(err) + } + messageCount++ + + // Close. + + if err := wc.WriteMessage(CloseMessage, FormatCloseMessage(CloseNormalClosure, "")); err != nil { + t.Fatal(err) + } + + for i := 0; i < messageCount; i++ { + err := rc.ReadJSON(&v) + if err != io.ErrUnexpectedEOF { + t.Error("read", i, err) + } + } + + err = rc.ReadJSON(&v) + if _, ok := err.(*CloseError); !ok { + t.Error("final", err) + } +} + +func TestDeprecatedJSON(t *testing.T) { + var buf bytes.Buffer + wc := newTestConn(nil, &buf, true) + rc := newTestConn(&buf, nil, false) + + var actual, expect struct { + A int + B string + } + expect.A = 1 + expect.B = "hello" + + if err := WriteJSON(wc, &expect); err != nil { + t.Fatal("write", err) + } + + if err := ReadJSON(rc, &actual); err != nil { + t.Fatal("read", err) + } + + if !reflect.DeepEqual(&actual, &expect) { + t.Fatal("equal", actual, expect) + } +} +// Copyright 2016 The Gorilla WebSocket Authors. All rights reserved. Use of +// this source code is governed by a BSD-style license that can be found in the +// LICENSE file. + +// !appengine + + +func maskBytesByByte(key [4]byte, pos int, b []byte) int { + for i := range b { + b[i] ^= key[pos&3] + pos++ + } + return pos & 3 +} + +func notzero(b []byte) int { + for i := range b { + if b[i] != 0 { + return i + } + } + return -1 +} + +func TestMaskBytes(t *testing.T) { + key := [4]byte{1, 2, 3, 4} + for size := 1; size <= 1024; size++ { + for align := 0; align < wordSize; align++ { + for pos := 0; pos < 4; pos++ { + b := make([]byte, size+align)[align:] + maskBytes(key, pos, b) + maskBytesByByte(key, pos, b) + if i := notzero(b); i >= 0 { + t.Errorf("size:%d, align:%d, pos:%d, offset:%d", size, align, pos, i) + } + } + } + } +} + +func BenchmarkMaskBytes(b *testing.B) { + for _, size := range []int{2, 4, 8, 16, 32, 512, 1024} { + b.Run(fmt.Sprintf("size-%d", size), func(b *testing.B) { + for _, align := range []int{wordSize / 2} { + b.Run(fmt.Sprintf("align-%d", align), func(b *testing.B) { + for _, fn := range []struct { + name string + fn func(key [4]byte, pos int, b []byte) int + }{ + {"byte", maskBytesByByte}, + {"word", maskBytes}, + } { + b.Run(fn.name, func(b *testing.B) { + key := newMaskKey() + data := make([]byte, size+align)[align:] + for i := 0; i < b.N; i++ { + fn.fn(key, 0, data) + } + b.SetBytes(int64(len(data))) + }) + } + }) + } + }) + } +} +// Copyright 2017 The Gorilla WebSocket Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + + +var preparedMessageTests = []struct { + messageType int + isServer bool + enableWriteCompression bool + compressionLevel int +}{ + // Server + {TextMessage, true, false, flate.BestSpeed}, + {TextMessage, true, true, flate.BestSpeed}, + {TextMessage, true, true, flate.BestCompression}, + {PingMessage, true, false, flate.BestSpeed}, + {PingMessage, true, true, flate.BestSpeed}, + + // Client + {TextMessage, false, false, flate.BestSpeed}, + {TextMessage, false, true, flate.BestSpeed}, + {TextMessage, false, true, flate.BestCompression}, + {PingMessage, false, false, flate.BestSpeed}, + {PingMessage, false, true, flate.BestSpeed}, +} + +func TestPreparedMessage(t *testing.T) { + testRand := rand.New(rand.NewSource(99)) + prevMaskRand := maskRand + maskRand = testRand + defer func() { maskRand = prevMaskRand }() + + for _, tt := range preparedMessageTests { + var data = []byte("this is a test") + var buf bytes.Buffer + c := newTestConn(nil, &buf, tt.isServer) + if tt.enableWriteCompression { + c.newCompressionWriter = compressNoContextTakeover + } + if err := c.SetCompressionLevel(tt.compressionLevel); err != nil { + t.Fatal(err) + } + + // Seed random number generator for consistent frame mask. + testRand.Seed(1234) + + if err := c.WriteMessage(tt.messageType, data); err != nil { + t.Fatal(err) + } + want := buf.String() + + pm, err := NewPreparedMessage(tt.messageType, data) + if err != nil { + t.Fatal(err) + } + + // Scribble on data to ensure that NewPreparedMessage takes a snapshot. + copy(data, "hello world") + + // Seed random number generator for consistent frame mask. + testRand.Seed(1234) + + buf.Reset() + if err := c.WritePreparedMessage(pm); err != nil { + t.Fatal(err) + } + got := buf.String() + + if got != want { + t.Errorf("write message != prepared message for %+v", tt) + } + } +} +// Copyright 2013 The Gorilla WebSocket Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + + +var subprotocolTests = []struct { + h string + protocols []string +}{ + {"", nil}, + {"foo", []string{"foo"}}, + {"foo,bar", []string{"foo", "bar"}}, + {"foo, bar", []string{"foo", "bar"}}, + {" foo, bar", []string{"foo", "bar"}}, + {" foo, bar ", []string{"foo", "bar"}}, +} + +func TestSubprotocols(t *testing.T) { + for _, st := range subprotocolTests { + r := http.Request{Header: http.Header{"Sec-Websocket-Protocol": {st.h}}} + protocols := Subprotocols(&r) + if !reflect.DeepEqual(st.protocols, protocols) { + t.Errorf("SubProtocols(%q) returned %#v, want %#v", st.h, protocols, st.protocols) + } + } +} + +var isWebSocketUpgradeTests = []struct { + ok bool + h http.Header +}{ + {false, http.Header{"Upgrade": {"websocket"}}}, + {false, http.Header{"Connection": {"upgrade"}}}, + {true, http.Header{"Connection": {"upgRade"}, "Upgrade": {"WebSocket"}}}, +} + +func TestIsWebSocketUpgrade(t *testing.T) { + for _, tt := range isWebSocketUpgradeTests { + ok := IsWebSocketUpgrade(&http.Request{Header: tt.h}) + if tt.ok != ok { + t.Errorf("IsWebSocketUpgrade(%v) returned %v, want %v", tt.h, ok, tt.ok) + } + } +} + +func TestSubProtocolSelection(t *testing.T) { + upgrader := Upgrader{ + Subprotocols: []string{"foo", "bar", "baz"}, + } + + r := http.Request{Header: http.Header{"Sec-Websocket-Protocol": {"foo", "bar"}}} + s := upgrader.selectSubprotocol(&r, nil) + if s != "foo" { + t.Errorf("Upgrader.selectSubprotocol returned %v, want %v", s, "foo") + } + + r = http.Request{Header: http.Header{"Sec-Websocket-Protocol": {"bar", "foo"}}} + s = upgrader.selectSubprotocol(&r, nil) + if s != "bar" { + t.Errorf("Upgrader.selectSubprotocol returned %v, want %v", s, "bar") + } + + r = http.Request{Header: http.Header{"Sec-Websocket-Protocol": {"baz"}}} + s = upgrader.selectSubprotocol(&r, nil) + if s != "baz" { + t.Errorf("Upgrader.selectSubprotocol returned %v, want %v", s, "baz") + } + + r = http.Request{Header: http.Header{"Sec-Websocket-Protocol": {"quux"}}} + s = upgrader.selectSubprotocol(&r, nil) + if s != "" { + t.Errorf("Upgrader.selectSubprotocol returned %v, want %v", s, "empty string") + } +} + +var checkSameOriginTests = []struct { + ok bool + r *http.Request +}{ + {false, &http.Request{Host: "example.org", Header: map[string][]string{"Origin": {"https://other.org"}}}}, + {true, &http.Request{Host: "example.org", Header: map[string][]string{"Origin": {"https://example.org"}}}}, + {true, &http.Request{Host: "Example.org", Header: map[string][]string{"Origin": {"https://example.org"}}}}, +} + +func TestCheckSameOrigin(t *testing.T) { + for _, tt := range checkSameOriginTests { + ok := checkSameOrigin(tt.r) + if tt.ok != ok { + t.Errorf("checkSameOrigin(%+v) returned %v, want %v", tt.r, ok, tt.ok) + } + } +} + +type reuseTestResponseWriter struct { + brw *bufio.ReadWriter + http.ResponseWriter +} + +func (resp *reuseTestResponseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) { + return fakeNetConn{strings.NewReader(""), &bytes.Buffer{}}, resp.brw, nil +} + +var bufioReuseTests = []struct { + n int + reuse bool +}{ + {4096, true}, + {128, false}, +} + +func xTestBufioReuse(t *testing.T) { + for i, tt := range bufioReuseTests { + br := bufio.NewReaderSize(strings.NewReader(""), tt.n) + bw := bufio.NewWriterSize(&bytes.Buffer{}, tt.n) + resp := &reuseTestResponseWriter{ + brw: bufio.NewReadWriter(br, bw), + } + upgrader := Upgrader{} + c, err := upgrader.Upgrade(resp, &http.Request{ + Method: http.MethodGet, + Header: http.Header{ + "Upgrade": []string{"websocket"}, + "Connection": []string{"upgrade"}, + "Sec-Websocket-Key": []string{"dGhlIHNhbXBsZSBub25jZQ=="}, + "Sec-Websocket-Version": []string{"13"}, + }}, nil) + if err != nil { + t.Fatal(err) + } + if reuse := c.br == br; reuse != tt.reuse { + t.Errorf("%d: buffered reader reuse=%v, want %v", i, reuse, tt.reuse) + } + writeBuf := bw.AvailableBuffer() + if reuse := &c.writeBuf[0] == &writeBuf[0]; reuse != tt.reuse { + t.Errorf("%d: write buffer reuse=%v, want %v", i, reuse, tt.reuse) + } + } +} + +func TestHijack_NotSupported(t *testing.T) { + t.Parallel() + + req := httptest.NewRequest(http.MethodGet, "http://example.com", nil) + req.Header.Set("Upgrade", "websocket") + req.Header.Set("Connection", "upgrade") + req.Header.Set("Sec-Websocket-Key", "dGhlIHNhbXBsZSBub25jZQ==") + req.Header.Set("Sec-Websocket-Version", "13") + + recorder := httptest.NewRecorder() + + upgrader := Upgrader{} + _, err := upgrader.Upgrade(recorder, req, nil) + + if want := (HandshakeError{}); !errors.As(err, &want) || recorder.Code != http.StatusInternalServerError { + t.Errorf("want %T and status_code=%d", want, http.StatusInternalServerError) + t.Fatalf("got err=%T and status_code=%d", err, recorder.Code) + } +} +// Copyright 2014 The Gorilla WebSocket Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + + +var equalASCIIFoldTests = []struct { + t, s string + eq bool +}{ + {"WebSocket", "websocket", true}, + {"websocket", "WebSocket", true}, + {"Öyster", "öyster", false}, + {"WebSocket", "WetSocket", false}, +} + +func TestEqualASCIIFold(t *testing.T) { + for _, tt := range equalASCIIFoldTests { + eq := equalASCIIFold(tt.s, tt.t) + if eq != tt.eq { + t.Errorf("equalASCIIFold(%q, %q) = %v, want %v", tt.s, tt.t, eq, tt.eq) + } + } +} + +var tokenListContainsValueTests = []struct { + value string + ok bool +}{ + {"WebSocket", true}, + {"WEBSOCKET", true}, + {"websocket", true}, + {"websockets", false}, + {"x websocket", false}, + {"websocket x", false}, + {"other,websocket,more", true}, + {"other, websocket, more", true}, +} + +func TestTokenListContainsValue(t *testing.T) { + for _, tt := range tokenListContainsValueTests { + h := http.Header{"Upgrade": {tt.value}} + ok := tokenListContainsValue(h, "Upgrade", "websocket") + if ok != tt.ok { + t.Errorf("tokenListContainsValue(h, n, %q) = %v, want %v", tt.value, ok, tt.ok) + } + } +} + +var isValidChallengeKeyTests = []struct { + key string + ok bool +}{ + {"dGhlIHNhbXBsZSBub25jZQ==", true}, + {"", false}, + {"InvalidKey", false}, + {"WHQ4eXhscUtKYjBvOGN3WEdtOEQ=", false}, +} + +func TestIsValidChallengeKey(t *testing.T) { + for _, tt := range isValidChallengeKeyTests { + ok := isValidChallengeKey(tt.key) + if ok != tt.ok { + t.Errorf("isValidChallengeKey returns %v, want %v", ok, tt.ok) + } + } +} + +var parseExtensionTests = []struct { + value string + extensions []map[string]string +}{ + {`foo`, []map[string]string{{"": "foo"}}}, + {`foo, bar; baz=2`, []map[string]string{ + {"": "foo"}, + {"": "bar", "baz": "2"}}}, + {`foo; bar="b,a;z"`, []map[string]string{ + {"": "foo", "bar": "b,a;z"}}}, + {`foo , bar; baz = 2`, []map[string]string{ + {"": "foo"}, + {"": "bar", "baz": "2"}}}, + {`foo, bar; baz=2 junk`, []map[string]string{ + {"": "foo"}}}, + {`foo junk, bar; baz=2 junk`, nil}, + {`mux; max-channels=4; flow-control, deflate-stream`, []map[string]string{ + {"": "mux", "max-channels": "4", "flow-control": ""}, + {"": "deflate-stream"}}}, + {`permessage-foo; x="10"`, []map[string]string{ + {"": "permessage-foo", "x": "10"}}}, + {`permessage-foo; use_y, permessage-foo`, []map[string]string{ + {"": "permessage-foo", "use_y": ""}, + {"": "permessage-foo"}}}, + {`permessage-deflate; client_max_window_bits; server_max_window_bits=10 , permessage-deflate; client_max_window_bits`, []map[string]string{ + {"": "permessage-deflate", "client_max_window_bits": "", "server_max_window_bits": "10"}, + {"": "permessage-deflate", "client_max_window_bits": ""}}}, + {"permessage-deflate; server_no_context_takeover; client_max_window_bits=15", []map[string]string{ + {"": "permessage-deflate", "server_no_context_takeover": "", "client_max_window_bits": "15"}, + }}, +} + +func TestParseExtensions(t *testing.T) { + for _, tt := range parseExtensionTests { + h := http.Header{http.CanonicalHeaderKey("Sec-WebSocket-Extensions"): {tt.value}} + extensions := parseExtensions(h) + if !reflect.DeepEqual(extensions, tt.extensions) { + t.Errorf("parseExtensions(%q)\n = %v,\nwant %v", tt.value, extensions, tt.extensions) + } + } +} + +func TestParseArgs(t *testing.T) { + given := ParseArgs([]string { "x", "y", "z" }) + expected := CLIArgs { + FromAddr: "y", + ToAddr: "z", + } + + g.AssertEqual(given, expected) +} + + + +func MainTest() { + tests := []testing.InternalTest { + { "TestProxyDial", TestProxyDial }, + { "TestProxyAuthorizationDial", TestProxyAuthorizationDial }, + { "TestDial", TestDial }, + { "TestDialCookieJar", TestDialCookieJar }, + { "TestDialTLS", TestDialTLS }, + { "TestDialTimeout", TestDialTimeout }, + { "TestHandshakeTimeout", TestHandshakeTimeout }, + { "TestHandshakeTimeoutInContext", TestHandshakeTimeoutInContext }, + { "TestDialBadScheme", TestDialBadScheme }, + { "TestDialBadOrigin", TestDialBadOrigin }, + { "TestDialBadHeader", TestDialBadHeader }, + { "TestBadMethod", TestBadMethod }, + { "TestNoUpgrade", TestNoUpgrade }, + { "TestDialExtraTokensInRespHeaders", TestDialExtraTokensInRespHeaders }, + { "TestHandshake", TestHandshake }, + { "TestRespOnBadHandshake", TestRespOnBadHandshake }, + { "TestHost", TestHost }, + { "TestDialCompression", TestDialCompression }, + { "TestSocksProxyDial", TestSocksProxyDial }, + { "TestTracingDialWithContext", TestTracingDialWithContext }, + { "TestEmptyTracingDialWithContext", TestEmptyTracingDialWithContext }, + { "TestNetDialConnect", TestNetDialConnect }, + { "TestNextProtos", TestNextProtos }, + { "TestDataReceivedBeforeHandshake", TestDataReceivedBeforeHandshake }, + { "TestHostPortNoPort", TestHostPortNoPort }, + { "TestTruncWriter", TestTruncWriter }, + { "TestValidCompressionLevel", TestValidCompressionLevel }, + { "TestFraming", TestFraming }, + { "TestWriteControlDeadline", TestWriteControlDeadline }, + { "TestConcurrencyWriteControl", TestConcurrencyWriteControl }, + { "TestControl", TestControl }, + { "TestWriteBufferPool", TestWriteBufferPool }, + { "TestWriteBufferPoolSync", TestWriteBufferPoolSync }, + { "TestWriteBufferPoolError", TestWriteBufferPoolError }, + { "TestCloseFrameBeforeFinalMessageFrame", TestCloseFrameBeforeFinalMessageFrame }, + { "TestEOFWithinFrame", TestEOFWithinFrame }, + { "TestEOFBeforeFinalFrame", TestEOFBeforeFinalFrame }, + { "TestWriteAfterMessageWriterClose", TestWriteAfterMessageWriterClose }, + { "TestReadLimit", TestReadLimit }, + { "TestAddrs", TestAddrs }, + { "TestDeprecatedUnderlyingConn", TestDeprecatedUnderlyingConn }, + { "TestNetConn", TestNetConn }, + { "TestBufioReadBytes", TestBufioReadBytes }, + { "TestCloseError", TestCloseError }, + { "TestUnexpectedCloseErrors", TestUnexpectedCloseErrors }, + { "TestConcurrentWritePanic", TestConcurrentWritePanic }, + { "TestFailedConnectionReadPanic", TestFailedConnectionReadPanic }, + { "TestJoinMessages", TestJoinMessages }, + { "TestJSON", TestJSON }, + { "TestPartialJSONRead", TestPartialJSONRead }, + { "TestDeprecatedJSON", TestDeprecatedJSON }, + { "TestMaskBytes", TestMaskBytes }, + { "TestPreparedMessage", TestPreparedMessage }, + { "TestSubprotocols", TestSubprotocols }, + { "TestIsWebSocketUpgrade", TestIsWebSocketUpgrade }, + { "TestSubProtocolSelection", TestSubProtocolSelection }, + { "TestCheckSameOrigin", TestCheckSameOrigin }, + { "TestHijack_NotSupported", TestHijack_NotSupported }, + { "TestEqualASCIIFold", TestEqualASCIIFold }, + { "TestTokenListContainsValue", TestTokenListContainsValue }, + { "TestIsValidChallengeKey", TestIsValidChallengeKey }, + { "TestParseExtensions", TestParseExtensions }, + { "TestParseArgs", TestParseArgs }, + } + + // FIXME: run benchmarks + benchmarks := []testing.InternalBenchmark { + { "BenchmarkWriteNoCompression", BenchmarkWriteNoCompression }, + { "BenchmarkWriteWithCompression", BenchmarkWriteWithCompression }, + { "BenchmarkBroadcast", BenchmarkBroadcast }, + { "BenchmarkMaskBytes", BenchmarkMaskBytes }, + } + + fuzzTargets := []testing.InternalFuzzTarget {} + examples := []testing.InternalExample {} + m := testing.MainStart( + testdeps.TestDeps {}, + tests, + benchmarks, + fuzzTargets, + examples, + ) + os.Exit(m.Run()) +} |