diff --git a/errors.go b/errors.go index 5ffc742..7bedec5 100644 --- a/errors.go +++ b/errors.go @@ -64,7 +64,7 @@ var ( // ErrSessionShutdown is used if there is a shutdown during // an operation - ErrSessionShutdown = &Error{msg: "session shutdown"} + ErrSessionShutdown = &GoAwayError{ErrorCode: goAwayNormal, Remote: false} // ErrStreamsExhausted is returned if we have no more // stream ids to issue @@ -87,7 +87,7 @@ var ( ErrUnexpectedFlag = &Error{msg: "unexpected flag"} // ErrRemoteGoAway is used when we get a go away from the other side - ErrRemoteGoAway = &Error{msg: "remote end is not accepting connections"} + ErrRemoteGoAway = &GoAwayError{Remote: true, ErrorCode: goAwayNormal} // ErrStreamReset is sent if a stream is reset. This can happen // if the backlog is exceeded, or if there was a remote GoAway. diff --git a/session.go b/session.go index 62ea2c3..3a62ac2 100644 --- a/session.go +++ b/session.go @@ -102,6 +102,8 @@ type Session struct { // recvDoneCh is closed when recv() exits to avoid a race // between stream registration and stream shutdown recvDoneCh chan struct{} + // recvErr is the error the receive loop ended with + recvErr error // sendDoneCh is closed when send() exits to avoid a race // between returning from a Stream.Write and exiting from the send loop @@ -288,10 +290,18 @@ func (s *Session) AcceptStream() (*Stream, error) { // semantics of the underlying net.Conn. For TCP connections, it may be dropped depending on LINGER value or // if there's unread data in the kernel receive buffer. func (s *Session) Close() error { - return s.close(true, goAwayNormal) + return s.closeWithGoAway(goAwayNormal) } -func (s *Session) close(sendGoAway bool, errCode uint32) error { +// CloseWithError is used to close the session and all streams after sending a GoAway message with errCode. +// The GoAway may not actually be sent depending on the semantics of the underlying net.Conn. +// For TCP connections, it may be dropped depending on LINGER value or if there's unread data in the kernel +// receive buffer. +func (s *Session) CloseWithError(errCode uint32) error { + return s.closeWithGoAway(errCode) +} + +func (s *Session) closeWithGoAway(errCode uint32) error { s.shutdownLock.Lock() defer s.shutdownLock.Unlock() @@ -300,22 +310,25 @@ func (s *Session) close(sendGoAway bool, errCode uint32) error { } s.shutdown = true if s.shutdownErr == nil { - s.shutdownErr = ErrSessionShutdown + if errCode == goAwayNormal { + s.shutdownErr = ErrSessionShutdown + } else { + s.shutdownErr = &GoAwayError{Remote: false, ErrorCode: errCode} + } } close(s.shutdownCh) s.stopKeepalive() // wait for write loop to exit - _ = s.conn.SetWriteDeadline(time.Now().Add(-1 * time.Hour)) // if SetWriteDeadline errored, any blocked writes will be unblocked + // We need to complete writing the current frame before sending a goaway + // This will wait for at most s.config.ConnectionWriteTimeout <-s.sendDoneCh - if sendGoAway { - ga := s.goAway(errCode) - if err := s.conn.SetWriteDeadline(time.Now().Add(goAwayWaitTime)); err == nil { - _, _ = s.conn.Write(ga[:]) // there's nothing we can do on error here - } + ga := s.goAway(errCode) + if err := s.conn.SetWriteDeadline(time.Now().Add(goAwayWaitTime)); err == nil { + _, _ = s.conn.Write(ga[:]) // there's nothing we can do on error here } - s.conn.SetWriteDeadline(time.Time{}) + s.conn.Close() <-s.recvDoneCh @@ -329,15 +342,30 @@ func (s *Session) close(sendGoAway bool, errCode uint32) error { return nil } -// exitErr is used to handle an error that is causing the -// session to terminate. -func (s *Session) exitErr(err error) { +func (s *Session) closeWithoutGoAway(err error) error { s.shutdownLock.Lock() + defer s.shutdownLock.Unlock() + if s.shutdown { + return nil + } + s.shutdown = true if s.shutdownErr == nil { s.shutdownErr = err } - s.shutdownLock.Unlock() - s.close(false, 0) + close(s.shutdownCh) + s.conn.Close() + <-s.sendDoneCh + <-s.recvDoneCh + s.stopKeepalive() + + s.streamLock.Lock() + defer s.streamLock.Unlock() + for id, stream := range s.streams { + stream.forceClose() + delete(s.streams, id) + stream.memorySpan.Done() + } + return nil } // GoAway can be used to prevent accepting further @@ -468,7 +496,7 @@ func (s *Session) startKeepalive() { if err != nil { s.logger.Printf("[ERR] yamux: keepalive failed: %v", err) - s.exitErr(ErrKeepAliveTimeout) + s.closeWithoutGoAway(ErrKeepAliveTimeout) } }) } @@ -533,7 +561,20 @@ func (s *Session) sendMsg(hdr header, body []byte, deadline <-chan struct{}) err // send is a long running goroutine that sends data func (s *Session) send() { if err := s.sendLoop(); err != nil { - s.exitErr(err) + // Prefer the recvLoop error over the sendLoop error. The receive loop might have the error code + // received in a GoAway frame received just before the TCP RST that closed the sendLoop + s.shutdownLock.Lock() + if s.shutdownErr == nil { + s.conn.Close() + <-s.recvDoneCh + if _, ok := s.recvErr.(*GoAwayError); ok { + s.shutdownErr = s.recvErr + } else { + s.shutdownErr = err + } + } + s.shutdownLock.Unlock() + s.closeWithoutGoAway(err) } } @@ -661,7 +702,7 @@ func (s *Session) sendLoop() (err error) { // recv is a long running goroutine that accepts new data func (s *Session) recv() { if err := s.recvLoop(); err != nil { - s.exitErr(err) + s.closeWithoutGoAway(err) } } @@ -683,7 +724,10 @@ func (s *Session) recvLoop() (err error) { err = fmt.Errorf("panic in yamux receive loop: %s", rerr) } }() - defer close(s.recvDoneCh) + defer func() { + s.recvErr = err + close(s.recvDoneCh) + }() var hdr header for { // fmt.Printf("ReadFull from %#v\n", s.reader) @@ -799,17 +843,17 @@ func (s *Session) handleGoAway(hdr header) error { switch code { case goAwayNormal: atomic.SwapInt32(&s.remoteGoAway, 1) + // Don't close connection on normal go away. Let the existing streams + // complete gracefully. + return nil case goAwayProtoErr: s.logger.Printf("[ERR] yamux: received protocol error go away") - return fmt.Errorf("yamux protocol error") case goAwayInternalErr: s.logger.Printf("[ERR] yamux: received internal error go away") - return fmt.Errorf("remote yamux internal error") default: - s.logger.Printf("[ERR] yamux: received unexpected go away") - return fmt.Errorf("unexpected go away received") + s.logger.Printf("[ERR] yamux: received go away with error code: %d", code) } - return nil + return &GoAwayError{Remote: true, ErrorCode: code} } // incomingStream is used to create a new incoming stream diff --git a/session_test.go b/session_test.go index 974b6d5..df3e3c9 100644 --- a/session_test.go +++ b/session_test.go @@ -3,6 +3,7 @@ package yamux import ( "bytes" "context" + "errors" "fmt" "io" "math/rand" @@ -39,6 +40,8 @@ type pipeConn struct { writeDeadline pipeDeadline writeBlocker chan struct{} closeCh chan struct{} + closeOnce sync.Once + closeErr error } func (p *pipeConn) SetDeadline(t time.Time) error { @@ -65,10 +68,12 @@ func (p *pipeConn) Write(b []byte) (int, error) { } func (p *pipeConn) Close() error { - p.writeDeadline.set(time.Time{}) - err := p.Conn.Close() - close(p.closeCh) - return err + p.closeOnce.Do(func() { + p.writeDeadline.set(time.Time{}) + p.closeErr = p.Conn.Close() + close(p.closeCh) + }) + return p.closeErr } func (p *pipeConn) BlockWrites() { @@ -650,6 +655,35 @@ func TestGoAway(t *testing.T) { default: t.Fatalf("err: %v", err) } + time.Sleep(50 * time.Millisecond) + } + t.Fatalf("expected GoAway error") +} + +func TestCloseWithError(t *testing.T) { + // This test is noisy. + conf := testConf() + conf.LogOutput = io.Discard + + client, server := testClientServerConfig(conf) + defer client.Close() + defer server.Close() + + if err := server.CloseWithError(42); err != nil { + t.Fatalf("err: %v", err) + } + + for i := 0; i < 100; i++ { + s, err := client.Open(context.Background()) + if err == nil { + s.Close() + time.Sleep(50 * time.Millisecond) + continue + } + if !errors.Is(err, &GoAwayError{ErrorCode: 42, Remote: true}) { + t.Fatalf("err: %v", err) + } + return } t.Fatalf("expected GoAway error") } @@ -1048,6 +1082,7 @@ func TestKeepAlive_Timeout(t *testing.T) { // Prevent the client from responding clientConn := client.conn.(*pipeConn) clientConn.BlockWrites() + defer clientConn.UnblockWrites() select { case err := <-errCh: