Skip to content

Commit

Permalink
do a graceful close
Browse files Browse the repository at this point in the history
  • Loading branch information
sukunrt committed Aug 23, 2024
1 parent 276b891 commit 1525575
Show file tree
Hide file tree
Showing 2 changed files with 53 additions and 41 deletions.
57 changes: 45 additions & 12 deletions session.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,10 @@ func (n nullMemoryManagerImpl) Done() {}

var nullMemoryManager = &nullMemoryManagerImpl{}

type CloseWriter interface {
CloseWrite() error
}

// Session is used to wrap a reliable ordered connection and to
// multiplex it into multiple streams.
type Session struct {
Expand Down Expand Up @@ -304,18 +308,27 @@ func (s *Session) closeWithError(errCode uint32, sendGoAway bool) error {
// wait for write loop
_ = s.conn.SetWriteDeadline(time.Now().Add(-1 * time.Hour)) // if SetWriteDeadline errored, any blocked writes will be unblocked
<-s.sendDoneCh

// send the goaway frame
if sendGoAway {
buf := pool.Get(headerSize)
hdr := s.goAway(errCode)
copy(buf, hdr[:])
if err := s.conn.SetWriteDeadline(time.Now().Add(goAwayWaitTime)); err == nil {
_, _ = s.conn.Write(buf) // Ignore the error. We are going to close the connection anyway
if _, err = s.conn.Write(buf); err != nil {
sendGoAway = false
}
} else {
sendGoAway = false
}
}
s.conn.Close()

if w, ok := s.conn.(CloseWriter); ok && sendGoAway {
if err := w.CloseWrite(); err != nil {
s.conn.Close()

Check warning on line 326 in session.go

View check run for this annotation

Codecov / codecov/patch

session.go#L326

Added line #L326 was not covered by tests
}
} else {
s.conn.Close()
}
s.conn.SetReadDeadline(time.Now().Add(-1 * time.Hour))
// wait for read loop
<-s.recvDoneCh

Expand Down Expand Up @@ -533,12 +546,6 @@ func (s *Session) sendMsg(hdr header, body []byte, deadline <-chan struct{}) err
// send is a long running goroutine that sends data
func (s *Session) send() {
if err := s.sendLoop(); err != nil {
if !s.IsClosed() && (errors.Is(err, net.ErrClosed) || errors.Is(err, io.ErrClosedPipe) || strings.Contains(err.Error(), "reset") || strings.Contains(err.Error(), "broken pipe")) {
// if remote has closed the connection, wait for recv loop to exit
// unfortunately it is impossible to close the connection such that FIN is sent and not RST
<-s.recvDoneCh
return
}
s.exitErr(err)
}
}
Expand Down Expand Up @@ -654,7 +661,6 @@ func (s *Session) sendLoop() (err error) {

_, err := writer.Write(buf)
pool.Put(buf)

if err != nil {
if os.IsTimeout(err) {
err = ErrConnectionWriteTimeout
Expand Down Expand Up @@ -689,12 +695,39 @@ func (s *Session) recvLoop() (err error) {
err = fmt.Errorf("panic in yamux receive loop: %s", rerr)
}
}()
defer close(s.recvDoneCh)

gracefulCloseErr := errors.New("close gracefully")
defer func() {
close(s.recvDoneCh)
errGoAway := &ErrorGoAway{}
if errors.As(err, &errGoAway) {
return
}
if err != gracefulCloseErr {
s.conn.Close()
return
}
if err := s.conn.SetReadDeadline(time.Now().Add(5 * time.Second)); err != nil {
s.conn.Close()
return

Check warning on line 712 in session.go

View check run for this annotation

Codecov / codecov/patch

session.go#L711-L712

Added lines #L711 - L712 were not covered by tests
}
buf := make([]byte, 1<<16)
for {
_, err := s.conn.Read(buf)
if err != nil {
s.conn.Close()
return
}
}
}()
var hdr header
for {
// fmt.Printf("ReadFull from %#v\n", s.reader)
// Read the header
if _, err := io.ReadFull(s.reader, hdr[:]); err != nil {
if s.IsClosed() && os.IsTimeout(err) {
return gracefulCloseErr
}
if err != io.EOF && !strings.Contains(err.Error(), "closed") && !strings.Contains(err.Error(), "reset by peer") {
s.logger.Printf("[ERR] yamux: Failed to read header: %v", err)
}
Expand Down
37 changes: 8 additions & 29 deletions session_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,8 @@ type pipeConn struct {
writeDeadline pipeDeadline
writeBlocker chan struct{}
closeCh chan struct{}
closeOnce sync.Once
closeErr error
}

func (p *pipeConn) SetDeadline(t time.Time) error {
Expand All @@ -66,10 +68,12 @@ func (p *pipeConn) Write(b []byte) (int, error) {
}

func (p *pipeConn) Close() error {
p.writeDeadline.set(time.Time{})
err := p.Conn.Close()
close(p.closeCh)
return err
p.closeOnce.Do(func() {
close(p.closeCh)
p.writeDeadline.set(time.Time{})
p.closeErr = p.Conn.Close()
})
return p.closeErr
}

func (p *pipeConn) BlockWrites() {
Expand Down Expand Up @@ -1821,28 +1825,3 @@ func TestMaxIncomingStreams(t *testing.T) {
_, err = str.Read([]byte{0})
require.NoError(t, err)
}

func TestRSTBehavior(t *testing.T) {
client, server := testTCPConns(t)
var wg sync.WaitGroup
wg.Add(1)
go func() {
defer wg.Done()
defer server.Close()
server.Write([]byte("hello"))
time.Sleep(20 * time.Second)
buf := make([]byte, 10)
n, err := server.Read(buf)
if err != nil {
t.Error(err)
} else {
t.Log(string(buf[:n]))
}

}()
client.Write([]byte("world"))
time.Sleep(10 * time.Second)
// close client without reading server msg. This ensures that the TCP stack sends an RST
client.Close()
wg.Wait()
}

0 comments on commit 1525575

Please sign in to comment.