diff options
-rw-r--r-- | src/wscat.go | 470 | ||||
-rw-r--r-- | tests/wscat.go | 799 |
2 files changed, 13 insertions, 1256 deletions
diff --git a/src/wscat.go b/src/wscat.go index 6ed0aea..4cf63cb 100644 --- a/src/wscat.go +++ b/src/wscat.go @@ -7,7 +7,6 @@ import ( "context" "crypto/rand" "crypto/sha1" - "crypto/tls" "encoding/base64" "encoding/binary" "encoding/json" @@ -16,7 +15,6 @@ import ( "io" "net" "net/http" - "net/http/httptrace" "net/url" "os" "strconv" @@ -29,99 +27,12 @@ import ( g "gobang" ) + + // ErrBadHandshake is returned when the server response to opening handshake is // invalid. var ErrBadHandshake = errors.New("websocket: bad handshake") -var errInvalidCompression = errors.New("websocket: invalid compression negotiation") - -// NewClient creates a new client connection using the given net connection. -// The URL u specifies the host and request URI. Use requestHeader to specify -// the origin (Origin), subprotocols (Sec-WebSocket-Protocol) and cookies -// (Cookie). Use the response.Header to get the selected subprotocol -// (Sec-WebSocket-Protocol) and cookies (Set-Cookie). -// -// If the WebSocket handshake fails, ErrBadHandshake is returned along with a -// non-nil *http.Response so that callers can handle redirects, authentication, -// etc. -// -// Deprecated: Use Dialer instead. -func NewClient(netConn net.Conn, u *url.URL, requestHeader http.Header, readBufSize, writeBufSize int) (c *Conn, response *http.Response, err error) { - d := Dialer{ - ReadBufferSize: readBufSize, - WriteBufferSize: writeBufSize, - NetDial: func(net, addr string) (net.Conn, error) { - return netConn, nil - }, - } - return d.Dial(u.String(), requestHeader) -} - -// A Dialer contains options for connecting to WebSocket server. -// -// It is safe to call Dialer's methods concurrently. -type Dialer struct { - // NetDial specifies the dial function for creating TCP connections. If - // NetDial is nil, net.Dialer DialContext is used. - NetDial func(network, addr string) (net.Conn, error) - - // NetDialContext specifies the dial function for creating TCP connections. If - // NetDialContext is nil, NetDial is used. - NetDialContext func(ctx context.Context, network, addr string) (net.Conn, error) - - // NetDialTLSContext specifies the dial function for creating TLS/TCP connections. If - // NetDialTLSContext is nil, NetDialContext is used. - // If NetDialTLSContext is set, Dial assumes the TLS handshake is done there and - // TLSClientConfig is ignored. - NetDialTLSContext func(ctx context.Context, network, addr string) (net.Conn, error) - - // TLSClientConfig specifies the TLS configuration to use with tls.Client. - // If nil, the default configuration is used. - // If either NetDialTLS or NetDialTLSContext are set, Dial assumes the TLS handshake - // is done there and TLSClientConfig is ignored. - TLSClientConfig *tls.Config - - // HandshakeTimeout specifies the duration for the handshake to complete. - HandshakeTimeout time.Duration - - // ReadBufferSize and WriteBufferSize specify I/O buffer sizes in bytes. If a buffer - // size is zero, then a useful default size is used. The I/O buffer sizes - // do not limit the size of the messages that can be sent or received. - ReadBufferSize, WriteBufferSize int - - // WriteBufferPool is a pool of buffers for write operations. If the value - // is not set, then write buffers are allocated to the connection for the - // lifetime of the connection. - // - // A pool is most useful when the application has a modest volume of writes - // across a large number of connections. - // - // Applications should use a single pool for each unique value of - // WriteBufferSize. - WriteBufferPool BufferPool - - // Subprotocols specifies the client's requested subprotocols. - Subprotocols []string - - // EnableCompression specifies if the client should attempt to negotiate - // per message compression (RFC 7692). Setting this value to true does not - // guarantee that compression will be supported. Currently only "no context - // takeover" modes are supported. - EnableCompression bool - - // Jar specifies the cookie jar. - // If Jar is nil, cookies are not sent in requests and ignored - // in responses. - Jar http.CookieJar -} - -// Dial creates a new client connection by calling DialContext with a background context. -func (d *Dialer) Dial(urlStr string, requestHeader http.Header) (*Conn, *http.Response, error) { - return d.DialContext(context.Background(), urlStr, requestHeader) -} - -var errMalformedURL = errors.New("malformed ws or wss URL") - func hostPortNoPort(u *url.URL) (hostPort, hostNoPort string) { hostPort = u.Host hostNoPort = u.Host @@ -140,293 +51,6 @@ func hostPortNoPort(u *url.URL) (hostPort, hostNoPort string) { return hostPort, hostNoPort } -// DefaultDialer is a dialer with all fields set to the default values. -var DefaultDialer = &Dialer{ - HandshakeTimeout: 45 * time.Second, -} - -// nilDialer is dialer to use when receiver is nil. -var nilDialer = *DefaultDialer - -// DialContext creates a new client connection. Use requestHeader to specify the -// origin (Origin), subprotocols (Sec-WebSocket-Protocol) and cookies (Cookie). -// Use the response.Header to get the selected subprotocol -// (Sec-WebSocket-Protocol) and cookies (Set-Cookie). -// -// The context will be used in the request and in the Dialer. -// -// If the WebSocket handshake fails, ErrBadHandshake is returned along with a -// non-nil *http.Response so that callers can handle redirects, authentication, -// etcetera. The response body may not contain the entire response and does not -// need to be closed by the application. -func (d *Dialer) DialContext(ctx context.Context, urlStr string, requestHeader http.Header) (*Conn, *http.Response, error) { - if d == nil { - d = &nilDialer - } - - challengeKey, err := generateChallengeKey() - if err != nil { - return nil, nil, err - } - - u, err := url.Parse(urlStr) - if err != nil { - return nil, nil, err - } - - switch u.Scheme { - case "ws": - u.Scheme = "http" - case "wss": - u.Scheme = "https" - default: - return nil, nil, errMalformedURL - } - - if u.User != nil { - // User name and password are not allowed in websocket URIs. - return nil, nil, errMalformedURL - } - - req := &http.Request{ - Method: http.MethodGet, - URL: u, - Proto: "HTTP/1.1", - ProtoMajor: 1, - ProtoMinor: 1, - Header: make(http.Header), - Host: u.Host, - } - req = req.WithContext(ctx) - - // Set the cookies present in the cookie jar of the dialer - if d.Jar != nil { - for _, cookie := range d.Jar.Cookies(u) { - req.AddCookie(cookie) - } - } - - // Set the request headers using the capitalization for names and values in - // RFC examples. Although the capitalization shouldn't matter, there are - // servers that depend on it. The Header.Set method is not used because the - // method canonicalizes the header names. - req.Header["Upgrade"] = []string{"websocket"} - req.Header["Connection"] = []string{"Upgrade"} - req.Header["Sec-WebSocket-Key"] = []string{challengeKey} - req.Header["Sec-WebSocket-Version"] = []string{"13"} - if len(d.Subprotocols) > 0 { - req.Header["Sec-WebSocket-Protocol"] = []string{strings.Join(d.Subprotocols, ", ")} - } - for k, vs := range requestHeader { - switch { - case k == "Host": - if len(vs) > 0 { - req.Host = vs[0] - } - case k == "Upgrade" || - k == "Connection" || - k == "Sec-Websocket-Key" || - k == "Sec-Websocket-Version" || - k == "Sec-Websocket-Extensions" || - (k == "Sec-Websocket-Protocol" && len(d.Subprotocols) > 0): - return nil, nil, errors.New("websocket: duplicate header not allowed: " + k) - case k == "Sec-Websocket-Protocol": - req.Header["Sec-WebSocket-Protocol"] = vs - default: - req.Header[k] = vs - } - } - - if d.EnableCompression { - req.Header["Sec-WebSocket-Extensions"] = []string{"permessage-deflate; server_no_context_takeover; client_no_context_takeover"} - } - - if d.HandshakeTimeout != 0 { - var cancel func() - ctx, cancel = context.WithTimeout(ctx, d.HandshakeTimeout) - defer cancel() - } - - var netDial netDialerFunc - switch { - case u.Scheme == "https" && d.NetDialTLSContext != nil: - netDial = d.NetDialTLSContext - case d.NetDialContext != nil: - netDial = d.NetDialContext - case d.NetDial != nil: - netDial = func(ctx context.Context, net, addr string) (net.Conn, error) { - return d.NetDial(net, addr) - } - default: - netDial = (&net.Dialer{}).DialContext - } - - // If needed, wrap the dial function to set the connection deadline. - if deadline, ok := ctx.Deadline(); ok { - forwardDial := netDial - netDial = func(ctx context.Context, network, addr string) (net.Conn, error) { - c, err := forwardDial(ctx, network, addr) - if err != nil { - return nil, err - } - err = c.SetDeadline(deadline) - if err != nil { - c.Close() - return nil, err - } - return c, nil - } - } - - hostPort, hostNoPort := hostPortNoPort(u) - trace := httptrace.ContextClientTrace(ctx) - if trace != nil && trace.GetConn != nil { - trace.GetConn(hostPort) - } - - netConn, err := netDial(ctx, "tcp", hostPort) - if err != nil { - return nil, nil, err - } - if trace != nil && trace.GotConn != nil { - trace.GotConn(httptrace.GotConnInfo{ - Conn: netConn, - }) - } - - // Close the network connection when returning an error. The variable - // netConn is set to nil before the success return at the end of the - // function. - defer func() { - if netConn != nil { - // It's safe to ignore the error from Close() because this code is - // only executed when returning a more important error to the - // application. - _ = netConn.Close() - } - }() - - if u.Scheme == "https" && d.NetDialTLSContext == nil { - // If NetDialTLSContext is set, assume that the TLS handshake has already been done - - cfg := cloneTLSConfig(d.TLSClientConfig) - if cfg.ServerName == "" { - cfg.ServerName = hostNoPort - } - tlsConn := tls.Client(netConn, cfg) - netConn = tlsConn - - if trace != nil && trace.TLSHandshakeStart != nil { - trace.TLSHandshakeStart() - } - err := doHandshake(ctx, tlsConn, cfg) - if trace != nil && trace.TLSHandshakeDone != nil { - trace.TLSHandshakeDone(tlsConn.ConnectionState(), err) - } - - if err != nil { - return nil, nil, err - } - } - - conn := newConn(netConn, false, d.ReadBufferSize, d.WriteBufferSize, d.WriteBufferPool, nil, nil) - - if err := req.Write(netConn); err != nil { - return nil, nil, err - } - - if trace != nil && trace.GotFirstResponseByte != nil { - if peek, err := conn.br.Peek(1); err == nil && len(peek) == 1 { - trace.GotFirstResponseByte() - } - } - - resp, err := http.ReadResponse(conn.br, req) - if err != nil { - if d.TLSClientConfig != nil { - for _, proto := range d.TLSClientConfig.NextProtos { - if proto != "http/1.1" { - return nil, nil, fmt.Errorf( - "websocket: protocol %q was given but is not supported;"+ - "sharing tls.Config with net/http Transport can cause this error: %w", - proto, err, - ) - } - } - } - return nil, nil, err - } - - if d.Jar != nil { - if rc := resp.Cookies(); len(rc) > 0 { - d.Jar.SetCookies(u, rc) - } - } - - if resp.StatusCode != 101 || - !tokenListContainsValue(resp.Header, "Upgrade", "websocket") || - !tokenListContainsValue(resp.Header, "Connection", "upgrade") || - resp.Header.Get("Sec-Websocket-Accept") != computeAcceptKey(challengeKey) { - // Before closing the network connection on return from this - // function, slurp up some of the response to aid application - // debugging. - buf := make([]byte, 1024) - n, _ := io.ReadFull(resp.Body, buf) - resp.Body = io.NopCloser(bytes.NewReader(buf[:n])) - return nil, resp, ErrBadHandshake - } - - for _, ext := range parseExtensions(resp.Header) { - if ext[""] != "permessage-deflate" { - continue - } - _, snct := ext["server_no_context_takeover"] - _, cnct := ext["client_no_context_takeover"] - if !snct || !cnct { - return nil, resp, errInvalidCompression - } - conn.newCompressionWriter = compressNoContextTakeover - conn.newDecompressionReader = decompressNoContextTakeover - break - } - - resp.Body = io.NopCloser(bytes.NewReader([]byte{})) - conn.subprotocol = resp.Header.Get("Sec-Websocket-Protocol") - - if err := netConn.SetDeadline(time.Time{}); err != nil { - return nil, resp, err - } - - // Success! Set netConn to nil to stop the deferred function above from - // closing the network connection. - netConn = nil - - return conn, resp, nil -} - -func cloneTLSConfig(cfg *tls.Config) *tls.Config { - if cfg == nil { - return &tls.Config{} - } - return cfg.Clone() -} - -func doHandshake(ctx context.Context, tlsConn *tls.Conn, cfg *tls.Config) error { - if err := tlsConn.HandshakeContext(ctx); err != nil { - return err - } - if !cfg.InsecureSkipVerify { - if err := tlsConn.VerifyHostname(cfg.ServerName); err != nil { - return err - } - } - return nil -} -// Copyright 2017 The Gorilla WebSocket Authors. All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - - const ( minCompressionLevel = -2 // flate.HuffmanOnly not defined in Go < 1.6 maxCompressionLevel = flate.BestCompression @@ -565,10 +189,6 @@ func (r *flateReadWrapper) Close() error { r.fr = nil return err } -// Copyright 2013 The Gorilla WebSocket Authors. All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - const ( // Frame header byte 0 bits from Section 5.2 of RFC 6455 @@ -2287,70 +1907,6 @@ func (fn netDialerFunc) DialContext(ctx context.Context, network, addr string) ( return fn(ctx, network, addr) } -type httpProxyDialer struct { - proxyURL *url.URL - forwardDial netDialerFunc -} - -func (hpd *httpProxyDialer) DialContext(ctx context.Context, network string, addr string) (net.Conn, error) { - hostPort, _ := hostPortNoPort(hpd.proxyURL) - conn, err := hpd.forwardDial(ctx, network, hostPort) - if err != nil { - return nil, err - } - - connectHeader := make(http.Header) - if user := hpd.proxyURL.User; user != nil { - proxyUser := user.Username() - if proxyPassword, passwordSet := user.Password(); passwordSet { - credential := base64.StdEncoding.EncodeToString([]byte(proxyUser + ":" + proxyPassword)) - connectHeader.Set("Proxy-Authorization", "Basic "+credential) - } - } - - connectReq := &http.Request{ - Method: http.MethodConnect, - URL: &url.URL{Opaque: addr}, - Host: addr, - Header: connectHeader, - } - - if err := connectReq.Write(conn); err != nil { - conn.Close() - return nil, err - } - - // Read response. It's OK to use and discard buffered reader here because - // the remote server does not speak until spoken to. - br := bufio.NewReader(conn) - resp, err := http.ReadResponse(br, connectReq) - if err != nil { - conn.Close() - return nil, err - } - - // Close the response body to silence false positives from linters. Reset - // the buffered reader first to ensure that Close() does not read from - // conn. - // Note: Applications must call resp.Body.Close() on a response returned - // http.ReadResponse to inspect trailers or read another response from the - // buffered reader. The call to resp.Body.Close() does not release - // resources. - br.Reset(bytes.NewReader(nil)) - _ = resp.Body.Close() - - if resp.StatusCode != http.StatusOK { - _ = conn.Close() - f := strings.SplitN(resp.Status, " ", 2) - return nil, errors.New(f[1]) - } - return conn, nil -} -// Copyright 2013 The Gorilla WebSocket Authors. All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - - // HandshakeError describes an error with the handshake from the peer. type HandshakeError struct { message string @@ -3010,7 +2566,7 @@ var EmitActiveConnection = g.MakeGauge("active-connections") -func ParseArgs(args []string) CLIArgs { +func parseArgs(args []string) CLIArgs { if len(args) != 3 { fmt.Fprintf( os.Stderr, @@ -3038,12 +2594,11 @@ func copyData(c chan struct {}, from io.Reader, to io.WriteCloser) { } func Start(toAddr string, listener net.Listener) { - /* - upgrader := websocket.Upgrader {} + upgrader := Upgrader {} http.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { connFrom, err := upgrader.Upgrade(w, r, nil) if err != nil { - g.Warning( + g.Error( "Error upgrading connection", "upgrade-connection-error", "err", err, @@ -3057,16 +2612,16 @@ func Start(toAddr string, listener net.Listener) { if err != nil { g.Error( "Error dialing connection", - "dial-connection", + "dial-connection-error", "err", err, ) - os.Exit(1) + return } defer connTo.Close() messageType, reader, err := connFrom.NextReader() if err != nil { - g.Warning( + g.Error( "Failed to get next reader from connection", "connection-next-reader-error", "err", err, @@ -3076,9 +2631,9 @@ func Start(toAddr string, listener net.Listener) { writer, err := connFrom.NextWriter(messageType) if err != nil { - g.Warning( - "Failed to get next reader from connection", - "connection-next-reader-error", + g.Error( + "Failed to get next writer from connection", + "connection-next-writer-error", "err", err, ) return @@ -3097,13 +2652,12 @@ func Start(toAddr string, listener net.Listener) { server := http.Server{} err := server.Serve(listener) g.FatalIf(err) - */ } func Main() { g.Init() - args := ParseArgs(os.Args) + args := parseArgs(os.Args) listener := Listen(args.FromAddr) Start(args.ToAddr, listener) } diff --git a/tests/wscat.go b/tests/wscat.go index 6ef1446..80317ed 100644 --- a/tests/wscat.go +++ b/tests/wscat.go @@ -4,20 +4,15 @@ import ( "bufio" "bytes" "compress/flate" - "context" - "crypto/tls" "crypto/x509" "encoding/json" "errors" "fmt" "io" - "log" "math/rand" "net" "net/http" - "net/http/cookiejar" "net/http/httptest" - "net/http/httptrace" "net/url" "os" "reflect" @@ -44,13 +39,6 @@ var cstUpgrader = Upgrader{ }, } -var cstDialer = Dialer{ - Subprotocols: []string{"p1", "p2"}, - ReadBufferSize: 1024, - WriteBufferSize: 1024, - HandshakeTimeout: 30 * time.Second, -} - type cstHandler struct { *testing.T s *cstServer @@ -107,12 +95,6 @@ func (t cstHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { http.Error(w, "bad path", http.StatusBadRequest) return } - subprotos := Subprotocols(r) - if !reflect.DeepEqual(subprotos, cstDialer.Subprotocols) { - t.Logf("subprotols=%v, want %v", subprotos, cstDialer.Subprotocols) - http.Error(w, "bad protocol", http.StatusBadRequest) - return - } ws, err := cstUpgrader.Upgrade(w, r, http.Header{"Set-Cookie": {"sessionID=1234"}}) if err != nil { t.Logf("Upgrade: %v", err) @@ -169,66 +151,6 @@ func sendRecv(t *testing.T, ws *Conn) { } } -func TestDial(t *testing.T) { - s := newServer(t) - defer s.Close() - - ws, _, err := cstDialer.Dial(s.URL, nil) - if err != nil { - t.Fatalf("Dial: %v", err) - } - defer ws.Close() - sendRecv(t, ws) -} - -func TestDialCookieJar(t *testing.T) { - s := newServer(t) - defer s.Close() - - jar, _ := cookiejar.New(nil) - d := cstDialer - d.Jar = jar - - u, _ := url.Parse(s.URL) - - switch u.Scheme { - case "ws": - u.Scheme = "http" - case "wss": - u.Scheme = "https" - } - - cookies := []*http.Cookie{{Name: "gorilla", Value: "ws", Path: "/"}} - d.Jar.SetCookies(u, cookies) - - ws, _, err := d.Dial(s.URL, nil) - if err != nil { - t.Fatalf("Dial: %v", err) - } - defer ws.Close() - - var gorilla string - var sessionID string - for _, c := range d.Jar.Cookies(u) { - if c.Name == "gorilla" { - gorilla = c.Value - } - - if c.Name == "sessionID" { - sessionID = c.Value - } - } - if gorilla != "ws" { - t.Error("Cookie not present in jar.") - } - - if sessionID != "1234" { - t.Error("Set-Cookie not received from the server.") - } - - sendRecv(t, ws) -} - func rootCAs(t *testing.T, s *httptest.Server) *x509.CertPool { certs := x509.NewCertPool() for _, c := range s.TLS.Certificates { @@ -243,33 +165,6 @@ func rootCAs(t *testing.T, s *httptest.Server) *x509.CertPool { return certs } -func TestDialTLS(t *testing.T) { - s := newTLSServer(t) - defer s.Close() - - d := cstDialer - d.TLSClientConfig = &tls.Config{RootCAs: rootCAs(t, s.Server)} - ws, _, err := d.Dial(s.URL, nil) - if err != nil { - t.Fatalf("Dial: %v", err) - } - defer ws.Close() - sendRecv(t, ws) -} - -func TestDialTimeout(t *testing.T) { - s := newServer(t) - defer s.Close() - - d := cstDialer - d.HandshakeTimeout = -1 - ws, _, err := d.Dial(s.URL, nil) - if err == nil { - ws.Close() - t.Fatalf("Dial: nil") - } -} - // requireDeadlineNetConn fails the current test when Read or Write are called // with no deadline. type requireDeadlineNetConn struct { @@ -313,90 +208,6 @@ func (c *requireDeadlineNetConn) Close() error { return c.c.Close() } func (c *requireDeadlineNetConn) LocalAddr() net.Addr { return c.c.LocalAddr() } func (c *requireDeadlineNetConn) RemoteAddr() net.Addr { return c.c.RemoteAddr() } -func TestHandshakeTimeout(t *testing.T) { - s := newServer(t) - defer s.Close() - - d := cstDialer - d.NetDial = func(n, a string) (net.Conn, error) { - c, err := net.Dial(n, a) - return &requireDeadlineNetConn{c: c, t: t}, err - } - ws, _, err := d.Dial(s.URL, nil) - if err != nil { - t.Fatal("Dial:", err) - } - ws.Close() -} - -func TestHandshakeTimeoutInContext(t *testing.T) { - s := newServer(t) - defer s.Close() - - d := cstDialer - d.HandshakeTimeout = 0 - d.NetDialContext = func(ctx context.Context, n, a string) (net.Conn, error) { - netDialer := &net.Dialer{} - c, err := netDialer.DialContext(ctx, n, a) - return &requireDeadlineNetConn{c: c, t: t}, err - } - - ctx, cancel := context.WithDeadline(context.Background(), time.Now().Add(30*time.Second)) - defer cancel() - ws, _, err := d.DialContext(ctx, s.URL, nil) - if err != nil { - t.Fatal("Dial:", err) - } - ws.Close() -} - -func TestDialBadScheme(t *testing.T) { - s := newServer(t) - defer s.Close() - - ws, _, err := cstDialer.Dial(s.Server.URL, nil) - if err == nil { - ws.Close() - t.Fatalf("Dial: nil") - } -} - -func TestDialBadOrigin(t *testing.T) { - s := newServer(t) - defer s.Close() - - ws, resp, err := cstDialer.Dial(s.URL, http.Header{"Origin": {"bad"}}) - if err == nil { - ws.Close() - t.Fatalf("Dial: nil") - } - if resp == nil { - t.Fatalf("resp=nil, err=%v", err) - } - if resp.StatusCode != http.StatusForbidden { - t.Fatalf("status=%d, want %d", resp.StatusCode, http.StatusForbidden) - } -} - -func TestDialBadHeader(t *testing.T) { - s := newServer(t) - defer s.Close() - - for _, k := range []string{"Upgrade", - "Connection", - "Sec-Websocket-Key", - "Sec-Websocket-Version", - "Sec-Websocket-Protocol"} { - h := http.Header{} - h.Set(k, "bad") - ws, _, err := cstDialer.Dial(s.URL, http.Header{"Origin": {"bad"}}) - if err == nil { - ws.Close() - t.Errorf("Dial with header %s returned nil", k) - } - } -} - func TestBadMethod(t *testing.T) { s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { ws, err := cstUpgrader.Upgrade(w, r, nil) @@ -456,83 +267,6 @@ func TestNoUpgrade(t *testing.T) { } } -func TestDialExtraTokensInRespHeaders(t *testing.T) { - s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - challengeKey := r.Header.Get("Sec-Websocket-Key") - w.Header().Set("Upgrade", "foo, websocket") - w.Header().Set("Connection", "upgrade, keep-alive") - w.Header().Set("Sec-Websocket-Accept", computeAcceptKey(challengeKey)) - w.WriteHeader(101) - })) - defer s.Close() - - ws, _, err := cstDialer.Dial(makeWsProto(s.URL), nil) - if err != nil { - t.Fatalf("Dial: %v", err) - } - defer ws.Close() -} - -func TestHandshake(t *testing.T) { - s := newServer(t) - defer s.Close() - - ws, resp, err := cstDialer.Dial(s.URL, http.Header{"Origin": {s.URL}}) - if err != nil { - t.Fatalf("Dial: %v", err) - } - defer ws.Close() - - var sessionID string - for _, c := range resp.Cookies() { - if c.Name == "sessionID" { - sessionID = c.Value - } - } - if sessionID != "1234" { - t.Error("Set-Cookie not received from the server.") - } - - if ws.Subprotocol() != "p1" { - t.Errorf("ws.Subprotocol() = %s, want p1", ws.Subprotocol()) - } - sendRecv(t, ws) -} - -func TestRespOnBadHandshake(t *testing.T) { - const expectedStatus = http.StatusGone - const expectedBody = "This is the response body." - - s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.WriteHeader(expectedStatus) - _, _ = io.WriteString(w, expectedBody) - })) - defer s.Close() - - ws, resp, err := cstDialer.Dial(makeWsProto(s.URL), nil) - if err == nil { - ws.Close() - t.Fatalf("Dial: nil") - } - - if resp == nil { - t.Fatalf("resp=nil, err=%v", err) - } - - if resp.StatusCode != expectedStatus { - t.Errorf("resp.StatusCode=%d, want %d", resp.StatusCode, expectedStatus) - } - - p, err := io.ReadAll(resp.Body) - if err != nil { - t.Fatalf("ReadFull(resp.Body) returned error %v", err) - } - - if string(p) != expectedBody { - t.Errorf("resp.Body=%s, want %s", p, expectedBody) - } -} - type testLogWriter struct { t *testing.T } @@ -542,492 +276,6 @@ func (w testLogWriter) Write(p []byte) (int, error) { return len(p), nil } -// TestHost tests handling of host names and confirms that it matches net/http. -func TestHost(t *testing.T) { - - upgrader := Upgrader{} - handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - if IsWebSocketUpgrade(r) { - c, err := upgrader.Upgrade(w, r, http.Header{"X-Test-Host": {r.Host}}) - if err != nil { - t.Fatal(err) - } - c.Close() - } else { - w.Header().Set("X-Test-Host", r.Host) - } - }) - - server := httptest.NewServer(handler) - defer server.Close() - - tlsServer := httptest.NewTLSServer(handler) - defer tlsServer.Close() - - addrs := map[*httptest.Server]string{server: server.Listener.Addr().String(), tlsServer: tlsServer.Listener.Addr().String()} - wsProtos := map[*httptest.Server]string{server: "ws://", tlsServer: "wss://"} - httpProtos := map[*httptest.Server]string{server: "http://", tlsServer: "https://"} - - // Avoid log noise from net/http server by logging to testing.T - server.Config.ErrorLog = log.New(testLogWriter{t}, "", 0) - tlsServer.Config.ErrorLog = server.Config.ErrorLog - - cas := rootCAs(t, tlsServer) - - tests := []struct { - fail bool // true if dial / get should fail - server *httptest.Server // server to use - url string // host for request URI - header string // optional request host header - tls string // optional host for tls ServerName - wantAddr string // expected host for dial - wantHeader string // expected request header on server - insecureSkipVerify bool - }{ - { - server: server, - url: addrs[server], - wantAddr: addrs[server], - wantHeader: addrs[server], - }, - { - server: tlsServer, - url: addrs[tlsServer], - wantAddr: addrs[tlsServer], - wantHeader: addrs[tlsServer], - }, - - { - server: server, - url: addrs[server], - header: "badhost.com", - wantAddr: addrs[server], - wantHeader: "badhost.com", - }, - { - server: tlsServer, - url: addrs[tlsServer], - header: "badhost.com", - wantAddr: addrs[tlsServer], - wantHeader: "badhost.com", - }, - - { - server: server, - url: "example.com", - header: "badhost.com", - wantAddr: "example.com:80", - wantHeader: "badhost.com", - }, - { - server: tlsServer, - url: "example.com", - header: "badhost.com", - wantAddr: "example.com:443", - wantHeader: "badhost.com", - }, - - { - server: server, - url: "badhost.com", - header: "example.com", - wantAddr: "badhost.com:80", - wantHeader: "example.com", - }, - { - fail: true, - server: tlsServer, - url: "badhost.com", - header: "example.com", - wantAddr: "badhost.com:443", - }, - { - server: tlsServer, - url: "badhost.com", - insecureSkipVerify: true, - wantAddr: "badhost.com:443", - wantHeader: "badhost.com", - }, - { - server: tlsServer, - url: "badhost.com", - tls: "example.com", - wantAddr: "badhost.com:443", - wantHeader: "badhost.com", - }, - } - - for i, tt := range tests { - - tls := &tls.Config{ - RootCAs: cas, - ServerName: tt.tls, - InsecureSkipVerify: tt.insecureSkipVerify, - } - - var gotAddr string - dialer := Dialer{ - NetDial: func(network, addr string) (net.Conn, error) { - gotAddr = addr - return net.Dial(network, addrs[tt.server]) - }, - TLSClientConfig: tls, - } - - // Test websocket dial - - h := http.Header{} - if tt.header != "" { - h.Set("Host", tt.header) - } - c, resp, err := dialer.Dial(wsProtos[tt.server]+tt.url+"/", h) - if err == nil { - c.Close() - } - - check := func(protos map[*httptest.Server]string) { - name := fmt.Sprintf("%d: %s%s/ header[Host]=%q, tls.ServerName=%q", i+1, protos[tt.server], tt.url, tt.header, tt.tls) - if gotAddr != tt.wantAddr { - t.Errorf("%s: got addr %s, want %s", name, gotAddr, tt.wantAddr) - } - switch { - case tt.fail && err == nil: - t.Errorf("%s: unexpected success", name) - case !tt.fail && err != nil: - t.Errorf("%s: unexpected error %v", name, err) - case !tt.fail && err == nil: - if gotHost := resp.Header.Get("X-Test-Host"); gotHost != tt.wantHeader { - t.Errorf("%s: got host %s, want %s", name, gotHost, tt.wantHeader) - } - } - } - - check(wsProtos) - - // Confirm that net/http has same result - - transport := &http.Transport{ - Dial: dialer.NetDial, - TLSClientConfig: dialer.TLSClientConfig, - } - req, _ := http.NewRequest(http.MethodGet, httpProtos[tt.server]+tt.url+"/", nil) - if tt.header != "" { - req.Host = tt.header - } - client := &http.Client{Transport: transport} - resp, err = client.Do(req) - if err == nil { - resp.Body.Close() - } - transport.CloseIdleConnections() - check(httpProtos) - } -} - -func TestDialCompression(t *testing.T) { - s := newServer(t) - defer s.Close() - - dialer := cstDialer - dialer.EnableCompression = true - ws, _, err := dialer.Dial(s.URL, nil) - if err != nil { - t.Fatalf("Dial: %v", err) - } - defer ws.Close() - sendRecv(t, ws) -} - -func TestTracingDialWithContext(t *testing.T) { - - var headersWrote, requestWrote, getConn, gotConn, connectDone, gotFirstResponseByte bool - trace := &httptrace.ClientTrace{ - WroteHeaders: func() { - headersWrote = true - }, - WroteRequest: func(httptrace.WroteRequestInfo) { - requestWrote = true - }, - GetConn: func(hostPort string) { - getConn = true - }, - GotConn: func(info httptrace.GotConnInfo) { - gotConn = true - }, - ConnectDone: func(network, addr string, err error) { - connectDone = true - }, - GotFirstResponseByte: func() { - gotFirstResponseByte = true - }, - } - ctx := httptrace.WithClientTrace(context.Background(), trace) - - s := newTLSServer(t) - defer s.Close() - - d := cstDialer - d.TLSClientConfig = &tls.Config{RootCAs: rootCAs(t, s.Server)} - - ws, _, err := d.DialContext(ctx, s.URL, nil) - if err != nil { - t.Fatalf("Dial: %v", err) - } - - if !headersWrote { - t.Fatal("Headers was not written") - } - if !requestWrote { - t.Fatal("Request was not written") - } - if !getConn { - t.Fatal("getConn was not called") - } - if !gotConn { - t.Fatal("gotConn was not called") - } - if !connectDone { - t.Fatal("connectDone was not called") - } - if !gotFirstResponseByte { - t.Fatal("GotFirstResponseByte was not called") - } - - defer ws.Close() - sendRecv(t, ws) -} - -func TestEmptyTracingDialWithContext(t *testing.T) { - - trace := &httptrace.ClientTrace{} - ctx := httptrace.WithClientTrace(context.Background(), trace) - - s := newTLSServer(t) - defer s.Close() - - d := cstDialer - d.TLSClientConfig = &tls.Config{RootCAs: rootCAs(t, s.Server)} - - ws, _, err := d.DialContext(ctx, s.URL, nil) - if err != nil { - t.Fatalf("Dial: %v", err) - } - - defer ws.Close() - sendRecv(t, ws) -} - -// TestNetDialConnect tests selection of dial method between NetDial, NetDialContext, NetDialTLS or NetDialTLSContext -func TestNetDialConnect(t *testing.T) { - - upgrader := Upgrader{} - handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - if IsWebSocketUpgrade(r) { - c, err := upgrader.Upgrade(w, r, http.Header{"X-Test-Host": {r.Host}}) - if err != nil { - t.Fatal(err) - } - c.Close() - } else { - w.Header().Set("X-Test-Host", r.Host) - } - }) - - server := httptest.NewServer(handler) - defer server.Close() - - tlsServer := httptest.NewTLSServer(handler) - defer tlsServer.Close() - - testUrls := map[*httptest.Server]string{ - server: "ws://" + server.Listener.Addr().String() + "/", - tlsServer: "wss://" + tlsServer.Listener.Addr().String() + "/", - } - - cas := rootCAs(t, tlsServer) - tlsConfig := &tls.Config{ - RootCAs: cas, - ServerName: "example.com", - InsecureSkipVerify: false, - } - - tests := []struct { - name string - server *httptest.Server // server to use - netDial func(network, addr string) (net.Conn, error) - netDialContext func(ctx context.Context, network, addr string) (net.Conn, error) - netDialTLSContext func(ctx context.Context, network, addr string) (net.Conn, error) - tlsClientConfig *tls.Config - }{ - - { - name: "HTTP server, all NetDial* defined, shall use NetDialContext", - server: server, - netDial: func(network, addr string) (net.Conn, error) { - return nil, errors.New("NetDial should not be called") - }, - netDialContext: func(_ context.Context, network, addr string) (net.Conn, error) { - return net.Dial(network, addr) - }, - netDialTLSContext: func(_ context.Context, network, addr string) (net.Conn, error) { - return nil, errors.New("NetDialTLSContext should not be called") - }, - tlsClientConfig: nil, - }, - { - name: "HTTP server, all NetDial* undefined", - server: server, - netDial: nil, - netDialContext: nil, - netDialTLSContext: nil, - tlsClientConfig: nil, - }, - { - name: "HTTP server, NetDialContext undefined, shall fallback to NetDial", - server: server, - netDial: func(network, addr string) (net.Conn, error) { - return net.Dial(network, addr) - }, - netDialContext: nil, - netDialTLSContext: func(ctx context.Context, network, addr string) (net.Conn, error) { - return nil, errors.New("NetDialTLSContext should not be called") - }, - tlsClientConfig: nil, - }, - { - name: "HTTPS server, all NetDial* defined, shall use NetDialTLSContext", - server: tlsServer, - netDial: func(network, addr string) (net.Conn, error) { - return nil, errors.New("NetDial should not be called") - }, - netDialContext: func(ctx context.Context, network, addr string) (net.Conn, error) { - return nil, errors.New("NetDialContext should not be called") - }, - netDialTLSContext: func(ctx context.Context, network, addr string) (net.Conn, error) { - netConn, err := net.Dial(network, addr) - if err != nil { - return nil, err - } - tlsConn := tls.Client(netConn, tlsConfig) - err = tlsConn.Handshake() - if err != nil { - return nil, err - } - return tlsConn, nil - }, - tlsClientConfig: nil, - }, - { - name: "HTTPS server, NetDialTLSContext undefined, shall fallback to NetDialContext and do handshake", - server: tlsServer, - netDial: func(network, addr string) (net.Conn, error) { - return nil, errors.New("NetDial should not be called") - }, - netDialContext: func(ctx context.Context, network, addr string) (net.Conn, error) { - return net.Dial(network, addr) - }, - netDialTLSContext: nil, - tlsClientConfig: tlsConfig, - }, - { - name: "HTTPS server, NetDialTLSContext and NetDialContext undefined, shall fallback to NetDial and do handshake", - server: tlsServer, - netDial: func(network, addr string) (net.Conn, error) { - return net.Dial(network, addr) - }, - netDialContext: nil, - netDialTLSContext: nil, - tlsClientConfig: tlsConfig, - }, - { - name: "HTTPS server, all NetDial* undefined", - server: tlsServer, - netDial: nil, - netDialContext: nil, - netDialTLSContext: nil, - tlsClientConfig: tlsConfig, - }, - { - name: "HTTPS server, all NetDialTLSContext defined, dummy TlsClientConfig defined, shall not do handshake", - server: tlsServer, - netDial: func(network, addr string) (net.Conn, error) { - return nil, errors.New("NetDial should not be called") - }, - netDialContext: func(ctx context.Context, network, addr string) (net.Conn, error) { - return nil, errors.New("NetDialContext should not be called") - }, - netDialTLSContext: func(ctx context.Context, network, addr string) (net.Conn, error) { - netConn, err := net.Dial(network, addr) - if err != nil { - return nil, err - } - tlsConn := tls.Client(netConn, tlsConfig) - err = tlsConn.Handshake() - if err != nil { - return nil, err - } - return tlsConn, nil - }, - tlsClientConfig: &tls.Config{ - RootCAs: nil, - ServerName: "badserver.com", - InsecureSkipVerify: false, - }, - }, - } - - for _, tc := range tests { - dialer := Dialer{ - NetDial: tc.netDial, - NetDialContext: tc.netDialContext, - NetDialTLSContext: tc.netDialTLSContext, - TLSClientConfig: tc.tlsClientConfig, - } - - // Test websocket dial - c, _, err := dialer.Dial(testUrls[tc.server], nil) - if err != nil { - t.Errorf("FAILED %s, err: %s", tc.name, err.Error()) - } else { - c.Close() - } - } -} -func TestNextProtos(t *testing.T) { - ts := httptest.NewUnstartedServer( - http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}), - ) - ts.EnableHTTP2 = true - ts.StartTLS() - defer ts.Close() - - d := Dialer{ - TLSClientConfig: ts.Client().Transport.(*http.Transport).TLSClientConfig, - } - - r, err := ts.Client().Get(ts.URL) - if err != nil { - t.Fatalf("Get: %v", err) - } - r.Body.Close() - - // Asserts that Dialer.TLSClientConfig.NextProtos contains "h2" - // after the Client.Get call from net/http above. - var containsHTTP2 bool = false - for _, proto := range d.TLSClientConfig.NextProtos { - if proto == "h2" { - containsHTTP2 = true - } - } - if !containsHTTP2 { - t.Fatalf("Dialer.TLSClientConfig.NextProtos does not contain \"h2\"") - } - - _, _, err = d.Dial(makeWsProto(ts.URL), nil) - if err == nil { - t.Fatalf("Dial succeeded, expect fail ") - } -} type dataBeforeHandshakeResponseWriter struct { http.ResponseWriter @@ -1065,32 +313,6 @@ func (w dataBeforeHandshakeResponseWriter) Hijack() (net.Conn, *bufio.ReadWriter return c, rw, err } -func TestDataReceivedBeforeHandshake(t *testing.T) { - s := newServer(t) - defer s.Close() - - origHandler := s.Server.Config.Handler - s.Server.Config.Handler = http.HandlerFunc( - func(w http.ResponseWriter, r *http.Request) { - origHandler.ServeHTTP(dataBeforeHandshakeResponseWriter{w}, r) - }) - - for _, readBufferSize := range []int{0, 1024} { - t.Run(fmt.Sprintf("ReadBufferSize=%d", readBufferSize), func(t *testing.T) { - dialer := cstDialer - dialer.ReadBufferSize = readBufferSize - ws, _, err := cstDialer.Dial(s.URL, nil) - if err != nil { - t.Fatalf("Dial: %v", err) - } - defer ws.Close() - _, m, err := ws.ReadMessage() - if err != nil || string(m) != "Hello" { - t.Fatalf("ReadMessage() = %q, %v, want \"Hello\", nil", m, err) - } - }) - } -} // Copyright 2014 The Gorilla WebSocket Authors. All rights reserved. // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. @@ -2584,7 +1806,7 @@ func TestParseExtensions(t *testing.T) { } func TestParseArgs(t *testing.T) { - given := ParseArgs([]string { "x", "y", "z" }) + given := parseArgs([]string { "x", "y", "z" }) expected := CLIArgs { FromAddr: "y", ToAddr: "z", @@ -2597,27 +1819,8 @@ func TestParseArgs(t *testing.T) { func MainTest() { tests := []testing.InternalTest { - { "TestDial", TestDial }, - { "TestDialCookieJar", TestDialCookieJar }, - { "TestDialTLS", TestDialTLS }, - { "TestDialTimeout", TestDialTimeout }, - { "TestHandshakeTimeout", TestHandshakeTimeout }, - { "TestHandshakeTimeoutInContext", TestHandshakeTimeoutInContext }, - { "TestDialBadScheme", TestDialBadScheme }, - { "TestDialBadOrigin", TestDialBadOrigin }, - { "TestDialBadHeader", TestDialBadHeader }, { "TestBadMethod", TestBadMethod }, { "TestNoUpgrade", TestNoUpgrade }, - { "TestDialExtraTokensInRespHeaders", TestDialExtraTokensInRespHeaders }, - { "TestHandshake", TestHandshake }, - { "TestRespOnBadHandshake", TestRespOnBadHandshake }, - { "TestHost", TestHost }, - { "TestDialCompression", TestDialCompression }, - { "TestTracingDialWithContext", TestTracingDialWithContext }, - { "TestEmptyTracingDialWithContext", TestEmptyTracingDialWithContext }, - { "TestNetDialConnect", TestNetDialConnect }, - { "TestNextProtos", TestNextProtos }, - { "TestDataReceivedBeforeHandshake", TestDataReceivedBeforeHandshake }, { "TestHostPortNoPort", TestHostPortNoPort }, { "TestTruncWriter", TestTruncWriter }, { "TestValidCompressionLevel", TestValidCompressionLevel }, |