From 6cec2ca22e36e702265fd0a9173be341c8e44397 Mon Sep 17 00:00:00 2001 From: Anmol Sethi Date: Wed, 18 Oct 2023 22:47:59 -0700 Subject: [PATCH] close.go: Fix mid read close Closes #355 --- close.go | 7 +++++++ conn_test.go | 21 +++++++++++++++++++++ 2 files changed, 28 insertions(+) diff --git a/close.go b/close.go index d78a5442..fe1ced34 100644 --- a/close.go +++ b/close.go @@ -182,6 +182,13 @@ func (c *Conn) waitCloseHandshake() error { return c.readCloseFrameErr } + for i := int64(0); i < c.msgReader.payloadLength; i++ { + _, err := c.br.ReadByte() + if err != nil { + return err + } + } + for { h, err := c.readLoop(ctx) if err != nil { diff --git a/conn_test.go b/conn_test.go index 3df6c64a..abc1c81d 100644 --- a/conn_test.go +++ b/conn_test.go @@ -308,6 +308,27 @@ func TestConn(t *testing.T) { assert.ErrorIs(t, websocket.ErrClosed, err1) assert.ErrorIs(t, websocket.ErrClosed, err2) }) + + t.Run("MidReadClose", func(t *testing.T) { + tt, c1, c2 := newConnTest(t, nil, nil) + + tt.goEchoLoop(c2) + + c1.SetReadLimit(131072) + + for i := 0; i < 5; i++ { + err := wstest.Echo(tt.ctx, c1, 131072) + assert.Success(t, err) + } + + err := wsjson.Write(tt.ctx, c1, "four") + assert.Success(t, err) + _, _, err = c1.Reader(tt.ctx) + assert.Success(t, err) + + err = c1.Close(websocket.StatusNormalClosure, "") + assert.Success(t, err) + }) } func TestWasm(t *testing.T) {