Skip to content

Commit

Permalink
Merge pull request #169 from nhooyr/race
Browse files Browse the repository at this point in the history
Fix race with c.readerShouldLock
  • Loading branch information
nhooyr committed Nov 5, 2019
2 parents e36318f + 8b47056 commit f178ccf
Showing 1 changed file with 40 additions and 17 deletions.
57 changes: 40 additions & 17 deletions conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -78,11 +78,10 @@ type Conn struct {
readLock chan struct{}

// messageReader state.
readerMsgCtx context.Context
readerMsgHeader header
readerFrameEOF bool
readerMaskPos int
readerShouldLock bool
readerMsgCtx context.Context
readerMsgHeader header
readerFrameEOF bool
readerMaskPos int

setReadTimeout chan context.Context
setWriteTimeout chan context.Context
Expand Down Expand Up @@ -237,6 +236,10 @@ func (c *Conn) readTillMsg(ctx context.Context) (header, error) {
if h.opcode.controlOp() {
err = c.handleControl(ctx, h)
if err != nil {
// Pass through CloseErrors when receiving a close frame.
if h.opcode == opClose && CloseStatus(err) != -1 {
return header{}, err
}
return header{}, fmt.Errorf("failed to handle control frame %v: %w", h.opcode, err)
}
continue
Expand Down Expand Up @@ -445,7 +448,6 @@ func (c *Conn) reader(ctx context.Context, lock bool) (MessageType, io.Reader, e
c.readerFrameEOF = false
c.readerMaskPos = 0
c.readMsgLeft = c.msgReadLimit.Load()
c.readerShouldLock = lock

r := &messageReader{
c: c,
Expand All @@ -465,7 +467,11 @@ func (r *messageReader) eof() bool {

// Read reads as many bytes as possible into p.
func (r *messageReader) Read(p []byte) (int, error) {
n, err := r.read(p)
return r.exportedRead(p, true)
}

func (r *messageReader) exportedRead(p []byte, lock bool) (int, error) {
n, err := r.read(p, lock)
if err != nil {
// Have to return io.EOF directly for now, we cannot wrap as errors.Is
// isn't used widely yet.
Expand All @@ -477,17 +483,29 @@ func (r *messageReader) Read(p []byte) (int, error) {
return n, nil
}

func (r *messageReader) read(p []byte) (int, error) {
if r.c.readerShouldLock {
err := r.c.acquireLock(r.c.readerMsgCtx, r.c.readLock)
if err != nil {
return 0, err
func (r *messageReader) readUnlocked(p []byte) (int, error) {
return r.exportedRead(p, false)
}

func (r *messageReader) read(p []byte, lock bool) (int, error) {
if lock {
// If we cannot acquire the read lock, then
// there is either a concurrent read or the close handshake
// is proceeding.
select {
case r.c.readLock <- struct{}{}:
defer r.c.releaseLock(r.c.readLock)
default:
if r.c.closing.Load() == 1 {
<-r.c.closed
return 0, r.c.closeErr
}
return 0, errors.New("concurrent read detected")
}
defer r.c.releaseLock(r.c.readLock)
}

if r.eof() {
return 0, fmt.Errorf("cannot use EOFed reader")
return 0, errors.New("cannot use EOFed reader")
}

if r.c.readMsgLeft <= 0 {
Expand Down Expand Up @@ -950,8 +968,6 @@ func (c *Conn) waitClose() error {
return c.closeReceived
}

c.readerShouldLock = false

b := bpool.Get()
buf := b.Bytes()
buf = buf[:cap(buf)]
Expand All @@ -965,7 +981,8 @@ func (c *Conn) waitClose() error {
}
}

_, err = io.CopyBuffer(ioutil.Discard, c.activeReader, buf)
r := readerFunc(c.activeReader.readUnlocked)
_, err = io.CopyBuffer(ioutil.Discard, r, buf)
if err != nil {
return err
}
Expand Down Expand Up @@ -1019,6 +1036,12 @@ func (c *Conn) ping(ctx context.Context, p string) error {
}
}

type readerFunc func(p []byte) (int, error)

func (f readerFunc) Read(p []byte) (int, error) {
return f(p)
}

type writerFunc func(p []byte) (int, error)

func (f writerFunc) Write(p []byte) (int, error) {
Expand Down

0 comments on commit f178ccf

Please sign in to comment.