Skip to content

Commit

Permalink
all: add tests
Browse files Browse the repository at this point in the history
  • Loading branch information
schzhn committed Mar 28, 2024
1 parent c9015e7 commit 548a15c
Show file tree
Hide file tree
Showing 6 changed files with 177 additions and 89 deletions.
62 changes: 0 additions & 62 deletions internal/client/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -175,65 +175,3 @@ func (r *Runtime) IsEmpty() (ok bool) {
func (r *Runtime) Addr() (ip netip.Addr) {
return r.ip
}

// RuntimeIndex stores information about runtime clients.
type RuntimeIndex struct {
// index maps IP address to runtime client.
index map[netip.Addr]*Runtime
}

// NewRuntimeIndex returns initialized runtime index.
func NewRuntimeIndex() (ri *RuntimeIndex) {
return &RuntimeIndex{
index: map[netip.Addr]*Runtime{},
}
}

// Client returns the saved runtime client by ip. If no such client exists,
// returns nil.
func (ri *RuntimeIndex) Client(ip netip.Addr) (rc *Runtime, ok bool) {
rc, ok = ri.index[ip]

return rc, ok
}

// Add saves the runtime client in the index. IP address of a client must be
// unique. See [Client].
func (ri *RuntimeIndex) Add(rc *Runtime) {
ip := rc.Addr()
ri.index[ip] = rc
}

// Size returns the number of the runtime clients.
func (ri *RuntimeIndex) Size() (n int) {
return len(ri.index)
}

// Range calls cb for each runtime client in an undefined order.
func (ri *RuntimeIndex) Range(cb func(rc *Runtime) (cont bool)) {
for _, rc := range ri.index {
if !cb(rc) {
return
}
}
}

// Delete removes the runtime client by ip.
func (ri *RuntimeIndex) Delete(ip netip.Addr) {
delete(ri.index, ip)
}

// DeleteBySrc removes all runtime clients that have information only from the
// specified source and returns the number of removed clients.
func (ri *RuntimeIndex) DeleteBySrc(src Source) (n int) {
for ip, rc := range ri.index {
rc.Unset(src)

if rc.IsEmpty() {
delete(ri.index, ip)
n++
}
}

return n
}
63 changes: 63 additions & 0 deletions internal/client/runtimeindex.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
package client

import "net/netip"

// RuntimeIndex stores information about runtime clients.
type RuntimeIndex struct {
// index maps IP address to runtime client.
index map[netip.Addr]*Runtime
}

// NewRuntimeIndex returns initialized runtime index.
func NewRuntimeIndex() (ri *RuntimeIndex) {
return &RuntimeIndex{
index: map[netip.Addr]*Runtime{},
}
}

// Client returns the saved runtime client by ip. If no such client exists,
// returns nil.
func (ri *RuntimeIndex) Client(ip netip.Addr) (rc *Runtime) {
return ri.index[ip]
}

// Add saves the runtime client in the index. IP address of a client must be
// unique. See [Runtime.Client]. rc must not be nil.
func (ri *RuntimeIndex) Add(rc *Runtime) {
ip := rc.Addr()
ri.index[ip] = rc
}

// Size returns the number of the runtime clients.
func (ri *RuntimeIndex) Size() (n int) {
return len(ri.index)
}

// Range calls f for each runtime client in an undefined order.
func (ri *RuntimeIndex) Range(f func(rc *Runtime) (cont bool)) {
for _, rc := range ri.index {
if !f(rc) {
return
}
}
}

// Delete removes the runtime client by ip.
func (ri *RuntimeIndex) Delete(ip netip.Addr) {
delete(ri.index, ip)
}

// DeleteBySource removes all runtime clients that have information only from
// the specified source and returns the number of removed clients.
func (ri *RuntimeIndex) DeleteBySource(src Source) (n int) {
for ip, rc := range ri.index {
rc.Unset(src)

if rc.IsEmpty() {
delete(ri.index, ip)
n++
}
}

return n
}
85 changes: 85 additions & 0 deletions internal/client/runtimeindex_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
package client_test

import (
"net/netip"
"testing"

"github.com/AdguardTeam/AdGuardHome/internal/client"
"github.com/stretchr/testify/assert"
)

func TestRuntimeIndex(t *testing.T) {
const cliSrc = client.SourceARP

var (
ip1 = netip.MustParseAddr("1.1.1.1")
ip2 = netip.MustParseAddr("2.2.2.2")
ip3 = netip.MustParseAddr("3.3.3.3")
)

ri := client.NewRuntimeIndex()
currentSize := 0

testCases := []struct {
ip netip.Addr
name string
hosts []string
src client.Source
}{{
src: cliSrc,
ip: ip1,
name: "1",
hosts: []string{"host1"},
}, {
src: cliSrc,
ip: ip2,
name: "2",
hosts: []string{"host2"},
}, {
src: cliSrc,
ip: ip3,
name: "3",
hosts: []string{"host3"},
}}

for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
rc := client.NewRuntime(tc.ip)
rc.SetInfo(tc.src, tc.hosts)

ri.Add(rc)
currentSize++

got := ri.Client(tc.ip)
assert.Equal(t, rc, got)
})
}

