Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

concurrency.ForEachJob() #113

Merged
merged 9 commits into from
Jan 12, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
* [CHANGE] grpcutil: Convert Resolver into concrete type. #105
* [CHANGE] grpcutil.Resolver.Resolve: Take a service parameter. #102
* [CHANGE] grpcutil.Update: Remove gRPC LB related metadata. #102
* [CHANGE] concurrency.ForEach: deprecated and reimplemented by new `concurrency.ForEachJob`. #113
* [ENHANCEMENT] Add middleware package. #38
* [ENHANCEMENT] Add the ring package #45
* [ENHANCEMENT] Add limiter package. #41
Expand Down
55 changes: 32 additions & 23 deletions concurrency/runner.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"context"
"sync"

"go.uber.org/atomic"
"golang.org/x/sync/errgroup"

"github.com/grafana/dskit/internal/math"
Expand Down Expand Up @@ -62,45 +63,53 @@ func ForEachUser(ctx context.Context, userIDs []string, concurrency int, userFun

// ForEach runs the provided jobFunc for each job up to concurrency concurrent workers.
// The execution breaks on first error encountered.
//
// Deprecated: use ForEachJob instead.
func ForEach(ctx context.Context, jobs []interface{}, concurrency int, jobFunc func(ctx context.Context, job interface{}) error) error {
if len(jobs) == 0 {
return nil
return ForEachJob(ctx, len(jobs), concurrency, func(ctx context.Context, idx int) error {
return jobFunc(ctx, jobs[idx])
})
}

// CreateJobsFromStrings is an utility to create jobs from an slice of strings.
//
// Deprecated: will be removed as it's not needed when using ForEachJob.
func CreateJobsFromStrings(values []string) []interface{} {
jobs := make([]interface{}, len(values))
for i := 0; i < len(values); i++ {
jobs[i] = values[i]
}
return jobs
}

// Push all jobs to a channel.
ch := make(chan interface{}, len(jobs))
for _, job := range jobs {
ch <- job
// ForEachJob runs the provided jobFunc for each job index in [0, jobs) up to concurrency concurrent workers.
// The execution breaks on first error encountered.
func ForEachJob(ctx context.Context, jobs int, concurrency int, jobFunc func(ctx context.Context, idx int) error) error {
colega marked this conversation as resolved.
Show resolved Hide resolved
if jobs == 0 {
return nil
}
close(ch)

// Initialise indexes with -1 so first Inc() returns index 0.
indexes := atomic.NewInt64(-1)

// Start workers to process jobs.
g, ctx := errgroup.WithContext(ctx)
for ix := 0; ix < math.Min(concurrency, len(jobs)); ix++ {
for ix := 0; ix < math.Min(concurrency, jobs); ix++ {
g.Go(func() error {
for job := range ch {
if err := ctx.Err(); err != nil {
return err
for ctx.Err() == nil {
idx := int(indexes.Inc())
if idx >= jobs {
return nil
}

if err := jobFunc(ctx, job); err != nil {
if err := jobFunc(ctx, idx); err != nil {
return err
}
}

return nil
return ctx.Err()
})
}

// Wait until done (or context has canceled).
return g.Wait()
}

// CreateJobsFromStrings is an utility to create jobs from an slice of strings.
func CreateJobsFromStrings(values []string) []interface{} {
jobs := make([]interface{}, len(values))
for i := 0; i < len(values); i++ {
jobs[i] = values[i]
}
return jobs
}
110 changes: 90 additions & 20 deletions concurrency/runner_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,16 +14,14 @@ import (

func TestForEachUser(t *testing.T) {
var (
ctx = context.Background()

// Keep track of processed users.
processedMx sync.Mutex
processed []string
)

input := []string{"a", "b", "c"}

err := ForEachUser(ctx, input, 2, func(ctx context.Context, user string) error {
err := ForEachUser(context.Background(), input, 2, func(ctx context.Context, user string) error {
processedMx.Lock()
defer processedMx.Unlock()
processed = append(processed, user)
Expand All @@ -35,16 +33,12 @@ func TestForEachUser(t *testing.T) {
}

func TestForEachUser_ShouldContinueOnErrorButReturnIt(t *testing.T) {
var (
ctx = context.Background()

// Keep the processed users count.
processed atomic.Int32
)
// Keep the processed users count.
var processed atomic.Int32

input := []string{"a", "b", "c"}

err := ForEachUser(ctx, input, 2, func(ctx context.Context, user string) error {
err := ForEachUser(context.Background(), input, 2, func(ctx context.Context, user string) error {
if processed.CAS(0, 1) {
return errors.New("the first request is failing")
}
Expand Down Expand Up @@ -72,18 +66,98 @@ func TestForEachUser_ShouldReturnImmediatelyOnNoUsersProvided(t *testing.T) {
}))
}

func TestForEachJob(t *testing.T) {
jobs := []string{"a", "b", "c"}
processed := make([]string, len(jobs))

err := ForEachJob(context.Background(), len(jobs), 2, func(ctx context.Context, idx int) error {
processed[idx] = jobs[idx]
return nil
})

require.NoError(t, err)
assert.ElementsMatch(t, jobs, processed)
}

func TestForEachJob_ShouldBreakOnFirstError_ContextCancellationHandled(t *testing.T) {
// Keep the processed jobs count.
var processed atomic.Int32

err := ForEachJob(context.Background(), 3, 2, func(ctx context.Context, idx int) error {
if processed.CAS(0, 1) {
return errors.New("the first request is failing")
}

// Wait 1s and increase the number of processed jobs, unless the context get canceled earlier.
select {
case <-time.After(time.Second):
processed.Add(1)
case <-ctx.Done():
return ctx.Err()
}

return nil
})

require.EqualError(t, err, "the first request is failing")

// Since we expect the first error interrupts the workers, we should only see
// 1 job processed (the one which immediately returned error).
assert.Equal(t, int32(1), processed.Load())
}

func TestForEachJob_ShouldBreakOnFirstError_ContextCancellationUnhandled(t *testing.T) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What exactly does this test checks as opposed to TestForEachJob_ShouldBreakOnFirstError_ContextCancellationHandled?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

From what I see, TestForEachJob_ShouldBreakOnFirstError_ContextCancellationHandled is a simple test that returns an error and makes sure that context is canceled, without extra assertions on how many jobs were launched.

This TestForEachJob_ShouldBreakOnFirstError_ContextCancellationUnhandled makes sure that when two jobs (with concurrency=2) are launched, then the third job is not launched at all.

// Keep the processed jobs count.
var processed atomic.Int32

// waitGroup to await the start of the first two jobs
var wg sync.WaitGroup
wg.Add(2)

err := ForEachJob(context.Background(), 3, 2, func(ctx context.Context, idx int) error {
wg.Done()

if processed.CAS(0, 1) {
// wait till two jobs have been started
wg.Wait()
return errors.New("the first request is failing")
}

// Wait till context is cancelled to add processed jobs.
<-ctx.Done()
processed.Add(1)

return nil
})

require.EqualError(t, err, "the first request is failing")

// Since we expect the first error interrupts the workers, we should only
// see 2 job processed (the one which immediately returned error and the
// job with "b").
assert.Equal(t, int32(2), processed.Load())
}

func TestForEachJob_ShouldReturnImmediatelyOnNoJobsProvided(t *testing.T) {
// Keep the processed jobs count.
var processed atomic.Int32
require.NoError(t, ForEachJob(context.Background(), 0, 2, func(ctx context.Context, idx int) error {
processed.Inc()
return nil
}))
require.Zero(t, processed.Load())
}

func TestForEach(t *testing.T) {
var (
ctx = context.Background()

// Keep track of processed jobs.
processedMx sync.Mutex
processed []string
)

jobs := []string{"a", "b", "c"}

err := ForEach(ctx, CreateJobsFromStrings(jobs), 2, func(ctx context.Context, job interface{}) error {
err := ForEach(context.Background(), CreateJobsFromStrings(jobs), 2, func(ctx context.Context, job interface{}) error {
processedMx.Lock()
defer processedMx.Unlock()
processed = append(processed, job.(string))
Expand Down Expand Up @@ -126,18 +200,14 @@ func TestForEach_ShouldBreakOnFirstError_ContextCancellationHandled(t *testing.T
}

func TestForEach_ShouldBreakOnFirstError_ContextCancellationUnhandled(t *testing.T) {
var (
ctx = context.Background()

// Keep the processed jobs count.
processed atomic.Int32
)
// Keep the processed jobs count.
var processed atomic.Int32

// waitGroup to await the start of the first two jobs
var wg sync.WaitGroup
wg.Add(2)

err := ForEach(ctx, []interface{}{"a", "b", "c"}, 2, func(ctx context.Context, job interface{}) error {
err := ForEach(context.Background(), []interface{}{"a", "b", "c"}, 2, func(ctx context.Context, job interface{}) error {
wg.Done()

if processed.CAS(0, 1) {
Expand Down