diff --git a/internal/internal_integration/interrupt_and_timeout_test.go b/internal/internal_integration/interrupt_and_timeout_test.go index 38b5dbba2..b3d2b170c 100644 --- a/internal/internal_integration/interrupt_and_timeout_test.go +++ b/internal/internal_integration/interrupt_and_timeout_test.go @@ -2,6 +2,7 @@ package internal_integration_test import ( "context" + "strconv" "sync" "time" @@ -1013,71 +1014,95 @@ var _ = Describe("Interrupts and Timeouts", func() { }) Describe("passing contexts to TableEntries", func() { - var times *TimeMap - BeforeEach(func() { - times = NewTimeMap() + Describe("the happy path", func() { + var times *TimeMap + BeforeEach(func() { + times = NewTimeMap() - success, _ := RunFixture(CurrentSpecReport().LeafNodeText, func() { - Context("container", func() { - DescribeTable("timeout table", - func(c SpecContext, d context.Context, key string) { - key = d.Value("key").(string) + key - rt.Run(CurrentSpecReport().LeafNodeText) - t := time.Now() - <-c.Done() - times.Set(key, time.Since(t)) - }, - func(d context.Context, key string) string { - key = d.Value("key").(string) + key - return key - }, - Entry(nil, context.WithValue(context.Background(), "key", "entry-"), "1", NodeTimeout(time.Millisecond)*100), - Entry(nil, context.WithValue(context.Background(), "key", "entry-"), "2", SpecTimeout(time.Millisecond)*150), - ) - - DescribeTable("timeout table", - func(c context.Context, key string) { - rt.Run(CurrentSpecReport().LeafNodeText) - t := time.Now() - <-c.Done() - times.Set(key, time.Since(t)) - }, - func(key string) string { - return key - }, - Entry(nil, "entry-3", NodeTimeout(time.Millisecond)*100), - Entry(nil, "entry-4", SpecTimeout(time.Millisecond)*150), - ) - - DescribeTable("timeout table", - func(c context.Context, key string) { - key = c.Value("key").(string) + key - rt.Run(CurrentSpecReport().LeafNodeText + "-" + key) - }, - func(d context.Context, key string) string { - key = d.Value("key").(string) + key - return key - }, - Entry(nil, context.WithValue(context.Background(), "key", "entry-"), "5"), - Entry(nil, context.WithValue(context.Background(), "key", "entry-"), "6"), - ) + success, _ := RunFixture(CurrentSpecReport().LeafNodeText, func() { + Context("container", func() { + DescribeTable("timeout table", + func(c SpecContext, d context.Context, key string) { + key = d.Value("key").(string) + key + rt.Run(CurrentSpecReport().LeafNodeText) + t := time.Now() + <-c.Done() + times.Set(key, time.Since(t)) + }, + func(d context.Context, key string) string { + key = d.Value("key").(string) + key + return key + }, + Entry(nil, context.WithValue(context.Background(), "key", "entry-"), "1", NodeTimeout(time.Millisecond)*100), + Entry(nil, context.WithValue(context.Background(), "key", "entry-"), "2", SpecTimeout(time.Millisecond)*150), + ) + + DescribeTable("timeout table", + func(c context.Context, key string) { + rt.Run(CurrentSpecReport().LeafNodeText) + t := time.Now() + <-c.Done() + times.Set(key, time.Since(t)) + }, + func(key string) string { + return key + }, + Entry(nil, "entry-3", NodeTimeout(time.Millisecond)*100), + Entry(nil, "entry-4", SpecTimeout(time.Millisecond)*150), + ) + + DescribeTable("timeout table", + func(c context.Context, key string) { + key = c.Value("key").(string) + key + rt.Run(CurrentSpecReport().LeafNodeText + "-" + key) + }, + func(d context.Context, key string) string { + key = d.Value("key").(string) + key + return key + }, + Entry(nil, context.WithValue(context.Background(), "key", "entry-"), "5"), + Entry(nil, context.WithValue(context.Background(), "key", "entry-"), "6"), + ) + }) }) + Ω(success).Should(Equal(false)) + }) + + It("should work", func() { + Ω(rt).Should(HaveTracked("entry-1", "entry-2", "entry-3", "entry-4", "entry-5-entry-5", "entry-6-entry-6")) + Ω(reporter.Did.Find("entry-1")).Should(HaveTimedOut()) + Ω(reporter.Did.Find("entry-2")).Should(HaveTimedOut()) + Ω(reporter.Did.Find("entry-3")).Should(HaveTimedOut()) + Ω(reporter.Did.Find("entry-4")).Should(HaveTimedOut()) + Ω(reporter.Did.Find("entry-1").Failure.ProgressReport.CurrentNodeType).Should(Equal(types.NodeTypeIt)) + + Ω(times.Get("entry-1")).Should(BeNumerically("~", 100*time.Millisecond, 50*time.Millisecond)) + Ω(times.Get("entry-2")).Should(BeNumerically("~", 150*time.Millisecond, 50*time.Millisecond)) + Ω(times.Get("entry-3")).Should(BeNumerically("~", 100*time.Millisecond, 50*time.Millisecond)) + Ω(times.Get("entry-4")).Should(BeNumerically("~", 150*time.Millisecond, 50*time.Millisecond)) }) - Ω(success).Should(Equal(false)) }) - It("should work", func() { - Ω(rt).Should(HaveTracked("entry-1", "entry-2", "entry-3", "entry-4", "entry-5-entry-5", "entry-6-entry-6")) - Ω(reporter.Did.Find("entry-1")).Should(HaveTimedOut()) - Ω(reporter.Did.Find("entry-2")).Should(HaveTimedOut()) - Ω(reporter.Did.Find("entry-3")).Should(HaveTimedOut()) - Ω(reporter.Did.Find("entry-4")).Should(HaveTimedOut()) - Ω(reporter.Did.Find("entry-1").Failure.ProgressReport.CurrentNodeType).Should(Equal(types.NodeTypeIt)) - - Ω(times.Get("entry-1")).Should(BeNumerically("~", 100*time.Millisecond, 50*time.Millisecond)) - Ω(times.Get("entry-2")).Should(BeNumerically("~", 150*time.Millisecond, 50*time.Millisecond)) - Ω(times.Get("entry-3")).Should(BeNumerically("~", 100*time.Millisecond, 50*time.Millisecond)) - Ω(times.Get("entry-4")).Should(BeNumerically("~", 150*time.Millisecond, 50*time.Millisecond)) + Describe("the edge case in #1415", func() { + var four = 4 + var nSix = -6 + DescribeTable("it supports receiving a SpecContext and works with nil parameters", func(ctx context.Context, num *int, s string, cVal string) { + Ω(ctx).ShouldNot(BeNil()) + if num == nil { + Ω(s).Should(Equal("nil")) + } else { + Ω(s).Should(Equal(strconv.Itoa(*num))) + } + if cVal != "" { + Ω(ctx.Value("key")).Should(Equal(cVal)) + } + }, + Entry("4", &four, "4", ""), + Entry("-6", &nSix, "-6", ""), + Entry("nil", nil, "nil", ""), + Entry("4 with context value", context.WithValue(context.Background(), "key", "val"), &four, "4", "val"), + Entry("nil with context value", context.WithValue(context.Background(), "key", "val"), nil, "nil", "val"), + ) }) }) }) diff --git a/internal/internal_integration/table_test.go b/internal/internal_integration/table_test.go index bf954bb63..c8d01b455 100644 --- a/internal/internal_integration/table_test.go +++ b/internal/internal_integration/table_test.go @@ -348,6 +348,7 @@ var _ = Describe("Table driven tests", func() { Ω(b).Should(Equal(c[1])) Ω(c[2]).Should(BeNil()) }, Entry("variadic arguments", 1, "one", 1, "one", nil)) + }) Describe("when table entries are marked pending", func() { diff --git a/table_dsl.go b/table_dsl.go index a3aef821b..c7de7a8be 100644 --- a/table_dsl.go +++ b/table_dsl.go @@ -269,11 +269,15 @@ func generateTable(description string, isSubtree bool, args ...interface{}) { internalNodeArgs = append(internalNodeArgs, entry.decorations...) hasContext := false - if internalBodyType.NumIn() > 0. { + if internalBodyType.NumIn() > 0 { if internalBodyType.In(0).Implements(specContextType) { hasContext = true - } else if internalBodyType.In(0).Implements(contextType) && (len(entry.parameters) == 0 || !reflect.TypeOf(entry.parameters[0]).Implements(contextType)) { + } else if internalBodyType.In(0).Implements(contextType) { hasContext = true + if len(entry.parameters) > 0 && reflect.TypeOf(entry.parameters[0]) != nil && reflect.TypeOf(entry.parameters[0]).Implements(contextType) { + // we allow you to pass in a non-nil context + hasContext = false + } } }