diff --git a/README.md b/README.md index 47cc3296..1ba912a4 100644 --- a/README.md +++ b/README.md @@ -123,24 +123,17 @@ it has to reinvent hooks for TLS and proxies and prevents support of HTTP/2. Some more advantages of nhooyr/websocket are that it supports concurrent writes and makes it very easy to close the connection with a status code and reason. -nhooyr/websocket also responds to pings, pongs and close frames in a separate goroutine so that -your application doesn't always need to read from the connection unless it expects a data message. -gorilla/websocket requires you to constantly read from the connection to respond to control frames -even if you don't expect the peer to send any messages. - The ping API is also much nicer. gorilla/websocket requires registering a pong handler on the Conn which results in awkward control flow. With nhooyr/websocket you use the Ping method on the Conn that sends a ping and also waits for the pong. -In terms of performance, the differences depend on your application code. nhooyr/websocket -reuses buffers efficiently out of the box if you use the wsjson and wspb subpackages whereas -gorilla/websocket does not. As mentioned above, nhooyr/websocket also supports concurrent -writers out of the box. +In terms of performance, the differences mostly depend on your application code. nhooyr/websocket +reuses message buffers out of the box if you use the wsjson and wspb subpackages. +As mentioned above, nhooyr/websocket also supports concurrent writers. -The only performance con to nhooyr/websocket is that uses two extra goroutines. One for -reading pings, pongs and close frames async to application code and another to support -context.Context cancellation. This costs 4 KB of memory which is cheap compared -to the benefits. +The only performance con to nhooyr/websocket is that uses one extra goroutine to support +cancellation with context.Context and the net/http client side body upgrade. +This costs 2 KB of memory which is cheap compared to simplicity benefits. ### x/net/websocket diff --git a/accept.go b/accept.go index bf2ed3c8..ca1eeeaf 100644 --- a/accept.go +++ b/accept.go @@ -81,9 +81,6 @@ func verifyClientRequest(w http.ResponseWriter, r *http.Request) error { // Accept will reject the handshake if the Origin domain is not the same as the Host unless // the InsecureSkipVerify option is set. In other words, by default it does not allow // cross origin requests. -// -// The returned connection will be bound by r.Context(). Use conn.Context() to change -// the bounding context. func Accept(w http.ResponseWriter, r *http.Request, opts AcceptOptions) (*Conn, error) { c, err := accept(w, r, opts) if err != nil { @@ -109,7 +106,7 @@ func accept(w http.ResponseWriter, r *http.Request, opts AcceptOptions) (*Conn, hj, ok := w.(http.Hijacker) if !ok { err = xerrors.New("passed ResponseWriter does not implement http.Hijacker") - http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError) + http.Error(w, http.StatusText(http.StatusNotImplemented), http.StatusNotImplemented) return nil, err } @@ -143,7 +140,6 @@ func accept(w http.ResponseWriter, r *http.Request, opts AcceptOptions) (*Conn, closer: netConn, } c.init() - c.Context(r.Context()) return c, nil } diff --git a/example_test.go b/example_test.go index 57f0aa5e..050af907 100644 --- a/example_test.go +++ b/example_test.go @@ -59,3 +59,45 @@ func ExampleDial() { c.Close(websocket.StatusNormalClosure, "") } + +// This example shows how to correctly handle a WebSocket connection +// on which you will only write and do not expect to read data messages. +func Example_writeOnly() { + fn := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + c, err := websocket.Accept(w, r, websocket.AcceptOptions{}) + if err != nil { + log.Println(err) + return + } + defer c.Close(websocket.StatusInternalError, "the sky is falling") + + ctx, cancel := context.WithTimeout(r.Context(), time.Minute*10) + defer cancel() + + go func() { + defer cancel() + c.Reader(ctx) + c.Close(websocket.StatusPolicyViolation, "server doesn't accept data messages") + }() + + t := time.NewTicker(time.Second * 30) + defer t.Stop() + + for { + select { + case <-ctx.Done(): + c.Close(websocket.StatusNormalClosure, "") + return + case <-t.C: + err = wsjson.Write(ctx, c, "hi") + if err != nil { + log.Println(err) + return + } + } + } + }) + + err := http.ListenAndServe("localhost:8080", fn) + log.Fatal(err) +} diff --git a/header.go b/header.go index 62b30b38..16ab6474 100644 --- a/header.go +++ b/header.go @@ -31,10 +31,19 @@ type header struct { maskKey [4]byte } +func makeWriteHeaderBuf() []byte { + return make([]byte, maxHeaderSize) +} + // bytes returns the bytes of the header. // See https://tools.ietf.org/html/rfc6455#section-5.2 -func marshalHeader(h header) []byte { - b := make([]byte, 2, maxHeaderSize) +func writeHeader(b []byte, h header) []byte { + if b == nil { + b = makeWriteHeaderBuf() + } + + b = b[:2] + b[0] = 0 if h.fin { b[0] |= 1 << 7 @@ -75,12 +84,20 @@ func marshalHeader(h header) []byte { return b } +func makeReadHeaderBuf() []byte { + return make([]byte, maxHeaderSize-2) +} + // readHeader reads a header from the reader. // See https://tools.ietf.org/html/rfc6455#section-5.2 -func readHeader(r io.Reader) (header, error) { - // We read the first two bytes directly so that we know +func readHeader(b []byte, r io.Reader) (header, error) { + if b == nil { + b = makeReadHeaderBuf() + } + + // We read the first two bytes first so that we know // exactly how long the header is. - b := make([]byte, 2, maxHeaderSize-2) + b = b[:2] _, err := io.ReadFull(r, b) if err != nil { return header{}, err diff --git a/header_test.go b/header_test.go index b9cf351b..b45854ea 100644 --- a/header_test.go +++ b/header_test.go @@ -24,7 +24,7 @@ func TestHeader(t *testing.T) { t.Run("readNegativeLength", func(t *testing.T) { t.Parallel() - b := marshalHeader(header{ + b := writeHeader(nil, header{ payloadLength: 1<<16 + 1, }) @@ -32,7 +32,7 @@ func TestHeader(t *testing.T) { b[2] |= 1 << 7 r := bytes.NewReader(b) - _, err := readHeader(r) + _, err := readHeader(nil, r) if err == nil { t.Fatalf("unexpected error value: %+v", err) } @@ -90,9 +90,9 @@ func TestHeader(t *testing.T) { } func testHeader(t *testing.T, h header) { - b := marshalHeader(h) + b := writeHeader(nil, h) r := bytes.NewReader(b) - h2, err := readHeader(r) + h2, err := readHeader(nil, r) if err != nil { t.Logf("header: %#v", h) t.Logf("bytes: %b", b) diff --git a/internal/bpool/bpool_test.go b/internal/bpool/bpool_test.go index 2b302a47..5dfe56e6 100644 --- a/internal/bpool/bpool_test.go +++ b/internal/bpool/bpool_test.go @@ -32,7 +32,6 @@ func BenchmarkSyncPool(b *testing.B) { p := sync.Pool{} - b.ResetTimer() for i := 0; i < b.N; i++ { buf := p.Get() if buf == nil { diff --git a/limitedreader.go b/limitedreader.go deleted file mode 100644 index 63bf40c4..00000000 --- a/limitedreader.go +++ /dev/null @@ -1,34 +0,0 @@ -package websocket - -import ( - "fmt" - "io" - - "golang.org/x/xerrors" -) - -type limitedReader struct { - c *Conn - r io.Reader - left int64 - limit int64 -} - -func (lr *limitedReader) Read(p []byte) (int, error) { - if lr.limit == 0 { - lr.limit = lr.left - } - - if lr.left <= 0 { - msg := fmt.Sprintf("read limited at %v bytes", lr.limit) - lr.c.Close(StatusPolicyViolation, msg) - return 0, xerrors.Errorf(msg) - } - - if int64(len(p)) > lr.left { - p = p[:lr.left] - } - n, err := lr.r.Read(p) - lr.left -= int64(n) - return n, err -} diff --git a/websocket.go b/websocket.go index 37719932..2efc485d 100644 --- a/websocket.go +++ b/websocket.go @@ -21,6 +21,9 @@ import ( // All methods may be called concurrently except for Reader, Read // and SetReadLimit. // +// You must always read from the connection. Otherwise control +// frames will not be handled. See the docs on Reader. +// // Please be sure to call Close on the connection when you // are finished with it to release the associated resources. type Conn struct { @@ -28,7 +31,7 @@ type Conn struct { br *bufio.Reader bw *bufio.Writer // writeBuf is used for masking, its the buffer in bufio.Writer. - // Only used by the client. + // Only used by the client for masking the bytes in the buffer. writeBuf []byte closer io.Closer client bool @@ -45,23 +48,29 @@ type Conn struct { // writeFrameLock is acquired to write a single frame. // Effectively meaning whoever holds it gets to write to bw. writeFrameLock chan struct{} + writeHeaderBuf []byte + writeHeader *header + + // messageWriter state. + writeMsgOpcode opcode + writeMsgCtx context.Context // Used to ensure the previous reader is read till EOF before allowing // a new one. previousReader *messageReader // readFrameLock is acquired to read from bw. - readFrameLock chan struct{} - // readMsg is used by messageReader to receive frames from - // readLoop. - readMsg chan header - // readMsgDone is used to tell the readLoop to continue after - // messageReader has read a frame. - readMsgDone chan struct{} + readFrameLock chan struct{} + readHeaderBuf []byte + controlPayloadBuf []byte + + // messageReader state + readMsgCtx context.Context + readMsgHeader header + readFrameEOF bool + readMaskPos int setReadTimeout chan context.Context setWriteTimeout chan context.Context - setConnContext chan context.Context - getConnContext chan context.Context activePingsMu sync.Mutex activePings map[string]chan<- struct{} @@ -76,22 +85,22 @@ func (c *Conn) init() { c.writeFrameLock = make(chan struct{}, 1) c.readFrameLock = make(chan struct{}, 1) - c.readMsg = make(chan header) - c.readMsgDone = make(chan struct{}) c.setReadTimeout = make(chan context.Context) c.setWriteTimeout = make(chan context.Context) - c.setConnContext = make(chan context.Context) - c.getConnContext = make(chan context.Context) c.activePings = make(map[string]chan<- struct{}) + c.writeHeaderBuf = makeWriteHeaderBuf() + c.writeHeader = &header{} + c.readHeaderBuf = makeReadHeaderBuf() + c.controlPayloadBuf = make([]byte, maxControlFramePayload) + runtime.SetFinalizer(c, func(c *Conn) { c.close(xerrors.New("connection garbage collected")) }) go c.timeoutLoop() - go c.readLoop() } // Subprotocol returns the negotiated subprotocol. @@ -131,56 +140,23 @@ func (c *Conn) close(err error) { func (c *Conn) timeoutLoop() { readCtx := context.Background() writeCtx := context.Background() - parentCtx := context.Background() for { select { case <-c.closed: return + case writeCtx = <-c.setWriteTimeout: case readCtx = <-c.setReadTimeout: + case <-readCtx.Done(): c.close(xerrors.Errorf("data read timed out: %w", readCtx.Err())) case <-writeCtx.Done(): c.close(xerrors.Errorf("data write timed out: %w", writeCtx.Err())) - case <-parentCtx.Done(): - c.close(xerrors.Errorf("parent context cancelled: %w", parentCtx.Err())) - return - case parentCtx = <-c.setConnContext: - ctx, cancelCtx := context.WithCancel(parentCtx) - defer cancelCtx() - - select { - case <-c.closed: - return - case c.getConnContext <- ctx: - } } } } -// Context returns a context derived from parent that will be cancelled -// when the connection is closed or broken. -// If the parent context is cancelled, the connection will be closed. -func (c *Conn) Context(parent context.Context) context.Context { - select { - case <-c.closed: - ctx, cancel := context.WithCancel(parent) - cancel() - return ctx - case c.setConnContext <- parent: - } - - select { - case <-c.closed: - ctx, cancel := context.WithCancel(parent) - cancel() - return ctx - case ctx := <-c.getConnContext: - return ctx - } -} - func (c *Conn) acquireLock(ctx context.Context, lock chan struct{}) error { select { case <-ctx.Done(): @@ -210,30 +186,9 @@ func (c *Conn) releaseLock(lock chan struct{}) { } } -func (c *Conn) readLoop() { - for { - h, err := c.readTillMsg() - if err != nil { - return - } - - select { - case <-c.closed: - return - case c.readMsg <- h: - } - - select { - case <-c.closed: - return - case <-c.readMsgDone: - } - } -} - -func (c *Conn) readTillMsg() (header, error) { +func (c *Conn) readTillMsg(ctx context.Context) (header, error) { for { - h, err := c.readFrameHeader() + h, err := c.readFrameHeader(ctx) if err != nil { return header{}, err } @@ -245,7 +200,10 @@ func (c *Conn) readTillMsg() (header, error) { } if h.opcode.controlOp() { - c.handleControl(h) + err = c.handleControl(ctx, h) + if err != nil { + return header{}, xerrors.Errorf("failed to handle control frame: %w", err) + } continue } @@ -260,43 +218,63 @@ func (c *Conn) readTillMsg() (header, error) { } } -func (c *Conn) readFrameHeader() (header, error) { +func (c *Conn) readFrameHeader(ctx context.Context) (header, error) { err := c.acquireLock(context.Background(), c.readFrameLock) if err != nil { return header{}, err } defer c.releaseLock(c.readFrameLock) - h, err := readHeader(c.br) + select { + case <-c.closed: + return header{}, c.closeErr + case c.setReadTimeout <- ctx: + } + + h, err := readHeader(c.readHeaderBuf, c.br) if err != nil { + select { + case <-c.closed: + return header{}, c.closeErr + case <-ctx.Done(): + err = ctx.Err() + default: + } err := xerrors.Errorf("failed to read header: %w", err) c.releaseLock(c.readFrameLock) c.close(err) return header{}, err } + select { + case <-c.closed: + return header{}, c.closeErr + case c.setReadTimeout <- context.Background(): + } + return h, nil } -func (c *Conn) handleControl(h header) { +func (c *Conn) handleControl(ctx context.Context, h header) error { if h.payloadLength > maxControlFramePayload { - c.Close(StatusProtocolError, "control frame too large") - return + err := xerrors.Errorf("control frame too large at %v bytes", h.payloadLength) + c.Close(StatusProtocolError, err.Error()) + return err } if !h.fin { - c.Close(StatusProtocolError, "control frame cannot be fragmented") - return + err := xerrors.Errorf("received fragmented control frame") + c.Close(StatusProtocolError, err.Error()) + return err } - ctx, cancel := context.WithTimeout(context.Background(), time.Second*5) + ctx, cancel := context.WithTimeout(ctx, time.Second*5) defer cancel() - b := make([]byte, h.payloadLength) - + b := c.controlPayloadBuf[:h.payloadLength] _, err := c.readFramePayload(ctx, b) if err != nil { - return + return err } if h.masked { @@ -305,7 +283,7 @@ func (c *Conn) handleControl(h header) { switch h.opcode { case opPing: - c.writePong(b) + return c.writePong(b) case opPong: c.activePingsMu.Lock() pong, ok := c.activePings[string(b)] @@ -313,17 +291,15 @@ func (c *Conn) handleControl(h header) { if ok { close(pong) } + return nil case opClose: ce, err := parseClosePayload(b) if err != nil { - c.close(xerrors.Errorf("received invalid close payload: %w", err)) - return - } - if ce.Code == StatusNoStatusRcvd { - c.writeClose(nil, ce) - } else { - c.Close(ce.Code, ce.Reason) + c.Close(StatusProtocolError, "received invalid close payload") + return xerrors.Errorf("received invalid close payload: %w", err) } + c.writeClose(b, xerrors.Errorf("received close frame: %w", ce)) + return c.closeErr default: panic(fmt.Sprintf("websocket: unexpected control opcode: %#v", h)) } @@ -335,11 +311,10 @@ func (c *Conn) handleControl(h header) { // The passed context will also bound the reader. // Ensure you read to EOF otherwise the connection will hang. // -// Control (ping, pong, close) frames will be handled automatically -// in a separate goroutine so if you do not expect any data messages, -// you do not need to read from the connection. However, if the peer -// sends a data message, further pings, pongs and close frames will not -// be read if you do not read the message from the connection. +// You must read from the connection for close frames to be read. +// If you do not expect any data messages from the peer, just call +// Reader in a separate goroutine and close the connection with StatusPolicyViolation +// when it returns. See the writeOnly example. // // Only one Reader may be open at a time. // @@ -352,15 +327,11 @@ func (c *Conn) Reader(ctx context.Context) (MessageType, io.Reader, error) { if err != nil { return 0, nil, xerrors.Errorf("failed to get reader: %w", err) } - return typ, &limitedReader{ - c: c, - r: r, - left: c.msgReadLimit, - }, nil + return typ, r, nil } func (c *Conn) reader(ctx context.Context) (MessageType, io.Reader, error) { - if c.previousReader != nil && c.previousReader.h != nil { + if c.previousReader != nil && !c.readFrameEOF { // The only way we know for sure the previous reader is not yet complete is // if there is an active frame not yet fully read. // Otherwise, a user may have read the last byte but not the EOF if the EOF @@ -368,57 +339,52 @@ func (c *Conn) reader(ctx context.Context) (MessageType, io.Reader, error) { return 0, nil, xerrors.Errorf("previous message not read to completion") } - select { - case <-c.closed: - return 0, nil, c.closeErr - case <-ctx.Done(): - return 0, nil, ctx.Err() - case h := <-c.readMsg: - if c.previousReader != nil && !c.previousReader.done { - if h.opcode != opContinuation { - err := xerrors.Errorf("received new data message without finishing the previous message") - c.Close(StatusProtocolError, err.Error()) - return 0, nil, err - } + h, err := c.readTillMsg(ctx) + if err != nil { + return 0, nil, err + } - if !h.fin || h.payloadLength > 0 { - return 0, nil, xerrors.Errorf("previous message not read to completion") - } + if c.previousReader != nil && !c.previousReader.eof { + if h.opcode != opContinuation { + err := xerrors.Errorf("received new data message without finishing the previous message") + c.Close(StatusProtocolError, err.Error()) + return 0, nil, err + } - c.previousReader.done = true + if !h.fin || h.payloadLength > 0 { + return 0, nil, xerrors.Errorf("previous message not read to completion") + } - select { - case <-c.closed: - return 0, nil, c.closeErr - case c.readMsgDone <- struct{}{}: - } + c.previousReader.eof = true - return c.reader(ctx) - } else if h.opcode == opContinuation { - err := xerrors.Errorf("received continuation frame not after data or text frame") - c.Close(StatusProtocolError, err.Error()) + h, err = c.readTillMsg(ctx) + if err != nil { return 0, nil, err } + } else if h.opcode == opContinuation { + err := xerrors.Errorf("received continuation frame not after data or text frame") + c.Close(StatusProtocolError, err.Error()) + return 0, nil, err + } - r := &messageReader{ - ctx: ctx, - c: c, + c.readMsgCtx = ctx + c.readMsgHeader = h + c.readFrameEOF = false + c.readMaskPos = 0 - h: &h, - } - c.previousReader = r - return MessageType(h.opcode), r, nil + r := &messageReader{ + c: c, + left: c.msgReadLimit, } + c.previousReader = r + return MessageType(h.opcode), r, nil } // messageReader enables reading a data frame from the WebSocket connection. type messageReader struct { - ctx context.Context - c *Conn - - h *header - maskPos int - done bool + c *Conn + left int64 + eof bool } // Read reads as many bytes as possible into p. @@ -436,62 +402,62 @@ func (r *messageReader) Read(p []byte) (int, error) { } func (r *messageReader) read(p []byte) (int, error) { - if r.done { + if r.eof { return 0, xerrors.Errorf("cannot use EOFed reader") } - if r.h == nil { - select { - case <-r.c.closed: - return 0, r.c.closeErr - case <-r.ctx.Done(): - r.c.close(xerrors.Errorf("failed to read: %w", r.ctx.Err())) - return 0, r.ctx.Err() - case h := <-r.c.readMsg: - if h.opcode != opContinuation { - err := xerrors.Errorf("received new data frame without finishing the previous frame") - r.c.Close(StatusProtocolError, err.Error()) - return 0, err - } - r.h = &h + if r.left <= 0 { + err := xerrors.Errorf("read limited at %v bytes", r.c.msgReadLimit) + r.c.Close(StatusMessageTooBig, err.Error()) + return 0, err + } + + if int64(len(p)) > r.left { + p = p[:r.left] + } + + if r.c.readFrameEOF { + h, err := r.c.readTillMsg(r.c.readMsgCtx) + if err != nil { + return 0, err + } + + if h.opcode != opContinuation { + err := xerrors.Errorf("received new data message without finishing the previous message") + r.c.Close(StatusProtocolError, err.Error()) + return 0, err } + + r.c.readMsgHeader = h + r.c.readFrameEOF = false + r.c.readMaskPos = 0 } - if int64(len(p)) > r.h.payloadLength { - p = p[:r.h.payloadLength] + h := r.c.readMsgHeader + if int64(len(p)) > h.payloadLength { + p = p[:h.payloadLength] } - n, err := r.c.readFramePayload(r.ctx, p) + n, err := r.c.readFramePayload(r.c.readMsgCtx, p) - r.h.payloadLength -= int64(n) - if r.h.masked { - r.maskPos = fastXOR(r.h.maskKey, r.maskPos, p) + h.payloadLength -= int64(n) + r.left -= int64(n) + if h.masked { + r.c.readMaskPos = fastXOR(h.maskKey, r.c.readMaskPos, p) } + r.c.readMsgHeader = h if err != nil { return n, err } - if r.h.payloadLength == 0 { - select { - case <-r.c.closed: - return n, r.c.closeErr - case r.c.readMsgDone <- struct{}{}: - } - - fin := r.h.fin - - // Need to nil this as Reader uses it to check - // whether there is active data on the previous reader and - // now there isn't. - r.h = nil + if h.payloadLength == 0 { + r.c.readFrameEOF = true - if fin { - r.done = true + if h.fin { + r.eof = true return n, io.EOF } - - r.maskPos = 0 } return n, nil @@ -519,7 +485,7 @@ func (c *Conn) readFramePayload(ctx context.Context, p []byte) (int, error) { err = ctx.Err() default: } - err = xerrors.Errorf("failed to read from connection: %w", err) + err = xerrors.Errorf("failed to read frame payload: %w", err) c.releaseLock(c.readFrameLock) c.close(err) return n, err @@ -539,7 +505,7 @@ func (c *Conn) readFramePayload(ctx context.Context, p []byte) (int, error) { // // By default, the connection has a message read limit of 32768 bytes. // -// When the limit is hit, the connection will be closed with StatusPolicyViolation. +// When the limit is hit, the connection will be closed with StatusMessageTooBig. func (c *Conn) SetReadLimit(n int64) { c.msgReadLimit = n } @@ -578,10 +544,10 @@ func (c *Conn) writer(ctx context.Context, typ MessageType) (io.WriteCloser, err if err != nil { return nil, err } + c.writeMsgCtx = ctx + c.writeMsgOpcode = opcode(typ) return &messageWriter{ - ctx: ctx, - opcode: opcode(typ), - c: c, + c: c, }, nil } @@ -610,8 +576,6 @@ func (c *Conn) write(ctx context.Context, typ MessageType, p []byte) (int, error // messageWriter enables writing to a WebSocket connection. type messageWriter struct { - ctx context.Context - opcode opcode c *Conn closed bool } @@ -629,11 +593,11 @@ func (w *messageWriter) write(p []byte) (int, error) { if w.closed { return 0, xerrors.Errorf("cannot use closed writer") } - n, err := w.c.writeFrame(w.ctx, false, w.opcode, p) + n, err := w.c.writeFrame(w.c.writeMsgCtx, false, w.c.writeMsgOpcode, p) if err != nil { return n, xerrors.Errorf("failed to write data frame: %w", err) } - w.opcode = opContinuation + w.c.writeMsgOpcode = opContinuation return n, nil } @@ -653,7 +617,7 @@ func (w *messageWriter) close() error { } w.closed = true - _, err := w.c.writeFrame(w.ctx, true, w.opcode, nil) + _, err := w.c.writeFrame(w.c.writeMsgCtx, true, w.c.writeMsgOpcode, nil) if err != nil { return xerrors.Errorf("failed to write fin frame: %w", err) } @@ -672,65 +636,78 @@ func (c *Conn) writeControl(ctx context.Context, opcode opcode, p []byte) error // writeFrame handles all writes to the connection. func (c *Conn) writeFrame(ctx context.Context, fin bool, opcode opcode, p []byte) (int, error) { - h := header{ - fin: fin, - opcode: opcode, - masked: c.client, - payloadLength: int64(len(p)), + err := c.acquireLock(ctx, c.writeFrameLock) + if err != nil { + return 0, err } + defer c.releaseLock(c.writeFrameLock) + + select { + case <-c.closed: + return 0, c.closeErr + case c.setWriteTimeout <- ctx: + } + + c.writeHeader.fin = fin + c.writeHeader.opcode = opcode + c.writeHeader.masked = c.client + c.writeHeader.payloadLength = int64(len(p)) if c.client { - _, err := io.ReadFull(cryptorand.Reader, h.maskKey[:]) + _, err := io.ReadFull(cryptorand.Reader, c.writeHeader.maskKey[:]) if err != nil { return 0, xerrors.Errorf("failed to generate masking key: %w", err) } } - b2 := marshalHeader(h) - - err := c.acquireLock(ctx, c.writeFrameLock) + n, err := c.realWriteFrame(ctx, *c.writeHeader, p) if err != nil { - return 0, err + return n, err } - defer c.releaseLock(c.writeFrameLock) + // We already finished writing, no need to potentially brick the connection if + // the context expires. select { case <-c.closed: - return 0, c.closeErr - case c.setWriteTimeout <- ctx: + return n, c.closeErr + case c.setWriteTimeout <- context.Background(): } - writeErr := func(err error) error { - select { - case <-c.closed: - return c.closeErr - case <-ctx.Done(): - err = ctx.Err() - default: - } + return n, nil +} - err = xerrors.Errorf("failed to write to connection: %w", err) - // We need to release the lock first before closing the connection to ensure - // the lock can be acquired inside close to ensure no one can access c.bw. - c.releaseLock(c.writeFrameLock) - c.close(err) +func (c *Conn) realWriteFrame(ctx context.Context, h header, p []byte) (n int, err error) { + defer func() { + if err != nil { + select { + case <-c.closed: + err = c.closeErr + case <-ctx.Done(): + err = ctx.Err() + default: + } - return err - } + err = xerrors.Errorf("failed to write %v frame: %w", h.opcode, err) + // We need to release the lock first before closing the connection to ensure + // the lock can be acquired inside close to ensure no one can access c.bw. + c.releaseLock(c.writeFrameLock) + c.close(err) + } + }() - _, err = c.bw.Write(b2) + headerBytes := writeHeader(c.writeHeaderBuf, h) + _, err = c.bw.Write(headerBytes) if err != nil { - return 0, writeErr(err) + return 0, err } - var n int if c.client { var keypos int for len(p) > 0 { if c.bw.Available() == 0 { err = c.bw.Flush() if err != nil { - return n, writeErr(err) + return n, err } } @@ -744,7 +721,7 @@ func (c *Conn) writeFrame(ctx context.Context, fin bool, opcode opcode, p []byte n2, err := c.bw.Write(p2) if err != nil { - return n, writeErr(err) + return n, err } keypos = fastXOR(h.maskKey, keypos, c.writeBuf[i:i+n2]) @@ -755,25 +732,17 @@ func (c *Conn) writeFrame(ctx context.Context, fin bool, opcode opcode, p []byte } else { n, err = c.bw.Write(p) if err != nil { - return n, writeErr(err) + return n, err } } - if fin { + if h.fin { err = c.bw.Flush() if err != nil { - return n, writeErr(err) + return n, err } } - // We already finished writing, no need to potentially brick the connection if - // the context expires. - select { - case <-c.closed: - return n, c.closeErr - case c.setWriteTimeout <- context.Background(): - } - return n, nil } @@ -822,22 +791,29 @@ func (c *Conn) exportedClose(code StatusCode, reason string) error { p, _ = ce.bytes() } - return c.writeClose(p, ce) + err = c.writeClose(p, xerrors.Errorf("sent close frame: %w", ce)) + if err != nil { + return err + } + + if !xerrors.Is(c.closeErr, ce) { + return c.closeErr + } + + return nil } -func (c *Conn) writeClose(p []byte, cerr CloseError) error { +func (c *Conn) writeClose(p []byte, cerr error) error { ctx, cancel := context.WithTimeout(context.Background(), time.Second*5) defer cancel() + // If this fails, the connection had to have died. err := c.writeControl(ctx, opClose, p) if err != nil { return err } c.close(cerr) - if !xerrors.Is(c.closeErr, cerr) { - return c.closeErr - } return nil } diff --git a/websocket_test.go b/websocket_test.go index 9d867b50..5209e2d7 100644 --- a/websocket_test.go +++ b/websocket_test.go @@ -383,6 +383,8 @@ func TestHandshake(t *testing.T) { } defer c.Close(websocket.StatusInternalError, "") + go c.Reader(r.Context()) + err = c.Ping(r.Context()) if err != nil { return err @@ -403,18 +405,19 @@ func TestHandshake(t *testing.T) { } defer c.Close(websocket.StatusInternalError, "") - err = c.Ping(ctx) - if err != nil { - return err - } + errc := make(chan error, 1) + go func() { + errc <- c.Ping(ctx) + }() _, _, err = c.Read(ctx) if err != nil { return err } + err = <-errc c.Close(websocket.StatusNormalClosure, "") - return nil + return err }, }, { @@ -439,6 +442,8 @@ func TestHandshake(t *testing.T) { } defer c.Close(websocket.StatusInternalError, "") + go c.Reader(ctx) + err = c.Write(ctx, websocket.MessageBinary, []byte(strings.Repeat("x", 32769))) if err != nil { return err @@ -447,53 +452,13 @@ func TestHandshake(t *testing.T) { err = c.Ping(ctx) var ce websocket.CloseError - if !xerrors.As(err, &ce) || ce.Code != websocket.StatusPolicyViolation { + if !xerrors.As(err, &ce) || ce.Code != websocket.StatusMessageTooBig { return xerrors.Errorf("unexpected error: %w", err) } return nil }, }, - { - name: "context", - server: func(w http.ResponseWriter, r *http.Request) error { - c, err := websocket.Accept(w, r, websocket.AcceptOptions{}) - if err != nil { - return err - } - defer c.Close(websocket.StatusInternalError, "") - - ctx, cancel := context.WithTimeout(r.Context(), time.Second) - defer cancel() - - c.Context(ctx) - - for r.Context().Err() == nil { - err = c.Ping(ctx) - if err != nil { - return nil - } - } - - return xerrors.Errorf("all pings succeeded") - }, - client: func(ctx context.Context, u string) error { - c, _, err := websocket.Dial(ctx, u, websocket.DialOptions{}) - if err != nil { - return err - } - defer c.Close(websocket.StatusInternalError, "") - - cctx := c.Context(ctx) - - select { - case <-ctx.Done(): - return xerrors.Errorf("child context never cancelled") - case <-cctx.Done(): - return nil - } - }, - }, } for _, tc := range testCases { @@ -844,7 +809,7 @@ func benchConn(b *testing.B, echo, stream bool, size int) { defer c.Close(websocket.StatusInternalError, "") msg := []byte(strings.Repeat("2", size)) - buf := make([]byte, len(msg)) + readBuf := make([]byte, len(msg)) b.SetBytes(int64(len(msg))) b.ReportAllocs() b.ResetTimer() @@ -877,7 +842,7 @@ func benchConn(b *testing.B, echo, stream bool, size int) { b.Fatal(err) } - _, err = io.ReadFull(r, buf) + _, err = io.ReadFull(r, readBuf) if err != nil { b.Fatal(err) } @@ -914,7 +879,7 @@ func BenchmarkConn(b *testing.B) { b.Run("echo", func(b *testing.B) { for _, size := range sizes { b.Run(strconv.Itoa(size), func(b *testing.B) { - benchConn(b, true, true, size) + benchConn(b, false, false, size) }) } }) diff --git a/wsjson/wsjson.go b/wsjson/wsjson.go index 19e3e6d7..b72d562f 100644 --- a/wsjson/wsjson.go +++ b/wsjson/wsjson.go @@ -44,6 +44,7 @@ func read(ctx context.Context, c *websocket.Conn, v interface{}) error { err = json.Unmarshal(b.Bytes(), v) if err != nil { + c.Close(websocket.StatusInvalidFramePayloadData, "failed to unmarshal JSON") return xerrors.Errorf("failed to unmarshal json: %w", err) } diff --git a/wspb/wspb.go b/wspb/wspb.go index 49c2ae54..56b14ee8 100644 --- a/wspb/wspb.go +++ b/wspb/wspb.go @@ -46,6 +46,7 @@ func read(ctx context.Context, c *websocket.Conn, v proto.Message) error { err = proto.Unmarshal(b.Bytes(), v) if err != nil { + c.Close(websocket.StatusInvalidFramePayloadData, "failed to unmarshal protobuf") return xerrors.Errorf("failed to unmarshal protobuf: %w", err) }