diff --git a/README.md b/README.md index 455e775..554b3e1 100644 --- a/README.md +++ b/README.md @@ -87,18 +87,24 @@ db, err := sql.Open("mysql", "yyyy:xxx@websocket(https://example.com/proxy/mysql ``` Usage of ./wsgate-server: --dial_timeout duration - Dial timeout. (default 10s) --handshake_timeout duration - Handshake timeout. (default 10s) --listen string - Address to listen to. (default "127.0.0.1:8086") --map string - path and proxy host mapping file --public-key string - public key for signing auth header --version - show version --write_timeout duration - Write timeout. (default 10s) + -dial_timeout duration + Dial timeout. (default 10s) + -dump-tcp uint + Dump TCP. 0 = disable, 1 = src to dest, 2 = both + -handshake_timeout duration + Handshake timeout. (default 10s) + -jwt-freshness duration + time in seconds to allow generated jwt tokens (default 1h0m0s) + -listen string + Address to listen to. (default "127.0.0.1:8086") + -map string + path and proxy host mapping file + -public-key string + public key for verifying JWT auth header + -shutdown_timeout duration + timeout to wait for all connections to be closed (default 24h0m0s) + -version + show version + -write_timeout duration + Write timeout. (default 10s) ``` diff --git a/handler/handler.go b/handler/handler.go index 8e2c425..58f493f 100644 --- a/handler/handler.go +++ b/handler/handler.go @@ -89,19 +89,22 @@ func (h *Handler) Proxy(wg *sync.WaitGroup) func(w http.ResponseWriter, r *http. logger := h.logger.With( zap.Uint64("seq", h.sq.Next()), - zap.String("user-email", r.Header.Get("X-Goog-Authenticated-User-Email")), zap.String("x-forwarded-for", r.Header.Get("X-Forwarded-For")), zap.String("remote-addr", r.RemoteAddr), zap.String("destination", proxyDest), ) if h.pk.Enabled() { - _, err := h.pk.Verify(r.Header.Get("Authorization")) + sub, err := h.pk.Verify(r.Header.Get("Authorization")) if err != nil { - logger.Warn("No authorize", zap.Error(err)) + logger.Warn("Failed to authorize", zap.Error(err)) http.Error(w, err.Error(), http.StatusUnauthorized) return } + logger = logger.With(zap.String("user-email", sub)) + + } else { + logger = logger.With(zap.String("user-email", r.Header.Get("X-Goog-Authenticated-User-Email"))) } upstream, ok := h.mp.Get(proxyDest) diff --git a/publickey/publickey.go b/publickey/publickey.go index 34a2a9d..e739684 100644 --- a/publickey/publickey.go +++ b/publickey/publickey.go @@ -5,6 +5,7 @@ import ( "fmt" "io/ioutil" "strings" + "time" jwt "github.com/dgrijalva/jwt-go" "github.com/pkg/errors" @@ -15,10 +16,11 @@ import ( type Publickey struct { publicKeyFile string verifyKey *rsa.PublicKey + freshnessTime time.Duration } // New publickey reader/checker -func New(publicKeyFile string, logger *zap.Logger) (*Publickey, error) { +func New(publicKeyFile string, freshnessTime time.Duration, logger *zap.Logger) (*Publickey, error) { var verifyKey *rsa.PublicKey if publicKeyFile != "" { verifyBytes, err := ioutil.ReadFile(publicKeyFile) @@ -33,6 +35,7 @@ func New(publicKeyFile string, logger *zap.Logger) (*Publickey, error) { return &Publickey{ publicKeyFile: publicKeyFile, verifyKey: verifyKey, + freshnessTime: freshnessTime, }, nil } @@ -42,9 +45,9 @@ func (pk Publickey) Enabled() bool { } // Verify verify auth header -func (pk Publickey) Verify(t string) (bool, error) { +func (pk Publickey) Verify(t string) (string, error) { if t == "" { - return false, fmt.Errorf("no tokenString") + return "", fmt.Errorf("no tokenString") } t = strings.TrimPrefix(t, "Bearer ") claims := &jwt.StandardClaims{} @@ -57,8 +60,17 @@ func (pk Publickey) Verify(t string) (bool, error) { }) if err != nil { - return false, fmt.Errorf("Token is invalid: %v", err) + return "", fmt.Errorf("Token is invalid: %v", err) } - return true, nil + now := time.Now() + iat := now.Add(-pk.freshnessTime) + if claims.ExpiresAt == 0 || claims.ExpiresAt < now.Unix() { + return "", fmt.Errorf("Token is expired") + } + if claims.IssuedAt == 0 || claims.IssuedAt < iat.Unix() { + return "", fmt.Errorf("Token is too old") + } + + return claims.Subject, nil } diff --git a/wsgate-server.go b/wsgate-server.go index 78c181c..e3debb1 100644 --- a/wsgate-server.go +++ b/wsgate-server.go @@ -31,7 +31,8 @@ var ( writeTimeout = flag.Duration("write_timeout", 10*time.Second, "Write timeout.") shutdownTimeout = flag.Duration("shutdown_timeout", 86400*time.Second, "timeout to wait for all connections to be closed") mapFile = flag.String("map", "", "path and proxy host mapping file") - publicKeyFile = flag.String("public-key", "", "public key for signing auth header") + publicKeyFile = flag.String("public-key", "", "public key for verifying JWT auth header") + jwtFreshness = flag.Duration("jwt-freshness", 3600*time.Second, "time in seconds to allow generated jwt tokens") dumpTCP = flag.Uint("dump-tcp", 0, "Dump TCP. 0 = disable, 1 = src to dest, 2 = both") ) @@ -58,7 +59,7 @@ func main() { logger.Fatal("Failed init mapping", zap.Error(err)) } - pk, err := publickey.New(*publicKeyFile, logger) + pk, err := publickey.New(*publicKeyFile, *jwtFreshness, logger) if err != nil { logger.Fatal("Failed init publickey", zap.Error(err)) }