Skip to content

Commit

Permalink
add CloseWithError
Browse files Browse the repository at this point in the history
  • Loading branch information
sukunrt committed Aug 26, 2024
1 parent d8cf4e7 commit 18a75f1
Show file tree
Hide file tree
Showing 3 changed files with 109 additions and 30 deletions.
4 changes: 2 additions & 2 deletions errors.go
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ var (

// ErrSessionShutdown is used if there is a shutdown during
// an operation
ErrSessionShutdown = &Error{msg: "session shutdown"}
ErrSessionShutdown = &GoAwayError{ErrorCode: goAwayNormal, Remote: false}

// ErrStreamsExhausted is returned if we have no more
// stream ids to issue
Expand All @@ -87,7 +87,7 @@ var (
ErrUnexpectedFlag = &Error{msg: "unexpected flag"}

// ErrRemoteGoAway is used when we get a go away from the other side
ErrRemoteGoAway = &Error{msg: "remote end is not accepting connections"}
ErrRemoteGoAway = &GoAwayError{Remote: true, ErrorCode: goAwayNormal}

// ErrStreamReset is sent if a stream is reset. This can happen
// if the backlog is exceeded, or if there was a remote GoAway.
Expand Down
92 changes: 68 additions & 24 deletions session.go
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,8 @@ type Session struct {
// recvDoneCh is closed when recv() exits to avoid a race
// between stream registration and stream shutdown
recvDoneCh chan struct{}
// recvErr is the error the receive loop ended with
recvErr error

// sendDoneCh is closed when send() exits to avoid a race
// between returning from a Stream.Write and exiting from the send loop
Expand Down Expand Up @@ -288,10 +290,18 @@ func (s *Session) AcceptStream() (*Stream, error) {
// semantics of the underlying net.Conn. For TCP connections, it may be dropped depending on LINGER value or
// if there's unread data in the kernel receive buffer.
func (s *Session) Close() error {
return s.close(true, goAwayNormal)
return s.closeWithGoAway(goAwayNormal)
}

func (s *Session) close(sendGoAway bool, errCode uint32) error {
// CloseWithError is used to close the session and all streams after sending a GoAway message with errCode.
// The GoAway may not actually be sent depending on the semantics of the underlying net.Conn.
// For TCP connections, it may be dropped depending on LINGER value or if there's unread data in the kernel
// receive buffer.
func (s *Session) CloseWithError(errCode uint32) error {
return s.closeWithGoAway(errCode)
}

func (s *Session) closeWithGoAway(errCode uint32) error {
s.shutdownLock.Lock()
defer s.shutdownLock.Unlock()

Expand All @@ -300,22 +310,25 @@ func (s *Session) close(sendGoAway bool, errCode uint32) error {
}
s.shutdown = true
if s.shutdownErr == nil {
s.shutdownErr = ErrSessionShutdown
if errCode == goAwayNormal {
s.shutdownErr = ErrSessionShutdown
} else {
s.shutdownErr = &GoAwayError{Remote: false, ErrorCode: errCode}
}
}
close(s.shutdownCh)
s.stopKeepalive()

// wait for write loop to exit
_ = s.conn.SetWriteDeadline(time.Now().Add(-1 * time.Hour)) // if SetWriteDeadline errored, any blocked writes will be unblocked
// We need to complete writing the current frame before sending a goaway
// This will wait for at most s.config.ConnectionWriteTimeout
<-s.sendDoneCh
if sendGoAway {
ga := s.goAway(errCode)
if err := s.conn.SetWriteDeadline(time.Now().Add(goAwayWaitTime)); err == nil {
_, _ = s.conn.Write(ga[:]) // there's nothing we can do on error here
}
ga := s.goAway(errCode)
if err := s.conn.SetWriteDeadline(time.Now().Add(goAwayWaitTime)); err == nil {
_, _ = s.conn.Write(ga[:]) // there's nothing we can do on error here
}

s.conn.SetWriteDeadline(time.Time{})

s.conn.Close()
<-s.recvDoneCh

Expand All @@ -329,15 +342,30 @@ func (s *Session) close(sendGoAway bool, errCode uint32) error {
return nil
}

// exitErr is used to handle an error that is causing the
// session to terminate.
func (s *Session) exitErr(err error) {
func (s *Session) closeWithoutGoAway(err error) error {
s.shutdownLock.Lock()
defer s.shutdownLock.Unlock()
if s.shutdown {
return nil
}
s.shutdown = true
if s.shutdownErr == nil {
s.shutdownErr = err
}
s.shutdownLock.Unlock()
s.close(false, 0)
close(s.shutdownCh)
s.conn.Close()
<-s.sendDoneCh
<-s.recvDoneCh
s.stopKeepalive()

s.streamLock.Lock()
defer s.streamLock.Unlock()
for id, stream := range s.streams {
stream.forceClose()
delete(s.streams, id)
stream.memorySpan.Done()
}
return nil
}

// GoAway can be used to prevent accepting further
Expand Down Expand Up @@ -468,7 +496,7 @@ func (s *Session) startKeepalive() {

if err != nil {
s.logger.Printf("[ERR] yamux: keepalive failed: %v", err)
s.exitErr(ErrKeepAliveTimeout)
s.closeWithoutGoAway(ErrKeepAliveTimeout)
}
})
}
Expand Down Expand Up @@ -533,7 +561,20 @@ func (s *Session) sendMsg(hdr header, body []byte, deadline <-chan struct{}) err
// send is a long running goroutine that sends data
func (s *Session) send() {
if err := s.sendLoop(); err != nil {
s.exitErr(err)
// Prefer the recvLoop error over the sendLoop error. The receive loop might have the error code
// received in a GoAway frame received just before the TCP RST that closed the sendLoop
s.shutdownLock.Lock()
if s.shutdownErr == nil {
s.conn.Close()
<-s.recvDoneCh
if _, ok := s.recvErr.(*GoAwayError); ok {
s.shutdownErr = s.recvErr
} else {
s.shutdownErr = err
}
}
s.shutdownLock.Unlock()
s.closeWithoutGoAway(err)
}
}

Expand Down Expand Up @@ -661,7 +702,7 @@ func (s *Session) sendLoop() (err error) {
// recv is a long running goroutine that accepts new data
func (s *Session) recv() {
if err := s.recvLoop(); err != nil {
s.exitErr(err)
s.closeWithoutGoAway(err)
}
}

Expand All @@ -683,7 +724,10 @@ func (s *Session) recvLoop() (err error) {
err = fmt.Errorf("panic in yamux receive loop: %s", rerr)
}
}()
defer close(s.recvDoneCh)
defer func() {
s.recvErr = err
close(s.recvDoneCh)
}()
var hdr header
for {
// fmt.Printf("ReadFull from %#v\n", s.reader)
Expand Down Expand Up @@ -799,17 +843,17 @@ func (s *Session) handleGoAway(hdr header) error {
switch code {
case goAwayNormal:
atomic.SwapInt32(&s.remoteGoAway, 1)
// Don't close connection on normal go away. Let the existing streams
// complete gracefully.
return nil
case goAwayProtoErr:
s.logger.Printf("[ERR] yamux: received protocol error go away")
return fmt.Errorf("yamux protocol error")
case goAwayInternalErr:
s.logger.Printf("[ERR] yamux: received internal error go away")
return fmt.Errorf("remote yamux internal error")
default:
s.logger.Printf("[ERR] yamux: received unexpected go away")
return fmt.Errorf("unexpected go away received")
s.logger.Printf("[ERR] yamux: received go away with error code: %d", code)
}
return nil
return &GoAwayError{Remote: true, ErrorCode: code}
}

// incomingStream is used to create a new incoming stream
Expand Down
43 changes: 39 additions & 4 deletions session_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package yamux
import (
"bytes"
"context"
"errors"
"fmt"
"io"
"math/rand"
Expand Down Expand Up @@ -39,6 +40,8 @@ type pipeConn struct {
writeDeadline pipeDeadline
writeBlocker chan struct{}
closeCh chan struct{}
closeOnce sync.Once
closeErr error
}

func (p *pipeConn) SetDeadline(t time.Time) error {
Expand All @@ -65,10 +68,12 @@ func (p *pipeConn) Write(b []byte) (int, error) {
}

func (p *pipeConn) Close() error {
p.writeDeadline.set(time.Time{})
err := p.Conn.Close()
close(p.closeCh)
return err
p.closeOnce.Do(func() {
p.writeDeadline.set(time.Time{})
p.closeErr = p.Conn.Close()
close(p.closeCh)
})
return p.closeErr
}

func (p *pipeConn) BlockWrites() {
Expand Down Expand Up @@ -650,6 +655,35 @@ func TestGoAway(t *testing.T) {
default:
t.Fatalf("err: %v", err)
}
time.Sleep(50 * time.Millisecond)
}
t.Fatalf("expected GoAway error")
}

func TestCloseWithError(t *testing.T) {
// This test is noisy.
conf := testConf()
conf.LogOutput = io.Discard

client, server := testClientServerConfig(conf)
defer client.Close()
defer server.Close()

if err := server.CloseWithError(42); err != nil {
t.Fatalf("err: %v", err)
}

for i := 0; i < 100; i++ {
s, err := client.Open(context.Background())
if err == nil {
s.Close()
time.Sleep(50 * time.Millisecond)
continue
}
if !errors.Is(err, &GoAwayError{ErrorCode: 42, Remote: true}) {
t.Fatalf("err: %v", err)
}
return
}
t.Fatalf("expected GoAway error")
}
Expand Down Expand Up @@ -1048,6 +1082,7 @@ func TestKeepAlive_Timeout(t *testing.T) {
// Prevent the client from responding
clientConn := client.conn.(*pipeConn)
clientConn.BlockWrites()
defer clientConn.UnblockWrites()

select {
case err := <-errCh:
Expand Down

0 comments on commit 18a75f1

Please sign in to comment.