diff --git a/api/api.go b/api/api.go index 11e663c43..85ee26320 100644 --- a/api/api.go +++ b/api/api.go @@ -244,6 +244,7 @@ func NewServerRoutes(s Server) *mux.Router { } m.Use(refStringMiddleware) + m.Use(OAuthTokenMiddleware) return m } diff --git a/api/middleware.go b/api/middleware.go index 675591a95..28853913c 100644 --- a/api/middleware.go +++ b/api/middleware.go @@ -3,10 +3,12 @@ package api import ( "fmt" "net/http" + "strings" "time" "github.com/gorilla/mux" "github.com/qri-io/qri/api/util" + "github.com/qri-io/qri/auth/token" "github.com/qri-io/qri/dsref" ) @@ -96,3 +98,33 @@ func stripServerSideQueryParams(r *http.Request) { q.Del("refstr") r.URL.RawQuery = q.Encode() } + +const ( + bearerPrefix = "Bearer " + authorizationHeader = "authorization" +) + +// OAuthTokenMiddleware parses any "authorization" header containing a Bearer +// token & adds it to the request context +func OAuthTokenMiddleware(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + reqToken := r.Header.Get(authorizationHeader) + if reqToken == "" && r.FormValue(authorizationHeader) != "" { + reqToken = r.FormValue(authorizationHeader) + } + if reqToken == "" { + next.ServeHTTP(w, r) + return + } + + if !strings.HasPrefix(reqToken, bearerPrefix) { + util.WriteErrResponse(w, http.StatusBadRequest, fmt.Errorf("bad token")) + return + } + tokenStr := strings.TrimPrefix(reqToken, bearerPrefix) + ctx := token.AddToContext(r.Context(), tokenStr) + + r = r.WithContext(ctx) + next.ServeHTTP(w, r) + }) +} diff --git a/auth/token/context.go b/auth/token/context.go index 87a0bf7b8..9b05dc308 100644 --- a/auth/token/context.go +++ b/auth/token/context.go @@ -11,17 +11,17 @@ type CtxKey string // tokenCtxKey is the key for adding an access token to a context.Context const tokenCtxKey CtxKey = "Token" -// AddToContext adds a token value to a context -func AddToContext(ctx context.Context, t Token) context.Context { - return context.WithValue(ctx, tokenCtxKey, t) +// AddToContext adds a token string to a context +func AddToContext(ctx context.Context, s string) context.Context { + return context.WithValue(ctx, tokenCtxKey, s) } // FromCtx extracts the JWT from a given // context if one is set, returning nil otherwise -func FromCtx(ctx context.Context) *Token { +func FromCtx(ctx context.Context) string { iface := ctx.Value(tokenCtxKey) - if ref, ok := iface.(Token); ok { - return &ref + if s, ok := iface.(string); ok { + return s } - return nil + return "" }