// Package websocket implements the WebSocket protocol defined in RFC 6455. // // Overview // // The Conn type represents a WebSocket connection. A server application calls // the Upgrader.Upgrade method from an HTTP request handler to get a *Conn: // // var upgrader = websocket.Upgrader{ // ReadBufferSize: 1024, // WriteBufferSize: 1024, // } // // func handler(w http.ResponseWriter, r *http.Request) { // conn, err := upgrader.Upgrade(w, r, nil) // if err != nil { // log.Println(err) // return // } // ... Use conn to send and receive messages. // } // // Call the connection's WriteMessage and ReadMessage methods to send and // receive messages as a slice of bytes. This snippet of code shows how to echo // messages using these methods: // // for { // messageType, p, err := conn.ReadMessage() // if err != nil { // log.Println(err) // return // } // if err := conn.WriteMessage(messageType, p); err != nil { // log.Println(err) // return // } // } // // In above snippet of code, p is a []byte and messageType is an int with value // websocket.BinaryMessage or websocket.TextMessage. // // An application can also send and receive messages using the io.WriteCloser // and io.Reader interfaces. To send a message, call the connection NextWriter // method to get an io.WriteCloser, write the message to the writer and close // the writer when done. To receive a message, call the connection NextReader // method to get an io.Reader and read until io.EOF is returned. This snippet // shows how to echo messages using the NextWriter and NextReader methods: // // for { // messageType, r, err := conn.NextReader() // if err != nil { // return // } // w, err := conn.NextWriter(messageType) // if err != nil { // return err // } // if _, err := io.Copy(w, r); err != nil { // return err // } // if err := w.Close(); err != nil { // return err // } // } // // Data Messages // // The WebSocket protocol distinguishes between text and binary data messages. // Text messages are interpreted as UTF-8 encoded text. The interpretation of // binary messages is left to the application. // // This package uses the TextMessage and BinaryMessage integer constants to // identify the two data message types. The ReadMessage and NextReader methods // return the type of the received message. The messageType argument to the // WriteMessage and NextWriter methods specifies the type of a sent message. // // It is the application's responsibility to ensure that text messages are // valid UTF-8 encoded text. // // Control Messages // // The WebSocket protocol defines three types of control messages: close, ping // and pong. Call the connection WriteControl, WriteMessage or NextWriter // methods to send a control message to the peer. // // Connections handle received close messages by calling the handler function // set with the SetCloseHandler method and by returning a *CloseError from the // NextReader, ReadMessage or the message Read method. The default close // handler sends a close message to the peer. // // Connections handle received ping messages by calling the handler function // set with the SetPingHandler method. The default ping handler sends a pong // message to the peer. // // Connections handle received pong messages by calling the handler function // set with the SetPongHandler method. The default pong handler does nothing. // If an application sends ping messages, then the application should set a // pong handler to receive the corresponding pong. // // The control message handler functions are called from the NextReader, // ReadMessage and message reader Read methods. The default close and ping // handlers can block these methods for a short time when the handler writes to // the connection. // // The application must read the connection to process close, ping and pong // messages sent from the peer. If the application is not otherwise interested // in messages from the peer, then the application should start a goroutine to // read and discard messages from the peer. A simple example is: // // func readLoop(c *websocket.Conn) { // for { // if _, _, err := c.NextReader(); err != nil { // c.Close() // break // } // } // } // // Concurrency // // Connections support one concurrent reader and one concurrent writer. // // Applications are responsible for ensuring that no more than one goroutine // calls the write methods (NextWriter, SetWriteDeadline, WriteMessage, // WriteJSON, EnableWriteCompression, SetCompressionLevel) concurrently and // that no more than one goroutine calls the read methods (NextReader, // SetReadDeadline, ReadMessage, ReadJSON, SetPongHandler, SetPingHandler) // concurrently. // // The Close and WriteControl methods can be called concurrently with all other // methods. // // Origin Considerations // // Web browsers allow Javascript applications to open a WebSocket connection to // any host. It's up to the server to enforce an origin policy using the Origin // request header sent by the browser. // // The Upgrader calls the function specified in the CheckOrigin field to check // the origin. If the CheckOrigin function returns false, then the Upgrade // method fails the WebSocket handshake with HTTP status 403. // // If the CheckOrigin field is nil, then the Upgrader uses a safe default: fail // the handshake if the Origin request header is present and the Origin host is // not equal to the Host request header. // // Buffers // // Connections buffer network input and output to reduce the number // of system calls when reading or writing messages. // // Write buffers are also used for constructing WebSocket frames. See RFC 6455, // Section 5 for a discussion of message framing. A WebSocket frame header is // written to the network each time a write buffer is flushed to the network. // Decreasing the size of the write buffer can increase the amount of framing // overhead on the connection. // // The buffer sizes in bytes are specified by the ReadBufferSize and // WriteBufferSize fields in the Dialer and Upgrader. The Dialer uses a default // size of 4096 when a buffer size field is set to zero. The Upgrader reuses // buffers created by the HTTP server when a buffer size field is set to zero. // The HTTP server buffers have a size of 4096 at the time of this writing. // // The buffer sizes do not limit the size of a message that can be read or // written by a connection. // // Buffers are held for the lifetime of the connection by default. If the // Dialer or Upgrader WriteBufferPool field is set, then a connection holds the // write buffer only when writing a message. // // Applications should tune the buffer sizes to balance memory use and // performance. Increasing the buffer size uses more memory, but can reduce the // number of system calls to read or write the network. In the case of writing, // increasing the buffer size can reduce the number of frame headers written to // the network. // // Some guidelines for setting buffer parameters are: // // Limit the buffer sizes to the maximum expected message size. Buffers larger // than the largest message do not provide any benefit. // // Depending on the distribution of message sizes, setting the buffer size to // a value less than the maximum expected message size can greatly reduce memory // use with a small impact on performance. Here's an example: If 99% of the // messages are smaller than 256 bytes and the maximum message size is 512 // bytes, then a buffer size of 256 bytes will result in 1.01 more system calls // than a buffer size of 512 bytes. The memory savings is 50%. // // A write buffer pool is useful when the application has a modest number // writes over a large number of connections. when buffers are pooled, a larger // buffer size has a reduced impact on total memory use and has the benefit of // reducing system calls and frame overhead. // // Compression EXPERIMENTAL // // Per message compression extensions (RFC 7692) are experimentally supported // by this package in a limited capacity. Setting the EnableCompression option // to true in Dialer or Upgrader will attempt to negotiate per message deflate // support. // // var upgrader = websocket.Upgrader{ // EnableCompression: true, // } // // If compression was successfully negotiated with the connection's peer, any // message received in compressed form will be automatically decompressed. // All Read methods will return uncompressed bytes. // // Per message compression of messages written to a connection can be enabled // or disabled by calling the corresponding Conn method: // // conn.EnableWriteCompression(false) // // Currently this package does not support compression with "context takeover". // This means that messages must be compressed and decompressed in isolation, // without retaining sliding window or dictionary state across messages. For // more details refer to RFC 7692. // // Use of compression is experimental and may result in decreased performance. package wscat import ( "bufio" "bytes" "compress/flate" "context" "crypto/rand" "crypto/sha1" "encoding/base64" "encoding/binary" "encoding/json" "errors" "fmt" "io" "net" "net/http" "net/url" "os" "strconv" "strings" "sync" "time" "unicode/utf8" "unsafe" g "gobang" ) // ErrBadHandshake is returned when the server response to opening handshake is // invalid. var ErrBadHandshake = errors.New("websocket: bad handshake") func hostPortNoPort(u *url.URL) (hostPort, hostNoPort string) { hostPort = u.Host hostNoPort = u.Host if i := strings.LastIndex(u.Host, ":"); i > strings.LastIndex(u.Host, "]") { hostNoPort = hostNoPort[:i] } else { switch u.Scheme { case "wss": hostPort += ":443" case "https": hostPort += ":443" default: hostPort += ":80" } } return hostPort, hostNoPort } const ( minCompressionLevel = -2 // flate.HuffmanOnly not defined in Go < 1.6 maxCompressionLevel = flate.BestCompression defaultCompressionLevel = 1 ) var ( flateWriterPools [maxCompressionLevel - minCompressionLevel + 1]sync.Pool flateReaderPool = sync.Pool{New: func() interface{} { return flate.NewReader(nil) }} ) func decompressNoContextTakeover(r io.Reader) io.ReadCloser { const tail = // Add four bytes as specified in RFC "\x00\x00\xff\xff" + // Add final block to squelch unexpected EOF error from flate reader. "\x01\x00\x00\xff\xff" fr, _ := flateReaderPool.Get().(io.ReadCloser) mr := io.MultiReader(r, strings.NewReader(tail)) if err := fr.(flate.Resetter).Reset(mr, nil); err != nil { // Reset never fails, but handle error in case that changes. fr = flate.NewReader(mr) } return &flateReadWrapper{fr} } func isValidCompressionLevel(level int) bool { return minCompressionLevel <= level && level <= maxCompressionLevel } func compressNoContextTakeover(w io.WriteCloser, level int) io.WriteCloser { p := &flateWriterPools[level-minCompressionLevel] tw := &truncWriter{w: w} fw, _ := p.Get().(*flate.Writer) if fw == nil { fw, _ = flate.NewWriter(tw, level) } else { fw.Reset(tw) } return &flateWriteWrapper{fw: fw, tw: tw, p: p} } // truncWriter is an io.Writer that writes all but the last four bytes of the // stream to another io.Writer. type truncWriter struct { w io.WriteCloser n int p [4]byte } func (w *truncWriter) Write(p []byte) (int, error) { n := 0 // fill buffer first for simplicity. if w.n < len(w.p) { n = copy(w.p[w.n:], p) p = p[n:] w.n += n if len(p) == 0 { return n, nil } } m := len(p) if m > len(w.p) { m = len(w.p) } if nn, err := w.w.Write(w.p[:m]); err != nil { return n + nn, err } copy(w.p[:], w.p[m:]) copy(w.p[len(w.p)-m:], p[len(p)-m:]) nn, err := w.w.Write(p[:len(p)-m]) return n + nn, err } type flateWriteWrapper struct { fw *flate.Writer tw *truncWriter p *sync.Pool } func (w *flateWriteWrapper) Write(p []byte) (int, error) { if w.fw == nil { return 0, errWriteClosed } return w.fw.Write(p) } func (w *flateWriteWrapper) Close() error { if w.fw == nil { return errWriteClosed } err1 := w.fw.Flush() w.p.Put(w.fw) w.fw = nil if w.tw.p != [4]byte{0, 0, 0xff, 0xff} { return errors.New("websocket: internal error, unexpected bytes at end of flate stream") } err2 := w.tw.w.Close() if err1 != nil { return err1 } return err2 } type flateReadWrapper struct { fr io.ReadCloser } func (r *flateReadWrapper) Read(p []byte) (int, error) { if r.fr == nil { return 0, io.ErrClosedPipe } n, err := r.fr.Read(p) if err == io.EOF { // Preemptively place the reader back in the pool. This helps with // scenarios where the application does not call NextReader() soon after // this final read. r.Close() } return n, err } func (r *flateReadWrapper) Close() error { if r.fr == nil { return io.ErrClosedPipe } err := r.fr.Close() flateReaderPool.Put(r.fr) r.fr = nil return err } const ( // Frame header byte 0 bits from Section 5.2 of RFC 6455 finalBit = 1 << 7 rsv1Bit = 1 << 6 rsv2Bit = 1 << 5 rsv3Bit = 1 << 4 // Frame header byte 1 bits from Section 5.2 of RFC 6455 maskBit = 1 << 7 maxFrameHeaderSize = 2 + 8 + 4 // Fixed header + length + mask maxControlFramePayloadSize = 125 writeWait = time.Second defaultReadBufferSize = 4096 defaultWriteBufferSize = 4096 continuationFrame = 0 noFrame = -1 ) // Close codes defined in RFC 6455, section 11.7. const ( CloseNormalClosure = 1000 CloseGoingAway = 1001 CloseProtocolError = 1002 CloseUnsupportedData = 1003 CloseNoStatusReceived = 1005 CloseAbnormalClosure = 1006 CloseInvalidFramePayloadData = 1007 ClosePolicyViolation = 1008 CloseMessageTooBig = 1009 CloseMandatoryExtension = 1010 CloseInternalServerErr = 1011 CloseServiceRestart = 1012 CloseTryAgainLater = 1013 CloseTLSHandshake = 1015 ) // The message types are defined in RFC 6455, section 11.8. const ( // TextMessage denotes a text data message. The text message payload is // interpreted as UTF-8 encoded text data. TextMessage = 1 // BinaryMessage denotes a binary data message. BinaryMessage = 2 // CloseMessage denotes a close control message. The optional message // payload contains a numeric code and text. Use the FormatCloseMessage // function to format a close message payload. CloseMessage = 8 // PingMessage denotes a ping control message. The optional message payload // is UTF-8 encoded text. PingMessage = 9 // PongMessage denotes a pong control message. The optional message payload // is UTF-8 encoded text. PongMessage = 10 ) // ErrCloseSent is returned when the application writes a message to the // connection after sending a close message. var ErrCloseSent = errors.New("websocket: close sent") // ErrReadLimit is returned when reading a message that is larger than the // read limit set for the connection. var ErrReadLimit = errors.New("websocket: read limit exceeded") // netError satisfies the net Error interface. type netError struct { msg string temporary bool timeout bool } func (e *netError) Error() string { return e.msg } func (e *netError) Temporary() bool { return e.temporary } func (e *netError) Timeout() bool { return e.timeout } // CloseError represents a close message. type CloseError struct { // Code is defined in RFC 6455, section 11.7. Code int // Text is the optional text payload. Text string } func (e *CloseError) Error() string { s := []byte("websocket: close ") s = strconv.AppendInt(s, int64(e.Code), 10) switch e.Code { case CloseNormalClosure: s = append(s, " (normal)"...) case CloseGoingAway: s = append(s, " (going away)"...) case CloseProtocolError: s = append(s, " (protocol error)"...) case CloseUnsupportedData: s = append(s, " (unsupported data)"...) case CloseNoStatusReceived: s = append(s, " (no status)"...) case CloseAbnormalClosure: s = append(s, " (abnormal closure)"...) case CloseInvalidFramePayloadData: s = append(s, " (invalid payload data)"...) case ClosePolicyViolation: s = append(s, " (policy violation)"...) case CloseMessageTooBig: s = append(s, " (message too big)"...) case CloseMandatoryExtension: s = append(s, " (mandatory extension missing)"...) case CloseInternalServerErr: s = append(s, " (internal server error)"...) case CloseTLSHandshake: s = append(s, " (TLS handshake error)"...) } if e.Text != "" { s = append(s, ": "...) s = append(s, e.Text...) } return string(s) } // IsCloseError returns boolean indicating whether the error is a *CloseError // with one of the specified codes. func IsCloseError(err error, codes ...int) bool { if e, ok := err.(*CloseError); ok { for _, code := range codes { if e.Code == code { return true } } } return false } // IsUnexpectedCloseError returns boolean indicating whether the error is a // *CloseError with a code not in the list of expected codes. func IsUnexpectedCloseError(err error, expectedCodes ...int) bool { if e, ok := err.(*CloseError); ok { for _, code := range expectedCodes { if e.Code == code { return false } } return true } return false } var ( errWriteTimeout = &netError{msg: "websocket: write timeout", timeout: true, temporary: true} errUnexpectedEOF = &CloseError{Code: CloseAbnormalClosure, Text: io.ErrUnexpectedEOF.Error()} errBadWriteOpCode = errors.New("websocket: bad write message type") errWriteClosed = errors.New("websocket: write closed") errInvalidControlFrame = errors.New("websocket: invalid control frame") ) // maskRand is an io.Reader for generating mask bytes. The reader is initialized // to crypto/rand Reader. Tests swap the reader to a math/rand reader for // reproducible results. var maskRand = rand.Reader // newMaskKey returns a new 32 bit value for masking client frames. func newMaskKey() [4]byte { var k [4]byte _, _ = io.ReadFull(maskRand, k[:]) return k } func isControl(frameType int) bool { return frameType == CloseMessage || frameType == PingMessage || frameType == PongMessage } func isData(frameType int) bool { return frameType == TextMessage || frameType == BinaryMessage } var validReceivedCloseCodes = map[int]bool{ // see http://www.iana.org/assignments/websocket/websocket.xhtml#close-code-number CloseNormalClosure: true, CloseGoingAway: true, CloseProtocolError: true, CloseUnsupportedData: true, CloseNoStatusReceived: false, CloseAbnormalClosure: false, CloseInvalidFramePayloadData: true, ClosePolicyViolation: true, CloseMessageTooBig: true, CloseMandatoryExtension: true, CloseInternalServerErr: true, CloseServiceRestart: true, CloseTryAgainLater: true, CloseTLSHandshake: false, } func isValidReceivedCloseCode(code int) bool { return validReceivedCloseCodes[code] || (code >= 3000 && code <= 4999) } // BufferPool represents a pool of buffers. The *sync.Pool type satisfies this // interface. The type of the value stored in a pool is not specified. type BufferPool interface { // Get gets a value from the pool or returns nil if the pool is empty. Get() interface{} // Put adds a value to the pool. Put(interface{}) } // writePoolData is the type added to the write buffer pool. This wrapper is // used to prevent applications from peeking at and depending on the values // added to the pool. type writePoolData struct{ buf []byte } // The Conn type represents a WebSocket connection. type Conn struct { conn net.Conn isServer bool subprotocol string // Write fields mu chan struct{} // used as mutex to protect write to conn writeBuf []byte // frame is constructed in this buffer. writePool BufferPool writeBufSize int writeDeadline time.Time writer io.WriteCloser // the current writer returned to the application isWriting bool // for best-effort concurrent write detection writeErrMu sync.Mutex writeErr error enableWriteCompression bool compressionLevel int newCompressionWriter func(io.WriteCloser, int) io.WriteCloser // Read fields reader io.ReadCloser // the current reader returned to the application readErr error br *bufio.Reader // bytes remaining in current frame. // set setReadRemaining to safely update this value and prevent overflow readRemaining int64 readFinal bool // true the current message has more frames. readLength int64 // Message size. readLimit int64 // Maximum message size. readMaskPos int readMaskKey [4]byte handlePong func(string) error handlePing func(string) error handleClose func(int, string) error readErrCount int messageReader *messageReader // the current low-level reader readDecompress bool // whether last read frame had RSV1 set newDecompressionReader func(io.Reader) io.ReadCloser } func newConn(conn net.Conn, isServer bool, readBufferSize, writeBufferSize int, writeBufferPool BufferPool, br *bufio.Reader, writeBuf []byte) *Conn { if br == nil { if readBufferSize == 0 { readBufferSize = defaultReadBufferSize } else if readBufferSize < maxControlFramePayloadSize { // must be large enough for control frame readBufferSize = maxControlFramePayloadSize } br = bufio.NewReaderSize(conn, readBufferSize) } if writeBufferSize <= 0 { writeBufferSize = defaultWriteBufferSize } writeBufferSize += maxFrameHeaderSize if writeBuf == nil && writeBufferPool == nil { writeBuf = make([]byte, writeBufferSize) } mu := make(chan struct{}, 1) mu <- struct{}{} c := &Conn{ isServer: isServer, br: br, conn: conn, mu: mu, readFinal: true, writeBuf: writeBuf, writePool: writeBufferPool, writeBufSize: writeBufferSize, enableWriteCompression: true, compressionLevel: defaultCompressionLevel, } c.SetCloseHandler(nil) c.SetPingHandler(nil) c.SetPongHandler(nil) return c } // setReadRemaining tracks the number of bytes remaining on the connection. If n // overflows, an ErrReadLimit is returned. func (c *Conn) setReadRemaining(n int64) error { if n < 0 { return ErrReadLimit } c.readRemaining = n return nil } // Subprotocol returns the negotiated protocol for the connection. func (c *Conn) Subprotocol() string { return c.subprotocol } // Close closes the underlying network connection without sending or waiting // for a close message. func (c *Conn) Close() error { return c.conn.Close() } // LocalAddr returns the local network address. func (c *Conn) LocalAddr() net.Addr { return c.conn.LocalAddr() } // RemoteAddr returns the remote network address. func (c *Conn) RemoteAddr() net.Addr { return c.conn.RemoteAddr() } // Write methods func (c *Conn) writeFatal(err error) error { c.writeErrMu.Lock() if c.writeErr == nil { c.writeErr = err } c.writeErrMu.Unlock() return err } func (c *Conn) read(n int) ([]byte, error) { p, err := c.br.Peek(n) if err == io.EOF { err = errUnexpectedEOF } // Discard is guaranteed to succeed because the number of bytes to discard // is less than or equal to the number of bytes buffered. _, _ = c.br.Discard(len(p)) return p, err } func (c *Conn) write(frameType int, deadline time.Time, buf0, buf1 []byte) error { <-c.mu defer func() { c.mu <- struct{}{} }() c.writeErrMu.Lock() err := c.writeErr c.writeErrMu.Unlock() if err != nil { return err } if err := c.conn.SetWriteDeadline(deadline); err != nil { return c.writeFatal(err) } if len(buf1) == 0 { _, err = c.conn.Write(buf0) } else { err = c.writeBufs(buf0, buf1) } if err != nil { return c.writeFatal(err) } if frameType == CloseMessage { _ = c.writeFatal(ErrCloseSent) } return nil } func (c *Conn) writeBufs(bufs ...[]byte) error { b := net.Buffers(bufs) _, err := b.WriteTo(c.conn) return err } // WriteControl writes a control message with the given deadline. The allowed // message types are CloseMessage, PingMessage and PongMessage. func (c *Conn) WriteControl(messageType int, data []byte, deadline time.Time) error { if !isControl(messageType) { return errBadWriteOpCode } if len(data) > maxControlFramePayloadSize { return errInvalidControlFrame } b0 := byte(messageType) | finalBit b1 := byte(len(data)) if !c.isServer { b1 |= maskBit } buf := make([]byte, 0, maxFrameHeaderSize+maxControlFramePayloadSize) buf = append(buf, b0, b1) if c.isServer { buf = append(buf, data...) } else { key := newMaskKey() buf = append(buf, key[:]...) buf = append(buf, data...) maskBytes(key, 0, buf[6:]) } if deadline.IsZero() { // No timeout for zero time. <-c.mu } else { d := time.Until(deadline) if d < 0 { return errWriteTimeout } select { case <-c.mu: default: timer := time.NewTimer(d) select { case <-c.mu: timer.Stop() case <-timer.C: return errWriteTimeout } } } defer func() { c.mu <- struct{}{} }() c.writeErrMu.Lock() err := c.writeErr c.writeErrMu.Unlock() if err != nil { return err } if err := c.conn.SetWriteDeadline(deadline); err != nil { return c.writeFatal(err) } if _, err = c.conn.Write(buf); err != nil { return c.writeFatal(err) } if messageType == CloseMessage { _ = c.writeFatal(ErrCloseSent) } return err } // beginMessage prepares a connection and message writer for a new message. func (c *Conn) beginMessage(mw *messageWriter, messageType int) error { // Close previous writer if not already closed by the application. It's // probably better to return an error in this situation, but we cannot // change this without breaking existing applications. if c.writer != nil { c.writer.Close() c.writer = nil } if !isControl(messageType) && !isData(messageType) { return errBadWriteOpCode } c.writeErrMu.Lock() err := c.writeErr c.writeErrMu.Unlock() if err != nil { return err } mw.c = c mw.frameType = messageType mw.pos = maxFrameHeaderSize if c.writeBuf == nil { wpd, ok := c.writePool.Get().(writePoolData) if ok { c.writeBuf = wpd.buf } else { c.writeBuf = make([]byte, c.writeBufSize) } } return nil } // NextWriter returns a writer for the next message to send. The writer's Close // method flushes the complete message to the network. // // There can be at most one open writer on a connection. NextWriter closes the // previous writer if the application has not already done so. // // All message types (TextMessage, BinaryMessage, CloseMessage, PingMessage and // PongMessage) are supported. func (c *Conn) NextWriter(messageType int) (io.WriteCloser, error) { var mw messageWriter if err := c.beginMessage(&mw, messageType); err != nil { return nil, err } c.writer = &mw if c.newCompressionWriter != nil && c.enableWriteCompression && isData(messageType) { w := c.newCompressionWriter(c.writer, c.compressionLevel) mw.compress = true c.writer = w } return c.writer, nil } type messageWriter struct { c *Conn compress bool // whether next call to flushFrame should set RSV1 pos int // end of data in writeBuf. frameType int // type of the current frame. err error } func (w *messageWriter) endMessage(err error) error { if w.err != nil { return err } c := w.c w.err = err c.writer = nil if c.writePool != nil { c.writePool.Put(writePoolData{buf: c.writeBuf}) c.writeBuf = nil } return err } // flushFrame writes buffered data and extra as a frame to the network. The // final argument indicates that this is the last frame in the message. func (w *messageWriter) flushFrame(final bool, extra []byte) error { c := w.c length := w.pos - maxFrameHeaderSize + len(extra) // Check for invalid control frames. if isControl(w.frameType) && (!final || length > maxControlFramePayloadSize) { return w.endMessage(errInvalidControlFrame) } b0 := byte(w.frameType) if final { b0 |= finalBit } if w.compress { b0 |= rsv1Bit } w.compress = false b1 := byte(0) if !c.isServer { b1 |= maskBit } // Assume that the frame starts at beginning of c.writeBuf. framePos := 0 if c.isServer { // Adjust up if mask not included in the header. framePos = 4 } switch { case length >= 65536: c.writeBuf[framePos] = b0 c.writeBuf[framePos+1] = b1 | 127 binary.BigEndian.PutUint64(c.writeBuf[framePos+2:], uint64(length)) case length > 125: framePos += 6 c.writeBuf[framePos] = b0 c.writeBuf[framePos+1] = b1 | 126 binary.BigEndian.PutUint16(c.writeBuf[framePos+2:], uint16(length)) default: framePos += 8 c.writeBuf[framePos] = b0 c.writeBuf[framePos+1] = b1 | byte(length) } if !c.isServer { key := newMaskKey() copy(c.writeBuf[maxFrameHeaderSize-4:], key[:]) maskBytes(key, 0, c.writeBuf[maxFrameHeaderSize:w.pos]) if len(extra) > 0 { return w.endMessage(c.writeFatal(errors.New("websocket: internal error, extra used in client mode"))) } } // Write the buffers to the connection with best-effort detection of // concurrent writes. See the concurrency section in the package // documentation for more info. if c.isWriting { panic("concurrent write to websocket connection") } c.isWriting = true err := c.write(w.frameType, c.writeDeadline, c.writeBuf[framePos:w.pos], extra) if !c.isWriting { panic("concurrent write to websocket connection") } c.isWriting = false if err != nil { return w.endMessage(err) } if final { _ = w.endMessage(errWriteClosed) return nil } // Setup for next frame. w.pos = maxFrameHeaderSize w.frameType = continuationFrame return nil } func (w *messageWriter) ncopy(max int) (int, error) { n := len(w.c.writeBuf) - w.pos if n <= 0 { if err := w.flushFrame(false, nil); err != nil { return 0, err } n = len(w.c.writeBuf) - w.pos } if n > max { n = max } return n, nil } func (w *messageWriter) Write(p []byte) (int, error) { if w.err != nil { return 0, w.err } if len(p) > 2*len(w.c.writeBuf) && w.c.isServer { // Don't buffer large messages. err := w.flushFrame(false, p) if err != nil { return 0, err } return len(p), nil } nn := len(p) for len(p) > 0 { n, err := w.ncopy(len(p)) if err != nil { return 0, err } copy(w.c.writeBuf[w.pos:], p[:n]) w.pos += n p = p[n:] } return nn, nil } func (w *messageWriter) WriteString(p string) (int, error) { if w.err != nil { return 0, w.err } nn := len(p) for len(p) > 0 { n, err := w.ncopy(len(p)) if err != nil { return 0, err } copy(w.c.writeBuf[w.pos:], p[:n]) w.pos += n p = p[n:] } return nn, nil } func (w *messageWriter) ReadFrom(r io.Reader) (nn int64, err error) { if w.err != nil { return 0, w.err } for { if w.pos == len(w.c.writeBuf) { err = w.flushFrame(false, nil) if err != nil { break } } var n int n, err = r.Read(w.c.writeBuf[w.pos:]) w.pos += n nn += int64(n) if err != nil { if err == io.EOF { err = nil } break } } return nn, err } func (w *messageWriter) Close() error { if w.err != nil { return w.err } return w.flushFrame(true, nil) } // WritePreparedMessage writes prepared message into connection. func (c *Conn) WritePreparedMessage(pm *PreparedMessage) error { frameType, frameData, err := pm.frame(prepareKey{ isServer: c.isServer, compress: c.newCompressionWriter != nil && c.enableWriteCompression && isData(pm.messageType), compressionLevel: c.compressionLevel, }) if err != nil { return err } if c.isWriting { panic("concurrent write to websocket connection") } c.isWriting = true err = c.write(frameType, c.writeDeadline, frameData, nil) if !c.isWriting { panic("concurrent write to websocket connection") } c.isWriting = false return err } // WriteMessage is a helper method for getting a writer using NextWriter, // writing the message and closing the writer. func (c *Conn) WriteMessage(messageType int, data []byte) error { if c.isServer && (c.newCompressionWriter == nil || !c.enableWriteCompression) { // Fast path with no allocations and single frame. var mw messageWriter if err := c.beginMessage(&mw, messageType); err != nil { return err } n := copy(c.writeBuf[mw.pos:], data) mw.pos += n data = data[n:] return mw.flushFrame(true, data) } w, err := c.NextWriter(messageType) if err != nil { return err } if _, err = w.Write(data); err != nil { return err } return w.Close() } // SetWriteDeadline sets the write deadline on the underlying network // connection. After a write has timed out, the websocket state is corrupt and // all future writes will return an error. A zero value for t means writes will // not time out. func (c *Conn) SetWriteDeadline(t time.Time) error { c.writeDeadline = t return nil } // Read methods func (c *Conn) advanceFrame() (int, error) { // 1. Skip remainder of previous frame. if c.readRemaining > 0 { if _, err := io.CopyN(io.Discard, c.br, c.readRemaining); err != nil { return noFrame, err } } // 2. Read and parse first two bytes of frame header. // To aid debugging, collect and report all errors in the first two bytes // of the header. var errors []string p, err := c.read(2) if err != nil { return noFrame, err } frameType := int(p[0] & 0xf) final := p[0]&finalBit != 0 rsv1 := p[0]&rsv1Bit != 0 rsv2 := p[0]&rsv2Bit != 0 rsv3 := p[0]&rsv3Bit != 0 mask := p[1]&maskBit != 0 _ = c.setReadRemaining(int64(p[1] & 0x7f)) // will not fail because argument is >= 0 c.readDecompress = false if rsv1 { if c.newDecompressionReader != nil { c.readDecompress = true } else { errors = append(errors, "RSV1 set") } } if rsv2 { errors = append(errors, "RSV2 set") } if rsv3 { errors = append(errors, "RSV3 set") } switch frameType { case CloseMessage, PingMessage, PongMessage: if c.readRemaining > maxControlFramePayloadSize { errors = append(errors, "len > 125 for control") } if !final { errors = append(errors, "FIN not set on control") } case TextMessage, BinaryMessage: if !c.readFinal { errors = append(errors, "data before FIN") } c.readFinal = final case continuationFrame: if c.readFinal { errors = append(errors, "continuation after FIN") } c.readFinal = final default: errors = append(errors, "bad opcode "+strconv.Itoa(frameType)) } if mask != c.isServer { errors = append(errors, "bad MASK") } if len(errors) > 0 { return noFrame, c.handleProtocolError(strings.Join(errors, ", ")) } // 3. Read and parse frame length as per // https://tools.ietf.org/html/rfc6455#section-5.2 // // The length of the "Payload data", in bytes: if 0-125, that is the payload // length. // - If 126, the following 2 bytes interpreted as a 16-bit unsigned // integer are the payload length. // - If 127, the following 8 bytes interpreted as // a 64-bit unsigned integer (the most significant bit MUST be 0) are the // payload length. Multibyte length quantities are expressed in network byte // order. switch c.readRemaining { case 126: p, err := c.read(2) if err != nil { return noFrame, err } if err := c.setReadRemaining(int64(binary.BigEndian.Uint16(p))); err != nil { return noFrame, err } case 127: p, err := c.read(8) if err != nil { return noFrame, err } if err := c.setReadRemaining(int64(binary.BigEndian.Uint64(p))); err != nil { return noFrame, err } } // 4. Handle frame masking. if mask { c.readMaskPos = 0 p, err := c.read(len(c.readMaskKey)) if err != nil { return noFrame, err } copy(c.readMaskKey[:], p) } // 5. For text and binary messages, enforce read limit and return. if frameType == continuationFrame || frameType == TextMessage || frameType == BinaryMessage { c.readLength += c.readRemaining // Don't allow readLength to overflow in the presence of a large readRemaining // counter. if c.readLength < 0 { return noFrame, ErrReadLimit } if c.readLimit > 0 && c.readLength > c.readLimit { // Make a best effort to send a close message describing the problem. _ = c.WriteControl(CloseMessage, FormatCloseMessage(CloseMessageTooBig, ""), time.Now().Add(writeWait)) return noFrame, ErrReadLimit } return frameType, nil } // 6. Read control frame payload. var payload []byte if c.readRemaining > 0 { payload, err = c.read(int(c.readRemaining)) _ = c.setReadRemaining(0) // will not fail because argument is >= 0 if err != nil { return noFrame, err } if c.isServer { maskBytes(c.readMaskKey, 0, payload) } } // 7. Process control frame payload. switch frameType { case PongMessage: if err := c.handlePong(string(payload)); err != nil { return noFrame, err } case PingMessage: if err := c.handlePing(string(payload)); err != nil { return noFrame, err } case CloseMessage: closeCode := CloseNoStatusReceived closeText := "" if len(payload) >= 2 { closeCode = int(binary.BigEndian.Uint16(payload)) if !isValidReceivedCloseCode(closeCode) { return noFrame, c.handleProtocolError("bad close code " + strconv.Itoa(closeCode)) } closeText = string(payload[2:]) if !utf8.ValidString(closeText) { return noFrame, c.handleProtocolError("invalid utf8 payload in close frame") } } if err := c.handleClose(closeCode, closeText); err != nil { return noFrame, err } return noFrame, &CloseError{Code: closeCode, Text: closeText} } return frameType, nil } func (c *Conn) handleProtocolError(message string) error { data := FormatCloseMessage(CloseProtocolError, message) if len(data) > maxControlFramePayloadSize { data = data[:maxControlFramePayloadSize] } // Make a best effor to send a close message describing the problem. _ = c.WriteControl(CloseMessage, data, time.Now().Add(writeWait)) return errors.New("websocket: " + message) } // NextReader returns the next data message received from the peer. The // returned messageType is either TextMessage or BinaryMessage. // // There can be at most one open reader on a connection. NextReader discards // the previous message if the application has not already consumed it. // // Applications must break out of the application's read loop when this method // returns a non-nil error value. Errors returned from this method are // permanent. Once this method returns a non-nil error, all subsequent calls to // this method return the same error. func (c *Conn) NextReader() (messageType int, r io.Reader, err error) { // Close previous reader, only relevant for decompression. if c.reader != nil { c.reader.Close() c.reader = nil } c.messageReader = nil c.readLength = 0 for c.readErr == nil { frameType, err := c.advanceFrame() if err != nil { c.readErr = err break } if frameType == TextMessage || frameType == BinaryMessage { c.messageReader = &messageReader{c} c.reader = c.messageReader if c.readDecompress { c.reader = c.newDecompressionReader(c.reader) } return frameType, c.reader, nil } } // Applications that do handle the error returned from this method spin in // tight loop on connection failure. To help application developers detect // this error, panic on repeated reads to the failed connection. c.readErrCount++ if c.readErrCount >= 1000 { panic("repeated read on failed websocket connection") } return noFrame, nil, c.readErr } type messageReader struct{ c *Conn } func (r *messageReader) Read(b []byte) (int, error) { c := r.c if c.messageReader != r { return 0, io.EOF } for c.readErr == nil { if c.readRemaining > 0 { if int64(len(b)) > c.readRemaining { b = b[:c.readRemaining] } n, err := c.br.Read(b) c.readErr = err if c.isServer { c.readMaskPos = maskBytes(c.readMaskKey, c.readMaskPos, b[:n]) } rem := c.readRemaining rem -= int64(n) _ = c.setReadRemaining(rem) // rem is guaranteed to be >= 0 if c.readRemaining > 0 && c.readErr == io.EOF { c.readErr = errUnexpectedEOF } return n, c.readErr } if c.readFinal { c.messageReader = nil return 0, io.EOF } frameType, err := c.advanceFrame() switch { case err != nil: c.readErr = err case frameType == TextMessage || frameType == BinaryMessage: c.readErr = errors.New("websocket: internal error, unexpected text or binary in Reader") } } err := c.readErr if err == io.EOF && c.messageReader == r { err = errUnexpectedEOF } return 0, err } func (r *messageReader) Close() error { return nil } // ReadMessage is a helper method for getting a reader using NextReader and // reading from that reader to a buffer. func (c *Conn) ReadMessage() (messageType int, p []byte, err error) { var r io.Reader messageType, r, err = c.NextReader() if err != nil { return messageType, nil, err } p, err = io.ReadAll(r) return messageType, p, err } // SetReadDeadline sets the read deadline on the underlying network connection. // After a read has timed out, the websocket connection state is corrupt and // all future reads will return an error. A zero value for t means reads will // not time out. func (c *Conn) SetReadDeadline(t time.Time) error { return c.conn.SetReadDeadline(t) } // SetReadLimit sets the maximum size in bytes for a message read from the peer. If a // message exceeds the limit, the connection sends a close message to the peer // and returns ErrReadLimit to the application. func (c *Conn) SetReadLimit(limit int64) { c.readLimit = limit } // CloseHandler returns the current close handler func (c *Conn) CloseHandler() func(code int, text string) error { return c.handleClose } // SetCloseHandler sets the handler for close messages received from the peer. // The code argument to h is the received close code or CloseNoStatusReceived // if the close message is empty. The default close handler sends a close // message back to the peer. // // The handler function is called from the NextReader, ReadMessage and message // reader Read methods. The application must read the connection to process // close messages as described in the section on Control Messages above. // // The connection read methods return a CloseError when a close message is // received. Most applications should handle close messages as part of their // normal error handling. Applications should only set a close handler when the // application must perform some action before sending a close message back to // the peer. func (c *Conn) SetCloseHandler(h func(code int, text string) error) { if h == nil { h = func(code int, text string) error { message := FormatCloseMessage(code, "") // Make a best effor to send the close message. _ = c.WriteControl(CloseMessage, message, time.Now().Add(writeWait)) return nil } } c.handleClose = h } // PingHandler returns the current ping handler func (c *Conn) PingHandler() func(appData string) error { return c.handlePing } // SetPingHandler sets the handler for ping messages received from the peer. // The appData argument to h is the PING message application data. The default // ping handler sends a pong to the peer. // // The handler function is called from the NextReader, ReadMessage and message // reader Read methods. The application must read the connection to process // ping messages as described in the section on Control Messages above. func (c *Conn) SetPingHandler(h func(appData string) error) { if h == nil { h = func(message string) error { // Make a best effort to send the pong message. _ = c.WriteControl(PongMessage, []byte(message), time.Now().Add(writeWait)) return nil } } c.handlePing = h } // PongHandler returns the current pong handler func (c *Conn) PongHandler() func(appData string) error { return c.handlePong } // SetPongHandler sets the handler for pong messages received from the peer. // The appData argument to h is the PONG message application data. The default // pong handler does nothing. // // The handler function is called from the NextReader, ReadMessage and message // reader Read methods. The application must read the connection to process // pong messages as described in the section on Control Messages above. func (c *Conn) SetPongHandler(h func(appData string) error) { if h == nil { h = func(string) error { return nil } } c.handlePong = h } // NetConn returns the underlying connection that is wrapped by c. // Note that writing to or reading from this connection directly will corrupt the // WebSocket connection. func (c *Conn) NetConn() net.Conn { return c.conn } // EnableWriteCompression enables and disables write compression of // subsequent text and binary messages. This function is a noop if // compression was not negotiated with the peer. func (c *Conn) EnableWriteCompression(enable bool) { c.enableWriteCompression = enable } // SetCompressionLevel sets the flate compression level for subsequent text and // binary messages. This function is a noop if compression was not negotiated // with the peer. See the compress/flate package for a description of // compression levels. func (c *Conn) SetCompressionLevel(level int) error { if !isValidCompressionLevel(level) { return errors.New("websocket: invalid compression level") } c.compressionLevel = level return nil } // FormatCloseMessage formats closeCode and text as a WebSocket close message. // An empty message is returned for code CloseNoStatusReceived. func FormatCloseMessage(closeCode int, text string) []byte { if closeCode == CloseNoStatusReceived { // Return empty message because it's illegal to send // CloseNoStatusReceived. Return non-nil value in case application // checks for nil. return []byte{} } buf := make([]byte, 2+len(text)) binary.BigEndian.PutUint16(buf, uint16(closeCode)) copy(buf[2:], text) return buf } // JoinMessages concatenates received messages to create a single io.Reader. // The string term is appended to each message. The returned reader does not // support concurrent calls to the Read method. func JoinMessages(c *Conn, term string) io.Reader { return &joinReader{c: c, term: term} } type joinReader struct { c *Conn term string r io.Reader } func (r *joinReader) Read(p []byte) (int, error) { if r.r == nil { var err error _, r.r, err = r.c.NextReader() if err != nil { return 0, err } if r.term != "" { r.r = io.MultiReader(r.r, strings.NewReader(r.term)) } } n, err := r.r.Read(p) if err == io.EOF { err = nil r.r = nil } return n, err } // WriteJSON writes the JSON encoding of v as a message. // // See the documentation for encoding/json Marshal for details about the // conversion of Go values to JSON. func (c *Conn) WriteJSON(v interface{}) error { w, err := c.NextWriter(TextMessage) if err != nil { return err } err1 := json.NewEncoder(w).Encode(v) err2 := w.Close() if err1 != nil { return err1 } return err2 } // ReadJSON reads the next JSON-encoded message from the connection and stores // it in the value pointed to by v. // // See the documentation for the encoding/json Unmarshal function for details // about the conversion of JSON to a Go value. func (c *Conn) ReadJSON(v interface{}) error { _, r, err := c.NextReader() if err != nil { return err } err = json.NewDecoder(r).Decode(v) if err == io.EOF { // One value is expected in the message. err = io.ErrUnexpectedEOF } return err } /// go:build !appengine // +build !appengine const wordSize = int(unsafe.Sizeof(uintptr(0))) func maskBytesUnsafe(key [4]byte, pos int, b []byte) int { // Mask one byte at a time for small buffers. if len(b) < 2*wordSize { for i := range b { b[i] ^= key[pos&3] pos++ } return pos & 3 } // Mask one byte at a time to word boundary. if n := int(uintptr(unsafe.Pointer(&b[0]))) % wordSize; n != 0 { n = wordSize - n for i := range b[:n] { b[i] ^= key[pos&3] pos++ } b = b[n:] } // Create aligned word size key. var k [wordSize]byte for i := range k { k[i] = key[(pos+i)&3] } kw := *(*uintptr)(unsafe.Pointer(&k)) // Mask one word at a time. n := (len(b) / wordSize) * wordSize for i := 0; i < n; i += wordSize { *(*uintptr)(unsafe.Pointer(uintptr(unsafe.Pointer(&b[0])) + uintptr(i))) ^= kw } // Mask one byte at a time for remaining bytes. b = b[n:] for i := range b { b[i] ^= key[pos&3] pos++ } return pos & 3 } /// go:build appengine // +build appengine func maskBytes(key [4]byte, pos int, b []byte) int { for i := range b { b[i] ^= key[pos&3] pos++ } return pos & 3 } // PreparedMessage caches on the wire representations of a message payload. // Use PreparedMessage to efficiently send a message payload to multiple // connections. PreparedMessage is especially useful when compression is used // because the CPU and memory expensive compression operation can be executed // once for a given set of compression options. type PreparedMessage struct { messageType int data []byte mu sync.Mutex frames map[prepareKey]*preparedFrame } // prepareKey defines a unique set of options to cache prepared frames in PreparedMessage. type prepareKey struct { isServer bool compress bool compressionLevel int } // preparedFrame contains data in wire representation. type preparedFrame struct { once sync.Once data []byte } // NewPreparedMessage returns an initialized PreparedMessage. You can then send // it to connection using WritePreparedMessage method. Valid wire // representation will be calculated lazily only once for a set of current // connection options. func NewPreparedMessage(messageType int, data []byte) (*PreparedMessage, error) { pm := &PreparedMessage{ messageType: messageType, frames: make(map[prepareKey]*preparedFrame), data: data, } // Prepare a plain server frame. _, frameData, err := pm.frame(prepareKey{isServer: true, compress: false}) if err != nil { return nil, err } // To protect against caller modifying the data argument, remember the data // copied to the plain server frame. pm.data = frameData[len(frameData)-len(data):] return pm, nil } func (pm *PreparedMessage) frame(key prepareKey) (int, []byte, error) { pm.mu.Lock() frame, ok := pm.frames[key] if !ok { frame = &preparedFrame{} pm.frames[key] = frame } pm.mu.Unlock() var err error frame.once.Do(func() { // Prepare a frame using a 'fake' connection. // TODO: Refactor code in conn.go to allow more direct construction of // the frame. mu := make(chan struct{}, 1) mu <- struct{}{} var nc prepareConn c := &Conn{ conn: &nc, mu: mu, isServer: key.isServer, compressionLevel: key.compressionLevel, enableWriteCompression: true, writeBuf: make([]byte, defaultWriteBufferSize+maxFrameHeaderSize), } if key.compress { c.newCompressionWriter = compressNoContextTakeover } err = c.WriteMessage(pm.messageType, pm.data) frame.data = nc.buf.Bytes() }) return pm.messageType, frame.data, err } type prepareConn struct { buf bytes.Buffer net.Conn } func (pc *prepareConn) Write(p []byte) (int, error) { return pc.buf.Write(p) } func (pc *prepareConn) SetWriteDeadline(t time.Time) error { return nil } type netDialerFunc func(ctx context.Context, network, addr string) (net.Conn, error) func (fn netDialerFunc) Dial(network, addr string) (net.Conn, error) { return fn(context.Background(), network, addr) } func (fn netDialerFunc) DialContext(ctx context.Context, network, addr string) (net.Conn, error) { return fn(ctx, network, addr) } // HandshakeError describes an error with the handshake from the peer. type HandshakeError struct { message string } func (e HandshakeError) Error() string { return e.message } // Upgrader specifies parameters for upgrading an HTTP connection to a // WebSocket connection. // // It is safe to call Upgrader's methods concurrently. type Upgrader struct { // HandshakeTimeout specifies the duration for the handshake to complete. HandshakeTimeout time.Duration // ReadBufferSize and WriteBufferSize specify I/O buffer sizes in bytes. If a buffer // size is zero, then buffers allocated by the HTTP server are used. The // I/O buffer sizes do not limit the size of the messages that can be sent // or received. ReadBufferSize, WriteBufferSize int // WriteBufferPool is a pool of buffers for write operations. If the value // is not set, then write buffers are allocated to the connection for the // lifetime of the connection. // // A pool is most useful when the application has a modest volume of writes // across a large number of connections. // // Applications should use a single pool for each unique value of // WriteBufferSize. WriteBufferPool BufferPool // Subprotocols specifies the server's supported protocols in order of // preference. If this field is not nil, then the Upgrade method negotiates a // subprotocol by selecting the first match in this list with a protocol // requested by the client. If there's no match, then no protocol is // negotiated (the Sec-Websocket-Protocol header is not included in the // handshake response). Subprotocols []string // Error specifies the function for generating HTTP error responses. If Error // is nil, then http.Error is used to generate the HTTP response. Error func(w http.ResponseWriter, r *http.Request, status int, reason error) // CheckOrigin returns true if the request Origin header is acceptable. If // CheckOrigin is nil, then a safe default is used: return false if the // Origin request header is present and the origin host is not equal to // request Host header. // // A CheckOrigin function should carefully validate the request origin to // prevent cross-site request forgery. CheckOrigin func(r *http.Request) bool // EnableCompression specify if the server should attempt to negotiate per // message compression (RFC 7692). Setting this value to true does not // guarantee that compression will be supported. Currently only "no context // takeover" modes are supported. EnableCompression bool } func (u *Upgrader) returnError(w http.ResponseWriter, r *http.Request, status int, reason string) (*Conn, error) { err := HandshakeError{reason} if u.Error != nil { u.Error(w, r, status, err) } else { w.Header().Set("Sec-Websocket-Version", "13") http.Error(w, http.StatusText(status), status) } return nil, err } // checkSameOrigin returns true if the origin is not set or is equal to the request host. func checkSameOrigin(r *http.Request) bool { origin := r.Header["Origin"] if len(origin) == 0 { return true } u, err := url.Parse(origin[0]) if err != nil { return false } return equalASCIIFold(u.Host, r.Host) } func (u *Upgrader) selectSubprotocol(r *http.Request, responseHeader http.Header) string { if u.Subprotocols != nil { clientProtocols := Subprotocols(r) for _, clientProtocol := range clientProtocols { for _, serverProtocol := range u.Subprotocols { if clientProtocol == serverProtocol { return clientProtocol } } } } else if responseHeader != nil { return responseHeader.Get("Sec-Websocket-Protocol") } return "" } // Upgrade upgrades the HTTP server connection to the WebSocket protocol. // // The responseHeader is included in the response to the client's upgrade // request. Use the responseHeader to specify cookies (Set-Cookie). To specify // subprotocols supported by the server, set Upgrader.Subprotocols directly. // // If the upgrade fails, then Upgrade replies to the client with an HTTP error // response. func (u *Upgrader) Upgrade(w http.ResponseWriter, r *http.Request, responseHeader http.Header) (*Conn, error) { const badHandshake = "websocket: the client is not using the websocket protocol: " if !tokenListContainsValue(r.Header, "Connection", "upgrade") { return u.returnError(w, r, http.StatusBadRequest, badHandshake+"'upgrade' token not found in 'Connection' header") } if !tokenListContainsValue(r.Header, "Upgrade", "websocket") { w.Header().Set("Upgrade", "websocket") return u.returnError(w, r, http.StatusUpgradeRequired, badHandshake+"'websocket' token not found in 'Upgrade' header") } if r.Method != http.MethodGet { return u.returnError(w, r, http.StatusMethodNotAllowed, badHandshake+"request method is not GET") } if !tokenListContainsValue(r.Header, "Sec-Websocket-Version", "13") { return u.returnError(w, r, http.StatusBadRequest, "websocket: unsupported version: 13 not found in 'Sec-Websocket-Version' header") } if _, ok := responseHeader["Sec-Websocket-Extensions"]; ok { return u.returnError(w, r, http.StatusInternalServerError, "websocket: application specific 'Sec-WebSocket-Extensions' headers are unsupported") } checkOrigin := u.CheckOrigin if checkOrigin == nil { checkOrigin = checkSameOrigin } if !checkOrigin(r) { return u.returnError(w, r, http.StatusForbidden, "websocket: request origin not allowed by Upgrader.CheckOrigin") } challengeKey := r.Header.Get("Sec-Websocket-Key") if !isValidChallengeKey(challengeKey) { return u.returnError(w, r, http.StatusBadRequest, "websocket: not a websocket handshake: 'Sec-WebSocket-Key' header must be Base64 encoded value of 16-byte in length") } subprotocol := u.selectSubprotocol(r, responseHeader) // Negotiate PMCE var compress bool if u.EnableCompression { for _, ext := range parseExtensions(r.Header) { if ext[""] != "permessage-deflate" { continue } compress = true break } } netConn, brw, err := http.NewResponseController(w).Hijack() if err != nil { return u.returnError(w, r, http.StatusInternalServerError, "websocket: hijack: "+err.Error()) } // Close the network connection when returning an error. The variable // netConn is set to nil before the success return at the end of the // function. defer func() { if netConn != nil { // It's safe to ignore the error from Close() because this code is // only executed when returning a more important error to the // application. _ = netConn.Close() } }() var br *bufio.Reader if u.ReadBufferSize == 0 && brw.Reader.Size() > 256 { // Use hijacked buffered reader as the connection reader. br = brw.Reader } else if brw.Reader.Buffered() > 0 { // Wrap the network connection to read buffered data in brw.Reader // before reading from the network connection. This should be rare // because a client must not send message data before receiving the // handshake response. netConn = &brNetConn{br: brw.Reader, Conn: netConn} } buf := brw.Writer.AvailableBuffer() var writeBuf []byte if u.WriteBufferPool == nil && u.WriteBufferSize == 0 && len(buf) >= maxFrameHeaderSize+256 { // Reuse hijacked write buffer as connection buffer. writeBuf = buf } c := newConn(netConn, true, u.ReadBufferSize, u.WriteBufferSize, u.WriteBufferPool, br, writeBuf) c.subprotocol = subprotocol if compress { c.newCompressionWriter = compressNoContextTakeover c.newDecompressionReader = decompressNoContextTakeover } // Use larger of hijacked buffer and connection write buffer for header. p := buf if len(c.writeBuf) > len(p) { p = c.writeBuf } p = p[:0] p = append(p, "HTTP/1.1 101 Switching Protocols\r\nUpgrade: websocket\r\nConnection: Upgrade\r\nSec-WebSocket-Accept: "...) p = append(p, computeAcceptKey(challengeKey)...) p = append(p, "\r\n"...) if c.subprotocol != "" { p = append(p, "Sec-WebSocket-Protocol: "...) p = append(p, c.subprotocol...) p = append(p, "\r\n"...) } if compress { p = append(p, "Sec-WebSocket-Extensions: permessage-deflate; server_no_context_takeover; client_no_context_takeover\r\n"...) } for k, vs := range responseHeader { if k == "Sec-Websocket-Protocol" { continue } for _, v := range vs { p = append(p, k...) p = append(p, ": "...) for i := 0; i < len(v); i++ { b := v[i] if b <= 31 { // prevent response splitting. b = ' ' } p = append(p, b) } p = append(p, "\r\n"...) } } p = append(p, "\r\n"...) if u.HandshakeTimeout > 0 { if err := netConn.SetWriteDeadline(time.Now().Add(u.HandshakeTimeout)); err != nil { return nil, err } } else { // Clear deadlines set by HTTP server. if err := netConn.SetDeadline(time.Time{}); err != nil { return nil, err } } if _, err = netConn.Write(p); err != nil { return nil, err } if u.HandshakeTimeout > 0 { if err := netConn.SetWriteDeadline(time.Time{}); err != nil { return nil, err } } // Success! Set netConn to nil to stop the deferred function above from // closing the network connection. netConn = nil return c, nil } // Subprotocols returns the subprotocols requested by the client in the // Sec-Websocket-Protocol header. func Subprotocols(r *http.Request) []string { h := strings.TrimSpace(r.Header.Get("Sec-Websocket-Protocol")) if h == "" { return nil } protocols := strings.Split(h, ",") for i := range protocols { protocols[i] = strings.TrimSpace(protocols[i]) } return protocols } // IsWebSocketUpgrade returns true if the client requested upgrade to the // WebSocket protocol. func IsWebSocketUpgrade(r *http.Request) bool { return tokenListContainsValue(r.Header, "Connection", "upgrade") && tokenListContainsValue(r.Header, "Upgrade", "websocket") } type brNetConn struct { br *bufio.Reader net.Conn } func (b *brNetConn) Read(p []byte) (n int, err error) { if b.br != nil { // Limit read to buferred data. if n := b.br.Buffered(); len(p) > n { p = p[:n] } n, err = b.br.Read(p) if b.br.Buffered() == 0 { b.br = nil } return n, err } return b.Conn.Read(p) } // NetConn returns the underlying connection that is wrapped by b. func (b *brNetConn) NetConn() net.Conn { return b.Conn } var keyGUID = []byte("258EAFA5-E914-47DA-95CA-C5AB0DC85B11") func computeAcceptKey(challengeKey string) string { h := sha1.New() h.Write([]byte(challengeKey)) h.Write(keyGUID) return base64.StdEncoding.EncodeToString(h.Sum(nil)) } func generateChallengeKey() (string, error) { p := make([]byte, 16) if _, err := io.ReadFull(rand.Reader, p); err != nil { return "", err } return base64.StdEncoding.EncodeToString(p), nil } // Token octets per RFC 2616. var isTokenOctet = [256]bool{ '!': true, '#': true, '$': true, '%': true, '&': true, '\'': true, '*': true, '+': true, '-': true, '.': true, '0': true, '1': true, '2': true, '3': true, '4': true, '5': true, '6': true, '7': true, '8': true, '9': true, 'A': true, 'B': true, 'C': true, 'D': true, 'E': true, 'F': true, 'G': true, 'H': true, 'I': true, 'J': true, 'K': true, 'L': true, 'M': true, 'N': true, 'O': true, 'P': true, 'Q': true, 'R': true, 'S': true, 'T': true, 'U': true, 'W': true, 'V': true, 'X': true, 'Y': true, 'Z': true, '^': true, '_': true, '`': true, 'a': true, 'b': true, 'c': true, 'd': true, 'e': true, 'f': true, 'g': true, 'h': true, 'i': true, 'j': true, 'k': true, 'l': true, 'm': true, 'n': true, 'o': true, 'p': true, 'q': true, 'r': true, 's': true, 't': true, 'u': true, 'v': true, 'w': true, 'x': true, 'y': true, 'z': true, '|': true, '~': true, } // skipSpace returns a slice of the string s with all leading RFC 2616 linear // whitespace removed. func skipSpace(s string) (rest string) { i := 0 for ; i < len(s); i++ { if b := s[i]; b != ' ' && b != '\t' { break } } return s[i:] } // nextToken returns the leading RFC 2616 token of s and the string following // the token. func nextToken(s string) (token, rest string) { i := 0 for ; i < len(s); i++ { if !isTokenOctet[s[i]] { break } } return s[:i], s[i:] } // nextTokenOrQuoted returns the leading token or quoted string per RFC 2616 // and the string following the token or quoted string. func nextTokenOrQuoted(s string) (value string, rest string) { if !strings.HasPrefix(s, "\"") { return nextToken(s) } s = s[1:] for i := 0; i < len(s); i++ { switch s[i] { case '"': return s[:i], s[i+1:] case '\\': p := make([]byte, len(s)-1) j := copy(p, s[:i]) escape := true for i = i + 1; i < len(s); i++ { b := s[i] switch { case escape: escape = false p[j] = b j++ case b == '\\': escape = true case b == '"': return string(p[:j]), s[i+1:] default: p[j] = b j++ } } return "", "" } } return "", "" } // equalASCIIFold returns true if s is equal to t with ASCII case folding as // defined in RFC 4790. func equalASCIIFold(s, t string) bool { for s != "" && t != "" { sr, size := utf8.DecodeRuneInString(s) s = s[size:] tr, size := utf8.DecodeRuneInString(t) t = t[size:] if sr == tr { continue } if 'A' <= sr && sr <= 'Z' { sr = sr + 'a' - 'A' } if 'A' <= tr && tr <= 'Z' { tr = tr + 'a' - 'A' } if sr != tr { return false } } return s == t } // tokenListContainsValue returns true if the 1#token header with the given // name contains a token equal to value with ASCII case folding. func tokenListContainsValue(header http.Header, name string, value string) bool { headers: for _, s := range header[name] { for { var t string t, s = nextToken(skipSpace(s)) if t == "" { continue headers } s = skipSpace(s) if s != "" && s[0] != ',' { continue headers } if equalASCIIFold(t, value) { return true } if s == "" { continue headers } s = s[1:] } } return false } // parseExtensions parses WebSocket extensions from a header. func parseExtensions(header http.Header) []map[string]string { // From RFC 6455: // // Sec-WebSocket-Extensions = extension-list // extension-list = 1#extension // extension = extension-token *( ";" extension-param ) // extension-token = registered-token // registered-token = token // extension-param = token [ "=" (token | quoted-string) ] // ;When using the quoted-string syntax variant, the value // ;after quoted-string unescaping MUST conform to the // ;'token' ABNF. var result []map[string]string headers: for _, s := range header["Sec-Websocket-Extensions"] { for { var t string t, s = nextToken(skipSpace(s)) if t == "" { continue headers } ext := map[string]string{"": t} for { s = skipSpace(s) if !strings.HasPrefix(s, ";") { break } var k string k, s = nextToken(skipSpace(s[1:])) if k == "" { continue headers } s = skipSpace(s) var v string if strings.HasPrefix(s, "=") { v, s = nextTokenOrQuoted(skipSpace(s[1:])) s = skipSpace(s) } if s != "" && s[0] != ',' && s[0] != ';' { continue headers } ext[k] = v } if s != "" && s[0] != ',' { continue headers } result = append(result, ext) if s == "" { continue headers } s = s[1:] } } return result } // isValidChallengeKey checks if the argument meets RFC6455 specification. func isValidChallengeKey(s string) bool { // From RFC6455: // // A |Sec-WebSocket-Key| header field with a base64-encoded (see // Section 4 of [RFC4648]) value that, when decoded, is 16 bytes in // length. if s == "" { return false } decoded, err := base64.StdEncoding.DecodeString(s) return err == nil && len(decoded) == 16 } type _CLIArgs struct { fromAddr string toAddr string } var emitActiveConnection = g.MakeGauge("active-connections") func parseArgs(args []string) _CLIArgs { if len(args) != 3 { fmt.Fprintf( os.Stderr, "Usage: %s FROM.socket TO.socket\n", args[0], ) os.Exit(2) } return _CLIArgs { fromAddr: args[1], toAddr: args[2], } } func listen(fromAddr string) net.Listener { listener, err := net.Listen("unix", fromAddr) g.FatalIf(err) g.Info("Started listening", "listen-start", "from-address", fromAddr) return listener } func copyData(c chan struct {}, from io.Reader, to io.WriteCloser) { io.Copy(to, from) c <- struct {} {} } func Start(toAddr string, listener net.Listener) { upgrader := Upgrader {} http.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { connFrom, err := upgrader.Upgrade(w, r, nil) if err != nil { g.Error( "Error upgrading connection", "upgrade-connection-error", "err", err, ) return } defer connFrom.Close() emitActiveConnection.Inc() connTo, err := net.Dial("unix", toAddr) if err != nil { g.Error( "Error dialing connection", "dial-connection-error", "err", err, ) return } defer connTo.Close() messageType, reader, err := connFrom.NextReader() if err != nil { g.Error( "Failed to get next reader from connection", "connection-next-reader-error", "err", err, ) return } writer, err := connFrom.NextWriter(messageType) if err != nil { g.Error( "Failed to get next writer from connection", "connection-next-writer-error", "err", err, ) return } c := make(chan struct {}) go copyData(c, connTo, writer) go copyData(c, reader, connTo) go func() { <- c emitActiveConnection.Dec() }() }); server := http.Server{} err := server.Serve(listener) g.FatalIf(err) } func Main() { g.Init() args := parseArgs(os.Args) listener := listen(args.fromAddr) Start(args.toAddr, listener) }