Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add auth token propagation for metrics reader #3341

Merged
merged 17 commits into from
Nov 2, 2021
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion cmd/query/app/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ import (

"github.com/jaegertracing/jaeger/cmd/query/app/apiv3"
"github.com/jaegertracing/jaeger/cmd/query/app/querysvc"
"github.com/jaegertracing/jaeger/pkg/bearertoken"
"github.com/jaegertracing/jaeger/pkg/healthcheck"
"github.com/jaegertracing/jaeger/pkg/netutils"
"github.com/jaegertracing/jaeger/pkg/recoveryhandler"
Expand Down Expand Up @@ -158,7 +159,7 @@ func createHTTPServer(querySvc *querysvc.QueryService, metricsQuerySvc querysvc.
var handler http.Handler = r
handler = additionalHeadersHandler(handler, queryOpts.AdditionalHeaders)
if queryOpts.BearerTokenPropagation {
handler = bearerTokenPropagationHandler(logger, handler)
handler = bearertoken.PropagationHandler(logger, handler)
}
handler = handlers.CompressHandler(handler)
recoveryHandler := recoveryhandler.NewRecoveryHandler(logger, true)
Expand Down
4 changes: 2 additions & 2 deletions cmd/query/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,12 +35,12 @@ import (
"github.com/jaegertracing/jaeger/cmd/query/app"
"github.com/jaegertracing/jaeger/cmd/query/app/querysvc"
"github.com/jaegertracing/jaeger/cmd/status"
"github.com/jaegertracing/jaeger/pkg/bearertoken"
"github.com/jaegertracing/jaeger/pkg/config"
"github.com/jaegertracing/jaeger/pkg/version"
metricsPlugin "github.com/jaegertracing/jaeger/plugin/metrics"
"github.com/jaegertracing/jaeger/plugin/storage"
"github.com/jaegertracing/jaeger/ports"
"github.com/jaegertracing/jaeger/storage/spanstore"
storageMetrics "github.com/jaegertracing/jaeger/storage/spanstore/metrics"
)

Expand Down Expand Up @@ -95,7 +95,7 @@ func main() {
opentracing.SetGlobalTracer(tracer)
queryOpts := new(app.QueryOptions).InitFromViper(v, logger)
// TODO: Need to figure out set enable/disable propagation on storage plugins.
v.Set(spanstore.StoragePropagationKey, queryOpts.BearerTokenPropagation)
v.Set(bearertoken.StoragePropagationKey, queryOpts.BearerTokenPropagation)
storageFactory.InitFromViper(v, logger)
if err := storageFactory.Initialize(baseFactory, logger); err != nil {
logger.Fatal("Failed to init storage factory", zap.Error(err))
Expand Down
26 changes: 26 additions & 0 deletions pkg/bearertoken/context.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
package bearertoken

import "context"

type contextKey string

// Key is the string literal used internally in the implementation of this context.
const Key = "bearer.token"
const bearerToken = contextKey(Key)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't see Key being used outside of this package

Suggested change
// Key is the string literal used internally in the implementation of this context.
const Key = "bearer.token"
const bearerToken = contextKey(Key)
type contextKeyType string
const contextKey = contextKeyType("bearer.token")


// StoragePropagationKey is a key for viper configuration to pass this option to storage plugins.
const StoragePropagationKey = "storage.propagate.token"

// ContextWithBearerToken set bearer token in context.
func ContextWithBearerToken(ctx context.Context, token string) context.Context {
if token == "" {
return ctx
}
return context.WithValue(ctx, bearerToken, token)
}

// GetBearerToken from context, or empty string if there is no token.
func GetBearerToken(ctx context.Context) (string, bool) {
val, ok := ctx.Value(bearerToken).(string)
return val, ok
}
17 changes: 17 additions & 0 deletions pkg/bearertoken/context_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
package bearertoken

import (
"context"
"testing"

"github.com/stretchr/testify/assert"
)

func Test_GetBearerToken(t *testing.T) {
const token = "blah"
ctx := context.Background()
ctx = ContextWithBearerToken(ctx, token)
contextToken, ok := GetBearerToken(ctx)
assert.True(t, ok)
assert.Equal(t, contextToken, token)
}
Original file line number Diff line number Diff line change
Expand Up @@ -12,18 +12,20 @@
// See the License for the specific language governing permissions and
// limitations under the License.

package app
package bearertoken

import (
"net/http"
"strings"

"go.uber.org/zap"

"github.com/jaegertracing/jaeger/storage/spanstore"
)

func bearerTokenPropagationHandler(logger *zap.Logger, h http.Handler) http.Handler {
// PropagationHandler returns a http.Handler containing the logic to extract
// the Authorization token from the http.Request and inserts it into the http.Request
// context for easier access to the request token via GetBearerToken for bearer token
// propagation use cases.
albertteoh marked this conversation as resolved.
Show resolved Hide resolved
func PropagationHandler(logger *zap.Logger, h http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
authHeaderValue := r.Header.Get("Authorization")
Expand All @@ -40,15 +42,14 @@ func bearerTokenPropagationHandler(logger *zap.Logger, h http.Handler) http.Hand
token = headerValue[1]
}
} else if len(headerValue) == 1 {
// Tread all value as a token
// Treat the entire value as a token.
token = authHeaderValue
} else {
logger.Warn("Invalid authorization header value, skipping token propagation")
}
h.ServeHTTP(w, r.WithContext(spanstore.ContextWithBearerToken(ctx, token)))
h.ServeHTTP(w, r.WithContext(ContextWithBearerToken(ctx, token)))
} else {
h.ServeHTTP(w, r.WithContext(ctx))
}
})

}
Original file line number Diff line number Diff line change
Expand Up @@ -12,41 +12,44 @@
// See the License for the specific language governing permissions and
// limitations under the License.

