diff --git a/internal/aghalg/orderedmap.go b/internal/aghalg/orderedmap.go index ec915e50cc7..5e15cf0733f 100644 --- a/internal/aghalg/orderedmap.go +++ b/internal/aghalg/orderedmap.go @@ -5,7 +5,7 @@ import ( ) // SortedMap is a map that keeps elements in order with internal sorting -// function. +// function. Must be initialised by the [NewSortedMap]. type SortedMap[K comparable, V any] struct { vals map[K]V cmp func(a, b K) (res int) @@ -23,36 +23,50 @@ func NewSortedMap[K comparable, V any](cmp func(a, b K) (res int)) SortedMap[K, } } -// Set adds val with key to the sorted map. +// Set adds val with key to the sorted map. It panics if the m is nil. func (m *SortedMap[K, V]) Set(key K, val V) { + m.vals[key] = val + i, has := slices.BinarySearchFunc(m.keys, key, m.cmp) if has { m.keys[i] = key - m.vals[key] = val + } else { + m.keys = slices.Insert(m.keys, i, key) + } +} +// Get returns val by key from the sorted map. +func (m *SortedMap[K, V]) Get(key K) (val V, ok bool) { + if m == nil { return } - m.keys = slices.Insert(m.keys, i, key) - m.vals[key] = val -} + val, ok = m.vals[key] -// Get returns val by key from the sorted map. -func (m *SortedMap[K, V]) Get(key K) (val V) { - return m.vals[key] + return val, ok } // Del removes the value by key from the sorted map. func (m *SortedMap[K, V]) Del(key K) { - i, has := slices.BinarySearchFunc(m.keys, key, m.cmp) - if has { - m.keys = slices.Delete(m.keys, i, i+1) - delete(m.vals, key) + if m == nil { + return } + + if _, has := m.vals[key]; !has { + return + } + + delete(m.vals, key) + i, _ := slices.BinarySearchFunc(m.keys, key, m.cmp) + m.keys = slices.Delete(m.keys, i, i+1) } // Clear removes all elements from the sorted map. func (m *SortedMap[K, V]) Clear() { + if m == nil { + return + } + // TODO(s.chzhen): Use built-in clear in Go 1.21. m.keys = nil m.vals = make(map[K]V) @@ -61,6 +75,10 @@ func (m *SortedMap[K, V]) Clear() { // Range calls cb for each element of the map, sorted by m.cmp. If cb returns // false it stops. func (m *SortedMap[K, V]) Range(cb func(K, V) (cont bool)) { + if m == nil { + return + } + for _, k := range m.keys { if !cb(k, m.vals[k]) { return diff --git a/internal/aghalg/orderedmap_test.go b/internal/aghalg/orderedmap_test.go index e7a5ae50cdc..46128ed0e5e 100644 --- a/internal/aghalg/orderedmap_test.go +++ b/internal/aghalg/orderedmap_test.go @@ -36,13 +36,18 @@ func TestNewSortedMap(t *testing.T) { assert.Equal(t, letters, gotLetters) assert.Equal(t, nums, gotNums) - assert.Equal(t, nums[0], m.Get(letters[0])) + + n, ok := m.Get(letters[0]) + assert.True(t, ok) + assert.Equal(t, nums[0], n) }) t.Run("clear", func(t *testing.T) { lastLetter := letters[len(letters)-1] m.Del(lastLetter) - assert.Equal(t, 0, m.Get(lastLetter)) + + _, ok := m.Get(lastLetter) + assert.False(t, ok) m.Clear() @@ -56,3 +61,35 @@ func TestNewSortedMap(t *testing.T) { assert.Len(t, gotLetters, 0) }) } + +func TestNewSortedMap_nil(t *testing.T) { + const ( + key = "key" + val = "val" + ) + + var m SortedMap[string, string] + + assert.Panics(t, func() { + m.Set(key, val) + }) + + assert.NotPanics(t, func() { + _, ok := m.Get(key) + assert.False(t, ok) + }) + + assert.NotPanics(t, func() { + m.Range(func(_, _ string) (cont bool) { + return true + }) + }) + + assert.NotPanics(t, func() { + m.Del(key) + }) + + assert.NotPanics(t, func() { + m.Clear() + }) +}