t.Run("size", func(t *testing.T) {
assert.Equal(t, currentSize, ri.Size())
})

t.Run("range", func(t *testing.T) {
s := 0

ri.Range(func(rc *client.Runtime) (cont bool) {
s++

return true
})

assert.Equal(t, currentSize, s)
})

t.Run("delete", func(t *testing.T) {
ri.Delete(ip1)
currentSize--

assert.Equal(t, currentSize, ri.Size())
})

t.Run("delete_by_src", func(t *testing.T) {
assert.Equal(t, currentSize, ri.DeleteBySource(cliSrc))
assert.Equal(t, 0, ri.Size())
})
}
43 changes: 24 additions & 19 deletions internal/home/clients.go
Original file line number Diff line number Diff line change
Expand Up @@ -363,8 +363,8 @@ func (clients *clientsContainer) clientSource(ip netip.Addr) (src client.Source)
return client.SourcePersistent
}

rc, ok := clients.runtimeIndex.Client(ip)
if ok {
rc := clients.runtimeIndex.Client(ip)
if rc != nil {
src, _ = rc.Info()
}

Expand Down Expand Up @@ -420,9 +420,8 @@ func (clients *clientsContainer) clientOrArtificial(
}, false
}

var rc *client.Runtime
rc, ok = clients.findRuntimeClient(ip)
if ok {
rc := clients.findRuntimeClient(ip)
if rc != nil {
_, host := rc.Info()

return &querylog.Client{
Expand Down Expand Up @@ -554,9 +553,9 @@ func (clients *clientsContainer) findDHCP(ip netip.Addr) (c *client.Persistent,

// runtimeClient returns a runtime client from internal index. Note that it
// doesn't include DHCP clients.
func (clients *clientsContainer) runtimeClient(ip netip.Addr) (rc *client.Runtime, ok bool) {
func (clients *clientsContainer) runtimeClient(ip netip.Addr) (rc *client.Runtime) {
if ip == (netip.Addr{}) {
return nil, false
return nil
}

clients.lock.Lock()
Expand All @@ -566,21 +565,21 @@ func (clients *clientsContainer) runtimeClient(ip netip.Addr) (rc *client.Runtim
}

// findRuntimeClient finds a runtime client by their IP.
func (clients *clientsContainer) findRuntimeClient(ip netip.Addr) (rc *client.Runtime, ok bool) {
rc, ok = clients.runtimeClient(ip)
func (clients *clientsContainer) findRuntimeClient(ip netip.Addr) (rc *client.Runtime) {
rc = clients.runtimeClient(ip)
host := clients.dhcp.HostByIP(ip)

if host != "" {
if !ok {
if rc == nil {
rc = client.NewRuntime(ip)
}

rc.SetInfo(client.SourceDHCP, []string{host})

return rc, true
return rc
}

return rc, ok
return rc
}

// check validates the client. It also sorts the client tags.
Expand Down Expand Up @@ -732,8 +731,8 @@ func (clients *clientsContainer) setWHOISInfo(ip netip.Addr, wi *whois.Info) {
return
}

rc, ok := clients.runtimeIndex.Client(ip)
if !ok {
rc := clients.runtimeIndex.Client(ip)
if rc == nil {
// Create a RuntimeClient implicitly so that we don't do this check
// again.
rc = client.NewRuntime(ip)
Expand Down Expand Up @@ -796,8 +795,8 @@ func (clients *clientsContainer) addHostLocked(
host string,
src client.Source,
) (ok bool) {
rc, ok := clients.runtimeIndex.Client(ip)
if !ok {
rc := clients.runtimeIndex.Client(ip)
if rc == nil {
if src < client.SourceDHCP {
if clients.dhcp.HostByIP(ip) != "" {
return false
Expand All @@ -810,7 +809,13 @@ func (clients *clientsContainer) addHostLocked(

rc.SetInfo(src, []string{host})

log.Debug("clients: adding client info %s -> %q %q [%d]", ip, src, host, clients.runtimeIndex.Size())
log.Debug(
"clients: adding client info %s -> %q %q [%d]",
ip,
src,
host,
clients.runtimeIndex.Size(),
)

return true
}
Expand All @@ -821,7 +826,7 @@ func (clients *clientsContainer) addFromHostsFile(hosts *hostsfile.DefaultStorag
clients.lock.Lock()
defer clients.lock.Unlock()

deleted := clients.runtimeIndex.DeleteBySrc(client.SourceHostsFile)
deleted := clients.runtimeIndex.DeleteBySource(client.SourceHostsFile)
log.Debug("clients: removed %d client aliases from system hosts file", deleted)

added := 0
Expand Down Expand Up @@ -861,7 +866,7 @@ func (clients *clientsContainer) addFromSystemARP() {
clients.lock.Lock()
defer clients.lock.Unlock()

deleted := clients.runtimeIndex.DeleteBySrc(client.SourceARP)
deleted := clients.runtimeIndex.DeleteBySource(client.SourceARP)
log.Debug("clients: removed %d client aliases from arp neighborhood", deleted)

added := 0
Expand Down
9 changes: 3 additions & 6 deletions internal/home/clients_internal_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -244,9 +244,8 @@ func TestClientsWHOIS(t *testing.T) {
t.Run("new_client", func(t *testing.T) {
ip := netip.MustParseAddr("1.1.1.255")
clients.setWHOISInfo(ip, whois)
rc, ok := clients.runtimeIndex.Client(ip)
rc := clients.runtimeIndex.Client(ip)
require.NotNil(t, rc)
require.True(t, ok)

assert.Equal(t, whois, rc.WHOIS())
})
Expand All @@ -257,9 +256,8 @@ func TestClientsWHOIS(t *testing.T) {
assert.True(t, ok)

clients.setWHOISInfo(ip, whois)
rc, ok := clients.runtimeIndex.Client(ip)
rc := clients.runtimeIndex.Client(ip)
require.NotNil(t, rc)
require.True(t, ok)

assert.Equal(t, whois, rc.WHOIS())
})
Expand All @@ -276,9 +274,8 @@ func TestClientsWHOIS(t *testing.T) {
assert.True(t, ok)

clients.setWHOISInfo(ip, whois)
rc, ok := clients.runtimeIndex.Client(ip)
rc := clients.runtimeIndex.Client(ip)
require.Nil(t, rc)
require.False(t, ok)

assert.True(t, clients.remove("client1"))
})
Expand Down
4 changes: 2 additions & 2 deletions internal/home/clientshttp.go
Original file line number Diff line number Diff line change
Expand Up @@ -465,8 +465,8 @@ func (clients *clientsContainer) handleFindClient(w http.ResponseWriter, r *http
// /etc/hosts tables, DHCP leases, or blocklists. cj is guaranteed to be
// non-nil.
func (clients *clientsContainer) findRuntime(ip netip.Addr, idStr string) (cj *clientJSON) {
rc, ok := clients.findRuntimeClient(ip)
if !ok {
rc := clients.findRuntimeClient(ip)
if rc == nil {
// It is still possible that the IP used to be in the runtime clients
// list, but then the server was reloaded. So, check the DNS server's
// blocked IP list.
Expand Down

0 comments on commit 548a15c

Please sign in to comment.