Skip to content

Commit

Permalink
internal/grpcsync: refactor test (#6427)
Browse files Browse the repository at this point in the history
  • Loading branch information
arvindbr8 committed Jun 30, 2023
1 parent 51042db commit acbfcbb
Showing 1 changed file with 45 additions and 66 deletions.
111 changes: 45 additions & 66 deletions internal/grpcsync/pubsub_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,44 +19,26 @@
package grpcsync

import (
"fmt"
"sync"
"testing"
"time"

"github.com/google/go-cmp/cmp"
)

type testSubscriber struct {
mu sync.Mutex
msgs []int
onMsgCh chan struct{}
onMsgCh chan int
}

func newTestSubscriber(chSize int) *testSubscriber {
return &testSubscriber{onMsgCh: make(chan struct{}, chSize)}
return &testSubscriber{onMsgCh: make(chan int, chSize)}
}

func (ts *testSubscriber) OnMessage(msg interface{}) {
ts.mu.Lock()
defer ts.mu.Unlock()
ts.msgs = append(ts.msgs, msg.(int))
select {
case ts.onMsgCh <- struct{}{}:
case ts.onMsgCh <- msg.(int):
default:
}
}

func (ts *testSubscriber) receivedMsgs() []int {
ts.mu.Lock()
defer ts.mu.Unlock()

msgs := make([]int, len(ts.msgs))
copy(msgs, ts.msgs)

return msgs
}

func (s) TestPubSub_PublishNoMsg(t *testing.T) {
pubsub := NewPubSub()
defer pubsub.Stop()
Expand All @@ -66,7 +48,7 @@ func (s) TestPubSub_PublishNoMsg(t *testing.T) {

select {
case <-ts.onMsgCh:
t.Fatalf("Subscriber callback invoked when no message was published")
t.Fatal("Subscriber callback invoked when no message was published")
case <-time.After(defaultTestShortTimeout):
}
}
Expand All @@ -78,95 +60,92 @@ func (s) TestPubSub_PublishMsgs_RegisterSubs_And_Stop(t *testing.T) {

ts1 := newTestSubscriber(numPublished)
pubsub.Subscribe(ts1)
wantMsgs1 := []int{}

var wg sync.WaitGroup
wg.Add(2)
// Publish ten messages on the pubsub and ensure that they are received in order by the subscriber.
go func() {
for i := 0; i < numPublished; i++ {
pubsub.Publish(i)
wantMsgs1 = append(wantMsgs1, i)
}
wg.Done()
}()

isTimeout := false
go func() {
defer wg.Done()
for i := 0; i < numPublished; i++ {
select {
case <-ts1.onMsgCh:
case m := <-ts1.onMsgCh:
if m != i {
t.Errorf("Received unexpected message: %q; want: %q", m, i)
return
}
case <-time.After(defaultTestTimeout):
isTimeout = true
t.Error("Timeout when expecting the onMessage() callback to be invoked")
return
}
}
wg.Done()
}()

wg.Wait()
if isTimeout {
t.Fatalf("Timeout when expecting the onMessage() callback to be invoked")
}
if gotMsgs1 := ts1.receivedMsgs(); !cmp.Equal(gotMsgs1, wantMsgs1) {
t.Fatalf("Received messages is %v, want %v", gotMsgs1, wantMsgs1)
if t.Failed() {
t.FailNow()
}

// Register another subscriber and ensure that it receives the last published message.
ts2 := newTestSubscriber(numPublished)
pubsub.Subscribe(ts2)
wantMsgs2 := wantMsgs1[len(wantMsgs1)-1:]

select {
case <-ts2.onMsgCh:
case m := <-ts2.onMsgCh:
if m != numPublished-1 {
t.Fatalf("Received unexpected message: %q; want: %q", m, numPublished-1)
}
case <-time.After(defaultTestShortTimeout):
t.Fatalf("Timeout when expecting the onMessage() callback to be invoked")
}
if gotMsgs2 := ts2.receivedMsgs(); !cmp.Equal(gotMsgs2, wantMsgs2) {
t.Fatalf("Received messages is %v, want %v", gotMsgs2, wantMsgs2)
t.Fatal("Timeout when expecting the onMessage() callback to be invoked")
}

wg.Add(3)
// Publish ten messages on the pubsub and ensure that they are received in order by the subscribers.
go func() {
for i := 0; i < numPublished; i++ {
pubsub.Publish(i)
wantMsgs1 = append(wantMsgs1, i)
wantMsgs2 = append(wantMsgs2, i)
}
wg.Done()
}()
errCh := make(chan error, 1)
go func() {
defer wg.Done()
for i := 0; i < numPublished; i++ {
select {
case <-ts1.onMsgCh:
case m := <-ts1.onMsgCh:
if m != i {
t.Errorf("Received unexpected message: %q; want: %q", m, i)
return
}
case <-time.After(defaultTestTimeout):
errCh <- fmt.Errorf("")
t.Error("Timeout when expecting the onMessage() callback to be invoked")
return
}
}
wg.Done()

}()
go func() {
defer wg.Done()
for i := 0; i < numPublished; i++ {
select {
case <-ts2.onMsgCh:
case m := <-ts2.onMsgCh:
if m != i {
t.Errorf("Received unexpected message: %q; want: %q", m, i)
return
}
case <-time.After(defaultTestTimeout):
errCh <- fmt.Errorf("")
t.Error("Timeout when expecting the onMessage() callback to be invoked")
return
}
}
wg.Done()
}()
wg.Wait()
select {
case <-errCh:
t.Fatalf("Timeout when expecting the onMessage() callback to be invoked")
default:
}
if gotMsgs1 := ts1.receivedMsgs(); !cmp.Equal(gotMsgs1, wantMsgs1) {
t.Fatalf("Received messages is %v, want %v", gotMsgs1, wantMsgs1)
}
if gotMsgs2 := ts2.receivedMsgs(); !cmp.Equal(gotMsgs2, wantMsgs2) {
t.Fatalf("Received messages is %v, want %v", gotMsgs2, wantMsgs2)
if t.Failed() {
t.FailNow()
}

pubsub.Stop()
Expand All @@ -178,9 +157,9 @@ func (s) TestPubSub_PublishMsgs_RegisterSubs_And_Stop(t *testing.T) {
// pubsub has already closed.
select {
case <-ts1.onMsgCh:
t.Fatalf("The callback was invoked after pubsub being stopped")
t.Fatal("The callback was invoked after pubsub being stopped")
case <-ts2.onMsgCh: