Skip to content

Commit

Permalink
Merge pull request #5 from kazeburo/graceflul-stop
Browse files Browse the repository at this point in the history
graceful stop
  • Loading branch information
kazeburo committed May 9, 2019
2 parents 09f22b2 + 47260f5 commit 3edcb9b
Show file tree
Hide file tree
Showing 2 changed files with 55 additions and 8 deletions.
6 changes: 5 additions & 1 deletion handler/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"io"
"net"
"net/http"
"sync"
"time"

"github.com/gorilla/mux"
Expand Down Expand Up @@ -73,8 +74,11 @@ func (h *Handler) Hello() func(w http.ResponseWriter, r *http.Request) {
}

// Proxy proxy handler
func (h *Handler) Proxy() func(w http.ResponseWriter, r *http.Request) {
func (h *Handler) Proxy(wg *sync.WaitGroup) func(w http.ResponseWriter, r *http.Request) {
return func(w http.ResponseWriter, r *http.Request) {
wg.Add(1)
defer wg.Done()

vars := mux.Vars(r)
proxyDest := vars["dest"]
upstream := ""
Expand Down
57 changes: 50 additions & 7 deletions wsgate-server.go
Original file line number Diff line number Diff line change
@@ -1,11 +1,16 @@
package main

import (
"context"
"flag"
"fmt"
"net"
"net/http"
"os"
"os/signal"
"runtime"
"sync"
"syscall"
"time"

"github.com/gorilla/mux"
Expand All @@ -24,6 +29,7 @@ var (
handshakeTimeout = flag.Duration("handshake_timeout", 10*time.Second, "Handshake timeout.")
dialTimeout = flag.Duration("dial_timeout", 10*time.Second, "Dial timeout.")
writeTimeout = flag.Duration("write_timeout", 10*time.Second, "Write timeout.")
shutdownTimeout = flag.Duration("shutdown_timeout", 86400*time.Second, "timeout to wait for all connections to be closed")
mapFile = flag.String("map", "", "path and proxy host mapping file")
publicKeyFile = flag.String("public-key", "", "public key for signing auth header")
dumpTCP = flag.Uint("dump-tcp", 0, "Dump TCP. 0 = disable, 1 = src to dest, 2 = both")
Expand Down Expand Up @@ -70,10 +76,49 @@ func main() {
logger.Fatal("Failed init handler", zap.Error(err))
}

wg := &sync.WaitGroup{}
defer func() {
c := make(chan struct{})
go func() {
defer close(c)
wg.Wait()
}()
select {
case <-c:
logger.Info("All connections closed. Shutdown")
return
case <-time.After(*shutdownTimeout):
logger.Info("Timeout, close some connections. Shutdown")
return
}
}()

m := mux.NewRouter()
m.HandleFunc("/", proxyHandler.Hello())
m.HandleFunc("/live", proxyHandler.Hello())
m.HandleFunc("/proxy/{dest}", proxyHandler.Proxy())
m.HandleFunc("/proxy/{dest}", proxyHandler.Proxy(wg))

s := &http.Server{
Handler: m,
ReadTimeout: 10 * time.Second,
WriteTimeout: 10 * time.Second,
MaxHeaderBytes: 1 << 20,
}

idleConnsClosed := make(chan struct{})
go func() {
sigChan := make(chan os.Signal, 1)
signal.Notify(sigChan, syscall.SIGTERM)
<-sigChan
logger.Info("Signal received. Start to shutdown")
ctx, cancel := context.WithTimeout(context.Background(), *shutdownTimeout)
if es := s.Shutdown(ctx); es != nil {
logger.Warn("Shutdown error", zap.Error(err))
}
cancel()
close(idleConnsClosed)
logger.Info("Waiting for all connections to be closed")
}()

l, err := ss.NewListener()
if l == nil || err != nil {
Expand All @@ -84,11 +129,9 @@ func main() {
}
}

s := &http.Server{
Handler: m,
ReadTimeout: 10 * time.Second,
WriteTimeout: 10 * time.Second,
MaxHeaderBytes: 1 << 20,
if err := s.Serve(l); err != http.ErrServerClosed {
logger.Error("Error in Serve", zap.Error(err))
}
s.Serve(l)

<-idleConnsClosed
}

0 comments on commit 3edcb9b

Please sign in to comment.