diff --git a/internal/msggateway/client.go b/internal/msggateway/client.go index 9d594c1fa7..ca72bfb846 100644 --- a/internal/msggateway/client.go +++ b/internal/msggateway/client.go @@ -73,6 +73,7 @@ type Client struct { closed atomic.Bool closedErr error token string + hbCtx context.Context } // ResetClient updates the client's state with new connection and context information. @@ -89,6 +90,7 @@ func (c *Client) ResetClient(ctx *UserConnContext, conn LongConn, longConnServer c.closed.Store(false) c.closedErr = nil c.token = ctx.GetToken() + c.hbCtx, _ = context.WithTimeout(c.ctx, pongWait*2) } func (c *Client) pingHandler(_ string) error { @@ -113,23 +115,8 @@ func (c *Client) readMessage() { _ = c.conn.SetReadDeadline(pongWait) c.conn.SetPingHandler(c.pingHandler) - if c.PlatformID == 5 { - go func() { - ticker := time.NewTicker(20) - defer ticker.Stop() - - for { - select { - case <-ticker.C: - if err := c.writePingMsg(); err != nil { - log.ZError(c.ctx, "send Ping Message error.", err) - return - } - case <-c.ctx.Done(): - return - } - } - }() + if c.PlatformID == constant.WebPlatformID { + go c.heartbeat(c.hbCtx) } for { @@ -138,7 +125,6 @@ func (c *Client) readMessage() { if returnErr != nil { log.ZWarn(c.ctx, "readMessage", returnErr, "messageType", messageType) c.closedErr = returnErr - <-c.ctx.Done() return } @@ -146,7 +132,6 @@ func (c *Client) readMessage() { if c.closed.Load() { // The scenario where the connection has just been closed, but the coroutine has not exited c.closedErr = ErrConnClosed - <-c.ctx.Done() return } @@ -156,12 +141,10 @@ func (c *Client) readMessage() { parseDataErr := c.handleMessage(message) if parseDataErr != nil { c.closedErr = parseDataErr - <-c.ctx.Done() return } case MessageText: c.closedErr = ErrNotSupportMessageProtocol - <-c.ctx.Done() return case PingMessage: @@ -170,7 +153,6 @@ func (c *Client) readMessage() { case CloseMessage: c.closedErr = ErrClientClosed - <-c.ctx.Done() return default: @@ -261,6 +243,7 @@ func (c *Client) close() { c.closed.Store(true) c.conn.Close() + <-c.hbCtx.Done() // Close initiated heartbeat in server send. c.longConnServer.UnRegister(c) } @@ -347,20 +330,22 @@ func (c *Client) writeBinaryMsg(resp Resp) error { return c.conn.WriteMessage(MessageBinary, encodedBuf) } -func (c *Client) writePingMsg() error { - if c.closed.Load() { - return nil - } - - c.w.Lock() - defer c.w.Unlock() +func (c *Client) heartbeat(ctx context.Context) { + log.ZDebug(ctx, "server initiative send heartbeat start.") + ticker := time.NewTicker(pingPeriod) + defer ticker.Stop() - err := c.conn.SetWriteDeadline(writeWait) - if err != nil { - return err + for { + select { + case <-ticker.C: + if err := c.conn.WriteMessage(PingMessage, nil); err != nil { + log.ZError(c.ctx, "send Ping Message error.", err) + return + } + case <-c.hbCtx.Done(): + return + } } - - return c.conn.WriteMessage(PingMessage, nil) } func (c *Client) writePongMsg() error { diff --git a/internal/msggateway/constant.go b/internal/msggateway/constant.go index 64664ac0ab..125be16358 100644 --- a/internal/msggateway/constant.go +++ b/internal/msggateway/constant.go @@ -53,6 +53,9 @@ const ( // Time allowed to read the next pong message from the peer. pongWait = 30 * time.Second + // Send pings to peer with this period. Must be less than pongWait. + pingPeriod = (pongWait * 9) / 10 + // Maximum message size allowed from peer. maxMessageSize = 51200 ) diff --git a/internal/msggateway/long_conn.go b/internal/msggateway/long_conn.go index 7d5bef4c3a..c1b3e27c93 100644 --- a/internal/msggateway/long_conn.go +++ b/internal/msggateway/long_conn.go @@ -16,10 +16,11 @@ package msggateway import ( "encoding/json" - "github.com/openimsdk/tools/apiresp" "net/http" "time" + "github.com/openimsdk/tools/apiresp" + "github.com/gorilla/websocket" "github.com/openimsdk/tools/errs" )