diff --git a/pkg/distributor/otel.go b/pkg/distributor/otel.go index b9b3b87dda3..84fc55fb306 100644 --- a/pkg/distributor/otel.go +++ b/pkg/distributor/otel.go @@ -54,7 +54,7 @@ func OTLPHandler( ) http.Handler { discardedDueToOtelParseError := validation.DiscardedSamplesCounter(reg, otelParseError) - return handler(maxRecvMsgSize, sourceIPs, allowSkipLabelNameValidation, limits, retryCfg, push, logger, func(ctx context.Context, r *http.Request, maxRecvMsgSize int, dst []byte, req *mimirpb.PreallocWriteRequest, logger log.Logger) ([]byte, error) { + return handler(maxRecvMsgSize, sourceIPs, allowSkipLabelNameValidation, limits, retryCfg, push, logger, func(ctx context.Context, r *http.Request, maxRecvMsgSize int, buffers *util.RequestBuffers, req *mimirpb.PreallocWriteRequest, logger log.Logger) error { contentType := r.Header.Get("Content-Type") contentEncoding := r.Header.Get("Content-Encoding") var compression util.CompressionType @@ -64,65 +64,66 @@ func OTLPHandler( case "": compression = util.NoCompression default: - return nil, httpgrpc.Errorf(http.StatusUnsupportedMediaType, "unsupported compression: %s. Only \"gzip\" or no compression supported", contentEncoding) + return httpgrpc.Errorf(http.StatusUnsupportedMediaType, "unsupported compression: %s. Only \"gzip\" or no compression supported", contentEncoding) } - var decoderFunc func(io.ReadCloser) (pmetricotlp.ExportRequest, []byte, error) + var decoderFunc func(io.ReadCloser) (pmetricotlp.ExportRequest, error) switch contentType { case pbContentType: - decoderFunc = func(reader io.ReadCloser) (pmetricotlp.ExportRequest, []byte, error) { + decoderFunc = func(reader io.ReadCloser) (pmetricotlp.ExportRequest, error) { exportReq := pmetricotlp.NewExportRequest() unmarshaler := otlpProtoUnmarshaler{ request: &exportReq, } - buf, err := util.ParseProtoReader(ctx, reader, int(r.ContentLength), maxRecvMsgSize, dst, unmarshaler, compression) + err := util.ParseProtoReader(ctx, reader, int(r.ContentLength), maxRecvMsgSize, buffers, unmarshaler, compression) var tooLargeErr util.MsgSizeTooLargeErr if errors.As(err, &tooLargeErr) { - return exportReq, buf, httpgrpc.Errorf(http.StatusRequestEntityTooLarge, distributorMaxWriteMessageSizeErr{ + return exportReq, httpgrpc.Errorf(http.StatusRequestEntityTooLarge, distributorMaxWriteMessageSizeErr{ actual: tooLargeErr.Actual, limit: tooLargeErr.Limit, }.Error()) } - return exportReq, buf, err + return exportReq, err } case jsonContentType: - decoderFunc = func(reader io.ReadCloser) (pmetricotlp.ExportRequest, []byte, error) { + decoderFunc = func(reader io.ReadCloser) (pmetricotlp.ExportRequest, error) { exportReq := pmetricotlp.NewExportRequest() - var buf bytes.Buffer - if r.ContentLength > 0 { + sz := int(r.ContentLength) + if sz > 0 { // Extra space guarantees no reallocation - buf.Grow(int(r.ContentLength) + bytes.MinRead) + sz += bytes.MinRead } + buf := buffers.Get(sz) if compression == util.Gzip { var err error reader, err = gzip.NewReader(reader) if err != nil { - return exportReq, buf.Bytes(), errors.Wrap(err, "create gzip reader") + return exportReq, errors.Wrap(err, "create gzip reader") } } reader = http.MaxBytesReader(nil, reader, int64(maxRecvMsgSize)) if _, err := buf.ReadFrom(reader); err != nil { if util.IsRequestBodyTooLarge(err) { - return exportReq, buf.Bytes(), httpgrpc.Errorf(http.StatusRequestEntityTooLarge, distributorMaxWriteMessageSizeErr{ + return exportReq, httpgrpc.Errorf(http.StatusRequestEntityTooLarge, distributorMaxWriteMessageSizeErr{ actual: -1, limit: maxRecvMsgSize, }.Error()) } - return exportReq, buf.Bytes(), errors.Wrap(err, "read write request") + return exportReq, errors.Wrap(err, "read write request") } - return exportReq, buf.Bytes(), exportReq.UnmarshalJSON(buf.Bytes()) + return exportReq, exportReq.UnmarshalJSON(buf.Bytes()) } default: - return nil, httpgrpc.Errorf(http.StatusUnsupportedMediaType, "unsupported content type: %s, supported: [%s, %s]", contentType, jsonContentType, pbContentType) + return httpgrpc.Errorf(http.StatusUnsupportedMediaType, "unsupported content type: %s, supported: [%s, %s]", contentType, jsonContentType, pbContentType) } if r.ContentLength > int64(maxRecvMsgSize) { - return nil, httpgrpc.Errorf(http.StatusRequestEntityTooLarge, distributorMaxWriteMessageSizeErr{ + return httpgrpc.Errorf(http.StatusRequestEntityTooLarge, distributorMaxWriteMessageSizeErr{ actual: int(r.ContentLength), limit: maxRecvMsgSize, }.Error()) @@ -135,22 +136,22 @@ func OTLPHandler( spanLogger.SetTag("content_encoding", contentEncoding) spanLogger.SetTag("content_length", r.ContentLength) - otlpReq, buf, err := decoderFunc(r.Body) + otlpReq, err := decoderFunc(r.Body) if err != nil { - return buf, err + return err } level.Debug(spanLogger).Log("msg", "decoding complete, starting conversion") tenantID, err := tenant.TenantID(ctx) if err != nil { - return buf, err + return err } addSuffixes := limits.OTelMetricSuffixesEnabled(tenantID) metrics, err := otelMetricsToTimeseries(tenantID, addSuffixes, discardedDueToOtelParseError, logger, otlpReq.Metrics()) if err != nil { - return buf, err + return err } metricCount := len(metrics) @@ -179,7 +180,7 @@ func OTLPHandler( req.Metadata = metadata } - return buf, nil + return nil }) } diff --git a/pkg/distributor/push.go b/pkg/distributor/push.go index b1b07d83d95..2c03d93af0d 100644 --- a/pkg/distributor/push.go +++ b/pkg/distributor/push.go @@ -6,6 +6,7 @@ package distributor import ( + "bytes" "context" "errors" "flag" @@ -33,17 +34,14 @@ import ( // PushFunc defines the type of the push. It is similar to http.HandlerFunc. type PushFunc func(ctx context.Context, req *Request) error -// parserFunc defines how to read the body the request from an HTTP request -type parserFunc func(ctx context.Context, r *http.Request, maxSize int, buffer []byte, req *mimirpb.PreallocWriteRequest, logger log.Logger) ([]byte, error) - -// Wrap a slice in a struct so we can store a pointer in sync.Pool -type bufHolder struct { - buf []byte -} +// parserFunc defines how to read the body the request from an HTTP request. It takes an optional RequestBuffers. +type parserFunc func(ctx context.Context, r *http.Request, maxSize int, buffers *util.RequestBuffers, req *mimirpb.PreallocWriteRequest, logger log.Logger) error var ( bufferPool = sync.Pool{ - New: func() interface{} { return &bufHolder{buf: make([]byte, 256*1024)} }, + New: func() any { + return bytes.NewBuffer(make([]byte, 0, 256*1024)) + }, } errRetryBaseLessThanOneSecond = errors.New("retry base duration should not be less than 1 second") errNonPositiveMaxBackoffExponent = errors.New("max backoff exponent should be a positive value") @@ -87,12 +85,12 @@ func Handler( push PushFunc, logger log.Logger, ) http.Handler { - return handler(maxRecvMsgSize, sourceIPs, allowSkipLabelNameValidation, limits, retryCfg, push, logger, func(ctx context.Context, r *http.Request, maxRecvMsgSize int, dst []byte, req *mimirpb.PreallocWriteRequest, _ log.Logger) ([]byte, error) { - res, err := util.ParseProtoReader(ctx, r.Body, int(r.ContentLength), maxRecvMsgSize, dst, req, util.RawSnappy) + return handler(maxRecvMsgSize, sourceIPs, allowSkipLabelNameValidation, limits, retryCfg, push, logger, func(ctx context.Context, r *http.Request, maxRecvMsgSize int, buffers *util.RequestBuffers, req *mimirpb.PreallocWriteRequest, _ log.Logger) error { + err := util.ParseProtoReader(ctx, r.Body, int(r.ContentLength), maxRecvMsgSize, buffers, req, util.RawSnappy) if errors.Is(err, util.MsgSizeTooLargeErr{}) { err = distributorMaxWriteMessageSizeErr{actual: int(r.ContentLength), limit: maxRecvMsgSize} } - return res, err + return err }) } @@ -129,22 +127,17 @@ func handler( } } supplier := func() (*mimirpb.WriteRequest, func(), error) { - bufHolder := bufferPool.Get().(*bufHolder) + rb := util.NewRequestBuffers(&bufferPool) var req mimirpb.PreallocWriteRequest - buf, err := parser(ctx, r, maxRecvMsgSize, bufHolder.buf, &req, logger) - if err != nil { + if err := parser(ctx, r, maxRecvMsgSize, rb, &req, logger); err != nil { // Check for httpgrpc error, default to client error if parsing failed if _, ok := httpgrpc.HTTPResponseFromError(err); !ok { err = httpgrpc.Errorf(http.StatusBadRequest, err.Error()) } - bufferPool.Put(bufHolder) + rb.CleanUp() return nil, nil, err } - // If decoding allocated a bigger buffer, put that one back in the pool. - if buf = buf[:cap(buf)]; len(buf) > len(bufHolder.buf) { - bufHolder.buf = buf - } if allowSkipLabelNameValidation { req.SkipLabelNameValidation = req.SkipLabelNameValidation && r.Header.Get(SkipLabelNameValidationHeader) == "true" @@ -154,7 +147,7 @@ func handler( cleanup := func() { mimirpb.ReuseSlice(req.Timeseries) - bufferPool.Put(bufHolder) + rb.CleanUp() } return &req.WriteRequest, cleanup, nil } diff --git a/pkg/distributor/push_test.go b/pkg/distributor/push_test.go index 00107692027..d6d88aaaa08 100644 --- a/pkg/distributor/push_test.go +++ b/pkg/distributor/push_test.go @@ -36,6 +36,7 @@ import ( "google.golang.org/grpc/codes" "github.com/grafana/mimir/pkg/mimirpb" + "github.com/grafana/mimir/pkg/util" "github.com/grafana/mimir/pkg/util/test" "github.com/grafana/mimir/pkg/util/validation" ) @@ -134,18 +135,18 @@ func TestHandlerOTLPPush(t *testing.T) { require.Len(t, series, 1) samples := series[0].Samples - assert.Equal(t, 1, len(samples)) + require.Len(t, samples, 1) assert.Equal(t, float64(1), samples[0].Value) assert.Equal(t, "__name__", series[0].Labels[0].Name) assert.Equal(t, "foo", series[0].Labels[0].Value) metadata := request.Metadata + require.Len(t, metadata, 1) assert.Equal(t, mimirpb.GAUGE, metadata[0].GetType()) assert.Equal(t, "foo", metadata[0].GetMetricFamilyName()) assert.Equal(t, "metric_help", metadata[0].GetHelp()) assert.Equal(t, "metric_unit", metadata[0].GetUnit()) - pushReq.CleanUp() return nil } @@ -154,7 +155,7 @@ func TestHandlerOTLPPush(t *testing.T) { require.NoError(t, err) series := request.Timeseries - assert.Len(t, series, 1) + require.Len(t, series, 1) samples := series[0].Samples require.Equal(t, 1, len(samples)) @@ -165,7 +166,6 @@ func TestHandlerOTLPPush(t *testing.T) { metadata := request.Metadata assert.Equal(t, []*mimirpb.MetricMetadata(nil), metadata) - pushReq.CleanUp() return nil } @@ -295,6 +295,7 @@ func TestHandlerOTLPPush(t *testing.T) { require.NoError(t, err) pusher := func(ctx context.Context, pushReq *Request) error { t.Helper() + t.Cleanup(pushReq.CleanUp) return tt.verifyFunc(t, pushReq) } handler := OTLPHandler(tt.maxMsgSize, nil, false, tt.enableOtelMetadataStorage, limits, RetryConfig{}, nil, pusher, log.NewNopLogger()) @@ -305,7 +306,7 @@ func TestHandlerOTLPPush(t *testing.T) { assert.Equal(t, tt.responseCode, resp.Code) if tt.errMessage != "" { body, err := io.ReadAll(resp.Body) - assert.NoError(t, err) + require.NoError(t, err) assert.Contains(t, string(body), tt.errMessage) } }) @@ -520,13 +521,13 @@ func TestHandler_EnsureSkipLabelNameValidationBehaviour(t *testing.T) { req: createRequest(t, createMimirWriteRequestProtobufWithNonSupportedLabelNames(t, true)), verifyReqHandler: func(ctx context.Context, pushReq *Request) error { request, err := pushReq.WriteRequest() - assert.NoError(t, err) + require.NoError(t, err) + t.Cleanup(pushReq.CleanUp) assert.Len(t, request.Timeseries, 1) assert.Equal(t, "a-label", request.Timeseries[0].Labels[0].Name) assert.Equal(t, "value", request.Timeseries[0].Labels[0].Value) assert.Equal(t, mimirpb.RULE, request.Source) assert.False(t, request.SkipLabelNameValidation) - pushReq.CleanUp() return nil }, includeAllowSkiplabelNameValidationHeader: true, @@ -602,13 +603,13 @@ func verifyWritePushFunc(t *testing.T, expectSource mimirpb.WriteRequest_SourceE t.Helper() return func(ctx context.Context, pushReq *Request) error { request, err := pushReq.WriteRequest() - assert.NoError(t, err) - assert.Len(t, request.Timeseries, 1) - assert.Equal(t, "__name__", request.Timeseries[0].Labels[0].Name) - assert.Equal(t, "foo", request.Timeseries[0].Labels[0].Value) - assert.Equal(t, expectSource, request.Source) - assert.False(t, request.SkipLabelNameValidation) - pushReq.CleanUp() + require.NoError(t, err) + t.Cleanup(pushReq.CleanUp) + require.Len(t, request.Timeseries, 1) + require.Equal(t, "__name__", request.Timeseries[0].Labels[0].Name) + require.Equal(t, "foo", request.Timeseries[0].Labels[0].Value) + require.Equal(t, expectSource, request.Source) + require.False(t, request.SkipLabelNameValidation) return nil } } @@ -758,8 +759,8 @@ func TestHandler_ErrorTranslation(t *testing.T) { } for _, tc := range parserTestCases { t.Run(tc.name, func(t *testing.T) { - parserFunc := func(context.Context, *http.Request, int, []byte, *mimirpb.PreallocWriteRequest, log.Logger) ([]byte, error) { - return nil, tc.err + parserFunc := func(context.Context, *http.Request, int, *util.RequestBuffers, *mimirpb.PreallocWriteRequest, log.Logger) error { + return tc.err } pushFunc := func(ctx context.Context, req *Request) error { _, err := req.WriteRequest() // just read the body so we can trigger the parser @@ -825,8 +826,8 @@ func TestHandler_ErrorTranslation(t *testing.T) { for _, tc := range testCases { tc := tc t.Run(tc.name, func(t *testing.T) { - parserFunc := func(context.Context, *http.Request, int, []byte, *mimirpb.PreallocWriteRequest, log.Logger) ([]byte, error) { - return nil, nil + parserFunc := func(context.Context, *http.Request, int, *util.RequestBuffers, *mimirpb.PreallocWriteRequest, log.Logger) error { + return nil } pushFunc := func(ctx context.Context, req *Request) error { _, err := req.WriteRequest() // just read the body so we can trigger the parser @@ -841,7 +842,7 @@ func TestHandler_ErrorTranslation(t *testing.T) { assert.Equal(t, tc.expectedHTTPStatus, recorder.Code) if tc.err != nil { - assert.Equal(t, fmt.Sprintf("%s\n", tc.expectedErrorMessage), recorder.Body.String()) + require.Equal(t, fmt.Sprintf("%s\n", tc.expectedErrorMessage), recorder.Body.String()) } header := recorder.Header().Get(server.DoNotLogErrorHeaderKey) if tc.expectedDoNotLogErrorHeader { diff --git a/pkg/ingester/client/client_test.go b/pkg/ingester/client/client_test.go index 13519ecf82e..39e385acb3c 100644 --- a/pkg/ingester/client/client_test.go +++ b/pkg/ingester/client/client_test.go @@ -50,10 +50,10 @@ func TestMarshall(t *testing.T) { plentySize = 1024 * 1024 ) req := mimirpb.WriteRequest{} - _, err := util.ParseProtoReader(context.Background(), recorder.Body, recorder.Body.Len(), tooSmallSize, nil, &req, util.RawSnappy) + err := util.ParseProtoReader(context.Background(), recorder.Body, recorder.Body.Len(), tooSmallSize, nil, &req, util.RawSnappy) require.Error(t, err) - _, err = util.ParseProtoReader(context.Background(), recorder.Body, recorder.Body.Len(), plentySize, nil, &req, util.RawSnappy) + err = util.ParseProtoReader(context.Background(), recorder.Body, recorder.Body.Len(), plentySize, nil, &req, util.RawSnappy) require.NoError(t, err) - require.Equal(t, numSeries, len(req.Timeseries)) + require.Len(t, req.Timeseries, numSeries) } } diff --git a/pkg/querier/remote_read.go b/pkg/querier/remote_read.go index a14bc469e1d..10b0d6f7c64 100644 --- a/pkg/querier/remote_read.go +++ b/pkg/querier/remote_read.go @@ -46,7 +46,7 @@ func remoteReadHandler(q storage.SampleAndChunkQueryable, maxBytesInFrame int, l ctx := r.Context() var req client.ReadRequest logger := util_log.WithContext(r.Context(), lg) - if _, err := util.ParseProtoReader(ctx, r.Body, int(r.ContentLength), maxRemoteReadQuerySize, nil, &req, util.RawSnappy); err != nil { + if err := util.ParseProtoReader(ctx, r.Body, int(r.ContentLength), maxRemoteReadQuerySize, nil, &req, util.RawSnappy); err != nil { level.Error(logger).Log("msg", "failed to parse proto", "err", err.Error()) http.Error(w, err.Error(), http.StatusBadRequest) return diff --git a/pkg/util/http.go b/pkg/util/http.go index 57a0b2af019..835411208bb 100644 --- a/pkg/util/http.go +++ b/pkg/util/http.go @@ -17,6 +17,7 @@ import ( "net/http" "net/url" "strings" + "sync" "github.com/go-kit/log" "github.com/go-kit/log/level" @@ -145,15 +146,15 @@ const ( ) // ParseProtoReader parses a compressed proto from an io.Reader. -// You can pass in and receive back the decompression buffer for pooling, or pass in nil and ignore the return. -func ParseProtoReader(ctx context.Context, reader io.Reader, expectedSize, maxSize int, dst []byte, req proto.Message, compression CompressionType) ([]byte, error) { +// You can pass in an optional RequestBuffers. +func ParseProtoReader(ctx context.Context, reader io.Reader, expectedSize, maxSize int, buffers *RequestBuffers, req proto.Message, compression CompressionType) error { sp := opentracing.SpanFromContext(ctx) if sp != nil { sp.LogFields(otlog.Event("util.ParseProtoReader[start reading]")) } - body, err := decompressRequest(dst, reader, expectedSize, maxSize, compression, sp) + body, err := decompressRequest(buffers, reader, expectedSize, maxSize, compression, sp) if err != nil { - return nil, err + return err } if sp != nil { @@ -173,14 +174,14 @@ func ParseProtoReader(ctx context.Context, reader io.Reader, expectedSize, maxSi sp.LogFields(otlog.Event("util.ParseProtoReader[unmarshal done]"), otlog.Error(err)) } - return nil, err + return err } if sp != nil { sp.LogFields(otlog.Event("util.ParseProtoReader[unmarshal done]")) } - return body, nil + return nil } type MsgSizeTooLargeErr struct { @@ -198,7 +199,7 @@ func (e MsgSizeTooLargeErr) Is(err error) bool { return ok1 || ok2 } -func decompressRequest(dst []byte, reader io.Reader, expectedSize, maxSize int, compression CompressionType, sp opentracing.Span) ([]byte, error) { +func decompressRequest(buffers *RequestBuffers, reader io.Reader, expectedSize, maxSize int, compression CompressionType, sp opentracing.Span) ([]byte, error) { if expectedSize > maxSize { return nil, MsgSizeTooLargeErr{Actual: expectedSize, Limit: maxSize} } @@ -216,7 +217,7 @@ func decompressRequest(dst []byte, reader io.Reader, expectedSize, maxSize int, return buf.Bytes(), nil } - return decompressSnappyFromBuffer(dst, buf, maxSize, sp) + return decompressSnappyFromBuffer(buffers, buf, maxSize, sp) } } @@ -234,10 +235,13 @@ func decompressRequest(dst []byte, reader io.Reader, expectedSize, maxSize int, // Limit at maxSize+1 so we can tell when the size is exceeded reader = io.LimitReader(reader, int64(maxSize)+1) - var buf bytes.Buffer - if expectedSize > 0 { - buf.Grow(expectedSize + bytes.MinRead) // extra space guarantees no reallocation + + sz := expectedSize + if sz > 0 { + // Extra space guarantees no reallocation + sz += bytes.MinRead } + buf := buffers.Get(sz) if _, err := buf.ReadFrom(reader); err != nil { if compression == Gzip { return nil, errors.Wrap(err, "decompress gzip") @@ -246,7 +250,7 @@ func decompressRequest(dst []byte, reader io.Reader, expectedSize, maxSize int, } if compression == RawSnappy { - return decompressSnappyFromBuffer(dst, &buf, maxSize, sp) + return decompressSnappyFromBuffer(buffers, buf, maxSize, sp) } if buf.Len() > maxSize { @@ -255,7 +259,7 @@ func decompressRequest(dst []byte, reader io.Reader, expectedSize, maxSize int, return buf.Bytes(), nil } -func decompressSnappyFromBuffer(dst []byte, buffer *bytes.Buffer, maxSize int, sp opentracing.Span) ([]byte, error) { +func decompressSnappyFromBuffer(buffers *RequestBuffers, buffer *bytes.Buffer, maxSize int, sp opentracing.Span) ([]byte, error) { if sp != nil { sp.LogFields(otlog.Event("util.ParseProtoReader[decompressSnappy]"), otlog.Int("size", buffer.Len())) } @@ -267,11 +271,17 @@ func decompressSnappyFromBuffer(dst []byte, buffer *bytes.Buffer, maxSize int, s if size > maxSize { return nil, MsgSizeTooLargeErr{Actual: size, Limit: maxSize} } - body, err := snappy.Decode(dst, buffer.Bytes()) + + decBuf := buffers.Get(size) + // Snappy bases itself on the target buffer's length, not capacity + decBufBytes := decBuf.Bytes()[0:size] + + decoded, err := snappy.Decode(decBufBytes, buffer.Bytes()) if err != nil { return nil, errors.Wrap(err, "decompress snappy") } - return body, nil + + return decoded, nil } // tryBufferFromReader attempts to cast the reader to a `*bytes.Buffer` this is possible when using httpgrpc. @@ -374,3 +384,46 @@ func copyValues(src url.Values) url.Values { func IsHTTPStatusCode(code codes.Code) bool { return int(code) >= 100 && int(code) < 600 } + +// RequestBuffers provides pooled request buffers. +type RequestBuffers struct { + p *sync.Pool + buffers []*bytes.Buffer + // Allows avoiding heap allocation + buffersBacking [10]*bytes.Buffer +} + +// NewRequestBuffers returns a new RequestBuffers given a sync.Pool. +func NewRequestBuffers(p *sync.Pool) *RequestBuffers { + rb := &RequestBuffers{ + p: p, + } + rb.buffers = rb.buffersBacking[:0] + return rb +} + +// Get obtains a buffer from the pool. It will be returned back to the pool when CleanUp is called. +func (rb *RequestBuffers) Get(size int) *bytes.Buffer { + if rb == nil { + if size < 0 { + size = 0 + } + return bytes.NewBuffer(make([]byte, 0, size)) + } + + b := rb.p.Get().(*bytes.Buffer) + b.Reset() + if size > 0 { + b.Grow(size) + } + rb.buffers = append(rb.buffers, b) + return b +} + +// CleanUp releases buffers back to the pool. +func (rb *RequestBuffers) CleanUp() { + for _, b := range rb.buffers { + rb.p.Put(b) + } + rb.buffers = rb.buffers[:0] +} diff --git a/pkg/util/http_test.go b/pkg/util/http_test.go index e7a36647992..f92f88c303f 100644 --- a/pkg/util/http_test.go +++ b/pkg/util/http_test.go @@ -205,7 +205,7 @@ func TestParseProtoReader(t *testing.T) { reader = bytesBuffered{Buffer: &buf} } - _, err := util.ParseProtoReader(context.Background(), reader, 0, tt.maxSize, nil, &fromWire, tt.compression) + err := util.ParseProtoReader(context.Background(), reader, 0, tt.maxSize, nil, &fromWire, tt.compression) if tt.expectErr { require.Error(t, err) return diff --git a/tools/trafficdump/parser.go b/tools/trafficdump/parser.go index 3899a0b4cc4..ccfaf323113 100644 --- a/tools/trafficdump/parser.go +++ b/tools/trafficdump/parser.go @@ -134,7 +134,9 @@ func (rp *parser) processHTTPRequest(req *http.Request, body []byte) *request { if rp.decodePush && req.Method == "POST" && strings.Contains(req.URL.Path, "/push") { var matched bool - r.PushRequest, r.cleanup, matched = rp.decodePushRequest(req, body, rp.matchers) + rb := util.NewRequestBuffers(&bufferPool) + r.cleanup = rb.CleanUp + r.PushRequest, matched = rp.decodePushRequest(req, body, rp.matchers, rb) if !matched { r.ignored = true } @@ -152,35 +154,19 @@ func (rp *parser) processHTTPRequest(req *http.Request, body []byte) *request { return &r } -// Wrap a slice in a struct so we can store a pointer in sync.Pool -type bufHolder struct { - buf []byte -} - var bufferPool = sync.Pool{ - New: func() interface{} { return &bufHolder{buf: make([]byte, 256*1024)} }, + New: func() any { + return bytes.NewBuffer(make([]byte, 0, 256*1024)) + }, } -func (rp *parser) decodePushRequest(req *http.Request, body []byte, matchers []*labels.Matcher) (*pushRequest, func(), bool) { +func (rp *parser) decodePushRequest(req *http.Request, body []byte, matchers []*labels.Matcher, buffers *util.RequestBuffers) (*pushRequest, bool) { res := &pushRequest{Version: req.Header.Get("X-Prometheus-Remote-Write-Version")} - bufHolder := bufferPool.Get().(*bufHolder) - - cleanup := func() { - bufferPool.Put(bufHolder) - } - var wr mimirpb.WriteRequest - buf, err := util.ParseProtoReader(context.Background(), bytes.NewReader(body), int(req.ContentLength), 100<<20, bufHolder.buf, &wr, util.RawSnappy) - if err != nil { - cleanup() + if err := util.ParseProtoReader(context.Background(), bytes.NewReader(body), int(req.ContentLength), 100<<20, buffers, &wr, util.RawSnappy); err != nil { res.Error = fmt.Errorf("failed to decode decodePush request: %s", err).Error() - return nil, nil, true - } - - // If decoding allocated a bigger buffer, put that one back in the pool. - if len(buf) > len(bufHolder.buf) { - bufHolder.buf = buf + return nil, true } // See if we find the matching series. If not, we ignore this request. @@ -195,8 +181,7 @@ func (rp *parser) decodePushRequest(req *http.Request, body []byte, matchers []* } if !matched { - cleanup() - return nil, nil, false + return nil, false } } @@ -218,7 +203,7 @@ func (rp *parser) decodePushRequest(req *http.Request, body []byte, matchers []* res.Metadata = wr.Metadata - return res, cleanup, true + return res, true } func matches(lbls labels.Labels, matchers []*labels.Matcher) bool {