diff options
-rw-r--r-- | .gitignore | 3 | ||||
-rw-r--r-- | Makefile | 72 | ||||
-rw-r--r-- | go.mod | 16 | ||||
-rw-r--r-- | src/lib.go | 118 | ||||
-rw-r--r-- | src/main.go (renamed from src/cmd/main.go) | 2 | ||||
-rw-r--r-- | src/proxy.go | 405 | ||||
-rw-r--r-- | src/socks.go | 473 | ||||
-rw-r--r-- | src/wscat.go | 3148 | ||||
-rw-r--r-- | tests/lib_test.go | 24 | ||||
-rw-r--r-- | tests/main.go | 7 | ||||
-rw-r--r-- | tests/wscat.go | 2842 |
11 files changed, 6933 insertions, 177 deletions
@@ -1,2 +1,5 @@ /*.bin +/src/*.a +/src/*.bin +/tests/*.a /tests/*.bin @@ -8,6 +8,7 @@ LANGUAGES = en PREFIX = /usr BINDIR = $(PREFIX)/bin LIBDIR = $(PREFIX)/lib +GOLIBDIR = $(LIBDIR)/go INCLUDEDIR = $(PREFIX)/include SRCDIR = $(PREFIX)/src/$(NAME) SHAREDIR = $(PREFIX)/share @@ -17,12 +18,11 @@ EXEC = ./ ## Where to store the installation. Empty by default. DESTDIR = LDLIBS = -GOFLAGS = .SUFFIXES: -.SUFFIXES: .go .bin +.SUFFIXES: .go .a .bin .bin-check @@ -30,14 +30,24 @@ all: include deps.mk +objects = \ + src/$(NAME).a \ + src/proxy.a \ + src/socks.a \ + src/main.a \ + tests/$(NAME).a \ + tests/main.a \ + sources = \ - src/lib.go \ - src/cmd/main.go \ + src/$(NAME).go \ + src/main.go \ derived-assets = \ + $(objects) \ + src/main.bin \ + tests/main.bin \ $(NAME).bin \ - tests/lib_test.bin \ side-assets = \ $(NAME).socket \ @@ -49,22 +59,48 @@ side-assets = \ all: $(derived-assets) -$(NAME).bin: src/lib.go src/cmd/main.go Makefile - go build $(GOFLAGS) -v -o $@ src/cmd/main.go +$(objects): Makefile + +src/$(NAME).a: src/$(NAME).go src/proxy.a +src/proxy.a: src/proxy.go src/socks.a +src/socks.a: src/socks.go +src/$(NAME).a src/proxy.a src/socks.a: + go tool compile $(GOCFLAGS) -o $@ -p $(*F) -I src $*.go + +tests/$(NAME).a: tests/$(NAME).go src/$(NAME).go src/proxy.a +tests/$(NAME).a: + go tool compile $(GOCFLAGS) -o $@ -p $(*F) -I src $*.go src/$(*F).go + +src/main.a: src/main.go src/$(NAME).a +tests/main.a: tests/main.go tests/$(NAME).a +src/main.a tests/main.a: + go tool compile $(GOCFLAGS) -o $@ -I $(@D) $*.go + +src/main.bin: src/main.a +tests/main.bin: tests/main.a +src/main.bin tests/main.bin: + go tool link $(GOLDFLAGS) -o $@ -L $(@D) -L src $*.a + +$(NAME).bin: src/main.bin + ln -fs $? $@ + -tests/lib_test.bin: src/lib.go tests/lib_test.go Makefile - go test $(GOFLAGS) -v -o $@ -c $*.go +tests.bin-check = \ + tests/main.bin-check \ +tests/main.bin-check: tests/main.bin +$(tests.bin-check): + $(EXEC)$*.bin -check-unit: tests/lib_test.bin - ./tests/lib_test.bin +check-unit: $(tests.bin-check) integration-tests = \ tests/cli-opts.sh \ tests/integration.sh \ +.PRECIOUS: $(integration-tests) $(integration-tests): $(NAME).bin $(integration-tests): ALWAYS sh $@ @@ -90,21 +126,21 @@ clean: install: all mkdir -p \ '$(DESTDIR)$(BINDIR)' \ + '$(DESTDIR)$(GOLIBDIR)' \ + '$(DESTDIR)$(SRCDIR)' \ cp $(NAME).bin '$(DESTDIR)$(BINDIR)'/$(NAME) - for f in $(sources); do \ - dir='$(DESTDIR)$(SRCDIR)'/"`dirname "$${f#src/}"`"; \ - mkdir -p "$$dir"; \ - cp -P "$$f" "$$dir"; \ - done + cp src/$(NAME).a '$(DESTDIR)$(GOLIBDIR)' + cp $(sources) '$(DESTDIR)$(SRCDIR)' ## Uninstalls from $(DESTDIR)$(PREFIX). This is a perfect mirror ## of the "install" target, and removes *all* that was installed. ## A dedicated test asserts that this is always true. uninstall: rm -rf \ - '$(DESTDIR)$(BINDIR)'/$(NAME) \ - '$(DESTDIR)$(SRCDIR)' \ + '$(DESTDIR)$(BINDIR)'/$(NAME) \ + '$(DESTDIR)$(GOLIBDIR)'/$(NAME).a \ + '$(DESTDIR)$(SRCDIR)' \ @@ -1,16 +0,0 @@ -module euandre.org/wscat - -go 1.21.5 - -require ( - euandre.org/gobang v0.1.0 - github.com/gorilla/websocket v1.5.3 -) - -require golang.org/x/net v0.26.0 // indirect - -replace ( - euandre.org/gobang => ../gobang - github.com/gorilla/websocket => ../websocket - golang.org/x/net => ../netx -) diff --git a/src/lib.go b/src/lib.go deleted file mode 100644 index 3733475..0000000 --- a/src/lib.go +++ /dev/null @@ -1,118 +0,0 @@ -package wscat - -import ( - "fmt" - "net" - "net/http" - "os" - - g "euandre.org/gobang/src" - - "github.com/gorilla/websocket" -) - - - -type CLIArgs struct { - FromAddr string - ToAddr string -} - - - -const X = 1 - -var EmitActiveConnection = g.MakeGauge("active-connections") - - - -func ParseArgs(args []string) CLIArgs { - if len(args) != 3 { - fmt.Fprintf( - os.Stderr, - "Usage: %s FROM.socket TO.socket\n", - args[0], - ) - os.Exit(2) - } - return CLIArgs { - FromAddr: args[1], - ToAddr: args[2], - } -} - -func Listen(fromAddr string) net.Listener { - listener, err := net.Listen("unix", fromAddr) - g.FatalIf(err) - g.Info("Started listening", "listen-start", "from-address", fromAddr) - return listener -} - -func Start(toAddr string, listener net.Listener) { - upgrader := websocket.Upgrader {} - http.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { - connFrom, err := upgrader.Upgrade(w, r, nil) - if err != nil { - g.Warning( - "Error upgrading connection", - "upgrade-connection-error", - "err", err, - ) - return - } - defer connFrom.Close() - EmitActiveConnection.Inc() - - connTo, err := net.Dial("unix", toAddr) - if err != nil { - g.Error( - "Error dialing connection", - "dial-connection", - "err", err, - ) - os.Exit(1) - } - defer connTo.Close() - - messageType, reader, err := connFrom.NextReader() - if err != nil { - g.Warning( - "Failed to get next reader from connection", - "connection-next-reader-error", - "err", err, - ) - return - } - - writer, err := connFrom.NextWriter(messageType) - if err != nil { - g.Warning( - "Failed to get next reader from connection", - "connection-next-reader-error", - "err", err, - ) - return - } - - - c := make(chan g.CopyResult) - go g.CopyData(c, "c2s", connTo, writer) - go g.CopyData(c, "s2c", reader, connTo) - go func() { - <- c - EmitActiveConnection.Dec() - }() - }); - - server := http.Server{} - err := server.Serve(listener) - g.FatalIf(err) -} - - -func Main() { - g.Init() - args := ParseArgs(os.Args) - listener := Listen(args.FromAddr) - Start(args.ToAddr, listener) -} diff --git a/src/cmd/main.go b/src/main.go index e6ae309..ddc8f19 100644 --- a/src/cmd/main.go +++ b/src/main.go @@ -1,6 +1,6 @@ package main -import "euandre.org/wscat/src" +import "wscat" func main() { wscat.Main() diff --git a/src/proxy.go b/src/proxy.go new file mode 100644 index 0000000..7ef4be3 --- /dev/null +++ b/src/proxy.go @@ -0,0 +1,405 @@ +// Copyright 2019 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package proxy + +import ( + "context" + "errors" + "net" + "net/url" + "os" + "strings" + "sync" + + "socks" +) + +// A ContextDialer dials using a context. +type ContextDialer interface { + DialContext(ctx context.Context, network, address string) (net.Conn, error) +} + +// Dial works like DialContext on net.Dialer but using a dialer returned by FromEnvironment. +// +// The passed ctx is only used for returning the Conn, not the lifetime of the Conn. +// +// Custom dialers (registered via RegisterDialerType) that do not implement ContextDialer +// can leak a goroutine for as long as it takes the underlying Dialer implementation to timeout. +// +// A Conn returned from a successful Dial after the context has been cancelled will be immediately closed. +func Dial(ctx context.Context, network, address string) (net.Conn, error) { + d := FromEnvironment() + if xd, ok := d.(ContextDialer); ok { + return xd.DialContext(ctx, network, address) + } + return dialContext(ctx, d, network, address) +} + +// WARNING: this can leak a goroutine for as long as the underlying Dialer implementation takes to timeout +// A Conn returned from a successful Dial after the context has been cancelled will be immediately closed. +func dialContext(ctx context.Context, d Dialer, network, address string) (net.Conn, error) { + var ( + conn net.Conn + done = make(chan struct{}, 1) + err error + ) + go func() { + conn, err = d.Dial(network, address) + close(done) + if conn != nil && ctx.Err() != nil { + conn.Close() + } + }() + select { + case <-ctx.Done(): + err = ctx.Err() + case <-done: + } + return conn, err +} +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + + +type direct struct{} + +// Direct implements Dialer by making network connections directly using net.Dial or net.DialContext. +var Direct = direct{} + +var ( + _ Dialer = Direct + _ ContextDialer = Direct +) + +// Dial directly invokes net.Dial with the supplied parameters. +func (direct) Dial(network, addr string) (net.Conn, error) { + return net.Dial(network, addr) +} + +// DialContext instantiates a net.Dialer and invokes its DialContext receiver with the supplied parameters. +func (direct) DialContext(ctx context.Context, network, addr string) (net.Conn, error) { + var d net.Dialer + return d.DialContext(ctx, network, addr) +} +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + + +// A PerHost directs connections to a default Dialer unless the host name +// requested matches one of a number of exceptions. +type PerHost struct { + def, bypass Dialer + + bypassNetworks []*net.IPNet + bypassIPs []net.IP + bypassZones []string + bypassHosts []string +} + +// NewPerHost returns a PerHost Dialer that directs connections to either +// defaultDialer or bypass, depending on whether the connection matches one of +// the configured rules. +func NewPerHost(defaultDialer, bypass Dialer) *PerHost { + return &PerHost{ + def: defaultDialer, + bypass: bypass, + } +} + +// Dial connects to the address addr on the given network through either +// defaultDialer or bypass. +func (p *PerHost) Dial(network, addr string) (c net.Conn, err error) { + host, _, err := net.SplitHostPort(addr) + if err != nil { + return nil, err + } + + return p.dialerForRequest(host).Dial(network, addr) +} + +// DialContext connects to the address addr on the given network through either +// defaultDialer or bypass. +func (p *PerHost) DialContext(ctx context.Context, network, addr string) (c net.Conn, err error) { + host, _, err := net.SplitHostPort(addr) + if err != nil { + return nil, err + } + d := p.dialerForRequest(host) + if x, ok := d.(ContextDialer); ok { + return x.DialContext(ctx, network, addr) + } + return dialContext(ctx, d, network, addr) +} + +func (p *PerHost) dialerForRequest(host string) Dialer { + if ip := net.ParseIP(host); ip != nil { + for _, net := range p.bypassNetworks { + if net.Contains(ip) { + return p.bypass + } + } + for _, bypassIP := range p.bypassIPs { + if bypassIP.Equal(ip) { + return p.bypass + } + } + return p.def + } + + for _, zone := range p.bypassZones { + if strings.HasSuffix(host, zone) { + return p.bypass + } + if host == zone[1:] { + // For a zone ".example.com", we match "example.com" + // too. + return p.bypass + } + } + for _, bypassHost := range p.bypassHosts { + if bypassHost == host { + return p.bypass + } + } + return p.def +} + +// AddFromString parses a string that contains comma-separated values +// specifying hosts that should use the bypass proxy. Each value is either an +// IP address, a CIDR range, a zone (*.example.com) or a host name +// (localhost). A best effort is made to parse the string and errors are +// ignored. +func (p *PerHost) AddFromString(s string) { + hosts := strings.Split(s, ",") + for _, host := range hosts { + host = strings.TrimSpace(host) + if len(host) == 0 { + continue + } + if strings.Contains(host, "/") { + // We assume that it's a CIDR address like 127.0.0.0/8 + if _, net, err := net.ParseCIDR(host); err == nil { + p.AddNetwork(net) + } + continue + } + if ip := net.ParseIP(host); ip != nil { + p.AddIP(ip) + continue + } + if strings.HasPrefix(host, "*.") { + p.AddZone(host[1:]) + continue + } + p.AddHost(host) + } +} + +// AddIP specifies an IP address that will use the bypass proxy. Note that +// this will only take effect if a literal IP address is dialed. A connection +// to a named host will never match an IP. +func (p *PerHost) AddIP(ip net.IP) { + p.bypassIPs = append(p.bypassIPs, ip) +} + +// AddNetwork specifies an IP range that will use the bypass proxy. Note that +// this will only take effect if a literal IP address is dialed. A connection +// to a named host will never match. +func (p *PerHost) AddNetwork(net *net.IPNet) { + p.bypassNetworks = append(p.bypassNetworks, net) +} + +// AddZone specifies a DNS suffix that will use the bypass proxy. A zone of +// "example.com" matches "example.com" and all of its subdomains. +func (p *PerHost) AddZone(zone string) { + zone = strings.TrimSuffix(zone, ".") + if !strings.HasPrefix(zone, ".") { + zone = "." + zone + } + p.bypassZones = append(p.bypassZones, zone) +} + +// AddHost specifies a host name that will use the bypass proxy. +func (p *PerHost) AddHost(host string) { + host = strings.TrimSuffix(host, ".") + p.bypassHosts = append(p.bypassHosts, host) +} +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Package proxy provides support for a variety of protocols to proxy network +// data. +// package proxy // import "golang.org/x/net/proxy" + +// A Dialer is a means to establish a connection. +// Custom dialers should also implement ContextDialer. +type Dialer interface { + // Dial connects to the given address via the proxy. + Dial(network, addr string) (c net.Conn, err error) +} + +// Auth contains authentication parameters that specific Dialers may require. +type Auth struct { + User, Password string +} + +// FromEnvironment returns the dialer specified by the proxy-related +// variables in the environment and makes underlying connections +// directly. +func FromEnvironment() Dialer { + return FromEnvironmentUsing(Direct) +} + +// FromEnvironmentUsing returns the dialer specify by the proxy-related +// variables in the environment and makes underlying connections +// using the provided forwarding Dialer (for instance, a *net.Dialer +// with desired configuration). +func FromEnvironmentUsing(forward Dialer) Dialer { + allProxy := allProxyEnv.Get() + if len(allProxy) == 0 { + return forward + } + + proxyURL, err := url.Parse(allProxy) + if err != nil { + return forward + } + proxy, err := FromURL(proxyURL, forward) + if err != nil { + return forward + } + + noProxy := noProxyEnv.Get() + if len(noProxy) == 0 { + return proxy + } + + perHost := NewPerHost(proxy, forward) + perHost.AddFromString(noProxy) + return perHost +} + +// proxySchemes is a map from URL schemes to a function that creates a Dialer +// from a URL with such a scheme. +var proxySchemes map[string]func(*url.URL, Dialer) (Dialer, error) + +// RegisterDialerType takes a URL scheme and a function to generate Dialers from +// a URL with that scheme and a forwarding Dialer. Registered schemes are used +// by FromURL. +func RegisterDialerType(scheme string, f func(*url.URL, Dialer) (Dialer, error)) { + if proxySchemes == nil { + proxySchemes = make(map[string]func(*url.URL, Dialer) (Dialer, error)) + } + proxySchemes[scheme] = f +} + +// FromURL returns a Dialer given a URL specification and an underlying +// Dialer for it to make network requests. +func FromURL(u *url.URL, forward Dialer) (Dialer, error) { + var auth *Auth + if u.User != nil { + auth = new(Auth) + auth.User = u.User.Username() + if p, ok := u.User.Password(); ok { + auth.Password = p + } + } + + switch u.Scheme { + case "socks5", "socks5h": + addr := u.Hostname() + port := u.Port() + if port == "" { + port = "1080" + } + return SOCKS5("tcp", net.JoinHostPort(addr, port), auth, forward) + } + + // If the scheme doesn't match any of the built-in schemes, see if it + // was registered by another package. + if proxySchemes != nil { + if f, ok := proxySchemes[u.Scheme]; ok { + return f(u, forward) + } + } + + return nil, errors.New("proxy: unknown scheme: " + u.Scheme) +} + +var ( + allProxyEnv = &envOnce{ + names: []string{"ALL_PROXY", "all_proxy"}, + } + noProxyEnv = &envOnce{ + names: []string{"NO_PROXY", "no_proxy"}, + } +) + +// envOnce looks up an environment variable (optionally by multiple +// names) once. It mitigates expensive lookups on some platforms +// (e.g. Windows). +// (Borrowed from net/http/transport.go) +type envOnce struct { + names []string + once sync.Once + val string +} + +func (e *envOnce) Get() string { + e.once.Do(e.init) + return e.val +} + +func (e *envOnce) init() { + for _, n := range e.names { + e.val = os.Getenv(n) + if e.val != "" { + return + } + } +} + +// reset is used by tests +func (e *envOnce) reset() { + e.once = sync.Once{} + e.val = "" +} +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + + +// SOCKS5 returns a Dialer that makes SOCKSv5 connections to the given +// address with an optional username and password. +// See RFC 1928 and RFC 1929. +func SOCKS5(network, address string, auth *Auth, forward Dialer) (Dialer, error) { + d := socks.NewDialer(network, address) + if forward != nil { + if f, ok := forward.(ContextDialer); ok { + d.ProxyDial = func(ctx context.Context, network string, address string) (net.Conn, error) { + return f.DialContext(ctx, network, address) + } + } else { + d.ProxyDial = func(ctx context.Context, network string, address string) (net.Conn, error) { + return dialContext(ctx, forward, network, address) + } + } + } + if auth != nil { + up := socks.UsernamePassword{ + Username: auth.User, + Password: auth.Password, + } + d.AuthMethods = []socks.AuthMethod{ + socks.AuthMethodNotRequired, + socks.AuthMethodUsernamePassword, + } + d.Authenticate = up.Authenticate + } + return d, nil +} diff --git a/src/socks.go b/src/socks.go new file mode 100644 index 0000000..d506505 --- /dev/null +++ b/src/socks.go @@ -0,0 +1,473 @@ +// Package socks provides a SOCKS version 5 client implementation. +// +// SOCKS protocol version 5 is defined in RFC 1928. +// Username/Password authentication for SOCKS version 5 is defined in +// RFC 1929. +package socks + +import ( + "context" + "errors" + "io" + "net" + "strconv" + "time" +) + + + +var ( + noDeadline = time.Time{} + aLongTimeAgo = time.Unix(1, 0) +) + +func (d *Dialer) connect(ctx context.Context, c net.Conn, address string) (_ net.Addr, ctxErr error) { + host, port, err := splitHostPort(address) + if err != nil { + return nil, err + } + if deadline, ok := ctx.Deadline(); ok && !deadline.IsZero() { + c.SetDeadline(deadline) + defer c.SetDeadline(noDeadline) + } + if ctx != context.Background() { + errCh := make(chan error, 1) + done := make(chan struct{}) + defer func() { + close(done) + if ctxErr == nil { + ctxErr = <-errCh + } + }() + go func() { + select { + case <-ctx.Done(): + c.SetDeadline(aLongTimeAgo) + errCh <- ctx.Err() + case <-done: + errCh <- nil + } + }() + } + + b := make([]byte, 0, 6+len(host)) // the size here is just an estimate + b = append(b, Version5) + if len(d.AuthMethods) == 0 || d.Authenticate == nil { + b = append(b, 1, byte(AuthMethodNotRequired)) + } else { + ams := d.AuthMethods + if len(ams) > 255 { + return nil, errors.New("too many authentication methods") + } + b = append(b, byte(len(ams))) + for _, am := range ams { + b = append(b, byte(am)) + } + } + if _, ctxErr = c.Write(b); ctxErr != nil { + return + } + + if _, ctxErr = io.ReadFull(c, b[:2]); ctxErr != nil { + return + } + if b[0] != Version5 { + return nil, errors.New("unexpected protocol version " + strconv.Itoa(int(b[0]))) + } + am := AuthMethod(b[1]) + if am == AuthMethodNoAcceptableMethods { + return nil, errors.New("no acceptable authentication methods") + } + if d.Authenticate != nil { + if ctxErr = d.Authenticate(ctx, c, am); ctxErr != nil { + return + } + } + + b = b[:0] + b = append(b, Version5, byte(d.cmd), 0) + if ip := net.ParseIP(host); ip != nil { + if ip4 := ip.To4(); ip4 != nil { + b = append(b, AddrTypeIPv4) + b = append(b, ip4...) + } else if ip6 := ip.To16(); ip6 != nil { + b = append(b, AddrTypeIPv6) + b = append(b, ip6...) + } else { + return nil, errors.New("unknown address type") + } + } else { + if len(host) > 255 { + return nil, errors.New("FQDN too long") + } + b = append(b, AddrTypeFQDN) + b = append(b, byte(len(host))) + b = append(b, host...) + } + b = append(b, byte(port>>8), byte(port)) + if _, ctxErr = c.Write(b); ctxErr != nil { + return + } + + if _, ctxErr = io.ReadFull(c, b[:4]); ctxErr != nil { + return + } + if b[0] != Version5 { + return nil, errors.New("unexpected protocol version " + strconv.Itoa(int(b[0]))) + } + if cmdErr := Reply(b[1]); cmdErr != StatusSucceeded { + return nil, errors.New("unknown error " + cmdErr.String()) + } + if b[2] != 0 { + return nil, errors.New("non-zero reserved field") + } + l := 2 + var a Addr + switch b[3] { + case AddrTypeIPv4: + l += net.IPv4len + a.IP = make(net.IP, net.IPv4len) + case AddrTypeIPv6: + l += net.IPv6len + a.IP = make(net.IP, net.IPv6len) + case AddrTypeFQDN: + if _, err := io.ReadFull(c, b[:1]); err != nil { + return nil, err + } + l += int(b[0]) + default: + return nil, errors.New("unknown address type " + strconv.Itoa(int(b[3]))) + } + if cap(b) < l { + b = make([]byte, l) + } else { + b = b[:l] + } + if _, ctxErr = io.ReadFull(c, b); ctxErr != nil { + return + } + if a.IP != nil { + copy(a.IP, b) + } else { + a.Name = string(b[:len(b)-2]) + } + a.Port = int(b[len(b)-2])<<8 | int(b[len(b)-1]) + return &a, nil +} + +func splitHostPort(address string) (string, int, error) { + host, port, err := net.SplitHostPort(address) + if err != nil { + return "", 0, err + } + portnum, err := strconv.Atoi(port) + if err != nil { + return "", 0, err + } + if 1 > portnum || portnum > 0xffff { + return "", 0, errors.New("port number out of range " + port) + } + return host, portnum, nil +} +// Copyright 2018 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// A Command represents a SOCKS command. +type Command int + +func (cmd Command) String() string { + switch cmd { + case CmdConnect: + return "socks connect" + case cmdBind: + return "socks bind" + default: + return "socks " + strconv.Itoa(int(cmd)) + } +} + +// An AuthMethod represents a SOCKS authentication method. +type AuthMethod int + +// A Reply represents a SOCKS command reply code. +type Reply int + +func (code Reply) String() string { + switch code { + case StatusSucceeded: + return "succeeded" + case 0x01: + return "general SOCKS server failure" + case 0x02: + return "connection not allowed by ruleset" + case 0x03: + return "network unreachable" + case 0x04: + return "host unreachable" + case 0x05: + return "connection refused" + case 0x06: + return "TTL expired" + case 0x07: + return "command not supported" + case 0x08: + return "address type not supported" + default: + return "unknown code: " + strconv.Itoa(int(code)) + } +} + +// Wire protocol constants. +const ( + Version5 = 0x05 + + AddrTypeIPv4 = 0x01 + AddrTypeFQDN = 0x03 + AddrTypeIPv6 = 0x04 + + CmdConnect Command = 0x01 // establishes an active-open forward proxy connection + cmdBind Command = 0x02 // establishes a passive-open forward proxy connection + + AuthMethodNotRequired AuthMethod = 0x00 // no authentication required + AuthMethodUsernamePassword AuthMethod = 0x02 // use username/password + AuthMethodNoAcceptableMethods AuthMethod = 0xff // no acceptable authentication methods + + StatusSucceeded Reply = 0x00 +) + +// An Addr represents a SOCKS-specific address. +// Either Name or IP is used exclusively. +type Addr struct { + Name string // fully-qualified domain name + IP net.IP + Port int +} + +func (a *Addr) Network() string { return "socks" } + +func (a *Addr) String() string { + if a == nil { + return "<nil>" + } + port := strconv.Itoa(a.Port) + if a.IP == nil { + return net.JoinHostPort(a.Name, port) + } + return net.JoinHostPort(a.IP.String(), port) +} + +// A Conn represents a forward proxy connection. +type Conn struct { + net.Conn + + boundAddr net.Addr +} + +// BoundAddr returns the address assigned by the proxy server for +// connecting to the command target address from the proxy server. +func (c *Conn) BoundAddr() net.Addr { + if c == nil { + return nil + } + return c.boundAddr +} + +// A Dialer holds SOCKS-specific options. +type Dialer struct { + cmd Command // either CmdConnect or cmdBind + proxyNetwork string // network between a proxy server and a client + proxyAddress string // proxy server address + + // ProxyDial specifies the optional dial function for + // establishing the transport connection. + ProxyDial func(context.Context, string, string) (net.Conn, error) + + // AuthMethods specifies the list of request authentication + // methods. + // If empty, SOCKS client requests only AuthMethodNotRequired. + AuthMethods []AuthMethod + + // Authenticate specifies the optional authentication + // function. It must be non-nil when AuthMethods is not empty. + // It must return an error when the authentication is failed. + Authenticate func(context.Context, io.ReadWriter, AuthMethod) error +} + +// DialContext connects to the provided address on the provided +// network. +// +// The returned error value may be a net.OpError. When the Op field of +// net.OpError contains "socks", the Source field contains a proxy +// server address and the Addr field contains a command target +// address. +// +// See func Dial of the net package of standard library for a +// description of the network and address parameters. +func (d *Dialer) DialContext(ctx context.Context, network, address string) (net.Conn, error) { + if err := d.validateTarget(network, address); err != nil { + proxy, dst, _ := d.pathAddrs(address) + return nil, &net.OpError{Op: d.cmd.String(), Net: network, Source: proxy, Addr: dst, Err: err} + } + if ctx == nil { + proxy, dst, _ := d.pathAddrs(address) + return nil, &net.OpError{Op: d.cmd.String(), Net: network, Source: proxy, Addr: dst, Err: errors.New("nil context")} + } + var err error + var c net.Conn + if d.ProxyDial != nil { + c, err = d.ProxyDial(ctx, d.proxyNetwork, d.proxyAddress) + } else { + var dd net.Dialer + c, err = dd.DialContext(ctx, d.proxyNetwork, d.proxyAddress) + } + if err != nil { + proxy, dst, _ := d.pathAddrs(address) + return nil, &net.OpError{Op: d.cmd.String(), Net: network, Source: proxy, Addr: dst, Err: err} + } + a, err := d.connect(ctx, c, address) + if err != nil { + c.Close() + proxy, dst, _ := d.pathAddrs(address) + return nil, &net.OpError{Op: d.cmd.String(), Net: network, Source: proxy, Addr: dst, Err: err} + } + return &Conn{Conn: c, boundAddr: a}, nil +} + +// DialWithConn initiates a connection from SOCKS server to the target +// network and address using the connection c that is already +// connected to the SOCKS server. +// +// It returns the connection's local address assigned by the SOCKS +// server. +func (d *Dialer) DialWithConn(ctx context.Context, c net.Conn, network, address string) (net.Addr, error) { + if err := d.validateTarget(network, address); err != nil { + proxy, dst, _ := d.pathAddrs(address) + return nil, &net.OpError{Op: d.cmd.String(), Net: network, Source: proxy, Addr: dst, Err: err} + } + if ctx == nil { + proxy, dst, _ := d.pathAddrs(address) + return nil, &net.OpError{Op: d.cmd.String(), Net: network, Source: proxy, Addr: dst, Err: errors.New("nil context")} + } + a, err := d.connect(ctx, c, address) + if err != nil { + proxy, dst, _ := d.pathAddrs(address) + return nil, &net.OpError{Op: d.cmd.String(), Net: network, Source: proxy, Addr: dst, Err: err} + } + return a, nil +} + +// Dial connects to the provided address on the provided network. +// +// Unlike DialContext, it returns a raw transport connection instead +// of a forward proxy connection. +// +// Deprecated: Use DialContext or DialWithConn instead. +func (d *Dialer) Dial(network, address string) (net.Conn, error) { + if err := d.validateTarget(network, address); err != nil { + proxy, dst, _ := d.pathAddrs(address) + return nil, &net.OpError{Op: d.cmd.String(), Net: network, Source: proxy, Addr: dst, Err: err} + } + var err error + var c net.Conn + if d.ProxyDial != nil { + c, err = d.ProxyDial(context.Background(), d.proxyNetwork, d.proxyAddress) + } else { + c, err = net.Dial(d.proxyNetwork, d.proxyAddress) + } + if err != nil { + proxy, dst, _ := d.pathAddrs(address) + return nil, &net.OpError{Op: d.cmd.String(), Net: network, Source: proxy, Addr: dst, Err: err} + } + if _, err := d.DialWithConn(context.Background(), c, network, address); err != nil { + c.Close() + return nil, err + } + return c, nil +} + +func (d *Dialer) validateTarget(network, address string) error { + switch network { + case "tcp", "tcp6", "tcp4": + default: + return errors.New("network not implemented") + } + switch d.cmd { + case CmdConnect, cmdBind: + default: + return errors.New("command not implemented") + } + return nil +} + +func (d *Dialer) pathAddrs(address string) (proxy, dst net.Addr, err error) { + for i, s := range []string{d.proxyAddress, address} { + host, port, err := splitHostPort(s) + if err != nil { + return nil, nil, err + } + a := &Addr{Port: port} + a.IP = net.ParseIP(host) + if a.IP == nil { + a.Name = host + } + if i == 0 { + proxy = a + } else { + dst = a + } + } + return +} + +// NewDialer returns a new Dialer that dials through the provided +// proxy server's network and address. +func NewDialer(network, address string) *Dialer { + return &Dialer{proxyNetwork: network, proxyAddress: address, cmd: CmdConnect} +} + +const ( + authUsernamePasswordVersion = 0x01 + authStatusSucceeded = 0x00 +) + +// UsernamePassword are the credentials for the username/password +// authentication method. +type UsernamePassword struct { + Username string + Password string +} + +// Authenticate authenticates a pair of username and password with the +// proxy server. +func (up *UsernamePassword) Authenticate(ctx context.Context, rw io.ReadWriter, auth AuthMethod) error { + switch auth { + case AuthMethodNotRequired: + return nil + case AuthMethodUsernamePassword: + if len(up.Username) == 0 || len(up.Username) > 255 || len(up.Password) > 255 { + return errors.New("invalid username/password") + } + b := []byte{authUsernamePasswordVersion} + b = append(b, byte(len(up.Username))) + b = append(b, up.Username...) + b = append(b, byte(len(up.Password))) + b = append(b, up.Password...) + // TODO(mikio): handle IO deadlines and cancelation if + // necessary + if _, err := rw.Write(b); err != nil { + return err + } + if _, err := io.ReadFull(rw, b[:2]); err != nil { + return err + } + if b[0] != authUsernamePasswordVersion { + return errors.New("invalid username/password version") + } + if b[1] != authStatusSucceeded { + return errors.New("username/password authentication failed") + } + return nil + } + return errors.New("unsupported authentication method " + strconv.Itoa(int(auth))) +} diff --git a/src/wscat.go b/src/wscat.go new file mode 100644 index 0000000..b9a17a6 --- /dev/null +++ b/src/wscat.go @@ -0,0 +1,3148 @@ +package wscat + +import ( + "bufio" + "bytes" + "compress/flate" + "context" + "crypto/rand" + "crypto/sha1" + "crypto/tls" + "encoding/base64" + "encoding/binary" + "encoding/json" + "errors" + "fmt" + "io" + "net" + "net/http" + "net/http/httptrace" + "net/url" + "os" + "strconv" + "strings" + "sync" + "time" + "unicode/utf8" + "unsafe" + + "proxy" + + g "gobang" +) + +// ErrBadHandshake is returned when the server response to opening handshake is +// invalid. +var ErrBadHandshake = errors.New("websocket: bad handshake") + +var errInvalidCompression = errors.New("websocket: invalid compression negotiation") + +// NewClient creates a new client connection using the given net connection. +// The URL u specifies the host and request URI. Use requestHeader to specify +// the origin (Origin), subprotocols (Sec-WebSocket-Protocol) and cookies +// (Cookie). Use the response.Header to get the selected subprotocol +// (Sec-WebSocket-Protocol) and cookies (Set-Cookie). +// +// If the WebSocket handshake fails, ErrBadHandshake is returned along with a +// non-nil *http.Response so that callers can handle redirects, authentication, +// etc. +// +// Deprecated: Use Dialer instead. +func NewClient(netConn net.Conn, u *url.URL, requestHeader http.Header, readBufSize, writeBufSize int) (c *Conn, response *http.Response, err error) { + d := Dialer{ + ReadBufferSize: readBufSize, + WriteBufferSize: writeBufSize, + NetDial: func(net, addr string) (net.Conn, error) { + return netConn, nil + }, + } + return d.Dial(u.String(), requestHeader) +} + +// A Dialer contains options for connecting to WebSocket server. +// +// It is safe to call Dialer's methods concurrently. +type Dialer struct { + // NetDial specifies the dial function for creating TCP connections. If + // NetDial is nil, net.Dialer DialContext is used. + NetDial func(network, addr string) (net.Conn, error) + + // NetDialContext specifies the dial function for creating TCP connections. If + // NetDialContext is nil, NetDial is used. + NetDialContext func(ctx context.Context, network, addr string) (net.Conn, error) + + // NetDialTLSContext specifies the dial function for creating TLS/TCP connections. If + // NetDialTLSContext is nil, NetDialContext is used. + // If NetDialTLSContext is set, Dial assumes the TLS handshake is done there and + // TLSClientConfig is ignored. + NetDialTLSContext func(ctx context.Context, network, addr string) (net.Conn, error) + + // Proxy specifies a function to return a proxy for a given + // Request. If the function returns a non-nil error, the + // request is aborted with the provided error. + // If Proxy is nil or returns a nil *URL, no proxy is used. + Proxy func(*http.Request) (*url.URL, error) + + // TLSClientConfig specifies the TLS configuration to use with tls.Client. + // If nil, the default configuration is used. + // If either NetDialTLS or NetDialTLSContext are set, Dial assumes the TLS handshake + // is done there and TLSClientConfig is ignored. + TLSClientConfig *tls.Config + + // HandshakeTimeout specifies the duration for the handshake to complete. + HandshakeTimeout time.Duration + + // ReadBufferSize and WriteBufferSize specify I/O buffer sizes in bytes. If a buffer + // size is zero, then a useful default size is used. The I/O buffer sizes + // do not limit the size of the messages that can be sent or received. + ReadBufferSize, WriteBufferSize int + + // WriteBufferPool is a pool of buffers for write operations. If the value + // is not set, then write buffers are allocated to the connection for the + // lifetime of the connection. + // + // A pool is most useful when the application has a modest volume of writes + // across a large number of connections. + // + // Applications should use a single pool for each unique value of + // WriteBufferSize. + WriteBufferPool BufferPool + + // Subprotocols specifies the client's requested subprotocols. + Subprotocols []string + + // EnableCompression specifies if the client should attempt to negotiate + // per message compression (RFC 7692). Setting this value to true does not + // guarantee that compression will be supported. Currently only "no context + // takeover" modes are supported. + EnableCompression bool + + // Jar specifies the cookie jar. + // If Jar is nil, cookies are not sent in requests and ignored + // in responses. + Jar http.CookieJar +} + +// Dial creates a new client connection by calling DialContext with a background context. +func (d *Dialer) Dial(urlStr string, requestHeader http.Header) (*Conn, *http.Response, error) { + return d.DialContext(context.Background(), urlStr, requestHeader) +} + +var errMalformedURL = errors.New("malformed ws or wss URL") + +func hostPortNoPort(u *url.URL) (hostPort, hostNoPort string) { + hostPort = u.Host + hostNoPort = u.Host + if i := strings.LastIndex(u.Host, ":"); i > strings.LastIndex(u.Host, "]") { + hostNoPort = hostNoPort[:i] + } else { + switch u.Scheme { + case "wss": + hostPort += ":443" + case "https": + hostPort += ":443" + default: + hostPort += ":80" + } + } + return hostPort, hostNoPort +} + +// DefaultDialer is a dialer with all fields set to the default values. +var DefaultDialer = &Dialer{ + Proxy: http.ProxyFromEnvironment, + HandshakeTimeout: 45 * time.Second, +} + +// nilDialer is dialer to use when receiver is nil. +var nilDialer = *DefaultDialer + +// DialContext creates a new client connection. Use requestHeader to specify the +// origin (Origin), subprotocols (Sec-WebSocket-Protocol) and cookies (Cookie). +// Use the response.Header to get the selected subprotocol +// (Sec-WebSocket-Protocol) and cookies (Set-Cookie). +// +// The context will be used in the request and in the Dialer. +// +// If the WebSocket handshake fails, ErrBadHandshake is returned along with a +// non-nil *http.Response so that callers can handle redirects, authentication, +// etcetera. The response body may not contain the entire response and does not +// need to be closed by the application. +func (d *Dialer) DialContext(ctx context.Context, urlStr string, requestHeader http.Header) (*Conn, *http.Response, error) { + if d == nil { + d = &nilDialer + } + + challengeKey, err := generateChallengeKey() + if err != nil { + return nil, nil, err + } + + u, err := url.Parse(urlStr) + if err != nil { + return nil, nil, err + } + + switch u.Scheme { + case "ws": + u.Scheme = "http" + case "wss": + u.Scheme = "https" + default: + return nil, nil, errMalformedURL + } + + if u.User != nil { + // User name and password are not allowed in websocket URIs. + return nil, nil, errMalformedURL + } + + req := &http.Request{ + Method: http.MethodGet, + URL: u, + Proto: "HTTP/1.1", + ProtoMajor: 1, + ProtoMinor: 1, + Header: make(http.Header), + Host: u.Host, + } + req = req.WithContext(ctx) + + // Set the cookies present in the cookie jar of the dialer + if d.Jar != nil { + for _, cookie := range d.Jar.Cookies(u) { + req.AddCookie(cookie) + } + } + + // Set the request headers using the capitalization for names and values in + // RFC examples. Although the capitalization shouldn't matter, there are + // servers that depend on it. The Header.Set method is not used because the + // method canonicalizes the header names. + req.Header["Upgrade"] = []string{"websocket"} + req.Header["Connection"] = []string{"Upgrade"} + req.Header["Sec-WebSocket-Key"] = []string{challengeKey} + req.Header["Sec-WebSocket-Version"] = []string{"13"} + if len(d.Subprotocols) > 0 { + req.Header["Sec-WebSocket-Protocol"] = []string{strings.Join(d.Subprotocols, ", ")} + } + for k, vs := range requestHeader { + switch { + case k == "Host": + if len(vs) > 0 { + req.Host = vs[0] + } + case k == "Upgrade" || + k == "Connection" || + k == "Sec-Websocket-Key" || + k == "Sec-Websocket-Version" || + k == "Sec-Websocket-Extensions" || + (k == "Sec-Websocket-Protocol" && len(d.Subprotocols) > 0): + return nil, nil, errors.New("websocket: duplicate header not allowed: " + k) + case k == "Sec-Websocket-Protocol": + req.Header["Sec-WebSocket-Protocol"] = vs + default: + req.Header[k] = vs + } + } + + if d.EnableCompression { + req.Header["Sec-WebSocket-Extensions"] = []string{"permessage-deflate; server_no_context_takeover; client_no_context_takeover"} + } + + if d.HandshakeTimeout != 0 { + var cancel func() + ctx, cancel = context.WithTimeout(ctx, d.HandshakeTimeout) + defer cancel() + } + + var netDial netDialerFunc + switch { + case u.Scheme == "https" && d.NetDialTLSContext != nil: + netDial = d.NetDialTLSContext + case d.NetDialContext != nil: + netDial = d.NetDialContext + case d.NetDial != nil: + netDial = func(ctx context.Context, net, addr string) (net.Conn, error) { + return d.NetDial(net, addr) + } + default: + netDial = (&net.Dialer{}).DialContext + } + + // If needed, wrap the dial function to set the connection deadline. + if deadline, ok := ctx.Deadline(); ok { + forwardDial := netDial + netDial = func(ctx context.Context, network, addr string) (net.Conn, error) { + c, err := forwardDial(ctx, network, addr) + if err != nil { + return nil, err + } + err = c.SetDeadline(deadline) + if err != nil { + c.Close() + return nil, err + } + return c, nil + } + } + + // If needed, wrap the dial function to connect through a proxy. + if d.Proxy != nil { + proxyURL, err := d.Proxy(req) + if err != nil { + return nil, nil, err + } + if proxyURL != nil { + netDial, err = proxyFromURL(proxyURL, netDial) + if err != nil { + return nil, nil, err + } + } + } + + hostPort, hostNoPort := hostPortNoPort(u) + trace := httptrace.ContextClientTrace(ctx) + if trace != nil && trace.GetConn != nil { + trace.GetConn(hostPort) + } + + netConn, err := netDial(ctx, "tcp", hostPort) + if err != nil { + return nil, nil, err + } + if trace != nil && trace.GotConn != nil { + trace.GotConn(httptrace.GotConnInfo{ + Conn: netConn, + }) + } + + // Close the network connection when returning an error. The variable + // netConn is set to nil before the success return at the end of the + // function. + defer func() { + if netConn != nil { + // It's safe to ignore the error from Close() because this code is + // only executed when returning a more important error to the + // application. + _ = netConn.Close() + } + }() + + if u.Scheme == "https" && d.NetDialTLSContext == nil { + // If NetDialTLSContext is set, assume that the TLS handshake has already been done + + cfg := cloneTLSConfig(d.TLSClientConfig) + if cfg.ServerName == "" { + cfg.ServerName = hostNoPort + } + tlsConn := tls.Client(netConn, cfg) + netConn = tlsConn + + if trace != nil && trace.TLSHandshakeStart != nil { + trace.TLSHandshakeStart() + } + err := doHandshake(ctx, tlsConn, cfg) + if trace != nil && trace.TLSHandshakeDone != nil { + trace.TLSHandshakeDone(tlsConn.ConnectionState(), err) + } + + if err != nil { + return nil, nil, err + } + } + + conn := newConn(netConn, false, d.ReadBufferSize, d.WriteBufferSize, d.WriteBufferPool, nil, nil) + + if err := req.Write(netConn); err != nil { + return nil, nil, err + } + + if trace != nil && trace.GotFirstResponseByte != nil { + if peek, err := conn.br.Peek(1); err == nil && len(peek) == 1 { + trace.GotFirstResponseByte() + } + } + + resp, err := http.ReadResponse(conn.br, req) + if err != nil { + if d.TLSClientConfig != nil { + for _, proto := range d.TLSClientConfig.NextProtos { + if proto != "http/1.1" { + return nil, nil, fmt.Errorf( + "websocket: protocol %q was given but is not supported;"+ + "sharing tls.Config with net/http Transport can cause this error: %w", + proto, err, + ) + } + } + } + return nil, nil, err + } + + if d.Jar != nil { + if rc := resp.Cookies(); len(rc) > 0 { + d.Jar.SetCookies(u, rc) + } + } + + if resp.StatusCode != 101 || + !tokenListContainsValue(resp.Header, "Upgrade", "websocket") || + !tokenListContainsValue(resp.Header, "Connection", "upgrade") || + resp.Header.Get("Sec-Websocket-Accept") != computeAcceptKey(challengeKey) { + // Before closing the network connection on return from this + // function, slurp up some of the response to aid application + // debugging. + buf := make([]byte, 1024) + n, _ := io.ReadFull(resp.Body, buf) + resp.Body = io.NopCloser(bytes.NewReader(buf[:n])) + return nil, resp, ErrBadHandshake + } + + for _, ext := range parseExtensions(resp.Header) { + if ext[""] != "permessage-deflate" { + continue + } + _, snct := ext["server_no_context_takeover"] + _, cnct := ext["client_no_context_takeover"] + if !snct || !cnct { + return nil, resp, errInvalidCompression + } + conn.newCompressionWriter = compressNoContextTakeover + conn.newDecompressionReader = decompressNoContextTakeover + break + } + + resp.Body = io.NopCloser(bytes.NewReader([]byte{})) + conn.subprotocol = resp.Header.Get("Sec-Websocket-Protocol") + + if err := netConn.SetDeadline(time.Time{}); err != nil { + return nil, resp, err + } + + // Success! Set netConn to nil to stop the deferred function above from + // closing the network connection. + netConn = nil + + return conn, resp, nil +} + +func cloneTLSConfig(cfg *tls.Config) *tls.Config { + if cfg == nil { + return &tls.Config{} + } + return cfg.Clone() +} + +func doHandshake(ctx context.Context, tlsConn *tls.Conn, cfg *tls.Config) error { + if err := tlsConn.HandshakeContext(ctx); err != nil { + return err + } + if !cfg.InsecureSkipVerify { + if err := tlsConn.VerifyHostname(cfg.ServerName); err != nil { + return err + } + } + return nil +} +// Copyright 2017 The Gorilla WebSocket Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + + +const ( + minCompressionLevel = -2 // flate.HuffmanOnly not defined in Go < 1.6 + maxCompressionLevel = flate.BestCompression + defaultCompressionLevel = 1 +) + +var ( + flateWriterPools [maxCompressionLevel - minCompressionLevel + 1]sync.Pool + flateReaderPool = sync.Pool{New: func() interface{} { + return flate.NewReader(nil) + }} +) + +func decompressNoContextTakeover(r io.Reader) io.ReadCloser { + const tail = + // Add four bytes as specified in RFC + "\x00\x00\xff\xff" + + // Add final block to squelch unexpected EOF error from flate reader. + "\x01\x00\x00\xff\xff" + + fr, _ := flateReaderPool.Get().(io.ReadCloser) + mr := io.MultiReader(r, strings.NewReader(tail)) + if err := fr.(flate.Resetter).Reset(mr, nil); err != nil { + // Reset never fails, but handle error in case that changes. + fr = flate.NewReader(mr) + } + return &flateReadWrapper{fr} +} + +func isValidCompressionLevel(level int) bool { + return minCompressionLevel <= level && level <= maxCompressionLevel +} + +func compressNoContextTakeover(w io.WriteCloser, level int) io.WriteCloser { + p := &flateWriterPools[level-minCompressionLevel] + tw := &truncWriter{w: w} + fw, _ := p.Get().(*flate.Writer) + if fw == nil { + fw, _ = flate.NewWriter(tw, level) + } else { + fw.Reset(tw) + } + return &flateWriteWrapper{fw: fw, tw: tw, p: p} +} + +// truncWriter is an io.Writer that writes all but the last four bytes of the +// stream to another io.Writer. +type truncWriter struct { + w io.WriteCloser + n int + p [4]byte +} + +func (w *truncWriter) Write(p []byte) (int, error) { + n := 0 + + // fill buffer first for simplicity. + if w.n < len(w.p) { + n = copy(w.p[w.n:], p) + p = p[n:] + w.n += n + if len(p) == 0 { + return n, nil + } + } + + m := len(p) + if m > len(w.p) { + m = len(w.p) + } + + if nn, err := w.w.Write(w.p[:m]); err != nil { + return n + nn, err + } + + copy(w.p[:], w.p[m:]) + copy(w.p[len(w.p)-m:], p[len(p)-m:]) + nn, err := w.w.Write(p[:len(p)-m]) + return n + nn, err +} + +type flateWriteWrapper struct { + fw *flate.Writer + tw *truncWriter + p *sync.Pool +} + +func (w *flateWriteWrapper) Write(p []byte) (int, error) { + if w.fw == nil { + return 0, errWriteClosed + } + return w.fw.Write(p) +} + +func (w *flateWriteWrapper) Close() error { + if w.fw == nil { + return errWriteClosed + } + err1 := w.fw.Flush() + w.p.Put(w.fw) + w.fw = nil + if w.tw.p != [4]byte{0, 0, 0xff, 0xff} { + return errors.New("websocket: internal error, unexpected bytes at end of flate stream") + } + err2 := w.tw.w.Close() + if err1 != nil { + return err1 + } + return err2 +} + +type flateReadWrapper struct { + fr io.ReadCloser +} + +func (r *flateReadWrapper) Read(p []byte) (int, error) { + if r.fr == nil { + return 0, io.ErrClosedPipe + } + n, err := r.fr.Read(p) + if err == io.EOF { + // Preemptively place the reader back in the pool. This helps with + // scenarios where the application does not call NextReader() soon after + // this final read. + r.Close() + } + return n, err +} + +func (r *flateReadWrapper) Close() error { + if r.fr == nil { + return io.ErrClosedPipe + } + err := r.fr.Close() + flateReaderPool.Put(r.fr) + r.fr = nil + return err +} +// Copyright 2013 The Gorilla WebSocket Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + + +const ( + // Frame header byte 0 bits from Section 5.2 of RFC 6455 + finalBit = 1 << 7 + rsv1Bit = 1 << 6 + rsv2Bit = 1 << 5 + rsv3Bit = 1 << 4 + + // Frame header byte 1 bits from Section 5.2 of RFC 6455 + maskBit = 1 << 7 + + maxFrameHeaderSize = 2 + 8 + 4 // Fixed header + length + mask + maxControlFramePayloadSize = 125 + + writeWait = time.Second + + defaultReadBufferSize = 4096 + defaultWriteBufferSize = 4096 + + continuationFrame = 0 + noFrame = -1 +) + +// Close codes defined in RFC 6455, section 11.7. +const ( + CloseNormalClosure = 1000 + CloseGoingAway = 1001 + CloseProtocolError = 1002 + CloseUnsupportedData = 1003 + CloseNoStatusReceived = 1005 + CloseAbnormalClosure = 1006 + CloseInvalidFramePayloadData = 1007 + ClosePolicyViolation = 1008 + CloseMessageTooBig = 1009 + CloseMandatoryExtension = 1010 + CloseInternalServerErr = 1011 + CloseServiceRestart = 1012 + CloseTryAgainLater = 1013 + CloseTLSHandshake = 1015 +) + +// The message types are defined in RFC 6455, section 11.8. +const ( + // TextMessage denotes a text data message. The text message payload is + // interpreted as UTF-8 encoded text data. + TextMessage = 1 + + // BinaryMessage denotes a binary data message. + BinaryMessage = 2 + + // CloseMessage denotes a close control message. The optional message + // payload contains a numeric code and text. Use the FormatCloseMessage + // function to format a close message payload. + CloseMessage = 8 + + // PingMessage denotes a ping control message. The optional message payload + // is UTF-8 encoded text. + PingMessage = 9 + + // PongMessage denotes a pong control message. The optional message payload + // is UTF-8 encoded text. + PongMessage = 10 +) + +// ErrCloseSent is returned when the application writes a message to the +// connection after sending a close message. +var ErrCloseSent = errors.New("websocket: close sent") + +// ErrReadLimit is returned when reading a message that is larger than the +// read limit set for the connection. +var ErrReadLimit = errors.New("websocket: read limit exceeded") + +// netError satisfies the net Error interface. +type netError struct { + msg string + temporary bool + timeout bool +} + +func (e *netError) Error() string { return e.msg } +func (e *netError) Temporary() bool { return e.temporary } +func (e *netError) Timeout() bool { return e.timeout } + +// CloseError represents a close message. +type CloseError struct { + // Code is defined in RFC 6455, section 11.7. + Code int + + // Text is the optional text payload. + Text string +} + +func (e *CloseError) Error() string { + s := []byte("websocket: close ") + s = strconv.AppendInt(s, int64(e.Code), 10) + switch e.Code { + case CloseNormalClosure: + s = append(s, " (normal)"...) + case CloseGoingAway: + s = append(s, " (going away)"...) + case CloseProtocolError: + s = append(s, " (protocol error)"...) + case CloseUnsupportedData: + s = append(s, " (unsupported data)"...) + case CloseNoStatusReceived: + s = append(s, " (no status)"...) + case CloseAbnormalClosure: + s = append(s, " (abnormal closure)"...) + case CloseInvalidFramePayloadData: + s = append(s, " (invalid payload data)"...) + case ClosePolicyViolation: + s = append(s, " (policy violation)"...) + case CloseMessageTooBig: + s = append(s, " (message too big)"...) + case CloseMandatoryExtension: + s = append(s, " (mandatory extension missing)"...) + case CloseInternalServerErr: + s = append(s, " (internal server error)"...) + case CloseTLSHandshake: + s = append(s, " (TLS handshake error)"...) + } + if e.Text != "" { + s = append(s, ": "...) + s = append(s, e.Text...) + } + return string(s) +} + +// IsCloseError returns boolean indicating whether the error is a *CloseError +// with one of the specified codes. +func IsCloseError(err error, codes ...int) bool { + if e, ok := err.(*CloseError); ok { + for _, code := range codes { + if e.Code == code { + return true + } + } + } + return false +} + +// IsUnexpectedCloseError returns boolean indicating whether the error is a +// *CloseError with a code not in the list of expected codes. +func IsUnexpectedCloseError(err error, expectedCodes ...int) bool { + if e, ok := err.(*CloseError); ok { + for _, code := range expectedCodes { + if e.Code == code { + return false + } + } + return true + } + return false +} + +var ( + errWriteTimeout = &netError{msg: "websocket: write timeout", timeout: true, temporary: true} + errUnexpectedEOF = &CloseError{Code: CloseAbnormalClosure, Text: io.ErrUnexpectedEOF.Error()} + errBadWriteOpCode = errors.New("websocket: bad write message type") + errWriteClosed = errors.New("websocket: write closed") + errInvalidControlFrame = errors.New("websocket: invalid control frame") +) + +// maskRand is an io.Reader for generating mask bytes. The reader is initialized +// to crypto/rand Reader. Tests swap the reader to a math/rand reader for +// reproducible results. +var maskRand = rand.Reader + +// newMaskKey returns a new 32 bit value for masking client frames. +func newMaskKey() [4]byte { + var k [4]byte + _, _ = io.ReadFull(maskRand, k[:]) + return k +} + +func isControl(frameType int) bool { + return frameType == CloseMessage || frameType == PingMessage || frameType == PongMessage +} + +func isData(frameType int) bool { + return frameType == TextMessage || frameType == BinaryMessage +} + +var validReceivedCloseCodes = map[int]bool{ + // see http://www.iana.org/assignments/websocket/websocket.xhtml#close-code-number + + CloseNormalClosure: true, + CloseGoingAway: true, + CloseProtocolError: true, + CloseUnsupportedData: true, + CloseNoStatusReceived: false, + CloseAbnormalClosure: false, + CloseInvalidFramePayloadData: true, + ClosePolicyViolation: true, + CloseMessageTooBig: true, + CloseMandatoryExtension: true, + CloseInternalServerErr: true, + CloseServiceRestart: true, + CloseTryAgainLater: true, + CloseTLSHandshake: false, +} + +func isValidReceivedCloseCode(code int) bool { + return validReceivedCloseCodes[code] || (code >= 3000 && code <= 4999) +} + +// BufferPool represents a pool of buffers. The *sync.Pool type satisfies this +// interface. The type of the value stored in a pool is not specified. +type BufferPool interface { + // Get gets a value from the pool or returns nil if the pool is empty. + Get() interface{} + // Put adds a value to the pool. + Put(interface{}) +} + +// writePoolData is the type added to the write buffer pool. This wrapper is +// used to prevent applications from peeking at and depending on the values +// added to the pool. +type writePoolData struct{ buf []byte } + +// The Conn type represents a WebSocket connection. +type Conn struct { + conn net.Conn + isServer bool + subprotocol string + + // Write fields + mu chan struct{} // used as mutex to protect write to conn + writeBuf []byte // frame is constructed in this buffer. + writePool BufferPool + writeBufSize int + writeDeadline time.Time + writer io.WriteCloser // the current writer returned to the application + isWriting bool // for best-effort concurrent write detection + + writeErrMu sync.Mutex + writeErr error + + enableWriteCompression bool + compressionLevel int + newCompressionWriter func(io.WriteCloser, int) io.WriteCloser + + // Read fields + reader io.ReadCloser // the current reader returned to the application + readErr error + br *bufio.Reader + // bytes remaining in current frame. + // set setReadRemaining to safely update this value and prevent overflow + readRemaining int64 + readFinal bool // true the current message has more frames. + readLength int64 // Message size. + readLimit int64 // Maximum message size. + readMaskPos int + readMaskKey [4]byte + handlePong func(string) error + handlePing func(string) error + handleClose func(int, string) error + readErrCount int + messageReader *messageReader // the current low-level reader + + readDecompress bool // whether last read frame had RSV1 set + newDecompressionReader func(io.Reader) io.ReadCloser +} + +func newConn(conn net.Conn, isServer bool, readBufferSize, writeBufferSize int, writeBufferPool BufferPool, br *bufio.Reader, writeBuf []byte) *Conn { + + if br == nil { + if readBufferSize == 0 { + readBufferSize = defaultReadBufferSize + } else if readBufferSize < maxControlFramePayloadSize { + // must be large enough for control frame + readBufferSize = maxControlFramePayloadSize + } + br = bufio.NewReaderSize(conn, readBufferSize) + } + + if writeBufferSize <= 0 { + writeBufferSize = defaultWriteBufferSize + } + writeBufferSize += maxFrameHeaderSize + + if writeBuf == nil && writeBufferPool == nil { + writeBuf = make([]byte, writeBufferSize) + } + + mu := make(chan struct{}, 1) + mu <- struct{}{} + c := &Conn{ + isServer: isServer, + br: br, + conn: conn, + mu: mu, + readFinal: true, + writeBuf: writeBuf, + writePool: writeBufferPool, + writeBufSize: writeBufferSize, + enableWriteCompression: true, + compressionLevel: defaultCompressionLevel, + } + c.SetCloseHandler(nil) + c.SetPingHandler(nil) + c.SetPongHandler(nil) + return c +} + +// setReadRemaining tracks the number of bytes remaining on the connection. If n +// overflows, an ErrReadLimit is returned. +func (c *Conn) setReadRemaining(n int64) error { + if n < 0 { + return ErrReadLimit + } + + c.readRemaining = n + return nil +} + +// Subprotocol returns the negotiated protocol for the connection. +func (c *Conn) Subprotocol() string { + return c.subprotocol +} + +// Close closes the underlying network connection without sending or waiting +// for a close message. +func (c *Conn) Close() error { + return c.conn.Close() +} + +// LocalAddr returns the local network address. +func (c *Conn) LocalAddr() net.Addr { + return c.conn.LocalAddr() +} + +// RemoteAddr returns the remote network address. +func (c *Conn) RemoteAddr() net.Addr { + return c.conn.RemoteAddr() +} + +// Write methods + +func (c *Conn) writeFatal(err error) error { + c.writeErrMu.Lock() + if c.writeErr == nil { + c.writeErr = err + } + c.writeErrMu.Unlock() + return err +} + +func (c *Conn) read(n int) ([]byte, error) { + p, err := c.br.Peek(n) + if err == io.EOF { + err = errUnexpectedEOF + } + // Discard is guaranteed to succeed because the number of bytes to discard + // is less than or equal to the number of bytes buffered. + _, _ = c.br.Discard(len(p)) + return p, err +} + +func (c *Conn) write(frameType int, deadline time.Time, buf0, buf1 []byte) error { + <-c.mu + defer func() { c.mu <- struct{}{} }() + + c.writeErrMu.Lock() + err := c.writeErr + c.writeErrMu.Unlock() + if err != nil { + return err + } + + if err := c.conn.SetWriteDeadline(deadline); err != nil { + return c.writeFatal(err) + } + if len(buf1) == 0 { + _, err = c.conn.Write(buf0) + } else { + err = c.writeBufs(buf0, buf1) + } + if err != nil { + return c.writeFatal(err) + } + if frameType == CloseMessage { + _ = c.writeFatal(ErrCloseSent) + } + return nil +} + +func (c *Conn) writeBufs(bufs ...[]byte) error { + b := net.Buffers(bufs) + _, err := b.WriteTo(c.conn) + return err +} + +// WriteControl writes a control message with the given deadline. The allowed +// message types are CloseMessage, PingMessage and PongMessage. +func (c *Conn) WriteControl(messageType int, data []byte, deadline time.Time) error { + if !isControl(messageType) { + return errBadWriteOpCode + } + if len(data) > maxControlFramePayloadSize { + return errInvalidControlFrame + } + + b0 := byte(messageType) | finalBit + b1 := byte(len(data)) + if !c.isServer { + b1 |= maskBit + } + + buf := make([]byte, 0, maxFrameHeaderSize+maxControlFramePayloadSize) + buf = append(buf, b0, b1) + + if c.isServer { + buf = append(buf, data...) + } else { + key := newMaskKey() + buf = append(buf, key[:]...) + buf = append(buf, data...) + maskBytes(key, 0, buf[6:]) + } + + if deadline.IsZero() { + // No timeout for zero time. + <-c.mu + } else { + d := time.Until(deadline) + if d < 0 { + return errWriteTimeout + } + select { + case <-c.mu: + default: + timer := time.NewTimer(d) + select { + case <-c.mu: + timer.Stop() + case <-timer.C: + return errWriteTimeout + } + } + } + + defer func() { c.mu <- struct{}{} }() + + c.writeErrMu.Lock() + err := c.writeErr + c.writeErrMu.Unlock() + if err != nil { + return err + } + + if err := c.conn.SetWriteDeadline(deadline); err != nil { + return c.writeFatal(err) + } + if _, err = c.conn.Write(buf); err != nil { + return c.writeFatal(err) + } + if messageType == CloseMessage { + _ = c.writeFatal(ErrCloseSent) + } + return err +} + +// beginMessage prepares a connection and message writer for a new message. +func (c *Conn) beginMessage(mw *messageWriter, messageType int) error { + // Close previous writer if not already closed by the application. It's + // probably better to return an error in this situation, but we cannot + // change this without breaking existing applications. + if c.writer != nil { + c.writer.Close() + c.writer = nil + } + + if !isControl(messageType) && !isData(messageType) { + return errBadWriteOpCode + } + + c.writeErrMu.Lock() + err := c.writeErr + c.writeErrMu.Unlock() + if err != nil { + return err + } + + mw.c = c + mw.frameType = messageType + mw.pos = maxFrameHeaderSize + + if c.writeBuf == nil { + wpd, ok := c.writePool.Get().(writePoolData) + if ok { + c.writeBuf = wpd.buf + } else { + c.writeBuf = make([]byte, c.writeBufSize) + } + } + return nil +} + +// NextWriter returns a writer for the next message to send. The writer's Close +// method flushes the complete message to the network. +// +// There can be at most one open writer on a connection. NextWriter closes the +// previous writer if the application has not already done so. +// +// All message types (TextMessage, BinaryMessage, CloseMessage, PingMessage and +// PongMessage) are supported. +func (c *Conn) NextWriter(messageType int) (io.WriteCloser, error) { + var mw messageWriter + if err := c.beginMessage(&mw, messageType); err != nil { + return nil, err + } + c.writer = &mw + if c.newCompressionWriter != nil && c.enableWriteCompression && isData(messageType) { + w := c.newCompressionWriter(c.writer, c.compressionLevel) + mw.compress = true + c.writer = w + } + return c.writer, nil +} + +type messageWriter struct { + c *Conn + compress bool // whether next call to flushFrame should set RSV1 + pos int // end of data in writeBuf. + frameType int // type of the current frame. + err error +} + +func (w *messageWriter) endMessage(err error) error { + if w.err != nil { + return err + } + c := w.c + w.err = err + c.writer = nil + if c.writePool != nil { + c.writePool.Put(writePoolData{buf: c.writeBuf}) + c.writeBuf = nil + } + return err +} + +// flushFrame writes buffered data and extra as a frame to the network. The +// final argument indicates that this is the last frame in the message. +func (w *messageWriter) flushFrame(final bool, extra []byte) error { + c := w.c + length := w.pos - maxFrameHeaderSize + len(extra) + + // Check for invalid control frames. + if isControl(w.frameType) && + (!final || length > maxControlFramePayloadSize) { + return w.endMessage(errInvalidControlFrame) + } + + b0 := byte(w.frameType) + if final { + b0 |= finalBit + } + if w.compress { + b0 |= rsv1Bit + } + w.compress = false + + b1 := byte(0) + if !c.isServer { + b1 |= maskBit + } + + // Assume that the frame starts at beginning of c.writeBuf. + framePos := 0 + if c.isServer { + // Adjust up if mask not included in the header. + framePos = 4 + } + + switch { + case length >= 65536: + c.writeBuf[framePos] = b0 + c.writeBuf[framePos+1] = b1 | 127 + binary.BigEndian.PutUint64(c.writeBuf[framePos+2:], uint64(length)) + case length > 125: + framePos += 6 + c.writeBuf[framePos] = b0 + c.writeBuf[framePos+1] = b1 | 126 + binary.BigEndian.PutUint16(c.writeBuf[framePos+2:], uint16(length)) + default: + framePos += 8 + c.writeBuf[framePos] = b0 + c.writeBuf[framePos+1] = b1 | byte(length) + } + + if !c.isServer { + key := newMaskKey() + copy(c.writeBuf[maxFrameHeaderSize-4:], key[:]) + maskBytes(key, 0, c.writeBuf[maxFrameHeaderSize:w.pos]) + if len(extra) > 0 { + return w.endMessage(c.writeFatal(errors.New("websocket: internal error, extra used in client mode"))) + } + } + + // Write the buffers to the connection with best-effort detection of + // concurrent writes. See the concurrency section in the package + // documentation for more info. + + if c.isWriting { + panic("concurrent write to websocket connection") + } + c.isWriting = true + + err := c.write(w.frameType, c.writeDeadline, c.writeBuf[framePos:w.pos], extra) + + if !c.isWriting { + panic("concurrent write to websocket connection") + } + c.isWriting = false + + if err != nil { + return w.endMessage(err) + } + + if final { + _ = w.endMessage(errWriteClosed) + return nil + } + + // Setup for next frame. + w.pos = maxFrameHeaderSize + w.frameType = continuationFrame + return nil +} + +func (w *messageWriter) ncopy(max int) (int, error) { + n := len(w.c.writeBuf) - w.pos + if n <= 0 { + if err := w.flushFrame(false, nil); err != nil { + return 0, err + } + n = len(w.c.writeBuf) - w.pos + } + if n > max { + n = max + } + return n, nil +} + +func (w *messageWriter) Write(p []byte) (int, error) { + if w.err != nil { + return 0, w.err + } + + if len(p) > 2*len(w.c.writeBuf) && w.c.isServer { + // Don't buffer large messages. + err := w.flushFrame(false, p) + if err != nil { + return 0, err + } + return len(p), nil + } + + nn := len(p) + for len(p) > 0 { + n, err := w.ncopy(len(p)) + if err != nil { + return 0, err + } + copy(w.c.writeBuf[w.pos:], p[:n]) + w.pos += n + p = p[n:] + } + return nn, nil +} + +func (w *messageWriter) WriteString(p string) (int, error) { + if w.err != nil { + return 0, w.err + } + + nn := len(p) + for len(p) > 0 { + n, err := w.ncopy(len(p)) + if err != nil { + return 0, err + } + copy(w.c.writeBuf[w.pos:], p[:n]) + w.pos += n + p = p[n:] + } + return nn, nil +} + +func (w *messageWriter) ReadFrom(r io.Reader) (nn int64, err error) { + if w.err != nil { + return 0, w.err + } + for { + if w.pos == len(w.c.writeBuf) { + err = w.flushFrame(false, nil) + if err != nil { + break + } + } + var n int + n, err = r.Read(w.c.writeBuf[w.pos:]) + w.pos += n + nn += int64(n) + if err != nil { + if err == io.EOF { + err = nil + } + break + } + } + return nn, err +} + +func (w *messageWriter) Close() error { + if w.err != nil { + return w.err + } + return w.flushFrame(true, nil) +} + +// WritePreparedMessage writes prepared message into connection. +func (c *Conn) WritePreparedMessage(pm *PreparedMessage) error { + frameType, frameData, err := pm.frame(prepareKey{ + isServer: c.isServer, + compress: c.newCompressionWriter != nil && c.enableWriteCompression && isData(pm.messageType), + compressionLevel: c.compressionLevel, + }) + if err != nil { + return err + } + if c.isWriting { + panic("concurrent write to websocket connection") + } + c.isWriting = true + err = c.write(frameType, c.writeDeadline, frameData, nil) + if !c.isWriting { + panic("concurrent write to websocket connection") + } + c.isWriting = false + return err +} + +// WriteMessage is a helper method for getting a writer using NextWriter, +// writing the message and closing the writer. +func (c *Conn) WriteMessage(messageType int, data []byte) error { + + if c.isServer && (c.newCompressionWriter == nil || !c.enableWriteCompression) { + // Fast path with no allocations and single frame. + + var mw messageWriter + if err := c.beginMessage(&mw, messageType); err != nil { + return err + } + n := copy(c.writeBuf[mw.pos:], data) + mw.pos += n + data = data[n:] + return mw.flushFrame(true, data) + } + + w, err := c.NextWriter(messageType) + if err != nil { + return err + } + if _, err = w.Write(data); err != nil { + return err + } + return w.Close() +} + +// SetWriteDeadline sets the write deadline on the underlying network +// connection. After a write has timed out, the websocket state is corrupt and +// all future writes will return an error. A zero value for t means writes will +// not time out. +func (c *Conn) SetWriteDeadline(t time.Time) error { + c.writeDeadline = t + return nil +} + +// Read methods + +func (c *Conn) advanceFrame() (int, error) { + // 1. Skip remainder of previous frame. + + if c.readRemaining > 0 { + if _, err := io.CopyN(io.Discard, c.br, c.readRemaining); err != nil { + return noFrame, err + } + } + + // 2. Read and parse first two bytes of frame header. + // To aid debugging, collect and report all errors in the first two bytes + // of the header. + + var errors []string + + p, err := c.read(2) + if err != nil { + return noFrame, err + } + + frameType := int(p[0] & 0xf) + final := p[0]&finalBit != 0 + rsv1 := p[0]&rsv1Bit != 0 + rsv2 := p[0]&rsv2Bit != 0 + rsv3 := p[0]&rsv3Bit != 0 + mask := p[1]&maskBit != 0 + _ = c.setReadRemaining(int64(p[1] & 0x7f)) // will not fail because argument is >= 0 + + c.readDecompress = false + if rsv1 { + if c.newDecompressionReader != nil { + c.readDecompress = true + } else { + errors = append(errors, "RSV1 set") + } + } + + if rsv2 { + errors = append(errors, "RSV2 set") + } + + if rsv3 { + errors = append(errors, "RSV3 set") + } + + switch frameType { + case CloseMessage, PingMessage, PongMessage: + if c.readRemaining > maxControlFramePayloadSize { + errors = append(errors, "len > 125 for control") + } + if !final { + errors = append(errors, "FIN not set on control") + } + case TextMessage, BinaryMessage: + if !c.readFinal { + errors = append(errors, "data before FIN") + } + c.readFinal = final + case continuationFrame: + if c.readFinal { + errors = append(errors, "continuation after FIN") + } + c.readFinal = final + default: + errors = append(errors, "bad opcode "+strconv.Itoa(frameType)) + } + + if mask != c.isServer { + errors = append(errors, "bad MASK") + } + + if len(errors) > 0 { + return noFrame, c.handleProtocolError(strings.Join(errors, ", ")) + } + + // 3. Read and parse frame length as per + // https://tools.ietf.org/html/rfc6455#section-5.2 + // + // The length of the "Payload data", in bytes: if 0-125, that is the payload + // length. + // - If 126, the following 2 bytes interpreted as a 16-bit unsigned + // integer are the payload length. + // - If 127, the following 8 bytes interpreted as + // a 64-bit unsigned integer (the most significant bit MUST be 0) are the + // payload length. Multibyte length quantities are expressed in network byte + // order. + + switch c.readRemaining { + case 126: + p, err := c.read(2) + if err != nil { + return noFrame, err + } + + if err := c.setReadRemaining(int64(binary.BigEndian.Uint16(p))); err != nil { + return noFrame, err + } + case 127: + p, err := c.read(8) + if err != nil { + return noFrame, err + } + + if err := c.setReadRemaining(int64(binary.BigEndian.Uint64(p))); err != nil { + return noFrame, err + } + } + + // 4. Handle frame masking. + + if mask { + c.readMaskPos = 0 + p, err := c.read(len(c.readMaskKey)) + if err != nil { + return noFrame, err + } + copy(c.readMaskKey[:], p) + } + + // 5. For text and binary messages, enforce read limit and return. + + if frameType == continuationFrame || frameType == TextMessage || frameType == BinaryMessage { + + c.readLength += c.readRemaining + // Don't allow readLength to overflow in the presence of a large readRemaining + // counter. + if c.readLength < 0 { + return noFrame, ErrReadLimit + } + + if c.readLimit > 0 && c.readLength > c.readLimit { + // Make a best effort to send a close message describing the problem. + _ = c.WriteControl(CloseMessage, FormatCloseMessage(CloseMessageTooBig, ""), time.Now().Add(writeWait)) + return noFrame, ErrReadLimit + } + + return frameType, nil + } + + // 6. Read control frame payload. + + var payload []byte + if c.readRemaining > 0 { + payload, err = c.read(int(c.readRemaining)) + _ = c.setReadRemaining(0) // will not fail because argument is >= 0 + if err != nil { + return noFrame, err + } + if c.isServer { + maskBytes(c.readMaskKey, 0, payload) + } + } + + // 7. Process control frame payload. + + switch frameType { + case PongMessage: + if err := c.handlePong(string(payload)); err != nil { + return noFrame, err + } + case PingMessage: + if err := c.handlePing(string(payload)); err != nil { + return noFrame, err + } + case CloseMessage: + closeCode := CloseNoStatusReceived + closeText := "" + if len(payload) >= 2 { + closeCode = int(binary.BigEndian.Uint16(payload)) + if !isValidReceivedCloseCode(closeCode) { + return noFrame, c.handleProtocolError("bad close code " + strconv.Itoa(closeCode)) + } + closeText = string(payload[2:]) + if !utf8.ValidString(closeText) { + return noFrame, c.handleProtocolError("invalid utf8 payload in close frame") + } + } + if err := c.handleClose(closeCode, closeText); err != nil { + return noFrame, err + } + return noFrame, &CloseError{Code: closeCode, Text: closeText} + } + + return frameType, nil +} + +func (c *Conn) handleProtocolError(message string) error { + data := FormatCloseMessage(CloseProtocolError, message) + if len(data) > maxControlFramePayloadSize { + data = data[:maxControlFramePayloadSize] + } + // Make a best effor to send a close message describing the problem. + _ = c.WriteControl(CloseMessage, data, time.Now().Add(writeWait)) + return errors.New("websocket: " + message) +} + +// NextReader returns the next data message received from the peer. The +// returned messageType is either TextMessage or BinaryMessage. +// +// There can be at most one open reader on a connection. NextReader discards +// the previous message if the application has not already consumed it. +// +// Applications must break out of the application's read loop when this method +// returns a non-nil error value. Errors returned from this method are +// permanent. Once this method returns a non-nil error, all subsequent calls to +// this method return the same error. +func (c *Conn) NextReader() (messageType int, r io.Reader, err error) { + // Close previous reader, only relevant for decompression. + if c.reader != nil { + c.reader.Close() + c.reader = nil + } + + c.messageReader = nil + c.readLength = 0 + + for c.readErr == nil { + frameType, err := c.advanceFrame() + if err != nil { + c.readErr = err + break + } + + if frameType == TextMessage || frameType == BinaryMessage { + c.messageReader = &messageReader{c} + c.reader = c.messageReader + if c.readDecompress { + c.reader = c.newDecompressionReader(c.reader) + } + return frameType, c.reader, nil + } + } + + // Applications that do handle the error returned from this method spin in + // tight loop on connection failure. To help application developers detect + // this error, panic on repeated reads to the failed connection. + c.readErrCount++ + if c.readErrCount >= 1000 { + panic("repeated read on failed websocket connection") + } + + return noFrame, nil, c.readErr +} + +type messageReader struct{ c *Conn } + +func (r *messageReader) Read(b []byte) (int, error) { + c := r.c + if c.messageReader != r { + return 0, io.EOF + } + + for c.readErr == nil { + + if c.readRemaining > 0 { + if int64(len(b)) > c.readRemaining { + b = b[:c.readRemaining] + } + n, err := c.br.Read(b) + c.readErr = err + if c.isServer { + c.readMaskPos = maskBytes(c.readMaskKey, c.readMaskPos, b[:n]) + } + rem := c.readRemaining + rem -= int64(n) + _ = c.setReadRemaining(rem) // rem is guaranteed to be >= 0 + if c.readRemaining > 0 && c.readErr == io.EOF { + c.readErr = errUnexpectedEOF + } + return n, c.readErr + } + + if c.readFinal { + c.messageReader = nil + return 0, io.EOF + } + + frameType, err := c.advanceFrame() + switch { + case err != nil: + c.readErr = err + case frameType == TextMessage || frameType == BinaryMessage: + c.readErr = errors.New("websocket: internal error, unexpected text or binary in Reader") + } + } + + err := c.readErr + if err == io.EOF && c.messageReader == r { + err = errUnexpectedEOF + } + return 0, err +} + +func (r *messageReader) Close() error { + return nil +} + +// ReadMessage is a helper method for getting a reader using NextReader and +// reading from that reader to a buffer. +func (c *Conn) ReadMessage() (messageType int, p []byte, err error) { + var r io.Reader + messageType, r, err = c.NextReader() + if err != nil { + return messageType, nil, err + } + p, err = io.ReadAll(r) + return messageType, p, err +} + +// SetReadDeadline sets the read deadline on the underlying network connection. +// After a read has timed out, the websocket connection state is corrupt and +// all future reads will return an error. A zero value for t means reads will +// not time out. +func (c *Conn) SetReadDeadline(t time.Time) error { + return c.conn.SetReadDeadline(t) +} + +// SetReadLimit sets the maximum size in bytes for a message read from the peer. If a +// message exceeds the limit, the connection sends a close message to the peer +// and returns ErrReadLimit to the application. +func (c *Conn) SetReadLimit(limit int64) { + c.readLimit = limit +} + +// CloseHandler returns the current close handler +func (c *Conn) CloseHandler() func(code int, text string) error { + return c.handleClose +} + +// SetCloseHandler sets the handler for close messages received from the peer. +// The code argument to h is the received close code or CloseNoStatusReceived +// if the close message is empty. The default close handler sends a close +// message back to the peer. +// +// The handler function is called from the NextReader, ReadMessage and message +// reader Read methods. The application must read the connection to process +// close messages as described in the section on Control Messages above. +// +// The connection read methods return a CloseError when a close message is +// received. Most applications should handle close messages as part of their +// normal error handling. Applications should only set a close handler when the +// application must perform some action before sending a close message back to +// the peer. +func (c *Conn) SetCloseHandler(h func(code int, text string) error) { + if h == nil { + h = func(code int, text string) error { + message := FormatCloseMessage(code, "") + // Make a best effor to send the close message. + _ = c.WriteControl(CloseMessage, message, time.Now().Add(writeWait)) + return nil + } + } + c.handleClose = h +} + +// PingHandler returns the current ping handler +func (c *Conn) PingHandler() func(appData string) error { + return c.handlePing +} + +// SetPingHandler sets the handler for ping messages received from the peer. +// The appData argument to h is the PING message application data. The default +// ping handler sends a pong to the peer. +// +// The handler function is called from the NextReader, ReadMessage and message +// reader Read methods. The application must read the connection to process +// ping messages as described in the section on Control Messages above. +func (c *Conn) SetPingHandler(h func(appData string) error) { + if h == nil { + h = func(message string) error { + // Make a best effort to send the pong message. + _ = c.WriteControl(PongMessage, []byte(message), time.Now().Add(writeWait)) + return nil + } + } + c.handlePing = h +} + +// PongHandler returns the current pong handler +func (c *Conn) PongHandler() func(appData string) error { + return c.handlePong +} + +// SetPongHandler sets the handler for pong messages received from the peer. +// The appData argument to h is the PONG message application data. The default +// pong handler does nothing. +// +// The handler function is called from the NextReader, ReadMessage and message +// reader Read methods. The application must read the connection to process +// pong messages as described in the section on Control Messages above. +func (c *Conn) SetPongHandler(h func(appData string) error) { + if h == nil { + h = func(string) error { return nil } + } + c.handlePong = h +} + +// NetConn returns the underlying connection that is wrapped by c. +// Note that writing to or reading from this connection directly will corrupt the +// WebSocket connection. +func (c *Conn) NetConn() net.Conn { + return c.conn +} + +// UnderlyingConn returns the internal net.Conn. This can be used to further +// modifications to connection specific flags. +// Deprecated: Use the NetConn method. +func (c *Conn) UnderlyingConn() net.Conn { + return c.conn +} + +// EnableWriteCompression enables and disables write compression of +// subsequent text and binary messages. This function is a noop if +// compression was not negotiated with the peer. +func (c *Conn) EnableWriteCompression(enable bool) { + c.enableWriteCompression = enable +} + +// SetCompressionLevel sets the flate compression level for subsequent text and +// binary messages. This function is a noop if compression was not negotiated +// with the peer. See the compress/flate package for a description of +// compression levels. +func (c *Conn) SetCompressionLevel(level int) error { + if !isValidCompressionLevel(level) { + return errors.New("websocket: invalid compression level") + } + c.compressionLevel = level + return nil +} + +// FormatCloseMessage formats closeCode and text as a WebSocket close message. +// An empty message is returned for code CloseNoStatusReceived. +func FormatCloseMessage(closeCode int, text string) []byte { + if closeCode == CloseNoStatusReceived { + // Return empty message because it's illegal to send + // CloseNoStatusReceived. Return non-nil value in case application + // checks for nil. + return []byte{} + } + buf := make([]byte, 2+len(text)) + binary.BigEndian.PutUint16(buf, uint16(closeCode)) + copy(buf[2:], text) + return buf +} +// Copyright 2013 The Gorilla WebSocket Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Package websocket implements the WebSocket protocol defined in RFC 6455. +// +// Overview +// +// The Conn type represents a WebSocket connection. A server application calls +// the Upgrader.Upgrade method from an HTTP request handler to get a *Conn: +// +// var upgrader = websocket.Upgrader{ +// ReadBufferSize: 1024, +// WriteBufferSize: 1024, +// } +// +// func handler(w http.ResponseWriter, r *http.Request) { +// conn, err := upgrader.Upgrade(w, r, nil) +// if err != nil { +// log.Println(err) +// return +// } +// ... Use conn to send and receive messages. +// } +// +// Call the connection's WriteMessage and ReadMessage methods to send and +// receive messages as a slice of bytes. This snippet of code shows how to echo +// messages using these methods: +// +// for { +// messageType, p, err := conn.ReadMessage() +// if err != nil { +// log.Println(err) +// return +// } +// if err := conn.WriteMessage(messageType, p); err != nil { +// log.Println(err) +// return +// } +// } +// +// In above snippet of code, p is a []byte and messageType is an int with value +// websocket.BinaryMessage or websocket.TextMessage. +// +// An application can also send and receive messages using the io.WriteCloser +// and io.Reader interfaces. To send a message, call the connection NextWriter +// method to get an io.WriteCloser, write the message to the writer and close +// the writer when done. To receive a message, call the connection NextReader +// method to get an io.Reader and read until io.EOF is returned. This snippet +// shows how to echo messages using the NextWriter and NextReader methods: +// +// for { +// messageType, r, err := conn.NextReader() +// if err != nil { +// return +// } +// w, err := conn.NextWriter(messageType) +// if err != nil { +// return err +// } +// if _, err := io.Copy(w, r); err != nil { +// return err +// } +// if err := w.Close(); err != nil { +// return err +// } +// } +// +// Data Messages +// +// The WebSocket protocol distinguishes between text and binary data messages. +// Text messages are interpreted as UTF-8 encoded text. The interpretation of +// binary messages is left to the application. +// +// This package uses the TextMessage and BinaryMessage integer constants to +// identify the two data message types. The ReadMessage and NextReader methods +// return the type of the received message. The messageType argument to the +// WriteMessage and NextWriter methods specifies the type of a sent message. +// +// It is the application's responsibility to ensure that text messages are +// valid UTF-8 encoded text. +// +// Control Messages +// +// The WebSocket protocol defines three types of control messages: close, ping +// and pong. Call the connection WriteControl, WriteMessage or NextWriter +// methods to send a control message to the peer. +// +// Connections handle received close messages by calling the handler function +// set with the SetCloseHandler method and by returning a *CloseError from the +// NextReader, ReadMessage or the message Read method. The default close +// handler sends a close message to the peer. +// +// Connections handle received ping messages by calling the handler function +// set with the SetPingHandler method. The default ping handler sends a pong +// message to the peer. +// +// Connections handle received pong messages by calling the handler function +// set with the SetPongHandler method. The default pong handler does nothing. +// If an application sends ping messages, then the application should set a +// pong handler to receive the corresponding pong. +// +// The control message handler functions are called from the NextReader, +// ReadMessage and message reader Read methods. The default close and ping +// handlers can block these methods for a short time when the handler writes to +// the connection. +// +// The application must read the connection to process close, ping and pong +// messages sent from the peer. If the application is not otherwise interested +// in messages from the peer, then the application should start a goroutine to +// read and discard messages from the peer. A simple example is: +// +// func readLoop(c *websocket.Conn) { +// for { +// if _, _, err := c.NextReader(); err != nil { +// c.Close() +// break +// } +// } +// } +// +// Concurrency +// +// Connections support one concurrent reader and one concurrent writer. +// +// Applications are responsible for ensuring that no more than one goroutine +// calls the write methods (NextWriter, SetWriteDeadline, WriteMessage, +// WriteJSON, EnableWriteCompression, SetCompressionLevel) concurrently and +// that no more than one goroutine calls the read methods (NextReader, +// SetReadDeadline, ReadMessage, ReadJSON, SetPongHandler, SetPingHandler) +// concurrently. +// +// The Close and WriteControl methods can be called concurrently with all other +// methods. +// +// Origin Considerations +// +// Web browsers allow Javascript applications to open a WebSocket connection to +// any host. It's up to the server to enforce an origin policy using the Origin +// request header sent by the browser. +// +// The Upgrader calls the function specified in the CheckOrigin field to check +// the origin. If the CheckOrigin function returns false, then the Upgrade +// method fails the WebSocket handshake with HTTP status 403. +// +// If the CheckOrigin field is nil, then the Upgrader uses a safe default: fail +// the handshake if the Origin request header is present and the Origin host is +// not equal to the Host request header. +// +// The deprecated package-level Upgrade function does not perform origin +// checking. The application is responsible for checking the Origin header +// before calling the Upgrade function. +// +// Buffers +// +// Connections buffer network input and output to reduce the number +// of system calls when reading or writing messages. +// +// Write buffers are also used for constructing WebSocket frames. See RFC 6455, +// Section 5 for a discussion of message framing. A WebSocket frame header is +// written to the network each time a write buffer is flushed to the network. +// Decreasing the size of the write buffer can increase the amount of framing +// overhead on the connection. +// +// The buffer sizes in bytes are specified by the ReadBufferSize and +// WriteBufferSize fields in the Dialer and Upgrader. The Dialer uses a default +// size of 4096 when a buffer size field is set to zero. The Upgrader reuses +// buffers created by the HTTP server when a buffer size field is set to zero. +// The HTTP server buffers have a size of 4096 at the time of this writing. +// +// The buffer sizes do not limit the size of a message that can be read or +// written by a connection. +// +// Buffers are held for the lifetime of the connection by default. If the +// Dialer or Upgrader WriteBufferPool field is set, then a connection holds the +// write buffer only when writing a message. +// +// Applications should tune the buffer sizes to balance memory use and +// performance. Increasing the buffer size uses more memory, but can reduce the +// number of system calls to read or write the network. In the case of writing, +// increasing the buffer size can reduce the number of frame headers written to +// the network. +// +// Some guidelines for setting buffer parameters are: +// +// Limit the buffer sizes to the maximum expected message size. Buffers larger +// than the largest message do not provide any benefit. +// +// Depending on the distribution of message sizes, setting the buffer size to +// a value less than the maximum expected message size can greatly reduce memory +// use with a small impact on performance. Here's an example: If 99% of the +// messages are smaller than 256 bytes and the maximum message size is 512 +// bytes, then a buffer size of 256 bytes will result in 1.01 more system calls +// than a buffer size of 512 bytes. The memory savings is 50%. +// +// A write buffer pool is useful when the application has a modest number +// writes over a large number of connections. when buffers are pooled, a larger +// buffer size has a reduced impact on total memory use and has the benefit of +// reducing system calls and frame overhead. +// +// Compression EXPERIMENTAL +// +// Per message compression extensions (RFC 7692) are experimentally supported +// by this package in a limited capacity. Setting the EnableCompression option +// to true in Dialer or Upgrader will attempt to negotiate per message deflate +// support. +// +// var upgrader = websocket.Upgrader{ +// EnableCompression: true, +// } +// +// If compression was successfully negotiated with the connection's peer, any +// message received in compressed form will be automatically decompressed. +// All Read methods will return uncompressed bytes. +// +// Per message compression of messages written to a connection can be enabled +// or disabled by calling the corresponding Conn method: +// +// conn.EnableWriteCompression(false) +// +// Currently this package does not support compression with "context takeover". +// This means that messages must be compressed and decompressed in isolation, +// without retaining sliding window or dictionary state across messages. For +// more details refer to RFC 7692. +// +// Use of compression is experimental and may result in decreased performance. + +// Copyright 2019 The Gorilla WebSocket Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + + +// JoinMessages concatenates received messages to create a single io.Reader. +// The string term is appended to each message. The returned reader does not +// support concurrent calls to the Read method. +func JoinMessages(c *Conn, term string) io.Reader { + return &joinReader{c: c, term: term} +} + +type joinReader struct { + c *Conn + term string + r io.Reader +} + +func (r *joinReader) Read(p []byte) (int, error) { + if r.r == nil { + var err error + _, r.r, err = r.c.NextReader() + if err != nil { + return 0, err + } + if r.term != "" { + r.r = io.MultiReader(r.r, strings.NewReader(r.term)) + } + } + n, err := r.r.Read(p) + if err == io.EOF { + err = nil + r.r = nil + } + return n, err +} +// Copyright 2013 The Gorilla WebSocket Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + + +// WriteJSON writes the JSON encoding of v as a message. +// +// Deprecated: Use c.WriteJSON instead. +func WriteJSON(c *Conn, v interface{}) error { + return c.WriteJSON(v) +} + +// WriteJSON writes the JSON encoding of v as a message. +// +// See the documentation for encoding/json Marshal for details about the +// conversion of Go values to JSON. +func (c *Conn) WriteJSON(v interface{}) error { + w, err := c.NextWriter(TextMessage) + if err != nil { + return err + } + err1 := json.NewEncoder(w).Encode(v) + err2 := w.Close() + if err1 != nil { + return err1 + } + return err2 +} + +// ReadJSON reads the next JSON-encoded message from the connection and stores +// it in the value pointed to by v. +// +// Deprecated: Use c.ReadJSON instead. +func ReadJSON(c *Conn, v interface{}) error { + return c.ReadJSON(v) +} + +// ReadJSON reads the next JSON-encoded message from the connection and stores +// it in the value pointed to by v. +// +// See the documentation for the encoding/json Unmarshal function for details +// about the conversion of JSON to a Go value. +func (c *Conn) ReadJSON(v interface{}) error { + _, r, err := c.NextReader() + if err != nil { + return err + } + err = json.NewDecoder(r).Decode(v) + if err == io.EOF { + // One value is expected in the message. + err = io.ErrUnexpectedEOF + } + return err +} +// Copyright 2016 The Gorilla WebSocket Authors. All rights reserved. Use of +// this source code is governed by a BSD-style license that can be found in the +// LICENSE file. + +/// go:build !appengine // +build !appengine + + +const wordSize = int(unsafe.Sizeof(uintptr(0))) + +func maskBytesUnsafe(key [4]byte, pos int, b []byte) int { + // Mask one byte at a time for small buffers. + if len(b) < 2*wordSize { + for i := range b { + b[i] ^= key[pos&3] + pos++ + } + return pos & 3 + } + + // Mask one byte at a time to word boundary. + if n := int(uintptr(unsafe.Pointer(&b[0]))) % wordSize; n != 0 { + n = wordSize - n + for i := range b[:n] { + b[i] ^= key[pos&3] + pos++ + } + b = b[n:] + } + + // Create aligned word size key. + var k [wordSize]byte + for i := range k { + k[i] = key[(pos+i)&3] + } + kw := *(*uintptr)(unsafe.Pointer(&k)) + + // Mask one word at a time. + n := (len(b) / wordSize) * wordSize + for i := 0; i < n; i += wordSize { + *(*uintptr)(unsafe.Pointer(uintptr(unsafe.Pointer(&b[0])) + uintptr(i))) ^= kw + } + + // Mask one byte at a time for remaining bytes. + b = b[n:] + for i := range b { + b[i] ^= key[pos&3] + pos++ + } + + return pos & 3 +} +// Copyright 2016 The Gorilla WebSocket Authors. All rights reserved. Use of +// this source code is governed by a BSD-style license that can be found in the +// LICENSE file. + +/// go:build appengine // +build appengine + + +func maskBytes(key [4]byte, pos int, b []byte) int { + for i := range b { + b[i] ^= key[pos&3] + pos++ + } + return pos & 3 +} +// Copyright 2017 The Gorilla WebSocket Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + + +// PreparedMessage caches on the wire representations of a message payload. +// Use PreparedMessage to efficiently send a message payload to multiple +// connections. PreparedMessage is especially useful when compression is used +// because the CPU and memory expensive compression operation can be executed +// once for a given set of compression options. +type PreparedMessage struct { + messageType int + data []byte + mu sync.Mutex + frames map[prepareKey]*preparedFrame +} + +// prepareKey defines a unique set of options to cache prepared frames in PreparedMessage. +type prepareKey struct { + isServer bool + compress bool + compressionLevel int +} + +// preparedFrame contains data in wire representation. +type preparedFrame struct { + once sync.Once + data []byte +} + +// NewPreparedMessage returns an initialized PreparedMessage. You can then send +// it to connection using WritePreparedMessage method. Valid wire +// representation will be calculated lazily only once for a set of current +// connection options. +func NewPreparedMessage(messageType int, data []byte) (*PreparedMessage, error) { + pm := &PreparedMessage{ + messageType: messageType, + frames: make(map[prepareKey]*preparedFrame), + data: data, + } + + // Prepare a plain server frame. + _, frameData, err := pm.frame(prepareKey{isServer: true, compress: false}) + if err != nil { + return nil, err + } + + // To protect against caller modifying the data argument, remember the data + // copied to the plain server frame. + pm.data = frameData[len(frameData)-len(data):] + return pm, nil +} + +func (pm *PreparedMessage) frame(key prepareKey) (int, []byte, error) { + pm.mu.Lock() + frame, ok := pm.frames[key] + if !ok { + frame = &preparedFrame{} + pm.frames[key] = frame + } + pm.mu.Unlock() + + var err error + frame.once.Do(func() { + // Prepare a frame using a 'fake' connection. + // TODO: Refactor code in conn.go to allow more direct construction of + // the frame. + mu := make(chan struct{}, 1) + mu <- struct{}{} + var nc prepareConn + c := &Conn{ + conn: &nc, + mu: mu, + isServer: key.isServer, + compressionLevel: key.compressionLevel, + enableWriteCompression: true, + writeBuf: make([]byte, defaultWriteBufferSize+maxFrameHeaderSize), + } + if key.compress { + c.newCompressionWriter = compressNoContextTakeover + } + err = c.WriteMessage(pm.messageType, pm.data) + frame.data = nc.buf.Bytes() + }) + return pm.messageType, frame.data, err +} + +type prepareConn struct { + buf bytes.Buffer + net.Conn +} + +func (pc *prepareConn) Write(p []byte) (int, error) { return pc.buf.Write(p) } +func (pc *prepareConn) SetWriteDeadline(t time.Time) error { return nil } +// Copyright 2017 The Gorilla WebSocket Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + + +type netDialerFunc func(ctx context.Context, network, addr string) (net.Conn, error) + +func (fn netDialerFunc) Dial(network, addr string) (net.Conn, error) { + return fn(context.Background(), network, addr) +} + +func (fn netDialerFunc) DialContext(ctx context.Context, network, addr string) (net.Conn, error) { + return fn(ctx, network, addr) +} + +func proxyFromURL(proxyURL *url.URL, forwardDial netDialerFunc) (netDialerFunc, error) { + if proxyURL.Scheme == "http" { + return (&httpProxyDialer{proxyURL: proxyURL, forwardDial: forwardDial}).DialContext, nil + } + dialer, err := proxy.FromURL(proxyURL, forwardDial) + if err != nil { + return nil, err + } + if d, ok := dialer.(proxy.ContextDialer); ok { + return d.DialContext, nil + } + return func(ctx context.Context, net, addr string) (net.Conn, error) { + return dialer.Dial(net, addr) + }, nil +} + +type httpProxyDialer struct { + proxyURL *url.URL + forwardDial netDialerFunc +} + +func (hpd *httpProxyDialer) DialContext(ctx context.Context, network string, addr string) (net.Conn, error) { + hostPort, _ := hostPortNoPort(hpd.proxyURL) + conn, err := hpd.forwardDial(ctx, network, hostPort) + if err != nil { + return nil, err + } + + connectHeader := make(http.Header) + if user := hpd.proxyURL.User; user != nil { + proxyUser := user.Username() + if proxyPassword, passwordSet := user.Password(); passwordSet { + credential := base64.StdEncoding.EncodeToString([]byte(proxyUser + ":" + proxyPassword)) + connectHeader.Set("Proxy-Authorization", "Basic "+credential) + } + } + + connectReq := &http.Request{ + Method: http.MethodConnect, + URL: &url.URL{Opaque: addr}, + Host: addr, + Header: connectHeader, + } + + if err := connectReq.Write(conn); err != nil { + conn.Close() + return nil, err + } + + // Read response. It's OK to use and discard buffered reader here because + // the remote server does not speak until spoken to. + br := bufio.NewReader(conn) + resp, err := http.ReadResponse(br, connectReq) + if err != nil { + conn.Close() + return nil, err + } + + // Close the response body to silence false positives from linters. Reset + // the buffered reader first to ensure that Close() does not read from + // conn. + // Note: Applications must call resp.Body.Close() on a response returned + // http.ReadResponse to inspect trailers or read another response from the + // buffered reader. The call to resp.Body.Close() does not release + // resources. + br.Reset(bytes.NewReader(nil)) + _ = resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + _ = conn.Close() + f := strings.SplitN(resp.Status, " ", 2) + return nil, errors.New(f[1]) + } + return conn, nil +} +// Copyright 2013 The Gorilla WebSocket Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + + +// HandshakeError describes an error with the handshake from the peer. +type HandshakeError struct { + message string +} + +func (e HandshakeError) Error() string { return e.message } + +// Upgrader specifies parameters for upgrading an HTTP connection to a +// WebSocket connection. +// +// It is safe to call Upgrader's methods concurrently. +type Upgrader struct { + // HandshakeTimeout specifies the duration for the handshake to complete. + HandshakeTimeout time.Duration + + // ReadBufferSize and WriteBufferSize specify I/O buffer sizes in bytes. If a buffer + // size is zero, then buffers allocated by the HTTP server are used. The + // I/O buffer sizes do not limit the size of the messages that can be sent + // or received. + ReadBufferSize, WriteBufferSize int + + // WriteBufferPool is a pool of buffers for write operations. If the value + // is not set, then write buffers are allocated to the connection for the + // lifetime of the connection. + // + // A pool is most useful when the application has a modest volume of writes + // across a large number of connections. + // + // Applications should use a single pool for each unique value of + // WriteBufferSize. + WriteBufferPool BufferPool + + // Subprotocols specifies the server's supported protocols in order of + // preference. If this field is not nil, then the Upgrade method negotiates a + // subprotocol by selecting the first match in this list with a protocol + // requested by the client. If there's no match, then no protocol is + // negotiated (the Sec-Websocket-Protocol header is not included in the + // handshake response). + Subprotocols []string + + // Error specifies the function for generating HTTP error responses. If Error + // is nil, then http.Error is used to generate the HTTP response. + Error func(w http.ResponseWriter, r *http.Request, status int, reason error) + + // CheckOrigin returns true if the request Origin header is acceptable. If + // CheckOrigin is nil, then a safe default is used: return false if the + // Origin request header is present and the origin host is not equal to + // request Host header. + // + // A CheckOrigin function should carefully validate the request origin to + // prevent cross-site request forgery. + CheckOrigin func(r *http.Request) bool + + // EnableCompression specify if the server should attempt to negotiate per + // message compression (RFC 7692). Setting this value to true does not + // guarantee that compression will be supported. Currently only "no context + // takeover" modes are supported. + EnableCompression bool +} + +func (u *Upgrader) returnError(w http.ResponseWriter, r *http.Request, status int, reason string) (*Conn, error) { + err := HandshakeError{reason} + if u.Error != nil { + u.Error(w, r, status, err) + } else { + w.Header().Set("Sec-Websocket-Version", "13") + http.Error(w, http.StatusText(status), status) + } + return nil, err +} + +// checkSameOrigin returns true if the origin is not set or is equal to the request host. +func checkSameOrigin(r *http.Request) bool { + origin := r.Header["Origin"] + if len(origin) == 0 { + return true + } + u, err := url.Parse(origin[0]) + if err != nil { + return false + } + return equalASCIIFold(u.Host, r.Host) +} + +func (u *Upgrader) selectSubprotocol(r *http.Request, responseHeader http.Header) string { + if u.Subprotocols != nil { + clientProtocols := Subprotocols(r) + for _, clientProtocol := range clientProtocols { + for _, serverProtocol := range u.Subprotocols { + if clientProtocol == serverProtocol { + return clientProtocol + } + } + } + } else if responseHeader != nil { + return responseHeader.Get("Sec-Websocket-Protocol") + } + return "" +} + +// Upgrade upgrades the HTTP server connection to the WebSocket protocol. +// +// The responseHeader is included in the response to the client's upgrade +// request. Use the responseHeader to specify cookies (Set-Cookie). To specify +// subprotocols supported by the server, set Upgrader.Subprotocols directly. +// +// If the upgrade fails, then Upgrade replies to the client with an HTTP error +// response. +func (u *Upgrader) Upgrade(w http.ResponseWriter, r *http.Request, responseHeader http.Header) (*Conn, error) { + const badHandshake = "websocket: the client is not using the websocket protocol: " + + if !tokenListContainsValue(r.Header, "Connection", "upgrade") { + return u.returnError(w, r, http.StatusBadRequest, badHandshake+"'upgrade' token not found in 'Connection' header") + } + + if !tokenListContainsValue(r.Header, "Upgrade", "websocket") { + w.Header().Set("Upgrade", "websocket") + return u.returnError(w, r, http.StatusUpgradeRequired, badHandshake+"'websocket' token not found in 'Upgrade' header") + } + + if r.Method != http.MethodGet { + return u.returnError(w, r, http.StatusMethodNotAllowed, badHandshake+"request method is not GET") + } + + if !tokenListContainsValue(r.Header, "Sec-Websocket-Version", "13") { + return u.returnError(w, r, http.StatusBadRequest, "websocket: unsupported version: 13 not found in 'Sec-Websocket-Version' header") + } + + if _, ok := responseHeader["Sec-Websocket-Extensions"]; ok { + return u.returnError(w, r, http.StatusInternalServerError, "websocket: application specific 'Sec-WebSocket-Extensions' headers are unsupported") + } + + checkOrigin := u.CheckOrigin + if checkOrigin == nil { + checkOrigin = checkSameOrigin + } + if !checkOrigin(r) { + return u.returnError(w, r, http.StatusForbidden, "websocket: request origin not allowed by Upgrader.CheckOrigin") + } + + challengeKey := r.Header.Get("Sec-Websocket-Key") + if !isValidChallengeKey(challengeKey) { + return u.returnError(w, r, http.StatusBadRequest, "websocket: not a websocket handshake: 'Sec-WebSocket-Key' header must be Base64 encoded value of 16-byte in length") + } + + subprotocol := u.selectSubprotocol(r, responseHeader) + + // Negotiate PMCE + var compress bool + if u.EnableCompression { + for _, ext := range parseExtensions(r.Header) { + if ext[""] != "permessage-deflate" { + continue + } + compress = true + break + } + } + + netConn, brw, err := http.NewResponseController(w).Hijack() + if err != nil { + return u.returnError(w, r, http.StatusInternalServerError, + "websocket: hijack: "+err.Error()) + } + + // Close the network connection when returning an error. The variable + // netConn is set to nil before the success return at the end of the + // function. + defer func() { + if netConn != nil { + // It's safe to ignore the error from Close() because this code is + // only executed when returning a more important error to the + // application. + _ = netConn.Close() + } + }() + + var br *bufio.Reader + if u.ReadBufferSize == 0 && brw.Reader.Size() > 256 { + // Use hijacked buffered reader as the connection reader. + br = brw.Reader + } else if brw.Reader.Buffered() > 0 { + // Wrap the network connection to read buffered data in brw.Reader + // before reading from the network connection. This should be rare + // because a client must not send message data before receiving the + // handshake response. + netConn = &brNetConn{br: brw.Reader, Conn: netConn} + } + + buf := brw.Writer.AvailableBuffer() + + var writeBuf []byte + if u.WriteBufferPool == nil && u.WriteBufferSize == 0 && len(buf) >= maxFrameHeaderSize+256 { + // Reuse hijacked write buffer as connection buffer. + writeBuf = buf + } + + c := newConn(netConn, true, u.ReadBufferSize, u.WriteBufferSize, u.WriteBufferPool, br, writeBuf) + c.subprotocol = subprotocol + + if compress { + c.newCompressionWriter = compressNoContextTakeover + c.newDecompressionReader = decompressNoContextTakeover + } + + // Use larger of hijacked buffer and connection write buffer for header. + p := buf + if len(c.writeBuf) > len(p) { + p = c.writeBuf + } + p = p[:0] + + p = append(p, "HTTP/1.1 101 Switching Protocols\r\nUpgrade: websocket\r\nConnection: Upgrade\r\nSec-WebSocket-Accept: "...) + p = append(p, computeAcceptKey(challengeKey)...) + p = append(p, "\r\n"...) + if c.subprotocol != "" { + p = append(p, "Sec-WebSocket-Protocol: "...) + p = append(p, c.subprotocol...) + p = append(p, "\r\n"...) + } + if compress { + p = append(p, "Sec-WebSocket-Extensions: permessage-deflate; server_no_context_takeover; client_no_context_takeover\r\n"...) + } + for k, vs := range responseHeader { + if k == "Sec-Websocket-Protocol" { + continue + } + for _, v := range vs { + p = append(p, k...) + p = append(p, ": "...) + for i := 0; i < len(v); i++ { + b := v[i] + if b <= 31 { + // prevent response splitting. + b = ' ' + } + p = append(p, b) + } + p = append(p, "\r\n"...) + } + } + p = append(p, "\r\n"...) + + if u.HandshakeTimeout > 0 { + if err := netConn.SetWriteDeadline(time.Now().Add(u.HandshakeTimeout)); err != nil { + return nil, err + } + } else { + // Clear deadlines set by HTTP server. + if err := netConn.SetDeadline(time.Time{}); err != nil { + return nil, err + } + } + + if _, err = netConn.Write(p); err != nil { + return nil, err + } + if u.HandshakeTimeout > 0 { + if err := netConn.SetWriteDeadline(time.Time{}); err != nil { + return nil, err + } + } + + // Success! Set netConn to nil to stop the deferred function above from + // closing the network connection. + netConn = nil + + return c, nil +} + +// Upgrade upgrades the HTTP server connection to the WebSocket protocol. +// +// Deprecated: Use websocket.Upgrader instead. +// +// Upgrade does not perform origin checking. The application is responsible for +// checking the Origin header before calling Upgrade. An example implementation +// of the same origin policy check is: +// +// if req.Header.Get("Origin") != "http://"+req.Host { +// http.Error(w, "Origin not allowed", http.StatusForbidden) +// return +// } +// +// If the endpoint supports subprotocols, then the application is responsible +// for negotiating the protocol used on the connection. Use the Subprotocols() +// function to get the subprotocols requested by the client. Use the +// Sec-Websocket-Protocol response header to specify the subprotocol selected +// by the application. +// +// The responseHeader is included in the response to the client's upgrade +// request. Use the responseHeader to specify cookies (Set-Cookie) and the +// negotiated subprotocol (Sec-Websocket-Protocol). +// +// The connection buffers IO to the underlying network connection. The +// readBufSize and writeBufSize parameters specify the size of the buffers to +// use. Messages can be larger than the buffers. +// +// If the request is not a valid WebSocket handshake, then Upgrade returns an +// error of type HandshakeError. Applications should handle this error by +// replying to the client with an HTTP error response. +func Upgrade(w http.ResponseWriter, r *http.Request, responseHeader http.Header, readBufSize, writeBufSize int) (*Conn, error) { + u := Upgrader{ReadBufferSize: readBufSize, WriteBufferSize: writeBufSize} + u.Error = func(w http.ResponseWriter, r *http.Request, status int, reason error) { + // don't return errors to maintain backwards compatibility + } + u.CheckOrigin = func(r *http.Request) bool { + // allow all connections by default + return true + } + return u.Upgrade(w, r, responseHeader) +} + +// Subprotocols returns the subprotocols requested by the client in the +// Sec-Websocket-Protocol header. +func Subprotocols(r *http.Request) []string { + h := strings.TrimSpace(r.Header.Get("Sec-Websocket-Protocol")) + if h == "" { + return nil + } + protocols := strings.Split(h, ",") + for i := range protocols { + protocols[i] = strings.TrimSpace(protocols[i]) + } + return protocols +} + +// IsWebSocketUpgrade returns true if the client requested upgrade to the +// WebSocket protocol. +func IsWebSocketUpgrade(r *http.Request) bool { + return tokenListContainsValue(r.Header, "Connection", "upgrade") && + tokenListContainsValue(r.Header, "Upgrade", "websocket") +} + +type brNetConn struct { + br *bufio.Reader + net.Conn +} + +func (b *brNetConn) Read(p []byte) (n int, err error) { + if b.br != nil { + // Limit read to buferred data. + if n := b.br.Buffered(); len(p) > n { + p = p[:n] + } + n, err = b.br.Read(p) + if b.br.Buffered() == 0 { + b.br = nil + } + return n, err + } + return b.Conn.Read(p) +} + +// NetConn returns the underlying connection that is wrapped by b. +func (b *brNetConn) NetConn() net.Conn { + return b.Conn +} + +// Copyright 2013 The Gorilla WebSocket Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + + +var keyGUID = []byte("258EAFA5-E914-47DA-95CA-C5AB0DC85B11") + +func computeAcceptKey(challengeKey string) string { + h := sha1.New() + h.Write([]byte(challengeKey)) + h.Write(keyGUID) + return base64.StdEncoding.EncodeToString(h.Sum(nil)) +} + +func generateChallengeKey() (string, error) { + p := make([]byte, 16) + if _, err := io.ReadFull(rand.Reader, p); err != nil { + return "", err + } + return base64.StdEncoding.EncodeToString(p), nil +} + +// Token octets per RFC 2616. +var isTokenOctet = [256]bool{ + '!': true, + '#': true, + '$': true, + '%': true, + '&': true, + '\'': true, + '*': true, + '+': true, + '-': true, + '.': true, + '0': true, + '1': true, + '2': true, + '3': true, + '4': true, + '5': true, + '6': true, + '7': true, + '8': true, + '9': true, + 'A': true, + 'B': true, + 'C': true, + 'D': true, + 'E': true, + 'F': true, + 'G': true, + 'H': true, + 'I': true, + 'J': true, + 'K': true, + 'L': true, + 'M': true, + 'N': true, + 'O': true, + 'P': true, + 'Q': true, + 'R': true, + 'S': true, + 'T': true, + 'U': true, + 'W': true, + 'V': true, + 'X': true, + 'Y': true, + 'Z': true, + '^': true, + '_': true, + '`': true, + 'a': true, + 'b': true, + 'c': true, + 'd': true, + 'e': true, + 'f': true, + 'g': true, + 'h': true, + 'i': true, + 'j': true, + 'k': true, + 'l': true, + 'm': true, + 'n': true, + 'o': true, + 'p': true, + 'q': true, + 'r': true, + 's': true, + 't': true, + 'u': true, + 'v': true, + 'w': true, + 'x': true, + 'y': true, + 'z': true, + '|': true, + '~': true, +} + +// skipSpace returns a slice of the string s with all leading RFC 2616 linear +// whitespace removed. +func skipSpace(s string) (rest string) { + i := 0 + for ; i < len(s); i++ { + if b := s[i]; b != ' ' && b != '\t' { + break + } + } + return s[i:] +} + +// nextToken returns the leading RFC 2616 token of s and the string following +// the token. +func nextToken(s string) (token, rest string) { + i := 0 + for ; i < len(s); i++ { + if !isTokenOctet[s[i]] { + break + } + } + return s[:i], s[i:] +} + +// nextTokenOrQuoted returns the leading token or quoted string per RFC 2616 +// and the string following the token or quoted string. +func nextTokenOrQuoted(s string) (value string, rest string) { + if !strings.HasPrefix(s, "\"") { + return nextToken(s) + } + s = s[1:] + for i := 0; i < len(s); i++ { + switch s[i] { + case '"': + return s[:i], s[i+1:] + case '\\': + p := make([]byte, len(s)-1) + j := copy(p, s[:i]) + escape := true + for i = i + 1; i < len(s); i++ { + b := s[i] + switch { + case escape: + escape = false + p[j] = b + j++ + case b == '\\': + escape = true + case b == '"': + return string(p[:j]), s[i+1:] + default: + p[j] = b + j++ + } + } + return "", "" + } + } + return "", "" +} + +// equalASCIIFold returns true if s is equal to t with ASCII case folding as +// defined in RFC 4790. +func equalASCIIFold(s, t string) bool { + for s != "" && t != "" { + sr, size := utf8.DecodeRuneInString(s) + s = s[size:] + tr, size := utf8.DecodeRuneInString(t) + t = t[size:] + if sr == tr { + continue + } + if 'A' <= sr && sr <= 'Z' { + sr = sr + 'a' - 'A' + } + if 'A' <= tr && tr <= 'Z' { + tr = tr + 'a' - 'A' + } + if sr != tr { + return false + } + } + return s == t +} + +// tokenListContainsValue returns true if the 1#token header with the given +// name contains a token equal to value with ASCII case folding. +func tokenListContainsValue(header http.Header, name string, value string) bool { +headers: + for _, s := range header[name] { + for { + var t string + t, s = nextToken(skipSpace(s)) + if t == "" { + continue headers + } + s = skipSpace(s) + if s != "" && s[0] != ',' { + continue headers + } + if equalASCIIFold(t, value) { + return true + } + if s == "" { + continue headers + } + s = s[1:] + } + } + return false +} + +// parseExtensions parses WebSocket extensions from a header. +func parseExtensions(header http.Header) []map[string]string { + // From RFC 6455: + // + // Sec-WebSocket-Extensions = extension-list + // extension-list = 1#extension + // extension = extension-token *( ";" extension-param ) + // extension-token = registered-token + // registered-token = token + // extension-param = token [ "=" (token | quoted-string) ] + // ;When using the quoted-string syntax variant, the value + // ;after quoted-string unescaping MUST conform to the + // ;'token' ABNF. + + var result []map[string]string +headers: + for _, s := range header["Sec-Websocket-Extensions"] { + for { + var t string + t, s = nextToken(skipSpace(s)) + if t == "" { + continue headers + } + ext := map[string]string{"": t} + for { + s = skipSpace(s) + if !strings.HasPrefix(s, ";") { + break + } + var k string + k, s = nextToken(skipSpace(s[1:])) + if k == "" { + continue headers + } + s = skipSpace(s) + var v string + if strings.HasPrefix(s, "=") { + v, s = nextTokenOrQuoted(skipSpace(s[1:])) + s = skipSpace(s) + } + if s != "" && s[0] != ',' && s[0] != ';' { + continue headers + } + ext[k] = v + } + if s != "" && s[0] != ',' { + continue headers + } + result = append(result, ext) + if s == "" { + continue headers + } + s = s[1:] + } + } + return result +} + +// isValidChallengeKey checks if the argument meets RFC6455 specification. +func isValidChallengeKey(s string) bool { + // From RFC6455: + // + // A |Sec-WebSocket-Key| header field with a base64-encoded (see + // Section 4 of [RFC4648]) value that, when decoded, is 16 bytes in + // length. + + if s == "" { + return false + } + decoded, err := base64.StdEncoding.DecodeString(s) + return err == nil && len(decoded) == 16 +} + +type CLIArgs struct { + FromAddr string + ToAddr string +} + + + +const X = 1 + +var EmitActiveConnection = g.MakeGauge("active-connections") + + + +func ParseArgs(args []string) CLIArgs { + if len(args) != 3 { + fmt.Fprintf( + os.Stderr, + "Usage: %s FROM.socket TO.socket\n", + args[0], + ) + os.Exit(2) + } + return CLIArgs { + FromAddr: args[1], + ToAddr: args[2], + } +} + +func Listen(fromAddr string) net.Listener { + listener, err := net.Listen("unix", fromAddr) + g.FatalIf(err) + g.Info("Started listening", "listen-start", "from-address", fromAddr) + return listener +} + +func copyData(c chan struct {}, from io.Reader, to io.WriteCloser) { + io.Copy(to, from) + c <- struct {} {} +} + +func Start(toAddr string, listener net.Listener) { + /* + upgrader := websocket.Upgrader {} + http.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { + connFrom, err := upgrader.Upgrade(w, r, nil) + if err != nil { + g.Warning( + "Error upgrading connection", + "upgrade-connection-error", + "err", err, + ) + return + } + defer connFrom.Close() + EmitActiveConnection.Inc() + + connTo, err := net.Dial("unix", toAddr) + if err != nil { + g.Error( + "Error dialing connection", + "dial-connection", + "err", err, + ) + os.Exit(1) + } + defer connTo.Close() + + messageType, reader, err := connFrom.NextReader() + if err != nil { + g.Warning( + "Failed to get next reader from connection", + "connection-next-reader-error", + "err", err, + ) + return + } + + writer, err := connFrom.NextWriter(messageType) + if err != nil { + g.Warning( + "Failed to get next reader from connection", + "connection-next-reader-error", + "err", err, + ) + return + } + + + c := make(chan struct {}) + go copyData(c, connTo, writer) + go copyData(c, reader, connTo) + go func() { + <- c + EmitActiveConnection.Dec() + }() + }); + + server := http.Server{} + err := server.Serve(listener) + g.FatalIf(err) + */ +} + + +func Main() { + g.Init() + args := ParseArgs(os.Args) + listener := Listen(args.FromAddr) + Start(args.ToAddr, listener) +} diff --git a/tests/lib_test.go b/tests/lib_test.go deleted file mode 100644 index 1007609..0000000 --- a/tests/lib_test.go +++ /dev/null @@ -1,24 +0,0 @@ -package wscat_test - -import ( - "testing" - - g "euandre.org/gobang/src" - - "euandre.org/wscat/src" -) - - -func TestPlaceholder(t *testing.T) { - g.AssertEqual(t, wscat.X, 1) -} - -func TestParseArgs(t *testing.T) { - given := wscat.ParseArgs([]string { "x", "y", "z" }) - expected := wscat.CLIArgs { - FromAddr: "y", - ToAddr: "z", - } - - g.AssertEqual(t, given, expected) -} diff --git a/tests/main.go b/tests/main.go new file mode 100644 index 0000000..8123486 --- /dev/null +++ b/tests/main.go @@ -0,0 +1,7 @@ +package main + +import "wscat" + +func main() { + wscat.MainTest() +} diff --git a/tests/wscat.go b/tests/wscat.go new file mode 100644 index 0000000..da16f66 --- /dev/null +++ b/tests/wscat.go @@ -0,0 +1,2842 @@ +package wscat + +import ( + "bufio" + "bytes" + "compress/flate" + "context" + "crypto/tls" + "crypto/x509" + "encoding/base64" + "encoding/binary" + "encoding/json" + "errors" + "fmt" + "io" + "log" + "math/rand" + "net" + "net/http" + "net/http/cookiejar" + "net/http/httptest" + "net/http/httptrace" + "net/url" + "os" + "reflect" + "strings" + "sync" + "sync/atomic" + "testing" + "testing/internal/testdeps" + "testing/iotest" + "time" + + g "gobang" +) + + + +var cstUpgrader = Upgrader{ + Subprotocols: []string{"p0", "p1"}, + ReadBufferSize: 1024, + WriteBufferSize: 1024, + EnableCompression: true, + Error: func(w http.ResponseWriter, r *http.Request, status int, reason error) { + http.Error(w, reason.Error(), status) + }, +} + +var cstDialer = Dialer{ + Subprotocols: []string{"p1", "p2"}, + ReadBufferSize: 1024, + WriteBufferSize: 1024, + HandshakeTimeout: 30 * time.Second, +} + +type cstHandler struct { + *testing.T + s *cstServer +} + +type cstServer struct { + URL string + Server *httptest.Server + wg sync.WaitGroup +} + +const ( + cstPath = "/a/b" + cstRawQuery = "x=y" + cstRequestURI = cstPath + "?" + cstRawQuery +) + +func (s *cstServer) Close() { + s.Server.Close() + // Wait for handler functions to complete. + s.wg.Wait() +} + +func newServer(t *testing.T) *cstServer { + var s cstServer + s.Server = httptest.NewServer(cstHandler{T: t, s: &s}) + s.Server.URL += cstRequestURI + s.URL = makeWsProto(s.Server.URL) + return &s +} + +func newTLSServer(t *testing.T) *cstServer { + var s cstServer + s.Server = httptest.NewTLSServer(cstHandler{T: t, s: &s}) + s.Server.URL += cstRequestURI + s.URL = makeWsProto(s.Server.URL) + return &s +} + +func (t cstHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { + // Because tests wait for a response from a server, we are guaranteed that + // the wait group count is incremented before the test waits on the group + // in the call to (*cstServer).Close(). + t.s.wg.Add(1) + defer t.s.wg.Done() + + if r.URL.Path != cstPath { + t.Logf("path=%v, want %v", r.URL.Path, cstPath) + http.Error(w, "bad path", http.StatusBadRequest) + return + } + if r.URL.RawQuery != cstRawQuery { + t.Logf("query=%v, want %v", r.URL.RawQuery, cstRawQuery) + http.Error(w, "bad path", http.StatusBadRequest) + return + } + subprotos := Subprotocols(r) + if !reflect.DeepEqual(subprotos, cstDialer.Subprotocols) { + t.Logf("subprotols=%v, want %v", subprotos, cstDialer.Subprotocols) + http.Error(w, "bad protocol", http.StatusBadRequest) + return + } + ws, err := cstUpgrader.Upgrade(w, r, http.Header{"Set-Cookie": {"sessionID=1234"}}) + if err != nil { + t.Logf("Upgrade: %v", err) + return + } + defer ws.Close() + + if ws.Subprotocol() != "p1" { + t.Logf("Subprotocol() = %s, want p1", ws.Subprotocol()) + ws.Close() + return + } + op, rd, err := ws.NextReader() + if err != nil { + t.Logf("NextReader: %v", err) + return + } + wr, err := ws.NextWriter(op) + if err != nil { + t.Logf("NextWriter: %v", err) + return + } + if _, err = io.Copy(wr, rd); err != nil { + t.Logf("NextWriter: %v", err) + return + } + if err := wr.Close(); err != nil { + t.Logf("Close: %v", err) + return + } +} + +func makeWsProto(s string) string { + return "ws" + strings.TrimPrefix(s, "http") +} + +func sendRecv(t *testing.T, ws *Conn) { + const message = "Hello World!" + if err := ws.SetWriteDeadline(time.Now().Add(time.Second)); err != nil { + t.Fatalf("SetWriteDeadline: %v", err) + } + if err := ws.WriteMessage(TextMessage, []byte(message)); err != nil { + t.Fatalf("WriteMessage: %v", err) + } + if err := ws.SetReadDeadline(time.Now().Add(time.Second)); err != nil { + t.Fatalf("SetReadDeadline: %v", err) + } + _, p, err := ws.ReadMessage() + if err != nil { + t.Fatalf("ReadMessage: %v", err) + } + if string(p) != message { + t.Fatalf("message=%s, want %s", p, message) + } +} + +func TestProxyDial(t *testing.T) { + + s := newServer(t) + defer s.Close() + + surl, _ := url.Parse(s.Server.URL) + + cstDialer := cstDialer // make local copy for modification on next line. + cstDialer.Proxy = http.ProxyURL(surl) + + connect := false + origHandler := s.Server.Config.Handler + + // Capture the request Host header. + s.Server.Config.Handler = http.HandlerFunc( + func(w http.ResponseWriter, r *http.Request) { + if r.Method == http.MethodConnect { + connect = true + w.WriteHeader(http.StatusOK) + return + } + + if !connect { + t.Log("connect not received") + http.Error(w, "connect not received", http.StatusMethodNotAllowed) + return + } + origHandler.ServeHTTP(w, r) + }) + + ws, _, err := cstDialer.Dial(s.URL, nil) + if err != nil { + t.Fatalf("Dial: %v", err) + } + defer ws.Close() + sendRecv(t, ws) +} + +func TestProxyAuthorizationDial(t *testing.T) { + s := newServer(t) + defer s.Close() + + surl, _ := url.Parse(s.Server.URL) + surl.User = url.UserPassword("username", "password") + + cstDialer := cstDialer // make local copy for modification on next line. + cstDialer.Proxy = http.ProxyURL(surl) + + connect := false + origHandler := s.Server.Config.Handler + + // Capture the request Host header. + s.Server.Config.Handler = http.HandlerFunc( + func(w http.ResponseWriter, r *http.Request) { + proxyAuth := r.Header.Get("Proxy-Authorization") + expectedProxyAuth := "Basic " + base64.StdEncoding.EncodeToString([]byte("username:password")) + if r.Method == http.MethodConnect && proxyAuth == expectedProxyAuth { + connect = true + w.WriteHeader(http.StatusOK) + return + } + + if !connect { + t.Log("connect with proxy authorization not received") + http.Error(w, "connect with proxy authorization not received", http.StatusMethodNotAllowed) + return + } + origHandler.ServeHTTP(w, r) + }) + + ws, _, err := cstDialer.Dial(s.URL, nil) + if err != nil { + t.Fatalf("Dial: %v", err) + } + defer ws.Close() + sendRecv(t, ws) +} + +func TestDial(t *testing.T) { + s := newServer(t) + defer s.Close() + + ws, _, err := cstDialer.Dial(s.URL, nil) + if err != nil { + t.Fatalf("Dial: %v", err) + } + defer ws.Close() + sendRecv(t, ws) +} + +func TestDialCookieJar(t *testing.T) { + s := newServer(t) + defer s.Close() + + jar, _ := cookiejar.New(nil) + d := cstDialer + d.Jar = jar + + u, _ := url.Parse(s.URL) + + switch u.Scheme { + case "ws": + u.Scheme = "http" + case "wss": + u.Scheme = "https" + } + + cookies := []*http.Cookie{{Name: "gorilla", Value: "ws", Path: "/"}} + d.Jar.SetCookies(u, cookies) + + ws, _, err := d.Dial(s.URL, nil) + if err != nil { + t.Fatalf("Dial: %v", err) + } + defer ws.Close() + + var gorilla string + var sessionID string + for _, c := range d.Jar.Cookies(u) { + if c.Name == "gorilla" { + gorilla = c.Value + } + + if c.Name == "sessionID" { + sessionID = c.Value + } + } + if gorilla != "ws" { + t.Error("Cookie not present in jar.") + } + + if sessionID != "1234" { + t.Error("Set-Cookie not received from the server.") + } + + sendRecv(t, ws) +} + +func rootCAs(t *testing.T, s *httptest.Server) *x509.CertPool { + certs := x509.NewCertPool() + for _, c := range s.TLS.Certificates { + roots, err := x509.ParseCertificates(c.Certificate[len(c.Certificate)-1]) + if err != nil { + t.Fatalf("error parsing server's root cert: %v", err) + } + for _, root := range roots { + certs.AddCert(root) + } + } + return certs +} + +func TestDialTLS(t *testing.T) { + s := newTLSServer(t) + defer s.Close() + + d := cstDialer + d.TLSClientConfig = &tls.Config{RootCAs: rootCAs(t, s.Server)} + ws, _, err := d.Dial(s.URL, nil) + if err != nil { + t.Fatalf("Dial: %v", err) + } + defer ws.Close() + sendRecv(t, ws) +} + +func TestDialTimeout(t *testing.T) { + s := newServer(t) + defer s.Close() + + d := cstDialer + d.HandshakeTimeout = -1 + ws, _, err := d.Dial(s.URL, nil) + if err == nil { + ws.Close() + t.Fatalf("Dial: nil") + } +} + +// requireDeadlineNetConn fails the current test when Read or Write are called +// with no deadline. +type requireDeadlineNetConn struct { + t *testing.T + c net.Conn + readDeadlineIsSet bool + writeDeadlineIsSet bool +} + +func (c *requireDeadlineNetConn) SetDeadline(t time.Time) error { + c.writeDeadlineIsSet = !t.Equal(time.Time{}) + c.readDeadlineIsSet = c.writeDeadlineIsSet + return c.c.SetDeadline(t) +} + +func (c *requireDeadlineNetConn) SetReadDeadline(t time.Time) error { + c.readDeadlineIsSet = !t.Equal(time.Time{}) + return c.c.SetDeadline(t) +} + +func (c *requireDeadlineNetConn) SetWriteDeadline(t time.Time) error { + c.writeDeadlineIsSet = !t.Equal(time.Time{}) + return c.c.SetDeadline(t) +} + +func (c *requireDeadlineNetConn) Write(p []byte) (int, error) { + if !c.writeDeadlineIsSet { + c.t.Fatalf("write with no deadline") + } + return c.c.Write(p) +} + +func (c *requireDeadlineNetConn) Read(p []byte) (int, error) { + if !c.readDeadlineIsSet { + c.t.Fatalf("read with no deadline") + } + return c.c.Read(p) +} + +func (c *requireDeadlineNetConn) Close() error { return c.c.Close() } +func (c *requireDeadlineNetConn) LocalAddr() net.Addr { return c.c.LocalAddr() } +func (c *requireDeadlineNetConn) RemoteAddr() net.Addr { return c.c.RemoteAddr() } + +func TestHandshakeTimeout(t *testing.T) { + s := newServer(t) + defer s.Close() + + d := cstDialer + d.NetDial = func(n, a string) (net.Conn, error) { + c, err := net.Dial(n, a) + return &requireDeadlineNetConn{c: c, t: t}, err + } + ws, _, err := d.Dial(s.URL, nil) + if err != nil { + t.Fatal("Dial:", err) + } + ws.Close() +} + +func TestHandshakeTimeoutInContext(t *testing.T) { + s := newServer(t) + defer s.Close() + + d := cstDialer + d.HandshakeTimeout = 0 + d.NetDialContext = func(ctx context.Context, n, a string) (net.Conn, error) { + netDialer := &net.Dialer{} + c, err := netDialer.DialContext(ctx, n, a) + return &requireDeadlineNetConn{c: c, t: t}, err + } + + ctx, cancel := context.WithDeadline(context.Background(), time.Now().Add(30*time.Second)) + defer cancel() + ws, _, err := d.DialContext(ctx, s.URL, nil) + if err != nil { + t.Fatal("Dial:", err) + } + ws.Close() +} + +func TestDialBadScheme(t *testing.T) { + s := newServer(t) + defer s.Close() + + ws, _, err := cstDialer.Dial(s.Server.URL, nil) + if err == nil { + ws.Close() + t.Fatalf("Dial: nil") + } +} + +func TestDialBadOrigin(t *testing.T) { + s := newServer(t) + defer s.Close() + + ws, resp, err := cstDialer.Dial(s.URL, http.Header{"Origin": {"bad"}}) + if err == nil { + ws.Close() + t.Fatalf("Dial: nil") + } + if resp == nil { + t.Fatalf("resp=nil, err=%v", err) + } + if resp.StatusCode != http.StatusForbidden { + t.Fatalf("status=%d, want %d", resp.StatusCode, http.StatusForbidden) + } +} + +func TestDialBadHeader(t *testing.T) { + s := newServer(t) + defer s.Close() + + for _, k := range []string{"Upgrade", + "Connection", + "Sec-Websocket-Key", + "Sec-Websocket-Version", + "Sec-Websocket-Protocol"} { + h := http.Header{} + h.Set(k, "bad") + ws, _, err := cstDialer.Dial(s.URL, http.Header{"Origin": {"bad"}}) + if err == nil { + ws.Close() + t.Errorf("Dial with header %s returned nil", k) + } + } +} + +func TestBadMethod(t *testing.T) { + s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + ws, err := cstUpgrader.Upgrade(w, r, nil) + if err == nil { + t.Errorf("handshake succeeded, expect fail") + ws.Close() + } + })) + defer s.Close() + + req, err := http.NewRequest(http.MethodPost, s.URL, strings.NewReader("")) + if err != nil { + t.Fatalf("NewRequest returned error %v", err) + } + req.Header.Set("Connection", "upgrade") + req.Header.Set("Upgrade", "websocket") + req.Header.Set("Sec-Websocket-Version", "13") + + resp, err := http.DefaultClient.Do(req) + if err != nil { + t.Fatalf("Do returned error %v", err) + } + resp.Body.Close() + if resp.StatusCode != http.StatusMethodNotAllowed { + t.Errorf("Status = %d, want %d", resp.StatusCode, http.StatusMethodNotAllowed) + } +} + +func TestNoUpgrade(t *testing.T) { + t.Parallel() + s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + ws, err := cstUpgrader.Upgrade(w, r, nil) + if err == nil { + t.Errorf("handshake succeeded, expect fail") + ws.Close() + } + })) + defer s.Close() + + req, err := http.NewRequest(http.MethodGet, s.URL, strings.NewReader("")) + if err != nil { + t.Fatalf("NewRequest returned error %v", err) + } + req.Header.Set("Connection", "upgrade") + req.Header.Set("Sec-Websocket-Version", "13") + + resp, err := http.DefaultClient.Do(req) + if err != nil { + t.Fatalf("Do returned error %v", err) + } + resp.Body.Close() + if u := resp.Header.Get("Upgrade"); u != "websocket" { + t.Errorf("Upgrade response header is %q, want %q", u, "websocket") + } + if resp.StatusCode != http.StatusUpgradeRequired { + t.Errorf("Status = %d, want %d", resp.StatusCode, http.StatusUpgradeRequired) + } +} + +func TestDialExtraTokensInRespHeaders(t *testing.T) { + s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + challengeKey := r.Header.Get("Sec-Websocket-Key") + w.Header().Set("Upgrade", "foo, websocket") + w.Header().Set("Connection", "upgrade, keep-alive") + w.Header().Set("Sec-Websocket-Accept", computeAcceptKey(challengeKey)) + w.WriteHeader(101) + })) + defer s.Close() + + ws, _, err := cstDialer.Dial(makeWsProto(s.URL), nil) + if err != nil { + t.Fatalf("Dial: %v", err) + } + defer ws.Close() +} + +func TestHandshake(t *testing.T) { + s := newServer(t) + defer s.Close() + + ws, resp, err := cstDialer.Dial(s.URL, http.Header{"Origin": {s.URL}}) + if err != nil { + t.Fatalf("Dial: %v", err) + } + defer ws.Close() + + var sessionID string + for _, c := range resp.Cookies() { + if c.Name == "sessionID" { + sessionID = c.Value + } + } + if sessionID != "1234" { + t.Error("Set-Cookie not received from the server.") + } + + if ws.Subprotocol() != "p1" { + t.Errorf("ws.Subprotocol() = %s, want p1", ws.Subprotocol()) + } + sendRecv(t, ws) +} + +func TestRespOnBadHandshake(t *testing.T) { + const expectedStatus = http.StatusGone + const expectedBody = "This is the response body." + + s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(expectedStatus) + _, _ = io.WriteString(w, expectedBody) + })) + defer s.Close() + + ws, resp, err := cstDialer.Dial(makeWsProto(s.URL), nil) + if err == nil { + ws.Close() + t.Fatalf("Dial: nil") + } + + if resp == nil { + t.Fatalf("resp=nil, err=%v", err) + } + + if resp.StatusCode != expectedStatus { + t.Errorf("resp.StatusCode=%d, want %d", resp.StatusCode, expectedStatus) + } + + p, err := io.ReadAll(resp.Body) + if err != nil { + t.Fatalf("ReadFull(resp.Body) returned error %v", err) + } + + if string(p) != expectedBody { + t.Errorf("resp.Body=%s, want %s", p, expectedBody) + } +} + +type testLogWriter struct { + t *testing.T +} + +func (w testLogWriter) Write(p []byte) (int, error) { + w.t.Logf("%s", p) + return len(p), nil +} + +// TestHost tests handling of host names and confirms that it matches net/http. +func TestHost(t *testing.T) { + + upgrader := Upgrader{} + handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if IsWebSocketUpgrade(r) { + c, err := upgrader.Upgrade(w, r, http.Header{"X-Test-Host": {r.Host}}) + if err != nil { + t.Fatal(err) + } + c.Close() + } else { + w.Header().Set("X-Test-Host", r.Host) + } + }) + + server := httptest.NewServer(handler) + defer server.Close() + + tlsServer := httptest.NewTLSServer(handler) + defer tlsServer.Close() + + addrs := map[*httptest.Server]string{server: server.Listener.Addr().String(), tlsServer: tlsServer.Listener.Addr().String()} + wsProtos := map[*httptest.Server]string{server: "ws://", tlsServer: "wss://"} + httpProtos := map[*httptest.Server]string{server: "http://", tlsServer: "https://"} + + // Avoid log noise from net/http server by logging to testing.T + server.Config.ErrorLog = log.New(testLogWriter{t}, "", 0) + tlsServer.Config.ErrorLog = server.Config.ErrorLog + + cas := rootCAs(t, tlsServer) + + tests := []struct { + fail bool // true if dial / get should fail + server *httptest.Server // server to use + url string // host for request URI + header string // optional request host header + tls string // optional host for tls ServerName + wantAddr string // expected host for dial + wantHeader string // expected request header on server + insecureSkipVerify bool + }{ + { + server: server, + url: addrs[server], + wantAddr: addrs[server], + wantHeader: addrs[server], + }, + { + server: tlsServer, + url: addrs[tlsServer], + wantAddr: addrs[tlsServer], + wantHeader: addrs[tlsServer], + }, + + { + server: server, + url: addrs[server], + header: "badhost.com", + wantAddr: addrs[server], + wantHeader: "badhost.com", + }, + { + server: tlsServer, + url: addrs[tlsServer], + header: "badhost.com", + wantAddr: addrs[tlsServer], + wantHeader: "badhost.com", + }, + + { + server: server, + url: "example.com", + header: "badhost.com", + wantAddr: "example.com:80", + wantHeader: "badhost.com", + }, + { + server: tlsServer, + url: "example.com", + header: "badhost.com", + wantAddr: "example.com:443", + wantHeader: "badhost.com", + }, + + { + server: server, + url: "badhost.com", + header: "example.com", + wantAddr: "badhost.com:80", + wantHeader: "example.com", + }, + { + fail: true, + server: tlsServer, + url: "badhost.com", + header: "example.com", + wantAddr: "badhost.com:443", + }, + { + server: tlsServer, + url: "badhost.com", + insecureSkipVerify: true, + wantAddr: "badhost.com:443", + wantHeader: "badhost.com", + }, + { + server: tlsServer, + url: "badhost.com", + tls: "example.com", + wantAddr: "badhost.com:443", + wantHeader: "badhost.com", + }, + } + + for i, tt := range tests { + + tls := &tls.Config{ + RootCAs: cas, + ServerName: tt.tls, + InsecureSkipVerify: tt.insecureSkipVerify, + } + + var gotAddr string + dialer := Dialer{ + NetDial: func(network, addr string) (net.Conn, error) { + gotAddr = addr + return net.Dial(network, addrs[tt.server]) + }, + TLSClientConfig: tls, + } + + // Test websocket dial + + h := http.Header{} + if tt.header != "" { + h.Set("Host", tt.header) + } + c, resp, err := dialer.Dial(wsProtos[tt.server]+tt.url+"/", h) + if err == nil { + c.Close() + } + + check := func(protos map[*httptest.Server]string) { + name := fmt.Sprintf("%d: %s%s/ header[Host]=%q, tls.ServerName=%q", i+1, protos[tt.server], tt.url, tt.header, tt.tls) + if gotAddr != tt.wantAddr { + t.Errorf("%s: got addr %s, want %s", name, gotAddr, tt.wantAddr) + } + switch { + case tt.fail && err == nil: + t.Errorf("%s: unexpected success", name) + case !tt.fail && err != nil: + t.Errorf("%s: unexpected error %v", name, err) + case !tt.fail && err == nil: + if gotHost := resp.Header.Get("X-Test-Host"); gotHost != tt.wantHeader { + t.Errorf("%s: got host %s, want %s", name, gotHost, tt.wantHeader) + } + } + } + + check(wsProtos) + + // Confirm that net/http has same result + + transport := &http.Transport{ + Dial: dialer.NetDial, + TLSClientConfig: dialer.TLSClientConfig, + } + req, _ := http.NewRequest(http.MethodGet, httpProtos[tt.server]+tt.url+"/", nil) + if tt.header != "" { + req.Host = tt.header + } + client := &http.Client{Transport: transport} + resp, err = client.Do(req) + if err == nil { + resp.Body.Close() + } + transport.CloseIdleConnections() + check(httpProtos) + } +} + +func TestDialCompression(t *testing.T) { + s := newServer(t) + defer s.Close() + + dialer := cstDialer + dialer.EnableCompression = true + ws, _, err := dialer.Dial(s.URL, nil) + if err != nil { + t.Fatalf("Dial: %v", err) + } + defer ws.Close() + sendRecv(t, ws) +} + +func TestSocksProxyDial(t *testing.T) { + s := newServer(t) + defer s.Close() + + proxyListener, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatalf("listen failed: %v", err) + } + defer proxyListener.Close() + go func() { + c1, err := proxyListener.Accept() + if err != nil { + t.Errorf("proxy accept failed: %v", err) + return + } + defer c1.Close() + + _ = c1.SetDeadline(time.Now().Add(30 * time.Second)) + + buf := make([]byte, 32) + if _, err := io.ReadFull(c1, buf[:3]); err != nil { + t.Errorf("read failed: %v", err) + return + } + if want := []byte{5, 1, 0}; !bytes.Equal(want, buf[:len(want)]) { + t.Errorf("read %x, want %x", buf[:len(want)], want) + } + if _, err := c1.Write([]byte{5, 0}); err != nil { + t.Errorf("write failed: %v", err) + return + } + if _, err := io.ReadFull(c1, buf[:10]); err != nil { + t.Errorf("read failed: %v", err) + return + } + if want := []byte{5, 1, 0, 1}; !bytes.Equal(want, buf[:len(want)]) { + t.Errorf("read %x, want %x", buf[:len(want)], want) + return + } + buf[1] = 0 + if _, err := c1.Write(buf[:10]); err != nil { + t.Errorf("write failed: %v", err) + return + } + + ip := net.IP(buf[4:8]) + port := binary.BigEndian.Uint16(buf[8:10]) + + c2, err := net.DialTCP("tcp", nil, &net.TCPAddr{IP: ip, Port: int(port)}) + if err != nil { + t.Errorf("dial failed; %v", err) + return + } + defer c2.Close() + done := make(chan struct{}) + go func() { + _, _ = io.Copy(c1, c2) + close(done) + }() + _, _ = io.Copy(c2, c1) + <-done + }() + + purl, err := url.Parse("socks5://" + proxyListener.Addr().String()) + if err != nil { + t.Fatalf("parse failed: %v", err) + } + + cstDialer := cstDialer // make local copy for modification on next line. + cstDialer.Proxy = http.ProxyURL(purl) + + ws, _, err := cstDialer.Dial(s.URL, nil) + if err != nil { + t.Fatalf("Dial: %v", err) + } + defer ws.Close() + sendRecv(t, ws) +} + +func TestTracingDialWithContext(t *testing.T) { + + var headersWrote, requestWrote, getConn, gotConn, connectDone, gotFirstResponseByte bool + trace := &httptrace.ClientTrace{ + WroteHeaders: func() { + headersWrote = true + }, + WroteRequest: func(httptrace.WroteRequestInfo) { + requestWrote = true + }, + GetConn: func(hostPort string) { + getConn = true + }, + GotConn: func(info httptrace.GotConnInfo) { + gotConn = true + }, + ConnectDone: func(network, addr string, err error) { + connectDone = true + }, + GotFirstResponseByte: func() { + gotFirstResponseByte = true + }, + } + ctx := httptrace.WithClientTrace(context.Background(), trace) + + s := newTLSServer(t) + defer s.Close() + + d := cstDialer + d.TLSClientConfig = &tls.Config{RootCAs: rootCAs(t, s.Server)} + + ws, _, err := d.DialContext(ctx, s.URL, nil) + if err != nil { + t.Fatalf("Dial: %v", err) + } + + if !headersWrote { + t.Fatal("Headers was not written") + } + if !requestWrote { + t.Fatal("Request was not written") + } + if !getConn { + t.Fatal("getConn was not called") + } + if !gotConn { + t.Fatal("gotConn was not called") + } + if !connectDone { + t.Fatal("connectDone was not called") + } + if !gotFirstResponseByte { + t.Fatal("GotFirstResponseByte was not called") + } + + defer ws.Close() + sendRecv(t, ws) +} + +func TestEmptyTracingDialWithContext(t *testing.T) { + + trace := &httptrace.ClientTrace{} + ctx := httptrace.WithClientTrace(context.Background(), trace) + + s := newTLSServer(t) + defer s.Close() + + d := cstDialer + d.TLSClientConfig = &tls.Config{RootCAs: rootCAs(t, s.Server)} + + ws, _, err := d.DialContext(ctx, s.URL, nil) + if err != nil { + t.Fatalf("Dial: %v", err) + } + + defer ws.Close() + sendRecv(t, ws) +} + +// TestNetDialConnect tests selection of dial method between NetDial, NetDialContext, NetDialTLS or NetDialTLSContext +func TestNetDialConnect(t *testing.T) { + + upgrader := Upgrader{} + handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if IsWebSocketUpgrade(r) { + c, err := upgrader.Upgrade(w, r, http.Header{"X-Test-Host": {r.Host}}) + if err != nil { + t.Fatal(err) + } + c.Close() + } else { + w.Header().Set("X-Test-Host", r.Host) + } + }) + + server := httptest.NewServer(handler) + defer server.Close() + + tlsServer := httptest.NewTLSServer(handler) + defer tlsServer.Close() + + testUrls := map[*httptest.Server]string{ + server: "ws://" + server.Listener.Addr().String() + "/", + tlsServer: "wss://" + tlsServer.Listener.Addr().String() + "/", + } + + cas := rootCAs(t, tlsServer) + tlsConfig := &tls.Config{ + RootCAs: cas, + ServerName: "example.com", + InsecureSkipVerify: false, + } + + tests := []struct { + name string + server *httptest.Server // server to use + netDial func(network, addr string) (net.Conn, error) + netDialContext func(ctx context.Context, network, addr string) (net.Conn, error) + netDialTLSContext func(ctx context.Context, network, addr string) (net.Conn, error) + tlsClientConfig *tls.Config + }{ + + { + name: "HTTP server, all NetDial* defined, shall use NetDialContext", + server: server, + netDial: func(network, addr string) (net.Conn, error) { + return nil, errors.New("NetDial should not be called") + }, + netDialContext: func(_ context.Context, network, addr string) (net.Conn, error) { + return net.Dial(network, addr) + }, + netDialTLSContext: func(_ context.Context, network, addr string) (net.Conn, error) { + return nil, errors.New("NetDialTLSContext should not be called") + }, + tlsClientConfig: nil, + }, + { + name: "HTTP server, all NetDial* undefined", + server: server, + netDial: nil, + netDialContext: nil, + netDialTLSContext: nil, + tlsClientConfig: nil, + }, + { + name: "HTTP server, NetDialContext undefined, shall fallback to NetDial", + server: server, + netDial: func(network, addr string) (net.Conn, error) { + return net.Dial(network, addr) + }, + netDialContext: nil, + netDialTLSContext: func(ctx context.Context, network, addr string) (net.Conn, error) { + return nil, errors.New("NetDialTLSContext should not be called") + }, + tlsClientConfig: nil, + }, + { + name: "HTTPS server, all NetDial* defined, shall use NetDialTLSContext", + server: tlsServer, + netDial: func(network, addr string) (net.Conn, error) { + return nil, errors.New("NetDial should not be called") + }, + netDialContext: func(ctx context.Context, network, addr string) (net.Conn, error) { + return nil, errors.New("NetDialContext should not be called") + }, + netDialTLSContext: func(ctx context.Context, network, addr string) (net.Conn, error) { + netConn, err := net.Dial(network, addr) + if err != nil { + return nil, err + } + tlsConn := tls.Client(netConn, tlsConfig) + err = tlsConn.Handshake() + if err != nil { + return nil, err + } + return tlsConn, nil + }, + tlsClientConfig: nil, + }, + { + name: "HTTPS server, NetDialTLSContext undefined, shall fallback to NetDialContext and do handshake", + server: tlsServer, + netDial: func(network, addr string) (net.Conn, error) { + return nil, errors.New("NetDial should not be called") + }, + netDialContext: func(ctx context.Context, network, addr string) (net.Conn, error) { + return net.Dial(network, addr) + }, + netDialTLSContext: nil, + tlsClientConfig: tlsConfig, + }, + { + name: "HTTPS server, NetDialTLSContext and NetDialContext undefined, shall fallback to NetDial and do handshake", + server: tlsServer, + netDial: func(network, addr string) (net.Conn, error) { + return net.Dial(network, addr) + }, + netDialContext: nil, + netDialTLSContext: nil, + tlsClientConfig: tlsConfig, + }, + { + name: "HTTPS server, all NetDial* undefined", + server: tlsServer, + netDial: nil, + netDialContext: nil, + netDialTLSContext: nil, + tlsClientConfig: tlsConfig, + }, + { + name: "HTTPS server, all NetDialTLSContext defined, dummy TlsClientConfig defined, shall not do handshake", + server: tlsServer, + netDial: func(network, addr string) (net.Conn, error) { + return nil, errors.New("NetDial should not be called") + }, + netDialContext: func(ctx context.Context, network, addr string) (net.Conn, error) { + return nil, errors.New("NetDialContext should not be called") + }, + netDialTLSContext: func(ctx context.Context, network, addr string) (net.Conn, error) { + netConn, err := net.Dial(network, addr) + if err != nil { + return nil, err + } + tlsConn := tls.Client(netConn, tlsConfig) + err = tlsConn.Handshake() + if err != nil { + return nil, err + } + return tlsConn, nil + }, + tlsClientConfig: &tls.Config{ + RootCAs: nil, + ServerName: "badserver.com", + InsecureSkipVerify: false, + }, + }, + } + + for _, tc := range tests { + dialer := Dialer{ + NetDial: tc.netDial, + NetDialContext: tc.netDialContext, + NetDialTLSContext: tc.netDialTLSContext, + TLSClientConfig: tc.tlsClientConfig, + } + + // Test websocket dial + c, _, err := dialer.Dial(testUrls[tc.server], nil) + if err != nil { + t.Errorf("FAILED %s, err: %s", tc.name, err.Error()) + } else { + c.Close() + } + } +} +func TestNextProtos(t *testing.T) { + ts := httptest.NewUnstartedServer( + http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}), + ) + ts.EnableHTTP2 = true + ts.StartTLS() + defer ts.Close() + + d := Dialer{ + TLSClientConfig: ts.Client().Transport.(*http.Transport).TLSClientConfig, + } + + r, err := ts.Client().Get(ts.URL) + if err != nil { + t.Fatalf("Get: %v", err) + } + r.Body.Close() + + // Asserts that Dialer.TLSClientConfig.NextProtos contains "h2" + // after the Client.Get call from net/http above. + var containsHTTP2 bool = false + for _, proto := range d.TLSClientConfig.NextProtos { + if proto == "h2" { + containsHTTP2 = true + } + } + if !containsHTTP2 { + t.Fatalf("Dialer.TLSClientConfig.NextProtos does not contain \"h2\"") + } + + _, _, err = d.Dial(makeWsProto(ts.URL), nil) + if err == nil { + t.Fatalf("Dial succeeded, expect fail ") + } +} + +type dataBeforeHandshakeResponseWriter struct { + http.ResponseWriter +} + +type dataBeforeHandshakeConnection struct { + net.Conn + io.Reader +} + +func (c *dataBeforeHandshakeConnection) Read(p []byte) (int, error) { + return c.Reader.Read(p) +} + +func (w dataBeforeHandshakeResponseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) { + // Example single-frame masked text message from section 5.7 of the RFC. + message := []byte{0x81, 0x85, 0x37, 0xfa, 0x21, 0x3d, 0x7f, 0x9f, 0x4d, 0x51, 0x58} + n := len(message) / 2 + + c, rw, err := http.NewResponseController(w.ResponseWriter).Hijack() + if rw != nil { + // Load first part of message into bufio.Reader. If the websocket + // connection reads more than n bytes from the bufio.Reader, then the + // test will fail with an unexpected EOF error. + rw.Reader.Reset(bytes.NewReader(message[:n])) + rw.Reader.Peek(n) + } + if c != nil { + // Inject second part of message before data read from the network connection. + c = &dataBeforeHandshakeConnection{ + Conn: c, + Reader: io.MultiReader(bytes.NewReader(message[n:]), c), + } + } + return c, rw, err +} + +func TestDataReceivedBeforeHandshake(t *testing.T) { + s := newServer(t) + defer s.Close() + + origHandler := s.Server.Config.Handler + s.Server.Config.Handler = http.HandlerFunc( + func(w http.ResponseWriter, r *http.Request) { + origHandler.ServeHTTP(dataBeforeHandshakeResponseWriter{w}, r) + }) + + for _, readBufferSize := range []int{0, 1024} { + t.Run(fmt.Sprintf("ReadBufferSize=%d", readBufferSize), func(t *testing.T) { + dialer := cstDialer + dialer.ReadBufferSize = readBufferSize + ws, _, err := cstDialer.Dial(s.URL, nil) + if err != nil { + t.Fatalf("Dial: %v", err) + } + defer ws.Close() + _, m, err := ws.ReadMessage() + if err != nil || string(m) != "Hello" { + t.Fatalf("ReadMessage() = %q, %v, want \"Hello\", nil", m, err) + } + }) + } +} +// Copyright 2014 The Gorilla WebSocket Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + + +var hostPortNoPortTests = []struct { + u *url.URL + hostPort, hostNoPort string +}{ + {&url.URL{Scheme: "ws", Host: "example.com"}, "example.com:80", "example.com"}, + {&url.URL{Scheme: "wss", Host: "example.com"}, "example.com:443", "example.com"}, + {&url.URL{Scheme: "ws", Host: "example.com:7777"}, "example.com:7777", "example.com"}, + {&url.URL{Scheme: "wss", Host: "example.com:7777"}, "example.com:7777", "example.com"}, +} + +func TestHostPortNoPort(t *testing.T) { + for _, tt := range hostPortNoPortTests { + hostPort, hostNoPort := hostPortNoPort(tt.u) + if hostPort != tt.hostPort { + t.Errorf("hostPortNoPort(%v) returned hostPort %q, want %q", tt.u, hostPort, tt.hostPort) + } + if hostNoPort != tt.hostNoPort { + t.Errorf("hostPortNoPort(%v) returned hostNoPort %q, want %q", tt.u, hostNoPort, tt.hostNoPort) + } + } +} + +type nopCloser struct{ io.Writer } + +func (nopCloser) Close() error { return nil } + +func TestTruncWriter(t *testing.T) { + const data = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijlkmnopqrstuvwxyz987654321" + for n := 1; n <= 10; n++ { + var b bytes.Buffer + w := &truncWriter{w: nopCloser{&b}} + p := []byte(data) + for len(p) > 0 { + m := len(p) + if m > n { + m = n + } + _, _ = w.Write(p[:m]) + p = p[m:] + } + if b.String() != data[:len(data)-len(w.p)] { + t.Errorf("%d: %q", n, b.String()) + } + } +} + +func textMessages(num int) [][]byte { + messages := make([][]byte, num) + for i := 0; i < num; i++ { + msg := fmt.Sprintf("planet: %d, country: %d, city: %d, street: %d", i, i, i, i) + messages[i] = []byte(msg) + } + return messages +} + +func BenchmarkWriteNoCompression(b *testing.B) { + w := io.Discard + c := newTestConn(nil, w, false) + messages := textMessages(100) + b.ResetTimer() + for i := 0; i < b.N; i++ { + _ = c.WriteMessage(TextMessage, messages[i%len(messages)]) + } + b.ReportAllocs() +} + +func BenchmarkWriteWithCompression(b *testing.B) { + w := io.Discard + c := newTestConn(nil, w, false) + messages := textMessages(100) + c.enableWriteCompression = true + c.newCompressionWriter = compressNoContextTakeover + b.ResetTimer() + for i := 0; i < b.N; i++ { + _ = c.WriteMessage(TextMessage, messages[i%len(messages)]) + } + b.ReportAllocs() +} + +func TestValidCompressionLevel(t *testing.T) { + c := newTestConn(nil, nil, false) + for _, level := range []int{minCompressionLevel - 1, maxCompressionLevel + 1} { + if err := c.SetCompressionLevel(level); err == nil { + t.Errorf("no error for level %d", level) + } + } + for _, level := range []int{minCompressionLevel, maxCompressionLevel} { + if err := c.SetCompressionLevel(level); err != nil { + t.Errorf("error for level %d", level) + } + } +} +// Copyright 2017 The Gorilla WebSocket Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + + +// broadcastBench allows to run broadcast benchmarks. +// In every broadcast benchmark we create many connections, then send the same +// message into every connection and wait for all writes complete. This emulates +// an application where many connections listen to the same data - i.e. PUB/SUB +// scenarios with many subscribers in one channel. +type broadcastBench struct { + w io.Writer + closeCh chan struct{} + doneCh chan struct{} + count int32 + conns []*broadcastConn + compression bool + usePrepared bool +} + +type broadcastMessage struct { + payload []byte + prepared *PreparedMessage +} + +type broadcastConn struct { + conn *Conn + msgCh chan *broadcastMessage +} + +func newBroadcastConn(c *Conn) *broadcastConn { + return &broadcastConn{ + conn: c, + msgCh: make(chan *broadcastMessage, 1), + } +} + +func newBroadcastBench(usePrepared, compression bool) *broadcastBench { + bench := &broadcastBench{ + w: io.Discard, + doneCh: make(chan struct{}), + closeCh: make(chan struct{}), + usePrepared: usePrepared, + compression: compression, + } + bench.makeConns(10000) + return bench +} + +func (b *broadcastBench) makeConns(numConns int) { + conns := make([]*broadcastConn, numConns) + + for i := 0; i < numConns; i++ { + c := newTestConn(nil, b.w, true) + if b.compression { + c.enableWriteCompression = true + c.newCompressionWriter = compressNoContextTakeover + } + conns[i] = newBroadcastConn(c) + go func(c *broadcastConn) { + for { + select { + case msg := <-c.msgCh: + if msg.prepared != nil { + _ = c.conn.WritePreparedMessage(msg.prepared) + } else { + _ = c.conn.WriteMessage(TextMessage, msg.payload) + } + val := atomic.AddInt32(&b.count, 1) + if val%int32(numConns) == 0 { + b.doneCh <- struct{}{} + } + case <-b.closeCh: + return + } + } + }(conns[i]) + } + b.conns = conns +} + +func (b *broadcastBench) close() { + close(b.closeCh) +} + +func (b *broadcastBench) broadcastOnce(msg *broadcastMessage) { + for _, c := range b.conns { + c.msgCh <- msg + } + <-b.doneCh +} + +func BenchmarkBroadcast(b *testing.B) { + benchmarks := []struct { + name string + usePrepared bool + compression bool + }{ + {"NoCompression", false, false}, + {"Compression", false, true}, + {"NoCompressionPrepared", true, false}, + {"CompressionPrepared", true, true}, + } + payload := textMessages(1)[0] + for _, bm := range benchmarks { + b.Run(bm.name, func(b *testing.B) { + bench := newBroadcastBench(bm.usePrepared, bm.compression) + defer bench.close() + b.ResetTimer() + for i := 0; i < b.N; i++ { + message := &broadcastMessage{ + payload: payload, + } + if bench.usePrepared { + pm, _ := NewPreparedMessage(TextMessage, message.payload) + message.prepared = pm + } + bench.broadcastOnce(message) + } + b.ReportAllocs() + }) + } +} +// Copyright 2013 The Gorilla WebSocket Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + + +var _ net.Error = errWriteTimeout + +type fakeNetConn struct { + io.Reader + io.Writer +} + +func (c fakeNetConn) Close() error { return nil } +func (c fakeNetConn) LocalAddr() net.Addr { return localAddr } +func (c fakeNetConn) RemoteAddr() net.Addr { return remoteAddr } +func (c fakeNetConn) SetDeadline(t time.Time) error { return nil } +func (c fakeNetConn) SetReadDeadline(t time.Time) error { return nil } +func (c fakeNetConn) SetWriteDeadline(t time.Time) error { return nil } + +type fakeAddr int + +var ( + localAddr = fakeAddr(1) + remoteAddr = fakeAddr(2) +) + +func (a fakeAddr) Network() string { + return "net" +} + +func (a fakeAddr) String() string { + return "str" +} + +// newTestConn creates a connection backed by a fake network connection using +// default values for buffering. +func newTestConn(r io.Reader, w io.Writer, isServer bool) *Conn { + return newConn(fakeNetConn{Reader: r, Writer: w}, isServer, 1024, 1024, nil, nil, nil) +} + +func TestFraming(t *testing.T) { + frameSizes := []int{ + 0, 1, 2, 124, 125, 126, 127, 128, 129, 65534, 65535, + // 65536, 65537 + } + var readChunkers = []struct { + name string + f func(io.Reader) io.Reader + }{ + {"half", iotest.HalfReader}, + {"one", iotest.OneByteReader}, + {"asis", func(r io.Reader) io.Reader { return r }}, + } + writeBuf := make([]byte, 65537) + for i := range writeBuf { + writeBuf[i] = byte(i) + } + var writers = []struct { + name string + f func(w io.Writer, n int) (int, error) + }{ + {"iocopy", func(w io.Writer, n int) (int, error) { + nn, err := io.Copy(w, bytes.NewReader(writeBuf[:n])) + return int(nn), err + }}, + {"write", func(w io.Writer, n int) (int, error) { + return w.Write(writeBuf[:n]) + }}, + {"string", func(w io.Writer, n int) (int, error) { + return io.WriteString(w, string(writeBuf[:n])) + }}, + } + + for _, compress := range []bool{false, true} { + for _, isServer := range []bool{true, false} { + for _, chunker := range readChunkers { + + var connBuf bytes.Buffer + wc := newTestConn(nil, &connBuf, isServer) + rc := newTestConn(chunker.f(&connBuf), nil, !isServer) + if compress { + wc.newCompressionWriter = compressNoContextTakeover + rc.newDecompressionReader = decompressNoContextTakeover + } + for _, n := range frameSizes { + for _, writer := range writers { + name := fmt.Sprintf("z:%v, s:%v, r:%s, n:%d w:%s", compress, isServer, chunker.name, n, writer.name) + + w, err := wc.NextWriter(TextMessage) + if err != nil { + t.Errorf("%s: wc.NextWriter() returned %v", name, err) + continue + } + nn, err := writer.f(w, n) + if err != nil || nn != n { + t.Errorf("%s: w.Write(writeBuf[:n]) returned %d, %v", name, nn, err) + continue + } + err = w.Close() + if err != nil { + t.Errorf("%s: w.Close() returned %v", name, err) + continue + } + + opCode, r, err := rc.NextReader() + if err != nil || opCode != TextMessage { + t.Errorf("%s: NextReader() returned %d, r, %v", name, opCode, err) + continue + } + + t.Logf("frame size: %d", n) + rbuf, err := io.ReadAll(r) + if err != nil { + t.Errorf("%s: ReadFull() returned rbuf, %v", name, err) + continue + } + + if len(rbuf) != n { + t.Errorf("%s: len(rbuf) is %d, want %d", name, len(rbuf), n) + continue + } + + for i, b := range rbuf { + if byte(i) != b { + t.Errorf("%s: bad byte at offset %d", name, i) + break + } + } + } + } + } + } + } +} + +func TestWriteControlDeadline(t *testing.T) { + t.Parallel() + message := []byte("hello") + var connBuf bytes.Buffer + c := newTestConn(nil, &connBuf, true) + if err := c.WriteControl(PongMessage, message, time.Time{}); err != nil { + t.Errorf("WriteControl(..., zero deadline) = %v, want nil", err) + } + if err := c.WriteControl(PongMessage, message, time.Now().Add(time.Second)); err != nil { + t.Errorf("WriteControl(..., future deadline) = %v, want nil", err) + } + if err := c.WriteControl(PongMessage, message, time.Now().Add(-time.Second)); err == nil { + t.Errorf("WriteControl(..., past deadline) = nil, want timeout error") + } +} + +func TestConcurrencyWriteControl(t *testing.T) { + const message = "this is a ping/pong messsage" + loop := 10 + workers := 10 + for i := 0; i < loop; i++ { + var connBuf bytes.Buffer + + wg := sync.WaitGroup{} + wc := newTestConn(nil, &connBuf, true) + + for i := 0; i < workers; i++ { + wg.Add(1) + go func() { + defer wg.Done() + if err := wc.WriteControl(PongMessage, []byte(message), time.Now().Add(time.Second)); err != nil { + t.Errorf("concurrently wc.WriteControl() returned %v", err) + } + }() + } + + wg.Wait() + wc.Close() + } +} + +func TestControl(t *testing.T) { + const message = "this is a ping/pong message" + for _, isServer := range []bool{true, false} { + for _, isWriteControl := range []bool{true, false} { + name := fmt.Sprintf("s:%v, wc:%v", isServer, isWriteControl) + var connBuf bytes.Buffer + wc := newTestConn(nil, &connBuf, isServer) + rc := newTestConn(&connBuf, nil, !isServer) + if isWriteControl { + _ = wc.WriteControl(PongMessage, []byte(message), time.Now().Add(time.Second)) + } else { + w, err := wc.NextWriter(PongMessage) + if err != nil { + t.Errorf("%s: wc.NextWriter() returned %v", name, err) + continue + } + if _, err := w.Write([]byte(message)); err != nil { + t.Errorf("%s: w.Write() returned %v", name, err) + continue + } + if err := w.Close(); err != nil { + t.Errorf("%s: w.Close() returned %v", name, err) + continue + } + var actualMessage string + rc.SetPongHandler(func(s string) error { actualMessage = s; return nil }) + _, _, _ = rc.NextReader() + if actualMessage != message { + t.Errorf("%s: pong=%q, want %q", name, actualMessage, message) + continue + } + } + } + } +} + +// simpleBufferPool is an implementation of BufferPool for TestWriteBufferPool. +type simpleBufferPool struct { + v interface{} +} + +func (p *simpleBufferPool) Get() interface{} { + v := p.v + p.v = nil + return v +} + +func (p *simpleBufferPool) Put(v interface{}) { + p.v = v +} + +func TestWriteBufferPool(t *testing.T) { + const message = "Now is the time for all good people to come to the aid of the party." + + var buf bytes.Buffer + var pool simpleBufferPool + rc := newTestConn(&buf, nil, false) + + // Specify writeBufferSize smaller than message size to ensure that pooling + // works with fragmented messages. + wc := newConn(fakeNetConn{Writer: &buf}, true, 1024, len(message)-1, &pool, nil, nil) + + if wc.writeBuf != nil { + t.Fatal("writeBuf not nil after create") + } + + // Part 1: test NextWriter/Write/Close + + w, err := wc.NextWriter(TextMessage) + if err != nil { + t.Fatalf("wc.NextWriter() returned %v", err) + } + + if wc.writeBuf == nil { + t.Fatal("writeBuf is nil after NextWriter") + } + + writeBufAddr := &wc.writeBuf[0] + + if _, err := io.WriteString(w, message); err != nil { + t.Fatalf("io.WriteString(w, message) returned %v", err) + } + + if err := w.Close(); err != nil { + t.Fatalf("w.Close() returned %v", err) + } + + if wc.writeBuf != nil { + t.Fatal("writeBuf not nil after w.Close()") + } + + if wpd, ok := pool.v.(writePoolData); !ok || len(wpd.buf) == 0 || &wpd.buf[0] != writeBufAddr { + t.Fatal("writeBuf not returned to pool") + } + + opCode, p, err := rc.ReadMessage() + if opCode != TextMessage || err != nil { + t.Fatalf("ReadMessage() returned %d, p, %v", opCode, err) + } + + if s := string(p); s != message { + t.Fatalf("message is %s, want %s", s, message) + } + + // Part 2: Test WriteMessage. + + if err := wc.WriteMessage(TextMessage, []byte(message)); err != nil { + t.Fatalf("wc.WriteMessage() returned %v", err) + } + + if wc.writeBuf != nil { + t.Fatal("writeBuf not nil after wc.WriteMessage()") + } + + if wpd, ok := pool.v.(writePoolData); !ok || len(wpd.buf) == 0 || &wpd.buf[0] != writeBufAddr { + t.Fatal("writeBuf not returned to pool after WriteMessage") + } + + opCode, p, err = rc.ReadMessage() + if opCode != TextMessage || err != nil { + t.Fatalf("ReadMessage() returned %d, p, %v", opCode, err) + } + + if s := string(p); s != message { + t.Fatalf("message is %s, want %s", s, message) + } +} + +// TestWriteBufferPoolSync ensures that *sync.Pool works as a buffer pool. +func TestWriteBufferPoolSync(t *testing.T) { + var buf bytes.Buffer + var pool sync.Pool + wc := newConn(fakeNetConn{Writer: &buf}, true, 1024, 1024, &pool, nil, nil) + rc := newTestConn(&buf, nil, false) + + const message = "Hello World!" + for i := 0; i < 3; i++ { + if err := wc.WriteMessage(TextMessage, []byte(message)); err != nil { + t.Fatalf("wc.WriteMessage() returned %v", err) + } + opCode, p, err := rc.ReadMessage() + if opCode != TextMessage || err != nil { + t.Fatalf("ReadMessage() returned %d, p, %v", opCode, err) + } + if s := string(p); s != message { + t.Fatalf("message is %s, want %s", s, message) + } + } +} + +// errorWriter is an io.Writer than returns an error on all writes. +type errorWriter struct{} + +func (ew errorWriter) Write(p []byte) (int, error) { return 0, errors.New("error") } + +// TestWriteBufferPoolError ensures that buffer is returned to pool after error +// on write. +func TestWriteBufferPoolError(t *testing.T) { + + // Part 1: Test NextWriter/Write/Close + + var pool simpleBufferPool + wc := newConn(fakeNetConn{Writer: errorWriter{}}, true, 1024, 1024, &pool, nil, nil) + + w, err := wc.NextWriter(TextMessage) + if err != nil { + t.Fatalf("wc.NextWriter() returned %v", err) + } + + if wc.writeBuf == nil { + t.Fatal("writeBuf is nil after NextWriter") + } + + writeBufAddr := &wc.writeBuf[0] + + if _, err := io.WriteString(w, "Hello"); err != nil { + t.Fatalf("io.WriteString(w, message) returned %v", err) + } + + if err := w.Close(); err == nil { + t.Fatalf("w.Close() did not return error") + } + + if wpd, ok := pool.v.(writePoolData); !ok || len(wpd.buf) == 0 || &wpd.buf[0] != writeBufAddr { + t.Fatal("writeBuf not returned to pool") + } + + // Part 2: Test WriteMessage + + wc = newConn(fakeNetConn{Writer: errorWriter{}}, true, 1024, 1024, &pool, nil, nil) + + if err := wc.WriteMessage(TextMessage, []byte("Hello")); err == nil { + t.Fatalf("wc.WriteMessage did not return error") + } + + if wpd, ok := pool.v.(writePoolData); !ok || len(wpd.buf) == 0 || &wpd.buf[0] != writeBufAddr { + t.Fatal("writeBuf not returned to pool") + } +} + +func TestCloseFrameBeforeFinalMessageFrame(t *testing.T) { + const bufSize = 512 + + expectedErr := &CloseError{Code: CloseNormalClosure, Text: "hello"} + + var b1, b2 bytes.Buffer + wc := newConn(&fakeNetConn{Reader: nil, Writer: &b1}, false, 1024, bufSize, nil, nil, nil) + rc := newTestConn(&b1, &b2, true) + + w, _ := wc.NextWriter(BinaryMessage) + _, _ = w.Write(make([]byte, bufSize+bufSize/2)) + _ = wc.WriteControl(CloseMessage, FormatCloseMessage(expectedErr.Code, expectedErr.Text), time.Now().Add(10*time.Second)) + w.Close() + + op, r, err := rc.NextReader() + if op != BinaryMessage || err != nil { + t.Fatalf("NextReader() returned %d, %v", op, err) + } + _, err = io.Copy(io.Discard, r) + if !reflect.DeepEqual(err, expectedErr) { + t.Fatalf("io.Copy() returned %v, want %v", err, expectedErr) + } + _, _, err = rc.NextReader() + if !reflect.DeepEqual(err, expectedErr) { + t.Fatalf("NextReader() returned %v, want %v", err, expectedErr) + } +} + +func TestEOFWithinFrame(t *testing.T) { + const bufSize = 64 + + for n := 0; ; n++ { + var b bytes.Buffer + wc := newTestConn(nil, &b, false) + rc := newTestConn(&b, nil, true) + + w, _ := wc.NextWriter(BinaryMessage) + _, _ = w.Write(make([]byte, bufSize)) + w.Close() + + if n >= b.Len() { + break + } + b.Truncate(n) + + op, r, err := rc.NextReader() + if err == errUnexpectedEOF { + continue + } + if op != BinaryMessage || err != nil { + t.Fatalf("%d: NextReader() returned %d, %v", n, op, err) + } + _, err = io.Copy(io.Discard, r) + if err != errUnexpectedEOF { + t.Fatalf("%d: io.Copy() returned %v, want %v", n, err, errUnexpectedEOF) + } + _, _, err = rc.NextReader() + if err != errUnexpectedEOF { + t.Fatalf("%d: NextReader() returned %v, want %v", n, err, errUnexpectedEOF) + } + } +} + +func TestEOFBeforeFinalFrame(t *testing.T) { + const bufSize = 512 + + var b1, b2 bytes.Buffer + wc := newConn(&fakeNetConn{Writer: &b1}, false, 1024, bufSize, nil, nil, nil) + rc := newTestConn(&b1, &b2, true) + + w, _ := wc.NextWriter(BinaryMessage) + _, _ = w.Write(make([]byte, bufSize+bufSize/2)) + + op, r, err := rc.NextReader() + if op != BinaryMessage || err != nil { + t.Fatalf("NextReader() returned %d, %v", op, err) + } + _, err = io.Copy(io.Discard, r) + if err != errUnexpectedEOF { + t.Fatalf("io.Copy() returned %v, want %v", err, errUnexpectedEOF) + } + _, _, err = rc.NextReader() + if err != errUnexpectedEOF { + t.Fatalf("NextReader() returned %v, want %v", err, errUnexpectedEOF) + } +} + +func TestWriteAfterMessageWriterClose(t *testing.T) { + wc := newTestConn(nil, &bytes.Buffer{}, false) + w, _ := wc.NextWriter(BinaryMessage) + _, _ = io.WriteString(w, "hello") + if err := w.Close(); err != nil { + t.Fatalf("unexpected error closing message writer, %v", err) + } + + if _, err := io.WriteString(w, "world"); err == nil { + t.Fatalf("no error writing after close") + } + + w, _ = wc.NextWriter(BinaryMessage) + _, _ = io.WriteString(w, "hello") + + // close w by getting next writer + _, err := wc.NextWriter(BinaryMessage) + if err != nil { + t.Fatalf("unexpected error getting next writer, %v", err) + } + + if _, err := io.WriteString(w, "world"); err == nil { + t.Fatalf("no error writing after close") + } +} + +func TestReadLimit(t *testing.T) { + t.Run("Test ReadLimit is enforced", func(t *testing.T) { + const readLimit = 512 + message := make([]byte, readLimit+1) + + var b1, b2 bytes.Buffer + wc := newConn(&fakeNetConn{Writer: &b1}, false, 1024, readLimit-2, nil, nil, nil) + rc := newTestConn(&b1, &b2, true) + rc.SetReadLimit(readLimit) + + // Send message at the limit with interleaved pong. + w, _ := wc.NextWriter(BinaryMessage) + _, _ = w.Write(message[:readLimit-1]) + _ = wc.WriteControl(PongMessage, []byte("this is a pong"), time.Now().Add(10*time.Second)) + _, _ = w.Write(message[:1]) + w.Close() + + // Send message larger than the limit. + _ = wc.WriteMessage(BinaryMessage, message[:readLimit+1]) + + op, _, err := rc.NextReader() + if op != BinaryMessage || err != nil { + t.Fatalf("1: NextReader() returned %d, %v", op, err) + } + op, r, err := rc.NextReader() + if op != BinaryMessage || err != nil { + t.Fatalf("2: NextReader() returned %d, %v", op, err) + } + _, err = io.Copy(io.Discard, r) + if err != ErrReadLimit { + t.Fatalf("io.Copy() returned %v", err) + } + }) + + t.Run("Test that ReadLimit cannot be overflowed", func(t *testing.T) { + const readLimit = 1 + + var b1, b2 bytes.Buffer + rc := newTestConn(&b1, &b2, true) + rc.SetReadLimit(readLimit) + + // First, send a non-final binary message + b1.Write([]byte("\x02\x81")) + + // Mask key + b1.Write([]byte("\x00\x00\x00\x00")) + + // First payload + b1.Write([]byte("A")) + + // Next, send a negative-length, non-final continuation frame + b1.Write([]byte("\x00\xFF\x80\x00\x00\x00\x00\x00\x00\x00")) + + // Mask key + b1.Write([]byte("\x00\x00\x00\x00")) + + // Next, send a too long, final continuation frame + b1.Write([]byte("\x80\xFF\x00\x00\x00\x00\x00\x00\x00\x05")) + + // Mask key + b1.Write([]byte("\x00\x00\x00\x00")) + + // Too-long payload + b1.Write([]byte("BCDEF")) + + op, r, err := rc.NextReader() + if op != BinaryMessage || err != nil { + t.Fatalf("1: NextReader() returned %d, %v", op, err) + } + + var buf [10]byte + var read int + n, err := r.Read(buf[:]) + if err != nil && err != ErrReadLimit { + t.Fatalf("unexpected error testing read limit: %v", err) + } + read += n + + n, err = r.Read(buf[:]) + if err != nil && err != ErrReadLimit { + t.Fatalf("unexpected error testing read limit: %v", err) + } + read += n + + if err == nil && read > readLimit { + t.Fatalf("read limit exceeded: limit %d, read %d", readLimit, read) + } + }) +} + +func TestAddrs(t *testing.T) { + c := newTestConn(nil, nil, true) + if c.LocalAddr() != localAddr { + t.Errorf("LocalAddr = %v, want %v", c.LocalAddr(), localAddr) + } + if c.RemoteAddr() != remoteAddr { + t.Errorf("RemoteAddr = %v, want %v", c.RemoteAddr(), remoteAddr) + } +} + +func TestDeprecatedUnderlyingConn(t *testing.T) { + var b1, b2 bytes.Buffer + fc := fakeNetConn{Reader: &b1, Writer: &b2} + c := newConn(fc, true, 1024, 1024, nil, nil, nil) + ul := c.UnderlyingConn() + if ul != fc { + t.Fatalf("Underlying conn is not what it should be.") + } +} + +func TestNetConn(t *testing.T) { + var b1, b2 bytes.Buffer + fc := fakeNetConn{Reader: &b1, Writer: &b2} + c := newConn(fc, true, 1024, 1024, nil, nil, nil) + ul := c.NetConn() + if ul != fc { + t.Fatalf("Underlying conn is not what it should be.") + } +} + +func TestBufioReadBytes(t *testing.T) { + // Test calling bufio.ReadBytes for value longer than read buffer size. + + m := make([]byte, 512) + m[len(m)-1] = '\n' + + var b1, b2 bytes.Buffer + wc := newConn(fakeNetConn{Writer: &b1}, false, len(m)+64, len(m)+64, nil, nil, nil) + rc := newConn(fakeNetConn{Reader: &b1, Writer: &b2}, true, len(m)-64, len(m)-64, nil, nil, nil) + + w, _ := wc.NextWriter(BinaryMessage) + _, _ = w.Write(m) + w.Close() + + op, r, err := rc.NextReader() + if op != BinaryMessage || err != nil { + t.Fatalf("NextReader() returned %d, %v", op, err) + } + + br := bufio.NewReader(r) + p, err := br.ReadBytes('\n') + if err != nil { + t.Fatalf("ReadBytes() returned %v", err) + } + if len(p) != len(m) { + t.Fatalf("read returned %d bytes, want %d bytes", len(p), len(m)) + } +} + +var closeErrorTests = []struct { + err error + codes []int + ok bool +}{ + {&CloseError{Code: CloseNormalClosure}, []int{CloseNormalClosure}, true}, + {&CloseError{Code: CloseNormalClosure}, []int{CloseNoStatusReceived}, false}, + {&CloseError{Code: CloseNormalClosure}, []int{CloseNoStatusReceived, CloseNormalClosure}, true}, + {errors.New("hello"), []int{CloseNormalClosure}, false}, +} + +func TestCloseError(t *testing.T) { + for _, tt := range closeErrorTests { + ok := IsCloseError(tt.err, tt.codes...) + if ok != tt.ok { + t.Errorf("IsCloseError(%#v, %#v) returned %v, want %v", tt.err, tt.codes, ok, tt.ok) + } + } +} + +var unexpectedCloseErrorTests = []struct { + err error + codes []int + ok bool +}{ + {&CloseError{Code: CloseNormalClosure}, []int{CloseNormalClosure}, false}, + {&CloseError{Code: CloseNormalClosure}, []int{CloseNoStatusReceived}, true}, + {&CloseError{Code: CloseNormalClosure}, []int{CloseNoStatusReceived, CloseNormalClosure}, false}, + {errors.New("hello"), []int{CloseNormalClosure}, false}, +} + +func TestUnexpectedCloseErrors(t *testing.T) { + for _, tt := range unexpectedCloseErrorTests { + ok := IsUnexpectedCloseError(tt.err, tt.codes...) + if ok != tt.ok { + t.Errorf("IsUnexpectedCloseError(%#v, %#v) returned %v, want %v", tt.err, tt.codes, ok, tt.ok) + } + } +} + +type blockingWriter struct { + c1, c2 chan struct{} +} + +func (w blockingWriter) Write(p []byte) (int, error) { + // Allow main to continue + close(w.c1) + // Wait for panic in main + <-w.c2 + return len(p), nil +} + +func TestConcurrentWritePanic(t *testing.T) { + w := blockingWriter{make(chan struct{}), make(chan struct{})} + c := newTestConn(nil, w, false) + go func() { + _ = c.WriteMessage(TextMessage, []byte{}) + }() + + // wait for goroutine to block in write. + <-w.c1 + + defer func() { + close(w.c2) + if v := recover(); v != nil { + return + } + }() + + _ = c.WriteMessage(TextMessage, []byte{}) + t.Fatal("should not get here") +} + +type failingReader struct{} + +func (r failingReader) Read(p []byte) (int, error) { + return 0, io.EOF +} + +func TestFailedConnectionReadPanic(t *testing.T) { + c := newTestConn(failingReader{}, nil, false) + + defer func() { + if v := recover(); v != nil { + return + } + }() + + for i := 0; i < 20000; i++ { + _, _, _ = c.ReadMessage() + } + t.Fatal("should not get here") +} +// Copyright 2019 The Gorilla WebSocket Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + + +func TestJoinMessages(t *testing.T) { + messages := []string{"a", "bc", "def", "ghij", "klmno", "0", "12", "345", "6789"} + for _, readChunk := range []int{1, 2, 3, 4, 5, 6, 7} { + for _, term := range []string{"", ","} { + var connBuf bytes.Buffer + wc := newTestConn(nil, &connBuf, true) + rc := newTestConn(&connBuf, nil, false) + for _, m := range messages { + _ = wc.WriteMessage(BinaryMessage, []byte(m)) + } + + var result bytes.Buffer + _, err := io.CopyBuffer(&result, JoinMessages(rc, term), make([]byte, readChunk)) + if IsUnexpectedCloseError(err, CloseAbnormalClosure) { + t.Errorf("readChunk=%d, term=%q: unexpected error %v", readChunk, term, err) + } + want := strings.Join(messages, term) + term + if result.String() != want { + t.Errorf("readChunk=%d, term=%q, got %q, want %q", readChunk, term, result.String(), want) + } + } + } +} +// Copyright 2013 The Gorilla WebSocket Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + + +func TestJSON(t *testing.T) { + var buf bytes.Buffer + wc := newTestConn(nil, &buf, true) + rc := newTestConn(&buf, nil, false) + + var actual, expect struct { + A int + B string + } + expect.A = 1 + expect.B = "hello" + + if err := wc.WriteJSON(&expect); err != nil { + t.Fatal("write", err) + } + + if err := rc.ReadJSON(&actual); err != nil { + t.Fatal("read", err) + } + + if !reflect.DeepEqual(&actual, &expect) { + t.Fatal("equal", actual, expect) + } +} + +func TestPartialJSONRead(t *testing.T) { + var buf0, buf1 bytes.Buffer + wc := newTestConn(nil, &buf0, true) + rc := newTestConn(&buf0, &buf1, false) + + var v struct { + A int + B string + } + v.A = 1 + v.B = "hello" + + messageCount := 0 + + // Partial JSON values. + + data, err := json.Marshal(v) + if err != nil { + t.Fatal(err) + } + for i := len(data) - 1; i >= 0; i-- { + if err := wc.WriteMessage(TextMessage, data[:i]); err != nil { + t.Fatal(err) + } + messageCount++ + } + + // Whitespace. + + if err := wc.WriteMessage(TextMessage, []byte(" ")); err != nil { + t.Fatal(err) + } + messageCount++ + + // Close. + + if err := wc.WriteMessage(CloseMessage, FormatCloseMessage(CloseNormalClosure, "")); err != nil { + t.Fatal(err) + } + + for i := 0; i < messageCount; i++ { + err := rc.ReadJSON(&v) + if err != io.ErrUnexpectedEOF { + t.Error("read", i, err) + } + } + + err = rc.ReadJSON(&v) + if _, ok := err.(*CloseError); !ok { + t.Error("final", err) + } +} + +func TestDeprecatedJSON(t *testing.T) { + var buf bytes.Buffer + wc := newTestConn(nil, &buf, true) + rc := newTestConn(&buf, nil, false) + + var actual, expect struct { + A int + B string + } + expect.A = 1 + expect.B = "hello" + + if err := WriteJSON(wc, &expect); err != nil { + t.Fatal("write", err) + } + + if err := ReadJSON(rc, &actual); err != nil { + t.Fatal("read", err) + } + + if !reflect.DeepEqual(&actual, &expect) { + t.Fatal("equal", actual, expect) + } +} +// Copyright 2016 The Gorilla WebSocket Authors. All rights reserved. Use of +// this source code is governed by a BSD-style license that can be found in the +// LICENSE file. + +// !appengine + + +func maskBytesByByte(key [4]byte, pos int, b []byte) int { + for i := range b { + b[i] ^= key[pos&3] + pos++ + } + return pos & 3 +} + +func notzero(b []byte) int { + for i := range b { + if b[i] != 0 { + return i + } + } + return -1 +} + +func TestMaskBytes(t *testing.T) { + key := [4]byte{1, 2, 3, 4} + for size := 1; size <= 1024; size++ { + for align := 0; align < wordSize; align++ { + for pos := 0; pos < 4; pos++ { + b := make([]byte, size+align)[align:] + maskBytes(key, pos, b) + maskBytesByByte(key, pos, b) + if i := notzero(b); i >= 0 { + t.Errorf("size:%d, align:%d, pos:%d, offset:%d", size, align, pos, i) + } + } + } + } +} + +func BenchmarkMaskBytes(b *testing.B) { + for _, size := range []int{2, 4, 8, 16, 32, 512, 1024} { + b.Run(fmt.Sprintf("size-%d", size), func(b *testing.B) { + for _, align := range []int{wordSize / 2} { + b.Run(fmt.Sprintf("align-%d", align), func(b *testing.B) { + for _, fn := range []struct { + name string + fn func(key [4]byte, pos int, b []byte) int + }{ + {"byte", maskBytesByByte}, + {"word", maskBytes}, + } { + b.Run(fn.name, func(b *testing.B) { + key := newMaskKey() + data := make([]byte, size+align)[align:] + for i := 0; i < b.N; i++ { + fn.fn(key, 0, data) + } + b.SetBytes(int64(len(data))) + }) + } + }) + } + }) + } +} +// Copyright 2017 The Gorilla WebSocket Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + + +var preparedMessageTests = []struct { + messageType int + isServer bool + enableWriteCompression bool + compressionLevel int +}{ + // Server + {TextMessage, true, false, flate.BestSpeed}, + {TextMessage, true, true, flate.BestSpeed}, + {TextMessage, true, true, flate.BestCompression}, + {PingMessage, true, false, flate.BestSpeed}, + {PingMessage, true, true, flate.BestSpeed}, + + // Client + {TextMessage, false, false, flate.BestSpeed}, + {TextMessage, false, true, flate.BestSpeed}, + {TextMessage, false, true, flate.BestCompression}, + {PingMessage, false, false, flate.BestSpeed}, + {PingMessage, false, true, flate.BestSpeed}, +} + +func TestPreparedMessage(t *testing.T) { + testRand := rand.New(rand.NewSource(99)) + prevMaskRand := maskRand + maskRand = testRand + defer func() { maskRand = prevMaskRand }() + + for _, tt := range preparedMessageTests { + var data = []byte("this is a test") + var buf bytes.Buffer + c := newTestConn(nil, &buf, tt.isServer) + if tt.enableWriteCompression { + c.newCompressionWriter = compressNoContextTakeover + } + if err := c.SetCompressionLevel(tt.compressionLevel); err != nil { + t.Fatal(err) + } + + // Seed random number generator for consistent frame mask. + testRand.Seed(1234) + + if err := c.WriteMessage(tt.messageType, data); err != nil { + t.Fatal(err) + } + want := buf.String() + + pm, err := NewPreparedMessage(tt.messageType, data) + if err != nil { + t.Fatal(err) + } + + // Scribble on data to ensure that NewPreparedMessage takes a snapshot. + copy(data, "hello world") + + // Seed random number generator for consistent frame mask. + testRand.Seed(1234) + + buf.Reset() + if err := c.WritePreparedMessage(pm); err != nil { + t.Fatal(err) + } + got := buf.String() + + if got != want { + t.Errorf("write message != prepared message for %+v", tt) + } + } +} +// Copyright 2013 The Gorilla WebSocket Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + + +var subprotocolTests = []struct { + h string + protocols []string +}{ + {"", nil}, + {"foo", []string{"foo"}}, + {"foo,bar", []string{"foo", "bar"}}, + {"foo, bar", []string{"foo", "bar"}}, + {" foo, bar", []string{"foo", "bar"}}, + {" foo, bar ", []string{"foo", "bar"}}, +} + +func TestSubprotocols(t *testing.T) { + for _, st := range subprotocolTests { + r := http.Request{Header: http.Header{"Sec-Websocket-Protocol": {st.h}}} + protocols := Subprotocols(&r) + if !reflect.DeepEqual(st.protocols, protocols) { + t.Errorf("SubProtocols(%q) returned %#v, want %#v", st.h, protocols, st.protocols) + } + } +} + +var isWebSocketUpgradeTests = []struct { + ok bool + h http.Header +}{ + {false, http.Header{"Upgrade": {"websocket"}}}, + {false, http.Header{"Connection": {"upgrade"}}}, + {true, http.Header{"Connection": {"upgRade"}, "Upgrade": {"WebSocket"}}}, +} + +func TestIsWebSocketUpgrade(t *testing.T) { + for _, tt := range isWebSocketUpgradeTests { + ok := IsWebSocketUpgrade(&http.Request{Header: tt.h}) + if tt.ok != ok { + t.Errorf("IsWebSocketUpgrade(%v) returned %v, want %v", tt.h, ok, tt.ok) + } + } +} + +func TestSubProtocolSelection(t *testing.T) { + upgrader := Upgrader{ + Subprotocols: []string{"foo", "bar", "baz"}, + } + + r := http.Request{Header: http.Header{"Sec-Websocket-Protocol": {"foo", "bar"}}} + s := upgrader.selectSubprotocol(&r, nil) + if s != "foo" { + t.Errorf("Upgrader.selectSubprotocol returned %v, want %v", s, "foo") + } + + r = http.Request{Header: http.Header{"Sec-Websocket-Protocol": {"bar", "foo"}}} + s = upgrader.selectSubprotocol(&r, nil) + if s != "bar" { + t.Errorf("Upgrader.selectSubprotocol returned %v, want %v", s, "bar") + } + + r = http.Request{Header: http.Header{"Sec-Websocket-Protocol": {"baz"}}} + s = upgrader.selectSubprotocol(&r, nil) + if s != "baz" { + t.Errorf("Upgrader.selectSubprotocol returned %v, want %v", s, "baz") + } + + r = http.Request{Header: http.Header{"Sec-Websocket-Protocol": {"quux"}}} + s = upgrader.selectSubprotocol(&r, nil) + if s != "" { + t.Errorf("Upgrader.selectSubprotocol returned %v, want %v", s, "empty string") + } +} + +var checkSameOriginTests = []struct { + ok bool + r *http.Request +}{ + {false, &http.Request{Host: "example.org", Header: map[string][]string{"Origin": {"https://other.org"}}}}, + {true, &http.Request{Host: "example.org", Header: map[string][]string{"Origin": {"https://example.org"}}}}, + {true, &http.Request{Host: "Example.org", Header: map[string][]string{"Origin": {"https://example.org"}}}}, +} + +func TestCheckSameOrigin(t *testing.T) { + for _, tt := range checkSameOriginTests { + ok := checkSameOrigin(tt.r) + if tt.ok != ok { + t.Errorf("checkSameOrigin(%+v) returned %v, want %v", tt.r, ok, tt.ok) + } + } +} + +type reuseTestResponseWriter struct { + brw *bufio.ReadWriter + http.ResponseWriter +} + +func (resp *reuseTestResponseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) { + return fakeNetConn{strings.NewReader(""), &bytes.Buffer{}}, resp.brw, nil +} + +var bufioReuseTests = []struct { + n int + reuse bool +}{ + {4096, true}, + {128, false}, +} + +func xTestBufioReuse(t *testing.T) { + for i, tt := range bufioReuseTests { + br := bufio.NewReaderSize(strings.NewReader(""), tt.n) + bw := bufio.NewWriterSize(&bytes.Buffer{}, tt.n) + resp := &reuseTestResponseWriter{ + brw: bufio.NewReadWriter(br, bw), + } + upgrader := Upgrader{} + c, err := upgrader.Upgrade(resp, &http.Request{ + Method: http.MethodGet, + Header: http.Header{ + "Upgrade": []string{"websocket"}, + "Connection": []string{"upgrade"}, + "Sec-Websocket-Key": []string{"dGhlIHNhbXBsZSBub25jZQ=="}, + "Sec-Websocket-Version": []string{"13"}, + }}, nil) + if err != nil { + t.Fatal(err) + } + if reuse := c.br == br; reuse != tt.reuse { + t.Errorf("%d: buffered reader reuse=%v, want %v", i, reuse, tt.reuse) + } + writeBuf := bw.AvailableBuffer() + if reuse := &c.writeBuf[0] == &writeBuf[0]; reuse != tt.reuse { + t.Errorf("%d: write buffer reuse=%v, want %v", i, reuse, tt.reuse) + } + } +} + +func TestHijack_NotSupported(t *testing.T) { + t.Parallel() + + req := httptest.NewRequest(http.MethodGet, "http://example.com", nil) + req.Header.Set("Upgrade", "websocket") + req.Header.Set("Connection", "upgrade") + req.Header.Set("Sec-Websocket-Key", "dGhlIHNhbXBsZSBub25jZQ==") + req.Header.Set("Sec-Websocket-Version", "13") + + recorder := httptest.NewRecorder() + + upgrader := Upgrader{} + _, err := upgrader.Upgrade(recorder, req, nil) + + if want := (HandshakeError{}); !errors.As(err, &want) || recorder.Code != http.StatusInternalServerError { + t.Errorf("want %T and status_code=%d", want, http.StatusInternalServerError) + t.Fatalf("got err=%T and status_code=%d", err, recorder.Code) + } +} +// Copyright 2014 The Gorilla WebSocket Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + + +var equalASCIIFoldTests = []struct { + t, s string + eq bool +}{ + {"WebSocket", "websocket", true}, + {"websocket", "WebSocket", true}, + {"Öyster", "öyster", false}, + {"WebSocket", "WetSocket", false}, +} + +func TestEqualASCIIFold(t *testing.T) { + for _, tt := range equalASCIIFoldTests { + eq := equalASCIIFold(tt.s, tt.t) + if eq != tt.eq { + t.Errorf("equalASCIIFold(%q, %q) = %v, want %v", tt.s, tt.t, eq, tt.eq) + } + } +} + +var tokenListContainsValueTests = []struct { + value string + ok bool +}{ + {"WebSocket", true}, + {"WEBSOCKET", true}, + {"websocket", true}, + {"websockets", false}, + {"x websocket", false}, + {"websocket x", false}, + {"other,websocket,more", true}, + {"other, websocket, more", true}, +} + +func TestTokenListContainsValue(t *testing.T) { + for _, tt := range tokenListContainsValueTests { + h := http.Header{"Upgrade": {tt.value}} + ok := tokenListContainsValue(h, "Upgrade", "websocket") + if ok != tt.ok { + t.Errorf("tokenListContainsValue(h, n, %q) = %v, want %v", tt.value, ok, tt.ok) + } + } +} + +var isValidChallengeKeyTests = []struct { + key string + ok bool +}{ + {"dGhlIHNhbXBsZSBub25jZQ==", true}, + {"", false}, + {"InvalidKey", false}, + {"WHQ4eXhscUtKYjBvOGN3WEdtOEQ=", false}, +} + +func TestIsValidChallengeKey(t *testing.T) { + for _, tt := range isValidChallengeKeyTests { + ok := isValidChallengeKey(tt.key) + if ok != tt.ok { + t.Errorf("isValidChallengeKey returns %v, want %v", ok, tt.ok) + } + } +} + +var parseExtensionTests = []struct { + value string + extensions []map[string]string +}{ + {`foo`, []map[string]string{{"": "foo"}}}, + {`foo, bar; baz=2`, []map[string]string{ + {"": "foo"}, + {"": "bar", "baz": "2"}}}, + {`foo; bar="b,a;z"`, []map[string]string{ + {"": "foo", "bar": "b,a;z"}}}, + {`foo , bar; baz = 2`, []map[string]string{ + {"": "foo"}, + {"": "bar", "baz": "2"}}}, + {`foo, bar; baz=2 junk`, []map[string]string{ + {"": "foo"}}}, + {`foo junk, bar; baz=2 junk`, nil}, + {`mux; max-channels=4; flow-control, deflate-stream`, []map[string]string{ + {"": "mux", "max-channels": "4", "flow-control": ""}, + {"": "deflate-stream"}}}, + {`permessage-foo; x="10"`, []map[string]string{ + {"": "permessage-foo", "x": "10"}}}, + {`permessage-foo; use_y, permessage-foo`, []map[string]string{ + {"": "permessage-foo", "use_y": ""}, + {"": "permessage-foo"}}}, + {`permessage-deflate; client_max_window_bits; server_max_window_bits=10 , permessage-deflate; client_max_window_bits`, []map[string]string{ + {"": "permessage-deflate", "client_max_window_bits": "", "server_max_window_bits": "10"}, + {"": "permessage-deflate", "client_max_window_bits": ""}}}, + {"permessage-deflate; server_no_context_takeover; client_max_window_bits=15", []map[string]string{ + {"": "permessage-deflate", "server_no_context_takeover": "", "client_max_window_bits": "15"}, + }}, +} + +func TestParseExtensions(t *testing.T) { + for _, tt := range parseExtensionTests { + h := http.Header{http.CanonicalHeaderKey("Sec-WebSocket-Extensions"): {tt.value}} + extensions := parseExtensions(h) + if !reflect.DeepEqual(extensions, tt.extensions) { + t.Errorf("parseExtensions(%q)\n = %v,\nwant %v", tt.value, extensions, tt.extensions) + } + } +} + +func TestParseArgs(t *testing.T) { + given := ParseArgs([]string { "x", "y", "z" }) + expected := CLIArgs { + FromAddr: "y", + ToAddr: "z", + } + + g.AssertEqual(given, expected) +} + + + +func MainTest() { + tests := []testing.InternalTest { + { "TestProxyDial", TestProxyDial }, + { "TestProxyAuthorizationDial", TestProxyAuthorizationDial }, + { "TestDial", TestDial }, + { "TestDialCookieJar", TestDialCookieJar }, + { "TestDialTLS", TestDialTLS }, + { "TestDialTimeout", TestDialTimeout }, + { "TestHandshakeTimeout", TestHandshakeTimeout }, + { "TestHandshakeTimeoutInContext", TestHandshakeTimeoutInContext }, + { "TestDialBadScheme", TestDialBadScheme }, + { "TestDialBadOrigin", TestDialBadOrigin }, + { "TestDialBadHeader", TestDialBadHeader }, + { "TestBadMethod", TestBadMethod }, + { "TestNoUpgrade", TestNoUpgrade }, + { "TestDialExtraTokensInRespHeaders", TestDialExtraTokensInRespHeaders }, + { "TestHandshake", TestHandshake }, + { "TestRespOnBadHandshake", TestRespOnBadHandshake }, + { "TestHost", TestHost }, + { "TestDialCompression", TestDialCompression }, + { "TestSocksProxyDial", TestSocksProxyDial }, + { "TestTracingDialWithContext", TestTracingDialWithContext }, + { "TestEmptyTracingDialWithContext", TestEmptyTracingDialWithContext }, + { "TestNetDialConnect", TestNetDialConnect }, + { "TestNextProtos", TestNextProtos }, + { "TestDataReceivedBeforeHandshake", TestDataReceivedBeforeHandshake }, + { "TestHostPortNoPort", TestHostPortNoPort }, + { "TestTruncWriter", TestTruncWriter }, + { "TestValidCompressionLevel", TestValidCompressionLevel }, + { "TestFraming", TestFraming }, + { "TestWriteControlDeadline", TestWriteControlDeadline }, + { "TestConcurrencyWriteControl", TestConcurrencyWriteControl }, + { "TestControl", TestControl }, + { "TestWriteBufferPool", TestWriteBufferPool }, + { "TestWriteBufferPoolSync", TestWriteBufferPoolSync }, + { "TestWriteBufferPoolError", TestWriteBufferPoolError }, + { "TestCloseFrameBeforeFinalMessageFrame", TestCloseFrameBeforeFinalMessageFrame }, + { "TestEOFWithinFrame", TestEOFWithinFrame }, + { "TestEOFBeforeFinalFrame", TestEOFBeforeFinalFrame }, + { "TestWriteAfterMessageWriterClose", TestWriteAfterMessageWriterClose }, + { "TestReadLimit", TestReadLimit }, + { "TestAddrs", TestAddrs }, + { "TestDeprecatedUnderlyingConn", TestDeprecatedUnderlyingConn }, + { "TestNetConn", TestNetConn }, + { "TestBufioReadBytes", TestBufioReadBytes }, + { "TestCloseError", TestCloseError }, + { "TestUnexpectedCloseErrors", TestUnexpectedCloseErrors }, + { "TestConcurrentWritePanic", TestConcurrentWritePanic }, + { "TestFailedConnectionReadPanic", TestFailedConnectionReadPanic }, + { "TestJoinMessages", TestJoinMessages }, + { "TestJSON", TestJSON }, + { "TestPartialJSONRead", TestPartialJSONRead }, + { "TestDeprecatedJSON", TestDeprecatedJSON }, + { "TestMaskBytes", TestMaskBytes }, + { "TestPreparedMessage", TestPreparedMessage }, + { "TestSubprotocols", TestSubprotocols }, + { "TestIsWebSocketUpgrade", TestIsWebSocketUpgrade }, + { "TestSubProtocolSelection", TestSubProtocolSelection }, + { "TestCheckSameOrigin", TestCheckSameOrigin }, + { "TestHijack_NotSupported", TestHijack_NotSupported }, + { "TestEqualASCIIFold", TestEqualASCIIFold }, + { "TestTokenListContainsValue", TestTokenListContainsValue }, + { "TestIsValidChallengeKey", TestIsValidChallengeKey }, + { "TestParseExtensions", TestParseExtensions }, + { "TestParseArgs", TestParseArgs }, + } + + // FIXME: run benchmarks + benchmarks := []testing.InternalBenchmark { + { "BenchmarkWriteNoCompression", BenchmarkWriteNoCompression }, + { "BenchmarkWriteWithCompression", BenchmarkWriteWithCompression }, + { "BenchmarkBroadcast", BenchmarkBroadcast }, + { "BenchmarkMaskBytes", BenchmarkMaskBytes }, + } + + fuzzTargets := []testing.InternalFuzzTarget {} + examples := []testing.InternalExample {} + m := testing.MainStart( + testdeps.TestDeps {}, + tests, + benchmarks, + fuzzTargets, + examples, + ) + os.Exit(m.Run()) +} |