package app
package bearertoken

import (
"net/http"
"net/http/httptest"
"sync"
"testing"
"time"

"github.com/stretchr/testify/assert"
"go.uber.org/zap"

"github.com/jaegertracing/jaeger/storage/spanstore"
)

func Test_bearTokenPropagationHandler(t *testing.T) {
func Test_PropagationHandler(t *testing.T) {
httpClient := &http.Client{
Timeout: 2 * time.Second,
}

logger := zap.NewNop()
bearerToken := "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiJhZG1pbiIsIm5hbWUiOiJKb2huIERvZSIsImlhdCI"
const bearerToken = "blah"

validTokenHandler := func(stop *sync.WaitGroup) http.HandlerFunc {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
return func(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
token, ok := spanstore.GetBearerToken(ctx)
token, ok := GetBearerToken(ctx)
assert.Equal(t, token, bearerToken)
assert.True(t, ok)
stop.Done()
})
}
}

emptyHandler := func(stop *sync.WaitGroup) http.HandlerFunc {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
return func(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
token, _ := spanstore.GetBearerToken(ctx)
token, _ := GetBearerToken(ctx)
assert.Empty(t, token, bearerToken)
stop.Done()
})
}
}

testCases := []struct {
Expand All @@ -68,7 +71,7 @@ func Test_bearTokenPropagationHandler(t *testing.T) {
t.Run(testCase.name, func(t *testing.T) {
stop := sync.WaitGroup{}
stop.Add(1)
r := bearerTokenPropagationHandler(logger, testCase.handler(&stop))
r := PropagationHandler(logger, testCase.handler(&stop))
server := httptest.NewServer(r)
defer server.Close()
req, err := http.NewRequest("GET", server.URL, nil)
Expand All @@ -81,5 +84,4 @@ func Test_bearTokenPropagationHandler(t *testing.T) {
stop.Wait()
})
}

}
63 changes: 63 additions & 0 deletions pkg/bearertoken/transport.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
package bearertoken

import (
"errors"
"net/http"
)

// transport implements the http.RoundTripper interface,
// itself wrapping an instance of http.RoundTripper.
type transport struct {
defaultToken string
allowOverrideFromCtx bool
wrapped http.RoundTripper
}
albertteoh marked this conversation as resolved.
Show resolved Hide resolved

// Option sets attributes of this transport.
type Option func(*transport)

// WithAllowOverrideFromCtx sets whether the defaultToken can be overridden
// with the token within the request context.
func WithAllowOverrideFromCtx(allow bool) Option {
return func(t *transport) {
t.allowOverrideFromCtx = allow
}
}

// WithToken sets the defaultToken that will be injected into the outbound HTTP
// request's Authorization bearer token header.
// If the WithAllowOverrideFromCtx(true) option is provided, the request context's
// bearer token, will be used in preference to this token.
func WithToken(token string) Option {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I took the functional options approach because the defaultToken wasn't mandatory and didn't feel right to force callers to pass a sentinel token value, nor make callers guess what they should be setting it to if we exported this value in the struct, especially given it's in a separate package.

Let me know if you disagree with the approach.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We started to move away from the use of functional option pattern, it just adds unnecessary mental overhead and has poor discoverability in the docs. I feel like this type can simply be declared with public fields and instantiated directly, without a constructor function. The optionality of the fields is naturally available via struct zero values.

return func(t *transport) {
t.defaultToken = token
}
}

// NewTransport returns a new bearer token transport that wraps the given
// http.RoundTripper, forwarding the authorization token from inbound to
// outbound HTTP requests.
func NewTransport(roundTripper http.RoundTripper, opts ...Option) http.RoundTripper {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I didn't feel like it was necessary to expose the internals of the transport, hence why http.RoundTripper is returned, given that's the only use case right now.

t := &transport{wrapped: roundTripper}
for _, opt := range opts {
opt(t)
}
return t
}

// RoundTrip injects the outbound Authorization header with the
// token provided in the inbound request.
func (tr *transport) RoundTrip(r *http.Request) (*http.Response, error) {
if tr.wrapped == nil {
return nil, errors.New("no http.RoundTripper provided")
}
token := tr.defaultToken
if tr.allowOverrideFromCtx {
headerToken, _ := GetBearerToken(r.Context())
if headerToken != "" {
token = headerToken
}
}
r.Header.Set("Authorization", "Bearer "+token)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we want to check for token != "" before setting the header?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, I considered it but was technically a breaking change; I don't know enough about Auth headers to determine if this is a bug or correct behaviour, though sounds like it's a bug given your suggestion.

return tr.wrapped.RoundTrip(r)
}
98 changes: 98 additions & 0 deletions pkg/bearertoken/transport_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
package bearertoken

import (
"context"
"net/http"
"net/http/httptest"
"testing"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)

type roundTripFunc func(r *http.Request) (*http.Response, error)

func (s roundTripFunc) RoundTrip(r *http.Request) (*http.Response, error) {
return s(r)
}

func TestNewTransport(t *testing.T) {
for _, tc := range []struct {
name string
roundTripper http.RoundTripper
requestContext context.Context
options []Option
wantError bool
}{
{
name: "No options provided and request context set should have empty Bearer token",
roundTripper: roundTripFunc(func(r *http.Request) (*http.Response, error) {
assert.Equal(t, "Bearer ", r.Header.Get("Authorization"))
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wasn't sure if this is an acceptable Auth header value or if an error should be returned. This was the original behaviour so I kept it that way.

return &http.Response{
StatusCode: http.StatusOK,
}, nil
}),
requestContext: ContextWithBearerToken(context.Background(), "tokenFromContext"),
},
{
name: "Allow override from context provided, and request context set should use request context token",
roundTripper: roundTripFunc(func(r *http.Request) (*http.Response, error) {
assert.Equal(t, "Bearer tokenFromContext", r.Header.Get("Authorization"))
return &http.Response{
StatusCode: http.StatusOK,
}, nil
}),
requestContext: ContextWithBearerToken(context.Background(), "tokenFromContext"),
options: []Option{
WithAllowOverrideFromCtx(true),
},
},
{
name: "Allow override from context and token provided, and request context unset should use defaultToken",
roundTripper: roundTripFunc(func(r *http.Request) (*http.Response, error) {
assert.Equal(t, "Bearer initToken", r.Header.Get("Authorization"))
return &http.Response{}, nil
}),
requestContext: context.Background(),
options: []Option{
WithAllowOverrideFromCtx(true),
WithToken("initToken"),
},
},
{
name: "Allow override from context and token provided, and request context set should use context token",
roundTripper: roundTripFunc(func(r *http.Request) (*http.Response, error) {
assert.Equal(t, "Bearer tokenFromContext", r.Header.Get("Authorization"))
return &http.Response{}, nil
}),
requestContext: ContextWithBearerToken(context.Background(), "tokenFromContext"),
options: []Option{
WithAllowOverrideFromCtx(true),
WithToken("initToken"),
},
},
{
name: "Nil roundTripper provided should return an error",
requestContext: context.Background(),
wantError: true,
},
} {
t.Run(tc.name, func(t *testing.T) {
server := httptest.NewServer(nil)
defer server.Close()
req, err := http.NewRequestWithContext(tc.requestContext, "GET", server.URL, nil)
require.NoError(t, err)

tr := NewTransport(tc.roundTripper, tc.options...)
resp, err := tr.RoundTrip(req)

if tc.wantError {
assert.Nil(t, resp)
assert.Error(t, err)
} else {
assert.NotNil(t, resp)
assert.NoError(t, err)
}
})
}
}
Loading