diff --git a/handler/handler.go b/handler/handler.go new file mode 100644 index 0000000..4ba7602 --- /dev/null +++ b/handler/handler.go @@ -0,0 +1,216 @@ +package handler + +import ( + "fmt" + "io" + "net" + "net/http" + "time" + + "github.com/gorilla/mux" + "github.com/gorilla/websocket" + "github.com/kazeburo/wsgate-server/mapping" + "github.com/kazeburo/wsgate-server/publickey" + "go.uber.org/zap" +) + +// Handler handlers +type Handler struct { + logger *zap.Logger + upgrader websocket.Upgrader + dialTimeout time.Duration + writeTimeout time.Duration + mp *mapping.Mapping + pk *publickey.Publickey +} + +// New new handler +func New( + handshakeTimeout time.Duration, + dialTimeout time.Duration, + writeTimeout time.Duration, + mp *mapping.Mapping, + pk *publickey.Publickey, + logger *zap.Logger) (*Handler, error) { + + upgrader := websocket.Upgrader{ + ReadBufferSize: 1024, + WriteBufferSize: 1024, + HandshakeTimeout: handshakeTimeout, + CheckOrigin: func(r *http.Request) bool { + return true + }, + } + + return &Handler{ + logger: logger, + upgrader: upgrader, + dialTimeout: dialTimeout, + writeTimeout: writeTimeout, + mp: mp, + pk: pk, + }, nil +} + +// Hello hello handler +func (h *Handler) Hello() func(w http.ResponseWriter, r *http.Request) { + return func(w http.ResponseWriter, r *http.Request) { + w.Write([]byte("OK\n")) + } +} + +// Proxy proxy handler +func (h *Handler) Proxy() func(w http.ResponseWriter, r *http.Request) { + return func(w http.ResponseWriter, r *http.Request) { + vars := mux.Vars(r) + proxyDest := vars["dest"] + upstream := "" + readLen := int64(0) + writeLen := int64(0) + hasError := false + disconnectAt := "" + + logger := h.logger.With( + zap.String("user-email", r.Header.Get("X-Goog-Authenticated-User-Email")), + zap.String("x-forwarded-for", r.Header.Get("X-Forwarded-For")), + zap.String("remote-addr", r.RemoteAddr), + zap.String("destination", proxyDest), + ) + + if h.pk.Enabled() { + _, err := h.pk.Verify(r.Header.Get("Authorization")) + if err != nil { + logger.Warn("No authorize", zap.Error(err)) + http.Error(w, err.Error(), http.StatusUnauthorized) + return + } + } + + upstream, ok := h.mp.Get(proxyDest) + if !ok { + hasError = true + logger.Warn("No map found") + http.Error(w, fmt.Sprintf("Not found: %s", proxyDest), 404) + return + } + + logger = logger.With(zap.String("upstream", upstream)) + + s, err := net.DialTimeout("tcp", upstream, h.dialTimeout) + + if err != nil { + hasError = true + logger.Warn("DialTimeout", zap.Error(err)) + http.Error(w, fmt.Sprintf("Could not connect upstream: %v", err), 500) + return + } + + conn, err := h.upgrader.Upgrade(w, r, nil) + if err != nil { + hasError = true + s.Close() + logger.Warn("Failed to Upgrade", zap.Error(err)) + return + } + + logger.Info("log", zap.String("status", "Connected")) + + defer func() { + status := "Suceeded" + if hasError { + status = "Failed" + } + logger.Info("log", + zap.String("status", status), + zap.Int64("read", readLen), + zap.Int64("write", writeLen), + zap.String("disconnect_at", disconnectAt), + ) + }() + + doneCh := make(chan bool) + goClose := false + + // websocket -> server + go func() { + defer func() { doneCh <- true }() + for { + mt, r, err := conn.NextReader() + if websocket.IsCloseError(err, + websocket.CloseNormalClosure, // Normal. + websocket.CloseAbnormalClosure, // OpenSSH killed proxy client. + ) { + return + } + if err != nil { + if !goClose { + logger.Warn("NextReader", zap.Error(err)) + hasError = true + } + if disconnectAt == "" { + disconnectAt = "client_nextreader" + } + return + } + if mt != websocket.BinaryMessage { + logger.Warn("BinaryMessage required", zap.Int("messageType", mt)) + hasError = true + return + } + n, err := io.Copy(s, r) + if err != nil { + if !goClose { + logger.Warn("Reading from websocket", zap.Error(err)) + hasError = true + } + if disconnectAt == "" { + disconnectAt = "client_upstream_copy" + } + return + } + readLen += n + } + }() + + // server -> websocket + go func() { + defer func() { doneCh <- true }() + for { + b := make([]byte, 64*1024) + n, err := s.Read(b) + if err != nil { + if !goClose && err != io.EOF { + logger.Warn("Reading from dest", zap.Error(err)) + hasError = true + } + if disconnectAt == "" { + disconnectAt = "upstream_read" + } + return + } + + b = b[:n] + + if err := conn.WriteMessage(websocket.BinaryMessage, b); err != nil { + if !goClose { + logger.Warn("WriteMessage", zap.Error(err)) + hasError = true + } + if disconnectAt == "" { + disconnectAt = "client_write" + } + return + } + writeLen += int64(n) + } + }() + + <-doneCh + goClose = true + s.Close() + conn.Close() + <-doneCh + + } + +} diff --git a/mapping/mapping.go b/mapping/mapping.go new file mode 100644 index 0000000..667ec47 --- /dev/null +++ b/mapping/mapping.go @@ -0,0 +1,51 @@ +package mapping + +import ( + "bufio" + "os" + "regexp" + "strings" + + "github.com/pkg/errors" + "go.uber.org/zap" +) + +// Mapping struct +type Mapping struct { + m map[string]string +} + +// New new mapping +func New(mapFile string, logger *zap.Logger) (*Mapping, error) { + r := regexp.MustCompile(`^ *#`) + m := make(map[string]string) + if mapFile != "" { + f, err := os.Open(mapFile) + if err != nil { + return nil, errors.Wrap(err, "Failed to open mapFile") + } + s := bufio.NewScanner(f) + for s.Scan() { + if r.MatchString(s.Text()) { + continue + } + l := strings.SplitN(s.Text(), ",", 2) + if len(l) != 2 { + return nil, errors.Wrapf(err, "Invalid line: %s", s.Text()) + } + logger.Info("Created map", + zap.String("from", l[0]), + zap.String("to", l[1])) + m[l[0]] = l[1] + } + } + return &Mapping{ + m: m, + }, nil +} + +// Get get mapping +func (mp *Mapping) Get(proxyDest string) (string, bool) { + upstream, ok := mp.m[proxyDest] + return upstream, ok +} diff --git a/publickey/publickey.go b/publickey/publickey.go new file mode 100644 index 0000000..2f3bfd7 --- /dev/null +++ b/publickey/publickey.go @@ -0,0 +1,65 @@ +package publickey + +import ( + "crypto/rsa" + "fmt" + "io/ioutil" + "strings" + + jwt "github.com/dgrijalva/jwt-go" + "github.com/pkg/errors" + "go.uber.org/zap" +) + +// Publickey struct +type Publickey struct { + publicKeyFile string + verifyKey *rsa.PublicKey +} + +// New publickey reader/checker +func New(publicKeyFile string, logger *zap.Logger) (*Publickey, error) { + var verifyKey *rsa.PublicKey + if publicKeyFile != "" { + verifyBytes, err := ioutil.ReadFile(publicKeyFile) + if err != nil { + return nil, errors.Wrap(err, "Failed read pubkey") + } + verifyKey, err = jwt.ParseRSAPublicKeyFromPEM(verifyBytes) + if err != nil { + return nil, errors.Wrap(err, "Failed parse pubkey") + } + } + return &Publickey{ + publicKeyFile: publicKeyFile, + verifyKey: verifyKey, + }, nil +} + +// Enabled publickey is enabled +func (pk Publickey) Enabled() bool { + return pk.publicKeyFile != "" +} + +// Verify verify auth header +func (pk Publickey) Verify(t string) (bool, error) { + if t == "" { + return false, fmt.Errorf("no tokenString") + } + t = strings.TrimPrefix(t, "Bearer ") + claims := &jwt.StandardClaims{} + token, err := jwt.ParseWithClaims(t, claims, func(token *jwt.Token) (interface{}, error) { + return pk.verifyKey, nil + }) + + if err != nil { + return false, fmt.Errorf("Token is invalid: %v", err) + } + if !token.Valid { + return false, fmt.Errorf("Token is invalid") + } + if claims.Valid() != nil { + return false, fmt.Errorf("Invalid claims: %v", claims.Valid()) + } + return true, nil +} diff --git a/wsgate-server.go b/wsgate-server.go index deb3e73..1bc09e5 100644 --- a/wsgate-server.go +++ b/wsgate-server.go @@ -1,215 +1,33 @@ package main import ( - "bufio" - "crypto/rsa" "flag" "fmt" - "io" - "io/ioutil" "net" "net/http" - "os" - "regexp" "runtime" - "strings" "time" - jwt "github.com/dgrijalva/jwt-go" "github.com/gorilla/mux" - "github.com/gorilla/websocket" + "github.com/kazeburo/wsgate-server/handler" + "github.com/kazeburo/wsgate-server/mapping" + "github.com/kazeburo/wsgate-server/publickey" ss "github.com/lestrrat/go-server-starter-listener" "go.uber.org/zap" ) var ( - Version string - + // Version wsgate-server version + Version string + showVersion = flag.Bool("version", false, "show version") listen = flag.String("listen", "127.0.0.1:8086", "Address to listen to.") - dialTimeout = flag.Duration("dial_timeout", 10*time.Second, "Dial timeout.") handshakeTimeout = flag.Duration("handshake_timeout", 10*time.Second, "Handshake timeout.") + dialTimeout = flag.Duration("dial_timeout", 10*time.Second, "Dial timeout.") writeTimeout = flag.Duration("write_timeout", 10*time.Second, "Write timeout.") - showVersion = flag.Bool("version", false, "show version") mapFile = flag.String("map", "", "path and proxy host mapping file") publicKeyFile = flag.String("public-key", "", "public key for signing auth header") - - upgrader websocket.Upgrader - mapping map[string]string - verifyKey *rsa.PublicKey ) -func handleHello(w http.ResponseWriter, r *http.Request) { - w.Write([]byte("OK\n")) -} - -func handleProxy(w http.ResponseWriter, r *http.Request, logger *zap.Logger) { - vars := mux.Vars(r) - proxyDest := vars["dest"] - upstream := "" - readLen := int64(0) - writeLen := int64(0) - hasError := false - disconnectAt := "" - - logger = logger.With( - zap.String("user-email", r.Header.Get("X-Goog-Authenticated-User-Email")), - zap.String("x-forwarded-for", r.Header.Get("X-Forwarded-For")), - zap.String("remote-addr", r.RemoteAddr), - zap.String("destination", proxyDest), - ) - - if *publicKeyFile != "" { - tokenString := r.Header.Get("Authorization") - if tokenString == "" { - http.Error(w, http.StatusText(http.StatusUnauthorized), http.StatusUnauthorized) - return - } - tokenString = strings.TrimPrefix(tokenString, "Bearer ") - claims := &jwt.StandardClaims{} - token, err := jwt.ParseWithClaims(tokenString, claims, func(token *jwt.Token) (interface{}, error) { - return verifyKey, nil - }) - - if err != nil { - http.Error(w, fmt.Sprintf("Token is invalid: %v", err), http.StatusUnauthorized) - return - } - if !token.Valid { - http.Error(w, fmt.Sprintf("Token is invalid"), http.StatusUnauthorized) - return - } - if claims.Valid() != nil { - http.Error(w, fmt.Sprintf("Invalid claims: %v", claims.Valid()), http.StatusUnauthorized) - return - } - } - upstream, ok := mapping[proxyDest] - if !ok { - hasError = true - logger.Warn("No map found") - http.Error(w, fmt.Sprintf("Not found: %s", proxyDest), 404) - return - } - - logger = logger.With(zap.String("upstream", upstream)) - - s, err := net.DialTimeout("tcp", upstream, *dialTimeout) - - if err != nil { - hasError = true - logger.Warn("DialTimeout", zap.Error(err)) - http.Error(w, fmt.Sprintf("Could not connect upstream: %v", err), 500) - return - } - - conn, err := upgrader.Upgrade(w, r, nil) - if err != nil { - hasError = true - s.Close() - logger.Warn("Failed to Upgrade", zap.Error(err)) - return - } - - logger.Info("log", zap.String("status", "Connected")) - - defer func() { - status := "Suceeded" - if hasError { - status = "Failed" - } - logger.Info("log", - zap.String("status", status), - zap.Int64("read", readLen), - zap.Int64("write", writeLen), - zap.String("disconnect_at", disconnectAt), - ) - }() - - doneCh := make(chan bool) - goClose := false - - // websocket -> server - go func() { - defer func() { doneCh <- true }() - for { - mt, r, err := conn.NextReader() - if websocket.IsCloseError(err, - websocket.CloseNormalClosure, // Normal. - websocket.CloseAbnormalClosure, // OpenSSH killed proxy client. - ) { - return - } - if err != nil { - if !goClose { - logger.Warn("NextReader", zap.Error(err)) - hasError = true - } - if disconnectAt == "" { - disconnectAt = "client_nextreader" - } - return - } - if mt != websocket.BinaryMessage { - logger.Warn("BinaryMessage required", zap.Int("messageType", mt)) - hasError = true - return - } - n, err := io.Copy(s, r) - if err != nil { - if !goClose { - logger.Warn("Reading from websocket", zap.Error(err)) - hasError = true - } - if disconnectAt == "" { - disconnectAt = "client_upstream_copy" - } - return - } - readLen += n - } - }() - - // server -> websocket - go func() { - defer func() { doneCh <- true }() - for { - b := make([]byte, 64*1024) - n, err := s.Read(b) - if err != nil { - if !goClose && err != io.EOF { - logger.Warn("Reading from dest", zap.Error(err)) - hasError = true - } - if disconnectAt == "" { - disconnectAt = "upstream_read" - } - return - } - - b = b[:n] - - if err := conn.WriteMessage(websocket.BinaryMessage, b); err != nil { - if !goClose { - logger.Warn("WriteMessage", zap.Error(err)) - hasError = true - } - if disconnectAt == "" { - disconnectAt = "client_write" - } - return - } - writeLen += int64(n) - } - }() - - <-doneCh - goClose = true - s.Close() - conn.Close() - <-doneCh - -} - func printVersion() { fmt.Printf(`wsgate-server %s Compiler: %s %s @@ -228,57 +46,32 @@ func main() { logger, _ := zap.NewProduction() - r := regexp.MustCompile(`^ *#`) - mapping = make(map[string]string) - if *mapFile != "" { - f, err := os.Open(*mapFile) - if err != nil { - logger.Fatal("Failed to open mapFile", zap.Error(err)) - } - s := bufio.NewScanner(f) - for s.Scan() { - if r.MatchString(s.Text()) { - continue - } - l := strings.SplitN(s.Text(), ",", 2) - if len(l) != 2 { - logger.Fatal("Invalid line", - zap.String("mapFile", *mapFile), - zap.String("line", s.Text())) - } - logger.Info("Created map", - zap.String("from", l[0]), - zap.String("to", l[1])) - mapping[l[0]] = l[1] - } + mp, err := mapping.New(*mapFile, logger) + if err != nil { + logger.Fatal("Failed init mapping", zap.Error(err)) } - if *publicKeyFile != "" { - verifyBytes, err := ioutil.ReadFile(*publicKeyFile) - if err != nil { - logger.Fatal("Failed read pubkey", zap.Error(err)) - } - verifyKey, err = jwt.ParseRSAPublicKeyFromPEM(verifyBytes) - if err != nil { - logger.Fatal("Failed read pubkey", zap.Error(err)) - } + pk, err := publickey.New(*publicKeyFile, logger) + if err != nil { + logger.Fatal("Failed init publickey", zap.Error(err)) } - upgrader = websocket.Upgrader{ - ReadBufferSize: 1024, - WriteBufferSize: 1024, - HandshakeTimeout: *handshakeTimeout, - CheckOrigin: func(r *http.Request) bool { - return true - }, + proxyHandler, err := handler.New( + *handshakeTimeout, + *dialTimeout, + *writeTimeout, + mp, + pk, + logger, + ) + if err != nil { + logger.Fatal("Failed init handler", zap.Error(err)) } m := mux.NewRouter() - m.HandleFunc("/", handleHello) - m.HandleFunc("/live", handleHello) - m.HandleFunc("/proxy/{dest}", func(w http.ResponseWriter, r *http.Request) { - handleProxy(w, r, logger) - }) + m.HandleFunc("/", proxyHandler.Hello()) + m.HandleFunc("/live", proxyHandler.Hello()) + m.HandleFunc("/proxy/{dest}", proxyHandler.Proxy()) l, err := ss.NewListener() if l == nil || err != nil {