diff --git a/CHANGELOG_PENDING.md b/CHANGELOG_PENDING.md index 8a1927a39ca..d10c15397c3 100644 --- a/CHANGELOG_PENDING.md +++ b/CHANGELOG_PENDING.md @@ -3,3 +3,5 @@ ### SDK Enhancements ### SDK Bugs +* Add missing bool error matching. + * This enables waiters defined to match on presence/absence of errors. diff --git a/aws/request/waiter.go b/aws/request/waiter.go index 4601f883cc5..992ed0464b9 100644 --- a/aws/request/waiter.go +++ b/aws/request/waiter.go @@ -256,8 +256,17 @@ func (a *WaiterAcceptor) match(name string, l aws.Logger, req *Request, err erro s := a.Expected.(int) result = s == req.HTTPResponse.StatusCode case ErrorWaiterMatch: - if aerr, ok := err.(awserr.Error); ok { - result = aerr.Code() == a.Expected.(string) + switch ex := a.Expected.(type) { + case string: + if aerr, ok := err.(awserr.Error); ok { + result = aerr.Code() == ex + } + case bool: + if ex { + result = err != nil + } else { + result = err == nil + } } default: waiterLogf(l, "WARNING: Waiter %s encountered unexpected matcher: %s", diff --git a/aws/request/waiter_test.go b/aws/request/waiter_test.go index 7a9f29d9f1d..d1ac3d13ead 100644 --- a/aws/request/waiter_test.go +++ b/aws/request/waiter_test.go @@ -386,6 +386,217 @@ func TestWaiterError(t *testing.T) { } } +func TestWaiterRetryAnyError(t *testing.T) { + svc := &mockClient{Client: awstesting.NewClient(&aws.Config{ + Region: aws.String("mock-region"), + })} + svc.Handlers.Send.Clear() // mock sending + svc.Handlers.Unmarshal.Clear() + svc.Handlers.UnmarshalMeta.Clear() + svc.Handlers.UnmarshalError.Clear() + svc.Handlers.ValidateResponse.Clear() + + var reqNum int + results := []struct { + Out *MockOutput + Err error + }{ + { // retry + Err: awserr.New( + "MockException1", "mock exception message", nil, + ), + }, + { // retry + Err: awserr.New( + "MockException2", "mock exception message", nil, + ), + }, + { // success + Out: &MockOutput{ + States: []*MockState{ + {aws.String("running")}, + {aws.String("running")}, + }, + }, + }, + { // shouldn't happen + Out: &MockOutput{ + States: []*MockState{ + {aws.String("running")}, + {aws.String("running")}, + }, + }, + }, + } + + numBuiltReq := 0 + svc.Handlers.Build.PushBack(func(r *request.Request) { + numBuiltReq++ + }) + svc.Handlers.Send.PushBack(func(r *request.Request) { + code := http.StatusOK + r.HTTPResponse = &http.Response{ + StatusCode: code, + Status: http.StatusText(code), + Body: ioutil.NopCloser(bytes.NewReader([]byte{})), + } + }) + svc.Handlers.Unmarshal.PushBack(func(r *request.Request) { + if reqNum >= len(results) { + t.Errorf("too many polling requests made") + return + } + r.Data = results[reqNum].Out + reqNum++ + }) + svc.Handlers.UnmarshalMeta.PushBack(func(r *request.Request) { + // If there was an error unmarshal error will be called instead of unmarshal + // need to increment count here also + if err := results[reqNum].Err; err != nil { + r.Error = err + reqNum++ + } + }) + + w := request.Waiter{ + MaxAttempts: 10, + Delay: request.ConstantWaiterDelay(0), + SleepWithContext: aws.SleepWithContext, + Acceptors: []request.WaiterAcceptor{ + { + State: request.SuccessWaiterState, + Matcher: request.PathAllWaiterMatch, + Argument: "States[].State", + Expected: "running", + }, + { + State: request.RetryWaiterState, + Matcher: request.ErrorWaiterMatch, + Argument: "", + Expected: true, + }, + { + State: request.FailureWaiterState, + Matcher: request.ErrorWaiterMatch, + Argument: "", + Expected: "FailureException", + }, + }, + NewRequest: BuildNewMockRequest(svc, &MockInput{}), + } + + err := w.WaitWithContext(aws.BackgroundContext()) + if err != nil { + t.Fatalf("expected no error, but did get one: %v", err) + } + if e, a := 3, numBuiltReq; e != a { + t.Errorf("expect %d built requests got %d", e, a) + } + if e, a := 3, reqNum; e != a { + t.Errorf("expect %d reqNum got %d", e, a) + } +} + +func TestWaiterSuccessNoError(t *testing.T) { + svc := &mockClient{Client: awstesting.NewClient(&aws.Config{ + Region: aws.String("mock-region"), + })} + svc.Handlers.Send.Clear() // mock sending + svc.Handlers.Unmarshal.Clear() + svc.Handlers.UnmarshalMeta.Clear() + svc.Handlers.UnmarshalError.Clear() + svc.Handlers.ValidateResponse.Clear() + + var reqNum int + results := []struct { + Out *MockOutput + Err error + }{ + { // success + Out: &MockOutput{ + States: []*MockState{ + {aws.String("pending")}, + }, + }, + }, + { // shouldn't happen + Out: &MockOutput{ + States: []*MockState{ + {aws.String("pending")}, + {aws.String("pending")}, + }, + }, + }, + } + + numBuiltReq := 0 + svc.Handlers.Build.PushBack(func(r *request.Request) { + numBuiltReq++ + }) + svc.Handlers.Send.PushBack(func(r *request.Request) { + code := http.StatusOK + r.HTTPResponse = &http.Response{ + StatusCode: code, + Status: http.StatusText(code), + Body: ioutil.NopCloser(bytes.NewReader([]byte{})), + } + }) + svc.Handlers.Unmarshal.PushBack(func(r *request.Request) { + if reqNum >= len(results) { + t.Errorf("too many polling requests made") + return + } + r.Data = results[reqNum].Out + reqNum++ + }) + svc.Handlers.UnmarshalMeta.PushBack(func(r *request.Request) { + // If there was an error unmarshal error will be called instead of unmarshal + // need to increment count here also + if err := results[reqNum].Err; err != nil { + r.Error = err + reqNum++ + } + }) + + w := request.Waiter{ + MaxAttempts: 10, + Delay: request.ConstantWaiterDelay(0), + SleepWithContext: aws.SleepWithContext, + Acceptors: []request.WaiterAcceptor{ + { + State: request.SuccessWaiterState, + Matcher: request.PathAllWaiterMatch, + Argument: "States[].State", + Expected: "running", + }, + { + State: request.SuccessWaiterState, + Matcher: request.ErrorWaiterMatch, + Argument: "", + Expected: false, + }, + { + State: request.FailureWaiterState, + Matcher: request.ErrorWaiterMatch, + Argument: "", + Expected: "FailureException", + }, + }, + NewRequest: BuildNewMockRequest(svc, &MockInput{}), + } + + err := w.WaitWithContext(aws.BackgroundContext()) + if err != nil { + t.Fatalf("expected no error, but did get one") + } + if e, a := 1, numBuiltReq; e != a { + t.Errorf("expect %d built requests got %d", e, a) + } + if e, a := 1, reqNum; e != a { + t.Errorf("expect %d reqNum got %d", e, a) + } +} + func TestWaiterStatus(t *testing.T) { svc := &mockClient{Client: awstesting.NewClient(&aws.Config{ Region: aws.String("mock-region"),