diff --git a/objstore.go b/objstore.go index c913068..a83f7d2 100644 --- a/objstore.go +++ b/objstore.go @@ -551,6 +551,7 @@ func (b *metricBucket) Get(ctx context.Context, name string) (io.ReadCloser, err return nil, err } return newTimingReader( + ctx, rc, true, op, @@ -574,6 +575,7 @@ func (b *metricBucket) GetRange(ctx context.Context, name string, off, length in return nil, err } return newTimingReader( + ctx, rc, true, op, @@ -606,6 +608,7 @@ func (b *metricBucket) Upload(ctx context.Context, name string, r io.Reader) err b.ops.WithLabelValues(op).Inc() trc := newTimingReader( + ctx, r, false, op, @@ -663,6 +666,8 @@ func (b *metricBucket) Name() string { type timingReader struct { io.Reader + ctx context.Context + // closeReader holds whether the wrapper io.Reader should be closed when // Close() is called on the timingReader. closeReader bool @@ -682,7 +687,7 @@ type timingReader struct { transferredBytes *prometheus.HistogramVec } -func newTimingReader(r io.Reader, closeReader bool, op string, dur *prometheus.HistogramVec, failed *prometheus.CounterVec, isFailureExpected IsOpFailureExpectedFunc, fetchedBytes *prometheus.CounterVec, transferredBytes *prometheus.HistogramVec) io.ReadCloser { +func newTimingReader(ctx context.Context, r io.Reader, closeReader bool, op string, dur *prometheus.HistogramVec, failed *prometheus.CounterVec, isFailureExpected IsOpFailureExpectedFunc, fetchedBytes *prometheus.CounterVec, transferredBytes *prometheus.HistogramVec) io.ReadCloser { // Initialize the metrics with 0. dur.WithLabelValues(op) failed.WithLabelValues(op) @@ -690,6 +695,7 @@ func newTimingReader(r io.Reader, closeReader bool, op string, dur *prometheus.H trc := timingReader{ Reader: r, + ctx: ctx, closeReader: closeReader, objSize: objSize, objSizeErr: objSizeErr, @@ -756,7 +762,7 @@ func (r *timingReader) Read(b []byte) (n int, err error) { r.readBytes += int64(n) // Report metric just once. if !r.alreadyGotErr && err != nil && err != io.EOF { - if !r.isFailureExpected(err) { + if !r.isFailureExpected(err) && r.ctx.Err() != context.Canceled { r.failed.WithLabelValues(r.op).Inc() } r.alreadyGotErr = true diff --git a/objstore_test.go b/objstore_test.go index f83073c..35a5d88 100644 --- a/objstore_test.go +++ b/objstore_test.go @@ -412,7 +412,7 @@ func TestDownloadUploadDirConcurrency(t *testing.T) { func TestTimingReader(t *testing.T) { m := WrapWithMetrics(NewInMemBucket(), nil, "") r := bytes.NewReader([]byte("hello world")) - tr := newTimingReader(r, true, OpGet, m.opsDuration, m.opsFailures, func(err error) bool { + tr := newTimingReader(context.Background(), r, true, OpGet, m.opsDuration, m.opsFailures, func(err error) bool { return false }, m.opsFetchedBytes, m.opsTransferredBytes) @@ -447,7 +447,7 @@ func TestTimingReader_ExpectedError(t *testing.T) { m := WrapWithMetrics(NewInMemBucket(), nil, "") r := dummyReader{readerErr} - tr := newTimingReader(r, true, OpGet, m.opsDuration, m.opsFailures, func(err error) bool { return errors.Is(err, readerErr) }, m.opsFetchedBytes, m.opsTransferredBytes) + tr := newTimingReader(context.Background(), r, true, OpGet, m.opsDuration, m.opsFailures, func(err error) bool { return errors.Is(err, readerErr) }, m.opsFetchedBytes, m.opsTransferredBytes) buf := make([]byte, 1) _, err := io.ReadFull(tr, buf) @@ -461,7 +461,7 @@ func TestTimingReader_UnexpectedError(t *testing.T) { m := WrapWithMetrics(NewInMemBucket(), nil, "") r := dummyReader{readerErr} - tr := newTimingReader(r, true, OpGet, m.opsDuration, m.opsFailures, func(err error) bool { return false }, m.opsFetchedBytes, m.opsTransferredBytes) + tr := newTimingReader(context.Background(), r, true, OpGet, m.opsDuration, m.opsFailures, func(err error) bool { return false }, m.opsFetchedBytes, m.opsTransferredBytes) buf := make([]byte, 1) _, err := io.ReadFull(tr, buf) @@ -471,13 +471,16 @@ func TestTimingReader_UnexpectedError(t *testing.T) { } func TestTimingReader_ContextCancellation(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + cancel() + m := WrapWithMetrics(NewInMemBucket(), nil, "") - r := dummyReader{context.Canceled} - tr := newTimingReader(r, true, OpGet, m.opsDuration, m.opsFailures, func(err error) bool { return false }, m.opsFetchedBytes, m.opsTransferredBytes) + r := dummyReader{ctx.Err()} + tr := newTimingReader(ctx, r, true, OpGet, m.opsDuration, m.opsFailures, func(err error) bool { return false }, m.opsFetchedBytes, m.opsTransferredBytes) buf := make([]byte, 1) _, err := io.ReadFull(tr, buf) - testutil.Equals(t, context.Canceled, err) + testutil.Equals(t, ctx.Err(), err) testutil.Equals(t, float64(0), promtest.ToFloat64(m.opsFailures.WithLabelValues(OpGet))) } @@ -503,7 +506,7 @@ func TestTimingReader_ShouldCorrectlyWrapFile(t *testing.T) { }) m := WrapWithMetrics(NewInMemBucket(), nil, "") - r := newTimingReader(file, true, "", m.opsDuration, m.opsFailures, func(err error) bool { + r := newTimingReader(context.Background(), file, true, "", m.opsDuration, m.opsFailures, func(err error) bool { return false }, m.opsFetchedBytes, m.opsTransferredBytes)