Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Distributor: Use pooled buffers for reading/decompressing request body #6836

Merged
merged 10 commits into from
Dec 11, 2023
45 changes: 23 additions & 22 deletions pkg/distributor/otel.go
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ func OTLPHandler(
) http.Handler {
discardedDueToOtelParseError := validation.DiscardedSamplesCounter(reg, otelParseError)

return handler(maxRecvMsgSize, sourceIPs, allowSkipLabelNameValidation, limits, retryCfg, push, logger, func(ctx context.Context, r *http.Request, maxRecvMsgSize int, dst []byte, req *mimirpb.PreallocWriteRequest, logger log.Logger) ([]byte, error) {
return handler(maxRecvMsgSize, sourceIPs, allowSkipLabelNameValidation, limits, retryCfg, push, logger, func(ctx context.Context, r *http.Request, maxRecvMsgSize int, buffers *util.RequestBuffers, req *mimirpb.PreallocWriteRequest, logger log.Logger) error {
contentType := r.Header.Get("Content-Type")
contentEncoding := r.Header.Get("Content-Encoding")
var compression util.CompressionType
Expand All @@ -64,65 +64,66 @@ func OTLPHandler(
case "":
compression = util.NoCompression
default:
return nil, httpgrpc.Errorf(http.StatusUnsupportedMediaType, "unsupported compression: %s. Only \"gzip\" or no compression supported", contentEncoding)
return httpgrpc.Errorf(http.StatusUnsupportedMediaType, "unsupported compression: %s. Only \"gzip\" or no compression supported", contentEncoding)
}

var decoderFunc func(io.ReadCloser) (pmetricotlp.ExportRequest, []byte, error)
var decoderFunc func(io.ReadCloser) (pmetricotlp.ExportRequest, error)
switch contentType {
case pbContentType:
decoderFunc = func(reader io.ReadCloser) (pmetricotlp.ExportRequest, []byte, error) {
decoderFunc = func(reader io.ReadCloser) (pmetricotlp.ExportRequest, error) {
exportReq := pmetricotlp.NewExportRequest()
unmarshaler := otlpProtoUnmarshaler{
request: &exportReq,
}
buf, err := util.ParseProtoReader(ctx, reader, int(r.ContentLength), maxRecvMsgSize, dst, unmarshaler, compression)
err := util.ParseProtoReader(ctx, reader, int(r.ContentLength), maxRecvMsgSize, buffers, unmarshaler, compression)
var tooLargeErr util.MsgSizeTooLargeErr
if errors.As(err, &tooLargeErr) {
return exportReq, buf, httpgrpc.Errorf(http.StatusRequestEntityTooLarge, distributorMaxWriteMessageSizeErr{
return exportReq, httpgrpc.Errorf(http.StatusRequestEntityTooLarge, distributorMaxWriteMessageSizeErr{
actual: tooLargeErr.Actual,
limit: tooLargeErr.Limit,
}.Error())
}
return exportReq, buf, err
return exportReq, err
}

case jsonContentType:
decoderFunc = func(reader io.ReadCloser) (pmetricotlp.ExportRequest, []byte, error) {
decoderFunc = func(reader io.ReadCloser) (pmetricotlp.ExportRequest, error) {
exportReq := pmetricotlp.NewExportRequest()
var buf bytes.Buffer
if r.ContentLength > 0 {
sz := int(r.ContentLength)
if sz > 0 {
// Extra space guarantees no reallocation
buf.Grow(int(r.ContentLength) + bytes.MinRead)
sz += bytes.MinRead
}
buf := buffers.Get(sz)
if compression == util.Gzip {
var err error
reader, err = gzip.NewReader(reader)
if err != nil {
return exportReq, buf.Bytes(), errors.Wrap(err, "create gzip reader")
return exportReq, errors.Wrap(err, "create gzip reader")
}
}

reader = http.MaxBytesReader(nil, reader, int64(maxRecvMsgSize))
if _, err := buf.ReadFrom(reader); err != nil {
if util.IsRequestBodyTooLarge(err) {
return exportReq, buf.Bytes(), httpgrpc.Errorf(http.StatusRequestEntityTooLarge, distributorMaxWriteMessageSizeErr{
return exportReq, httpgrpc.Errorf(http.StatusRequestEntityTooLarge, distributorMaxWriteMessageSizeErr{
actual: -1,
limit: maxRecvMsgSize,
}.Error())
}

return exportReq, buf.Bytes(), errors.Wrap(err, "read write request")
return exportReq, errors.Wrap(err, "read write request")
}

return exportReq, buf.Bytes(), exportReq.UnmarshalJSON(buf.Bytes())
return exportReq, exportReq.UnmarshalJSON(buf.Bytes())
}

default:
return nil, httpgrpc.Errorf(http.StatusUnsupportedMediaType, "unsupported content type: %s, supported: [%s, %s]", contentType, jsonContentType, pbContentType)
return httpgrpc.Errorf(http.StatusUnsupportedMediaType, "unsupported content type: %s, supported: [%s, %s]", contentType, jsonContentType, pbContentType)
}

if r.ContentLength > int64(maxRecvMsgSize) {
return nil, httpgrpc.Errorf(http.StatusRequestEntityTooLarge, distributorMaxWriteMessageSizeErr{
return httpgrpc.Errorf(http.StatusRequestEntityTooLarge, distributorMaxWriteMessageSizeErr{
actual: int(r.ContentLength),
limit: maxRecvMsgSize,
}.Error())
Expand All @@ -135,22 +136,22 @@ func OTLPHandler(
spanLogger.SetTag("content_encoding", contentEncoding)
spanLogger.SetTag("content_length", r.ContentLength)

otlpReq, buf, err := decoderFunc(r.Body)
otlpReq, err := decoderFunc(r.Body)
if err != nil {
return buf, err
return err
}

level.Debug(spanLogger).Log("msg", "decoding complete, starting conversion")

tenantID, err := tenant.TenantID(ctx)
if err != nil {
return buf, err
return err
}
addSuffixes := limits.OTelMetricSuffixesEnabled(tenantID)

metrics, err := otelMetricsToTimeseries(tenantID, addSuffixes, discardedDueToOtelParseError, logger, otlpReq.Metrics())
if err != nil {
return buf, err
return err
}

metricCount := len(metrics)
Expand Down Expand Up @@ -179,7 +180,7 @@ func OTLPHandler(
req.Metadata = metadata
}

return buf, nil
return nil
})
}

Expand Down
33 changes: 13 additions & 20 deletions pkg/distributor/push.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
package distributor

import (
"bytes"
"context"
"errors"
"flag"
Expand Down Expand Up @@ -33,17 +34,14 @@ import (
// PushFunc defines the type of the push. It is similar to http.HandlerFunc.
type PushFunc func(ctx context.Context, req *Request) error

// parserFunc defines how to read the body the request from an HTTP request
type parserFunc func(ctx context.Context, r *http.Request, maxSize int, buffer []byte, req *mimirpb.PreallocWriteRequest, logger log.Logger) ([]byte, error)

// Wrap a slice in a struct so we can store a pointer in sync.Pool
type bufHolder struct {
buf []byte
}
// parserFunc defines how to read the body the request from an HTTP request. It takes an optional RequestBuffers.
type parserFunc func(ctx context.Context, r *http.Request, maxSize int, buffers *util.RequestBuffers, req *mimirpb.PreallocWriteRequest, logger log.Logger) error

var (
bufferPool = sync.Pool{
New: func() interface{} { return &bufHolder{buf: make([]byte, 256*1024)} },
New: func() any {
return bytes.NewBuffer(make([]byte, 0, 256*1024))
},
}
errRetryBaseLessThanOneSecond = errors.New("retry base duration should not be less than 1 second")
errNonPositiveMaxBackoffExponent = errors.New("max backoff exponent should be a positive value")
Expand Down Expand Up @@ -87,12 +85,12 @@ func Handler(
push PushFunc,
logger log.Logger,
) http.Handler {
return handler(maxRecvMsgSize, sourceIPs, allowSkipLabelNameValidation, limits, retryCfg, push, logger, func(ctx context.Context, r *http.Request, maxRecvMsgSize int, dst []byte, req *mimirpb.PreallocWriteRequest, _ log.Logger) ([]byte, error) {
res, err := util.ParseProtoReader(ctx, r.Body, int(r.ContentLength), maxRecvMsgSize, dst, req, util.RawSnappy)
return handler(maxRecvMsgSize, sourceIPs, allowSkipLabelNameValidation, limits, retryCfg, push, logger, func(ctx context.Context, r *http.Request, maxRecvMsgSize int, buffers *util.RequestBuffers, req *mimirpb.PreallocWriteRequest, _ log.Logger) error {
err := util.ParseProtoReader(ctx, r.Body, int(r.ContentLength), maxRecvMsgSize, buffers, req, util.RawSnappy)
if errors.Is(err, util.MsgSizeTooLargeErr{}) {
err = distributorMaxWriteMessageSizeErr{actual: int(r.ContentLength), limit: maxRecvMsgSize}
}
return res, err
return err
})
}

Expand Down Expand Up @@ -129,22 +127,17 @@ func handler(
}
}
supplier := func() (*mimirpb.WriteRequest, func(), error) {
bufHolder := bufferPool.Get().(*bufHolder)
rb := util.NewRequestBuffers(&bufferPool)
var req mimirpb.PreallocWriteRequest
buf, err := parser(ctx, r, maxRecvMsgSize, bufHolder.buf, &req, logger)
if err != nil {
if err := parser(ctx, r, maxRecvMsgSize, rb, &req, logger); err != nil {
// Check for httpgrpc error, default to client error if parsing failed
if _, ok := httpgrpc.HTTPResponseFromError(err); !ok {
err = httpgrpc.Errorf(http.StatusBadRequest, err.Error())
}

bufferPool.Put(bufHolder)
rb.CleanUp()
return nil, nil, err
}
// If decoding allocated a bigger buffer, put that one back in the pool.
if buf = buf[:cap(buf)]; len(buf) > len(bufHolder.buf) {
bufHolder.buf = buf
}

if allowSkipLabelNameValidation {
req.SkipLabelNameValidation = req.SkipLabelNameValidation && r.Header.Get(SkipLabelNameValidationHeader) == "true"
Expand All @@ -154,7 +147,7 @@ func handler(

cleanup := func() {
mimirpb.ReuseSlice(req.Timeseries)
bufferPool.Put(bufHolder)
rb.CleanUp()
}
return &req.WriteRequest, cleanup, nil
}
Expand Down
39 changes: 20 additions & 19 deletions pkg/distributor/push_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ import (
"google.golang.org/grpc/codes"

"github.com/grafana/mimir/pkg/mimirpb"
"github.com/grafana/mimir/pkg/util"
"github.com/grafana/mimir/pkg/util/test"
"github.com/grafana/mimir/pkg/util/validation"
)
Expand Down Expand Up @@ -134,18 +135,18 @@ func TestHandlerOTLPPush(t *testing.T) {
require.Len(t, series, 1)

samples := series[0].Samples
assert.Equal(t, 1, len(samples))
require.Len(t, samples, 1)
assert.Equal(t, float64(1), samples[0].Value)
assert.Equal(t, "__name__", series[0].Labels[0].Name)
assert.Equal(t, "foo", series[0].Labels[0].Value)

metadata := request.Metadata
require.Len(t, metadata, 1)
assert.Equal(t, mimirpb.GAUGE, metadata[0].GetType())
assert.Equal(t, "foo", metadata[0].GetMetricFamilyName())
assert.Equal(t, "metric_help", metadata[0].GetHelp())
assert.Equal(t, "metric_unit", metadata[0].GetUnit())

pushReq.CleanUp()
return nil
}

Expand All @@ -154,7 +155,7 @@ func TestHandlerOTLPPush(t *testing.T) {
require.NoError(t, err)

series := request.Timeseries
assert.Len(t, series, 1)
require.Len(t, series, 1)

samples := series[0].Samples
require.Equal(t, 1, len(samples))
Expand All @@ -165,7 +166,6 @@ func TestHandlerOTLPPush(t *testing.T) {
metadata := request.Metadata
assert.Equal(t, []*mimirpb.MetricMetadata(nil), metadata)

pushReq.CleanUp()
return nil
}

Expand Down Expand Up @@ -295,6 +295,7 @@ func TestHandlerOTLPPush(t *testing.T) {
require.NoError(t, err)
pusher := func(ctx context.Context, pushReq *Request) error {
t.Helper()
t.Cleanup(pushReq.CleanUp)
return tt.verifyFunc(t, pushReq)
}
handler := OTLPHandler(tt.maxMsgSize, nil, false, tt.enableOtelMetadataStorage, limits, RetryConfig{}, nil, pusher, log.NewNopLogger())
Expand All @@ -305,7 +306,7 @@ func TestHandlerOTLPPush(t *testing.T) {
assert.Equal(t, tt.responseCode, resp.Code)
if tt.errMessage != "" {
body, err := io.ReadAll(resp.Body)
assert.NoError(t, err)
require.NoError(t, err)
assert.Contains(t, string(body), tt.errMessage)
}
})
Expand Down Expand Up @@ -520,13 +521,13 @@ func TestHandler_EnsureSkipLabelNameValidationBehaviour(t *testing.T) {
req: createRequest(t, createMimirWriteRequestProtobufWithNonSupportedLabelNames(t, true)),
verifyReqHandler: func(ctx context.Context, pushReq *Request) error {
request, err := pushReq.WriteRequest()
assert.NoError(t, err)
require.NoError(t, err)
t.Cleanup(pushReq.CleanUp)
assert.Len(t, request.Timeseries, 1)
assert.Equal(t, "a-label", request.Timeseries[0].Labels[0].Name)
assert.Equal(t, "value", request.Timeseries[0].Labels[0].Value)
assert.Equal(t, mimirpb.RULE, request.Source)
assert.False(t, request.SkipLabelNameValidation)
pushReq.CleanUp()
return nil
},
includeAllowSkiplabelNameValidationHeader: true,
Expand Down Expand Up @@ -602,13 +603,13 @@ func verifyWritePushFunc(t *testing.T, expectSource mimirpb.WriteRequest_SourceE
t.Helper()
return func(ctx context.Context, pushReq *Request) error {
request, err := pushReq.WriteRequest()
assert.NoError(t, err)
assert.Len(t, request.Timeseries, 1)
assert.Equal(t, "__name__", request.Timeseries[0].Labels[0].Name)
assert.Equal(t, "foo", request.Timeseries[0].Labels[0].Value)
assert.Equal(t, expectSource, request.Source)
assert.False(t, request.SkipLabelNameValidation)
pushReq.CleanUp()
require.NoError(t, err)
t.Cleanup(pushReq.CleanUp)
require.Len(t, request.Timeseries, 1)
require.Equal(t, "__name__", request.Timeseries[0].Labels[0].Name)
require.Equal(t, "foo", request.Timeseries[0].Labels[0].Value)
require.Equal(t, expectSource, request.Source)
require.False(t, request.SkipLabelNameValidation)
return nil
}
}
Expand Down Expand Up @@ -758,8 +759,8 @@ func TestHandler_ErrorTranslation(t *testing.T) {
}
for _, tc := range parserTestCases {
t.Run(tc.name, func(t *testing.T) {
parserFunc := func(context.Context, *http.Request, int, []byte, *mimirpb.PreallocWriteRequest, log.Logger) ([]byte, error) {
return nil, tc.err
parserFunc := func(context.Context, *http.Request, int, *util.RequestBuffers, *mimirpb.PreallocWriteRequest, log.Logger) error {
return tc.err
}
pushFunc := func(ctx context.Context, req *Request) error {
_, err := req.WriteRequest() // just read the body so we can trigger the parser
Expand Down Expand Up @@ -825,8 +826,8 @@ func TestHandler_ErrorTranslation(t *testing.T) {
for _, tc := range testCases {
tc := tc
t.Run(tc.name, func(t *testing.T) {
parserFunc := func(context.Context, *http.Request, int, []byte, *mimirpb.PreallocWriteRequest, log.Logger) ([]byte, error) {
return nil, nil
parserFunc := func(context.Context, *http.Request, int, *util.RequestBuffers, *mimirpb.PreallocWriteRequest, log.Logger) error {
return nil
}
pushFunc := func(ctx context.Context, req *Request) error {
_, err := req.WriteRequest() // just read the body so we can trigger the parser
Expand All @@ -841,7 +842,7 @@ func TestHandler_ErrorTranslation(t *testing.T) {

assert.Equal(t, tc.expectedHTTPStatus, recorder.Code)
if tc.err != nil {
assert.Equal(t, fmt.Sprintf("%s\n", tc.expectedErrorMessage), recorder.Body.String())
require.Equal(t, fmt.Sprintf("%s\n", tc.expectedErrorMessage), recorder.Body.String())
}
header := recorder.Header().Get(server.DoNotLogErrorHeaderKey)
if tc.expectedDoNotLogErrorHeader {
Expand Down
6 changes: 3 additions & 3 deletions pkg/ingester/client/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,10 +50,10 @@ func TestMarshall(t *testing.T) {
plentySize = 1024 * 1024
)
req := mimirpb.WriteRequest{}
_, err := util.ParseProtoReader(context.Background(), recorder.Body, recorder.Body.Len(), tooSmallSize, nil, &req, util.RawSnappy)
err := util.ParseProtoReader(context.Background(), recorder.Body, recorder.Body.Len(), tooSmallSize, nil, &req, util.RawSnappy)
require.Error(t, err)
_, err = util.ParseProtoReader(context.Background(), recorder.Body, recorder.Body.Len(), plentySize, nil, &req, util.RawSnappy)
err = util.ParseProtoReader(context.Background(), recorder.Body, recorder.Body.Len(), plentySize, nil, &req, util.RawSnappy)
require.NoError(t, err)
require.Equal(t, numSeries, len(req.Timeseries))
require.Len(t, req.Timeseries, numSeries)
}
}
2 changes: 1 addition & 1 deletion pkg/querier/remote_read.go
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ func remoteReadHandler(q storage.SampleAndChunkQueryable, maxBytesInFrame int, l
ctx := r.Context()
var req client.ReadRequest
logger := util_log.WithContext(r.Context(), lg)
if _, err := util.ParseProtoReader(ctx, r.Body, int(r.ContentLength), maxRemoteReadQuerySize, nil, &req, util.RawSnappy); err != nil {
if err := util.ParseProtoReader(ctx, r.Body, int(r.ContentLength), maxRemoteReadQuerySize, nil, &req, util.RawSnappy); err != nil {
level.Error(logger).Log("msg", "failed to parse proto", "err", err.Error())
http.Error(w, err.Error(), http.StatusBadRequest)
return
Expand Down
Loading
Loading