summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorEuAndreh <eu@euandre.org>2024-08-10 11:30:01 -0300
committerEuAndreh <eu@euandre.org>2024-08-10 14:41:23 -0300
commit7b7819f07a790ba5f2e3ca3c3a3a83ea0a8e4a6a (patch)
tree67a0697fcd2ff097725e19969b6602195791292c
parentgo.mod: Use "replace" for all (diff)
downloadwscat-7b7819f07a790ba5f2e3ca3c3a3a83ea0a8e4a6a.tar.gz
wscat-7b7819f07a790ba5f2e3ca3c3a3a83ea0a8e4a6a.tar.xz
Build with "go tool" AND import code from deps
-rw-r--r--.gitignore3
-rw-r--r--Makefile72
-rw-r--r--go.mod16
-rw-r--r--src/lib.go118
-rw-r--r--src/main.go (renamed from src/cmd/main.go)2
-rw-r--r--src/proxy.go405
-rw-r--r--src/socks.go473
-rw-r--r--src/wscat.go3148
-rw-r--r--tests/lib_test.go24
-rw-r--r--tests/main.go7
-rw-r--r--tests/wscat.go2842
11 files changed, 6933 insertions, 177 deletions
diff --git a/.gitignore b/.gitignore
index 85a298a..75d542a 100644
--- a/.gitignore
+++ b/.gitignore
@@ -1,2 +1,5 @@
/*.bin
+/src/*.a
+/src/*.bin
+/tests/*.a
/tests/*.bin
diff --git a/Makefile b/Makefile
index a6d47a2..4713192 100644
--- a/Makefile
+++ b/Makefile
@@ -8,6 +8,7 @@ LANGUAGES = en
PREFIX = /usr
BINDIR = $(PREFIX)/bin
LIBDIR = $(PREFIX)/lib
+GOLIBDIR = $(LIBDIR)/go
INCLUDEDIR = $(PREFIX)/include
SRCDIR = $(PREFIX)/src/$(NAME)
SHAREDIR = $(PREFIX)/share
@@ -17,12 +18,11 @@ EXEC = ./
## Where to store the installation. Empty by default.
DESTDIR =
LDLIBS =
-GOFLAGS =
.SUFFIXES:
-.SUFFIXES: .go .bin
+.SUFFIXES: .go .a .bin .bin-check
@@ -30,14 +30,24 @@ all:
include deps.mk
+objects = \
+ src/$(NAME).a \
+ src/proxy.a \
+ src/socks.a \
+ src/main.a \
+ tests/$(NAME).a \
+ tests/main.a \
+
sources = \
- src/lib.go \
- src/cmd/main.go \
+ src/$(NAME).go \
+ src/main.go \
derived-assets = \
+ $(objects) \
+ src/main.bin \
+ tests/main.bin \
$(NAME).bin \
- tests/lib_test.bin \
side-assets = \
$(NAME).socket \
@@ -49,22 +59,48 @@ side-assets = \
all: $(derived-assets)
-$(NAME).bin: src/lib.go src/cmd/main.go Makefile
- go build $(GOFLAGS) -v -o $@ src/cmd/main.go
+$(objects): Makefile
+
+src/$(NAME).a: src/$(NAME).go src/proxy.a
+src/proxy.a: src/proxy.go src/socks.a
+src/socks.a: src/socks.go
+src/$(NAME).a src/proxy.a src/socks.a:
+ go tool compile $(GOCFLAGS) -o $@ -p $(*F) -I src $*.go
+
+tests/$(NAME).a: tests/$(NAME).go src/$(NAME).go src/proxy.a
+tests/$(NAME).a:
+ go tool compile $(GOCFLAGS) -o $@ -p $(*F) -I src $*.go src/$(*F).go
+
+src/main.a: src/main.go src/$(NAME).a
+tests/main.a: tests/main.go tests/$(NAME).a
+src/main.a tests/main.a:
+ go tool compile $(GOCFLAGS) -o $@ -I $(@D) $*.go
+
+src/main.bin: src/main.a
+tests/main.bin: tests/main.a
+src/main.bin tests/main.bin:
+ go tool link $(GOLDFLAGS) -o $@ -L $(@D) -L src $*.a
+
+$(NAME).bin: src/main.bin
+ ln -fs $? $@
+
-tests/lib_test.bin: src/lib.go tests/lib_test.go Makefile
- go test $(GOFLAGS) -v -o $@ -c $*.go
+tests.bin-check = \
+ tests/main.bin-check \
+tests/main.bin-check: tests/main.bin
+$(tests.bin-check):
+ $(EXEC)$*.bin
-check-unit: tests/lib_test.bin
- ./tests/lib_test.bin
+check-unit: $(tests.bin-check)
integration-tests = \
tests/cli-opts.sh \
tests/integration.sh \
+.PRECIOUS: $(integration-tests)
$(integration-tests): $(NAME).bin
$(integration-tests): ALWAYS
sh $@
@@ -90,21 +126,21 @@ clean:
install: all
mkdir -p \
'$(DESTDIR)$(BINDIR)' \
+ '$(DESTDIR)$(GOLIBDIR)' \
+ '$(DESTDIR)$(SRCDIR)' \
cp $(NAME).bin '$(DESTDIR)$(BINDIR)'/$(NAME)
- for f in $(sources); do \
- dir='$(DESTDIR)$(SRCDIR)'/"`dirname "$${f#src/}"`"; \
- mkdir -p "$$dir"; \
- cp -P "$$f" "$$dir"; \
- done
+ cp src/$(NAME).a '$(DESTDIR)$(GOLIBDIR)'
+ cp $(sources) '$(DESTDIR)$(SRCDIR)'
## Uninstalls from $(DESTDIR)$(PREFIX). This is a perfect mirror
## of the "install" target, and removes *all* that was installed.
## A dedicated test asserts that this is always true.
uninstall:
rm -rf \
- '$(DESTDIR)$(BINDIR)'/$(NAME) \
- '$(DESTDIR)$(SRCDIR)' \
+ '$(DESTDIR)$(BINDIR)'/$(NAME) \
+ '$(DESTDIR)$(GOLIBDIR)'/$(NAME).a \
+ '$(DESTDIR)$(SRCDIR)' \
diff --git a/go.mod b/go.mod
deleted file mode 100644
index 96e7eeb..0000000
--- a/go.mod
+++ /dev/null
@@ -1,16 +0,0 @@
-module euandre.org/wscat
-
-go 1.21.5
-
-require (
- euandre.org/gobang v0.1.0
- github.com/gorilla/websocket v1.5.3
-)
-
-require golang.org/x/net v0.26.0 // indirect
-
-replace (
- euandre.org/gobang => ../gobang
- github.com/gorilla/websocket => ../websocket
- golang.org/x/net => ../netx
-)
diff --git a/src/lib.go b/src/lib.go
deleted file mode 100644
index 3733475..0000000
--- a/src/lib.go
+++ /dev/null
@@ -1,118 +0,0 @@
-package wscat
-
-import (
- "fmt"
- "net"
- "net/http"
- "os"
-
- g "euandre.org/gobang/src"
-
- "github.com/gorilla/websocket"
-)
-
-
-
-type CLIArgs struct {
- FromAddr string
- ToAddr string
-}
-
-
-
-const X = 1
-
-var EmitActiveConnection = g.MakeGauge("active-connections")
-
-
-
-func ParseArgs(args []string) CLIArgs {
- if len(args) != 3 {
- fmt.Fprintf(
- os.Stderr,
- "Usage: %s FROM.socket TO.socket\n",
- args[0],
- )
- os.Exit(2)
- }
- return CLIArgs {
- FromAddr: args[1],
- ToAddr: args[2],
- }
-}
-
-func Listen(fromAddr string) net.Listener {
- listener, err := net.Listen("unix", fromAddr)
- g.FatalIf(err)
- g.Info("Started listening", "listen-start", "from-address", fromAddr)
- return listener
-}
-
-func Start(toAddr string, listener net.Listener) {
- upgrader := websocket.Upgrader {}
- http.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) {
- connFrom, err := upgrader.Upgrade(w, r, nil)
- if err != nil {
- g.Warning(
- "Error upgrading connection",
- "upgrade-connection-error",
- "err", err,
- )
- return
- }
- defer connFrom.Close()
- EmitActiveConnection.Inc()
-
- connTo, err := net.Dial("unix", toAddr)
- if err != nil {
- g.Error(
- "Error dialing connection",
- "dial-connection",
- "err", err,
- )
- os.Exit(1)
- }
- defer connTo.Close()
-
- messageType, reader, err := connFrom.NextReader()
- if err != nil {
- g.Warning(
- "Failed to get next reader from connection",
- "connection-next-reader-error",
- "err", err,
- )
- return
- }
-
- writer, err := connFrom.NextWriter(messageType)
- if err != nil {
- g.Warning(
- "Failed to get next reader from connection",
- "connection-next-reader-error",
- "err", err,
- )
- return
- }
-
-
- c := make(chan g.CopyResult)
- go g.CopyData(c, "c2s", connTo, writer)
- go g.CopyData(c, "s2c", reader, connTo)
- go func() {
- <- c
- EmitActiveConnection.Dec()
- }()
- });
-
- server := http.Server{}
- err := server.Serve(listener)
- g.FatalIf(err)
-}
-
-
-func Main() {
- g.Init()
- args := ParseArgs(os.Args)
- listener := Listen(args.FromAddr)
- Start(args.ToAddr, listener)
-}
diff --git a/src/cmd/main.go b/src/main.go
index e6ae309..ddc8f19 100644
--- a/src/cmd/main.go
+++ b/src/main.go
@@ -1,6 +1,6 @@
package main
-import "euandre.org/wscat/src"
+import "wscat"
func main() {
wscat.Main()
diff --git a/src/proxy.go b/src/proxy.go
new file mode 100644
index 0000000..7ef4be3
--- /dev/null
+++ b/src/proxy.go
@@ -0,0 +1,405 @@
+// Copyright 2019 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package proxy
+
+import (
+ "context"
+ "errors"
+ "net"
+ "net/url"
+ "os"
+ "strings"
+ "sync"
+
+ "socks"
+)
+
+// A ContextDialer dials using a context.
+type ContextDialer interface {
+ DialContext(ctx context.Context, network, address string) (net.Conn, error)
+}
+
+// Dial works like DialContext on net.Dialer but using a dialer returned by FromEnvironment.
+//
+// The passed ctx is only used for returning the Conn, not the lifetime of the Conn.
+//
+// Custom dialers (registered via RegisterDialerType) that do not implement ContextDialer
+// can leak a goroutine for as long as it takes the underlying Dialer implementation to timeout.
+//
+// A Conn returned from a successful Dial after the context has been cancelled will be immediately closed.
+func Dial(ctx context.Context, network, address string) (net.Conn, error) {
+ d := FromEnvironment()
+ if xd, ok := d.(ContextDialer); ok {
+ return xd.DialContext(ctx, network, address)
+ }
+ return dialContext(ctx, d, network, address)
+}
+
+// WARNING: this can leak a goroutine for as long as the underlying Dialer implementation takes to timeout
+// A Conn returned from a successful Dial after the context has been cancelled will be immediately closed.
+func dialContext(ctx context.Context, d Dialer, network, address string) (net.Conn, error) {
+ var (
+ conn net.Conn
+ done = make(chan struct{}, 1)
+ err error
+ )
+ go func() {
+ conn, err = d.Dial(network, address)
+ close(done)
+ if conn != nil && ctx.Err() != nil {
+ conn.Close()
+ }
+ }()
+ select {
+ case <-ctx.Done():
+ err = ctx.Err()
+ case <-done:
+ }
+ return conn, err
+}
+// Copyright 2011 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+
+type direct struct{}
+
+// Direct implements Dialer by making network connections directly using net.Dial or net.DialContext.
+var Direct = direct{}
+
+var (
+ _ Dialer = Direct
+ _ ContextDialer = Direct
+)
+
+// Dial directly invokes net.Dial with the supplied parameters.
+func (direct) Dial(network, addr string) (net.Conn, error) {
+ return net.Dial(network, addr)
+}
+
+// DialContext instantiates a net.Dialer and invokes its DialContext receiver with the supplied parameters.
+func (direct) DialContext(ctx context.Context, network, addr string) (net.Conn, error) {
+ var d net.Dialer
+ return d.DialContext(ctx, network, addr)
+}
+// Copyright 2011 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+
+// A PerHost directs connections to a default Dialer unless the host name
+// requested matches one of a number of exceptions.
+type PerHost struct {
+ def, bypass Dialer
+
+ bypassNetworks []*net.IPNet
+ bypassIPs []net.IP
+ bypassZones []string
+ bypassHosts []string
+}
+
+// NewPerHost returns a PerHost Dialer that directs connections to either
+// defaultDialer or bypass, depending on whether the connection matches one of
+// the configured rules.
+func NewPerHost(defaultDialer, bypass Dialer) *PerHost {
+ return &PerHost{
+ def: defaultDialer,
+ bypass: bypass,
+ }
+}
+
+// Dial connects to the address addr on the given network through either
+// defaultDialer or bypass.
+func (p *PerHost) Dial(network, addr string) (c net.Conn, err error) {
+ host, _, err := net.SplitHostPort(addr)
+ if err != nil {
+ return nil, err
+ }
+
+ return p.dialerForRequest(host).Dial(network, addr)
+}
+
+// DialContext connects to the address addr on the given network through either
+// defaultDialer or bypass.
+func (p *PerHost) DialContext(ctx context.Context, network, addr string) (c net.Conn, err error) {
+ host, _, err := net.SplitHostPort(addr)
+ if err != nil {
+ return nil, err
+ }
+ d := p.dialerForRequest(host)
+ if x, ok := d.(ContextDialer); ok {
+ return x.DialContext(ctx, network, addr)
+ }
+ return dialContext(ctx, d, network, addr)
+}
+
+func (p *PerHost) dialerForRequest(host string) Dialer {
+ if ip := net.ParseIP(host); ip != nil {
+ for _, net := range p.bypassNetworks {
+ if net.Contains(ip) {
+ return p.bypass
+ }
+ }
+ for _, bypassIP := range p.bypassIPs {
+ if bypassIP.Equal(ip) {
+ return p.bypass
+ }
+ }
+ return p.def
+ }
+
+ for _, zone := range p.bypassZones {
+ if strings.HasSuffix(host, zone) {
+ return p.bypass
+ }
+ if host == zone[1:] {
+ // For a zone ".example.com", we match "example.com"
+ // too.
+ return p.bypass
+ }
+ }
+ for _, bypassHost := range p.bypassHosts {
+ if bypassHost == host {
+ return p.bypass
+ }
+ }
+ return p.def
+}
+
+// AddFromString parses a string that contains comma-separated values
+// specifying hosts that should use the bypass proxy. Each value is either an
+// IP address, a CIDR range, a zone (*.example.com) or a host name
+// (localhost). A best effort is made to parse the string and errors are
+// ignored.
+func (p *PerHost) AddFromString(s string) {
+ hosts := strings.Split(s, ",")
+ for _, host := range hosts {
+ host = strings.TrimSpace(host)
+ if len(host) == 0 {
+ continue
+ }
+ if strings.Contains(host, "/") {
+ // We assume that it's a CIDR address like 127.0.0.0/8
+ if _, net, err := net.ParseCIDR(host); err == nil {
+ p.AddNetwork(net)
+ }
+ continue
+ }
+ if ip := net.ParseIP(host); ip != nil {
+ p.AddIP(ip)
+ continue
+ }
+ if strings.HasPrefix(host, "*.") {
+ p.AddZone(host[1:])
+ continue
+ }
+ p.AddHost(host)
+ }
+}
+
+// AddIP specifies an IP address that will use the bypass proxy. Note that
+// this will only take effect if a literal IP address is dialed. A connection
+// to a named host will never match an IP.
+func (p *PerHost) AddIP(ip net.IP) {
+ p.bypassIPs = append(p.bypassIPs, ip)
+}
+
+// AddNetwork specifies an IP range that will use the bypass proxy. Note that
+// this will only take effect if a literal IP address is dialed. A connection
+// to a named host will never match.
+func (p *PerHost) AddNetwork(net *net.IPNet) {
+ p.bypassNetworks = append(p.bypassNetworks, net)
+}
+
+// AddZone specifies a DNS suffix that will use the bypass proxy. A zone of
+// "example.com" matches "example.com" and all of its subdomains.
+func (p *PerHost) AddZone(zone string) {
+ zone = strings.TrimSuffix(zone, ".")
+ if !strings.HasPrefix(zone, ".") {
+ zone = "." + zone
+ }
+ p.bypassZones = append(p.bypassZones, zone)
+}
+
+// AddHost specifies a host name that will use the bypass proxy.
+func (p *PerHost) AddHost(host string) {
+ host = strings.TrimSuffix(host, ".")
+ p.bypassHosts = append(p.bypassHosts, host)
+}
+// Copyright 2011 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// Package proxy provides support for a variety of protocols to proxy network
+// data.
+// package proxy // import "golang.org/x/net/proxy"
+
+// A Dialer is a means to establish a connection.
+// Custom dialers should also implement ContextDialer.
+type Dialer interface {
+ // Dial connects to the given address via the proxy.
+ Dial(network, addr string) (c net.Conn, err error)
+}
+
+// Auth contains authentication parameters that specific Dialers may require.
+type Auth struct {
+ User, Password string
+}
+
+// FromEnvironment returns the dialer specified by the proxy-related
+// variables in the environment and makes underlying connections
+// directly.
+func FromEnvironment() Dialer {
+ return FromEnvironmentUsing(Direct)
+}
+
+// FromEnvironmentUsing returns the dialer specify by the proxy-related
+// variables in the environment and makes underlying connections
+// using the provided forwarding Dialer (for instance, a *net.Dialer
+// with desired configuration).
+func FromEnvironmentUsing(forward Dialer) Dialer {
+ allProxy := allProxyEnv.Get()
+ if len(allProxy) == 0 {
+ return forward
+ }
+
+ proxyURL, err := url.Parse(allProxy)
+ if err != nil {
+ return forward
+ }
+ proxy, err := FromURL(proxyURL, forward)
+ if err != nil {
+ return forward
+ }
+
+ noProxy := noProxyEnv.Get()
+ if len(noProxy) == 0 {
+ return proxy
+ }
+
+ perHost := NewPerHost(proxy, forward)
+ perHost.AddFromString(noProxy)
+ return perHost
+}
+
+// proxySchemes is a map from URL schemes to a function that creates a Dialer
+// from a URL with such a scheme.
+var proxySchemes map[string]func(*url.URL, Dialer) (Dialer, error)
+
+// RegisterDialerType takes a URL scheme and a function to generate Dialers from
+// a URL with that scheme and a forwarding Dialer. Registered schemes are used
+// by FromURL.
+func RegisterDialerType(scheme string, f func(*url.URL, Dialer) (Dialer, error)) {
+ if proxySchemes == nil {
+ proxySchemes = make(map[string]func(*url.URL, Dialer) (Dialer, error))
+ }
+ proxySchemes[scheme] = f
+}
+
+// FromURL returns a Dialer given a URL specification and an underlying
+// Dialer for it to make network requests.
+func FromURL(u *url.URL, forward Dialer) (Dialer, error) {
+ var auth *Auth
+ if u.User != nil {
+ auth = new(Auth)
+ auth.User = u.User.Username()
+ if p, ok := u.User.Password(); ok {
+ auth.Password = p
+ }
+ }
+
+ switch u.Scheme {
+ case "socks5", "socks5h":
+ addr := u.Hostname()
+ port := u.Port()
+ if port == "" {
+ port = "1080"
+ }
+ return SOCKS5("tcp", net.JoinHostPort(addr, port), auth, forward)
+ }
+
+ // If the scheme doesn't match any of the built-in schemes, see if it
+ // was registered by another package.
+ if proxySchemes != nil {
+ if f, ok := proxySchemes[u.Scheme]; ok {
+ return f(u, forward)
+ }
+ }
+
+ return nil, errors.New("proxy: unknown scheme: " + u.Scheme)
+}
+
+var (
+ allProxyEnv = &envOnce{
+ names: []string{"ALL_PROXY", "all_proxy"},
+ }
+ noProxyEnv = &envOnce{
+ names: []string{"NO_PROXY", "no_proxy"},
+ }
+)
+
+// envOnce looks up an environment variable (optionally by multiple
+// names) once. It mitigates expensive lookups on some platforms
+// (e.g. Windows).
+// (Borrowed from net/http/transport.go)
+type envOnce struct {
+ names []string
+ once sync.Once
+ val string
+}
+
+func (e *envOnce) Get() string {
+ e.once.Do(e.init)
+ return e.val
+}
+
+func (e *envOnce) init() {
+ for _, n := range e.names {
+ e.val = os.Getenv(n)
+ if e.val != "" {
+ return
+ }
+ }
+}
+
+// reset is used by tests
+func (e *envOnce) reset() {
+ e.once = sync.Once{}
+ e.val = ""
+}
+// Copyright 2011 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+
+// SOCKS5 returns a Dialer that makes SOCKSv5 connections to the given
+// address with an optional username and password.
+// See RFC 1928 and RFC 1929.
+func SOCKS5(network, address string, auth *Auth, forward Dialer) (Dialer, error) {
+ d := socks.NewDialer(network, address)
+ if forward != nil {
+ if f, ok := forward.(ContextDialer); ok {
+ d.ProxyDial = func(ctx context.Context, network string, address string) (net.Conn, error) {
+ return f.DialContext(ctx, network, address)
+ }
+ } else {
+ d.ProxyDial = func(ctx context.Context, network string, address string) (net.Conn, error) {
+ return dialContext(ctx, forward, network, address)
+ }
+ }
+ }
+ if auth != nil {
+ up := socks.UsernamePassword{
+ Username: auth.User,
+ Password: auth.Password,
+ }
+ d.AuthMethods = []socks.AuthMethod{
+ socks.AuthMethodNotRequired,
+ socks.AuthMethodUsernamePassword,
+ }
+ d.Authenticate = up.Authenticate
+ }
+ return d, nil
+}
diff --git a/src/socks.go b/src/socks.go
new file mode 100644
index 0000000..d506505
--- /dev/null
+++ b/src/socks.go
@@ -0,0 +1,473 @@
+// Package socks provides a SOCKS version 5 client implementation.
+//
+// SOCKS protocol version 5 is defined in RFC 1928.
+// Username/Password authentication for SOCKS version 5 is defined in
+// RFC 1929.
+package socks
+
+import (
+ "context"
+ "errors"
+ "io"
+ "net"
+ "strconv"
+ "time"
+)
+
+
+
+var (
+ noDeadline = time.Time{}
+ aLongTimeAgo = time.Unix(1, 0)
+)
+
+func (d *Dialer) connect(ctx context.Context, c net.Conn, address string) (_ net.Addr, ctxErr error) {
+ host, port, err := splitHostPort(address)
+ if err != nil {
+ return nil, err
+ }
+ if deadline, ok := ctx.Deadline(); ok && !deadline.IsZero() {
+ c.SetDeadline(deadline)
+ defer c.SetDeadline(noDeadline)
+ }
+ if ctx != context.Background() {
+ errCh := make(chan error, 1)
+ done := make(chan struct{})
+ defer func() {
+ close(done)
+ if ctxErr == nil {
+ ctxErr = <-errCh
+ }
+ }()
+ go func() {
+ select {
+ case <-ctx.Done():
+ c.SetDeadline(aLongTimeAgo)
+ errCh <- ctx.Err()
+ case <-done:
+ errCh <- nil
+ }
+ }()
+ }
+
+ b := make([]byte, 0, 6+len(host)) // the size here is just an estimate
+ b = append(b, Version5)
+ if len(d.AuthMethods) == 0 || d.Authenticate == nil {
+ b = append(b, 1, byte(AuthMethodNotRequired))
+ } else {
+ ams := d.AuthMethods
+ if len(ams) > 255 {
+ return nil, errors.New("too many authentication methods")
+ }
+ b = append(b, byte(len(ams)))
+ for _, am := range ams {
+ b = append(b, byte(am))
+ }
+ }
+ if _, ctxErr = c.Write(b); ctxErr != nil {
+ return
+ }
+
+ if _, ctxErr = io.ReadFull(c, b[:2]); ctxErr != nil {
+ return
+ }
+ if b[0] != Version5 {
+ return nil, errors.New("unexpected protocol version " + strconv.Itoa(int(b[0])))
+ }
+ am := AuthMethod(b[1])
+ if am == AuthMethodNoAcceptableMethods {
+ return nil, errors.New("no acceptable authentication methods")
+ }
+ if d.Authenticate != nil {
+ if ctxErr = d.Authenticate(ctx, c, am); ctxErr != nil {
+ return
+ }
+ }
+
+ b = b[:0]
+ b = append(b, Version5, byte(d.cmd), 0)
+ if ip := net.ParseIP(host); ip != nil {
+ if ip4 := ip.To4(); ip4 != nil {
+ b = append(b, AddrTypeIPv4)
+ b = append(b, ip4...)
+ } else if ip6 := ip.To16(); ip6 != nil {
+ b = append(b, AddrTypeIPv6)
+ b = append(b, ip6...)
+ } else {
+ return nil, errors.New("unknown address type")
+ }
+ } else {
+ if len(host) > 255 {
+ return nil, errors.New("FQDN too long")
+ }
+ b = append(b, AddrTypeFQDN)
+ b = append(b, byte(len(host)))
+ b = append(b, host...)
+ }
+ b = append(b, byte(port>>8), byte(port))
+ if _, ctxErr = c.Write(b); ctxErr != nil {
+ return
+ }
+
+ if _, ctxErr = io.ReadFull(c, b[:4]); ctxErr != nil {
+ return
+ }
+ if b[0] != Version5 {
+ return nil, errors.New("unexpected protocol version " + strconv.Itoa(int(b[0])))
+ }
+ if cmdErr := Reply(b[1]); cmdErr != StatusSucceeded {
+ return nil, errors.New("unknown error " + cmdErr.String())
+ }
+ if b[2] != 0 {
+ return nil, errors.New("non-zero reserved field")
+ }
+ l := 2
+ var a Addr
+ switch b[3] {
+ case AddrTypeIPv4:
+ l += net.IPv4len
+ a.IP = make(net.IP, net.IPv4len)
+ case AddrTypeIPv6:
+ l += net.IPv6len
+ a.IP = make(net.IP, net.IPv6len)
+ case AddrTypeFQDN:
+ if _, err := io.ReadFull(c, b[:1]); err != nil {
+ return nil, err
+ }
+ l += int(b[0])
+ default:
+ return nil, errors.New("unknown address type " + strconv.Itoa(int(b[3])))
+ }
+ if cap(b) < l {
+ b = make([]byte, l)
+ } else {
+ b = b[:l]
+ }
+ if _, ctxErr = io.ReadFull(c, b); ctxErr != nil {
+ return
+ }
+ if a.IP != nil {
+ copy(a.IP, b)
+ } else {
+ a.Name = string(b[:len(b)-2])
+ }
+ a.Port = int(b[len(b)-2])<<8 | int(b[len(b)-1])
+ return &a, nil
+}
+
+func splitHostPort(address string) (string, int, error) {
+ host, port, err := net.SplitHostPort(address)
+ if err != nil {
+ return "", 0, err
+ }
+ portnum, err := strconv.Atoi(port)
+ if err != nil {
+ return "", 0, err
+ }
+ if 1 > portnum || portnum > 0xffff {
+ return "", 0, errors.New("port number out of range " + port)
+ }
+ return host, portnum, nil
+}
+// Copyright 2018 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// A Command represents a SOCKS command.
+type Command int
+
+func (cmd Command) String() string {
+ switch cmd {
+ case CmdConnect:
+ return "socks connect"
+ case cmdBind:
+ return "socks bind"
+ default:
+ return "socks " + strconv.Itoa(int(cmd))
+ }
+}
+
+// An AuthMethod represents a SOCKS authentication method.
+type AuthMethod int
+
+// A Reply represents a SOCKS command reply code.
+type Reply int
+
+func (code Reply) String() string {
+ switch code {
+ case StatusSucceeded:
+ return "succeeded"
+ case 0x01:
+ return "general SOCKS server failure"
+ case 0x02:
+ return "connection not allowed by ruleset"
+ case 0x03:
+ return "network unreachable"
+ case 0x04:
+ return "host unreachable"
+ case 0x05:
+ return "connection refused"
+ case 0x06:
+ return "TTL expired"
+ case 0x07:
+ return "command not supported"
+ case 0x08:
+ return "address type not supported"
+ default:
+ return "unknown code: " + strconv.Itoa(int(code))
+ }
+}
+
+// Wire protocol constants.
+const (
+ Version5 = 0x05
+
+ AddrTypeIPv4 = 0x01
+ AddrTypeFQDN = 0x03
+ AddrTypeIPv6 = 0x04
+
+ CmdConnect Command = 0x01 // establishes an active-open forward proxy connection
+ cmdBind Command = 0x02 // establishes a passive-open forward proxy connection
+
+ AuthMethodNotRequired AuthMethod = 0x00 // no authentication required
+ AuthMethodUsernamePassword AuthMethod = 0x02 // use username/password
+ AuthMethodNoAcceptableMethods AuthMethod = 0xff // no acceptable authentication methods
+
+ StatusSucceeded Reply = 0x00
+)
+
+// An Addr represents a SOCKS-specific address.
+// Either Name or IP is used exclusively.
+type Addr struct {
+ Name string // fully-qualified domain name
+ IP net.IP
+ Port int
+}
+
+func (a *Addr) Network() string { return "socks" }
+
+func (a *Addr) String() string {
+ if a == nil {
+ return "<nil>"
+ }
+ port := strconv.Itoa(a.Port)
+ if a.IP == nil {
+ return net.JoinHostPort(a.Name, port)
+ }
+ return net.JoinHostPort(a.IP.String(), port)
+}
+
+// A Conn represents a forward proxy connection.
+type Conn struct {
+ net.Conn
+
+ boundAddr net.Addr
+}
+
+// BoundAddr returns the address assigned by the proxy server for
+// connecting to the command target address from the proxy server.
+func (c *Conn) BoundAddr() net.Addr {
+ if c == nil {
+ return nil
+ }
+ return c.boundAddr
+}
+
+// A Dialer holds SOCKS-specific options.
+type Dialer struct {
+ cmd Command // either CmdConnect or cmdBind
+ proxyNetwork string // network between a proxy server and a client
+ proxyAddress string // proxy server address
+
+ // ProxyDial specifies the optional dial function for
+ // establishing the transport connection.
+ ProxyDial func(context.Context, string, string) (net.Conn, error)
+
+ // AuthMethods specifies the list of request authentication
+ // methods.
+ // If empty, SOCKS client requests only AuthMethodNotRequired.
+ AuthMethods []AuthMethod
+
+ // Authenticate specifies the optional authentication
+ // function. It must be non-nil when AuthMethods is not empty.
+ // It must return an error when the authentication is failed.
+ Authenticate func(context.Context, io.ReadWriter, AuthMethod) error
+}
+
+// DialContext connects to the provided address on the provided
+// network.
+//
+// The returned error value may be a net.OpError. When the Op field of
+// net.OpError contains "socks", the Source field contains a proxy
+// server address and the Addr field contains a command target
+// address.
+//
+// See func Dial of the net package of standard library for a
+// description of the network and address parameters.
+func (d *Dialer) DialContext(ctx context.Context, network, address string) (net.Conn, error) {
+ if err := d.validateTarget(network, address); err != nil {
+ proxy, dst, _ := d.pathAddrs(address)
+ return nil, &net.OpError{Op: d.cmd.String(), Net: network, Source: proxy, Addr: dst, Err: err}
+ }
+ if ctx == nil {
+ proxy, dst, _ := d.pathAddrs(address)
+ return nil, &net.OpError{Op: d.cmd.String(), Net: network, Source: proxy, Addr: dst, Err: errors.New("nil context")}
+ }
+ var err error
+ var c net.Conn
+ if d.ProxyDial != nil {
+ c, err = d.ProxyDial(ctx, d.proxyNetwork, d.proxyAddress)
+ } else {
+ var dd net.Dialer
+ c, err = dd.DialContext(ctx, d.proxyNetwork, d.proxyAddress)
+ }
+ if err != nil {
+ proxy, dst, _ := d.pathAddrs(address)
+ return nil, &net.OpError{Op: d.cmd.String(), Net: network, Source: proxy, Addr: dst, Err: err}
+ }
+ a, err := d.connect(ctx, c, address)
+ if err != nil {
+ c.Close()
+ proxy, dst, _ := d.pathAddrs(address)
+ return nil, &net.OpError{Op: d.cmd.String(), Net: network, Source: proxy, Addr: dst, Err: err}
+ }
+ return &Conn{Conn: c, boundAddr: a}, nil
+}
+
+// DialWithConn initiates a connection from SOCKS server to the target
+// network and address using the connection c that is already
+// connected to the SOCKS server.
+//
+// It returns the connection's local address assigned by the SOCKS
+// server.
+func (d *Dialer) DialWithConn(ctx context.Context, c net.Conn, network, address string) (net.Addr, error) {
+ if err := d.validateTarget(network, address); err != nil {
+ proxy, dst, _ := d.pathAddrs(address)
+ return nil, &net.OpError{Op: d.cmd.String(), Net: network, Source: proxy, Addr: dst, Err: err}
+ }
+ if ctx == nil {
+ proxy, dst, _ := d.pathAddrs(address)
+ return nil, &net.OpError{Op: d.cmd.String(), Net: network, Source: proxy, Addr: dst, Err: errors.New("nil context")}
+ }
+ a, err := d.connect(ctx, c, address)
+ if err != nil {
+ proxy, dst, _ := d.pathAddrs(address)
+ return nil, &net.OpError{Op: d.cmd.String(), Net: network, Source: proxy, Addr: dst, Err: err}
+ }
+ return a, nil
+}
+
+// Dial connects to the provided address on the provided network.
+//
+// Unlike DialContext, it returns a raw transport connection instead
+// of a forward proxy connection.
+//
+// Deprecated: Use DialContext or DialWithConn instead.
+func (d *Dialer) Dial(network, address string) (net.Conn, error) {
+ if err := d.validateTarget(network, address); err != nil {
+ proxy, dst, _ := d.pathAddrs(address)
+ return nil, &net.OpError{Op: d.cmd.String(), Net: network, Source: proxy, Addr: dst, Err: err}
+ }
+ var err error
+ var c net.Conn
+ if d.ProxyDial != nil {
+ c, err = d.ProxyDial(context.Background(), d.proxyNetwork, d.proxyAddress)
+ } else {
+ c, err = net.Dial(d.proxyNetwork, d.proxyAddress)
+ }
+ if err != nil {
+ proxy, dst, _ := d.pathAddrs(address)
+ return nil, &net.OpError{Op: d.cmd.String(), Net: network, Source: proxy, Addr: dst, Err: err}
+ }
+ if _, err := d.DialWithConn(context.Background(), c, network, address); err != nil {
+ c.Close()
+ return nil, err
+ }
+ return c, nil
+}
+
+func (d *Dialer) validateTarget(network, address string) error {
+ switch network {
+ case "tcp", "tcp6", "tcp4":
+ default:
+ return errors.New("network not implemented")
+ }
+ switch d.cmd {
+ case CmdConnect, cmdBind:
+ default:
+ return errors.New("command not implemented")
+ }
+ return nil
+}
+
+func (d *Dialer) pathAddrs(address string) (proxy, dst net.Addr, err error) {
+ for i, s := range []string{d.proxyAddress, address} {
+ host, port, err := splitHostPort(s)
+ if err != nil {
+ return nil, nil, err
+ }
+ a := &Addr{Port: port}
+ a.IP = net.ParseIP(host)
+ if a.IP == nil {
+ a.Name = host
+ }
+ if i == 0 {
+ proxy = a
+ } else {
+ dst = a
+ }
+ }
+ return
+}
+
+// NewDialer returns a new Dialer that dials through the provided
+// proxy server's network and address.
+func NewDialer(network, address string) *Dialer {
+ return &Dialer{proxyNetwork: network, proxyAddress: address, cmd: CmdConnect}
+}
+
+const (
+ authUsernamePasswordVersion = 0x01
+ authStatusSucceeded = 0x00
+)
+
+// UsernamePassword are the credentials for the username/password
+// authentication method.
+type UsernamePassword struct {
+ Username string
+ Password string
+}
+
+// Authenticate authenticates a pair of username and password with the
+// proxy server.
+func (up *UsernamePassword) Authenticate(ctx context.Context, rw io.ReadWriter, auth AuthMethod) error {
+ switch auth {
+ case AuthMethodNotRequired:
+ return nil
+ case AuthMethodUsernamePassword:
+ if len(up.Username) == 0 || len(up.Username) > 255 || len(up.Password) > 255 {
+ return errors.New("invalid username/password")
+ }
+ b := []byte{authUsernamePasswordVersion}
+ b = append(b, byte(len(up.Username)))
+ b = append(b, up.Username...)
+ b = append(b, byte(len(up.Password)))
+ b = append(b, up.Password...)
+ // TODO(mikio): handle IO deadlines and cancelation if
+ // necessary
+ if _, err := rw.Write(b); err != nil {
+ return err
+ }
+ if _, err := io.ReadFull(rw, b[:2]); err != nil {
+ return err
+ }
+ if b[0] != authUsernamePasswordVersion {
+ return errors.New("invalid username/password version")
+ }
+ if b[1] != authStatusSucceeded {
+ return errors.New("username/password authentication failed")
+ }
+ return nil
+ }
+ return errors.New("unsupported authentication method " + strconv.Itoa(int(auth)))
+}
diff --git a/src/wscat.go b/src/wscat.go
new file mode 100644
index 0000000..b9a17a6
--- /dev/null
+++ b/src/wscat.go
@@ -0,0 +1,3148 @@
+package wscat
+
+import (
+ "bufio"
+ "bytes"
+ "compress/flate"
+ "context"
+ "crypto/rand"
+ "crypto/sha1"
+ "crypto/tls"
+ "encoding/base64"
+ "encoding/binary"
+ "encoding/json"
+ "errors"
+ "fmt"
+ "io"
+ "net"
+ "net/http"
+ "net/http/httptrace"
+ "net/url"
+ "os"
+ "strconv"
+ "strings"
+ "sync"
+ "time"
+ "unicode/utf8"
+ "unsafe"
+
+ "proxy"
+
+ g "gobang"
+)
+
+// ErrBadHandshake is returned when the server response to opening handshake is
+// invalid.
+var ErrBadHandshake = errors.New("websocket: bad handshake")
+
+var errInvalidCompression = errors.New("websocket: invalid compression negotiation")
+
+// NewClient creates a new client connection using the given net connection.
+// The URL u specifies the host and request URI. Use requestHeader to specify
+// the origin (Origin), subprotocols (Sec-WebSocket-Protocol) and cookies
+// (Cookie). Use the response.Header to get the selected subprotocol
+// (Sec-WebSocket-Protocol) and cookies (Set-Cookie).
+//
+// If the WebSocket handshake fails, ErrBadHandshake is returned along with a
+// non-nil *http.Response so that callers can handle redirects, authentication,
+// etc.
+//
+// Deprecated: Use Dialer instead.
+func NewClient(netConn net.Conn, u *url.URL, requestHeader http.Header, readBufSize, writeBufSize int) (c *Conn, response *http.Response, err error) {
+ d := Dialer{
+ ReadBufferSize: readBufSize,
+ WriteBufferSize: writeBufSize,
+ NetDial: func(net, addr string) (net.Conn, error) {
+ return netConn, nil
+ },
+ }
+ return d.Dial(u.String(), requestHeader)
+}
+
+// A Dialer contains options for connecting to WebSocket server.
+//
+// It is safe to call Dialer's methods concurrently.
+type Dialer struct {
+ // NetDial specifies the dial function for creating TCP connections. If
+ // NetDial is nil, net.Dialer DialContext is used.
+ NetDial func(network, addr string) (net.Conn, error)
+
+ // NetDialContext specifies the dial function for creating TCP connections. If
+ // NetDialContext is nil, NetDial is used.
+ NetDialContext func(ctx context.Context, network, addr string) (net.Conn, error)
+
+ // NetDialTLSContext specifies the dial function for creating TLS/TCP connections. If
+ // NetDialTLSContext is nil, NetDialContext is used.
+ // If NetDialTLSContext is set, Dial assumes the TLS handshake is done there and
+ // TLSClientConfig is ignored.
+ NetDialTLSContext func(ctx context.Context, network, addr string) (net.Conn, error)
+
+ // Proxy specifies a function to return a proxy for a given
+ // Request. If the function returns a non-nil error, the
+ // request is aborted with the provided error.
+ // If Proxy is nil or returns a nil *URL, no proxy is used.
+ Proxy func(*http.Request) (*url.URL, error)
+
+ // TLSClientConfig specifies the TLS configuration to use with tls.Client.
+ // If nil, the default configuration is used.
+ // If either NetDialTLS or NetDialTLSContext are set, Dial assumes the TLS handshake
+ // is done there and TLSClientConfig is ignored.
+ TLSClientConfig *tls.Config
+
+ // HandshakeTimeout specifies the duration for the handshake to complete.
+ HandshakeTimeout time.Duration
+
+ // ReadBufferSize and WriteBufferSize specify I/O buffer sizes in bytes. If a buffer
+ // size is zero, then a useful default size is used. The I/O buffer sizes
+ // do not limit the size of the messages that can be sent or received.
+ ReadBufferSize, WriteBufferSize int
+
+ // WriteBufferPool is a pool of buffers for write operations. If the value
+ // is not set, then write buffers are allocated to the connection for the
+ // lifetime of the connection.
+ //
+ // A pool is most useful when the application has a modest volume of writes
+ // across a large number of connections.
+ //
+ // Applications should use a single pool for each unique value of
+ // WriteBufferSize.
+ WriteBufferPool BufferPool
+
+ // Subprotocols specifies the client's requested subprotocols.
+ Subprotocols []string
+
+ // EnableCompression specifies if the client should attempt to negotiate
+ // per message compression (RFC 7692). Setting this value to true does not
+ // guarantee that compression will be supported. Currently only "no context
+ // takeover" modes are supported.
+ EnableCompression bool
+
+ // Jar specifies the cookie jar.
+ // If Jar is nil, cookies are not sent in requests and ignored
+ // in responses.
+ Jar http.CookieJar
+}
+
+// Dial creates a new client connection by calling DialContext with a background context.
+func (d *Dialer) Dial(urlStr string, requestHeader http.Header) (*Conn, *http.Response, error) {
+ return d.DialContext(context.Background(), urlStr, requestHeader)
+}
+
+var errMalformedURL = errors.New("malformed ws or wss URL")
+
+func hostPortNoPort(u *url.URL) (hostPort, hostNoPort string) {
+ hostPort = u.Host
+ hostNoPort = u.Host
+ if i := strings.LastIndex(u.Host, ":"); i > strings.LastIndex(u.Host, "]") {
+ hostNoPort = hostNoPort[:i]
+ } else {
+ switch u.Scheme {
+ case "wss":
+ hostPort += ":443"
+ case "https":
+ hostPort += ":443"
+ default:
+ hostPort += ":80"
+ }
+ }
+ return hostPort, hostNoPort
+}
+
+// DefaultDialer is a dialer with all fields set to the default values.
+var DefaultDialer = &Dialer{
+ Proxy: http.ProxyFromEnvironment,
+ HandshakeTimeout: 45 * time.Second,
+}
+
+// nilDialer is dialer to use when receiver is nil.
+var nilDialer = *DefaultDialer
+
+// DialContext creates a new client connection. Use requestHeader to specify the
+// origin (Origin), subprotocols (Sec-WebSocket-Protocol) and cookies (Cookie).
+// Use the response.Header to get the selected subprotocol
+// (Sec-WebSocket-Protocol) and cookies (Set-Cookie).
+//
+// The context will be used in the request and in the Dialer.
+//
+// If the WebSocket handshake fails, ErrBadHandshake is returned along with a
+// non-nil *http.Response so that callers can handle redirects, authentication,
+// etcetera. The response body may not contain the entire response and does not
+// need to be closed by the application.
+func (d *Dialer) DialContext(ctx context.Context, urlStr string, requestHeader http.Header) (*Conn, *http.Response, error) {
+ if d == nil {
+ d = &nilDialer
+ }
+
+ challengeKey, err := generateChallengeKey()
+ if err != nil {
+ return nil, nil, err
+ }
+
+ u, err := url.Parse(urlStr)
+ if err != nil {
+ return nil, nil, err
+ }
+
+ switch u.Scheme {
+ case "ws":
+ u.Scheme = "http"
+ case "wss":
+ u.Scheme = "https"
+ default:
+ return nil, nil, errMalformedURL
+ }
+
+ if u.User != nil {
+ // User name and password are not allowed in websocket URIs.
+ return nil, nil, errMalformedURL
+ }
+
+ req := &http.Request{
+ Method: http.MethodGet,
+ URL: u,
+ Proto: "HTTP/1.1",
+ ProtoMajor: 1,
+ ProtoMinor: 1,
+ Header: make(http.Header),
+ Host: u.Host,
+ }
+ req = req.WithContext(ctx)
+
+ // Set the cookies present in the cookie jar of the dialer
+ if d.Jar != nil {
+ for _, cookie := range d.Jar.Cookies(u) {
+ req.AddCookie(cookie)
+ }
+ }
+
+ // Set the request headers using the capitalization for names and values in
+ // RFC examples. Although the capitalization shouldn't matter, there are
+ // servers that depend on it. The Header.Set method is not used because the
+ // method canonicalizes the header names.
+ req.Header["Upgrade"] = []string{"websocket"}
+ req.Header["Connection"] = []string{"Upgrade"}
+ req.Header["Sec-WebSocket-Key"] = []string{challengeKey}
+ req.Header["Sec-WebSocket-Version"] = []string{"13"}
+ if len(d.Subprotocols) > 0 {
+ req.Header["Sec-WebSocket-Protocol"] = []string{strings.Join(d.Subprotocols, ", ")}
+ }
+ for k, vs := range requestHeader {
+ switch {
+ case k == "Host":
+ if len(vs) > 0 {
+ req.Host = vs[0]
+ }
+ case k == "Upgrade" ||
+ k == "Connection" ||
+ k == "Sec-Websocket-Key" ||
+ k == "Sec-Websocket-Version" ||
+ k == "Sec-Websocket-Extensions" ||
+ (k == "Sec-Websocket-Protocol" && len(d.Subprotocols) > 0):
+ return nil, nil, errors.New("websocket: duplicate header not allowed: " + k)
+ case k == "Sec-Websocket-Protocol":
+ req.Header["Sec-WebSocket-Protocol"] = vs
+ default:
+ req.Header[k] = vs
+ }
+ }
+
+ if d.EnableCompression {
+ req.Header["Sec-WebSocket-Extensions"] = []string{"permessage-deflate; server_no_context_takeover; client_no_context_takeover"}
+ }
+
+ if d.HandshakeTimeout != 0 {
+ var cancel func()
+ ctx, cancel = context.WithTimeout(ctx, d.HandshakeTimeout)
+ defer cancel()
+ }
+
+ var netDial netDialerFunc
+ switch {
+ case u.Scheme == "https" && d.NetDialTLSContext != nil:
+ netDial = d.NetDialTLSContext
+ case d.NetDialContext != nil:
+ netDial = d.NetDialContext
+ case d.NetDial != nil:
+ netDial = func(ctx context.Context, net, addr string) (net.Conn, error) {
+ return d.NetDial(net, addr)
+ }
+ default:
+ netDial = (&net.Dialer{}).DialContext
+ }
+
+ // If needed, wrap the dial function to set the connection deadline.
+ if deadline, ok := ctx.Deadline(); ok {
+ forwardDial := netDial
+ netDial = func(ctx context.Context, network, addr string) (net.Conn, error) {
+ c, err := forwardDial(ctx, network, addr)
+ if err != nil {
+ return nil, err
+ }
+ err = c.SetDeadline(deadline)
+ if err != nil {
+ c.Close()
+ return nil, err
+ }
+ return c, nil
+ }
+ }
+
+ // If needed, wrap the dial function to connect through a proxy.
+ if d.Proxy != nil {
+ proxyURL, err := d.Proxy(req)
+ if err != nil {
+ return nil, nil, err
+ }
+ if proxyURL != nil {
+ netDial, err = proxyFromURL(proxyURL, netDial)
+ if err != nil {
+ return nil, nil, err
+ }
+ }
+ }
+
+ hostPort, hostNoPort := hostPortNoPort(u)
+ trace := httptrace.ContextClientTrace(ctx)
+ if trace != nil && trace.GetConn != nil {
+ trace.GetConn(hostPort)
+ }
+
+ netConn, err := netDial(ctx, "tcp", hostPort)
+ if err != nil {
+ return nil, nil, err
+ }
+ if trace != nil && trace.GotConn != nil {
+ trace.GotConn(httptrace.GotConnInfo{
+ Conn: netConn,
+ })
+ }
+
+ // Close the network connection when returning an error. The variable
+ // netConn is set to nil before the success return at the end of the
+ // function.
+ defer func() {
+ if netConn != nil {
+ // It's safe to ignore the error from Close() because this code is
+ // only executed when returning a more important error to the
+ // application.
+ _ = netConn.Close()
+ }
+ }()
+
+ if u.Scheme == "https" && d.NetDialTLSContext == nil {
+ // If NetDialTLSContext is set, assume that the TLS handshake has already been done
+
+ cfg := cloneTLSConfig(d.TLSClientConfig)
+ if cfg.ServerName == "" {
+ cfg.ServerName = hostNoPort
+ }
+ tlsConn := tls.Client(netConn, cfg)
+ netConn = tlsConn
+
+ if trace != nil && trace.TLSHandshakeStart != nil {
+ trace.TLSHandshakeStart()
+ }
+ err := doHandshake(ctx, tlsConn, cfg)
+ if trace != nil && trace.TLSHandshakeDone != nil {
+ trace.TLSHandshakeDone(tlsConn.ConnectionState(), err)
+ }
+
+ if err != nil {
+ return nil, nil, err
+ }
+ }
+
+ conn := newConn(netConn, false, d.ReadBufferSize, d.WriteBufferSize, d.WriteBufferPool, nil, nil)
+
+ if err := req.Write(netConn); err != nil {
+ return nil, nil, err
+ }
+
+ if trace != nil && trace.GotFirstResponseByte != nil {
+ if peek, err := conn.br.Peek(1); err == nil && len(peek) == 1 {
+ trace.GotFirstResponseByte()
+ }
+ }
+
+ resp, err := http.ReadResponse(conn.br, req)
+ if err != nil {
+ if d.TLSClientConfig != nil {
+ for _, proto := range d.TLSClientConfig.NextProtos {
+ if proto != "http/1.1" {
+ return nil, nil, fmt.Errorf(
+ "websocket: protocol %q was given but is not supported;"+
+ "sharing tls.Config with net/http Transport can cause this error: %w",
+ proto, err,
+ )
+ }
+ }
+ }
+ return nil, nil, err
+ }
+
+ if d.Jar != nil {
+ if rc := resp.Cookies(); len(rc) > 0 {
+ d.Jar.SetCookies(u, rc)
+ }
+ }
+
+ if resp.StatusCode != 101 ||
+ !tokenListContainsValue(resp.Header, "Upgrade", "websocket") ||
+ !tokenListContainsValue(resp.Header, "Connection", "upgrade") ||
+ resp.Header.Get("Sec-Websocket-Accept") != computeAcceptKey(challengeKey) {
+ // Before closing the network connection on return from this
+ // function, slurp up some of the response to aid application
+ // debugging.
+ buf := make([]byte, 1024)
+ n, _ := io.ReadFull(resp.Body, buf)
+ resp.Body = io.NopCloser(bytes.NewReader(buf[:n]))
+ return nil, resp, ErrBadHandshake
+ }
+
+ for _, ext := range parseExtensions(resp.Header) {
+ if ext[""] != "permessage-deflate" {
+ continue
+ }
+ _, snct := ext["server_no_context_takeover"]
+ _, cnct := ext["client_no_context_takeover"]
+ if !snct || !cnct {
+ return nil, resp, errInvalidCompression
+ }
+ conn.newCompressionWriter = compressNoContextTakeover
+ conn.newDecompressionReader = decompressNoContextTakeover
+ break
+ }
+
+ resp.Body = io.NopCloser(bytes.NewReader([]byte{}))
+ conn.subprotocol = resp.Header.Get("Sec-Websocket-Protocol")
+
+ if err := netConn.SetDeadline(time.Time{}); err != nil {
+ return nil, resp, err
+ }
+
+ // Success! Set netConn to nil to stop the deferred function above from
+ // closing the network connection.
+ netConn = nil
+
+ return conn, resp, nil
+}
+
+func cloneTLSConfig(cfg *tls.Config) *tls.Config {
+ if cfg == nil {
+ return &tls.Config{}
+ }
+ return cfg.Clone()
+}
+
+func doHandshake(ctx context.Context, tlsConn *tls.Conn, cfg *tls.Config) error {
+ if err := tlsConn.HandshakeContext(ctx); err != nil {
+ return err
+ }
+ if !cfg.InsecureSkipVerify {
+ if err := tlsConn.VerifyHostname(cfg.ServerName); err != nil {
+ return err
+ }
+ }
+ return nil
+}
+// Copyright 2017 The Gorilla WebSocket Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+
+const (
+ minCompressionLevel = -2 // flate.HuffmanOnly not defined in Go < 1.6
+ maxCompressionLevel = flate.BestCompression
+ defaultCompressionLevel = 1
+)
+
+var (
+ flateWriterPools [maxCompressionLevel - minCompressionLevel + 1]sync.Pool
+ flateReaderPool = sync.Pool{New: func() interface{} {
+ return flate.NewReader(nil)
+ }}
+)
+
+func decompressNoContextTakeover(r io.Reader) io.ReadCloser {
+ const tail =
+ // Add four bytes as specified in RFC
+ "\x00\x00\xff\xff" +
+ // Add final block to squelch unexpected EOF error from flate reader.
+ "\x01\x00\x00\xff\xff"
+
+ fr, _ := flateReaderPool.Get().(io.ReadCloser)
+ mr := io.MultiReader(r, strings.NewReader(tail))
+ if err := fr.(flate.Resetter).Reset(mr, nil); err != nil {
+ // Reset never fails, but handle error in case that changes.
+ fr = flate.NewReader(mr)
+ }
+ return &flateReadWrapper{fr}
+}
+
+func isValidCompressionLevel(level int) bool {
+ return minCompressionLevel <= level && level <= maxCompressionLevel
+}
+
+func compressNoContextTakeover(w io.WriteCloser, level int) io.WriteCloser {
+ p := &flateWriterPools[level-minCompressionLevel]
+ tw := &truncWriter{w: w}
+ fw, _ := p.Get().(*flate.Writer)
+ if fw == nil {
+ fw, _ = flate.NewWriter(tw, level)
+ } else {
+ fw.Reset(tw)
+ }
+ return &flateWriteWrapper{fw: fw, tw: tw, p: p}
+}
+
+// truncWriter is an io.Writer that writes all but the last four bytes of the
+// stream to another io.Writer.
+type truncWriter struct {
+ w io.WriteCloser
+ n int
+ p [4]byte
+}
+
+func (w *truncWriter) Write(p []byte) (int, error) {
+ n := 0
+
+ // fill buffer first for simplicity.
+ if w.n < len(w.p) {
+ n = copy(w.p[w.n:], p)
+ p = p[n:]
+ w.n += n
+ if len(p) == 0 {
+ return n, nil
+ }
+ }
+
+ m := len(p)
+ if m > len(w.p) {
+ m = len(w.p)
+ }
+
+ if nn, err := w.w.Write(w.p[:m]); err != nil {
+ return n + nn, err
+ }
+
+ copy(w.p[:], w.p[m:])
+ copy(w.p[len(w.p)-m:], p[len(p)-m:])
+ nn, err := w.w.Write(p[:len(p)-m])
+ return n + nn, err
+}
+
+type flateWriteWrapper struct {
+ fw *flate.Writer
+ tw *truncWriter
+ p *sync.Pool
+}
+
+func (w *flateWriteWrapper) Write(p []byte) (int, error) {
+ if w.fw == nil {
+ return 0, errWriteClosed
+ }
+ return w.fw.Write(p)
+}
+
+func (w *flateWriteWrapper) Close() error {
+ if w.fw == nil {
+ return errWriteClosed
+ }
+ err1 := w.fw.Flush()
+ w.p.Put(w.fw)
+ w.fw = nil
+ if w.tw.p != [4]byte{0, 0, 0xff, 0xff} {
+ return errors.New("websocket: internal error, unexpected bytes at end of flate stream")
+ }
+ err2 := w.tw.w.Close()
+ if err1 != nil {
+ return err1
+ }
+ return err2
+}
+
+type flateReadWrapper struct {
+ fr io.ReadCloser
+}
+
+func (r *flateReadWrapper) Read(p []byte) (int, error) {
+ if r.fr == nil {
+ return 0, io.ErrClosedPipe
+ }
+ n, err := r.fr.Read(p)
+ if err == io.EOF {
+ // Preemptively place the reader back in the pool. This helps with
+ // scenarios where the application does not call NextReader() soon after
+ // this final read.
+ r.Close()
+ }
+ return n, err
+}
+
+func (r *flateReadWrapper) Close() error {
+ if r.fr == nil {
+ return io.ErrClosedPipe
+ }
+ err := r.fr.Close()
+ flateReaderPool.Put(r.fr)
+ r.fr = nil
+ return err
+}
+// Copyright 2013 The Gorilla WebSocket Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+
+const (
+ // Frame header byte 0 bits from Section 5.2 of RFC 6455
+ finalBit = 1 << 7
+ rsv1Bit = 1 << 6
+ rsv2Bit = 1 << 5
+ rsv3Bit = 1 << 4
+
+ // Frame header byte 1 bits from Section 5.2 of RFC 6455
+ maskBit = 1 << 7
+
+ maxFrameHeaderSize = 2 + 8 + 4 // Fixed header + length + mask
+ maxControlFramePayloadSize = 125
+
+ writeWait = time.Second
+
+ defaultReadBufferSize = 4096
+ defaultWriteBufferSize = 4096
+
+ continuationFrame = 0
+ noFrame = -1
+)
+
+// Close codes defined in RFC 6455, section 11.7.
+const (
+ CloseNormalClosure = 1000
+ CloseGoingAway = 1001
+ CloseProtocolError = 1002
+ CloseUnsupportedData = 1003
+ CloseNoStatusReceived = 1005
+ CloseAbnormalClosure = 1006
+ CloseInvalidFramePayloadData = 1007
+ ClosePolicyViolation = 1008
+ CloseMessageTooBig = 1009
+ CloseMandatoryExtension = 1010
+ CloseInternalServerErr = 1011
+ CloseServiceRestart = 1012
+ CloseTryAgainLater = 1013
+ CloseTLSHandshake = 1015
+)
+
+// The message types are defined in RFC 6455, section 11.8.
+const (
+ // TextMessage denotes a text data message. The text message payload is
+ // interpreted as UTF-8 encoded text data.
+ TextMessage = 1
+
+ // BinaryMessage denotes a binary data message.
+ BinaryMessage = 2
+
+ // CloseMessage denotes a close control message. The optional message
+ // payload contains a numeric code and text. Use the FormatCloseMessage
+ // function to format a close message payload.
+ CloseMessage = 8
+
+ // PingMessage denotes a ping control message. The optional message payload
+ // is UTF-8 encoded text.
+ PingMessage = 9
+
+ // PongMessage denotes a pong control message. The optional message payload
+ // is UTF-8 encoded text.
+ PongMessage = 10
+)
+
+// ErrCloseSent is returned when the application writes a message to the
+// connection after sending a close message.
+var ErrCloseSent = errors.New("websocket: close sent")
+
+// ErrReadLimit is returned when reading a message that is larger than the
+// read limit set for the connection.
+var ErrReadLimit = errors.New("websocket: read limit exceeded")
+
+// netError satisfies the net Error interface.
+type netError struct {
+ msg string
+ temporary bool
+ timeout bool
+}
+
+func (e *netError) Error() string { return e.msg }
+func (e *netError) Temporary() bool { return e.temporary }
+func (e *netError) Timeout() bool { return e.timeout }
+
+// CloseError represents a close message.
+type CloseError struct {
+ // Code is defined in RFC 6455, section 11.7.
+ Code int
+
+ // Text is the optional text payload.
+ Text string
+}
+
+func (e *CloseError) Error() string {
+ s := []byte("websocket: close ")
+ s = strconv.AppendInt(s, int64(e.Code), 10)
+ switch e.Code {
+ case CloseNormalClosure:
+ s = append(s, " (normal)"...)
+ case CloseGoingAway:
+ s = append(s, " (going away)"...)
+ case CloseProtocolError:
+ s = append(s, " (protocol error)"...)
+ case CloseUnsupportedData:
+ s = append(s, " (unsupported data)"...)
+ case CloseNoStatusReceived:
+ s = append(s, " (no status)"...)
+ case CloseAbnormalClosure:
+ s = append(s, " (abnormal closure)"...)
+ case CloseInvalidFramePayloadData:
+ s = append(s, " (invalid payload data)"...)
+ case ClosePolicyViolation:
+ s = append(s, " (policy violation)"...)
+ case CloseMessageTooBig:
+ s = append(s, " (message too big)"...)
+ case CloseMandatoryExtension:
+ s = append(s, " (mandatory extension missing)"...)
+ case CloseInternalServerErr:
+ s = append(s, " (internal server error)"...)
+ case CloseTLSHandshake:
+ s = append(s, " (TLS handshake error)"...)
+ }
+ if e.Text != "" {
+ s = append(s, ": "...)
+ s = append(s, e.Text...)
+ }
+ return string(s)
+}
+
+// IsCloseError returns boolean indicating whether the error is a *CloseError
+// with one of the specified codes.
+func IsCloseError(err error, codes ...int) bool {
+ if e, ok := err.(*CloseError); ok {
+ for _, code := range codes {
+ if e.Code == code {
+ return true
+ }
+ }
+ }
+ return false
+}
+
+// IsUnexpectedCloseError returns boolean indicating whether the error is a
+// *CloseError with a code not in the list of expected codes.
+func IsUnexpectedCloseError(err error, expectedCodes ...int) bool {
+ if e, ok := err.(*CloseError); ok {
+ for _, code := range expectedCodes {
+ if e.Code == code {
+ return false
+ }
+ }
+ return true
+ }
+ return false
+}
+
+var (
+ errWriteTimeout = &netError{msg: "websocket: write timeout", timeout: true, temporary: true}
+ errUnexpectedEOF = &CloseError{Code: CloseAbnormalClosure, Text: io.ErrUnexpectedEOF.Error()}
+ errBadWriteOpCode = errors.New("websocket: bad write message type")
+ errWriteClosed = errors.New("websocket: write closed")
+ errInvalidControlFrame = errors.New("websocket: invalid control frame")
+)
+
+// maskRand is an io.Reader for generating mask bytes. The reader is initialized
+// to crypto/rand Reader. Tests swap the reader to a math/rand reader for
+// reproducible results.
+var maskRand = rand.Reader
+
+// newMaskKey returns a new 32 bit value for masking client frames.
+func newMaskKey() [4]byte {
+ var k [4]byte
+ _, _ = io.ReadFull(maskRand, k[:])
+ return k
+}
+
+func isControl(frameType int) bool {
+ return frameType == CloseMessage || frameType == PingMessage || frameType == PongMessage
+}
+
+func isData(frameType int) bool {
+ return frameType == TextMessage || frameType == BinaryMessage
+}
+
+var validReceivedCloseCodes = map[int]bool{
+ // see http://www.iana.org/assignments/websocket/websocket.xhtml#close-code-number
+
+ CloseNormalClosure: true,
+ CloseGoingAway: true,
+ CloseProtocolError: true,
+ CloseUnsupportedData: true,
+ CloseNoStatusReceived: false,
+ CloseAbnormalClosure: false,
+ CloseInvalidFramePayloadData: true,
+ ClosePolicyViolation: true,
+ CloseMessageTooBig: true,
+ CloseMandatoryExtension: true,
+ CloseInternalServerErr: true,
+ CloseServiceRestart: true,
+ CloseTryAgainLater: true,
+ CloseTLSHandshake: false,
+}
+
+func isValidReceivedCloseCode(code int) bool {
+ return validReceivedCloseCodes[code] || (code >= 3000 && code <= 4999)
+}
+
+// BufferPool represents a pool of buffers. The *sync.Pool type satisfies this
+// interface. The type of the value stored in a pool is not specified.
+type BufferPool interface {
+ // Get gets a value from the pool or returns nil if the pool is empty.
+ Get() interface{}
+ // Put adds a value to the pool.
+ Put(interface{})
+}
+
+// writePoolData is the type added to the write buffer pool. This wrapper is
+// used to prevent applications from peeking at and depending on the values
+// added to the pool.
+type writePoolData struct{ buf []byte }
+
+// The Conn type represents a WebSocket connection.
+type Conn struct {
+ conn net.Conn
+ isServer bool
+ subprotocol string
+
+ // Write fields
+ mu chan struct{} // used as mutex to protect write to conn
+ writeBuf []byte // frame is constructed in this buffer.
+ writePool BufferPool
+ writeBufSize int
+ writeDeadline time.Time
+ writer io.WriteCloser // the current writer returned to the application
+ isWriting bool // for best-effort concurrent write detection
+
+ writeErrMu sync.Mutex
+ writeErr error
+
+ enableWriteCompression bool
+ compressionLevel int
+ newCompressionWriter func(io.WriteCloser, int) io.WriteCloser
+
+ // Read fields
+ reader io.ReadCloser // the current reader returned to the application
+ readErr error
+ br *bufio.Reader
+ // bytes remaining in current frame.
+ // set setReadRemaining to safely update this value and prevent overflow
+ readRemaining int64
+ readFinal bool // true the current message has more frames.
+ readLength int64 // Message size.
+ readLimit int64 // Maximum message size.
+ readMaskPos int
+ readMaskKey [4]byte
+ handlePong func(string) error
+ handlePing func(string) error
+ handleClose func(int, string) error
+ readErrCount int
+ messageReader *messageReader // the current low-level reader
+
+ readDecompress bool // whether last read frame had RSV1 set
+ newDecompressionReader func(io.Reader) io.ReadCloser
+}
+
+func newConn(conn net.Conn, isServer bool, readBufferSize, writeBufferSize int, writeBufferPool BufferPool, br *bufio.Reader, writeBuf []byte) *Conn {
+
+ if br == nil {
+ if readBufferSize == 0 {
+ readBufferSize = defaultReadBufferSize
+ } else if readBufferSize < maxControlFramePayloadSize {
+ // must be large enough for control frame
+ readBufferSize = maxControlFramePayloadSize
+ }
+ br = bufio.NewReaderSize(conn, readBufferSize)
+ }
+
+ if writeBufferSize <= 0 {
+ writeBufferSize = defaultWriteBufferSize
+ }
+ writeBufferSize += maxFrameHeaderSize
+
+ if writeBuf == nil && writeBufferPool == nil {
+ writeBuf = make([]byte, writeBufferSize)
+ }
+
+ mu := make(chan struct{}, 1)
+ mu <- struct{}{}
+ c := &Conn{
+ isServer: isServer,
+ br: br,
+ conn: conn,
+ mu: mu,
+ readFinal: true,
+ writeBuf: writeBuf,
+ writePool: writeBufferPool,
+ writeBufSize: writeBufferSize,
+ enableWriteCompression: true,
+ compressionLevel: defaultCompressionLevel,
+ }
+ c.SetCloseHandler(nil)
+ c.SetPingHandler(nil)
+ c.SetPongHandler(nil)
+ return c
+}
+
+// setReadRemaining tracks the number of bytes remaining on the connection. If n
+// overflows, an ErrReadLimit is returned.
+func (c *Conn) setReadRemaining(n int64) error {
+ if n < 0 {
+ return ErrReadLimit
+ }
+
+ c.readRemaining = n
+ return nil
+}
+
+// Subprotocol returns the negotiated protocol for the connection.
+func (c *Conn) Subprotocol() string {
+ return c.subprotocol
+}
+
+// Close closes the underlying network connection without sending or waiting
+// for a close message.
+func (c *Conn) Close() error {
+ return c.conn.Close()
+}
+
+// LocalAddr returns the local network address.
+func (c *Conn) LocalAddr() net.Addr {
+ return c.conn.LocalAddr()
+}
+
+// RemoteAddr returns the remote network address.
+func (c *Conn) RemoteAddr() net.Addr {
+ return c.conn.RemoteAddr()
+}
+
+// Write methods
+
+func (c *Conn) writeFatal(err error) error {
+ c.writeErrMu.Lock()
+ if c.writeErr == nil {
+ c.writeErr = err
+ }
+ c.writeErrMu.Unlock()
+ return err
+}
+
+func (c *Conn) read(n int) ([]byte, error) {
+ p, err := c.br.Peek(n)
+ if err == io.EOF {
+ err = errUnexpectedEOF
+ }
+ // Discard is guaranteed to succeed because the number of bytes to discard
+ // is less than or equal to the number of bytes buffered.
+ _, _ = c.br.Discard(len(p))
+ return p, err
+}
+
+func (c *Conn) write(frameType int, deadline time.Time, buf0, buf1 []byte) error {
+ <-c.mu
+ defer func() { c.mu <- struct{}{} }()
+
+ c.writeErrMu.Lock()
+ err := c.writeErr
+ c.writeErrMu.Unlock()
+ if err != nil {
+ return err
+ }
+
+ if err := c.conn.SetWriteDeadline(deadline); err != nil {
+ return c.writeFatal(err)
+ }
+ if len(buf1) == 0 {
+ _, err = c.conn.Write(buf0)
+ } else {
+ err = c.writeBufs(buf0, buf1)
+ }
+ if err != nil {
+ return c.writeFatal(err)
+ }
+ if frameType == CloseMessage {
+ _ = c.writeFatal(ErrCloseSent)
+ }
+ return nil
+}
+
+func (c *Conn) writeBufs(bufs ...[]byte) error {
+ b := net.Buffers(bufs)
+ _, err := b.WriteTo(c.conn)
+ return err
+}
+
+// WriteControl writes a control message with the given deadline. The allowed
+// message types are CloseMessage, PingMessage and PongMessage.
+func (c *Conn) WriteControl(messageType int, data []byte, deadline time.Time) error {
+ if !isControl(messageType) {
+ return errBadWriteOpCode
+ }
+ if len(data) > maxControlFramePayloadSize {
+ return errInvalidControlFrame
+ }
+
+ b0 := byte(messageType) | finalBit
+ b1 := byte(len(data))
+ if !c.isServer {
+ b1 |= maskBit
+ }
+
+ buf := make([]byte, 0, maxFrameHeaderSize+maxControlFramePayloadSize)
+ buf = append(buf, b0, b1)
+
+ if c.isServer {
+ buf = append(buf, data...)
+ } else {
+ key := newMaskKey()
+ buf = append(buf, key[:]...)
+ buf = append(buf, data...)
+ maskBytes(key, 0, buf[6:])
+ }
+
+ if deadline.IsZero() {
+ // No timeout for zero time.
+ <-c.mu
+ } else {
+ d := time.Until(deadline)
+ if d < 0 {
+ return errWriteTimeout
+ }
+ select {
+ case <-c.mu:
+ default:
+ timer := time.NewTimer(d)
+ select {
+ case <-c.mu:
+ timer.Stop()
+ case <-timer.C:
+ return errWriteTimeout
+ }
+ }
+ }
+
+ defer func() { c.mu <- struct{}{} }()
+
+ c.writeErrMu.Lock()
+ err := c.writeErr
+ c.writeErrMu.Unlock()
+ if err != nil {
+ return err
+ }
+
+ if err := c.conn.SetWriteDeadline(deadline); err != nil {
+ return c.writeFatal(err)
+ }
+ if _, err = c.conn.Write(buf); err != nil {
+ return c.writeFatal(err)
+ }
+ if messageType == CloseMessage {
+ _ = c.writeFatal(ErrCloseSent)
+ }
+ return err
+}
+
+// beginMessage prepares a connection and message writer for a new message.
+func (c *Conn) beginMessage(mw *messageWriter, messageType int) error {
+ // Close previous writer if not already closed by the application. It's
+ // probably better to return an error in this situation, but we cannot
+ // change this without breaking existing applications.
+ if c.writer != nil {
+ c.writer.Close()
+ c.writer = nil
+ }
+
+ if !isControl(messageType) && !isData(messageType) {
+ return errBadWriteOpCode
+ }
+
+ c.writeErrMu.Lock()
+ err := c.writeErr
+ c.writeErrMu.Unlock()
+ if err != nil {
+ return err
+ }
+
+ mw.c = c
+ mw.frameType = messageType
+ mw.pos = maxFrameHeaderSize
+
+ if c.writeBuf == nil {
+ wpd, ok := c.writePool.Get().(writePoolData)
+ if ok {
+ c.writeBuf = wpd.buf
+ } else {
+ c.writeBuf = make([]byte, c.writeBufSize)
+ }
+ }
+ return nil
+}
+
+// NextWriter returns a writer for the next message to send. The writer's Close
+// method flushes the complete message to the network.
+//
+// There can be at most one open writer on a connection. NextWriter closes the
+// previous writer if the application has not already done so.
+//
+// All message types (TextMessage, BinaryMessage, CloseMessage, PingMessage and
+// PongMessage) are supported.
+func (c *Conn) NextWriter(messageType int) (io.WriteCloser, error) {
+ var mw messageWriter
+ if err := c.beginMessage(&mw, messageType); err != nil {
+ return nil, err
+ }
+ c.writer = &mw
+ if c.newCompressionWriter != nil && c.enableWriteCompression && isData(messageType) {
+ w := c.newCompressionWriter(c.writer, c.compressionLevel)
+ mw.compress = true
+ c.writer = w
+ }
+ return c.writer, nil
+}
+
+type messageWriter struct {
+ c *Conn
+ compress bool // whether next call to flushFrame should set RSV1
+ pos int // end of data in writeBuf.
+ frameType int // type of the current frame.
+ err error
+}
+
+func (w *messageWriter) endMessage(err error) error {
+ if w.err != nil {
+ return err
+ }
+ c := w.c
+ w.err = err
+ c.writer = nil
+ if c.writePool != nil {
+ c.writePool.Put(writePoolData{buf: c.writeBuf})
+ c.writeBuf = nil
+ }
+ return err
+}
+
+// flushFrame writes buffered data and extra as a frame to the network. The
+// final argument indicates that this is the last frame in the message.
+func (w *messageWriter) flushFrame(final bool, extra []byte) error {
+ c := w.c
+ length := w.pos - maxFrameHeaderSize + len(extra)
+
+ // Check for invalid control frames.
+ if isControl(w.frameType) &&
+ (!final || length > maxControlFramePayloadSize) {
+ return w.endMessage(errInvalidControlFrame)
+ }
+
+ b0 := byte(w.frameType)
+ if final {
+ b0 |= finalBit
+ }
+ if w.compress {
+ b0 |= rsv1Bit
+ }
+ w.compress = false
+
+ b1 := byte(0)
+ if !c.isServer {
+ b1 |= maskBit
+ }
+
+ // Assume that the frame starts at beginning of c.writeBuf.
+ framePos := 0
+ if c.isServer {
+ // Adjust up if mask not included in the header.
+ framePos = 4
+ }
+
+ switch {
+ case length >= 65536:
+ c.writeBuf[framePos] = b0
+ c.writeBuf[framePos+1] = b1 | 127
+ binary.BigEndian.PutUint64(c.writeBuf[framePos+2:], uint64(length))
+ case length > 125:
+ framePos += 6
+ c.writeBuf[framePos] = b0
+ c.writeBuf[framePos+1] = b1 | 126
+ binary.BigEndian.PutUint16(c.writeBuf[framePos+2:], uint16(length))
+ default:
+ framePos += 8
+ c.writeBuf[framePos] = b0
+ c.writeBuf[framePos+1] = b1 | byte(length)
+ }
+
+ if !c.isServer {
+ key := newMaskKey()
+ copy(c.writeBuf[maxFrameHeaderSize-4:], key[:])
+ maskBytes(key, 0, c.writeBuf[maxFrameHeaderSize:w.pos])
+ if len(extra) > 0 {
+ return w.endMessage(c.writeFatal(errors.New("websocket: internal error, extra used in client mode")))
+ }
+ }
+
+ // Write the buffers to the connection with best-effort detection of
+ // concurrent writes. See the concurrency section in the package
+ // documentation for more info.
+
+ if c.isWriting {
+ panic("concurrent write to websocket connection")
+ }
+ c.isWriting = true
+
+ err := c.write(w.frameType, c.writeDeadline, c.writeBuf[framePos:w.pos], extra)
+
+ if !c.isWriting {
+ panic("concurrent write to websocket connection")
+ }
+ c.isWriting = false
+
+ if err != nil {
+ return w.endMessage(err)
+ }
+
+ if final {
+ _ = w.endMessage(errWriteClosed)
+ return nil
+ }
+
+ // Setup for next frame.
+ w.pos = maxFrameHeaderSize
+ w.frameType = continuationFrame
+ return nil
+}
+
+func (w *messageWriter) ncopy(max int) (int, error) {
+ n := len(w.c.writeBuf) - w.pos
+ if n <= 0 {
+ if err := w.flushFrame(false, nil); err != nil {
+ return 0, err
+ }
+ n = len(w.c.writeBuf) - w.pos
+ }
+ if n > max {
+ n = max
+ }
+ return n, nil
+}
+
+func (w *messageWriter) Write(p []byte) (int, error) {
+ if w.err != nil {
+ return 0, w.err
+ }
+
+ if len(p) > 2*len(w.c.writeBuf) && w.c.isServer {
+ // Don't buffer large messages.
+ err := w.flushFrame(false, p)
+ if err != nil {
+ return 0, err
+ }
+ return len(p), nil
+ }
+
+ nn := len(p)
+ for len(p) > 0 {
+ n, err := w.ncopy(len(p))
+ if err != nil {
+ return 0, err
+ }
+ copy(w.c.writeBuf[w.pos:], p[:n])
+ w.pos += n
+ p = p[n:]
+ }
+ return nn, nil
+}
+
+func (w *messageWriter) WriteString(p string) (int, error) {
+ if w.err != nil {
+ return 0, w.err
+ }
+
+ nn := len(p)
+ for len(p) > 0 {
+ n, err := w.ncopy(len(p))
+ if err != nil {
+ return 0, err
+ }
+ copy(w.c.writeBuf[w.pos:], p[:n])
+ w.pos += n
+ p = p[n:]
+ }
+ return nn, nil
+}
+
+func (w *messageWriter) ReadFrom(r io.Reader) (nn int64, err error) {
+ if w.err != nil {
+ return 0, w.err
+ }
+ for {
+ if w.pos == len(w.c.writeBuf) {
+ err = w.flushFrame(false, nil)
+ if err != nil {
+ break
+ }
+ }
+ var n int
+ n, err = r.Read(w.c.writeBuf[w.pos:])
+ w.pos += n
+ nn += int64(n)
+ if err != nil {
+ if err == io.EOF {
+ err = nil
+ }
+ break
+ }
+ }
+ return nn, err
+}
+
+func (w *messageWriter) Close() error {
+ if w.err != nil {
+ return w.err
+ }
+ return w.flushFrame(true, nil)
+}
+
+// WritePreparedMessage writes prepared message into connection.
+func (c *Conn) WritePreparedMessage(pm *PreparedMessage) error {
+ frameType, frameData, err := pm.frame(prepareKey{
+ isServer: c.isServer,
+ compress: c.newCompressionWriter != nil && c.enableWriteCompression && isData(pm.messageType),
+ compressionLevel: c.compressionLevel,
+ })
+ if err != nil {
+ return err
+ }
+ if c.isWriting {
+ panic("concurrent write to websocket connection")
+ }
+ c.isWriting = true
+ err = c.write(frameType, c.writeDeadline, frameData, nil)
+ if !c.isWriting {
+ panic("concurrent write to websocket connection")
+ }
+ c.isWriting = false
+ return err
+}
+
+// WriteMessage is a helper method for getting a writer using NextWriter,
+// writing the message and closing the writer.
+func (c *Conn) WriteMessage(messageType int, data []byte) error {
+
+ if c.isServer && (c.newCompressionWriter == nil || !c.enableWriteCompression) {
+ // Fast path with no allocations and single frame.
+
+ var mw messageWriter
+ if err := c.beginMessage(&mw, messageType); err != nil {
+ return err
+ }
+ n := copy(c.writeBuf[mw.pos:], data)
+ mw.pos += n
+ data = data[n:]
+ return mw.flushFrame(true, data)
+ }
+
+ w, err := c.NextWriter(messageType)
+ if err != nil {
+ return err
+ }
+ if _, err = w.Write(data); err != nil {
+ return err
+ }
+ return w.Close()
+}
+
+// SetWriteDeadline sets the write deadline on the underlying network
+// connection. After a write has timed out, the websocket state is corrupt and
+// all future writes will return an error. A zero value for t means writes will
+// not time out.
+func (c *Conn) SetWriteDeadline(t time.Time) error {
+ c.writeDeadline = t
+ return nil
+}
+
+// Read methods
+
+func (c *Conn) advanceFrame() (int, error) {
+ // 1. Skip remainder of previous frame.
+
+ if c.readRemaining > 0 {
+ if _, err := io.CopyN(io.Discard, c.br, c.readRemaining); err != nil {
+ return noFrame, err
+ }
+ }
+
+ // 2. Read and parse first two bytes of frame header.
+ // To aid debugging, collect and report all errors in the first two bytes
+ // of the header.
+
+ var errors []string
+
+ p, err := c.read(2)
+ if err != nil {
+ return noFrame, err
+ }
+
+ frameType := int(p[0] & 0xf)
+ final := p[0]&finalBit != 0
+ rsv1 := p[0]&rsv1Bit != 0
+ rsv2 := p[0]&rsv2Bit != 0
+ rsv3 := p[0]&rsv3Bit != 0
+ mask := p[1]&maskBit != 0
+ _ = c.setReadRemaining(int64(p[1] & 0x7f)) // will not fail because argument is >= 0
+
+ c.readDecompress = false
+ if rsv1 {
+ if c.newDecompressionReader != nil {
+ c.readDecompress = true
+ } else {
+ errors = append(errors, "RSV1 set")
+ }
+ }
+
+ if rsv2 {
+ errors = append(errors, "RSV2 set")
+ }
+
+ if rsv3 {
+ errors = append(errors, "RSV3 set")
+ }
+
+ switch frameType {
+ case CloseMessage, PingMessage, PongMessage:
+ if c.readRemaining > maxControlFramePayloadSize {
+ errors = append(errors, "len > 125 for control")
+ }
+ if !final {
+ errors = append(errors, "FIN not set on control")
+ }
+ case TextMessage, BinaryMessage:
+ if !c.readFinal {
+ errors = append(errors, "data before FIN")
+ }
+ c.readFinal = final
+ case continuationFrame:
+ if c.readFinal {
+ errors = append(errors, "continuation after FIN")
+ }
+ c.readFinal = final
+ default:
+ errors = append(errors, "bad opcode "+strconv.Itoa(frameType))
+ }
+
+ if mask != c.isServer {
+ errors = append(errors, "bad MASK")
+ }
+
+ if len(errors) > 0 {
+ return noFrame, c.handleProtocolError(strings.Join(errors, ", "))
+ }
+
+ // 3. Read and parse frame length as per
+ // https://tools.ietf.org/html/rfc6455#section-5.2
+ //
+ // The length of the "Payload data", in bytes: if 0-125, that is the payload
+ // length.
+ // - If 126, the following 2 bytes interpreted as a 16-bit unsigned
+ // integer are the payload length.
+ // - If 127, the following 8 bytes interpreted as
+ // a 64-bit unsigned integer (the most significant bit MUST be 0) are the
+ // payload length. Multibyte length quantities are expressed in network byte
+ // order.
+
+ switch c.readRemaining {
+ case 126:
+ p, err := c.read(2)
+ if err != nil {
+ return noFrame, err
+ }
+
+ if err := c.setReadRemaining(int64(binary.BigEndian.Uint16(p))); err != nil {
+ return noFrame, err
+ }
+ case 127:
+ p, err := c.read(8)
+ if err != nil {
+ return noFrame, err
+ }
+
+ if err := c.setReadRemaining(int64(binary.BigEndian.Uint64(p))); err != nil {
+ return noFrame, err
+ }
+ }
+
+ // 4. Handle frame masking.
+
+ if mask {
+ c.readMaskPos = 0
+ p, err := c.read(len(c.readMaskKey))
+ if err != nil {
+ return noFrame, err
+ }
+ copy(c.readMaskKey[:], p)
+ }
+
+ // 5. For text and binary messages, enforce read limit and return.
+
+ if frameType == continuationFrame || frameType == TextMessage || frameType == BinaryMessage {
+
+ c.readLength += c.readRemaining
+ // Don't allow readLength to overflow in the presence of a large readRemaining
+ // counter.
+ if c.readLength < 0 {
+ return noFrame, ErrReadLimit
+ }
+
+ if c.readLimit > 0 && c.readLength > c.readLimit {
+ // Make a best effort to send a close message describing the problem.
+ _ = c.WriteControl(CloseMessage, FormatCloseMessage(CloseMessageTooBig, ""), time.Now().Add(writeWait))
+ return noFrame, ErrReadLimit
+ }
+
+ return frameType, nil
+ }
+
+ // 6. Read control frame payload.
+
+ var payload []byte
+ if c.readRemaining > 0 {
+ payload, err = c.read(int(c.readRemaining))
+ _ = c.setReadRemaining(0) // will not fail because argument is >= 0
+ if err != nil {
+ return noFrame, err
+ }
+ if c.isServer {
+ maskBytes(c.readMaskKey, 0, payload)
+ }
+ }
+
+ // 7. Process control frame payload.
+
+ switch frameType {
+ case PongMessage:
+ if err := c.handlePong(string(payload)); err != nil {
+ return noFrame, err
+ }
+ case PingMessage:
+ if err := c.handlePing(string(payload)); err != nil {
+ return noFrame, err
+ }
+ case CloseMessage:
+ closeCode := CloseNoStatusReceived
+ closeText := ""
+ if len(payload) >= 2 {
+ closeCode = int(binary.BigEndian.Uint16(payload))
+ if !isValidReceivedCloseCode(closeCode) {
+ return noFrame, c.handleProtocolError("bad close code " + strconv.Itoa(closeCode))
+ }
+ closeText = string(payload[2:])
+ if !utf8.ValidString(closeText) {
+ return noFrame, c.handleProtocolError("invalid utf8 payload in close frame")
+ }
+ }
+ if err := c.handleClose(closeCode, closeText); err != nil {
+ return noFrame, err
+ }
+ return noFrame, &CloseError{Code: closeCode, Text: closeText}
+ }
+
+ return frameType, nil
+}
+
+func (c *Conn) handleProtocolError(message string) error {
+ data := FormatCloseMessage(CloseProtocolError, message)
+ if len(data) > maxControlFramePayloadSize {
+ data = data[:maxControlFramePayloadSize]
+ }
+ // Make a best effor to send a close message describing the problem.
+ _ = c.WriteControl(CloseMessage, data, time.Now().Add(writeWait))
+ return errors.New("websocket: " + message)
+}
+
+// NextReader returns the next data message received from the peer. The
+// returned messageType is either TextMessage or BinaryMessage.
+//
+// There can be at most one open reader on a connection. NextReader discards
+// the previous message if the application has not already consumed it.
+//
+// Applications must break out of the application's read loop when this method
+// returns a non-nil error value. Errors returned from this method are
+// permanent. Once this method returns a non-nil error, all subsequent calls to
+// this method return the same error.
+func (c *Conn) NextReader() (messageType int, r io.Reader, err error) {
+ // Close previous reader, only relevant for decompression.
+ if c.reader != nil {
+ c.reader.Close()
+ c.reader = nil
+ }
+
+ c.messageReader = nil
+ c.readLength = 0
+
+ for c.readErr == nil {
+ frameType, err := c.advanceFrame()
+ if err != nil {
+ c.readErr = err
+ break
+ }
+
+ if frameType == TextMessage || frameType == BinaryMessage {
+ c.messageReader = &messageReader{c}
+ c.reader = c.messageReader
+ if c.readDecompress {
+ c.reader = c.newDecompressionReader(c.reader)
+ }
+ return frameType, c.reader, nil
+ }
+ }
+
+ // Applications that do handle the error returned from this method spin in
+ // tight loop on connection failure. To help application developers detect
+ // this error, panic on repeated reads to the failed connection.
+ c.readErrCount++
+ if c.readErrCount >= 1000 {
+ panic("repeated read on failed websocket connection")
+ }
+
+ return noFrame, nil, c.readErr
+}
+
+type messageReader struct{ c *Conn }
+
+func (r *messageReader) Read(b []byte) (int, error) {
+ c := r.c
+ if c.messageReader != r {
+ return 0, io.EOF
+ }
+
+ for c.readErr == nil {
+
+ if c.readRemaining > 0 {
+ if int64(len(b)) > c.readRemaining {
+ b = b[:c.readRemaining]
+ }
+ n, err := c.br.Read(b)
+ c.readErr = err
+ if c.isServer {
+ c.readMaskPos = maskBytes(c.readMaskKey, c.readMaskPos, b[:n])
+ }
+ rem := c.readRemaining
+ rem -= int64(n)
+ _ = c.setReadRemaining(rem) // rem is guaranteed to be >= 0
+ if c.readRemaining > 0 && c.readErr == io.EOF {
+ c.readErr = errUnexpectedEOF
+ }
+ return n, c.readErr
+ }
+
+ if c.readFinal {
+ c.messageReader = nil
+ return 0, io.EOF
+ }
+
+ frameType, err := c.advanceFrame()
+ switch {
+ case err != nil:
+ c.readErr = err
+ case frameType == TextMessage || frameType == BinaryMessage:
+ c.readErr = errors.New("websocket: internal error, unexpected text or binary in Reader")
+ }
+ }
+
+ err := c.readErr
+ if err == io.EOF && c.messageReader == r {
+ err = errUnexpectedEOF
+ }
+ return 0, err
+}
+
+func (r *messageReader) Close() error {
+ return nil
+}
+
+// ReadMessage is a helper method for getting a reader using NextReader and
+// reading from that reader to a buffer.
+func (c *Conn) ReadMessage() (messageType int, p []byte, err error) {
+ var r io.Reader
+ messageType, r, err = c.NextReader()
+ if err != nil {
+ return messageType, nil, err
+ }
+ p, err = io.ReadAll(r)
+ return messageType, p, err
+}
+
+// SetReadDeadline sets the read deadline on the underlying network connection.
+// After a read has timed out, the websocket connection state is corrupt and
+// all future reads will return an error. A zero value for t means reads will
+// not time out.
+func (c *Conn) SetReadDeadline(t time.Time) error {
+ return c.conn.SetReadDeadline(t)
+}
+
+// SetReadLimit sets the maximum size in bytes for a message read from the peer. If a
+// message exceeds the limit, the connection sends a close message to the peer
+// and returns ErrReadLimit to the application.
+func (c *Conn) SetReadLimit(limit int64) {
+ c.readLimit = limit
+}
+
+// CloseHandler returns the current close handler
+func (c *Conn) CloseHandler() func(code int, text string) error {
+ return c.handleClose
+}
+
+// SetCloseHandler sets the handler for close messages received from the peer.
+// The code argument to h is the received close code or CloseNoStatusReceived
+// if the close message is empty. The default close handler sends a close
+// message back to the peer.
+//
+// The handler function is called from the NextReader, ReadMessage and message
+// reader Read methods. The application must read the connection to process
+// close messages as described in the section on Control Messages above.
+//
+// The connection read methods return a CloseError when a close message is
+// received. Most applications should handle close messages as part of their
+// normal error handling. Applications should only set a close handler when the
+// application must perform some action before sending a close message back to
+// the peer.
+func (c *Conn) SetCloseHandler(h func(code int, text string) error) {
+ if h == nil {
+ h = func(code int, text string) error {
+ message := FormatCloseMessage(code, "")
+ // Make a best effor to send the close message.
+ _ = c.WriteControl(CloseMessage, message, time.Now().Add(writeWait))
+ return nil
+ }
+ }
+ c.handleClose = h
+}
+
+// PingHandler returns the current ping handler
+func (c *Conn) PingHandler() func(appData string) error {
+ return c.handlePing
+}
+
+// SetPingHandler sets the handler for ping messages received from the peer.
+// The appData argument to h is the PING message application data. The default
+// ping handler sends a pong to the peer.
+//
+// The handler function is called from the NextReader, ReadMessage and message
+// reader Read methods. The application must read the connection to process
+// ping messages as described in the section on Control Messages above.
+func (c *Conn) SetPingHandler(h func(appData string) error) {
+ if h == nil {
+ h = func(message string) error {
+ // Make a best effort to send the pong message.
+ _ = c.WriteControl(PongMessage, []byte(message), time.Now().Add(writeWait))
+ return nil
+ }
+ }
+ c.handlePing = h
+}
+
+// PongHandler returns the current pong handler
+func (c *Conn) PongHandler() func(appData string) error {
+ return c.handlePong
+}
+
+// SetPongHandler sets the handler for pong messages received from the peer.
+// The appData argument to h is the PONG message application data. The default
+// pong handler does nothing.
+//
+// The handler function is called from the NextReader, ReadMessage and message
+// reader Read methods. The application must read the connection to process
+// pong messages as described in the section on Control Messages above.
+func (c *Conn) SetPongHandler(h func(appData string) error) {
+ if h == nil {
+ h = func(string) error { return nil }
+ }
+ c.handlePong = h
+}
+
+// NetConn returns the underlying connection that is wrapped by c.
+// Note that writing to or reading from this connection directly will corrupt the
+// WebSocket connection.
+func (c *Conn) NetConn() net.Conn {
+ return c.conn
+}
+
+// UnderlyingConn returns the internal net.Conn. This can be used to further
+// modifications to connection specific flags.
+// Deprecated: Use the NetConn method.
+func (c *Conn) UnderlyingConn() net.Conn {
+ return c.conn
+}
+
+// EnableWriteCompression enables and disables write compression of
+// subsequent text and binary messages. This function is a noop if
+// compression was not negotiated with the peer.
+func (c *Conn) EnableWriteCompression(enable bool) {
+ c.enableWriteCompression = enable
+}
+
+// SetCompressionLevel sets the flate compression level for subsequent text and
+// binary messages. This function is a noop if compression was not negotiated
+// with the peer. See the compress/flate package for a description of
+// compression levels.
+func (c *Conn) SetCompressionLevel(level int) error {
+ if !isValidCompressionLevel(level) {
+ return errors.New("websocket: invalid compression level")
+ }
+ c.compressionLevel = level
+ return nil
+}
+
+// FormatCloseMessage formats closeCode and text as a WebSocket close message.
+// An empty message is returned for code CloseNoStatusReceived.
+func FormatCloseMessage(closeCode int, text string) []byte {
+ if closeCode == CloseNoStatusReceived {
+ // Return empty message because it's illegal to send
+ // CloseNoStatusReceived. Return non-nil value in case application
+ // checks for nil.
+ return []byte{}
+ }
+ buf := make([]byte, 2+len(text))
+ binary.BigEndian.PutUint16(buf, uint16(closeCode))
+ copy(buf[2:], text)
+ return buf
+}
+// Copyright 2013 The Gorilla WebSocket Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// Package websocket implements the WebSocket protocol defined in RFC 6455.
+//
+// Overview
+//
+// The Conn type represents a WebSocket connection. A server application calls
+// the Upgrader.Upgrade method from an HTTP request handler to get a *Conn:
+//
+// var upgrader = websocket.Upgrader{
+// ReadBufferSize: 1024,
+// WriteBufferSize: 1024,
+// }
+//
+// func handler(w http.ResponseWriter, r *http.Request) {
+// conn, err := upgrader.Upgrade(w, r, nil)
+// if err != nil {
+// log.Println(err)
+// return
+// }
+// ... Use conn to send and receive messages.
+// }
+//
+// Call the connection's WriteMessage and ReadMessage methods to send and
+// receive messages as a slice of bytes. This snippet of code shows how to echo
+// messages using these methods:
+//
+// for {
+// messageType, p, err := conn.ReadMessage()
+// if err != nil {
+// log.Println(err)
+// return
+// }
+// if err := conn.WriteMessage(messageType, p); err != nil {
+// log.Println(err)
+// return
+// }
+// }
+//
+// In above snippet of code, p is a []byte and messageType is an int with value
+// websocket.BinaryMessage or websocket.TextMessage.
+//
+// An application can also send and receive messages using the io.WriteCloser
+// and io.Reader interfaces. To send a message, call the connection NextWriter
+// method to get an io.WriteCloser, write the message to the writer and close
+// the writer when done. To receive a message, call the connection NextReader
+// method to get an io.Reader and read until io.EOF is returned. This snippet
+// shows how to echo messages using the NextWriter and NextReader methods:
+//
+// for {
+// messageType, r, err := conn.NextReader()
+// if err != nil {
+// return
+// }
+// w, err := conn.NextWriter(messageType)
+// if err != nil {
+// return err
+// }
+// if _, err := io.Copy(w, r); err != nil {
+// return err
+// }
+// if err := w.Close(); err != nil {
+// return err
+// }
+// }
+//
+// Data Messages
+//
+// The WebSocket protocol distinguishes between text and binary data messages.
+// Text messages are interpreted as UTF-8 encoded text. The interpretation of
+// binary messages is left to the application.
+//
+// This package uses the TextMessage and BinaryMessage integer constants to
+// identify the two data message types. The ReadMessage and NextReader methods
+// return the type of the received message. The messageType argument to the
+// WriteMessage and NextWriter methods specifies the type of a sent message.
+//
+// It is the application's responsibility to ensure that text messages are
+// valid UTF-8 encoded text.
+//
+// Control Messages
+//
+// The WebSocket protocol defines three types of control messages: close, ping
+// and pong. Call the connection WriteControl, WriteMessage or NextWriter
+// methods to send a control message to the peer.
+//
+// Connections handle received close messages by calling the handler function
+// set with the SetCloseHandler method and by returning a *CloseError from the
+// NextReader, ReadMessage or the message Read method. The default close
+// handler sends a close message to the peer.
+//
+// Connections handle received ping messages by calling the handler function
+// set with the SetPingHandler method. The default ping handler sends a pong
+// message to the peer.
+//
+// Connections handle received pong messages by calling the handler function
+// set with the SetPongHandler method. The default pong handler does nothing.
+// If an application sends ping messages, then the application should set a
+// pong handler to receive the corresponding pong.
+//
+// The control message handler functions are called from the NextReader,
+// ReadMessage and message reader Read methods. The default close and ping
+// handlers can block these methods for a short time when the handler writes to
+// the connection.
+//
+// The application must read the connection to process close, ping and pong
+// messages sent from the peer. If the application is not otherwise interested
+// in messages from the peer, then the application should start a goroutine to
+// read and discard messages from the peer. A simple example is:
+//
+// func readLoop(c *websocket.Conn) {
+// for {
+// if _, _, err := c.NextReader(); err != nil {
+// c.Close()
+// break
+// }
+// }
+// }
+//
+// Concurrency
+//
+// Connections support one concurrent reader and one concurrent writer.
+//
+// Applications are responsible for ensuring that no more than one goroutine
+// calls the write methods (NextWriter, SetWriteDeadline, WriteMessage,
+// WriteJSON, EnableWriteCompression, SetCompressionLevel) concurrently and
+// that no more than one goroutine calls the read methods (NextReader,
+// SetReadDeadline, ReadMessage, ReadJSON, SetPongHandler, SetPingHandler)
+// concurrently.
+//
+// The Close and WriteControl methods can be called concurrently with all other
+// methods.
+//
+// Origin Considerations
+//
+// Web browsers allow Javascript applications to open a WebSocket connection to
+// any host. It's up to the server to enforce an origin policy using the Origin
+// request header sent by the browser.
+//
+// The Upgrader calls the function specified in the CheckOrigin field to check
+// the origin. If the CheckOrigin function returns false, then the Upgrade
+// method fails the WebSocket handshake with HTTP status 403.
+//
+// If the CheckOrigin field is nil, then the Upgrader uses a safe default: fail
+// the handshake if the Origin request header is present and the Origin host is
+// not equal to the Host request header.
+//
+// The deprecated package-level Upgrade function does not perform origin
+// checking. The application is responsible for checking the Origin header
+// before calling the Upgrade function.
+//
+// Buffers
+//
+// Connections buffer network input and output to reduce the number
+// of system calls when reading or writing messages.
+//
+// Write buffers are also used for constructing WebSocket frames. See RFC 6455,
+// Section 5 for a discussion of message framing. A WebSocket frame header is
+// written to the network each time a write buffer is flushed to the network.
+// Decreasing the size of the write buffer can increase the amount of framing
+// overhead on the connection.
+//
+// The buffer sizes in bytes are specified by the ReadBufferSize and
+// WriteBufferSize fields in the Dialer and Upgrader. The Dialer uses a default
+// size of 4096 when a buffer size field is set to zero. The Upgrader reuses
+// buffers created by the HTTP server when a buffer size field is set to zero.
+// The HTTP server buffers have a size of 4096 at the time of this writing.
+//
+// The buffer sizes do not limit the size of a message that can be read or
+// written by a connection.
+//
+// Buffers are held for the lifetime of the connection by default. If the
+// Dialer or Upgrader WriteBufferPool field is set, then a connection holds the
+// write buffer only when writing a message.
+//
+// Applications should tune the buffer sizes to balance memory use and
+// performance. Increasing the buffer size uses more memory, but can reduce the
+// number of system calls to read or write the network. In the case of writing,
+// increasing the buffer size can reduce the number of frame headers written to
+// the network.
+//
+// Some guidelines for setting buffer parameters are:
+//
+// Limit the buffer sizes to the maximum expected message size. Buffers larger
+// than the largest message do not provide any benefit.
+//
+// Depending on the distribution of message sizes, setting the buffer size to
+// a value less than the maximum expected message size can greatly reduce memory
+// use with a small impact on performance. Here's an example: If 99% of the
+// messages are smaller than 256 bytes and the maximum message size is 512
+// bytes, then a buffer size of 256 bytes will result in 1.01 more system calls
+// than a buffer size of 512 bytes. The memory savings is 50%.
+//
+// A write buffer pool is useful when the application has a modest number
+// writes over a large number of connections. when buffers are pooled, a larger
+// buffer size has a reduced impact on total memory use and has the benefit of
+// reducing system calls and frame overhead.
+//
+// Compression EXPERIMENTAL
+//
+// Per message compression extensions (RFC 7692) are experimentally supported
+// by this package in a limited capacity. Setting the EnableCompression option
+// to true in Dialer or Upgrader will attempt to negotiate per message deflate
+// support.
+//
+// var upgrader = websocket.Upgrader{
+// EnableCompression: true,
+// }
+//
+// If compression was successfully negotiated with the connection's peer, any
+// message received in compressed form will be automatically decompressed.
+// All Read methods will return uncompressed bytes.
+//
+// Per message compression of messages written to a connection can be enabled
+// or disabled by calling the corresponding Conn method:
+//
+// conn.EnableWriteCompression(false)
+//
+// Currently this package does not support compression with "context takeover".
+// This means that messages must be compressed and decompressed in isolation,
+// without retaining sliding window or dictionary state across messages. For
+// more details refer to RFC 7692.
+//
+// Use of compression is experimental and may result in decreased performance.
+
+// Copyright 2019 The Gorilla WebSocket Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+
+// JoinMessages concatenates received messages to create a single io.Reader.
+// The string term is appended to each message. The returned reader does not
+// support concurrent calls to the Read method.
+func JoinMessages(c *Conn, term string) io.Reader {
+ return &joinReader{c: c, term: term}
+}
+
+type joinReader struct {
+ c *Conn
+ term string
+ r io.Reader
+}
+
+func (r *joinReader) Read(p []byte) (int, error) {
+ if r.r == nil {
+ var err error
+ _, r.r, err = r.c.NextReader()
+ if err != nil {
+ return 0, err
+ }
+ if r.term != "" {
+ r.r = io.MultiReader(r.r, strings.NewReader(r.term))
+ }
+ }
+ n, err := r.r.Read(p)
+ if err == io.EOF {
+ err = nil
+ r.r = nil
+ }
+ return n, err
+}
+// Copyright 2013 The Gorilla WebSocket Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+
+// WriteJSON writes the JSON encoding of v as a message.
+//
+// Deprecated: Use c.WriteJSON instead.
+func WriteJSON(c *Conn, v interface{}) error {
+ return c.WriteJSON(v)
+}
+
+// WriteJSON writes the JSON encoding of v as a message.
+//
+// See the documentation for encoding/json Marshal for details about the
+// conversion of Go values to JSON.
+func (c *Conn) WriteJSON(v interface{}) error {
+ w, err := c.NextWriter(TextMessage)
+ if err != nil {
+ return err
+ }
+ err1 := json.NewEncoder(w).Encode(v)
+ err2 := w.Close()
+ if err1 != nil {
+ return err1
+ }
+ return err2
+}
+
+// ReadJSON reads the next JSON-encoded message from the connection and stores
+// it in the value pointed to by v.
+//
+// Deprecated: Use c.ReadJSON instead.
+func ReadJSON(c *Conn, v interface{}) error {
+ return c.ReadJSON(v)
+}
+
+// ReadJSON reads the next JSON-encoded message from the connection and stores
+// it in the value pointed to by v.
+//
+// See the documentation for the encoding/json Unmarshal function for details
+// about the conversion of JSON to a Go value.
+func (c *Conn) ReadJSON(v interface{}) error {
+ _, r, err := c.NextReader()
+ if err != nil {
+ return err
+ }
+ err = json.NewDecoder(r).Decode(v)
+ if err == io.EOF {
+ // One value is expected in the message.
+ err = io.ErrUnexpectedEOF
+ }
+ return err
+}
+// Copyright 2016 The Gorilla WebSocket Authors. All rights reserved. Use of
+// this source code is governed by a BSD-style license that can be found in the
+// LICENSE file.
+
+/// go:build !appengine // +build !appengine
+
+
+const wordSize = int(unsafe.Sizeof(uintptr(0)))
+
+func maskBytesUnsafe(key [4]byte, pos int, b []byte) int {
+ // Mask one byte at a time for small buffers.
+ if len(b) < 2*wordSize {
+ for i := range b {
+ b[i] ^= key[pos&3]
+ pos++
+ }
+ return pos & 3
+ }
+
+ // Mask one byte at a time to word boundary.
+ if n := int(uintptr(unsafe.Pointer(&b[0]))) % wordSize; n != 0 {
+ n = wordSize - n
+ for i := range b[:n] {
+ b[i] ^= key[pos&3]
+ pos++
+ }
+ b = b[n:]
+ }
+
+ // Create aligned word size key.
+ var k [wordSize]byte
+ for i := range k {
+ k[i] = key[(pos+i)&3]
+ }
+ kw := *(*uintptr)(unsafe.Pointer(&k))
+
+ // Mask one word at a time.
+ n := (len(b) / wordSize) * wordSize
+ for i := 0; i < n; i += wordSize {
+ *(*uintptr)(unsafe.Pointer(uintptr(unsafe.Pointer(&b[0])) + uintptr(i))) ^= kw
+ }
+
+ // Mask one byte at a time for remaining bytes.
+ b = b[n:]
+ for i := range b {
+ b[i] ^= key[pos&3]
+ pos++
+ }
+
+ return pos & 3
+}
+// Copyright 2016 The Gorilla WebSocket Authors. All rights reserved. Use of
+// this source code is governed by a BSD-style license that can be found in the
+// LICENSE file.
+
+/// go:build appengine // +build appengine
+
+
+func maskBytes(key [4]byte, pos int, b []byte) int {
+ for i := range b {
+ b[i] ^= key[pos&3]
+ pos++
+ }
+ return pos & 3
+}
+// Copyright 2017 The Gorilla WebSocket Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+
+// PreparedMessage caches on the wire representations of a message payload.
+// Use PreparedMessage to efficiently send a message payload to multiple
+// connections. PreparedMessage is especially useful when compression is used
+// because the CPU and memory expensive compression operation can be executed
+// once for a given set of compression options.
+type PreparedMessage struct {
+ messageType int
+ data []byte
+ mu sync.Mutex
+ frames map[prepareKey]*preparedFrame
+}
+
+// prepareKey defines a unique set of options to cache prepared frames in PreparedMessage.
+type prepareKey struct {
+ isServer bool
+ compress bool
+ compressionLevel int
+}
+
+// preparedFrame contains data in wire representation.
+type preparedFrame struct {
+ once sync.Once
+ data []byte
+}
+
+// NewPreparedMessage returns an initialized PreparedMessage. You can then send
+// it to connection using WritePreparedMessage method. Valid wire
+// representation will be calculated lazily only once for a set of current
+// connection options.
+func NewPreparedMessage(messageType int, data []byte) (*PreparedMessage, error) {
+ pm := &PreparedMessage{
+ messageType: messageType,
+ frames: make(map[prepareKey]*preparedFrame),
+ data: data,
+ }
+
+ // Prepare a plain server frame.
+ _, frameData, err := pm.frame(prepareKey{isServer: true, compress: false})
+ if err != nil {
+ return nil, err
+ }
+
+ // To protect against caller modifying the data argument, remember the data
+ // copied to the plain server frame.
+ pm.data = frameData[len(frameData)-len(data):]
+ return pm, nil
+}
+
+func (pm *PreparedMessage) frame(key prepareKey) (int, []byte, error) {
+ pm.mu.Lock()
+ frame, ok := pm.frames[key]
+ if !ok {
+ frame = &preparedFrame{}
+ pm.frames[key] = frame
+ }
+ pm.mu.Unlock()
+
+ var err error
+ frame.once.Do(func() {
+ // Prepare a frame using a 'fake' connection.
+ // TODO: Refactor code in conn.go to allow more direct construction of
+ // the frame.
+ mu := make(chan struct{}, 1)
+ mu <- struct{}{}
+ var nc prepareConn
+ c := &Conn{
+ conn: &nc,
+ mu: mu,
+ isServer: key.isServer,
+ compressionLevel: key.compressionLevel,
+ enableWriteCompression: true,
+ writeBuf: make([]byte, defaultWriteBufferSize+maxFrameHeaderSize),
+ }
+ if key.compress {
+ c.newCompressionWriter = compressNoContextTakeover
+ }
+ err = c.WriteMessage(pm.messageType, pm.data)
+ frame.data = nc.buf.Bytes()
+ })
+ return pm.messageType, frame.data, err
+}
+
+type prepareConn struct {
+ buf bytes.Buffer
+ net.Conn
+}
+
+func (pc *prepareConn) Write(p []byte) (int, error) { return pc.buf.Write(p) }
+func (pc *prepareConn) SetWriteDeadline(t time.Time) error { return nil }
+// Copyright 2017 The Gorilla WebSocket Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+
+type netDialerFunc func(ctx context.Context, network, addr string) (net.Conn, error)
+
+func (fn netDialerFunc) Dial(network, addr string) (net.Conn, error) {
+ return fn(context.Background(), network, addr)
+}
+
+func (fn netDialerFunc) DialContext(ctx context.Context, network, addr string) (net.Conn, error) {
+ return fn(ctx, network, addr)
+}
+
+func proxyFromURL(proxyURL *url.URL, forwardDial netDialerFunc) (netDialerFunc, error) {
+ if proxyURL.Scheme == "http" {
+ return (&httpProxyDialer{proxyURL: proxyURL, forwardDial: forwardDial}).DialContext, nil
+ }
+ dialer, err := proxy.FromURL(proxyURL, forwardDial)
+ if err != nil {
+ return nil, err
+ }
+ if d, ok := dialer.(proxy.ContextDialer); ok {
+ return d.DialContext, nil
+ }
+ return func(ctx context.Context, net, addr string) (net.Conn, error) {
+ return dialer.Dial(net, addr)
+ }, nil
+}
+
+type httpProxyDialer struct {
+ proxyURL *url.URL
+ forwardDial netDialerFunc
+}
+
+func (hpd *httpProxyDialer) DialContext(ctx context.Context, network string, addr string) (net.Conn, error) {
+ hostPort, _ := hostPortNoPort(hpd.proxyURL)
+ conn, err := hpd.forwardDial(ctx, network, hostPort)
+ if err != nil {
+ return nil, err
+ }
+
+ connectHeader := make(http.Header)
+ if user := hpd.proxyURL.User; user != nil {
+ proxyUser := user.Username()
+ if proxyPassword, passwordSet := user.Password(); passwordSet {
+ credential := base64.StdEncoding.EncodeToString([]byte(proxyUser + ":" + proxyPassword))
+ connectHeader.Set("Proxy-Authorization", "Basic "+credential)
+ }
+ }
+
+ connectReq := &http.Request{
+ Method: http.MethodConnect,
+ URL: &url.URL{Opaque: addr},
+ Host: addr,
+ Header: connectHeader,
+ }
+
+ if err := connectReq.Write(conn); err != nil {
+ conn.Close()
+ return nil, err
+ }
+
+ // Read response. It's OK to use and discard buffered reader here because
+ // the remote server does not speak until spoken to.
+ br := bufio.NewReader(conn)
+ resp, err := http.ReadResponse(br, connectReq)
+ if err != nil {
+ conn.Close()
+ return nil, err
+ }
+
+ // Close the response body to silence false positives from linters. Reset
+ // the buffered reader first to ensure that Close() does not read from
+ // conn.
+ // Note: Applications must call resp.Body.Close() on a response returned
+ // http.ReadResponse to inspect trailers or read another response from the
+ // buffered reader. The call to resp.Body.Close() does not release
+ // resources.
+ br.Reset(bytes.NewReader(nil))
+ _ = resp.Body.Close()
+
+ if resp.StatusCode != http.StatusOK {
+ _ = conn.Close()
+ f := strings.SplitN(resp.Status, " ", 2)
+ return nil, errors.New(f[1])
+ }
+ return conn, nil
+}
+// Copyright 2013 The Gorilla WebSocket Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+
+// HandshakeError describes an error with the handshake from the peer.
+type HandshakeError struct {
+ message string
+}
+
+func (e HandshakeError) Error() string { return e.message }
+
+// Upgrader specifies parameters for upgrading an HTTP connection to a
+// WebSocket connection.
+//
+// It is safe to call Upgrader's methods concurrently.
+type Upgrader struct {
+ // HandshakeTimeout specifies the duration for the handshake to complete.
+ HandshakeTimeout time.Duration
+
+ // ReadBufferSize and WriteBufferSize specify I/O buffer sizes in bytes. If a buffer
+ // size is zero, then buffers allocated by the HTTP server are used. The
+ // I/O buffer sizes do not limit the size of the messages that can be sent
+ // or received.
+ ReadBufferSize, WriteBufferSize int
+
+ // WriteBufferPool is a pool of buffers for write operations. If the value
+ // is not set, then write buffers are allocated to the connection for the
+ // lifetime of the connection.
+ //
+ // A pool is most useful when the application has a modest volume of writes
+ // across a large number of connections.
+ //
+ // Applications should use a single pool for each unique value of
+ // WriteBufferSize.
+ WriteBufferPool BufferPool
+
+ // Subprotocols specifies the server's supported protocols in order of
+ // preference. If this field is not nil, then the Upgrade method negotiates a
+ // subprotocol by selecting the first match in this list with a protocol
+ // requested by the client. If there's no match, then no protocol is
+ // negotiated (the Sec-Websocket-Protocol header is not included in the
+ // handshake response).
+ Subprotocols []string
+
+ // Error specifies the function for generating HTTP error responses. If Error
+ // is nil, then http.Error is used to generate the HTTP response.
+ Error func(w http.ResponseWriter, r *http.Request, status int, reason error)
+
+ // CheckOrigin returns true if the request Origin header is acceptable. If
+ // CheckOrigin is nil, then a safe default is used: return false if the
+ // Origin request header is present and the origin host is not equal to
+ // request Host header.
+ //
+ // A CheckOrigin function should carefully validate the request origin to
+ // prevent cross-site request forgery.
+ CheckOrigin func(r *http.Request) bool
+
+ // EnableCompression specify if the server should attempt to negotiate per
+ // message compression (RFC 7692). Setting this value to true does not
+ // guarantee that compression will be supported. Currently only "no context
+ // takeover" modes are supported.
+ EnableCompression bool
+}
+
+func (u *Upgrader) returnError(w http.ResponseWriter, r *http.Request, status int, reason string) (*Conn, error) {
+ err := HandshakeError{reason}
+ if u.Error != nil {
+ u.Error(w, r, status, err)
+ } else {
+ w.Header().Set("Sec-Websocket-Version", "13")
+ http.Error(w, http.StatusText(status), status)
+ }
+ return nil, err
+}
+
+// checkSameOrigin returns true if the origin is not set or is equal to the request host.
+func checkSameOrigin(r *http.Request) bool {
+ origin := r.Header["Origin"]
+ if len(origin) == 0 {
+ return true
+ }
+ u, err := url.Parse(origin[0])
+ if err != nil {
+ return false
+ }
+ return equalASCIIFold(u.Host, r.Host)
+}
+
+func (u *Upgrader) selectSubprotocol(r *http.Request, responseHeader http.Header) string {
+ if u.Subprotocols != nil {
+ clientProtocols := Subprotocols(r)
+ for _, clientProtocol := range clientProtocols {
+ for _, serverProtocol := range u.Subprotocols {
+ if clientProtocol == serverProtocol {
+ return clientProtocol
+ }
+ }
+ }
+ } else if responseHeader != nil {
+ return responseHeader.Get("Sec-Websocket-Protocol")
+ }
+ return ""
+}
+
+// Upgrade upgrades the HTTP server connection to the WebSocket protocol.
+//
+// The responseHeader is included in the response to the client's upgrade
+// request. Use the responseHeader to specify cookies (Set-Cookie). To specify
+// subprotocols supported by the server, set Upgrader.Subprotocols directly.
+//
+// If the upgrade fails, then Upgrade replies to the client with an HTTP error
+// response.
+func (u *Upgrader) Upgrade(w http.ResponseWriter, r *http.Request, responseHeader http.Header) (*Conn, error) {
+ const badHandshake = "websocket: the client is not using the websocket protocol: "
+
+ if !tokenListContainsValue(r.Header, "Connection", "upgrade") {
+ return u.returnError(w, r, http.StatusBadRequest, badHandshake+"'upgrade' token not found in 'Connection' header")
+ }
+
+ if !tokenListContainsValue(r.Header, "Upgrade", "websocket") {
+ w.Header().Set("Upgrade", "websocket")
+ return u.returnError(w, r, http.StatusUpgradeRequired, badHandshake+"'websocket' token not found in 'Upgrade' header")
+ }
+
+ if r.Method != http.MethodGet {
+ return u.returnError(w, r, http.StatusMethodNotAllowed, badHandshake+"request method is not GET")
+ }
+
+ if !tokenListContainsValue(r.Header, "Sec-Websocket-Version", "13") {
+ return u.returnError(w, r, http.StatusBadRequest, "websocket: unsupported version: 13 not found in 'Sec-Websocket-Version' header")
+ }
+
+ if _, ok := responseHeader["Sec-Websocket-Extensions"]; ok {
+ return u.returnError(w, r, http.StatusInternalServerError, "websocket: application specific 'Sec-WebSocket-Extensions' headers are unsupported")
+ }
+
+ checkOrigin := u.CheckOrigin
+ if checkOrigin == nil {
+ checkOrigin = checkSameOrigin
+ }
+ if !checkOrigin(r) {
+ return u.returnError(w, r, http.StatusForbidden, "websocket: request origin not allowed by Upgrader.CheckOrigin")
+ }
+
+ challengeKey := r.Header.Get("Sec-Websocket-Key")
+ if !isValidChallengeKey(challengeKey) {
+ return u.returnError(w, r, http.StatusBadRequest, "websocket: not a websocket handshake: 'Sec-WebSocket-Key' header must be Base64 encoded value of 16-byte in length")
+ }
+
+ subprotocol := u.selectSubprotocol(r, responseHeader)
+
+ // Negotiate PMCE
+ var compress bool
+ if u.EnableCompression {
+ for _, ext := range parseExtensions(r.Header) {
+ if ext[""] != "permessage-deflate" {
+ continue
+ }
+ compress = true
+ break
+ }
+ }
+
+ netConn, brw, err := http.NewResponseController(w).Hijack()
+ if err != nil {
+ return u.returnError(w, r, http.StatusInternalServerError,
+ "websocket: hijack: "+err.Error())
+ }
+
+ // Close the network connection when returning an error. The variable
+ // netConn is set to nil before the success return at the end of the
+ // function.
+ defer func() {
+ if netConn != nil {
+ // It's safe to ignore the error from Close() because this code is
+ // only executed when returning a more important error to the
+ // application.
+ _ = netConn.Close()
+ }
+ }()
+
+ var br *bufio.Reader
+ if u.ReadBufferSize == 0 && brw.Reader.Size() > 256 {
+ // Use hijacked buffered reader as the connection reader.
+ br = brw.Reader
+ } else if brw.Reader.Buffered() > 0 {
+ // Wrap the network connection to read buffered data in brw.Reader
+ // before reading from the network connection. This should be rare
+ // because a client must not send message data before receiving the
+ // handshake response.
+ netConn = &brNetConn{br: brw.Reader, Conn: netConn}
+ }
+
+ buf := brw.Writer.AvailableBuffer()
+
+ var writeBuf []byte
+ if u.WriteBufferPool == nil && u.WriteBufferSize == 0 && len(buf) >= maxFrameHeaderSize+256 {
+ // Reuse hijacked write buffer as connection buffer.
+ writeBuf = buf
+ }
+
+ c := newConn(netConn, true, u.ReadBufferSize, u.WriteBufferSize, u.WriteBufferPool, br, writeBuf)
+ c.subprotocol = subprotocol
+
+ if compress {
+ c.newCompressionWriter = compressNoContextTakeover
+ c.newDecompressionReader = decompressNoContextTakeover
+ }
+
+ // Use larger of hijacked buffer and connection write buffer for header.
+ p := buf
+ if len(c.writeBuf) > len(p) {
+ p = c.writeBuf
+ }
+ p = p[:0]
+
+ p = append(p, "HTTP/1.1 101 Switching Protocols\r\nUpgrade: websocket\r\nConnection: Upgrade\r\nSec-WebSocket-Accept: "...)
+ p = append(p, computeAcceptKey(challengeKey)...)
+ p = append(p, "\r\n"...)
+ if c.subprotocol != "" {
+ p = append(p, "Sec-WebSocket-Protocol: "...)
+ p = append(p, c.subprotocol...)
+ p = append(p, "\r\n"...)
+ }
+ if compress {
+ p = append(p, "Sec-WebSocket-Extensions: permessage-deflate; server_no_context_takeover; client_no_context_takeover\r\n"...)
+ }
+ for k, vs := range responseHeader {
+ if k == "Sec-Websocket-Protocol" {
+ continue
+ }
+ for _, v := range vs {
+ p = append(p, k...)
+ p = append(p, ": "...)
+ for i := 0; i < len(v); i++ {
+ b := v[i]
+ if b <= 31 {
+ // prevent response splitting.
+ b = ' '
+ }
+ p = append(p, b)
+ }
+ p = append(p, "\r\n"...)
+ }
+ }
+ p = append(p, "\r\n"...)
+
+ if u.HandshakeTimeout > 0 {
+ if err := netConn.SetWriteDeadline(time.Now().Add(u.HandshakeTimeout)); err != nil {
+ return nil, err
+ }
+ } else {
+ // Clear deadlines set by HTTP server.
+ if err := netConn.SetDeadline(time.Time{}); err != nil {
+ return nil, err
+ }
+ }
+
+ if _, err = netConn.Write(p); err != nil {
+ return nil, err
+ }
+ if u.HandshakeTimeout > 0 {
+ if err := netConn.SetWriteDeadline(time.Time{}); err != nil {
+ return nil, err
+ }
+ }
+
+ // Success! Set netConn to nil to stop the deferred function above from
+ // closing the network connection.
+ netConn = nil
+
+ return c, nil
+}
+
+// Upgrade upgrades the HTTP server connection to the WebSocket protocol.
+//
+// Deprecated: Use websocket.Upgrader instead.
+//
+// Upgrade does not perform origin checking. The application is responsible for
+// checking the Origin header before calling Upgrade. An example implementation
+// of the same origin policy check is:
+//
+// if req.Header.Get("Origin") != "http://"+req.Host {
+// http.Error(w, "Origin not allowed", http.StatusForbidden)
+// return
+// }
+//
+// If the endpoint supports subprotocols, then the application is responsible
+// for negotiating the protocol used on the connection. Use the Subprotocols()
+// function to get the subprotocols requested by the client. Use the
+// Sec-Websocket-Protocol response header to specify the subprotocol selected
+// by the application.
+//
+// The responseHeader is included in the response to the client's upgrade
+// request. Use the responseHeader to specify cookies (Set-Cookie) and the
+// negotiated subprotocol (Sec-Websocket-Protocol).
+//
+// The connection buffers IO to the underlying network connection. The
+// readBufSize and writeBufSize parameters specify the size of the buffers to
+// use. Messages can be larger than the buffers.
+//
+// If the request is not a valid WebSocket handshake, then Upgrade returns an
+// error of type HandshakeError. Applications should handle this error by
+// replying to the client with an HTTP error response.
+func Upgrade(w http.ResponseWriter, r *http.Request, responseHeader http.Header, readBufSize, writeBufSize int) (*Conn, error) {
+ u := Upgrader{ReadBufferSize: readBufSize, WriteBufferSize: writeBufSize}
+ u.Error = func(w http.ResponseWriter, r *http.Request, status int, reason error) {
+ // don't return errors to maintain backwards compatibility
+ }
+ u.CheckOrigin = func(r *http.Request) bool {
+ // allow all connections by default
+ return true
+ }
+ return u.Upgrade(w, r, responseHeader)
+}
+
+// Subprotocols returns the subprotocols requested by the client in the
+// Sec-Websocket-Protocol header.
+func Subprotocols(r *http.Request) []string {
+ h := strings.TrimSpace(r.Header.Get("Sec-Websocket-Protocol"))
+ if h == "" {
+ return nil
+ }
+ protocols := strings.Split(h, ",")
+ for i := range protocols {
+ protocols[i] = strings.TrimSpace(protocols[i])
+ }
+ return protocols
+}
+
+// IsWebSocketUpgrade returns true if the client requested upgrade to the
+// WebSocket protocol.
+func IsWebSocketUpgrade(r *http.Request) bool {
+ return tokenListContainsValue(r.Header, "Connection", "upgrade") &&
+ tokenListContainsValue(r.Header, "Upgrade", "websocket")
+}
+
+type brNetConn struct {
+ br *bufio.Reader
+ net.Conn
+}
+
+func (b *brNetConn) Read(p []byte) (n int, err error) {
+ if b.br != nil {
+ // Limit read to buferred data.
+ if n := b.br.Buffered(); len(p) > n {
+ p = p[:n]
+ }
+ n, err = b.br.Read(p)
+ if b.br.Buffered() == 0 {
+ b.br = nil
+ }
+ return n, err
+ }
+ return b.Conn.Read(p)
+}
+
+// NetConn returns the underlying connection that is wrapped by b.
+func (b *brNetConn) NetConn() net.Conn {
+ return b.Conn
+}
+
+// Copyright 2013 The Gorilla WebSocket Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+
+var keyGUID = []byte("258EAFA5-E914-47DA-95CA-C5AB0DC85B11")
+
+func computeAcceptKey(challengeKey string) string {
+ h := sha1.New()
+ h.Write([]byte(challengeKey))
+ h.Write(keyGUID)
+ return base64.StdEncoding.EncodeToString(h.Sum(nil))
+}
+
+func generateChallengeKey() (string, error) {
+ p := make([]byte, 16)
+ if _, err := io.ReadFull(rand.Reader, p); err != nil {
+ return "", err
+ }
+ return base64.StdEncoding.EncodeToString(p), nil
+}
+
+// Token octets per RFC 2616.
+var isTokenOctet = [256]bool{
+ '!': true,
+ '#': true,
+ '$': true,
+ '%': true,
+ '&': true,
+ '\'': true,
+ '*': true,
+ '+': true,
+ '-': true,
+ '.': true,
+ '0': true,
+ '1': true,
+ '2': true,
+ '3': true,
+ '4': true,
+ '5': true,
+ '6': true,
+ '7': true,
+ '8': true,
+ '9': true,
+ 'A': true,
+ 'B': true,
+ 'C': true,
+ 'D': true,
+ 'E': true,
+ 'F': true,
+ 'G': true,
+ 'H': true,
+ 'I': true,
+ 'J': true,
+ 'K': true,
+ 'L': true,
+ 'M': true,
+ 'N': true,
+ 'O': true,
+ 'P': true,
+ 'Q': true,
+ 'R': true,
+ 'S': true,
+ 'T': true,
+ 'U': true,
+ 'W': true,
+ 'V': true,
+ 'X': true,
+ 'Y': true,
+ 'Z': true,
+ '^': true,
+ '_': true,
+ '`': true,
+ 'a': true,
+ 'b': true,
+ 'c': true,
+ 'd': true,
+ 'e': true,
+ 'f': true,
+ 'g': true,
+ 'h': true,
+ 'i': true,
+ 'j': true,
+ 'k': true,
+ 'l': true,
+ 'm': true,
+ 'n': true,
+ 'o': true,
+ 'p': true,
+ 'q': true,
+ 'r': true,
+ 's': true,
+ 't': true,
+ 'u': true,
+ 'v': true,
+ 'w': true,
+ 'x': true,
+ 'y': true,
+ 'z': true,
+ '|': true,
+ '~': true,
+}
+
+// skipSpace returns a slice of the string s with all leading RFC 2616 linear
+// whitespace removed.
+func skipSpace(s string) (rest string) {
+ i := 0
+ for ; i < len(s); i++ {
+ if b := s[i]; b != ' ' && b != '\t' {
+ break
+ }
+ }
+ return s[i:]
+}
+
+// nextToken returns the leading RFC 2616 token of s and the string following
+// the token.
+func nextToken(s string) (token, rest string) {
+ i := 0
+ for ; i < len(s); i++ {
+ if !isTokenOctet[s[i]] {
+ break
+ }
+ }
+ return s[:i], s[i:]
+}
+
+// nextTokenOrQuoted returns the leading token or quoted string per RFC 2616
+// and the string following the token or quoted string.
+func nextTokenOrQuoted(s string) (value string, rest string) {
+ if !strings.HasPrefix(s, "\"") {
+ return nextToken(s)
+ }
+ s = s[1:]
+ for i := 0; i < len(s); i++ {
+ switch s[i] {
+ case '"':
+ return s[:i], s[i+1:]
+ case '\\':
+ p := make([]byte, len(s)-1)
+ j := copy(p, s[:i])
+ escape := true
+ for i = i + 1; i < len(s); i++ {
+ b := s[i]
+ switch {
+ case escape:
+ escape = false
+ p[j] = b
+ j++
+ case b == '\\':
+ escape = true
+ case b == '"':
+ return string(p[:j]), s[i+1:]
+ default:
+ p[j] = b
+ j++
+ }
+ }
+ return "", ""
+ }
+ }
+ return "", ""
+}
+
+// equalASCIIFold returns true if s is equal to t with ASCII case folding as
+// defined in RFC 4790.
+func equalASCIIFold(s, t string) bool {
+ for s != "" && t != "" {
+ sr, size := utf8.DecodeRuneInString(s)
+ s = s[size:]
+ tr, size := utf8.DecodeRuneInString(t)
+ t = t[size:]
+ if sr == tr {
+ continue
+ }
+ if 'A' <= sr && sr <= 'Z' {
+ sr = sr + 'a' - 'A'
+ }
+ if 'A' <= tr && tr <= 'Z' {
+ tr = tr + 'a' - 'A'
+ }
+ if sr != tr {
+ return false
+ }
+ }
+ return s == t
+}
+
+// tokenListContainsValue returns true if the 1#token header with the given
+// name contains a token equal to value with ASCII case folding.
+func tokenListContainsValue(header http.Header, name string, value string) bool {
+headers:
+ for _, s := range header[name] {
+ for {
+ var t string
+ t, s = nextToken(skipSpace(s))
+ if t == "" {
+ continue headers
+ }
+ s = skipSpace(s)
+ if s != "" && s[0] != ',' {
+ continue headers
+ }
+ if equalASCIIFold(t, value) {
+ return true
+ }
+ if s == "" {
+ continue headers
+ }
+ s = s[1:]
+ }
+ }
+ return false
+}
+
+// parseExtensions parses WebSocket extensions from a header.
+func parseExtensions(header http.Header) []map[string]string {
+ // From RFC 6455:
+ //
+ // Sec-WebSocket-Extensions = extension-list
+ // extension-list = 1#extension
+ // extension = extension-token *( ";" extension-param )
+ // extension-token = registered-token
+ // registered-token = token
+ // extension-param = token [ "=" (token | quoted-string) ]
+ // ;When using the quoted-string syntax variant, the value
+ // ;after quoted-string unescaping MUST conform to the
+ // ;'token' ABNF.
+
+ var result []map[string]string
+headers:
+ for _, s := range header["Sec-Websocket-Extensions"] {
+ for {
+ var t string
+ t, s = nextToken(skipSpace(s))
+ if t == "" {
+ continue headers
+ }
+ ext := map[string]string{"": t}
+ for {
+ s = skipSpace(s)
+ if !strings.HasPrefix(s, ";") {
+ break
+ }
+ var k string
+ k, s = nextToken(skipSpace(s[1:]))
+ if k == "" {
+ continue headers
+ }
+ s = skipSpace(s)
+ var v string
+ if strings.HasPrefix(s, "=") {
+ v, s = nextTokenOrQuoted(skipSpace(s[1:]))
+ s = skipSpace(s)
+ }
+ if s != "" && s[0] != ',' && s[0] != ';' {
+ continue headers
+ }
+ ext[k] = v
+ }
+ if s != "" && s[0] != ',' {
+ continue headers
+ }
+ result = append(result, ext)
+ if s == "" {
+ continue headers
+ }
+ s = s[1:]
+ }
+ }
+ return result
+}
+
+// isValidChallengeKey checks if the argument meets RFC6455 specification.
+func isValidChallengeKey(s string) bool {
+ // From RFC6455:
+ //
+ // A |Sec-WebSocket-Key| header field with a base64-encoded (see
+ // Section 4 of [RFC4648]) value that, when decoded, is 16 bytes in
+ // length.
+
+ if s == "" {
+ return false
+ }
+ decoded, err := base64.StdEncoding.DecodeString(s)
+ return err == nil && len(decoded) == 16
+}
+
+type CLIArgs struct {
+ FromAddr string
+ ToAddr string
+}
+
+
+
+const X = 1
+
+var EmitActiveConnection = g.MakeGauge("active-connections")
+
+
+
+func ParseArgs(args []string) CLIArgs {
+ if len(args) != 3 {
+ fmt.Fprintf(
+ os.Stderr,
+ "Usage: %s FROM.socket TO.socket\n",
+ args[0],
+ )
+ os.Exit(2)
+ }
+ return CLIArgs {
+ FromAddr: args[1],
+ ToAddr: args[2],
+ }
+}
+
+func Listen(fromAddr string) net.Listener {
+ listener, err := net.Listen("unix", fromAddr)
+ g.FatalIf(err)
+ g.Info("Started listening", "listen-start", "from-address", fromAddr)
+ return listener
+}
+
+func copyData(c chan struct {}, from io.Reader, to io.WriteCloser) {
+ io.Copy(to, from)
+ c <- struct {} {}
+}
+
+func Start(toAddr string, listener net.Listener) {
+ /*
+ upgrader := websocket.Upgrader {}
+ http.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) {
+ connFrom, err := upgrader.Upgrade(w, r, nil)
+ if err != nil {
+ g.Warning(
+ "Error upgrading connection",
+ "upgrade-connection-error",
+ "err", err,
+ )
+ return
+ }
+ defer connFrom.Close()
+ EmitActiveConnection.Inc()
+
+ connTo, err := net.Dial("unix", toAddr)
+ if err != nil {
+ g.Error(
+ "Error dialing connection",
+ "dial-connection",
+ "err", err,
+ )
+ os.Exit(1)
+ }
+ defer connTo.Close()
+
+ messageType, reader, err := connFrom.NextReader()
+ if err != nil {
+ g.Warning(
+ "Failed to get next reader from connection",
+ "connection-next-reader-error",
+ "err", err,
+ )
+ return
+ }
+
+ writer, err := connFrom.NextWriter(messageType)
+ if err != nil {
+ g.Warning(
+ "Failed to get next reader from connection",
+ "connection-next-reader-error",
+ "err", err,
+ )
+ return
+ }
+
+
+ c := make(chan struct {})
+ go copyData(c, connTo, writer)
+ go copyData(c, reader, connTo)
+ go func() {
+ <- c
+ EmitActiveConnection.Dec()
+ }()
+ });
+
+ server := http.Server{}
+ err := server.Serve(listener)
+ g.FatalIf(err)
+ */
+}
+
+
+func Main() {
+ g.Init()
+ args := ParseArgs(os.Args)
+ listener := Listen(args.FromAddr)
+ Start(args.ToAddr, listener)
+}
diff --git a/tests/lib_test.go b/tests/lib_test.go
deleted file mode 100644
index 1007609..0000000
--- a/tests/lib_test.go
+++ /dev/null
@@ -1,24 +0,0 @@
-package wscat_test
-
-import (
- "testing"
-
- g "euandre.org/gobang/src"
-
- "euandre.org/wscat/src"
-)
-
-
-func TestPlaceholder(t *testing.T) {
- g.AssertEqual(t, wscat.X, 1)
-}
-
-func TestParseArgs(t *testing.T) {
- given := wscat.ParseArgs([]string { "x", "y", "z" })
- expected := wscat.CLIArgs {
- FromAddr: "y",
- ToAddr: "z",
- }
-
- g.AssertEqual(t, given, expected)
-}
diff --git a/tests/main.go b/tests/main.go
new file mode 100644
index 0000000..8123486
--- /dev/null
+++ b/tests/main.go
@@ -0,0 +1,7 @@
+package main
+
+import "wscat"
+
+func main() {
+ wscat.MainTest()
+}
diff --git a/tests/wscat.go b/tests/wscat.go
new file mode 100644
index 0000000..da16f66
--- /dev/null
+++ b/tests/wscat.go
@@ -0,0 +1,2842 @@
+package wscat
+
+import (
+ "bufio"
+ "bytes"
+ "compress/flate"
+ "context"
+ "crypto/tls"
+ "crypto/x509"
+ "encoding/base64"
+ "encoding/binary"
+ "encoding/json"
+ "errors"
+ "fmt"
+ "io"
+ "log"
+ "math/rand"
+ "net"
+ "net/http"
+ "net/http/cookiejar"
+ "net/http/httptest"
+ "net/http/httptrace"
+ "net/url"
+ "os"
+ "reflect"
+ "strings"
+ "sync"
+ "sync/atomic"
+ "testing"
+ "testing/internal/testdeps"
+ "testing/iotest"
+ "time"
+
+ g "gobang"
+)
+
+
+
+var cstUpgrader = Upgrader{
+ Subprotocols: []string{"p0", "p1"},
+ ReadBufferSize: 1024,
+ WriteBufferSize: 1024,
+ EnableCompression: true,
+ Error: func(w http.ResponseWriter, r *http.Request, status int, reason error) {
+ http.Error(w, reason.Error(), status)
+ },
+}
+
+var cstDialer = Dialer{
+ Subprotocols: []string{"p1", "p2"},
+ ReadBufferSize: 1024,
+ WriteBufferSize: 1024,
+ HandshakeTimeout: 30 * time.Second,
+}
+
+type cstHandler struct {
+ *testing.T
+ s *cstServer
+}
+
+type cstServer struct {
+ URL string
+ Server *httptest.Server
+ wg sync.WaitGroup
+}
+
+const (
+ cstPath = "/a/b"
+ cstRawQuery = "x=y"
+ cstRequestURI = cstPath + "?" + cstRawQuery
+)
+
+func (s *cstServer) Close() {
+ s.Server.Close()
+ // Wait for handler functions to complete.
+ s.wg.Wait()
+}
+
+func newServer(t *testing.T) *cstServer {
+ var s cstServer
+ s.Server = httptest.NewServer(cstHandler{T: t, s: &s})
+ s.Server.URL += cstRequestURI
+ s.URL = makeWsProto(s.Server.URL)
+ return &s
+}
+
+func newTLSServer(t *testing.T) *cstServer {
+ var s cstServer
+ s.Server = httptest.NewTLSServer(cstHandler{T: t, s: &s})
+ s.Server.URL += cstRequestURI
+ s.URL = makeWsProto(s.Server.URL)
+ return &s
+}
+
+func (t cstHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
+ // Because tests wait for a response from a server, we are guaranteed that
+ // the wait group count is incremented before the test waits on the group
+ // in the call to (*cstServer).Close().
+ t.s.wg.Add(1)
+ defer t.s.wg.Done()
+
+ if r.URL.Path != cstPath {
+ t.Logf("path=%v, want %v", r.URL.Path, cstPath)
+ http.Error(w, "bad path", http.StatusBadRequest)
+ return
+ }
+ if r.URL.RawQuery != cstRawQuery {
+ t.Logf("query=%v, want %v", r.URL.RawQuery, cstRawQuery)
+ http.Error(w, "bad path", http.StatusBadRequest)
+ return
+ }
+ subprotos := Subprotocols(r)
+ if !reflect.DeepEqual(subprotos, cstDialer.Subprotocols) {
+ t.Logf("subprotols=%v, want %v", subprotos, cstDialer.Subprotocols)
+ http.Error(w, "bad protocol", http.StatusBadRequest)
+ return
+ }
+ ws, err := cstUpgrader.Upgrade(w, r, http.Header{"Set-Cookie": {"sessionID=1234"}})
+ if err != nil {
+ t.Logf("Upgrade: %v", err)
+ return
+ }
+ defer ws.Close()
+
+ if ws.Subprotocol() != "p1" {
+ t.Logf("Subprotocol() = %s, want p1", ws.Subprotocol())
+ ws.Close()
+ return
+ }
+ op, rd, err := ws.NextReader()
+ if err != nil {
+ t.Logf("NextReader: %v", err)
+ return
+ }
+ wr, err := ws.NextWriter(op)
+ if err != nil {
+ t.Logf("NextWriter: %v", err)
+ return
+ }
+ if _, err = io.Copy(wr, rd); err != nil {
+ t.Logf("NextWriter: %v", err)
+ return
+ }
+ if err := wr.Close(); err != nil {
+ t.Logf("Close: %v", err)
+ return
+ }
+}
+
+func makeWsProto(s string) string {
+ return "ws" + strings.TrimPrefix(s, "http")
+}
+
+func sendRecv(t *testing.T, ws *Conn) {
+ const message = "Hello World!"
+ if err := ws.SetWriteDeadline(time.Now().Add(time.Second)); err != nil {
+ t.Fatalf("SetWriteDeadline: %v", err)
+ }
+ if err := ws.WriteMessage(TextMessage, []byte(message)); err != nil {
+ t.Fatalf("WriteMessage: %v", err)
+ }
+ if err := ws.SetReadDeadline(time.Now().Add(time.Second)); err != nil {
+ t.Fatalf("SetReadDeadline: %v", err)
+ }
+ _, p, err := ws.ReadMessage()
+ if err != nil {
+ t.Fatalf("ReadMessage: %v", err)
+ }
+ if string(p) != message {
+ t.Fatalf("message=%s, want %s", p, message)
+ }
+}
+
+func TestProxyDial(t *testing.T) {
+
+ s := newServer(t)
+ defer s.Close()
+
+ surl, _ := url.Parse(s.Server.URL)
+
+ cstDialer := cstDialer // make local copy for modification on next line.
+ cstDialer.Proxy = http.ProxyURL(surl)
+
+ connect := false
+ origHandler := s.Server.Config.Handler
+
+ // Capture the request Host header.
+ s.Server.Config.Handler = http.HandlerFunc(
+ func(w http.ResponseWriter, r *http.Request) {
+ if r.Method == http.MethodConnect {
+ connect = true
+ w.WriteHeader(http.StatusOK)
+ return
+ }
+
+ if !connect {
+ t.Log("connect not received")
+ http.Error(w, "connect not received", http.StatusMethodNotAllowed)
+ return
+ }
+ origHandler.ServeHTTP(w, r)
+ })
+
+ ws, _, err := cstDialer.Dial(s.URL, nil)
+ if err != nil {
+ t.Fatalf("Dial: %v", err)
+ }
+ defer ws.Close()
+ sendRecv(t, ws)
+}
+
+func TestProxyAuthorizationDial(t *testing.T) {
+ s := newServer(t)
+ defer s.Close()
+
+ surl, _ := url.Parse(s.Server.URL)
+ surl.User = url.UserPassword("username", "password")
+
+ cstDialer := cstDialer // make local copy for modification on next line.
+ cstDialer.Proxy = http.ProxyURL(surl)
+
+ connect := false
+ origHandler := s.Server.Config.Handler
+
+ // Capture the request Host header.
+ s.Server.Config.Handler = http.HandlerFunc(
+ func(w http.ResponseWriter, r *http.Request) {
+ proxyAuth := r.Header.Get("Proxy-Authorization")
+ expectedProxyAuth := "Basic " + base64.StdEncoding.EncodeToString([]byte("username:password"))
+ if r.Method == http.MethodConnect && proxyAuth == expectedProxyAuth {
+ connect = true
+ w.WriteHeader(http.StatusOK)
+ return
+ }
+
+ if !connect {
+ t.Log("connect with proxy authorization not received")
+ http.Error(w, "connect with proxy authorization not received", http.StatusMethodNotAllowed)
+ return
+ }
+ origHandler.ServeHTTP(w, r)
+ })
+
+ ws, _, err := cstDialer.Dial(s.URL, nil)
+ if err != nil {
+ t.Fatalf("Dial: %v", err)
+ }
+ defer ws.Close()
+ sendRecv(t, ws)
+}
+
+func TestDial(t *testing.T) {
+ s := newServer(t)
+ defer s.Close()
+
+ ws, _, err := cstDialer.Dial(s.URL, nil)
+ if err != nil {
+ t.Fatalf("Dial: %v", err)
+ }
+ defer ws.Close()
+ sendRecv(t, ws)
+}
+
+func TestDialCookieJar(t *testing.T) {
+ s := newServer(t)
+ defer s.Close()
+
+ jar, _ := cookiejar.New(nil)
+ d := cstDialer
+ d.Jar = jar
+
+ u, _ := url.Parse(s.URL)
+
+ switch u.Scheme {
+ case "ws":
+ u.Scheme = "http"
+ case "wss":
+ u.Scheme = "https"
+ }
+
+ cookies := []*http.Cookie{{Name: "gorilla", Value: "ws", Path: "/"}}
+ d.Jar.SetCookies(u, cookies)
+
+ ws, _, err := d.Dial(s.URL, nil)
+ if err != nil {
+ t.Fatalf("Dial: %v", err)
+ }
+ defer ws.Close()
+
+ var gorilla string
+ var sessionID string
+ for _, c := range d.Jar.Cookies(u) {
+ if c.Name == "gorilla" {
+ gorilla = c.Value
+ }
+
+ if c.Name == "sessionID" {
+ sessionID = c.Value
+ }
+ }
+ if gorilla != "ws" {
+ t.Error("Cookie not present in jar.")
+ }
+
+ if sessionID != "1234" {
+ t.Error("Set-Cookie not received from the server.")
+ }
+
+ sendRecv(t, ws)
+}
+
+func rootCAs(t *testing.T, s *httptest.Server) *x509.CertPool {
+ certs := x509.NewCertPool()
+ for _, c := range s.TLS.Certificates {
+ roots, err := x509.ParseCertificates(c.Certificate[len(c.Certificate)-1])
+ if err != nil {
+ t.Fatalf("error parsing server's root cert: %v", err)
+ }
+ for _, root := range roots {
+ certs.AddCert(root)
+ }
+ }
+ return certs
+}
+
+func TestDialTLS(t *testing.T) {
+ s := newTLSServer(t)
+ defer s.Close()
+
+ d := cstDialer
+ d.TLSClientConfig = &tls.Config{RootCAs: rootCAs(t, s.Server)}
+ ws, _, err := d.Dial(s.URL, nil)
+ if err != nil {
+ t.Fatalf("Dial: %v", err)
+ }
+ defer ws.Close()
+ sendRecv(t, ws)
+}
+
+func TestDialTimeout(t *testing.T) {
+ s := newServer(t)
+ defer s.Close()
+
+ d := cstDialer
+ d.HandshakeTimeout = -1
+ ws, _, err := d.Dial(s.URL, nil)
+ if err == nil {
+ ws.Close()
+ t.Fatalf("Dial: nil")
+ }
+}
+
+// requireDeadlineNetConn fails the current test when Read or Write are called
+// with no deadline.
+type requireDeadlineNetConn struct {
+ t *testing.T
+ c net.Conn
+ readDeadlineIsSet bool
+ writeDeadlineIsSet bool
+}
+
+func (c *requireDeadlineNetConn) SetDeadline(t time.Time) error {
+ c.writeDeadlineIsSet = !t.Equal(time.Time{})
+ c.readDeadlineIsSet = c.writeDeadlineIsSet
+ return c.c.SetDeadline(t)
+}
+
+func (c *requireDeadlineNetConn) SetReadDeadline(t time.Time) error {
+ c.readDeadlineIsSet = !t.Equal(time.Time{})
+ return c.c.SetDeadline(t)
+}
+
+func (c *requireDeadlineNetConn) SetWriteDeadline(t time.Time) error {
+ c.writeDeadlineIsSet = !t.Equal(time.Time{})
+ return c.c.SetDeadline(t)
+}
+
+func (c *requireDeadlineNetConn) Write(p []byte) (int, error) {
+ if !c.writeDeadlineIsSet {
+ c.t.Fatalf("write with no deadline")
+ }
+ return c.c.Write(p)
+}
+
+func (c *requireDeadlineNetConn) Read(p []byte) (int, error) {
+ if !c.readDeadlineIsSet {
+ c.t.Fatalf("read with no deadline")
+ }
+ return c.c.Read(p)
+}
+
+func (c *requireDeadlineNetConn) Close() error { return c.c.Close() }
+func (c *requireDeadlineNetConn) LocalAddr() net.Addr { return c.c.LocalAddr() }
+func (c *requireDeadlineNetConn) RemoteAddr() net.Addr { return c.c.RemoteAddr() }
+
+func TestHandshakeTimeout(t *testing.T) {
+ s := newServer(t)
+ defer s.Close()
+
+ d := cstDialer
+ d.NetDial = func(n, a string) (net.Conn, error) {
+ c, err := net.Dial(n, a)
+ return &requireDeadlineNetConn{c: c, t: t}, err
+ }
+ ws, _, err := d.Dial(s.URL, nil)
+ if err != nil {
+ t.Fatal("Dial:", err)
+ }
+ ws.Close()
+}
+
+func TestHandshakeTimeoutInContext(t *testing.T) {
+ s := newServer(t)
+ defer s.Close()
+
+ d := cstDialer
+ d.HandshakeTimeout = 0
+ d.NetDialContext = func(ctx context.Context, n, a string) (net.Conn, error) {
+ netDialer := &net.Dialer{}
+ c, err := netDialer.DialContext(ctx, n, a)
+ return &requireDeadlineNetConn{c: c, t: t}, err
+ }
+
+ ctx, cancel := context.WithDeadline(context.Background(), time.Now().Add(30*time.Second))
+ defer cancel()
+ ws, _, err := d.DialContext(ctx, s.URL, nil)
+ if err != nil {
+ t.Fatal("Dial:", err)
+ }
+ ws.Close()
+}
+
+func TestDialBadScheme(t *testing.T) {
+ s := newServer(t)
+ defer s.Close()
+
+ ws, _, err := cstDialer.Dial(s.Server.URL, nil)
+ if err == nil {
+ ws.Close()
+ t.Fatalf("Dial: nil")
+ }
+}
+
+func TestDialBadOrigin(t *testing.T) {
+ s := newServer(t)
+ defer s.Close()
+
+ ws, resp, err := cstDialer.Dial(s.URL, http.Header{"Origin": {"bad"}})
+ if err == nil {
+ ws.Close()
+ t.Fatalf("Dial: nil")
+ }
+ if resp == nil {
+ t.Fatalf("resp=nil, err=%v", err)
+ }
+ if resp.StatusCode != http.StatusForbidden {
+ t.Fatalf("status=%d, want %d", resp.StatusCode, http.StatusForbidden)
+ }
+}
+
+func TestDialBadHeader(t *testing.T) {
+ s := newServer(t)
+ defer s.Close()
+
+ for _, k := range []string{"Upgrade",
+ "Connection",
+ "Sec-Websocket-Key",
+ "Sec-Websocket-Version",
+ "Sec-Websocket-Protocol"} {
+ h := http.Header{}
+ h.Set(k, "bad")
+ ws, _, err := cstDialer.Dial(s.URL, http.Header{"Origin": {"bad"}})
+ if err == nil {
+ ws.Close()
+ t.Errorf("Dial with header %s returned nil", k)
+ }
+ }
+}
+
+func TestBadMethod(t *testing.T) {
+ s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ ws, err := cstUpgrader.Upgrade(w, r, nil)
+ if err == nil {
+ t.Errorf("handshake succeeded, expect fail")
+ ws.Close()
+ }
+ }))
+ defer s.Close()
+
+ req, err := http.NewRequest(http.MethodPost, s.URL, strings.NewReader(""))
+ if err != nil {
+ t.Fatalf("NewRequest returned error %v", err)
+ }
+ req.Header.Set("Connection", "upgrade")
+ req.Header.Set("Upgrade", "websocket")
+ req.Header.Set("Sec-Websocket-Version", "13")
+
+ resp, err := http.DefaultClient.Do(req)
+ if err != nil {
+ t.Fatalf("Do returned error %v", err)
+ }
+ resp.Body.Close()
+ if resp.StatusCode != http.StatusMethodNotAllowed {
+ t.Errorf("Status = %d, want %d", resp.StatusCode, http.StatusMethodNotAllowed)
+ }
+}
+
+func TestNoUpgrade(t *testing.T) {
+ t.Parallel()
+ s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ ws, err := cstUpgrader.Upgrade(w, r, nil)
+ if err == nil {
+ t.Errorf("handshake succeeded, expect fail")
+ ws.Close()
+ }
+ }))
+ defer s.Close()
+
+ req, err := http.NewRequest(http.MethodGet, s.URL, strings.NewReader(""))
+ if err != nil {
+ t.Fatalf("NewRequest returned error %v", err)
+ }
+ req.Header.Set("Connection", "upgrade")
+ req.Header.Set("Sec-Websocket-Version", "13")
+
+ resp, err := http.DefaultClient.Do(req)
+ if err != nil {
+ t.Fatalf("Do returned error %v", err)
+ }
+ resp.Body.Close()
+ if u := resp.Header.Get("Upgrade"); u != "websocket" {
+ t.Errorf("Upgrade response header is %q, want %q", u, "websocket")
+ }
+ if resp.StatusCode != http.StatusUpgradeRequired {
+ t.Errorf("Status = %d, want %d", resp.StatusCode, http.StatusUpgradeRequired)
+ }
+}
+
+func TestDialExtraTokensInRespHeaders(t *testing.T) {
+ s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ challengeKey := r.Header.Get("Sec-Websocket-Key")
+ w.Header().Set("Upgrade", "foo, websocket")
+ w.Header().Set("Connection", "upgrade, keep-alive")
+ w.Header().Set("Sec-Websocket-Accept", computeAcceptKey(challengeKey))
+ w.WriteHeader(101)
+ }))
+ defer s.Close()
+
+ ws, _, err := cstDialer.Dial(makeWsProto(s.URL), nil)
+ if err != nil {
+ t.Fatalf("Dial: %v", err)
+ }
+ defer ws.Close()
+}
+
+func TestHandshake(t *testing.T) {
+ s := newServer(t)
+ defer s.Close()
+
+ ws, resp, err := cstDialer.Dial(s.URL, http.Header{"Origin": {s.URL}})
+ if err != nil {
+ t.Fatalf("Dial: %v", err)
+ }
+ defer ws.Close()
+
+ var sessionID string
+ for _, c := range resp.Cookies() {
+ if c.Name == "sessionID" {
+ sessionID = c.Value
+ }
+ }
+ if sessionID != "1234" {
+ t.Error("Set-Cookie not received from the server.")
+ }
+
+ if ws.Subprotocol() != "p1" {
+ t.Errorf("ws.Subprotocol() = %s, want p1", ws.Subprotocol())
+ }
+ sendRecv(t, ws)
+}
+
+func TestRespOnBadHandshake(t *testing.T) {
+ const expectedStatus = http.StatusGone
+ const expectedBody = "This is the response body."
+
+ s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ w.WriteHeader(expectedStatus)
+ _, _ = io.WriteString(w, expectedBody)
+ }))
+ defer s.Close()
+
+ ws, resp, err := cstDialer.Dial(makeWsProto(s.URL), nil)
+ if err == nil {
+ ws.Close()
+ t.Fatalf("Dial: nil")
+ }
+
+ if resp == nil {
+ t.Fatalf("resp=nil, err=%v", err)
+ }
+
+ if resp.StatusCode != expectedStatus {
+ t.Errorf("resp.StatusCode=%d, want %d", resp.StatusCode, expectedStatus)
+ }
+
+ p, err := io.ReadAll(resp.Body)
+ if err != nil {
+ t.Fatalf("ReadFull(resp.Body) returned error %v", err)
+ }
+
+ if string(p) != expectedBody {
+ t.Errorf("resp.Body=%s, want %s", p, expectedBody)
+ }
+}
+
+type testLogWriter struct {
+ t *testing.T
+}
+
+func (w testLogWriter) Write(p []byte) (int, error) {
+ w.t.Logf("%s", p)
+ return len(p), nil
+}
+
+// TestHost tests handling of host names and confirms that it matches net/http.
+func TestHost(t *testing.T) {
+
+ upgrader := Upgrader{}
+ handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ if IsWebSocketUpgrade(r) {
+ c, err := upgrader.Upgrade(w, r, http.Header{"X-Test-Host": {r.Host}})
+ if err != nil {
+ t.Fatal(err)
+ }
+ c.Close()
+ } else {
+ w.Header().Set("X-Test-Host", r.Host)
+ }
+ })
+
+ server := httptest.NewServer(handler)
+ defer server.Close()
+
+ tlsServer := httptest.NewTLSServer(handler)
+ defer tlsServer.Close()
+
+ addrs := map[*httptest.Server]string{server: server.Listener.Addr().String(), tlsServer: tlsServer.Listener.Addr().String()}
+ wsProtos := map[*httptest.Server]string{server: "ws://", tlsServer: "wss://"}
+ httpProtos := map[*httptest.Server]string{server: "http://", tlsServer: "https://"}
+
+ // Avoid log noise from net/http server by logging to testing.T
+ server.Config.ErrorLog = log.New(testLogWriter{t}, "", 0)
+ tlsServer.Config.ErrorLog = server.Config.ErrorLog
+
+ cas := rootCAs(t, tlsServer)
+
+ tests := []struct {
+ fail bool // true if dial / get should fail
+ server *httptest.Server // server to use
+ url string // host for request URI
+ header string // optional request host header
+ tls string // optional host for tls ServerName
+ wantAddr string // expected host for dial
+ wantHeader string // expected request header on server
+ insecureSkipVerify bool
+ }{
+ {
+ server: server,
+ url: addrs[server],
+ wantAddr: addrs[server],
+ wantHeader: addrs[server],
+ },
+ {
+ server: tlsServer,
+ url: addrs[tlsServer],
+ wantAddr: addrs[tlsServer],
+ wantHeader: addrs[tlsServer],
+ },
+
+ {
+ server: server,
+ url: addrs[server],
+ header: "badhost.com",
+ wantAddr: addrs[server],
+ wantHeader: "badhost.com",
+ },
+ {
+ server: tlsServer,
+ url: addrs[tlsServer],
+ header: "badhost.com",
+ wantAddr: addrs[tlsServer],
+ wantHeader: "badhost.com",
+ },
+
+ {
+ server: server,
+ url: "example.com",
+ header: "badhost.com",
+ wantAddr: "example.com:80",
+ wantHeader: "badhost.com",
+ },
+ {
+ server: tlsServer,
+ url: "example.com",
+ header: "badhost.com",
+ wantAddr: "example.com:443",
+ wantHeader: "badhost.com",
+ },
+
+ {
+ server: server,
+ url: "badhost.com",
+ header: "example.com",
+ wantAddr: "badhost.com:80",
+ wantHeader: "example.com",
+ },
+ {
+ fail: true,
+ server: tlsServer,
+ url: "badhost.com",
+ header: "example.com",
+ wantAddr: "badhost.com:443",
+ },
+ {
+ server: tlsServer,
+ url: "badhost.com",
+ insecureSkipVerify: true,
+ wantAddr: "badhost.com:443",
+ wantHeader: "badhost.com",
+ },
+ {
+ server: tlsServer,
+ url: "badhost.com",
+ tls: "example.com",
+ wantAddr: "badhost.com:443",
+ wantHeader: "badhost.com",
+ },
+ }
+
+ for i, tt := range tests {
+
+ tls := &tls.Config{
+ RootCAs: cas,
+ ServerName: tt.tls,
+ InsecureSkipVerify: tt.insecureSkipVerify,
+ }
+
+ var gotAddr string
+ dialer := Dialer{
+ NetDial: func(network, addr string) (net.Conn, error) {
+ gotAddr = addr
+ return net.Dial(network, addrs[tt.server])
+ },
+ TLSClientConfig: tls,
+ }
+
+ // Test websocket dial
+
+ h := http.Header{}
+ if tt.header != "" {
+ h.Set("Host", tt.header)
+ }
+ c, resp, err := dialer.Dial(wsProtos[tt.server]+tt.url+"/", h)
+ if err == nil {
+ c.Close()
+ }
+
+ check := func(protos map[*httptest.Server]string) {
+ name := fmt.Sprintf("%d: %s%s/ header[Host]=%q, tls.ServerName=%q", i+1, protos[tt.server], tt.url, tt.header, tt.tls)
+ if gotAddr != tt.wantAddr {
+ t.Errorf("%s: got addr %s, want %s", name, gotAddr, tt.wantAddr)
+ }
+ switch {
+ case tt.fail && err == nil:
+ t.Errorf("%s: unexpected success", name)
+ case !tt.fail && err != nil:
+ t.Errorf("%s: unexpected error %v", name, err)
+ case !tt.fail && err == nil:
+ if gotHost := resp.Header.Get("X-Test-Host"); gotHost != tt.wantHeader {
+ t.Errorf("%s: got host %s, want %s", name, gotHost, tt.wantHeader)
+ }
+ }
+ }
+
+ check(wsProtos)
+
+ // Confirm that net/http has same result
+
+ transport := &http.Transport{
+ Dial: dialer.NetDial,
+ TLSClientConfig: dialer.TLSClientConfig,
+ }
+ req, _ := http.NewRequest(http.MethodGet, httpProtos[tt.server]+tt.url+"/", nil)
+ if tt.header != "" {
+ req.Host = tt.header
+ }
+ client := &http.Client{Transport: transport}
+ resp, err = client.Do(req)
+ if err == nil {
+ resp.Body.Close()
+ }
+ transport.CloseIdleConnections()
+ check(httpProtos)
+ }
+}
+
+func TestDialCompression(t *testing.T) {
+ s := newServer(t)
+ defer s.Close()
+
+ dialer := cstDialer
+ dialer.EnableCompression = true
+ ws, _, err := dialer.Dial(s.URL, nil)
+ if err != nil {
+ t.Fatalf("Dial: %v", err)
+ }
+ defer ws.Close()
+ sendRecv(t, ws)
+}
+
+func TestSocksProxyDial(t *testing.T) {
+ s := newServer(t)
+ defer s.Close()
+
+ proxyListener, err := net.Listen("tcp", "127.0.0.1:0")
+ if err != nil {
+ t.Fatalf("listen failed: %v", err)
+ }
+ defer proxyListener.Close()
+ go func() {
+ c1, err := proxyListener.Accept()
+ if err != nil {
+ t.Errorf("proxy accept failed: %v", err)
+ return
+ }
+ defer c1.Close()
+
+ _ = c1.SetDeadline(time.Now().Add(30 * time.Second))
+
+ buf := make([]byte, 32)
+ if _, err := io.ReadFull(c1, buf[:3]); err != nil {
+ t.Errorf("read failed: %v", err)
+ return
+ }
+ if want := []byte{5, 1, 0}; !bytes.Equal(want, buf[:len(want)]) {
+ t.Errorf("read %x, want %x", buf[:len(want)], want)
+ }
+ if _, err := c1.Write([]byte{5, 0}); err != nil {
+ t.Errorf("write failed: %v", err)
+ return
+ }
+ if _, err := io.ReadFull(c1, buf[:10]); err != nil {
+ t.Errorf("read failed: %v", err)
+ return
+ }
+ if want := []byte{5, 1, 0, 1}; !bytes.Equal(want, buf[:len(want)]) {
+ t.Errorf("read %x, want %x", buf[:len(want)], want)
+ return
+ }
+ buf[1] = 0
+ if _, err := c1.Write(buf[:10]); err != nil {
+ t.Errorf("write failed: %v", err)
+ return
+ }
+
+ ip := net.IP(buf[4:8])
+ port := binary.BigEndian.Uint16(buf[8:10])
+
+ c2, err := net.DialTCP("tcp", nil, &net.TCPAddr{IP: ip, Port: int(port)})
+ if err != nil {
+ t.Errorf("dial failed; %v", err)
+ return
+ }
+ defer c2.Close()
+ done := make(chan struct{})
+ go func() {
+ _, _ = io.Copy(c1, c2)
+ close(done)
+ }()
+ _, _ = io.Copy(c2, c1)
+ <-done
+ }()
+
+ purl, err := url.Parse("socks5://" + proxyListener.Addr().String())
+ if err != nil {
+ t.Fatalf("parse failed: %v", err)
+ }
+
+ cstDialer := cstDialer // make local copy for modification on next line.
+ cstDialer.Proxy = http.ProxyURL(purl)
+
+ ws, _, err := cstDialer.Dial(s.URL, nil)
+ if err != nil {
+ t.Fatalf("Dial: %v", err)
+ }
+ defer ws.Close()
+ sendRecv(t, ws)
+}
+
+func TestTracingDialWithContext(t *testing.T) {
+
+ var headersWrote, requestWrote, getConn, gotConn, connectDone, gotFirstResponseByte bool
+ trace := &httptrace.ClientTrace{
+ WroteHeaders: func() {
+ headersWrote = true
+ },
+ WroteRequest: func(httptrace.WroteRequestInfo) {
+ requestWrote = true
+ },
+ GetConn: func(hostPort string) {
+ getConn = true
+ },
+ GotConn: func(info httptrace.GotConnInfo) {
+ gotConn = true
+ },
+ ConnectDone: func(network, addr string, err error) {
+ connectDone = true
+ },
+ GotFirstResponseByte: func() {
+ gotFirstResponseByte = true
+ },
+ }
+ ctx := httptrace.WithClientTrace(context.Background(), trace)
+
+ s := newTLSServer(t)
+ defer s.Close()
+
+ d := cstDialer
+ d.TLSClientConfig = &tls.Config{RootCAs: rootCAs(t, s.Server)}
+
+ ws, _, err := d.DialContext(ctx, s.URL, nil)
+ if err != nil {
+ t.Fatalf("Dial: %v", err)
+ }
+
+ if !headersWrote {
+ t.Fatal("Headers was not written")
+ }
+ if !requestWrote {
+ t.Fatal("Request was not written")
+ }
+ if !getConn {
+ t.Fatal("getConn was not called")
+ }
+ if !gotConn {
+ t.Fatal("gotConn was not called")
+ }
+ if !connectDone {
+ t.Fatal("connectDone was not called")
+ }
+ if !gotFirstResponseByte {
+ t.Fatal("GotFirstResponseByte was not called")
+ }
+
+ defer ws.Close()
+ sendRecv(t, ws)
+}
+
+func TestEmptyTracingDialWithContext(t *testing.T) {
+
+ trace := &httptrace.ClientTrace{}
+ ctx := httptrace.WithClientTrace(context.Background(), trace)
+
+ s := newTLSServer(t)
+ defer s.Close()
+
+ d := cstDialer
+ d.TLSClientConfig = &tls.Config{RootCAs: rootCAs(t, s.Server)}
+
+ ws, _, err := d.DialContext(ctx, s.URL, nil)
+ if err != nil {
+ t.Fatalf("Dial: %v", err)
+ }
+
+ defer ws.Close()
+ sendRecv(t, ws)
+}
+
+// TestNetDialConnect tests selection of dial method between NetDial, NetDialContext, NetDialTLS or NetDialTLSContext
+func TestNetDialConnect(t *testing.T) {
+
+ upgrader := Upgrader{}
+ handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ if IsWebSocketUpgrade(r) {
+ c, err := upgrader.Upgrade(w, r, http.Header{"X-Test-Host": {r.Host}})
+ if err != nil {
+ t.Fatal(err)
+ }
+ c.Close()
+ } else {
+ w.Header().Set("X-Test-Host", r.Host)
+ }
+ })
+
+ server := httptest.NewServer(handler)
+ defer server.Close()
+
+ tlsServer := httptest.NewTLSServer(handler)
+ defer tlsServer.Close()
+
+ testUrls := map[*httptest.Server]string{
+ server: "ws://" + server.Listener.Addr().String() + "/",
+ tlsServer: "wss://" + tlsServer.Listener.Addr().String() + "/",
+ }
+
+ cas := rootCAs(t, tlsServer)
+ tlsConfig := &tls.Config{
+ RootCAs: cas,
+ ServerName: "example.com",
+ InsecureSkipVerify: false,
+ }
+
+ tests := []struct {
+ name string
+ server *httptest.Server // server to use
+ netDial func(network, addr string) (net.Conn, error)
+ netDialContext func(ctx context.Context, network, addr string) (net.Conn, error)
+ netDialTLSContext func(ctx context.Context, network, addr string) (net.Conn, error)
+ tlsClientConfig *tls.Config
+ }{
+
+ {
+ name: "HTTP server, all NetDial* defined, shall use NetDialContext",
+ server: server,
+ netDial: func(network, addr string) (net.Conn, error) {
+ return nil, errors.New("NetDial should not be called")
+ },
+ netDialContext: func(_ context.Context, network, addr string) (net.Conn, error) {
+ return net.Dial(network, addr)
+ },
+ netDialTLSContext: func(_ context.Context, network, addr string) (net.Conn, error) {
+ return nil, errors.New("NetDialTLSContext should not be called")
+ },
+ tlsClientConfig: nil,
+ },
+ {
+ name: "HTTP server, all NetDial* undefined",
+ server: server,
+ netDial: nil,
+ netDialContext: nil,
+ netDialTLSContext: nil,
+ tlsClientConfig: nil,
+ },
+ {
+ name: "HTTP server, NetDialContext undefined, shall fallback to NetDial",
+ server: server,
+ netDial: func(network, addr string) (net.Conn, error) {
+ return net.Dial(network, addr)
+ },
+ netDialContext: nil,
+ netDialTLSContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
+ return nil, errors.New("NetDialTLSContext should not be called")
+ },
+ tlsClientConfig: nil,
+ },
+ {
+ name: "HTTPS server, all NetDial* defined, shall use NetDialTLSContext",
+ server: tlsServer,
+ netDial: func(network, addr string) (net.Conn, error) {
+ return nil, errors.New("NetDial should not be called")
+ },
+ netDialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
+ return nil, errors.New("NetDialContext should not be called")
+ },
+ netDialTLSContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
+ netConn, err := net.Dial(network, addr)
+ if err != nil {
+ return nil, err
+ }
+ tlsConn := tls.Client(netConn, tlsConfig)
+ err = tlsConn.Handshake()
+ if err != nil {
+ return nil, err
+ }
+ return tlsConn, nil
+ },
+ tlsClientConfig: nil,
+ },
+ {
+ name: "HTTPS server, NetDialTLSContext undefined, shall fallback to NetDialContext and do handshake",
+ server: tlsServer,
+ netDial: func(network, addr string) (net.Conn, error) {
+ return nil, errors.New("NetDial should not be called")
+ },
+ netDialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
+ return net.Dial(network, addr)
+ },
+ netDialTLSContext: nil,
+ tlsClientConfig: tlsConfig,
+ },
+ {
+ name: "HTTPS server, NetDialTLSContext and NetDialContext undefined, shall fallback to NetDial and do handshake",
+ server: tlsServer,
+ netDial: func(network, addr string) (net.Conn, error) {
+ return net.Dial(network, addr)
+ },
+ netDialContext: nil,
+ netDialTLSContext: nil,
+ tlsClientConfig: tlsConfig,
+ },
+ {
+ name: "HTTPS server, all NetDial* undefined",
+ server: tlsServer,
+ netDial: nil,
+ netDialContext: nil,
+ netDialTLSContext: nil,
+ tlsClientConfig: tlsConfig,
+ },
+ {
+ name: "HTTPS server, all NetDialTLSContext defined, dummy TlsClientConfig defined, shall not do handshake",
+ server: tlsServer,
+ netDial: func(network, addr string) (net.Conn, error) {
+ return nil, errors.New("NetDial should not be called")
+ },
+ netDialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
+ return nil, errors.New("NetDialContext should not be called")
+ },
+ netDialTLSContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
+ netConn, err := net.Dial(network, addr)
+ if err != nil {
+ return nil, err
+ }
+ tlsConn := tls.Client(netConn, tlsConfig)
+ err = tlsConn.Handshake()
+ if err != nil {
+ return nil, err
+ }
+ return tlsConn, nil
+ },
+ tlsClientConfig: &tls.Config{
+ RootCAs: nil,
+ ServerName: "badserver.com",
+ InsecureSkipVerify: false,
+ },
+ },
+ }
+
+ for _, tc := range tests {
+ dialer := Dialer{
+ NetDial: tc.netDial,
+ NetDialContext: tc.netDialContext,
+ NetDialTLSContext: tc.netDialTLSContext,
+ TLSClientConfig: tc.tlsClientConfig,
+ }
+
+ // Test websocket dial
+ c, _, err := dialer.Dial(testUrls[tc.server], nil)
+ if err != nil {
+ t.Errorf("FAILED %s, err: %s", tc.name, err.Error())
+ } else {
+ c.Close()
+ }
+ }
+}
+func TestNextProtos(t *testing.T) {
+ ts := httptest.NewUnstartedServer(
+ http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}),
+ )
+ ts.EnableHTTP2 = true
+ ts.StartTLS()
+ defer ts.Close()
+
+ d := Dialer{
+ TLSClientConfig: ts.Client().Transport.(*http.Transport).TLSClientConfig,
+ }
+
+ r, err := ts.Client().Get(ts.URL)
+ if err != nil {
+ t.Fatalf("Get: %v", err)
+ }
+ r.Body.Close()
+
+ // Asserts that Dialer.TLSClientConfig.NextProtos contains "h2"
+ // after the Client.Get call from net/http above.
+ var containsHTTP2 bool = false
+ for _, proto := range d.TLSClientConfig.NextProtos {
+ if proto == "h2" {
+ containsHTTP2 = true
+ }
+ }
+ if !containsHTTP2 {
+ t.Fatalf("Dialer.TLSClientConfig.NextProtos does not contain \"h2\"")
+ }
+
+ _, _, err = d.Dial(makeWsProto(ts.URL), nil)
+ if err == nil {
+ t.Fatalf("Dial succeeded, expect fail ")
+ }
+}
+
+type dataBeforeHandshakeResponseWriter struct {
+ http.ResponseWriter
+}
+
+type dataBeforeHandshakeConnection struct {
+ net.Conn
+ io.Reader
+}
+
+func (c *dataBeforeHandshakeConnection) Read(p []byte) (int, error) {
+ return c.Reader.Read(p)
+}
+
+func (w dataBeforeHandshakeResponseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) {
+ // Example single-frame masked text message from section 5.7 of the RFC.
+ message := []byte{0x81, 0x85, 0x37, 0xfa, 0x21, 0x3d, 0x7f, 0x9f, 0x4d, 0x51, 0x58}
+ n := len(message) / 2
+
+ c, rw, err := http.NewResponseController(w.ResponseWriter).Hijack()
+ if rw != nil {
+ // Load first part of message into bufio.Reader. If the websocket
+ // connection reads more than n bytes from the bufio.Reader, then the
+ // test will fail with an unexpected EOF error.
+ rw.Reader.Reset(bytes.NewReader(message[:n]))
+ rw.Reader.Peek(n)
+ }
+ if c != nil {
+ // Inject second part of message before data read from the network connection.
+ c = &dataBeforeHandshakeConnection{
+ Conn: c,
+ Reader: io.MultiReader(bytes.NewReader(message[n:]), c),
+ }
+ }
+ return c, rw, err
+}
+
+func TestDataReceivedBeforeHandshake(t *testing.T) {
+ s := newServer(t)
+ defer s.Close()
+
+ origHandler := s.Server.Config.Handler
+ s.Server.Config.Handler = http.HandlerFunc(
+ func(w http.ResponseWriter, r *http.Request) {
+ origHandler.ServeHTTP(dataBeforeHandshakeResponseWriter{w}, r)
+ })
+
+ for _, readBufferSize := range []int{0, 1024} {
+ t.Run(fmt.Sprintf("ReadBufferSize=%d", readBufferSize), func(t *testing.T) {
+ dialer := cstDialer
+ dialer.ReadBufferSize = readBufferSize
+ ws, _, err := cstDialer.Dial(s.URL, nil)
+ if err != nil {
+ t.Fatalf("Dial: %v", err)
+ }
+ defer ws.Close()
+ _, m, err := ws.ReadMessage()
+ if err != nil || string(m) != "Hello" {
+ t.Fatalf("ReadMessage() = %q, %v, want \"Hello\", nil", m, err)
+ }
+ })
+ }
+}
+// Copyright 2014 The Gorilla WebSocket Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+
+var hostPortNoPortTests = []struct {
+ u *url.URL
+ hostPort, hostNoPort string
+}{
+ {&url.URL{Scheme: "ws", Host: "example.com"}, "example.com:80", "example.com"},
+ {&url.URL{Scheme: "wss", Host: "example.com"}, "example.com:443", "example.com"},
+ {&url.URL{Scheme: "ws", Host: "example.com:7777"}, "example.com:7777", "example.com"},
+ {&url.URL{Scheme: "wss", Host: "example.com:7777"}, "example.com:7777", "example.com"},
+}
+
+func TestHostPortNoPort(t *testing.T) {
+ for _, tt := range hostPortNoPortTests {
+ hostPort, hostNoPort := hostPortNoPort(tt.u)
+ if hostPort != tt.hostPort {
+ t.Errorf("hostPortNoPort(%v) returned hostPort %q, want %q", tt.u, hostPort, tt.hostPort)
+ }
+ if hostNoPort != tt.hostNoPort {
+ t.Errorf("hostPortNoPort(%v) returned hostNoPort %q, want %q", tt.u, hostNoPort, tt.hostNoPort)
+ }
+ }
+}
+
+type nopCloser struct{ io.Writer }
+
+func (nopCloser) Close() error { return nil }
+
+func TestTruncWriter(t *testing.T) {
+ const data = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijlkmnopqrstuvwxyz987654321"
+ for n := 1; n <= 10; n++ {
+ var b bytes.Buffer
+ w := &truncWriter{w: nopCloser{&b}}
+ p := []byte(data)
+ for len(p) > 0 {
+ m := len(p)
+ if m > n {
+ m = n
+ }
+ _, _ = w.Write(p[:m])
+ p = p[m:]
+ }
+ if b.String() != data[:len(data)-len(w.p)] {
+ t.Errorf("%d: %q", n, b.String())
+ }
+ }
+}
+
+func textMessages(num int) [][]byte {
+ messages := make([][]byte, num)
+ for i := 0; i < num; i++ {
+ msg := fmt.Sprintf("planet: %d, country: %d, city: %d, street: %d", i, i, i, i)
+ messages[i] = []byte(msg)
+ }
+ return messages
+}
+
+func BenchmarkWriteNoCompression(b *testing.B) {
+ w := io.Discard
+ c := newTestConn(nil, w, false)
+ messages := textMessages(100)
+ b.ResetTimer()
+ for i := 0; i < b.N; i++ {
+ _ = c.WriteMessage(TextMessage, messages[i%len(messages)])
+ }
+ b.ReportAllocs()
+}
+
+func BenchmarkWriteWithCompression(b *testing.B) {
+ w := io.Discard
+ c := newTestConn(nil, w, false)
+ messages := textMessages(100)
+ c.enableWriteCompression = true
+ c.newCompressionWriter = compressNoContextTakeover
+ b.ResetTimer()
+ for i := 0; i < b.N; i++ {
+ _ = c.WriteMessage(TextMessage, messages[i%len(messages)])
+ }
+ b.ReportAllocs()
+}
+
+func TestValidCompressionLevel(t *testing.T) {
+ c := newTestConn(nil, nil, false)
+ for _, level := range []int{minCompressionLevel - 1, maxCompressionLevel + 1} {
+ if err := c.SetCompressionLevel(level); err == nil {
+ t.Errorf("no error for level %d", level)
+ }
+ }
+ for _, level := range []int{minCompressionLevel, maxCompressionLevel} {
+ if err := c.SetCompressionLevel(level); err != nil {
+ t.Errorf("error for level %d", level)
+ }
+ }
+}
+// Copyright 2017 The Gorilla WebSocket Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+
+// broadcastBench allows to run broadcast benchmarks.
+// In every broadcast benchmark we create many connections, then send the same
+// message into every connection and wait for all writes complete. This emulates
+// an application where many connections listen to the same data - i.e. PUB/SUB
+// scenarios with many subscribers in one channel.
+type broadcastBench struct {
+ w io.Writer
+ closeCh chan struct{}
+ doneCh chan struct{}
+ count int32
+ conns []*broadcastConn
+ compression bool
+ usePrepared bool
+}
+
+type broadcastMessage struct {
+ payload []byte
+ prepared *PreparedMessage
+}
+
+type broadcastConn struct {
+ conn *Conn
+ msgCh chan *broadcastMessage
+}
+
+func newBroadcastConn(c *Conn) *broadcastConn {
+ return &broadcastConn{
+ conn: c,
+ msgCh: make(chan *broadcastMessage, 1),
+ }
+}
+
+func newBroadcastBench(usePrepared, compression bool) *broadcastBench {
+ bench := &broadcastBench{
+ w: io.Discard,
+ doneCh: make(chan struct{}),
+ closeCh: make(chan struct{}),
+ usePrepared: usePrepared,
+ compression: compression,
+ }
+ bench.makeConns(10000)
+ return bench
+}
+
+func (b *broadcastBench) makeConns(numConns int) {
+ conns := make([]*broadcastConn, numConns)
+
+ for i := 0; i < numConns; i++ {
+ c := newTestConn(nil, b.w, true)
+ if b.compression {
+ c.enableWriteCompression = true
+ c.newCompressionWriter = compressNoContextTakeover
+ }
+ conns[i] = newBroadcastConn(c)
+ go func(c *broadcastConn) {
+ for {
+ select {
+ case msg := <-c.msgCh:
+ if msg.prepared != nil {
+ _ = c.conn.WritePreparedMessage(msg.prepared)
+ } else {
+ _ = c.conn.WriteMessage(TextMessage, msg.payload)
+ }
+ val := atomic.AddInt32(&b.count, 1)
+ if val%int32(numConns) == 0 {
+ b.doneCh <- struct{}{}
+ }
+ case <-b.closeCh:
+ return
+ }
+ }
+ }(conns[i])
+ }
+ b.conns = conns
+}
+
+func (b *broadcastBench) close() {
+ close(b.closeCh)
+}
+
+func (b *broadcastBench) broadcastOnce(msg *broadcastMessage) {
+ for _, c := range b.conns {
+ c.msgCh <- msg
+ }
+ <-b.doneCh
+}
+
+func BenchmarkBroadcast(b *testing.B) {
+ benchmarks := []struct {
+ name string
+ usePrepared bool
+ compression bool
+ }{
+ {"NoCompression", false, false},
+ {"Compression", false, true},
+ {"NoCompressionPrepared", true, false},
+ {"CompressionPrepared", true, true},
+ }
+ payload := textMessages(1)[0]
+ for _, bm := range benchmarks {
+ b.Run(bm.name, func(b *testing.B) {
+ bench := newBroadcastBench(bm.usePrepared, bm.compression)
+ defer bench.close()
+ b.ResetTimer()
+ for i := 0; i < b.N; i++ {
+ message := &broadcastMessage{
+ payload: payload,
+ }
+ if bench.usePrepared {
+ pm, _ := NewPreparedMessage(TextMessage, message.payload)
+ message.prepared = pm
+ }
+ bench.broadcastOnce(message)
+ }
+ b.ReportAllocs()
+ })
+ }
+}
+// Copyright 2013 The Gorilla WebSocket Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+
+var _ net.Error = errWriteTimeout
+
+type fakeNetConn struct {
+ io.Reader
+ io.Writer
+}
+
+func (c fakeNetConn) Close() error { return nil }
+func (c fakeNetConn) LocalAddr() net.Addr { return localAddr }
+func (c fakeNetConn) RemoteAddr() net.Addr { return remoteAddr }
+func (c fakeNetConn) SetDeadline(t time.Time) error { return nil }
+func (c fakeNetConn) SetReadDeadline(t time.Time) error { return nil }
+func (c fakeNetConn) SetWriteDeadline(t time.Time) error { return nil }
+
+type fakeAddr int
+
+var (
+ localAddr = fakeAddr(1)
+ remoteAddr = fakeAddr(2)
+)
+
+func (a fakeAddr) Network() string {
+ return "net"
+}
+
+func (a fakeAddr) String() string {
+ return "str"
+}
+
+// newTestConn creates a connection backed by a fake network connection using
+// default values for buffering.
+func newTestConn(r io.Reader, w io.Writer, isServer bool) *Conn {
+ return newConn(fakeNetConn{Reader: r, Writer: w}, isServer, 1024, 1024, nil, nil, nil)
+}
+
+func TestFraming(t *testing.T) {
+ frameSizes := []int{
+ 0, 1, 2, 124, 125, 126, 127, 128, 129, 65534, 65535,
+ // 65536, 65537
+ }
+ var readChunkers = []struct {
+ name string
+ f func(io.Reader) io.Reader
+ }{
+ {"half", iotest.HalfReader},
+ {"one", iotest.OneByteReader},
+ {"asis", func(r io.Reader) io.Reader { return r }},
+ }
+ writeBuf := make([]byte, 65537)
+ for i := range writeBuf {
+ writeBuf[i] = byte(i)
+ }
+ var writers = []struct {
+ name string
+ f func(w io.Writer, n int) (int, error)
+ }{
+ {"iocopy", func(w io.Writer, n int) (int, error) {
+ nn, err := io.Copy(w, bytes.NewReader(writeBuf[:n]))
+ return int(nn), err
+ }},
+ {"write", func(w io.Writer, n int) (int, error) {
+ return w.Write(writeBuf[:n])
+ }},
+ {"string", func(w io.Writer, n int) (int, error) {
+ return io.WriteString(w, string(writeBuf[:n]))
+ }},
+ }
+
+ for _, compress := range []bool{false, true} {
+ for _, isServer := range []bool{true, false} {
+ for _, chunker := range readChunkers {
+
+ var connBuf bytes.Buffer
+ wc := newTestConn(nil, &connBuf, isServer)
+ rc := newTestConn(chunker.f(&connBuf), nil, !isServer)
+ if compress {
+ wc.newCompressionWriter = compressNoContextTakeover
+ rc.newDecompressionReader = decompressNoContextTakeover
+ }
+ for _, n := range frameSizes {
+ for _, writer := range writers {
+ name := fmt.Sprintf("z:%v, s:%v, r:%s, n:%d w:%s", compress, isServer, chunker.name, n, writer.name)
+
+ w, err := wc.NextWriter(TextMessage)
+ if err != nil {
+ t.Errorf("%s: wc.NextWriter() returned %v", name, err)
+ continue
+ }
+ nn, err := writer.f(w, n)
+ if err != nil || nn != n {
+ t.Errorf("%s: w.Write(writeBuf[:n]) returned %d, %v", name, nn, err)
+ continue
+ }
+ err = w.Close()
+ if err != nil {
+ t.Errorf("%s: w.Close() returned %v", name, err)
+ continue
+ }
+
+ opCode, r, err := rc.NextReader()
+ if err != nil || opCode != TextMessage {
+ t.Errorf("%s: NextReader() returned %d, r, %v", name, opCode, err)
+ continue
+ }
+
+ t.Logf("frame size: %d", n)
+ rbuf, err := io.ReadAll(r)
+ if err != nil {
+ t.Errorf("%s: ReadFull() returned rbuf, %v", name, err)
+ continue
+ }
+
+ if len(rbuf) != n {
+ t.Errorf("%s: len(rbuf) is %d, want %d", name, len(rbuf), n)
+ continue
+ }
+
+ for i, b := range rbuf {
+ if byte(i) != b {
+ t.Errorf("%s: bad byte at offset %d", name, i)
+ break
+ }
+ }
+ }
+ }
+ }
+ }
+ }
+}
+
+func TestWriteControlDeadline(t *testing.T) {
+ t.Parallel()
+ message := []byte("hello")
+ var connBuf bytes.Buffer
+ c := newTestConn(nil, &connBuf, true)
+ if err := c.WriteControl(PongMessage, message, time.Time{}); err != nil {
+ t.Errorf("WriteControl(..., zero deadline) = %v, want nil", err)
+ }
+ if err := c.WriteControl(PongMessage, message, time.Now().Add(time.Second)); err != nil {
+ t.Errorf("WriteControl(..., future deadline) = %v, want nil", err)
+ }
+ if err := c.WriteControl(PongMessage, message, time.Now().Add(-time.Second)); err == nil {
+ t.Errorf("WriteControl(..., past deadline) = nil, want timeout error")
+ }
+}
+
+func TestConcurrencyWriteControl(t *testing.T) {
+ const message = "this is a ping/pong messsage"
+ loop := 10
+ workers := 10
+ for i := 0; i < loop; i++ {
+ var connBuf bytes.Buffer
+
+ wg := sync.WaitGroup{}
+ wc := newTestConn(nil, &connBuf, true)
+
+ for i := 0; i < workers; i++ {
+ wg.Add(1)
+ go func() {
+ defer wg.Done()
+ if err := wc.WriteControl(PongMessage, []byte(message), time.Now().Add(time.Second)); err != nil {
+ t.Errorf("concurrently wc.WriteControl() returned %v", err)
+ }
+ }()
+ }
+
+ wg.Wait()
+ wc.Close()
+ }
+}
+
+func TestControl(t *testing.T) {
+ const message = "this is a ping/pong message"
+ for _, isServer := range []bool{true, false} {
+ for _, isWriteControl := range []bool{true, false} {
+ name := fmt.Sprintf("s:%v, wc:%v", isServer, isWriteControl)
+ var connBuf bytes.Buffer
+ wc := newTestConn(nil, &connBuf, isServer)
+ rc := newTestConn(&connBuf, nil, !isServer)
+ if isWriteControl {
+ _ = wc.WriteControl(PongMessage, []byte(message), time.Now().Add(time.Second))
+ } else {
+ w, err := wc.NextWriter(PongMessage)
+ if err != nil {
+ t.Errorf("%s: wc.NextWriter() returned %v", name, err)
+ continue
+ }
+ if _, err := w.Write([]byte(message)); err != nil {
+ t.Errorf("%s: w.Write() returned %v", name, err)
+ continue
+ }
+ if err := w.Close(); err != nil {
+ t.Errorf("%s: w.Close() returned %v", name, err)
+ continue
+ }
+ var actualMessage string
+ rc.SetPongHandler(func(s string) error { actualMessage = s; return nil })
+ _, _, _ = rc.NextReader()
+ if actualMessage != message {
+ t.Errorf("%s: pong=%q, want %q", name, actualMessage, message)
+ continue
+ }
+ }
+ }
+ }
+}
+
+// simpleBufferPool is an implementation of BufferPool for TestWriteBufferPool.
+type simpleBufferPool struct {
+ v interface{}
+}
+
+func (p *simpleBufferPool) Get() interface{} {
+ v := p.v
+ p.v = nil
+ return v
+}
+
+func (p *simpleBufferPool) Put(v interface{}) {
+ p.v = v
+}
+
+func TestWriteBufferPool(t *testing.T) {
+ const message = "Now is the time for all good people to come to the aid of the party."
+
+ var buf bytes.Buffer
+ var pool simpleBufferPool
+ rc := newTestConn(&buf, nil, false)
+
+ // Specify writeBufferSize smaller than message size to ensure that pooling
+ // works with fragmented messages.
+ wc := newConn(fakeNetConn{Writer: &buf}, true, 1024, len(message)-1, &pool, nil, nil)
+
+ if wc.writeBuf != nil {
+ t.Fatal("writeBuf not nil after create")
+ }
+
+ // Part 1: test NextWriter/Write/Close
+
+ w, err := wc.NextWriter(TextMessage)
+ if err != nil {
+ t.Fatalf("wc.NextWriter() returned %v", err)
+ }
+
+ if wc.writeBuf == nil {
+ t.Fatal("writeBuf is nil after NextWriter")
+ }
+
+ writeBufAddr := &wc.writeBuf[0]
+
+ if _, err := io.WriteString(w, message); err != nil {
+ t.Fatalf("io.WriteString(w, message) returned %v", err)
+ }
+
+ if err := w.Close(); err != nil {
+ t.Fatalf("w.Close() returned %v", err)
+ }
+
+ if wc.writeBuf != nil {
+ t.Fatal("writeBuf not nil after w.Close()")
+ }
+
+ if wpd, ok := pool.v.(writePoolData); !ok || len(wpd.buf) == 0 || &wpd.buf[0] != writeBufAddr {
+ t.Fatal("writeBuf not returned to pool")
+ }
+
+ opCode, p, err := rc.ReadMessage()
+ if opCode != TextMessage || err != nil {
+ t.Fatalf("ReadMessage() returned %d, p, %v", opCode, err)
+ }
+
+ if s := string(p); s != message {
+ t.Fatalf("message is %s, want %s", s, message)
+ }
+
+ // Part 2: Test WriteMessage.
+
+ if err := wc.WriteMessage(TextMessage, []byte(message)); err != nil {
+ t.Fatalf("wc.WriteMessage() returned %v", err)
+ }
+
+ if wc.writeBuf != nil {
+ t.Fatal("writeBuf not nil after wc.WriteMessage()")
+ }
+
+ if wpd, ok := pool.v.(writePoolData); !ok || len(wpd.buf) == 0 || &wpd.buf[0] != writeBufAddr {
+ t.Fatal("writeBuf not returned to pool after WriteMessage")
+ }
+
+ opCode, p, err = rc.ReadMessage()
+ if opCode != TextMessage || err != nil {
+ t.Fatalf("ReadMessage() returned %d, p, %v", opCode, err)
+ }
+
+ if s := string(p); s != message {
+ t.Fatalf("message is %s, want %s", s, message)
+ }
+}
+
+// TestWriteBufferPoolSync ensures that *sync.Pool works as a buffer pool.
+func TestWriteBufferPoolSync(t *testing.T) {
+ var buf bytes.Buffer
+ var pool sync.Pool
+ wc := newConn(fakeNetConn{Writer: &buf}, true, 1024, 1024, &pool, nil, nil)
+ rc := newTestConn(&buf, nil, false)
+
+ const message = "Hello World!"
+ for i := 0; i < 3; i++ {
+ if err := wc.WriteMessage(TextMessage, []byte(message)); err != nil {
+ t.Fatalf("wc.WriteMessage() returned %v", err)
+ }
+ opCode, p, err := rc.ReadMessage()
+ if opCode != TextMessage || err != nil {
+ t.Fatalf("ReadMessage() returned %d, p, %v", opCode, err)
+ }
+ if s := string(p); s != message {
+ t.Fatalf("message is %s, want %s", s, message)
+ }
+ }
+}
+
+// errorWriter is an io.Writer than returns an error on all writes.
+type errorWriter struct{}
+
+func (ew errorWriter) Write(p []byte) (int, error) { return 0, errors.New("error") }
+
+// TestWriteBufferPoolError ensures that buffer is returned to pool after error
+// on write.
+func TestWriteBufferPoolError(t *testing.T) {
+
+ // Part 1: Test NextWriter/Write/Close
+
+ var pool simpleBufferPool
+ wc := newConn(fakeNetConn{Writer: errorWriter{}}, true, 1024, 1024, &pool, nil, nil)
+
+ w, err := wc.NextWriter(TextMessage)
+ if err != nil {
+ t.Fatalf("wc.NextWriter() returned %v", err)
+ }
+
+ if wc.writeBuf == nil {
+ t.Fatal("writeBuf is nil after NextWriter")
+ }
+
+ writeBufAddr := &wc.writeBuf[0]
+
+ if _, err := io.WriteString(w, "Hello"); err != nil {
+ t.Fatalf("io.WriteString(w, message) returned %v", err)
+ }
+
+ if err := w.Close(); err == nil {
+ t.Fatalf("w.Close() did not return error")
+ }
+
+ if wpd, ok := pool.v.(writePoolData); !ok || len(wpd.buf) == 0 || &wpd.buf[0] != writeBufAddr {
+ t.Fatal("writeBuf not returned to pool")
+ }
+
+ // Part 2: Test WriteMessage
+
+ wc = newConn(fakeNetConn{Writer: errorWriter{}}, true, 1024, 1024, &pool, nil, nil)
+
+ if err := wc.WriteMessage(TextMessage, []byte("Hello")); err == nil {
+ t.Fatalf("wc.WriteMessage did not return error")
+ }
+
+ if wpd, ok := pool.v.(writePoolData); !ok || len(wpd.buf) == 0 || &wpd.buf[0] != writeBufAddr {
+ t.Fatal("writeBuf not returned to pool")
+ }
+}
+
+func TestCloseFrameBeforeFinalMessageFrame(t *testing.T) {
+ const bufSize = 512
+
+ expectedErr := &CloseError{Code: CloseNormalClosure, Text: "hello"}
+
+ var b1, b2 bytes.Buffer
+ wc := newConn(&fakeNetConn{Reader: nil, Writer: &b1}, false, 1024, bufSize, nil, nil, nil)
+ rc := newTestConn(&b1, &b2, true)
+
+ w, _ := wc.NextWriter(BinaryMessage)
+ _, _ = w.Write(make([]byte, bufSize+bufSize/2))
+ _ = wc.WriteControl(CloseMessage, FormatCloseMessage(expectedErr.Code, expectedErr.Text), time.Now().Add(10*time.Second))
+ w.Close()
+
+ op, r, err := rc.NextReader()
+ if op != BinaryMessage || err != nil {
+ t.Fatalf("NextReader() returned %d, %v", op, err)
+ }
+ _, err = io.Copy(io.Discard, r)
+ if !reflect.DeepEqual(err, expectedErr) {
+ t.Fatalf("io.Copy() returned %v, want %v", err, expectedErr)
+ }
+ _, _, err = rc.NextReader()
+ if !reflect.DeepEqual(err, expectedErr) {
+ t.Fatalf("NextReader() returned %v, want %v", err, expectedErr)
+ }
+}
+
+func TestEOFWithinFrame(t *testing.T) {
+ const bufSize = 64
+
+ for n := 0; ; n++ {
+ var b bytes.Buffer
+ wc := newTestConn(nil, &b, false)
+ rc := newTestConn(&b, nil, true)
+
+ w, _ := wc.NextWriter(BinaryMessage)
+ _, _ = w.Write(make([]byte, bufSize))
+ w.Close()
+
+ if n >= b.Len() {
+ break
+ }
+ b.Truncate(n)
+
+ op, r, err := rc.NextReader()
+ if err == errUnexpectedEOF {
+ continue
+ }
+ if op != BinaryMessage || err != nil {
+ t.Fatalf("%d: NextReader() returned %d, %v", n, op, err)
+ }
+ _, err = io.Copy(io.Discard, r)
+ if err != errUnexpectedEOF {
+ t.Fatalf("%d: io.Copy() returned %v, want %v", n, err, errUnexpectedEOF)
+ }
+ _, _, err = rc.NextReader()
+ if err != errUnexpectedEOF {
+ t.Fatalf("%d: NextReader() returned %v, want %v", n, err, errUnexpectedEOF)
+ }
+ }
+}
+
+func TestEOFBeforeFinalFrame(t *testing.T) {
+ const bufSize = 512
+
+ var b1, b2 bytes.Buffer
+ wc := newConn(&fakeNetConn{Writer: &b1}, false, 1024, bufSize, nil, nil, nil)
+ rc := newTestConn(&b1, &b2, true)
+
+ w, _ := wc.NextWriter(BinaryMessage)
+ _, _ = w.Write(make([]byte, bufSize+bufSize/2))
+
+ op, r, err := rc.NextReader()
+ if op != BinaryMessage || err != nil {
+ t.Fatalf("NextReader() returned %d, %v", op, err)
+ }
+ _, err = io.Copy(io.Discard, r)
+ if err != errUnexpectedEOF {
+ t.Fatalf("io.Copy() returned %v, want %v", err, errUnexpectedEOF)
+ }
+ _, _, err = rc.NextReader()
+ if err != errUnexpectedEOF {
+ t.Fatalf("NextReader() returned %v, want %v", err, errUnexpectedEOF)
+ }
+}
+
+func TestWriteAfterMessageWriterClose(t *testing.T) {
+ wc := newTestConn(nil, &bytes.Buffer{}, false)
+ w, _ := wc.NextWriter(BinaryMessage)
+ _, _ = io.WriteString(w, "hello")
+ if err := w.Close(); err != nil {
+ t.Fatalf("unexpected error closing message writer, %v", err)
+ }
+
+ if _, err := io.WriteString(w, "world"); err == nil {
+ t.Fatalf("no error writing after close")
+ }
+
+ w, _ = wc.NextWriter(BinaryMessage)
+ _, _ = io.WriteString(w, "hello")
+
+ // close w by getting next writer
+ _, err := wc.NextWriter(BinaryMessage)
+ if err != nil {
+ t.Fatalf("unexpected error getting next writer, %v", err)
+ }
+
+ if _, err := io.WriteString(w, "world"); err == nil {
+ t.Fatalf("no error writing after close")
+ }
+}
+
+func TestReadLimit(t *testing.T) {
+ t.Run("Test ReadLimit is enforced", func(t *testing.T) {
+ const readLimit = 512
+ message := make([]byte, readLimit+1)
+
+ var b1, b2 bytes.Buffer
+ wc := newConn(&fakeNetConn{Writer: &b1}, false, 1024, readLimit-2, nil, nil, nil)
+ rc := newTestConn(&b1, &b2, true)
+ rc.SetReadLimit(readLimit)
+
+ // Send message at the limit with interleaved pong.
+ w, _ := wc.NextWriter(BinaryMessage)
+ _, _ = w.Write(message[:readLimit-1])
+ _ = wc.WriteControl(PongMessage, []byte("this is a pong"), time.Now().Add(10*time.Second))
+ _, _ = w.Write(message[:1])
+ w.Close()
+
+ // Send message larger than the limit.
+ _ = wc.WriteMessage(BinaryMessage, message[:readLimit+1])
+
+ op, _, err := rc.NextReader()
+ if op != BinaryMessage || err != nil {
+ t.Fatalf("1: NextReader() returned %d, %v", op, err)
+ }
+ op, r, err := rc.NextReader()
+ if op != BinaryMessage || err != nil {
+ t.Fatalf("2: NextReader() returned %d, %v", op, err)
+ }
+ _, err = io.Copy(io.Discard, r)
+ if err != ErrReadLimit {
+ t.Fatalf("io.Copy() returned %v", err)
+ }
+ })
+
+ t.Run("Test that ReadLimit cannot be overflowed", func(t *testing.T) {
+ const readLimit = 1
+
+ var b1, b2 bytes.Buffer
+ rc := newTestConn(&b1, &b2, true)
+ rc.SetReadLimit(readLimit)
+
+ // First, send a non-final binary message
+ b1.Write([]byte("\x02\x81"))
+
+ // Mask key
+ b1.Write([]byte("\x00\x00\x00\x00"))
+
+ // First payload
+ b1.Write([]byte("A"))
+
+ // Next, send a negative-length, non-final continuation frame
+ b1.Write([]byte("\x00\xFF\x80\x00\x00\x00\x00\x00\x00\x00"))
+
+ // Mask key
+ b1.Write([]byte("\x00\x00\x00\x00"))
+
+ // Next, send a too long, final continuation frame
+ b1.Write([]byte("\x80\xFF\x00\x00\x00\x00\x00\x00\x00\x05"))
+
+ // Mask key
+ b1.Write([]byte("\x00\x00\x00\x00"))
+
+ // Too-long payload
+ b1.Write([]byte("BCDEF"))
+
+ op, r, err := rc.NextReader()
+ if op != BinaryMessage || err != nil {
+ t.Fatalf("1: NextReader() returned %d, %v", op, err)
+ }
+
+ var buf [10]byte
+ var read int
+ n, err := r.Read(buf[:])
+ if err != nil && err != ErrReadLimit {
+ t.Fatalf("unexpected error testing read limit: %v", err)
+ }
+ read += n
+
+ n, err = r.Read(buf[:])
+ if err != nil && err != ErrReadLimit {
+ t.Fatalf("unexpected error testing read limit: %v", err)
+ }
+ read += n
+
+ if err == nil && read > readLimit {
+ t.Fatalf("read limit exceeded: limit %d, read %d", readLimit, read)
+ }
+ })
+}
+
+func TestAddrs(t *testing.T) {
+ c := newTestConn(nil, nil, true)
+ if c.LocalAddr() != localAddr {
+ t.Errorf("LocalAddr = %v, want %v", c.LocalAddr(), localAddr)
+ }
+ if c.RemoteAddr() != remoteAddr {
+ t.Errorf("RemoteAddr = %v, want %v", c.RemoteAddr(), remoteAddr)
+ }
+}
+
+func TestDeprecatedUnderlyingConn(t *testing.T) {
+ var b1, b2 bytes.Buffer
+ fc := fakeNetConn{Reader: &b1, Writer: &b2}
+ c := newConn(fc, true, 1024, 1024, nil, nil, nil)
+ ul := c.UnderlyingConn()
+ if ul != fc {
+ t.Fatalf("Underlying conn is not what it should be.")
+ }
+}
+
+func TestNetConn(t *testing.T) {
+ var b1, b2 bytes.Buffer
+ fc := fakeNetConn{Reader: &b1, Writer: &b2}
+ c := newConn(fc, true, 1024, 1024, nil, nil, nil)
+ ul := c.NetConn()
+ if ul != fc {
+ t.Fatalf("Underlying conn is not what it should be.")
+ }
+}
+
+func TestBufioReadBytes(t *testing.T) {
+ // Test calling bufio.ReadBytes for value longer than read buffer size.
+
+ m := make([]byte, 512)
+ m[len(m)-1] = '\n'
+
+ var b1, b2 bytes.Buffer
+ wc := newConn(fakeNetConn{Writer: &b1}, false, len(m)+64, len(m)+64, nil, nil, nil)
+ rc := newConn(fakeNetConn{Reader: &b1, Writer: &b2}, true, len(m)-64, len(m)-64, nil, nil, nil)
+
+ w, _ := wc.NextWriter(BinaryMessage)
+ _, _ = w.Write(m)
+ w.Close()
+
+ op, r, err := rc.NextReader()
+ if op != BinaryMessage || err != nil {
+ t.Fatalf("NextReader() returned %d, %v", op, err)
+ }
+
+ br := bufio.NewReader(r)
+ p, err := br.ReadBytes('\n')
+ if err != nil {
+ t.Fatalf("ReadBytes() returned %v", err)
+ }
+ if len(p) != len(m) {
+ t.Fatalf("read returned %d bytes, want %d bytes", len(p), len(m))
+ }
+}
+
+var closeErrorTests = []struct {
+ err error
+ codes []int
+ ok bool
+}{
+ {&CloseError{Code: CloseNormalClosure}, []int{CloseNormalClosure}, true},
+ {&CloseError{Code: CloseNormalClosure}, []int{CloseNoStatusReceived}, false},
+ {&CloseError{Code: CloseNormalClosure}, []int{CloseNoStatusReceived, CloseNormalClosure}, true},
+ {errors.New("hello"), []int{CloseNormalClosure}, false},
+}
+
+func TestCloseError(t *testing.T) {
+ for _, tt := range closeErrorTests {
+ ok := IsCloseError(tt.err, tt.codes...)
+ if ok != tt.ok {
+ t.Errorf("IsCloseError(%#v, %#v) returned %v, want %v", tt.err, tt.codes, ok, tt.ok)
+ }
+ }
+}
+
+var unexpectedCloseErrorTests = []struct {
+ err error
+ codes []int
+ ok bool
+}{
+ {&CloseError{Code: CloseNormalClosure}, []int{CloseNormalClosure}, false},
+ {&CloseError{Code: CloseNormalClosure}, []int{CloseNoStatusReceived}, true},
+ {&CloseError{Code: CloseNormalClosure}, []int{CloseNoStatusReceived, CloseNormalClosure}, false},
+ {errors.New("hello"), []int{CloseNormalClosure}, false},
+}
+
+func TestUnexpectedCloseErrors(t *testing.T) {
+ for _, tt := range unexpectedCloseErrorTests {
+ ok := IsUnexpectedCloseError(tt.err, tt.codes...)
+ if ok != tt.ok {
+ t.Errorf("IsUnexpectedCloseError(%#v, %#v) returned %v, want %v", tt.err, tt.codes, ok, tt.ok)
+ }
+ }
+}
+
+type blockingWriter struct {
+ c1, c2 chan struct{}
+}
+
+func (w blockingWriter) Write(p []byte) (int, error) {
+ // Allow main to continue
+ close(w.c1)
+ // Wait for panic in main
+ <-w.c2
+ return len(p), nil
+}
+
+func TestConcurrentWritePanic(t *testing.T) {
+ w := blockingWriter{make(chan struct{}), make(chan struct{})}
+ c := newTestConn(nil, w, false)
+ go func() {
+ _ = c.WriteMessage(TextMessage, []byte{})
+ }()
+
+ // wait for goroutine to block in write.
+ <-w.c1
+
+ defer func() {
+ close(w.c2)
+ if v := recover(); v != nil {
+ return
+ }
+ }()
+
+ _ = c.WriteMessage(TextMessage, []byte{})
+ t.Fatal("should not get here")
+}
+
+type failingReader struct{}
+
+func (r failingReader) Read(p []byte) (int, error) {
+ return 0, io.EOF
+}
+
+func TestFailedConnectionReadPanic(t *testing.T) {
+ c := newTestConn(failingReader{}, nil, false)
+
+ defer func() {
+ if v := recover(); v != nil {
+ return
+ }
+ }()
+
+ for i := 0; i < 20000; i++ {
+ _, _, _ = c.ReadMessage()
+ }
+ t.Fatal("should not get here")
+}
+// Copyright 2019 The Gorilla WebSocket Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+
+func TestJoinMessages(t *testing.T) {
+ messages := []string{"a", "bc", "def", "ghij", "klmno", "0", "12", "345", "6789"}
+ for _, readChunk := range []int{1, 2, 3, 4, 5, 6, 7} {
+ for _, term := range []string{"", ","} {
+ var connBuf bytes.Buffer
+ wc := newTestConn(nil, &connBuf, true)
+ rc := newTestConn(&connBuf, nil, false)
+ for _, m := range messages {
+ _ = wc.WriteMessage(BinaryMessage, []byte(m))
+ }
+
+ var result bytes.Buffer
+ _, err := io.CopyBuffer(&result, JoinMessages(rc, term), make([]byte, readChunk))
+ if IsUnexpectedCloseError(err, CloseAbnormalClosure) {
+ t.Errorf("readChunk=%d, term=%q: unexpected error %v", readChunk, term, err)
+ }
+ want := strings.Join(messages, term) + term
+ if result.String() != want {
+ t.Errorf("readChunk=%d, term=%q, got %q, want %q", readChunk, term, result.String(), want)
+ }
+ }
+ }
+}
+// Copyright 2013 The Gorilla WebSocket Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+
+func TestJSON(t *testing.T) {
+ var buf bytes.Buffer
+ wc := newTestConn(nil, &buf, true)
+ rc := newTestConn(&buf, nil, false)
+
+ var actual, expect struct {
+ A int
+ B string
+ }
+ expect.A = 1
+ expect.B = "hello"
+
+ if err := wc.WriteJSON(&expect); err != nil {
+ t.Fatal("write", err)
+ }
+
+ if err := rc.ReadJSON(&actual); err != nil {
+ t.Fatal("read", err)
+ }
+
+ if !reflect.DeepEqual(&actual, &expect) {
+ t.Fatal("equal", actual, expect)
+ }
+}
+
+func TestPartialJSONRead(t *testing.T) {
+ var buf0, buf1 bytes.Buffer
+ wc := newTestConn(nil, &buf0, true)
+ rc := newTestConn(&buf0, &buf1, false)
+
+ var v struct {
+ A int
+ B string
+ }
+ v.A = 1
+ v.B = "hello"
+
+ messageCount := 0
+
+ // Partial JSON values.
+
+ data, err := json.Marshal(v)
+ if err != nil {
+ t.Fatal(err)
+ }
+ for i := len(data) - 1; i >= 0; i-- {
+ if err := wc.WriteMessage(TextMessage, data[:i]); err != nil {
+ t.Fatal(err)
+ }
+ messageCount++
+ }
+
+ // Whitespace.
+
+ if err := wc.WriteMessage(TextMessage, []byte(" ")); err != nil {
+ t.Fatal(err)
+ }
+ messageCount++
+
+ // Close.
+
+ if err := wc.WriteMessage(CloseMessage, FormatCloseMessage(CloseNormalClosure, "")); err != nil {
+ t.Fatal(err)
+ }
+
+ for i := 0; i < messageCount; i++ {
+ err := rc.ReadJSON(&v)
+ if err != io.ErrUnexpectedEOF {
+ t.Error("read", i, err)
+ }
+ }
+
+ err = rc.ReadJSON(&v)
+ if _, ok := err.(*CloseError); !ok {
+ t.Error("final", err)
+ }
+}
+
+func TestDeprecatedJSON(t *testing.T) {
+ var buf bytes.Buffer
+ wc := newTestConn(nil, &buf, true)
+ rc := newTestConn(&buf, nil, false)
+
+ var actual, expect struct {
+ A int
+ B string
+ }
+ expect.A = 1
+ expect.B = "hello"
+
+ if err := WriteJSON(wc, &expect); err != nil {
+ t.Fatal("write", err)
+ }
+
+ if err := ReadJSON(rc, &actual); err != nil {
+ t.Fatal("read", err)
+ }
+
+ if !reflect.DeepEqual(&actual, &expect) {
+ t.Fatal("equal", actual, expect)
+ }
+}
+// Copyright 2016 The Gorilla WebSocket Authors. All rights reserved. Use of
+// this source code is governed by a BSD-style license that can be found in the
+// LICENSE file.
+
+// !appengine
+
+
+func maskBytesByByte(key [4]byte, pos int, b []byte) int {
+ for i := range b {
+ b[i] ^= key[pos&3]
+ pos++
+ }
+ return pos & 3
+}
+
+func notzero(b []byte) int {
+ for i := range b {
+ if b[i] != 0 {
+ return i
+ }
+ }
+ return -1
+}
+
+func TestMaskBytes(t *testing.T) {
+ key := [4]byte{1, 2, 3, 4}
+ for size := 1; size <= 1024; size++ {
+ for align := 0; align < wordSize; align++ {
+ for pos := 0; pos < 4; pos++ {
+ b := make([]byte, size+align)[align:]
+ maskBytes(key, pos, b)
+ maskBytesByByte(key, pos, b)
+ if i := notzero(b); i >= 0 {
+ t.Errorf("size:%d, align:%d, pos:%d, offset:%d", size, align, pos, i)
+ }
+ }
+ }
+ }
+}
+
+func BenchmarkMaskBytes(b *testing.B) {
+ for _, size := range []int{2, 4, 8, 16, 32, 512, 1024} {
+ b.Run(fmt.Sprintf("size-%d", size), func(b *testing.B) {
+ for _, align := range []int{wordSize / 2} {
+ b.Run(fmt.Sprintf("align-%d", align), func(b *testing.B) {
+ for _, fn := range []struct {
+ name string
+ fn func(key [4]byte, pos int, b []byte) int
+ }{
+ {"byte", maskBytesByByte},
+ {"word", maskBytes},
+ } {
+ b.Run(fn.name, func(b *testing.B) {
+ key := newMaskKey()
+ data := make([]byte, size+align)[align:]
+ for i := 0; i < b.N; i++ {
+ fn.fn(key, 0, data)
+ }
+ b.SetBytes(int64(len(data)))
+ })
+ }
+ })
+ }
+ })
+ }
+}
+// Copyright 2017 The Gorilla WebSocket Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+
+var preparedMessageTests = []struct {
+ messageType int
+ isServer bool
+ enableWriteCompression bool
+ compressionLevel int
+}{
+ // Server
+ {TextMessage, true, false, flate.BestSpeed},
+ {TextMessage, true, true, flate.BestSpeed},
+ {TextMessage, true, true, flate.BestCompression},
+ {PingMessage, true, false, flate.BestSpeed},
+ {PingMessage, true, true, flate.BestSpeed},
+
+ // Client
+ {TextMessage, false, false, flate.BestSpeed},
+ {TextMessage, false, true, flate.BestSpeed},
+ {TextMessage, false, true, flate.BestCompression},
+ {PingMessage, false, false, flate.BestSpeed},
+ {PingMessage, false, true, flate.BestSpeed},
+}
+
+func TestPreparedMessage(t *testing.T) {
+ testRand := rand.New(rand.NewSource(99))
+ prevMaskRand := maskRand
+ maskRand = testRand
+ defer func() { maskRand = prevMaskRand }()
+
+ for _, tt := range preparedMessageTests {
+ var data = []byte("this is a test")
+ var buf bytes.Buffer
+ c := newTestConn(nil, &buf, tt.isServer)
+ if tt.enableWriteCompression {
+ c.newCompressionWriter = compressNoContextTakeover
+ }
+ if err := c.SetCompressionLevel(tt.compressionLevel); err != nil {
+ t.Fatal(err)
+ }
+
+ // Seed random number generator for consistent frame mask.
+ testRand.Seed(1234)
+
+ if err := c.WriteMessage(tt.messageType, data); err != nil {
+ t.Fatal(err)
+ }
+ want := buf.String()
+
+ pm, err := NewPreparedMessage(tt.messageType, data)
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ // Scribble on data to ensure that NewPreparedMessage takes a snapshot.
+ copy(data, "hello world")
+
+ // Seed random number generator for consistent frame mask.
+ testRand.Seed(1234)
+
+ buf.Reset()
+ if err := c.WritePreparedMessage(pm); err != nil {
+ t.Fatal(err)
+ }
+ got := buf.String()
+
+ if got != want {
+ t.Errorf("write message != prepared message for %+v", tt)
+ }
+ }
+}
+// Copyright 2013 The Gorilla WebSocket Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+
+var subprotocolTests = []struct {
+ h string
+ protocols []string
+}{
+ {"", nil},
+ {"foo", []string{"foo"}},
+ {"foo,bar", []string{"foo", "bar"}},
+ {"foo, bar", []string{"foo", "bar"}},
+ {" foo, bar", []string{"foo", "bar"}},
+ {" foo, bar ", []string{"foo", "bar"}},
+}
+
+func TestSubprotocols(t *testing.T) {
+ for _, st := range subprotocolTests {
+ r := http.Request{Header: http.Header{"Sec-Websocket-Protocol": {st.h}}}
+ protocols := Subprotocols(&r)
+ if !reflect.DeepEqual(st.protocols, protocols) {
+ t.Errorf("SubProtocols(%q) returned %#v, want %#v", st.h, protocols, st.protocols)
+ }
+ }
+}
+
+var isWebSocketUpgradeTests = []struct {
+ ok bool
+ h http.Header
+}{
+ {false, http.Header{"Upgrade": {"websocket"}}},
+ {false, http.Header{"Connection": {"upgrade"}}},
+ {true, http.Header{"Connection": {"upgRade"}, "Upgrade": {"WebSocket"}}},
+}
+
+func TestIsWebSocketUpgrade(t *testing.T) {
+ for _, tt := range isWebSocketUpgradeTests {
+ ok := IsWebSocketUpgrade(&http.Request{Header: tt.h})
+ if tt.ok != ok {
+ t.Errorf("IsWebSocketUpgrade(%v) returned %v, want %v", tt.h, ok, tt.ok)
+ }
+ }
+}
+
+func TestSubProtocolSelection(t *testing.T) {
+ upgrader := Upgrader{
+ Subprotocols: []string{"foo", "bar", "baz"},
+ }
+
+ r := http.Request{Header: http.Header{"Sec-Websocket-Protocol": {"foo", "bar"}}}
+ s := upgrader.selectSubprotocol(&r, nil)
+ if s != "foo" {
+ t.Errorf("Upgrader.selectSubprotocol returned %v, want %v", s, "foo")
+ }
+
+ r = http.Request{Header: http.Header{"Sec-Websocket-Protocol": {"bar", "foo"}}}
+ s = upgrader.selectSubprotocol(&r, nil)
+ if s != "bar" {
+ t.Errorf("Upgrader.selectSubprotocol returned %v, want %v", s, "bar")
+ }
+
+ r = http.Request{Header: http.Header{"Sec-Websocket-Protocol": {"baz"}}}
+ s = upgrader.selectSubprotocol(&r, nil)
+ if s != "baz" {
+ t.Errorf("Upgrader.selectSubprotocol returned %v, want %v", s, "baz")
+ }
+
+ r = http.Request{Header: http.Header{"Sec-Websocket-Protocol": {"quux"}}}
+ s = upgrader.selectSubprotocol(&r, nil)
+ if s != "" {
+ t.Errorf("Upgrader.selectSubprotocol returned %v, want %v", s, "empty string")
+ }
+}
+
+var checkSameOriginTests = []struct {
+ ok bool
+ r *http.Request
+}{
+ {false, &http.Request{Host: "example.org", Header: map[string][]string{"Origin": {"https://other.org"}}}},
+ {true, &http.Request{Host: "example.org", Header: map[string][]string{"Origin": {"https://example.org"}}}},
+ {true, &http.Request{Host: "Example.org", Header: map[string][]string{"Origin": {"https://example.org"}}}},
+}
+
+func TestCheckSameOrigin(t *testing.T) {
+ for _, tt := range checkSameOriginTests {
+ ok := checkSameOrigin(tt.r)
+ if tt.ok != ok {
+ t.Errorf("checkSameOrigin(%+v) returned %v, want %v", tt.r, ok, tt.ok)
+ }
+ }
+}
+
+type reuseTestResponseWriter struct {
+ brw *bufio.ReadWriter
+ http.ResponseWriter
+}
+
+func (resp *reuseTestResponseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) {
+ return fakeNetConn{strings.NewReader(""), &bytes.Buffer{}}, resp.brw, nil
+}
+
+var bufioReuseTests = []struct {
+ n int
+ reuse bool
+}{
+ {4096, true},
+ {128, false},
+}
+
+func xTestBufioReuse(t *testing.T) {
+ for i, tt := range bufioReuseTests {
+ br := bufio.NewReaderSize(strings.NewReader(""), tt.n)
+ bw := bufio.NewWriterSize(&bytes.Buffer{}, tt.n)
+ resp := &reuseTestResponseWriter{
+ brw: bufio.NewReadWriter(br, bw),
+ }
+ upgrader := Upgrader{}
+ c, err := upgrader.Upgrade(resp, &http.Request{
+ Method: http.MethodGet,
+ Header: http.Header{
+ "Upgrade": []string{"websocket"},
+ "Connection": []string{"upgrade"},
+ "Sec-Websocket-Key": []string{"dGhlIHNhbXBsZSBub25jZQ=="},
+ "Sec-Websocket-Version": []string{"13"},
+ }}, nil)
+ if err != nil {
+ t.Fatal(err)
+ }
+ if reuse := c.br == br; reuse != tt.reuse {
+ t.Errorf("%d: buffered reader reuse=%v, want %v", i, reuse, tt.reuse)
+ }
+ writeBuf := bw.AvailableBuffer()
+ if reuse := &c.writeBuf[0] == &writeBuf[0]; reuse != tt.reuse {
+ t.Errorf("%d: write buffer reuse=%v, want %v", i, reuse, tt.reuse)
+ }
+ }
+}
+
+func TestHijack_NotSupported(t *testing.T) {
+ t.Parallel()
+
+ req := httptest.NewRequest(http.MethodGet, "http://example.com", nil)
+ req.Header.Set("Upgrade", "websocket")
+ req.Header.Set("Connection", "upgrade")
+ req.Header.Set("Sec-Websocket-Key", "dGhlIHNhbXBsZSBub25jZQ==")
+ req.Header.Set("Sec-Websocket-Version", "13")
+
+ recorder := httptest.NewRecorder()
+
+ upgrader := Upgrader{}
+ _, err := upgrader.Upgrade(recorder, req, nil)
+
+ if want := (HandshakeError{}); !errors.As(err, &want) || recorder.Code != http.StatusInternalServerError {
+ t.Errorf("want %T and status_code=%d", want, http.StatusInternalServerError)
+ t.Fatalf("got err=%T and status_code=%d", err, recorder.Code)
+ }
+}
+// Copyright 2014 The Gorilla WebSocket Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+
+var equalASCIIFoldTests = []struct {
+ t, s string
+ eq bool
+}{
+ {"WebSocket", "websocket", true},
+ {"websocket", "WebSocket", true},
+ {"Öyster", "öyster", false},
+ {"WebSocket", "WetSocket", false},
+}
+
+func TestEqualASCIIFold(t *testing.T) {
+ for _, tt := range equalASCIIFoldTests {
+ eq := equalASCIIFold(tt.s, tt.t)
+ if eq != tt.eq {
+ t.Errorf("equalASCIIFold(%q, %q) = %v, want %v", tt.s, tt.t, eq, tt.eq)
+ }
+ }
+}
+
+var tokenListContainsValueTests = []struct {
+ value string
+ ok bool
+}{
+ {"WebSocket", true},
+ {"WEBSOCKET", true},
+ {"websocket", true},
+ {"websockets", false},
+ {"x websocket", false},
+ {"websocket x", false},
+ {"other,websocket,more", true},
+ {"other, websocket, more", true},
+}
+
+func TestTokenListContainsValue(t *testing.T) {
+ for _, tt := range tokenListContainsValueTests {
+ h := http.Header{"Upgrade": {tt.value}}
+ ok := tokenListContainsValue(h, "Upgrade", "websocket")
+ if ok != tt.ok {
+ t.Errorf("tokenListContainsValue(h, n, %q) = %v, want %v", tt.value, ok, tt.ok)
+ }
+ }
+}
+
+var isValidChallengeKeyTests = []struct {
+ key string
+ ok bool
+}{
+ {"dGhlIHNhbXBsZSBub25jZQ==", true},
+ {"", false},
+ {"InvalidKey", false},
+ {"WHQ4eXhscUtKYjBvOGN3WEdtOEQ=", false},
+}
+
+func TestIsValidChallengeKey(t *testing.T) {
+ for _, tt := range isValidChallengeKeyTests {
+ ok := isValidChallengeKey(tt.key)
+ if ok != tt.ok {
+ t.Errorf("isValidChallengeKey returns %v, want %v", ok, tt.ok)
+ }
+ }
+}
+
+var parseExtensionTests = []struct {
+ value string
+ extensions []map[string]string
+}{
+ {`foo`, []map[string]string{{"": "foo"}}},
+ {`foo, bar; baz=2`, []map[string]string{
+ {"": "foo"},
+ {"": "bar", "baz": "2"}}},
+ {`foo; bar="b,a;z"`, []map[string]string{
+ {"": "foo", "bar": "b,a;z"}}},
+ {`foo , bar; baz = 2`, []map[string]string{
+ {"": "foo"},
+ {"": "bar", "baz": "2"}}},
+ {`foo, bar; baz=2 junk`, []map[string]string{
+ {"": "foo"}}},
+ {`foo junk, bar; baz=2 junk`, nil},
+ {`mux; max-channels=4; flow-control, deflate-stream`, []map[string]string{
+ {"": "mux", "max-channels": "4", "flow-control": ""},
+ {"": "deflate-stream"}}},
+ {`permessage-foo; x="10"`, []map[string]string{
+ {"": "permessage-foo", "x": "10"}}},
+ {`permessage-foo; use_y, permessage-foo`, []map[string]string{
+ {"": "permessage-foo", "use_y": ""},
+ {"": "permessage-foo"}}},
+ {`permessage-deflate; client_max_window_bits; server_max_window_bits=10 , permessage-deflate; client_max_window_bits`, []map[string]string{
+ {"": "permessage-deflate", "client_max_window_bits": "", "server_max_window_bits": "10"},
+ {"": "permessage-deflate", "client_max_window_bits": ""}}},
+ {"permessage-deflate; server_no_context_takeover; client_max_window_bits=15", []map[string]string{
+ {"": "permessage-deflate", "server_no_context_takeover": "", "client_max_window_bits": "15"},
+ }},
+}
+
+func TestParseExtensions(t *testing.T) {
+ for _, tt := range parseExtensionTests {
+ h := http.Header{http.CanonicalHeaderKey("Sec-WebSocket-Extensions"): {tt.value}}
+ extensions := parseExtensions(h)
+ if !reflect.DeepEqual(extensions, tt.extensions) {
+ t.Errorf("parseExtensions(%q)\n = %v,\nwant %v", tt.value, extensions, tt.extensions)
+ }
+ }
+}
+
+func TestParseArgs(t *testing.T) {
+ given := ParseArgs([]string { "x", "y", "z" })
+ expected := CLIArgs {
+ FromAddr: "y",
+ ToAddr: "z",
+ }
+
+ g.AssertEqual(given, expected)
+}
+
+
+
+func MainTest() {
+ tests := []testing.InternalTest {
+ { "TestProxyDial", TestProxyDial },
+ { "TestProxyAuthorizationDial", TestProxyAuthorizationDial },
+ { "TestDial", TestDial },
+ { "TestDialCookieJar", TestDialCookieJar },
+ { "TestDialTLS", TestDialTLS },
+ { "TestDialTimeout", TestDialTimeout },
+ { "TestHandshakeTimeout", TestHandshakeTimeout },
+ { "TestHandshakeTimeoutInContext", TestHandshakeTimeoutInContext },
+ { "TestDialBadScheme", TestDialBadScheme },
+ { "TestDialBadOrigin", TestDialBadOrigin },
+ { "TestDialBadHeader", TestDialBadHeader },
+ { "TestBadMethod", TestBadMethod },
+ { "TestNoUpgrade", TestNoUpgrade },
+ { "TestDialExtraTokensInRespHeaders", TestDialExtraTokensInRespHeaders },
+ { "TestHandshake", TestHandshake },
+ { "TestRespOnBadHandshake", TestRespOnBadHandshake },
+ { "TestHost", TestHost },
+ { "TestDialCompression", TestDialCompression },
+ { "TestSocksProxyDial", TestSocksProxyDial },
+ { "TestTracingDialWithContext", TestTracingDialWithContext },
+ { "TestEmptyTracingDialWithContext", TestEmptyTracingDialWithContext },
+ { "TestNetDialConnect", TestNetDialConnect },
+ { "TestNextProtos", TestNextProtos },
+ { "TestDataReceivedBeforeHandshake", TestDataReceivedBeforeHandshake },
+ { "TestHostPortNoPort", TestHostPortNoPort },
+ { "TestTruncWriter", TestTruncWriter },
+ { "TestValidCompressionLevel", TestValidCompressionLevel },
+ { "TestFraming", TestFraming },
+ { "TestWriteControlDeadline", TestWriteControlDeadline },
+ { "TestConcurrencyWriteControl", TestConcurrencyWriteControl },
+ { "TestControl", TestControl },
+ { "TestWriteBufferPool", TestWriteBufferPool },
+ { "TestWriteBufferPoolSync", TestWriteBufferPoolSync },
+ { "TestWriteBufferPoolError", TestWriteBufferPoolError },
+ { "TestCloseFrameBeforeFinalMessageFrame", TestCloseFrameBeforeFinalMessageFrame },
+ { "TestEOFWithinFrame", TestEOFWithinFrame },
+ { "TestEOFBeforeFinalFrame", TestEOFBeforeFinalFrame },
+ { "TestWriteAfterMessageWriterClose", TestWriteAfterMessageWriterClose },
+ { "TestReadLimit", TestReadLimit },
+ { "TestAddrs", TestAddrs },
+ { "TestDeprecatedUnderlyingConn", TestDeprecatedUnderlyingConn },
+ { "TestNetConn", TestNetConn },
+ { "TestBufioReadBytes", TestBufioReadBytes },
+ { "TestCloseError", TestCloseError },
+ { "TestUnexpectedCloseErrors", TestUnexpectedCloseErrors },
+ { "TestConcurrentWritePanic", TestConcurrentWritePanic },
+ { "TestFailedConnectionReadPanic", TestFailedConnectionReadPanic },
+ { "TestJoinMessages", TestJoinMessages },
+ { "TestJSON", TestJSON },
+ { "TestPartialJSONRead", TestPartialJSONRead },
+ { "TestDeprecatedJSON", TestDeprecatedJSON },
+ { "TestMaskBytes", TestMaskBytes },
+ { "TestPreparedMessage", TestPreparedMessage },
+ { "TestSubprotocols", TestSubprotocols },
+ { "TestIsWebSocketUpgrade", TestIsWebSocketUpgrade },
+ { "TestSubProtocolSelection", TestSubProtocolSelection },
+ { "TestCheckSameOrigin", TestCheckSameOrigin },
+ { "TestHijack_NotSupported", TestHijack_NotSupported },
+ { "TestEqualASCIIFold", TestEqualASCIIFold },
+ { "TestTokenListContainsValue", TestTokenListContainsValue },
+ { "TestIsValidChallengeKey", TestIsValidChallengeKey },
+ { "TestParseExtensions", TestParseExtensions },
+ { "TestParseArgs", TestParseArgs },
+ }
+
+ // FIXME: run benchmarks
+ benchmarks := []testing.InternalBenchmark {
+ { "BenchmarkWriteNoCompression", BenchmarkWriteNoCompression },
+ { "BenchmarkWriteWithCompression", BenchmarkWriteWithCompression },
+ { "BenchmarkBroadcast", BenchmarkBroadcast },
+ { "BenchmarkMaskBytes", BenchmarkMaskBytes },
+ }
+
+ fuzzTargets := []testing.InternalFuzzTarget {}
+ examples := []testing.InternalExample {}
+ m := testing.MainStart(
+ testdeps.TestDeps {},
+ tests,
+ benchmarks,
+ fuzzTargets,
+ examples,
+ )
+ os.Exit(m.Run())
+}