summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--src/wscat.go470
-rw-r--r--tests/wscat.go799
2 files changed, 13 insertions, 1256 deletions
diff --git a/src/wscat.go b/src/wscat.go
index 6ed0aea..4cf63cb 100644
--- a/src/wscat.go
+++ b/src/wscat.go
@@ -7,7 +7,6 @@ import (
"context"
"crypto/rand"
"crypto/sha1"
- "crypto/tls"
"encoding/base64"
"encoding/binary"
"encoding/json"
@@ -16,7 +15,6 @@ import (
"io"
"net"
"net/http"
- "net/http/httptrace"
"net/url"
"os"
"strconv"
@@ -29,99 +27,12 @@ import (
g "gobang"
)
+
+
// ErrBadHandshake is returned when the server response to opening handshake is
// invalid.
var ErrBadHandshake = errors.New("websocket: bad handshake")
-var errInvalidCompression = errors.New("websocket: invalid compression negotiation")
-
-// NewClient creates a new client connection using the given net connection.
-// The URL u specifies the host and request URI. Use requestHeader to specify
-// the origin (Origin), subprotocols (Sec-WebSocket-Protocol) and cookies
-// (Cookie). Use the response.Header to get the selected subprotocol
-// (Sec-WebSocket-Protocol) and cookies (Set-Cookie).
-//
-// If the WebSocket handshake fails, ErrBadHandshake is returned along with a
-// non-nil *http.Response so that callers can handle redirects, authentication,
-// etc.
-//
-// Deprecated: Use Dialer instead.
-func NewClient(netConn net.Conn, u *url.URL, requestHeader http.Header, readBufSize, writeBufSize int) (c *Conn, response *http.Response, err error) {
- d := Dialer{
- ReadBufferSize: readBufSize,
- WriteBufferSize: writeBufSize,
- NetDial: func(net, addr string) (net.Conn, error) {
- return netConn, nil
- },
- }
- return d.Dial(u.String(), requestHeader)
-}
-
-// A Dialer contains options for connecting to WebSocket server.
-//
-// It is safe to call Dialer's methods concurrently.
-type Dialer struct {
- // NetDial specifies the dial function for creating TCP connections. If
- // NetDial is nil, net.Dialer DialContext is used.
- NetDial func(network, addr string) (net.Conn, error)
-
- // NetDialContext specifies the dial function for creating TCP connections. If
- // NetDialContext is nil, NetDial is used.
- NetDialContext func(ctx context.Context, network, addr string) (net.Conn, error)
-
- // NetDialTLSContext specifies the dial function for creating TLS/TCP connections. If
- // NetDialTLSContext is nil, NetDialContext is used.
- // If NetDialTLSContext is set, Dial assumes the TLS handshake is done there and
- // TLSClientConfig is ignored.
- NetDialTLSContext func(ctx context.Context, network, addr string) (net.Conn, error)
-
- // TLSClientConfig specifies the TLS configuration to use with tls.Client.
- // If nil, the default configuration is used.
- // If either NetDialTLS or NetDialTLSContext are set, Dial assumes the TLS handshake
- // is done there and TLSClientConfig is ignored.
- TLSClientConfig *tls.Config
-
- // HandshakeTimeout specifies the duration for the handshake to complete.
- HandshakeTimeout time.Duration
-
- // ReadBufferSize and WriteBufferSize specify I/O buffer sizes in bytes. If a buffer
- // size is zero, then a useful default size is used. The I/O buffer sizes
- // do not limit the size of the messages that can be sent or received.
- ReadBufferSize, WriteBufferSize int
-
- // WriteBufferPool is a pool of buffers for write operations. If the value
- // is not set, then write buffers are allocated to the connection for the
- // lifetime of the connection.
- //
- // A pool is most useful when the application has a modest volume of writes
- // across a large number of connections.
- //
- // Applications should use a single pool for each unique value of
- // WriteBufferSize.
- WriteBufferPool BufferPool
-
- // Subprotocols specifies the client's requested subprotocols.
- Subprotocols []string
-
- // EnableCompression specifies if the client should attempt to negotiate
- // per message compression (RFC 7692). Setting this value to true does not
- // guarantee that compression will be supported. Currently only "no context
- // takeover" modes are supported.
- EnableCompression bool
-
- // Jar specifies the cookie jar.
- // If Jar is nil, cookies are not sent in requests and ignored
- // in responses.
- Jar http.CookieJar
-}
-
-// Dial creates a new client connection by calling DialContext with a background context.
-func (d *Dialer) Dial(urlStr string, requestHeader http.Header) (*Conn, *http.Response, error) {
- return d.DialContext(context.Background(), urlStr, requestHeader)
-}
-
-var errMalformedURL = errors.New("malformed ws or wss URL")
-
func hostPortNoPort(u *url.URL) (hostPort, hostNoPort string) {
hostPort = u.Host
hostNoPort = u.Host
@@ -140,293 +51,6 @@ func hostPortNoPort(u *url.URL) (hostPort, hostNoPort string) {
return hostPort, hostNoPort
}
-// DefaultDialer is a dialer with all fields set to the default values.
-var DefaultDialer = &Dialer{
- HandshakeTimeout: 45 * time.Second,
-}
-
-// nilDialer is dialer to use when receiver is nil.
-var nilDialer = *DefaultDialer
-
-// DialContext creates a new client connection. Use requestHeader to specify the
-// origin (Origin), subprotocols (Sec-WebSocket-Protocol) and cookies (Cookie).
-// Use the response.Header to get the selected subprotocol
-// (Sec-WebSocket-Protocol) and cookies (Set-Cookie).
-//
-// The context will be used in the request and in the Dialer.
-//
-// If the WebSocket handshake fails, ErrBadHandshake is returned along with a
-// non-nil *http.Response so that callers can handle redirects, authentication,
-// etcetera. The response body may not contain the entire response and does not
-// need to be closed by the application.
-func (d *Dialer) DialContext(ctx context.Context, urlStr string, requestHeader http.Header) (*Conn, *http.Response, error) {
- if d == nil {
- d = &nilDialer
- }
-
- challengeKey, err := generateChallengeKey()
- if err != nil {
- return nil, nil, err
- }
-
- u, err := url.Parse(urlStr)
- if err != nil {
- return nil, nil, err
- }
-
- switch u.Scheme {
- case "ws":
- u.Scheme = "http"
- case "wss":
- u.Scheme = "https"
- default:
- return nil, nil, errMalformedURL
- }
-
- if u.User != nil {
- // User name and password are not allowed in websocket URIs.
- return nil, nil, errMalformedURL
- }
-
- req := &http.Request{
- Method: http.MethodGet,
- URL: u,
- Proto: "HTTP/1.1",
- ProtoMajor: 1,
- ProtoMinor: 1,
- Header: make(http.Header),
- Host: u.Host,
- }
- req = req.WithContext(ctx)
-
- // Set the cookies present in the cookie jar of the dialer
- if d.Jar != nil {
- for _, cookie := range d.Jar.Cookies(u) {
- req.AddCookie(cookie)
- }
- }
-
- // Set the request headers using the capitalization for names and values in
- // RFC examples. Although the capitalization shouldn't matter, there are
- // servers that depend on it. The Header.Set method is not used because the
- // method canonicalizes the header names.
- req.Header["Upgrade"] = []string{"websocket"}
- req.Header["Connection"] = []string{"Upgrade"}
- req.Header["Sec-WebSocket-Key"] = []string{challengeKey}
- req.Header["Sec-WebSocket-Version"] = []string{"13"}
- if len(d.Subprotocols) > 0 {
- req.Header["Sec-WebSocket-Protocol"] = []string{strings.Join(d.Subprotocols, ", ")}
- }
- for k, vs := range requestHeader {
- switch {
- case k == "Host":
- if len(vs) > 0 {
- req.Host = vs[0]
- }
- case k == "Upgrade" ||
- k == "Connection" ||
- k == "Sec-Websocket-Key" ||
- k == "Sec-Websocket-Version" ||
- k == "Sec-Websocket-Extensions" ||
- (k == "Sec-Websocket-Protocol" && len(d.Subprotocols) > 0):
- return nil, nil, errors.New("websocket: duplicate header not allowed: " + k)
- case k == "Sec-Websocket-Protocol":
- req.Header["Sec-WebSocket-Protocol"] = vs
- default:
- req.Header[k] = vs
- }
- }
-
- if d.EnableCompression {
- req.Header["Sec-WebSocket-Extensions"] = []string{"permessage-deflate; server_no_context_takeover; client_no_context_takeover"}
- }
-
- if d.HandshakeTimeout != 0 {
- var cancel func()
- ctx, cancel = context.WithTimeout(ctx, d.HandshakeTimeout)
- defer cancel()
- }
-
- var netDial netDialerFunc
- switch {
- case u.Scheme == "https" && d.NetDialTLSContext != nil:
- netDial = d.NetDialTLSContext
- case d.NetDialContext != nil:
- netDial = d.NetDialContext
- case d.NetDial != nil:
- netDial = func(ctx context.Context, net, addr string) (net.Conn, error) {
- return d.NetDial(net, addr)
- }
- default:
- netDial = (&net.Dialer{}).DialContext
- }
-
- // If needed, wrap the dial function to set the connection deadline.
- if deadline, ok := ctx.Deadline(); ok {
- forwardDial := netDial
- netDial = func(ctx context.Context, network, addr string) (net.Conn, error) {
- c, err := forwardDial(ctx, network, addr)
- if err != nil {
- return nil, err
- }
- err = c.SetDeadline(deadline)
- if err != nil {
- c.Close()
- return nil, err
- }
- return c, nil
- }
- }
-
- hostPort, hostNoPort := hostPortNoPort(u)
- trace := httptrace.ContextClientTrace(ctx)
- if trace != nil && trace.GetConn != nil {
- trace.GetConn(hostPort)
- }
-
- netConn, err := netDial(ctx, "tcp", hostPort)
- if err != nil {
- return nil, nil, err
- }
- if trace != nil && trace.GotConn != nil {
- trace.GotConn(httptrace.GotConnInfo{
- Conn: netConn,
- })
- }
-
- // Close the network connection when returning an error. The variable
- // netConn is set to nil before the success return at the end of the
- // function.
- defer func() {
- if netConn != nil {
- // It's safe to ignore the error from Close() because this code is
- // only executed when returning a more important error to the
- // application.
- _ = netConn.Close()
- }
- }()
-
- if u.Scheme == "https" && d.NetDialTLSContext == nil {
- // If NetDialTLSContext is set, assume that the TLS handshake has already been done
-
- cfg := cloneTLSConfig(d.TLSClientConfig)
- if cfg.ServerName == "" {
- cfg.ServerName = hostNoPort
- }
- tlsConn := tls.Client(netConn, cfg)
- netConn = tlsConn
-
- if trace != nil && trace.TLSHandshakeStart != nil {
- trace.TLSHandshakeStart()
- }
- err := doHandshake(ctx, tlsConn, cfg)
- if trace != nil && trace.TLSHandshakeDone != nil {
- trace.TLSHandshakeDone(tlsConn.ConnectionState(), err)
- }
-
- if err != nil {
- return nil, nil, err
- }
- }
-
- conn := newConn(netConn, false, d.ReadBufferSize, d.WriteBufferSize, d.WriteBufferPool, nil, nil)
-
- if err := req.Write(netConn); err != nil {
- return nil, nil, err
- }
-
- if trace != nil && trace.GotFirstResponseByte != nil {
- if peek, err := conn.br.Peek(1); err == nil && len(peek) == 1 {
- trace.GotFirstResponseByte()
- }
- }
-
- resp, err := http.ReadResponse(conn.br, req)
- if err != nil {
- if d.TLSClientConfig != nil {
- for _, proto := range d.TLSClientConfig.NextProtos {
- if proto != "http/1.1" {
- return nil, nil, fmt.Errorf(
- "websocket: protocol %q was given but is not supported;"+
- "sharing tls.Config with net/http Transport can cause this error: %w",
- proto, err,
- )
- }
- }
- }
- return nil, nil, err
- }
-
- if d.Jar != nil {
- if rc := resp.Cookies(); len(rc) > 0 {
- d.Jar.SetCookies(u, rc)
- }
- }
-
- if resp.StatusCode != 101 ||
- !tokenListContainsValue(resp.Header, "Upgrade", "websocket") ||
- !tokenListContainsValue(resp.Header, "Connection", "upgrade") ||
- resp.Header.Get("Sec-Websocket-Accept") != computeAcceptKey(challengeKey) {
- // Before closing the network connection on return from this
- // function, slurp up some of the response to aid application
- // debugging.
- buf := make([]byte, 1024)
- n, _ := io.ReadFull(resp.Body, buf)
- resp.Body = io.NopCloser(bytes.NewReader(buf[:n]))
- return nil, resp, ErrBadHandshake
- }
-
- for _, ext := range parseExtensions(resp.Header) {
- if ext[""] != "permessage-deflate" {
- continue
- }
- _, snct := ext["server_no_context_takeover"]
- _, cnct := ext["client_no_context_takeover"]
- if !snct || !cnct {
- return nil, resp, errInvalidCompression
- }
- conn.newCompressionWriter = compressNoContextTakeover
- conn.newDecompressionReader = decompressNoContextTakeover
- break
- }
-
- resp.Body = io.NopCloser(bytes.NewReader([]byte{}))
- conn.subprotocol = resp.Header.Get("Sec-Websocket-Protocol")
-
- if err := netConn.SetDeadline(time.Time{}); err != nil {
- return nil, resp, err
- }
-
- // Success! Set netConn to nil to stop the deferred function above from
- // closing the network connection.
- netConn = nil
-
- return conn, resp, nil
-}
-
-func cloneTLSConfig(cfg *tls.Config) *tls.Config {
- if cfg == nil {
- return &tls.Config{}
- }
- return cfg.Clone()
-}
-
-func doHandshake(ctx context.Context, tlsConn *tls.Conn, cfg *tls.Config) error {
- if err := tlsConn.HandshakeContext(ctx); err != nil {
- return err
- }
- if !cfg.InsecureSkipVerify {
- if err := tlsConn.VerifyHostname(cfg.ServerName); err != nil {
- return err
- }
- }
- return nil
-}
-// Copyright 2017 The Gorilla WebSocket Authors. All rights reserved.
-// Use of this source code is governed by a BSD-style
-// license that can be found in the LICENSE file.
-
-
const (
minCompressionLevel = -2 // flate.HuffmanOnly not defined in Go < 1.6
maxCompressionLevel = flate.BestCompression
@@ -565,10 +189,6 @@ func (r *flateReadWrapper) Close() error {
r.fr = nil
return err
}
-// Copyright 2013 The Gorilla WebSocket Authors. All rights reserved.
-// Use of this source code is governed by a BSD-style
-// license that can be found in the LICENSE file.
-
const (
// Frame header byte 0 bits from Section 5.2 of RFC 6455
@@ -2287,70 +1907,6 @@ func (fn netDialerFunc) DialContext(ctx context.Context, network, addr string) (
return fn(ctx, network, addr)
}
-type httpProxyDialer struct {
- proxyURL *url.URL
- forwardDial netDialerFunc
-}
-
-func (hpd *httpProxyDialer) DialContext(ctx context.Context, network string, addr string) (net.Conn, error) {
- hostPort, _ := hostPortNoPort(hpd.proxyURL)
- conn, err := hpd.forwardDial(ctx, network, hostPort)
- if err != nil {
- return nil, err
- }
-
- connectHeader := make(http.Header)
- if user := hpd.proxyURL.User; user != nil {
- proxyUser := user.Username()
- if proxyPassword, passwordSet := user.Password(); passwordSet {
- credential := base64.StdEncoding.EncodeToString([]byte(proxyUser + ":" + proxyPassword))
- connectHeader.Set("Proxy-Authorization", "Basic "+credential)
- }
- }
-
- connectReq := &http.Request{
- Method: http.MethodConnect,
- URL: &url.URL{Opaque: addr},
- Host: addr,
- Header: connectHeader,
- }
-
- if err := connectReq.Write(conn); err != nil {
- conn.Close()
- return nil, err
- }
-
- // Read response. It's OK to use and discard buffered reader here because
- // the remote server does not speak until spoken to.
- br := bufio.NewReader(conn)
- resp, err := http.ReadResponse(br, connectReq)
- if err != nil {
- conn.Close()
- return nil, err
- }
-
- // Close the response body to silence false positives from linters. Reset
- // the buffered reader first to ensure that Close() does not read from
- // conn.
- // Note: Applications must call resp.Body.Close() on a response returned
- // http.ReadResponse to inspect trailers or read another response from the
- // buffered reader. The call to resp.Body.Close() does not release
- // resources.
- br.Reset(bytes.NewReader(nil))
- _ = resp.Body.Close()
-
- if resp.StatusCode != http.StatusOK {
- _ = conn.Close()
- f := strings.SplitN(resp.Status, " ", 2)
- return nil, errors.New(f[1])
- }
- return conn, nil
-}
-// Copyright 2013 The Gorilla WebSocket Authors. All rights reserved.
-// Use of this source code is governed by a BSD-style
-// license that can be found in the LICENSE file.
-
-
// HandshakeError describes an error with the handshake from the peer.
type HandshakeError struct {
message string
@@ -3010,7 +2566,7 @@ var EmitActiveConnection = g.MakeGauge("active-connections")
-func ParseArgs(args []string) CLIArgs {
+func parseArgs(args []string) CLIArgs {
if len(args) != 3 {
fmt.Fprintf(
os.Stderr,
@@ -3038,12 +2594,11 @@ func copyData(c chan struct {}, from io.Reader, to io.WriteCloser) {
}
func Start(toAddr string, listener net.Listener) {
- /*
- upgrader := websocket.Upgrader {}
+ upgrader := Upgrader {}
http.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) {
connFrom, err := upgrader.Upgrade(w, r, nil)
if err != nil {
- g.Warning(
+ g.Error(
"Error upgrading connection",
"upgrade-connection-error",
"err", err,
@@ -3057,16 +2612,16 @@ func Start(toAddr string, listener net.Listener) {
if err != nil {
g.Error(
"Error dialing connection",
- "dial-connection",
+ "dial-connection-error",
"err", err,
)
- os.Exit(1)
+ return
}
defer connTo.Close()
messageType, reader, err := connFrom.NextReader()
if err != nil {
- g.Warning(
+ g.Error(
"Failed to get next reader from connection",
"connection-next-reader-error",
"err", err,
@@ -3076,9 +2631,9 @@ func Start(toAddr string, listener net.Listener) {
writer, err := connFrom.NextWriter(messageType)
if err != nil {
- g.Warning(
- "Failed to get next reader from connection",
- "connection-next-reader-error",
+ g.Error(
+ "Failed to get next writer from connection",
+ "connection-next-writer-error",
"err", err,
)
return
@@ -3097,13 +2652,12 @@ func Start(toAddr string, listener net.Listener) {
server := http.Server{}
err := server.Serve(listener)
g.FatalIf(err)
- */
}
func Main() {
g.Init()
- args := ParseArgs(os.Args)
+ args := parseArgs(os.Args)
listener := Listen(args.FromAddr)
Start(args.ToAddr, listener)
}
diff --git a/tests/wscat.go b/tests/wscat.go
index 6ef1446..80317ed 100644
--- a/tests/wscat.go
+++ b/tests/wscat.go
@@ -4,20 +4,15 @@ import (
"bufio"
"bytes"
"compress/flate"
- "context"
- "crypto/tls"
"crypto/x509"
"encoding/json"
"errors"
"fmt"
"io"
- "log"
"math/rand"
"net"
"net/http"
- "net/http/cookiejar"
"net/http/httptest"
- "net/http/httptrace"
"net/url"
"os"
"reflect"
@@ -44,13 +39,6 @@ var cstUpgrader = Upgrader{
},
}
-var cstDialer = Dialer{
- Subprotocols: []string{"p1", "p2"},
- ReadBufferSize: 1024,
- WriteBufferSize: 1024,
- HandshakeTimeout: 30 * time.Second,
-}
-
type cstHandler struct {
*testing.T
s *cstServer
@@ -107,12 +95,6 @@ func (t cstHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
http.Error(w, "bad path", http.StatusBadRequest)
return
}
- subprotos := Subprotocols(r)
- if !reflect.DeepEqual(subprotos, cstDialer.Subprotocols) {
- t.Logf("subprotols=%v, want %v", subprotos, cstDialer.Subprotocols)
- http.Error(w, "bad protocol", http.StatusBadRequest)
- return
- }
ws, err := cstUpgrader.Upgrade(w, r, http.Header{"Set-Cookie": {"sessionID=1234"}})
if err != nil {
t.Logf("Upgrade: %v", err)
@@ -169,66 +151,6 @@ func sendRecv(t *testing.T, ws *Conn) {
}
}
-func TestDial(t *testing.T) {
- s := newServer(t)
- defer s.Close()
-
- ws, _, err := cstDialer.Dial(s.URL, nil)
- if err != nil {
- t.Fatalf("Dial: %v", err)
- }
- defer ws.Close()
- sendRecv(t, ws)
-}
-
-func TestDialCookieJar(t *testing.T) {
- s := newServer(t)
- defer s.Close()
-
- jar, _ := cookiejar.New(nil)
- d := cstDialer
- d.Jar = jar
-
- u, _ := url.Parse(s.URL)
-
- switch u.Scheme {
- case "ws":
- u.Scheme = "http"
- case "wss":
- u.Scheme = "https"
- }
-
- cookies := []*http.Cookie{{Name: "gorilla", Value: "ws", Path: "/"}}
- d.Jar.SetCookies(u, cookies)
-
- ws, _, err := d.Dial(s.URL, nil)
- if err != nil {
- t.Fatalf("Dial: %v", err)
- }
- defer ws.Close()
-
- var gorilla string
- var sessionID string
- for _, c := range d.Jar.Cookies(u) {
- if c.Name == "gorilla" {
- gorilla = c.Value
- }
-
- if c.Name == "sessionID" {
- sessionID = c.Value
- }
- }
- if gorilla != "ws" {
- t.Error("Cookie not present in jar.")
- }
-
- if sessionID != "1234" {
- t.Error("Set-Cookie not received from the server.")
- }
-
- sendRecv(t, ws)
-}
-
func rootCAs(t *testing.T, s *httptest.Server) *x509.CertPool {
certs := x509.NewCertPool()
for _, c := range s.TLS.Certificates {
@@ -243,33 +165,6 @@ func rootCAs(t *testing.T, s *httptest.Server) *x509.CertPool {
return certs
}
-func TestDialTLS(t *testing.T) {
- s := newTLSServer(t)
- defer s.Close()
-
- d := cstDialer
- d.TLSClientConfig = &tls.Config{RootCAs: rootCAs(t, s.Server)}
- ws, _, err := d.Dial(s.URL, nil)
- if err != nil {
- t.Fatalf("Dial: %v", err)
- }
- defer ws.Close()
- sendRecv(t, ws)
-}
-
-func TestDialTimeout(t *testing.T) {
- s := newServer(t)
- defer s.Close()
-
- d := cstDialer
- d.HandshakeTimeout = -1
- ws, _, err := d.Dial(s.URL, nil)
- if err == nil {
- ws.Close()
- t.Fatalf("Dial: nil")
- }
-}
-
// requireDeadlineNetConn fails the current test when Read or Write are called
// with no deadline.
type requireDeadlineNetConn struct {
@@ -313,90 +208,6 @@ func (c *requireDeadlineNetConn) Close() error { return c.c.Close() }
func (c *requireDeadlineNetConn) LocalAddr() net.Addr { return c.c.LocalAddr() }
func (c *requireDeadlineNetConn) RemoteAddr() net.Addr { return c.c.RemoteAddr() }
-func TestHandshakeTimeout(t *testing.T) {
- s := newServer(t)
- defer s.Close()
-
- d := cstDialer
- d.NetDial = func(n, a string) (net.Conn, error) {
- c, err := net.Dial(n, a)
- return &requireDeadlineNetConn{c: c, t: t}, err
- }
- ws, _, err := d.Dial(s.URL, nil)
- if err != nil {
- t.Fatal("Dial:", err)
- }
- ws.Close()
-}
-
-func TestHandshakeTimeoutInContext(t *testing.T) {
- s := newServer(t)
- defer s.Close()
-
- d := cstDialer
- d.HandshakeTimeout = 0
- d.NetDialContext = func(ctx context.Context, n, a string) (net.Conn, error) {
- netDialer := &net.Dialer{}
- c, err := netDialer.DialContext(ctx, n, a)
- return &requireDeadlineNetConn{c: c, t: t}, err
- }
-
- ctx, cancel := context.WithDeadline(context.Background(), time.Now().Add(30*time.Second))
- defer cancel()
- ws, _, err := d.DialContext(ctx, s.URL, nil)
- if err != nil {
- t.Fatal("Dial:", err)
- }
- ws.Close()
-}
-
-func TestDialBadScheme(t *testing.T) {
- s := newServer(t)
- defer s.Close()
-
- ws, _, err := cstDialer.Dial(s.Server.URL, nil)
- if err == nil {
- ws.Close()
- t.Fatalf("Dial: nil")
- }
-}
-
-func TestDialBadOrigin(t *testing.T) {
- s := newServer(t)
- defer s.Close()
-
- ws, resp, err := cstDialer.Dial(s.URL, http.Header{"Origin": {"bad"}})
- if err == nil {
- ws.Close()
- t.Fatalf("Dial: nil")
- }
- if resp == nil {
- t.Fatalf("resp=nil, err=%v", err)
- }
- if resp.StatusCode != http.StatusForbidden {
- t.Fatalf("status=%d, want %d", resp.StatusCode, http.StatusForbidden)
- }
-}
-
-func TestDialBadHeader(t *testing.T) {
- s := newServer(t)
- defer s.Close()
-
- for _, k := range []string{"Upgrade",
- "Connection",
- "Sec-Websocket-Key",
- "Sec-Websocket-Version",
- "Sec-Websocket-Protocol"} {
- h := http.Header{}
- h.Set(k, "bad")
- ws, _, err := cstDialer.Dial(s.URL, http.Header{"Origin": {"bad"}})
- if err == nil {
- ws.Close()
- t.Errorf("Dial with header %s returned nil", k)
- }
- }
-}
-
func TestBadMethod(t *testing.T) {
s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
ws, err := cstUpgrader.Upgrade(w, r, nil)
@@ -456,83 +267,6 @@ func TestNoUpgrade(t *testing.T) {
}
}
-func TestDialExtraTokensInRespHeaders(t *testing.T) {
- s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
- challengeKey := r.Header.Get("Sec-Websocket-Key")
- w.Header().Set("Upgrade", "foo, websocket")
- w.Header().Set("Connection", "upgrade, keep-alive")
- w.Header().Set("Sec-Websocket-Accept", computeAcceptKey(challengeKey))
- w.WriteHeader(101)
- }))
- defer s.Close()
-
- ws, _, err := cstDialer.Dial(makeWsProto(s.URL), nil)
- if err != nil {
- t.Fatalf("Dial: %v", err)
- }
- defer ws.Close()
-}
-
-func TestHandshake(t *testing.T) {
- s := newServer(t)
- defer s.Close()
-
- ws, resp, err := cstDialer.Dial(s.URL, http.Header{"Origin": {s.URL}})
- if err != nil {
- t.Fatalf("Dial: %v", err)
- }
- defer ws.Close()
-
- var sessionID string
- for _, c := range resp.Cookies() {
- if c.Name == "sessionID" {
- sessionID = c.Value
- }
- }
- if sessionID != "1234" {
- t.Error("Set-Cookie not received from the server.")
- }
-
- if ws.Subprotocol() != "p1" {
- t.Errorf("ws.Subprotocol() = %s, want p1", ws.Subprotocol())
- }
- sendRecv(t, ws)
-}
-
-func TestRespOnBadHandshake(t *testing.T) {
- const expectedStatus = http.StatusGone
- const expectedBody = "This is the response body."
-
- s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
- w.WriteHeader(expectedStatus)
- _, _ = io.WriteString(w, expectedBody)
- }))
- defer s.Close()
-
- ws, resp, err := cstDialer.Dial(makeWsProto(s.URL), nil)
- if err == nil {
- ws.Close()
- t.Fatalf("Dial: nil")
- }
-
- if resp == nil {
- t.Fatalf("resp=nil, err=%v", err)
- }
-
- if resp.StatusCode != expectedStatus {
- t.Errorf("resp.StatusCode=%d, want %d", resp.StatusCode, expectedStatus)
- }
-
- p, err := io.ReadAll(resp.Body)
- if err != nil {
- t.Fatalf("ReadFull(resp.Body) returned error %v", err)
- }
-
- if string(p) != expectedBody {
- t.Errorf("resp.Body=%s, want %s", p, expectedBody)
- }
-}
-
type testLogWriter struct {
t *testing.T
}
@@ -542,492 +276,6 @@ func (w testLogWriter) Write(p []byte) (int, error) {
return len(p), nil
}
-// TestHost tests handling of host names and confirms that it matches net/http.
-func TestHost(t *testing.T) {
-
- upgrader := Upgrader{}
- handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
- if IsWebSocketUpgrade(r) {
- c, err := upgrader.Upgrade(w, r, http.Header{"X-Test-Host": {r.Host}})
- if err != nil {
- t.Fatal(err)
- }
- c.Close()
- } else {
- w.Header().Set("X-Test-Host", r.Host)
- }
- })
-
- server := httptest.NewServer(handler)
- defer server.Close()
-
- tlsServer := httptest.NewTLSServer(handler)
- defer tlsServer.Close()
-
- addrs := map[*httptest.Server]string{server: server.Listener.Addr().String(), tlsServer: tlsServer.Listener.Addr().String()}
- wsProtos := map[*httptest.Server]string{server: "ws://", tlsServer: "wss://"}
- httpProtos := map[*httptest.Server]string{server: "http://", tlsServer: "https://"}
-
- // Avoid log noise from net/http server by logging to testing.T
- server.Config.ErrorLog = log.New(testLogWriter{t}, "", 0)
- tlsServer.Config.ErrorLog = server.Config.ErrorLog
-
- cas := rootCAs(t, tlsServer)
-
- tests := []struct {
- fail bool // true if dial / get should fail
- server *httptest.Server // server to use
- url string // host for request URI
- header string // optional request host header
- tls string // optional host for tls ServerName
- wantAddr string // expected host for dial
- wantHeader string // expected request header on server
- insecureSkipVerify bool
- }{
- {
- server: server,
- url: addrs[server],
- wantAddr: addrs[server],
- wantHeader: addrs[server],
- },
- {
- server: tlsServer,
- url: addrs[tlsServer],
- wantAddr: addrs[tlsServer],
- wantHeader: addrs[tlsServer],
- },
-
- {
- server: server,
- url: addrs[server],
- header: "badhost.com",
- wantAddr: addrs[server],
- wantHeader: "badhost.com",
- },
- {
- server: tlsServer,
- url: addrs[tlsServer],
- header: "badhost.com",
- wantAddr: addrs[tlsServer],
- wantHeader: "badhost.com",
- },
-
- {
- server: server,
- url: "example.com",
- header: "badhost.com",
- wantAddr: "example.com:80",
- wantHeader: "badhost.com",
- },
- {
- server: tlsServer,
- url: "example.com",
- header: "badhost.com",
- wantAddr: "example.com:443",
- wantHeader: "badhost.com",
- },
-
- {
- server: server,
- url: "badhost.com",
- header: "example.com",
- wantAddr: "badhost.com:80",
- wantHeader: "example.com",
- },
- {
- fail: true,
- server: tlsServer,
- url: "badhost.com",
- header: "example.com",
- wantAddr: "badhost.com:443",
- },
- {
- server: tlsServer,
- url: "badhost.com",
- insecureSkipVerify: true,
- wantAddr: "badhost.com:443",
- wantHeader: "badhost.com",
- },
- {
- server: tlsServer,
- url: "badhost.com",
- tls: "example.com",
- wantAddr: "badhost.com:443",
- wantHeader: "badhost.com",
- },
- }
-
- for i, tt := range tests {
-
- tls := &tls.Config{
- RootCAs: cas,
- ServerName: tt.tls,
- InsecureSkipVerify: tt.insecureSkipVerify,
- }
-
- var gotAddr string
- dialer := Dialer{
- NetDial: func(network, addr string) (net.Conn, error) {
- gotAddr = addr
- return net.Dial(network, addrs[tt.server])
- },
- TLSClientConfig: tls,
- }
-
- // Test websocket dial
-
- h := http.Header{}
- if tt.header != "" {
- h.Set("Host", tt.header)
- }
- c, resp, err := dialer.Dial(wsProtos[tt.server]+tt.url+"/", h)
- if err == nil {
- c.Close()
- }
-
- check := func(protos map[*httptest.Server]string) {
- name := fmt.Sprintf("%d: %s%s/ header[Host]=%q, tls.ServerName=%q", i+1, protos[tt.server], tt.url, tt.header, tt.tls)
- if gotAddr != tt.wantAddr {
- t.Errorf("%s: got addr %s, want %s", name, gotAddr, tt.wantAddr)
- }
- switch {
- case tt.fail && err == nil:
- t.Errorf("%s: unexpected success", name)
- case !tt.fail && err != nil:
- t.Errorf("%s: unexpected error %v", name, err)
- case !tt.fail && err == nil:
- if gotHost := resp.Header.Get("X-Test-Host"); gotHost != tt.wantHeader {
- t.Errorf("%s: got host %s, want %s", name, gotHost, tt.wantHeader)
- }
- }
- }
-
- check(wsProtos)
-
- // Confirm that net/http has same result
-
- transport := &http.Transport{
- Dial: dialer.NetDial,
- TLSClientConfig: dialer.TLSClientConfig,
- }
- req, _ := http.NewRequest(http.MethodGet, httpProtos[tt.server]+tt.url+"/", nil)
- if tt.header != "" {
- req.Host = tt.header
- }
- client := &http.Client{Transport: transport}
- resp, err = client.Do(req)
- if err == nil {
- resp.Body.Close()
- }
- transport.CloseIdleConnections()
- check(httpProtos)
- }
-}
-
-func TestDialCompression(t *testing.T) {
- s := newServer(t)
- defer s.Close()
-
- dialer := cstDialer
- dialer.EnableCompression = true
- ws, _, err := dialer.Dial(s.URL, nil)
- if err != nil {
- t.Fatalf("Dial: %v", err)
- }
- defer ws.Close()
- sendRecv(t, ws)
-}
-
-func TestTracingDialWithContext(t *testing.T) {
-
- var headersWrote, requestWrote, getConn, gotConn, connectDone, gotFirstResponseByte bool
- trace := &httptrace.ClientTrace{
- WroteHeaders: func() {
- headersWrote = true
- },
- WroteRequest: func(httptrace.WroteRequestInfo) {
- requestWrote = true
- },
- GetConn: func(hostPort string) {
- getConn = true
- },
- GotConn: func(info httptrace.GotConnInfo) {
- gotConn = true
- },
- ConnectDone: func(network, addr string, err error) {
- connectDone = true
- },
- GotFirstResponseByte: func() {
- gotFirstResponseByte = true
- },
- }
- ctx := httptrace.WithClientTrace(context.Background(), trace)
-
- s := newTLSServer(t)
- defer s.Close()
-
- d := cstDialer
- d.TLSClientConfig = &tls.Config{RootCAs: rootCAs(t, s.Server)}
-
- ws, _, err := d.DialContext(ctx, s.URL, nil)
- if err != nil {
- t.Fatalf("Dial: %v", err)
- }
-
- if !headersWrote {
- t.Fatal("Headers was not written")
- }
- if !requestWrote {
- t.Fatal("Request was not written")
- }
- if !getConn {
- t.Fatal("getConn was not called")
- }
- if !gotConn {
- t.Fatal("gotConn was not called")
- }
- if !connectDone {
- t.Fatal("connectDone was not called")
- }
- if !gotFirstResponseByte {
- t.Fatal("GotFirstResponseByte was not called")
- }
-
- defer ws.Close()
- sendRecv(t, ws)
-}
-
-func TestEmptyTracingDialWithContext(t *testing.T) {
-
- trace := &httptrace.ClientTrace{}
- ctx := httptrace.WithClientTrace(context.Background(), trace)
-
- s := newTLSServer(t)
- defer s.Close()
-
- d := cstDialer
- d.TLSClientConfig = &tls.Config{RootCAs: rootCAs(t, s.Server)}
-
- ws, _, err := d.DialContext(ctx, s.URL, nil)
- if err != nil {
- t.Fatalf("Dial: %v", err)
- }
-
- defer ws.Close()
- sendRecv(t, ws)
-}
-
-// TestNetDialConnect tests selection of dial method between NetDial, NetDialContext, NetDialTLS or NetDialTLSContext
-func TestNetDialConnect(t *testing.T) {
-
- upgrader := Upgrader{}
- handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
- if IsWebSocketUpgrade(r) {
- c, err := upgrader.Upgrade(w, r, http.Header{"X-Test-Host": {r.Host}})
- if err != nil {
- t.Fatal(err)
- }
- c.Close()
- } else {
- w.Header().Set("X-Test-Host", r.Host)
- }
- })
-
- server := httptest.NewServer(handler)
- defer server.Close()
-
- tlsServer := httptest.NewTLSServer(handler)
- defer tlsServer.Close()
-
- testUrls := map[*httptest.Server]string{
- server: "ws://" + server.Listener.Addr().String() + "/",
- tlsServer: "wss://" + tlsServer.Listener.Addr().String() + "/",
- }
-
- cas := rootCAs(t, tlsServer)
- tlsConfig := &tls.Config{
- RootCAs: cas,
- ServerName: "example.com",
- InsecureSkipVerify: false,
- }
-
- tests := []struct {
- name string
- server *httptest.Server // server to use
- netDial func(network, addr string) (net.Conn, error)
- netDialContext func(ctx context.Context, network, addr string) (net.Conn, error)
- netDialTLSContext func(ctx context.Context, network, addr string) (net.Conn, error)
- tlsClientConfig *tls.Config
- }{
-
- {
- name: "HTTP server, all NetDial* defined, shall use NetDialContext",
- server: server,
- netDial: func(network, addr string) (net.Conn, error) {
- return nil, errors.New("NetDial should not be called")
- },
- netDialContext: func(_ context.Context, network, addr string) (net.Conn, error) {
- return net.Dial(network, addr)
- },
- netDialTLSContext: func(_ context.Context, network, addr string) (net.Conn, error) {
- return nil, errors.New("NetDialTLSContext should not be called")
- },
- tlsClientConfig: nil,
- },
- {
- name: "HTTP server, all NetDial* undefined",
- server: server,
- netDial: nil,
- netDialContext: nil,
- netDialTLSContext: nil,
- tlsClientConfig: nil,
- },
- {
- name: "HTTP server, NetDialContext undefined, shall fallback to NetDial",
- server: server,
- netDial: func(network, addr string) (net.Conn, error) {
- return net.Dial(network, addr)
- },
- netDialContext: nil,
- netDialTLSContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
- return nil, errors.New("NetDialTLSContext should not be called")
- },
- tlsClientConfig: nil,
- },
- {
- name: "HTTPS server, all NetDial* defined, shall use NetDialTLSContext",
- server: tlsServer,
- netDial: func(network, addr string) (net.Conn, error) {
- return nil, errors.New("NetDial should not be called")
- },
- netDialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
- return nil, errors.New("NetDialContext should not be called")
- },
- netDialTLSContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
- netConn, err := net.Dial(network, addr)
- if err != nil {
- return nil, err
- }
- tlsConn := tls.Client(netConn, tlsConfig)
- err = tlsConn.Handshake()
- if err != nil {
- return nil, err
- }
- return tlsConn, nil
- },
- tlsClientConfig: nil,
- },
- {
- name: "HTTPS server, NetDialTLSContext undefined, shall fallback to NetDialContext and do handshake",
- server: tlsServer,
- netDial: func(network, addr string) (net.Conn, error) {
- return nil, errors.New("NetDial should not be called")
- },
- netDialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
- return net.Dial(network, addr)
- },
- netDialTLSContext: nil,
- tlsClientConfig: tlsConfig,
- },
- {
- name: "HTTPS server, NetDialTLSContext and NetDialContext undefined, shall fallback to NetDial and do handshake",
- server: tlsServer,
- netDial: func(network, addr string) (net.Conn, error) {
- return net.Dial(network, addr)
- },
- netDialContext: nil,
- netDialTLSContext: nil,
- tlsClientConfig: tlsConfig,
- },
- {
- name: "HTTPS server, all NetDial* undefined",
- server: tlsServer,
- netDial: nil,
- netDialContext: nil,
- netDialTLSContext: nil,
- tlsClientConfig: tlsConfig,
- },
- {
- name: "HTTPS server, all NetDialTLSContext defined, dummy TlsClientConfig defined, shall not do handshake",
- server: tlsServer,
- netDial: func(network, addr string) (net.Conn, error) {
- return nil, errors.New("NetDial should not be called")
- },
- netDialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
- return nil, errors.New("NetDialContext should not be called")
- },
- netDialTLSContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
- netConn, err := net.Dial(network, addr)
- if err != nil {
- return nil, err
- }
- tlsConn := tls.Client(netConn, tlsConfig)
- err = tlsConn.Handshake()
- if err != nil {
- return nil, err
- }
- return tlsConn, nil
- },
- tlsClientConfig: &tls.Config{
- RootCAs: nil,
- ServerName: "badserver.com",
- InsecureSkipVerify: false,
- },
- },
- }
-
- for _, tc := range tests {
- dialer := Dialer{
- NetDial: tc.netDial,
- NetDialContext: tc.netDialContext,
- NetDialTLSContext: tc.netDialTLSContext,
- TLSClientConfig: tc.tlsClientConfig,
- }
-
- // Test websocket dial
- c, _, err := dialer.Dial(testUrls[tc.server], nil)
- if err != nil {
- t.Errorf("FAILED %s, err: %s", tc.name, err.Error())
- } else {
- c.Close()
- }
- }
-}
-func TestNextProtos(t *testing.T) {
- ts := httptest.NewUnstartedServer(
- http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}),
- )
- ts.EnableHTTP2 = true
- ts.StartTLS()
- defer ts.Close()
-
- d := Dialer{
- TLSClientConfig: ts.Client().Transport.(*http.Transport).TLSClientConfig,
- }
-
- r, err := ts.Client().Get(ts.URL)
- if err != nil {
- t.Fatalf("Get: %v", err)
- }
- r.Body.Close()
-
- // Asserts that Dialer.TLSClientConfig.NextProtos contains "h2"
- // after the Client.Get call from net/http above.
- var containsHTTP2 bool = false
- for _, proto := range d.TLSClientConfig.NextProtos {
- if proto == "h2" {
- containsHTTP2 = true
- }
- }
- if !containsHTTP2 {
- t.Fatalf("Dialer.TLSClientConfig.NextProtos does not contain \"h2\"")
- }
-
- _, _, err = d.Dial(makeWsProto(ts.URL), nil)
- if err == nil {
- t.Fatalf("Dial succeeded, expect fail ")
- }
-}
type dataBeforeHandshakeResponseWriter struct {
http.ResponseWriter
@@ -1065,32 +313,6 @@ func (w dataBeforeHandshakeResponseWriter) Hijack() (net.Conn, *bufio.ReadWriter
return c, rw, err
}
-func TestDataReceivedBeforeHandshake(t *testing.T) {
- s := newServer(t)
- defer s.Close()
-
- origHandler := s.Server.Config.Handler
- s.Server.Config.Handler = http.HandlerFunc(
- func(w http.ResponseWriter, r *http.Request) {
- origHandler.ServeHTTP(dataBeforeHandshakeResponseWriter{w}, r)
- })
-
- for _, readBufferSize := range []int{0, 1024} {
- t.Run(fmt.Sprintf("ReadBufferSize=%d", readBufferSize), func(t *testing.T) {
- dialer := cstDialer
- dialer.ReadBufferSize = readBufferSize
- ws, _, err := cstDialer.Dial(s.URL, nil)
- if err != nil {
- t.Fatalf("Dial: %v", err)
- }
- defer ws.Close()
- _, m, err := ws.ReadMessage()
- if err != nil || string(m) != "Hello" {
- t.Fatalf("ReadMessage() = %q, %v, want \"Hello\", nil", m, err)
- }
- })
- }
-}
// Copyright 2014 The Gorilla WebSocket Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
@@ -2584,7 +1806,7 @@ func TestParseExtensions(t *testing.T) {
}
func TestParseArgs(t *testing.T) {
- given := ParseArgs([]string { "x", "y", "z" })
+ given := parseArgs([]string { "x", "y", "z" })
expected := CLIArgs {
FromAddr: "y",
ToAddr: "z",
@@ -2597,27 +1819,8 @@ func TestParseArgs(t *testing.T) {
func MainTest() {
tests := []testing.InternalTest {
- { "TestDial", TestDial },
- { "TestDialCookieJar", TestDialCookieJar },
- { "TestDialTLS", TestDialTLS },
- { "TestDialTimeout", TestDialTimeout },
- { "TestHandshakeTimeout", TestHandshakeTimeout },
- { "TestHandshakeTimeoutInContext", TestHandshakeTimeoutInContext },
- { "TestDialBadScheme", TestDialBadScheme },
- { "TestDialBadOrigin", TestDialBadOrigin },
- { "TestDialBadHeader", TestDialBadHeader },
{ "TestBadMethod", TestBadMethod },
{ "TestNoUpgrade", TestNoUpgrade },
- { "TestDialExtraTokensInRespHeaders", TestDialExtraTokensInRespHeaders },
- { "TestHandshake", TestHandshake },
- { "TestRespOnBadHandshake", TestRespOnBadHandshake },
- { "TestHost", TestHost },
- { "TestDialCompression", TestDialCompression },
- { "TestTracingDialWithContext", TestTracingDialWithContext },
- { "TestEmptyTracingDialWithContext", TestEmptyTracingDialWithContext },
- { "TestNetDialConnect", TestNetDialConnect },
- { "TestNextProtos", TestNextProtos },
- { "TestDataReceivedBeforeHandshake", TestDataReceivedBeforeHandshake },
{ "TestHostPortNoPort", TestHostPortNoPort },
{ "TestTruncWriter", TestTruncWriter },
{ "TestValidCompressionLevel", TestValidCompressionLevel },