From 47260f576079ad0621cf2f7dc644c99074783497 Mon Sep 17 00:00:00 2001 From: Masahiro Nagano Date: Thu, 9 May 2019 12:35:08 +0900 Subject: [PATCH] graceful stop --- handler/handler.go | 6 ++++- wsgate-server.go | 57 ++++++++++++++++++++++++++++++++++++++++------ 2 files changed, 55 insertions(+), 8 deletions(-) diff --git a/handler/handler.go b/handler/handler.go index 88c59e0..8e2c425 100644 --- a/handler/handler.go +++ b/handler/handler.go @@ -5,6 +5,7 @@ import ( "io" "net" "net/http" + "sync" "time" "github.com/gorilla/mux" @@ -73,8 +74,11 @@ func (h *Handler) Hello() func(w http.ResponseWriter, r *http.Request) { } // Proxy proxy handler -func (h *Handler) Proxy() func(w http.ResponseWriter, r *http.Request) { +func (h *Handler) Proxy(wg *sync.WaitGroup) func(w http.ResponseWriter, r *http.Request) { return func(w http.ResponseWriter, r *http.Request) { + wg.Add(1) + defer wg.Done() + vars := mux.Vars(r) proxyDest := vars["dest"] upstream := "" diff --git a/wsgate-server.go b/wsgate-server.go index 25631f5..78c181c 100644 --- a/wsgate-server.go +++ b/wsgate-server.go @@ -1,11 +1,16 @@ package main import ( + "context" "flag" "fmt" "net" "net/http" + "os" + "os/signal" "runtime" + "sync" + "syscall" "time" "github.com/gorilla/mux" @@ -24,6 +29,7 @@ var ( handshakeTimeout = flag.Duration("handshake_timeout", 10*time.Second, "Handshake timeout.") dialTimeout = flag.Duration("dial_timeout", 10*time.Second, "Dial timeout.") writeTimeout = flag.Duration("write_timeout", 10*time.Second, "Write timeout.") + shutdownTimeout = flag.Duration("shutdown_timeout", 86400*time.Second, "timeout to wait for all connections to be closed") mapFile = flag.String("map", "", "path and proxy host mapping file") publicKeyFile = flag.String("public-key", "", "public key for signing auth header") dumpTCP = flag.Uint("dump-tcp", 0, "Dump TCP. 0 = disable, 1 = src to dest, 2 = both") @@ -70,10 +76,49 @@ func main() { logger.Fatal("Failed init handler", zap.Error(err)) } + wg := &sync.WaitGroup{} + defer func() { + c := make(chan struct{}) + go func() { + defer close(c) + wg.Wait() + }() + select { + case <-c: + logger.Info("All connections closed. Shutdown") + return + case <-time.After(*shutdownTimeout): + logger.Info("Timeout, close some connections. Shutdown") + return + } + }() + m := mux.NewRouter() m.HandleFunc("/", proxyHandler.Hello()) m.HandleFunc("/live", proxyHandler.Hello()) - m.HandleFunc("/proxy/{dest}", proxyHandler.Proxy()) + m.HandleFunc("/proxy/{dest}", proxyHandler.Proxy(wg)) + + s := &http.Server{ + Handler: m, + ReadTimeout: 10 * time.Second, + WriteTimeout: 10 * time.Second, + MaxHeaderBytes: 1 << 20, + } + + idleConnsClosed := make(chan struct{}) + go func() { + sigChan := make(chan os.Signal, 1) + signal.Notify(sigChan, syscall.SIGTERM) + <-sigChan + logger.Info("Signal received. Start to shutdown") + ctx, cancel := context.WithTimeout(context.Background(), *shutdownTimeout) + if es := s.Shutdown(ctx); es != nil { + logger.Warn("Shutdown error", zap.Error(err)) + } + cancel() + close(idleConnsClosed) + logger.Info("Waiting for all connections to be closed") + }() l, err := ss.NewListener() if l == nil || err != nil { @@ -84,11 +129,9 @@ func main() { } } - s := &http.Server{ - Handler: m, - ReadTimeout: 10 * time.Second, - WriteTimeout: 10 * time.Second, - MaxHeaderBytes: 1 << 20, + if err := s.Serve(l); err != http.ErrServerClosed { + logger.Error("Error in Serve", zap.Error(err)) } - s.Serve(l) + + <-idleConnsClosed }