Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add options lambdaurl.WithDetectContentType and lambda.WithContextValue #516

Merged
merged 15 commits into from
Dec 1, 2023
Merged
1 change: 1 addition & 0 deletions .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ jobs:
name: run tests
runs-on: ubuntu-latest
strategy:
fail-fast: false
matrix:
go:
- "1.21"
Expand Down
22 changes: 22 additions & 0 deletions lambda/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ type Handler interface {
type handlerOptions struct {
handlerFunc
baseContext context.Context
contextValues map[interface{}]interface{}
jsonRequestUseNumber bool
jsonRequestDisallowUnknownFields bool
jsonResponseEscapeHTML bool
Expand Down Expand Up @@ -50,6 +51,23 @@ func WithContext(ctx context.Context) Option {
})
}

// WithContextValue adds a value to the handler context.
// If a base context was set using WithContext, that base is used as the parent.
//
// Usage:
//
// lambda.StartWithOptions(
// func (ctx context.Context) (string, error) {
// return ctx.Value("foo"), nil
// },
// lambda.WithContextValue("foo", "bar")
// )
func WithContextValue(key interface{}, value interface{}) Option {
return Option(func(h *handlerOptions) {
h.contextValues[key] = value
})
}

// WithSetEscapeHTML sets the SetEscapeHTML argument on the underlying json encoder
//
// Usage:
Expand Down Expand Up @@ -211,13 +229,17 @@ func newHandler(handlerFunc interface{}, options ...Option) *handlerOptions {
}
h := &handlerOptions{
baseContext: context.Background(),
contextValues: map[interface{}]interface{}{},
jsonResponseEscapeHTML: false,
jsonResponseIndentPrefix: "",
jsonResponseIndentValue: "",
}
for _, option := range options {
option(h)
}
for k, v := range h.contextValues {
h.baseContext = context.WithValue(h.baseContext, k, v)
}
if h.enableSIGTERM {
enableSIGTERM(h.sigtermCallbacks)
}
Expand Down
12 changes: 7 additions & 5 deletions lambda/sigterm_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (
"os"
"os/exec"
"path"
"strconv"
"strings"
"testing"
"time"
Expand All @@ -17,10 +18,6 @@ import (
"github.com/stretchr/testify/require"
)

const (
rieInvokeAPI = "http://localhost:8080/2015-03-31/functions/function/invocations"
)

func TestEnableSigterm(t *testing.T) {
if _, err := exec.LookPath("aws-lambda-rie"); err != nil {
t.Skipf("%v - install from https://github.com/aws/aws-lambda-runtime-interface-emulator/", err)
Expand All @@ -34,6 +31,7 @@ func TestEnableSigterm(t *testing.T) {
handlerBuild.Stdout = os.Stderr
require.NoError(t, handlerBuild.Run())

portI := 0
for name, opts := range map[string]struct {
envVars []string
assertLogs func(t *testing.T, logs string)
Expand All @@ -53,8 +51,12 @@ func TestEnableSigterm(t *testing.T) {
},
} {
t.Run(name, func(t *testing.T) {
portI += 1
addr1 := "localhost:" + strconv.Itoa(8000+portI)
addr2 := "localhost:" + strconv.Itoa(9000+portI)
rieInvokeAPI := "http://" + addr1 + "/2015-03-31/functions/function/invocations"
// run the runtime interface emulator, capture the logs for assertion
cmd := exec.Command("aws-lambda-rie", "sigterm.handler")
cmd := exec.Command("aws-lambda-rie", "--runtime-interface-emulator-address", addr1, "--runtime-api-address", addr2, "sigterm.handler")
cmd.Env = append([]string{
"PATH=" + testDir,
"AWS_LAMBDA_FUNCTION_TIMEOUT=2",
Expand Down
91 changes: 76 additions & 15 deletions lambdaurl/http_handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,24 +18,76 @@ import (
"github.com/aws/aws-lambda-go/lambda"
)

type detectContentTypeContextKey struct{}

// WithDetectContentType sets the behavior of content type detection when the Content-Type header is not already provided.
// When true, the first Write call will pass the intial bytes to http.DetectContentType.
// When false, and if no Content-Type is provided, no Content-Type will be sent back to Lambda,
// and the Lambda Function URL will fallback to it's default.
//
// Note: The http.ResponseWriter passed to the handler is unbuffered.
// This may result in different Content-Type headers in the Function URL response when compared to http.ListenAndServe.
//
// Usage:
//
// lambdaurl.Start(
// http.HandlerFunc(func (w http.ResponseWriter, r *http.Request) {
// w.Write("<!DOCTYPE html><html></html>")
// }),
// lambdaurl.WithDetectContentType(true)
// )
func WithDetectContentType(detectContentType bool) lambda.Option {
return lambda.WithContextValue(detectContentTypeContextKey{}, detectContentType)
}

type httpResponseWriter struct {
detectContentType bool
header http.Header
writer io.Writer
once sync.Once
ready chan<- header
}

type header struct {
code int
header http.Header
writer io.Writer
once sync.Once
status chan<- int
}

func (w *httpResponseWriter) Header() http.Header {
if w.header == nil {
w.header = http.Header{}
}
return w.header
}

func (w *httpResponseWriter) Write(p []byte) (int, error) {
w.once.Do(func() { w.status <- http.StatusOK })
w.writeHeader(http.StatusOK, p)
return w.writer.Write(p)
}

func (w *httpResponseWriter) WriteHeader(statusCode int) {
w.once.Do(func() { w.status <- statusCode })
w.writeHeader(statusCode, nil)
}

func (w *httpResponseWriter) writeHeader(statusCode int, initialPayload []byte) {
w.once.Do(func() {
if w.detectContentType {
if w.Header().Get("Content-Type") == "" {
w.Header().Set("Content-Type", detectContentType(initialPayload))
}
}
w.ready <- header{code: statusCode, header: w.header}
})
}

func detectContentType(p []byte) string {
// http.DetectContentType returns "text/plain; charset=utf-8" for nil and zero-length byte slices.
// This is a weird behavior, since otherwise it defaults to "application/octet-stream"! So we'll do that.
// This differs from http.ListenAndServe, which set no Content-Type when the initial Flush body is empty.
if len(p) == 0 {
return "application/octet-stream"
}
return http.DetectContentType(p)
}

type requestContextKey struct{}
Expand All @@ -46,11 +98,13 @@ func RequestFromContext(ctx context.Context) (*events.LambdaFunctionURLRequest,
return req, ok
}

// Wrap converts an http.Handler into a lambda request handler.
// Wrap converts an http.Handler into a Lambda request handler.
//
// Only Lambda Function URLs configured with `InvokeMode: RESPONSE_STREAM` are supported with the returned handler.
// The response body of the handler will conform to the content-type `application/vnd.awslambda.http-integration-response`
// The response body of the handler will conform to the content-type `application/vnd.awslambda.http-integration-response`.
func Wrap(handler http.Handler) func(context.Context, *events.LambdaFunctionURLRequest) (*events.LambdaFunctionURLStreamingResponse, error) {
return func(ctx context.Context, request *events.LambdaFunctionURLRequest) (*events.LambdaFunctionURLStreamingResponse, error) {

var body io.Reader = strings.NewReader(request.Body)
if request.IsBase64Encoded {
body = base64.NewDecoder(base64.StdEncoding, body)
Expand All @@ -67,21 +121,28 @@ func Wrap(handler http.Handler) func(context.Context, *events.LambdaFunctionURLR
for k, v := range request.Headers {
httpRequest.Header.Add(k, v)
}
status := make(chan int) // Signals when it's OK to start returning the response body to Lambda
header := http.Header{}

ready := make(chan header) // Signals when it's OK to start returning the response body to Lambda
r, w := io.Pipe()
responseWriter := &httpResponseWriter{writer: w, ready: ready}
if detectContentType, ok := ctx.Value(detectContentTypeContextKey{}).(bool); ok {
responseWriter.detectContentType = detectContentType
}
go func() {
defer close(status)
defer close(ready)
defer w.Close() // TODO: recover and CloseWithError the any panic value once the runtime API client supports plumbing fatal errors through the reader
handler.ServeHTTP(&httpResponseWriter{writer: w, header: header, status: status}, httpRequest)
//nolint:errcheck
defer responseWriter.Write(nil) // force default status, headers, content type detection, if none occured during the execution of the handler
handler.ServeHTTP(responseWriter, httpRequest)
}()
header := <-ready
response := &events.LambdaFunctionURLStreamingResponse{
Body: r,
StatusCode: <-status,
StatusCode: header.code,
}
if len(header) > 0 {
response.Headers = make(map[string]string, len(header))
for k, v := range header {
if len(header.header) > 0 {
response.Headers = make(map[string]string, len(header.header))
for k, v := range header.header {
if k == "Set-Cookie" {
response.Cookies = v
} else {
Expand Down
Loading