Skip to content

Commit

Permalink
devide to 3 pkgs
Browse files Browse the repository at this point in the history
  • Loading branch information
kazeburo committed Feb 4, 2019
1 parent 2abc29b commit 573f74b
Show file tree
Hide file tree
Showing 4 changed files with 358 additions and 233 deletions.
216 changes: 216 additions & 0 deletions handler/handler.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,216 @@
package handler

import (
"fmt"
"io"
"net"
"net/http"
"time"

"github.com/gorilla/mux"
"github.com/gorilla/websocket"
"github.com/kazeburo/wsgate-server/mapping"
"github.com/kazeburo/wsgate-server/publickey"
"go.uber.org/zap"
)

// Handler handlers
type Handler struct {
logger *zap.Logger
upgrader websocket.Upgrader
dialTimeout time.Duration
writeTimeout time.Duration
mp *mapping.Mapping
pk *publickey.Publickey
}

// New new handler
func New(
handshakeTimeout time.Duration,
dialTimeout time.Duration,
writeTimeout time.Duration,
mp *mapping.Mapping,
pk *publickey.Publickey,
logger *zap.Logger) (*Handler, error) {

upgrader := websocket.Upgrader{
ReadBufferSize: 1024,
WriteBufferSize: 1024,
HandshakeTimeout: handshakeTimeout,
CheckOrigin: func(r *http.Request) bool {
return true
},
}

return &Handler{
logger: logger,
upgrader: upgrader,
dialTimeout: dialTimeout,
writeTimeout: writeTimeout,
mp: mp,
pk: pk,
}, nil
}

// Hello hello handler
func (h *Handler) Hello() func(w http.ResponseWriter, r *http.Request) {
return func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte("OK\n"))
}
}

// Proxy proxy handler
func (h *Handler) Proxy() func(w http.ResponseWriter, r *http.Request) {
return func(w http.ResponseWriter, r *http.Request) {
vars := mux.Vars(r)
proxyDest := vars["dest"]
upstream := ""
readLen := int64(0)
writeLen := int64(0)
hasError := false
disconnectAt := ""

logger := h.logger.With(
zap.String("user-email", r.Header.Get("X-Goog-Authenticated-User-Email")),
zap.String("x-forwarded-for", r.Header.Get("X-Forwarded-For")),
zap.String("remote-addr", r.RemoteAddr),
zap.String("destination", proxyDest),
)

if h.pk.Enabled() {
_, err := h.pk.Verify(r.Header.Get("Authorization"))
if err != nil {
logger.Warn("No authorize", zap.Error(err))
http.Error(w, err.Error(), http.StatusUnauthorized)
return
}
}

upstream, ok := h.mp.Get(proxyDest)
if !ok {
hasError = true
logger.Warn("No map found")
http.Error(w, fmt.Sprintf("Not found: %s", proxyDest), 404)
return
}

logger = logger.With(zap.String("upstream", upstream))

s, err := net.DialTimeout("tcp", upstream, h.dialTimeout)

if err != nil {
hasError = true
logger.Warn("DialTimeout", zap.Error(err))
http.Error(w, fmt.Sprintf("Could not connect upstream: %v", err), 500)
return
}

conn, err := h.upgrader.Upgrade(w, r, nil)
if err != nil {
hasError = true
s.Close()
logger.Warn("Failed to Upgrade", zap.Error(err))
return
}

logger.Info("log", zap.String("status", "Connected"))

defer func() {
status := "Suceeded"
if hasError {
status = "Failed"
}
logger.Info("log",
zap.String("status", status),
zap.Int64("read", readLen),
zap.Int64("write", writeLen),
zap.String("disconnect_at", disconnectAt),
)
}()

doneCh := make(chan bool)
goClose := false

// websocket -> server
go func() {
defer func() { doneCh <- true }()
for {
mt, r, err := conn.NextReader()
if websocket.IsCloseError(err,
websocket.CloseNormalClosure, // Normal.
websocket.CloseAbnormalClosure, // OpenSSH killed proxy client.
) {
return
}
if err != nil {
if !goClose {
logger.Warn("NextReader", zap.Error(err))
hasError = true
}
if disconnectAt == "" {
disconnectAt = "client_nextreader"
}
return
}
if mt != websocket.BinaryMessage {
logger.Warn("BinaryMessage required", zap.Int("messageType", mt))
hasError = true
return
}
n, err := io.Copy(s, r)
if err != nil {
if !goClose {
logger.Warn("Reading from websocket", zap.Error(err))
hasError = true
}
if disconnectAt == "" {
disconnectAt = "client_upstream_copy"
}
return
}
readLen += n
}
}()

// server -> websocket
go func() {
defer func() { doneCh <- true }()
for {
b := make([]byte, 64*1024)
n, err := s.Read(b)
if err != nil {
if !goClose && err != io.EOF {
logger.Warn("Reading from dest", zap.Error(err))
hasError = true
}
if disconnectAt == "" {
disconnectAt = "upstream_read"
}
return
}

b = b[:n]

if err := conn.WriteMessage(websocket.BinaryMessage, b); err != nil {
if !goClose {
logger.Warn("WriteMessage", zap.Error(err))
hasError = true
}
if disconnectAt == "" {
disconnectAt = "client_write"
}
return
}
writeLen += int64(n)
}
}()

<-doneCh
goClose = true
s.Close()
conn.Close()
<-doneCh

}

}
51 changes: 51 additions & 0 deletions mapping/mapping.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
package mapping

import (
"bufio"
"os"
"regexp"
"strings"

"github.com/pkg/errors"
"go.uber.org/zap"
)

// Mapping struct
type Mapping struct {
m map[string]string
}

// New new mapping
func New(mapFile string, logger *zap.Logger) (*Mapping, error) {
r := regexp.MustCompile(`^ *#`)
m := make(map[string]string)
if mapFile != "" {
f, err := os.Open(mapFile)
if err != nil {
return nil, errors.Wrap(err, "Failed to open mapFile")
}
s := bufio.NewScanner(f)
for s.Scan() {
if r.MatchString(s.Text()) {
continue
}
l := strings.SplitN(s.Text(), ",", 2)
if len(l) != 2 {
return nil, errors.Wrapf(err, "Invalid line: %s", s.Text())
}
logger.Info("Created map",
zap.String("from", l[0]),
zap.String("to", l[1]))
m[l[0]] = l[1]
}
}
return &Mapping{
m: m,
}, nil
}

// Get get mapping
func (mp *Mapping) Get(proxyDest string) (string, bool) {
upstream, ok := mp.m[proxyDest]
return upstream, ok
}
65 changes: 65 additions & 0 deletions publickey/publickey.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
package publickey

import (
"crypto/rsa"
"fmt"
"io/ioutil"
"strings"

jwt "github.com/dgrijalva/jwt-go"
"github.com/pkg/errors"
"go.uber.org/zap"
)

// Publickey struct
type Publickey struct {
publicKeyFile string
verifyKey *rsa.PublicKey
}

// New publickey reader/checker
func New(publicKeyFile string, logger *zap.Logger) (*Publickey, error) {
var verifyKey *rsa.PublicKey
if publicKeyFile != "" {
verifyBytes, err := ioutil.ReadFile(publicKeyFile)
if err != nil {
return nil, errors.Wrap(err, "Failed read pubkey")
}
verifyKey, err = jwt.ParseRSAPublicKeyFromPEM(verifyBytes)
if err != nil {
return nil, errors.Wrap(err, "Failed parse pubkey")
}
}
return &Publickey{
publicKeyFile: publicKeyFile,
verifyKey: verifyKey,
}, nil
}

// Enabled publickey is enabled
func (pk Publickey) Enabled() bool {
return pk.publicKeyFile != ""
}

// Verify verify auth header
func (pk Publickey) Verify(t string) (bool, error) {
if t == "" {
return false, fmt.Errorf("no tokenString")
}
t = strings.TrimPrefix(t, "Bearer ")
claims := &jwt.StandardClaims{}
token, err := jwt.ParseWithClaims(t, claims, func(token *jwt.Token) (interface{}, error) {
return pk.verifyKey, nil
})

if err != nil {
return false, fmt.Errorf("Token is invalid: %v", err)
}
if !token.Valid {
return false, fmt.Errorf("Token is invalid")
}
if claims.Valid() != nil {
return false, fmt.Errorf("Invalid claims: %v", claims.Valid())
}
return true, nil
}
Loading

0 comments on commit 573f74b

Please sign in to comment.