Skip to content

Commit

Permalink
Merge pull request #374 from redis/fix-rueidislock-infinite-loop
Browse files Browse the repository at this point in the history
fix: rueidislock's infinite loops when receiving bulk invalidations
  • Loading branch information
rueian committed Sep 22, 2023
2 parents 0608d8b + dbcc67f commit ebb261d
Show file tree
Hide file tree
Showing 2 changed files with 72 additions and 22 deletions.
44 changes: 23 additions & 21 deletions rueidislock/lock.go
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,6 @@ type locker struct {
totalcnt int32
noloop bool
setpx bool
closed bool
}

type gate struct {
Expand Down Expand Up @@ -188,6 +187,10 @@ func (m *locker) waitgate(ctx context.Context, name string) (g *gate, err error)
m.mu.Lock()
g, ok := m.gates[name]
if !ok {
if m.gates == nil {
m.mu.Unlock()
return nil, ErrLockerClosed
}
g = makegate(m.totalcnt)
g.w++
m.gates[name] = g
Expand All @@ -200,7 +203,7 @@ func (m *locker) waitgate(ctx context.Context, name string) (g *gate, err error)
select {
case <-ctx.Done():
m.mu.Lock()
if g.w--; g.w == 0 {
if g.w--; g.w == 0 && m.gates[name] == g {
delete(m.gates, name)
}
m.mu.Unlock()
Expand All @@ -216,7 +219,7 @@ func (m *locker) waitgate(ctx context.Context, name string) (g *gate, err error)
func (m *locker) trygate(name string) (g *gate) {
m.mu.Lock()
_, ok := m.gates[name]
if !ok {
if !ok && m.gates != nil {
g = makegate(m.totalcnt)
g.w++
m.gates[name] = g
Expand All @@ -227,15 +230,20 @@ func (m *locker) trygate(name string) (g *gate) {

func (m *locker) onInvalidations(messages []rueidis.RedisMessage) {
if messages == nil {
m.mu.Lock()
m.mu.RLock()
for _, g := range m.gates {
close(g.ch)
select {
case g.ch <- struct{}{}:
default:
}
for _, ch := range g.csc {
close(ch)
select {
case ch <- struct{}{}:
default:
}
}
}
m.gates = make(map[string]*gate)
m.mu.Unlock()
m.mu.RUnlock()
}
for _, msg := range messages {
k, _ := msg.ToString()
Expand Down Expand Up @@ -306,7 +314,9 @@ func (m *locker) try(ctx context.Context, cancel context.CancelFunc, name string
}
m.mu.Lock()
if g.w--; g.w == 0 {
delete(m.gates, name)
if m.gates[name] == g {
delete(m.gates, name)
}
} else {
if g, ok := m.gates[name]; ok {
g.ch <- struct{}{}
Expand Down Expand Up @@ -375,27 +385,19 @@ func (m *locker) WithContext(ctx context.Context, name string) (context.Context,
}
}
if cancel(); err != nil {
if err == ErrLockerClosed {
m.mu.RLock()
closed := m.closed
m.mu.RUnlock()
if !closed {
continue
}
}
return ctx, cancel, err
}
}
}

func (m *locker) Close() {
m.mu.Lock()
m.closed = true
for _, g := range m.gates {
close(g.ch)
}
m.gates = nil
m.mu.Unlock()
m.client.Close()
if m.noloop {
m.onInvalidations(nil)
}
}

var (
Expand Down
50 changes: 49 additions & 1 deletion rueidislock/lock_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ import (
"github.com/redis/rueidis"
)

var address = []string{"127.0.0.1:6376"}
var address = []string{"127.0.0.1:6379"}

func newLocker(t *testing.T, noLoop, setpx, nocsc bool) *locker {
impl, err := NewLocker(LockerOption{
Expand Down Expand Up @@ -451,6 +451,9 @@ func TestLocker_Close(t *testing.T) {
if err := ctx.Err(); !errors.Is(err, context.Canceled) {
t.Fatal(err)
}
if _, _, err := locker.WithContext(context.Background(), lck); err != ErrLockerClosed {
t.Error(err)
}
}
for _, nocsc := range []bool{false, true} {
t.Run("Tracking Loop", func(t *testing.T) {
Expand Down Expand Up @@ -509,3 +512,48 @@ func TestLocker_RetryErrLockerClosed(t *testing.T) {
})
}
}

func TestLocker_Flush(t *testing.T) {
test := func(t *testing.T, noLoop, setpx, nocsc bool) {
client, err := rueidis.NewClient(rueidis.ClientOption{InitAddress: address})
if err != nil {
t.Fatal(err)
}
defer client.Close()

locker := newLocker(t, noLoop, setpx, nocsc)

lck := strconv.Itoa(rand.Int())
ctx, _, err := locker.WithContext(context.Background(), lck)
if err != nil {
t.Fatal(err)
}

if err := client.Do(context.Background(), client.B().Flushall().Build()).Error(); err != nil {
t.Fatal(err)
}

<-ctx.Done()

if err := ctx.Err(); !errors.Is(err, context.Canceled) {
t.Fatal(err)
}

ctx, cancel, err := locker.WithContext(context.Background(), strconv.Itoa(rand.Int()))
if err != nil {
t.Fatal(err)
}
cancel()
}
for _, nocsc := range []bool{false, true} {
t.Run("Tracking Loop", func(t *testing.T) {
test(t, false, false, nocsc)
})
t.Run("Tracking NoLoop", func(t *testing.T) {
test(t, true, false, nocsc)
})
t.Run("SET PX", func(t *testing.T) {
test(t, true, true, nocsc)
})
}
}

0 comments on commit ebb261d

Please sign in to comment.