Skip to content

Commit

Permalink
Add a local OIDC provider (#17)
Browse files Browse the repository at this point in the history
### Description

Implement a local OIDC provider _server_ to enable SSO on backend
services that support OIDC. This new built-in server uses the existing
user identity provided by SAML and/or OIDC.

Claims can be rewritten to match the needs of the local clients, e.g.
the `email` claim can used to derive a `preferred_username` claim by
removing the `@domain` part.

### Type of change

* [x] New feature
* [ ] Feature improvement
* [ ] Bug fix
* [ ] Documentation
* [ ] Cleanup / refactoring
* [ ] Other (please explain)


### How is this change tested ?

* [x] Unit tests
* [x] Manual tests (explain)
* [ ] Tests are not needed
  • Loading branch information
rthellend committed Sep 24, 2023
1 parent 086aad1 commit 3218b53
Show file tree
Hide file tree
Showing 21 changed files with 1,356 additions and 229 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
config.yaml
version.sh
bin
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ Overview of features:
* [x] Terminate _TCP_ connections, and forward the TLS connection to any TLS server (passthrough). The proxy doesn't see the plaintext.
* [x] Terminate HTTPS connections, and forward the requests to HTTP or HTTPS servers (http/1 only, not recommended with c2fmzq-server).
* [x] TLS client authentication & authorization (when the proxy terminates the TLS connections).
* [x] User authentication with OpenID Connect and SAML (for HTTP and HTTPS connections). Optionally issue JSON Web Tokens (JWT) to authenticated users to use with the backend services.
* [x] User authentication with OpenID Connect and SAML (for HTTP and HTTPS connections). Optionally issue JSON Web Tokens (JWT) to authenticated users to use with the backend services and/or run a local OpenID Connect server for backend services.
* [x] Access control by IP address.
* [x] Routing based on Server Name Indication (SNI), with optional default route when SNI isn't used.
* [x] Simple round-robin load balancing between servers.
Expand Down
204 changes: 180 additions & 24 deletions proxy/backend-sso.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,88 +24,234 @@
package proxy

import (
"context"
_ "embed"
"fmt"
"html/template"
"log"
"net"
"net/http"
"slices"
"sort"
"strings"
"time"

jwt "github.com/golang-jwt/jwt/v5"

"github.com/c2FmZQ/tlsproxy/proxy/internal/cookiemanager"
)

type ctxAuthKey struct{}

const (
xTLSProxyUserIDHeader = "X-tlsproxy-user-id"
)

func (be *Backend) getUserAuthentication(w http.ResponseWriter, req *http.Request) bool {
// Filter out the tlsproxy auth cookie.
defer cookiemanager.FilterOutAuthTokenCookie(req)
var (
authCtxKey ctxAuthKey

req.Header.Del(xTLSProxyUserIDHeader)
if be.SSO != nil && !be.checkCookies(w, req) {
return false
//go:embed permission-denied-template.html
permissionDeniedEmbed string
permissionDeniedTemplate *template.Template
//go:embed logout-template.html
logoutEmbed string
logoutTemplate *template.Template
//go:embed sso-status-template.html
ssoStatusEmbed string
ssoStatusTemplate *template.Template
)

func init() {
permissionDeniedTemplate = template.Must(template.New("permission-denied").Parse(permissionDeniedEmbed))
logoutTemplate = template.Must(template.New("logout").Parse(logoutEmbed))
ssoStatusTemplate = template.Must(template.New("sso-status").Parse(ssoStatusEmbed))
}

func claimsFromCtx(ctx context.Context) jwt.Claims {
if v := ctx.Value(authCtxKey); v != nil {
return v.(jwt.Claims)
}
return true
return nil
}

func (be *Backend) userAuthentication(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
req.Header.Del(xTLSProxyUserIDHeader)
if be.SSO != nil {
claims, cont := be.checkCookies(w, req)
if !cont {
return
}
if claims != nil {
sub, err := claims.GetSubject()
if err == nil && sub != "" {
if be.SSO.SetUserIDHeader {
req.Header.Set(xTLSProxyUserIDHeader, sub)
}
req = req.WithContext(context.WithValue(req.Context(), authCtxKey, claims))
}
}
}
// Filter out the tlsproxy auth cookie.
cookiemanager.FilterOutAuthTokenCookie(req)
next.ServeHTTP(w, req)
})
}

func (be *Backend) checkCookies(w http.ResponseWriter, req *http.Request) bool {
func (be *Backend) checkCookies(w http.ResponseWriter, req *http.Request) (jwt.Claims, bool) {
// If a valid ID Token is in the authorization header, use it and
// ignore the cookies.
if tok, err := be.SSO.cm.ValidateAuthorizationHeader(req); err == nil {
if sub, err := tok.Claims.GetSubject(); err == nil && sub != "" {
req.Header.Set(xTLSProxyUserIDHeader, sub)
return true
}
return tok.Claims, true
}

authToken, err := be.SSO.cm.ValidateAuthTokenCookie(req)
if err != nil {
return true
return nil, true
}
sub, err := authToken.Claims.GetSubject()
if err != nil || sub == "" {
return true
return nil, true
}
req.Header.Set(xTLSProxyUserIDHeader, sub)
authClaims := authToken.Claims

if !be.SSO.GenerateIDTokens {
return true
return authClaims, true
}

host := req.Host
if h, _, err := net.SplitHostPort(host); err == nil {
host = h
}
if !slices.Contains(be.ServerNames, host) {
return true
return authClaims, true
}

if err := be.SSO.cm.ValidateIDTokenCookie(req, authToken); err == nil {
// Token is already set, and is valid.
return true
return authClaims, true
}
if err := be.SSO.cm.SetIDTokenCookie(w, req, authToken); err != nil {
http.Error(w, "internal error", http.StatusInternalServerError)
return false
return nil, false
}
http.Redirect(w, req, req.URL.String(), http.StatusFound)
return false
return nil, false
}

func (be *Backend) serveSSOStatus(w http.ResponseWriter, req *http.Request) {
var claims jwt.MapClaims
if c := claimsFromCtx(req.Context()); c != nil {
claims, _ = c.(jwt.MapClaims)
}
var keys []string
for k := range claims {
keys = append(keys, k)
}
sort.Strings(keys)
type kv struct {
Key, Value string
}
var data struct {
Token string
Claims []kv
}
for _, k := range keys {
if k == "iat" {
v, _ := claims.GetIssuedAt()
data.Claims = append(data.Claims, kv{k, v.String()})
continue
}
if k == "exp" {
v, _ := claims.GetExpirationTime()
data.Claims = append(data.Claims, kv{k, v.String()})
continue
}
data.Claims = append(data.Claims, kv{k, fmt.Sprint(claims[k])})
}
token, _, err := be.makeTokenForURL(req)
if err != nil {
http.Error(w, "internal error", http.StatusInternalServerError)
return
}
data.Token = token
ssoStatusTemplate.Execute(w, data)
}

func (be *Backend) serveLogout(w http.ResponseWriter, req *http.Request) {
if be.SSO != nil {
be.SSO.cm.ClearCookies(w)
}
req.ParseForm()
if tokenStr := req.Form.Get("u"); tokenStr != "" {
tok, err := be.tm.ValidateToken(tokenStr)
if err == jwt.ErrTokenExpired {
http.Error(w, "data expired", http.StatusBadRequest)
return
}
if err != nil {
http.Error(w, "invalid request", http.StatusBadRequest)
return
}
c, ok := tok.Claims.(jwt.MapClaims)
if !ok {
http.Error(w, "invalid request", http.StatusBadRequest)
return
}
url, ok := c["url"].(string)
if !ok {
http.Error(w, "invalid request", http.StatusBadRequest)
return
}
be.SSO.p.RequestLogin(w, req, url)
return
}
logoutTemplate.Execute(w, nil)
}

func (be *Backend) servePermissionDenied(w http.ResponseWriter, req *http.Request) {
var subject string
if claims := claimsFromCtx(req.Context()); claims != nil {
subject, _ = claims.GetSubject()
}
token, url, err := be.makeTokenForURL(req)
if err != nil {
http.Error(w, "internal error", http.StatusInternalServerError)
return
}

data := struct {
Subject string
URL string
DisplayURL string
Token string
}{
Subject: subject,
URL: url,
DisplayURL: url,
Token: token,
}
if len(data.DisplayURL) > 100 {
data.DisplayURL = data.DisplayURL[:97] + "..."
}
w.WriteHeader(http.StatusForbidden)
permissionDeniedTemplate.Execute(w, data)
}

func (be *Backend) enforceSSOPolicy(w http.ResponseWriter, req *http.Request) bool {
if be.SSO == nil || !pathMatches(be.SSO.Paths, req.URL.Path) {
return true
}
userID := req.Header.Get(xTLSProxyUserIDHeader)
if userID == "" {
claims := claimsFromCtx(req.Context())
if claims == nil {
u := req.URL
u.Scheme = "https"
u.Host = req.Host
log.Printf("REQ %s ➔ %s %s ➔ status:%d (SSO)", formatReqDesc(req), req.Method, req.RequestURI, http.StatusFound)
be.SSO.p.RequestLogin(w, req, u.String())
return false
}
userID, _ := claims.GetSubject()
host := req.Host
if h, _, err := net.SplitHostPort(host); err == nil {
host = h
Expand All @@ -114,8 +260,7 @@ func (be *Backend) enforceSSOPolicy(w http.ResponseWriter, req *http.Request) bo
if be.SSO.ACL != nil && !slices.Contains(*be.SSO.ACL, userID) && !slices.Contains(*be.SSO.ACL, "@"+userDomain) {
be.recordEvent(fmt.Sprintf("deny %s to %s", userID, host))
log.Printf("REQ %s ➔ %s %s ➔ status:%d (SSO)", formatReqDesc(req), req.Method, req.RequestURI, http.StatusForbidden)
be.SSO.cm.ClearCookies(w)
http.Error(w, "Forbidden", http.StatusForbidden)
be.servePermissionDenied(w, req)
return false
}
be.recordEvent(fmt.Sprintf("allow %s to %s", userID, host))
Expand All @@ -133,3 +278,14 @@ func pathMatches(prefixes []string, path string) bool {
}
return false
}

func (be *Backend) makeTokenForURL(req *http.Request) (string, string, error) {
u := req.URL
u.Scheme = "https"
u.Host = req.Host
token, err := be.tm.CreateToken(jwt.MapClaims{
"url": u.String(),
"exp": time.Now().Add(time.Hour).Unix(),
}, "ES256")
return token, u.String(), err
}
49 changes: 19 additions & 30 deletions proxy/backend.go
Original file line number Diff line number Diff line change
Expand Up @@ -198,35 +198,32 @@ func (be *Backend) bridgeConns(client, server net.Conn) error {
return retErr
}

func (be *Backend) runLocalHandlersAndAuthz(w http.ResponseWriter, req *http.Request) bool {
h, exists := be.localHandlers[req.URL.Path]
if exists && h.ssoBypass {
h.handler(w, req)
return false
}
if !be.enforceSSOPolicy(w, req) {
return false
}
if exists && !h.ssoBypass {
h.handler(w, req)
return false
}
return true
}

func (be *Backend) consoleHandler() http.Handler {
func (be *Backend) localHandlersAndAuthz(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
if !be.getUserAuthentication(w, req) {
h, exists := be.localHandlers[req.URL.Path]
if exists && h.ssoBypass {
h.handler.ServeHTTP(w, req)
return
}
if !be.enforceSSOPolicy(w, req) {
return
}
if exists && !h.ssoBypass {
h.handler.ServeHTTP(w, req)
return
}
logRequest(req)
if !be.runLocalHandlersAndAuthz(w, req) {
if next == nil {
http.NotFound(w, req)
return
}
http.NotFound(w, req)
next.ServeHTTP(w, req)
})
}

func (be *Backend) consoleHandler() http.Handler {
return be.userAuthentication(logHandler(be.localHandlersAndAuthz(nil)))
}

func (be *Backend) reverseProxy() http.Handler {
var rp http.Handler
if len(be.Addresses) == 0 {
Expand All @@ -250,15 +247,7 @@ func (be *Backend) reverseProxy() http.Handler {
ModifyResponse: be.reverseProxyModifyResponse,
}
}
return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
if !be.getUserAuthentication(w, req) {
return
}
if !be.runLocalHandlersAndAuthz(w, req) {
return
}
rp.ServeHTTP(w, req)
})
return be.userAuthentication(be.localHandlersAndAuthz(rp))
}

func (be *Backend) reverseProxyDial(ctx context.Context, network, addr string) (net.Conn, error) {
Expand Down
Loading

0 comments on commit 3218b53

Please sign in to comment.