diff --git a/dot/build_spec.go b/dot/build_spec.go index 361fb151ad..fb877196a2 100644 --- a/dot/build_spec.go +++ b/dot/build_spec.go @@ -4,7 +4,6 @@ package dot import ( - "context" "encoding/json" "fmt" "os" @@ -95,21 +94,15 @@ func BuildFromDB(path string) (*BuildSpec, error) { tmpGen.Genesis.Raw = make(map[string]map[string]string) tmpGen.Genesis.Runtime = make(map[string]map[string]interface{}) - // BootstrapMailer should not return an error here since there is no URLs to connect to - disabledTelemetry, err := telemetry.BootstrapMailer(context.TODO(), nil, false, nil) - if err != nil { - panic("telemetry should not fail at BuildFromDB function: " + err.Error()) - } - config := state.Config{ Path: path, LogLevel: log.Info, - Telemetry: disabledTelemetry, + Telemetry: telemetry.NewNoopMailer(), } stateSrvc := state.NewService(config) - err = stateSrvc.SetupBase() + err := stateSrvc.SetupBase() if err != nil { return nil, fmt.Errorf("cannot setup state database: %w", err) } diff --git a/dot/build_spec_integration_test.go b/dot/build_spec_integration_test.go index e70e916a6e..9812a7c7b1 100644 --- a/dot/build_spec_integration_test.go +++ b/dot/build_spec_integration_test.go @@ -22,7 +22,15 @@ const codeHex = "0x3a636f6465" func TestBuildFromGenesis_Integration(t *testing.T) { t.Parallel() - file := genesis.CreateTestGenesisJSONFile(t, false) + genesisFields := genesis.Fields{ + Raw: map[string]map[string]string{}, + Runtime: map[string]map[string]interface{}{ + "System": { + "code": "mocktestcode", + }, + }, + } + file := genesis.CreateTestGenesisJSONFile(t, genesisFields) bs, err := BuildFromGenesis(file, 0) const expectedChainType = "TESTCHAINTYPE" @@ -43,7 +51,7 @@ func TestBuildFromGenesis_Integration(t *testing.T) { jGen := genesis.Genesis{} err = json.Unmarshal(hr, &jGen) require.NoError(t, err) - genesis.TestGenesis.Genesis = genesis.TestFieldsHR + genesis.TestGenesis.Genesis = genesisFields require.Equal(t, genesis.TestGenesis.Genesis.Runtime, jGen.Genesis.Runtime) require.Equal(t, expectedChainType, jGen.ChainType) require.Equal(t, expectedProperties, jGen.Properties) diff --git a/dot/build_spec_test.go b/dot/build_spec_test.go index cbbd87420d..d04f1b1cdb 100644 --- a/dot/build_spec_test.go +++ b/dot/build_spec_test.go @@ -167,7 +167,15 @@ func TestBuildFromDB(t *testing.T) { } func TestBuildFromGenesis(t *testing.T) { - testGenesisPath := genesis.CreateTestGenesisJSONFile(t, false) + genesisFields := genesis.Fields{ + Raw: map[string]map[string]string{}, + Runtime: map[string]map[string]interface{}{ + "System": { + "code": "mocktestcode", + }, + }, + } + testGenesisPath := genesis.CreateTestGenesisJSONFile(t, genesisFields) type args struct { path string diff --git a/dot/network/config.go b/dot/network/config.go index 0ffbbc3554..8b849467a0 100644 --- a/dot/network/config.go +++ b/dot/network/config.go @@ -97,8 +97,6 @@ type Config struct { // telemetryInterval how often to send telemetry metrics telemetryInterval time.Duration - noPreAllocate bool // internal option - batchSize int // internal option // SlotDuration is the slot duration to produce a block diff --git a/dot/network/connmgr.go b/dot/network/connmgr.go index f35cdf7971..bb0add3c3a 100644 --- a/dot/network/connmgr.go +++ b/dot/network/connmgr.go @@ -110,18 +110,6 @@ func (cm *ConnManager) ListenClose(n network.Network, addr ma.Multiaddr) { "Host %s stopped listening on address %s", n.LocalPeer(), addr) } -// returns a slice of peers that are unprotected and may be pruned. -func (cm *ConnManager) unprotectedPeers(peers []peer.ID) []peer.ID { - unprot := []peer.ID{} - for _, id := range peers { - if !cm.IsProtected(id, "") && !cm.isPersistent(id) { - unprot = append(unprot, id) - } - } - - return unprot -} - // Connected is called when a connection opened func (cm *ConnManager) Connected(n network.Network, c network.Conn) { logger.Tracef( @@ -141,8 +129,3 @@ func (cm *ConnManager) Disconnected(_ network.Network, c network.Conn) { cm.disconnectHandler(c.RemotePeer()) } } - -func (cm *ConnManager) isPersistent(p peer.ID) bool { - _, ok := cm.persistentPeers.Load(p) - return ok -} diff --git a/dot/network/connmgr_test.go b/dot/network/connmgr_test.go index 1ed6001816..01fad5ba1a 100644 --- a/dot/network/connmgr_test.go +++ b/dot/network/connmgr_test.go @@ -115,13 +115,13 @@ func TestProtectUnprotectPeer(t *testing.T) { require.True(t, cm.IsProtected(p1, "")) require.True(t, cm.IsProtected(p2, "")) - unprot := cm.unprotectedPeers([]peer.ID{p1, p2, p3, p4}) + unprot := unprotectedPeers(cm, []peer.ID{p1, p2, p3, p4}) require.Equal(t, unprot, []peer.ID{p3, p4}) cm.Unprotect(p1, "") cm.Unprotect(p2, "") - unprot = cm.unprotectedPeers([]peer.ID{p1, p2, p3, p4}) + unprot = unprotectedPeers(cm, []peer.ID{p1, p2, p3, p4}) require.Equal(t, unprot, []peer.ID{p1, p2, p3, p4}) } @@ -224,7 +224,7 @@ func TestSetReservedPeer(t *testing.T) { addrA := nodes[0].host.multiaddrs()[0] addrB := nodes[1].host.multiaddrs()[0] - addrC := nodes[2].host.addrInfo() + addrC := addrInfo(nodes[2].host) config := &Config{ BasePath: t.TempDir(), diff --git a/dot/network/discovery_test.go b/dot/network/discovery_test.go index 9858d9e428..dfb3165654 100644 --- a/dot/network/discovery_test.go +++ b/dot/network/discovery_test.go @@ -121,7 +121,7 @@ func TestBeginDiscovery(t *testing.T) { nodeB := createTestService(t, configB) nodeB.noGossip = true - addrInfoB := nodeB.host.addrInfo() + addrInfoB := addrInfo(nodeB.host) err := nodeA.host.connect(addrInfoB) if failedToDial(err) { time.Sleep(TestBackoffTimeout) @@ -170,7 +170,7 @@ func TestBeginDiscovery_ThreeNodes(t *testing.T) { nodeC.noGossip = true // connect A and B - addrInfoB := nodeB.host.addrInfo() + addrInfoB := addrInfo(nodeB.host) err := nodeA.host.connect(addrInfoB) if failedToDial(err) { time.Sleep(TestBackoffTimeout) @@ -179,7 +179,7 @@ func TestBeginDiscovery_ThreeNodes(t *testing.T) { require.NoError(t, err) // connect A and C - addrInfoC := nodeC.host.addrInfo() + addrInfoC := addrInfo(nodeC.host) err = nodeA.host.connect(addrInfoC) if failedToDial(err) { time.Sleep(TestBackoffTimeout) diff --git a/dot/network/gossip_test.go b/dot/network/gossip_test.go index 6c80f6905b..7bbe349f53 100644 --- a/dot/network/gossip_test.go +++ b/dot/network/gossip_test.go @@ -42,7 +42,7 @@ func TestGossip(t *testing.T) { handlerB := newTestStreamHandler(testBlockAnnounceMessageDecoder) nodeB.host.registerStreamHandler(nodeB.host.protocolID, handlerB.handleStream) - addrInfoA := nodeA.host.addrInfo() + addrInfoA := addrInfo(nodeA.host) err := nodeB.host.connect(addrInfoA) // retry connect if "failed to dial" error if failedToDial(err) { @@ -70,7 +70,7 @@ func TestGossip(t *testing.T) { } require.NoError(t, err) - addrInfoB := nodeB.host.addrInfo() + addrInfoB := addrInfo(nodeB.host) err = nodeC.host.connect(addrInfoB) // retry connect if "failed to dial" error if failedToDial(err) { diff --git a/dot/network/helpers_test.go b/dot/network/helpers_test.go index fb5d5f8d1e..bd68f2a237 100644 --- a/dot/network/helpers_test.go +++ b/dot/network/helpers_test.go @@ -147,3 +147,28 @@ func testBlockAnnounceHandshakeDecoder(in []byte, _ peer.ID, _ bool) (Message, e err := msg.Decode(in) return msg, err } + +// addrInfo returns the libp2p peer.AddrInfo of the host +func addrInfo(h *host) peer.AddrInfo { + return peer.AddrInfo{ + ID: h.p2pHost.ID(), + Addrs: h.p2pHost.Addrs(), + } +} + +// returns a slice of peers that are unprotected and may be pruned. +func unprotectedPeers(cm *ConnManager, peers []peer.ID) []peer.ID { + unprot := []peer.ID{} + for _, id := range peers { + if cm.IsProtected(id, "") { + continue + } + + _, isPersistent := cm.persistentPeers.Load(id) + if !isPersistent { + unprot = append(unprot, id) + } + } + + return unprot +} diff --git a/dot/network/host.go b/dot/network/host.go index 9426eee7fa..124add7975 100644 --- a/dot/network/host.go +++ b/dot/network/host.go @@ -392,14 +392,6 @@ func (h *host) peerCount() int { return len(peers) } -// addrInfo returns the libp2p peer.AddrInfo of the host -func (h *host) addrInfo() peer.AddrInfo { - return peer.AddrInfo{ - ID: h.p2pHost.ID(), - Addrs: h.p2pHost.Addrs(), - } -} - // multiaddrs returns the multiaddresses of the host func (h *host) multiaddrs() (multiaddrs []ma.Multiaddr) { addrs := h.p2pHost.Addrs() diff --git a/dot/network/host_test.go b/dot/network/host_test.go index baa812a9c4..4cb5bfab8e 100644 --- a/dot/network/host_test.go +++ b/dot/network/host_test.go @@ -29,7 +29,7 @@ func TestExternalAddrs(t *testing.T) { node := createTestService(t, config) - addrInfo := node.host.addrInfo() + addrInfo := addrInfo(node.host) privateIPs, err := newPrivateIPFilters() require.NoError(t, err) @@ -60,7 +60,7 @@ func TestExternalAddrsPublicIP(t *testing.T) { } node := createTestService(t, config) - addrInfo := node.host.addrInfo() + addrInfo := addrInfo(node.host) privateIPs, err := newPrivateIPFilters() require.NoError(t, err) @@ -92,7 +92,7 @@ func TestExternalAddrsPublicDNS(t *testing.T) { } node := createTestService(t, config) - addrInfo := node.host.addrInfo() + addrInfo := addrInfo(node.host) expected := []ma.Multiaddr{ mustNewMultiAddr("/ip4/127.0.0.1/tcp/7001"), @@ -126,7 +126,7 @@ func TestConnect(t *testing.T) { nodeB := createTestService(t, configB) nodeB.noGossip = true - addrInfoB := nodeB.host.addrInfo() + addrInfoB := addrInfo(nodeB.host) err := nodeA.host.connect(addrInfoB) // retry connect if "failed to dial" error if failedToDial(err) { @@ -207,7 +207,7 @@ func TestSend(t *testing.T) { handler := newTestStreamHandler(testBlockRequestMessageDecoder) nodeB.host.registerStreamHandler(nodeB.host.protocolID, handler.handleStream) - addrInfoB := nodeB.host.addrInfo() + addrInfoB := addrInfo(nodeB.host) err := nodeA.host.connect(addrInfoB) // retry connect if "failed to dial" error if failedToDial(err) { @@ -244,7 +244,7 @@ func TestExistingStream(t *testing.T) { handlerA := newTestStreamHandler(testBlockRequestMessageDecoder) nodeA.host.registerStreamHandler(nodeA.host.protocolID, handlerA.handleStream) - addrInfoA := nodeA.host.addrInfo() + addrInfoA := addrInfo(nodeA.host) configB := &Config{ BasePath: t.TempDir(), Port: availablePort(t), @@ -257,7 +257,7 @@ func TestExistingStream(t *testing.T) { handlerB := newTestStreamHandler(testBlockRequestMessageDecoder) nodeB.host.registerStreamHandler(nodeB.host.protocolID, handlerB.handleStream) - addrInfoB := nodeB.host.addrInfo() + addrInfoB := addrInfo(nodeB.host) err := nodeA.host.connect(addrInfoB) // retry connect if "failed to dial" error if failedToDial(err) { @@ -320,7 +320,7 @@ func TestStreamCloseMetadataCleanup(t *testing.T) { handlerB := newTestStreamHandler(testBlockAnnounceHandshakeDecoder) nodeB.host.registerStreamHandler(blockAnnounceID, handlerB.handleStream) - addrInfoB := nodeB.host.addrInfo() + addrInfoB := addrInfo(nodeB.host) err := nodeA.host.connect(addrInfoB) // retry connect if "failed to dial" error if failedToDial(err) { @@ -386,7 +386,7 @@ func Test_PeerSupportsProtocol(t *testing.T) { nodeB := createTestService(t, configB) nodeB.noGossip = true - addrInfoB := nodeB.host.addrInfo() + addrInfoB := addrInfo(nodeB.host) err := nodeA.host.connect(addrInfoB) // retry connect if "failed to dial" error if failedToDial(err) { @@ -490,7 +490,7 @@ func Test_RemoveReservedPeers(t *testing.T) { time.Sleep(100 * time.Millisecond) require.Equal(t, 1, nodeA.host.peerCount()) - pID := nodeB.host.addrInfo().ID.String() + pID := addrInfo(nodeB.host).ID.String() err = nodeA.host.removeReservedPeers(pID) require.NoError(t, err) @@ -498,7 +498,7 @@ func Test_RemoveReservedPeers(t *testing.T) { time.Sleep(100 * time.Millisecond) require.Equal(t, 1, nodeA.host.peerCount()) - isProtected := nodeA.host.p2pHost.ConnManager().IsProtected(nodeB.host.addrInfo().ID, "") + isProtected := nodeA.host.p2pHost.ConnManager().IsProtected(addrInfo(nodeB.host).ID, "") require.False(t, isProtected) err = nodeA.host.removeReservedPeers("unknown_perr_id") @@ -531,7 +531,7 @@ func TestStreamCloseEOF(t *testing.T) { nodeB.host.registerStreamHandler(nodeB.host.protocolID, handler.handleStream) require.False(t, handler.exit) - addrInfoB := nodeB.host.addrInfo() + addrInfoB := addrInfo(nodeB.host) err := nodeA.host.connect(addrInfoB) // retry connect if "failed to dial" error if failedToDial(err) { @@ -582,7 +582,7 @@ func TestPeerConnect(t *testing.T) { nodeB := createTestService(t, configB) nodeB.noGossip = true - addrInfoB := nodeB.host.addrInfo() + addrInfoB := addrInfo(nodeB.host) nodeA.host.p2pHost.Peerstore().AddAddrs(addrInfoB.ID, addrInfoB.Addrs, peerstore.PermanentAddrTTL) nodeA.host.cm.peerSetHandler.AddPeer(0, addrInfoB.ID) @@ -620,7 +620,7 @@ func TestBannedPeer(t *testing.T) { nodeB := createTestService(t, configB) nodeB.noGossip = true - addrInfoB := nodeB.host.addrInfo() + addrInfoB := addrInfo(nodeB.host) nodeA.host.p2pHost.Peerstore().AddAddrs(addrInfoB.ID, addrInfoB.Addrs, peerstore.PermanentAddrTTL) nodeA.host.cm.peerSetHandler.AddPeer(0, addrInfoB.ID) @@ -673,7 +673,7 @@ func TestPeerReputation(t *testing.T) { nodeB := createTestService(t, configB) nodeB.noGossip = true - addrInfoB := nodeB.host.addrInfo() + addrInfoB := addrInfo(nodeB.host) nodeA.host.p2pHost.Peerstore().AddAddrs(addrInfoB.ID, addrInfoB.Addrs, peerstore.PermanentAddrTTL) nodeA.host.cm.peerSetHandler.AddPeer(0, addrInfoB.ID) diff --git a/dot/network/light_test.go b/dot/network/light_test.go index 7cbc037921..c6ee57ff6c 100644 --- a/dot/network/light_test.go +++ b/dot/network/light_test.go @@ -104,7 +104,7 @@ func TestHandleLightMessage_Response(t *testing.T) { } b := createTestService(t, configB) - addrInfoB := b.host.addrInfo() + addrInfoB := addrInfo(b.host) err := s.host.connect(addrInfoB) // retry connect if "failed to dial" error if failedToDial(err) { diff --git a/dot/network/notifications.go b/dot/network/notifications.go index 38fa9fa731..7ddc312d11 100644 --- a/dot/network/notifications.go +++ b/dot/network/notifications.go @@ -9,8 +9,7 @@ import ( "io" "time" - "github.com/libp2p/go-libp2p-core/mux" - libp2pnetwork "github.com/libp2p/go-libp2p-core/network" + "github.com/libp2p/go-libp2p-core/network" "github.com/libp2p/go-libp2p-core/peer" "github.com/libp2p/go-libp2p-core/protocol" @@ -84,10 +83,10 @@ type handshakeData struct { received bool validated bool handshake Handshake - stream libp2pnetwork.Stream + stream network.Stream } -func newHandshakeData(received, validated bool, stream libp2pnetwork.Stream) *handshakeData { +func newHandshakeData(received, validated bool, stream network.Stream) *handshakeData { return &handshakeData{ received: received, validated: validated, @@ -123,7 +122,7 @@ func (s *Service) createNotificationsMessageHandler( notificationsMessageHandler NotificationsMessageHandler, batchHandler NotificationsMessageBatchHandler, ) messageHandler { - return func(stream libp2pnetwork.Stream, m Message) error { + return func(stream network.Stream, m Message) error { if m == nil || info == nil || info.handshakeValidator == nil || notificationsMessageHandler == nil { return nil } @@ -228,7 +227,7 @@ func (s *Service) createNotificationsMessageHandler( } } -func closeOutboundStream(info *notificationsProtocol, peerID peer.ID, stream libp2pnetwork.Stream) { +func closeOutboundStream(info *notificationsProtocol, peerID peer.ID, stream network.Stream) { logger.Debugf( "cleaning up outbound handshake data for protocol=%s, peer=%s", stream.Protocol(), @@ -279,7 +278,7 @@ func (s *Service) sendData(peer peer.ID, hs Handshake, info *notificationsProtoc logger.Debugf("failed to send message to peer %s: %s", peer, err) // the stream was closed or reset, close it on our end and delete it from our peer's data - if errors.Is(err, io.EOF) || errors.Is(err, mux.ErrReset) { + if errors.Is(err, io.EOF) || errors.Is(err, network.ErrReset) { closeOutboundStream(info, peer, stream) } return @@ -299,7 +298,7 @@ func (s *Service) sendData(peer peer.ID, hs Handshake, info *notificationsProtoc var errPeerDisconnected = errors.New("peer disconnected") -func (s *Service) sendHandshake(peer peer.ID, hs Handshake, info *notificationsProtocol) (libp2pnetwork.Stream, error) { +func (s *Service) sendHandshake(peer peer.ID, hs Handshake, info *notificationsProtocol) (network.Stream, error) { // multiple processes could each call this upcoming section, opening multiple streams and // sending multiple handshakes. thus, we need to have a per-peer and per-protocol lock @@ -413,7 +412,7 @@ func (s *Service) broadcastExcluding(info *notificationsProtocol, excluding peer } } -func (s *Service) readHandshake(stream libp2pnetwork.Stream, decoder HandshakeDecoder, maxSize uint64, +func (s *Service) readHandshake(stream network.Stream, decoder HandshakeDecoder, maxSize uint64, ) <-chan *handshakeReader { hsC := make(chan *handshakeReader) diff --git a/dot/network/notifications_test.go b/dot/network/notifications_test.go index c4ad87a5f8..215efea544 100644 --- a/dot/network/notifications_test.go +++ b/dot/network/notifications_test.go @@ -104,7 +104,7 @@ func TestCreateNotificationsMessageHandler_BlockAnnounce(t *testing.T) { testPeerID := b.host.id() // connect nodes - addrInfoB := b.host.addrInfo() + addrInfoB := addrInfo(b.host) err := s.host.connect(addrInfoB) if failedToDial(err) { time.Sleep(TestBackoffTimeout) @@ -173,7 +173,7 @@ func TestCreateNotificationsMessageHandler_BlockAnnounceHandshake(t *testing.T) testPeerID := b.host.id() // connect nodes - addrInfoB := b.host.addrInfo() + addrInfoB := addrInfo(b.host) err := s.host.connect(addrInfoB) if failedToDial(err) { time.Sleep(TestBackoffTimeout) @@ -254,7 +254,7 @@ func Test_HandshakeTimeout(t *testing.T) { // should not respond to a handshake message }) - addrInfosB := nodeB.host.addrInfo() + addrInfosB := addrInfo(nodeB.host) err := nodeA.host.connect(addrInfosB) // retry connect if "failed to dial" error @@ -333,7 +333,7 @@ func TestCreateNotificationsMessageHandler_HandleTransaction(t *testing.T) { txnBatchHandler := srvc1.createBatchMessageHandler(txnBatch) // connect nodes - addrInfoB := srvc2.host.addrInfo() + addrInfoB := addrInfo(srvc2.host) err := srvc1.host.connect(addrInfoB) if failedToDial(err) { time.Sleep(TestBackoffTimeout) diff --git a/dot/network/service_test.go b/dot/network/service_test.go index 5dd4b3ae28..9efe777fb9 100644 --- a/dot/network/service_test.go +++ b/dot/network/service_test.go @@ -160,8 +160,6 @@ func createTestService(t *testing.T, cfg *Config) (srvc *Service) { cfg.Telemetry = telemetryMock } - cfg.noPreAllocate = true - srvc, err := NewService(cfg) require.NoError(t, err) @@ -211,7 +209,7 @@ func TestBroadcastMessages(t *testing.T) { handler := newTestStreamHandler(testBlockAnnounceHandshakeDecoder) nodeB.host.registerStreamHandler(nodeB.host.protocolID+blockAnnounceID, handler.handleStream) - addrInfoB := nodeB.host.addrInfo() + addrInfoB := addrInfo(nodeB.host) err := nodeA.host.connect(addrInfoB) // retry connect if "failed to dial" error if failedToDial(err) { @@ -260,7 +258,7 @@ func TestBroadcastDuplicateMessage(t *testing.T) { handler := newTestStreamHandler(testBlockAnnounceHandshakeDecoder) nodeB.host.registerStreamHandler(nodeB.host.protocolID+blockAnnounceID, handler.handleStream) - addrInfoB := nodeB.host.addrInfo() + addrInfoB := addrInfo(nodeB.host) err := nodeA.host.connect(addrInfoB) // retry connect if "failed to dial" error if failedToDial(err) { @@ -354,7 +352,7 @@ func TestPersistPeerStore(t *testing.T) { nodeA := nodes[0] nodeB := nodes[1] - addrInfoB := nodeB.host.addrInfo() + addrInfoB := addrInfo(nodeB.host) err := nodeA.host.connect(addrInfoB) if failedToDial(err) { time.Sleep(TestBackoffTimeout) @@ -394,7 +392,7 @@ func TestHandleConn(t *testing.T) { nodeB := createTestService(t, configB) - addrInfoB := nodeB.host.addrInfo() + addrInfoB := addrInfo(nodeB.host) err := nodeA.host.connect(addrInfoB) if failedToDial(err) { time.Sleep(TestBackoffTimeout) diff --git a/dot/node.go b/dot/node.go index 86c6aa7de5..edb21cd137 100644 --- a/dot/node.go +++ b/dot/node.go @@ -411,7 +411,11 @@ func newNode(cfg *Config, return node, nil } -func setupTelemetry(cfg *Config, genesisData *genesis.Data) (mailer *telemetry.Mailer, err error) { +func setupTelemetry(cfg *Config, genesisData *genesis.Data) (mailer telemetry.Client, err error) { + if cfg.Global.NoTelemetry { + return telemetry.NewNoopMailer(), nil + } + var telemetryEndpoints []*genesis.TelemetryEndpoint if len(cfg.Global.TelemetryURLs) == 0 && genesisData != nil { telemetryEndpoints = append(telemetryEndpoints, genesisData.TelemetryEndpoints...) @@ -424,7 +428,7 @@ func setupTelemetry(cfg *Config, genesisData *genesis.Data) (mailer *telemetry.M telemetryLogger := log.NewFromGlobal(log.AddContext("pkg", "telemetry")) return telemetry.BootstrapMailer(context.TODO(), - telemetryEndpoints, !cfg.Global.NoTelemetry, telemetryLogger) + telemetryEndpoints, telemetryLogger) } // stores the global node name to reuse diff --git a/dot/rpc/json2/server.go b/dot/rpc/json2/server.go index 201082c9cd..764e58b5e0 100644 --- a/dot/rpc/json2/server.go +++ b/dot/rpc/json2/server.go @@ -202,7 +202,7 @@ func (c *CodecRequest) WriteResponse(w http.ResponseWriter, reply interface{}) { } // WriteError encodes the error and writes it to the ResponseWriter. -func (c *CodecRequest) WriteError(w http.ResponseWriter, status int, err error) { +func (c *CodecRequest) WriteError(w http.ResponseWriter, _ int, err error) { err = c.tryToMapIfNotAnErrorAlready(err) jsonErr, ok := err.(*json2.Error) if !ok { diff --git a/dot/state/pruner/pruner.go b/dot/state/pruner/pruner.go index 57a1ed415f..bbfa074dc1 100644 --- a/dot/state/pruner/pruner.go +++ b/dot/state/pruner/pruner.go @@ -59,7 +59,7 @@ type Pruner interface { type ArchiveNode struct{} // StoreJournalRecord for archive node doesn't do anything. -func (a *ArchiveNode) StoreJournalRecord(_, _ map[common.Hash]struct{}, +func (*ArchiveNode) StoreJournalRecord(_, _ map[common.Hash]struct{}, _ common.Hash, _ int64) error { return nil } @@ -326,7 +326,7 @@ func (p *FullNode) loadDeathList() error { return nil } -func (p *FullNode) deleteJournalRecord(b chaindb.Batch, key *journalKey) error { +func (*FullNode) deleteJournalRecord(b chaindb.Batch, key *journalKey) error { encKey, err := scale.Marshal(*key) if err != nil { return err @@ -373,7 +373,7 @@ func (p *FullNode) getLastPrunedIndex() (int64, error) { return blockNum, nil } -func (p *FullNode) deleteKeys(b chaindb.Batch, nodesHash map[common.Hash]int64) error { +func (*FullNode) deleteKeys(b chaindb.Batch, nodesHash map[common.Hash]int64) error { for k := range nodesHash { err := b.Del(k.ToBytes()) if err != nil { diff --git a/dot/sync/chain_processor.go b/dot/sync/chain_processor.go index 267785021e..0e03b0426e 100644 --- a/dot/sync/chain_processor.go +++ b/dot/sync/chain_processor.go @@ -19,7 +19,7 @@ import ( // ChainProcessor processes ready blocks. // it is implemented by *chainProcessor type ChainProcessor interface { - start() + processReadyBlocks() stop() } @@ -65,10 +65,6 @@ func newChainProcessor(readyBlocks *blockQueue, pendingBlocks DisjointBlockSet, } } -func (s *chainProcessor) start() { - go s.processReadyBlocks() -} - func (s *chainProcessor) stop() { s.cancel() } diff --git a/dot/sync/chain_processor_integration_test.go b/dot/sync/chain_processor_integration_test.go index f875edfa34..8b359176da 100644 --- a/dot/sync/chain_processor_integration_test.go +++ b/dot/sync/chain_processor_integration_test.go @@ -236,7 +236,7 @@ func TestChainProcessor_HandleJustification(t *testing.T) { func TestChainProcessor_processReadyBlocks_errFailedToGetParent(t *testing.T) { syncer := newTestSyncer(t) processor := syncer.chainProcessor.(*chainProcessor) - processor.start() + go processor.processReadyBlocks() defer processor.cancel() header := &types.Header{ diff --git a/dot/sync/mock_chain_processor_test.go b/dot/sync/mock_chain_processor_test.go index 5eeaa0d450..fc6b9c1569 100644 --- a/dot/sync/mock_chain_processor_test.go +++ b/dot/sync/mock_chain_processor_test.go @@ -33,16 +33,16 @@ func (m *MockChainProcessor) EXPECT() *MockChainProcessorMockRecorder { return m.recorder } -// start mocks base method. -func (m *MockChainProcessor) start() { +// processReadyBlocks mocks base method. +func (m *MockChainProcessor) processReadyBlocks() { m.ctrl.T.Helper() - m.ctrl.Call(m, "start") + m.ctrl.Call(m, "processReadyBlocks") } -// start indicates an expected call of start. -func (mr *MockChainProcessorMockRecorder) start() *gomock.Call { +// processReadyBlocks indicates an expected call of processReadyBlocks. +func (mr *MockChainProcessorMockRecorder) processReadyBlocks() *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "start", reflect.TypeOf((*MockChainProcessor)(nil).start)) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "processReadyBlocks", reflect.TypeOf((*MockChainProcessor)(nil).processReadyBlocks)) } // stop mocks base method. diff --git a/dot/sync/syncer.go b/dot/sync/syncer.go index 6cf7965ada..47f680b315 100644 --- a/dot/sync/syncer.go +++ b/dot/sync/syncer.go @@ -100,7 +100,7 @@ func NewService(cfg *Config) (*Service, error) { // Start begins the chainSync and chainProcessor modules. It begins syncing in bootstrap mode func (s *Service) Start() error { go s.chainSync.start() - s.chainProcessor.start() + go s.chainProcessor.processReadyBlocks() return nil } diff --git a/dot/sync/syncer_test.go b/dot/sync/syncer_test.go index fb45ee57df..08146e6bc7 100644 --- a/dot/sync/syncer_test.go +++ b/dot/sync/syncer_test.go @@ -5,6 +5,7 @@ package sync import ( "errors" + "sync" "testing" "github.com/ChainSafe/gossamer/dot/network" @@ -317,15 +318,20 @@ func TestService_IsSynced(t *testing.T) { func TestService_Start(t *testing.T) { t.Parallel() ctrl := gomock.NewController(t) - done := make(chan struct{}) + + var allCalled sync.WaitGroup chainSync := NewMockChainSync(ctrl) + allCalled.Add(1) chainSync.EXPECT().start().DoAndReturn(func() { - close(done) + allCalled.Done() }) chainProcessor := NewMockChainProcessor(ctrl) - chainProcessor.EXPECT().start() + allCalled.Add(1) + chainProcessor.EXPECT().processReadyBlocks().DoAndReturn(func() { + allCalled.Done() + }) service := Service{ chainSync: chainSync, @@ -333,7 +339,7 @@ func TestService_Start(t *testing.T) { } err := service.Start() - <-done + allCalled.Wait() assert.NoError(t, err) } diff --git a/dot/system/service.go b/dot/system/service.go index 608620a951..69ebaf4544 100644 --- a/dot/system/service.go +++ b/dot/system/service.go @@ -48,11 +48,11 @@ func (s *Service) Properties() map[string]interface{} { } // Start implements Service interface -func (s *Service) Start() error { +func (*Service) Start() error { return nil } // Stop implements Service interface -func (s *Service) Stop() error { +func (*Service) Stop() error { return nil } diff --git a/dot/telemetry/mailer.go b/dot/telemetry/mailer.go index 08068415d9..ebc0c47dc1 100644 --- a/dot/telemetry/mailer.go +++ b/dot/telemetry/mailer.go @@ -25,32 +25,19 @@ type telemetryConnection struct { // Mailer can send messages to the telemetry servers. type Mailer struct { - *sync.Mutex + mutex *sync.Mutex - logger log.LeveledLogger - enabled bool + logger log.LeveledLogger connections []*telemetryConnection } -func newMailer(enabled bool, logger log.LeveledLogger) *Mailer { - mailer := &Mailer{ - new(sync.Mutex), - logger, - enabled, - nil, - } - - return mailer -} - // BootstrapMailer setup the mailer, the connections and start the async message shipment -func BootstrapMailer(ctx context.Context, conns []*genesis.TelemetryEndpoint, enabled bool, logger log.LeveledLogger) ( +func BootstrapMailer(ctx context.Context, conns []*genesis.TelemetryEndpoint, logger log.LeveledLogger) ( mailer *Mailer, err error) { - - mailer = newMailer(enabled, logger) - if !enabled { - return mailer, nil + mailer = &Mailer{ + mutex: new(sync.Mutex), + logger: logger, } for _, v := range conns { @@ -90,12 +77,10 @@ func BootstrapMailer(ctx context.Context, conns []*genesis.TelemetryEndpoint, en // SendMessage sends Message to connected telemetry listeners through messageReceiver func (m *Mailer) SendMessage(msg Message) { - m.Lock() - defer m.Unlock() + m.mutex.Lock() + defer m.mutex.Unlock() - if m.enabled { - go m.shipTelemetryMessage(msg) - } + go m.shipTelemetryMessage(msg) } func (m *Mailer) shipTelemetryMessage(msg Message) { diff --git a/dot/telemetry/mailer_test.go b/dot/telemetry/mailer_test.go index 8b4b414dfd..3041cb242d 100644 --- a/dot/telemetry/mailer_test.go +++ b/dot/telemetry/mailer_test.go @@ -47,9 +47,8 @@ func newTestMailer(t *testing.T, handler http.HandlerFunc) (mailer *Mailer) { testEndpoints := []*genesis.TelemetryEndpoint{testEndpoint1} logger := log.New(log.SetWriter(io.Discard)) - const telemetryEnabled = true - mailer, err := BootstrapMailer(context.Background(), testEndpoints, telemetryEnabled, logger) + mailer, err := BootstrapMailer(context.Background(), testEndpoints, logger) require.NoError(t, err) return mailer diff --git a/dot/telemetry/noop.go b/dot/telemetry/noop.go new file mode 100644 index 0000000000..203360a749 --- /dev/null +++ b/dot/telemetry/noop.go @@ -0,0 +1,15 @@ +// Copyright 2022 ChainSafe Systems (ON) +// SPDX-License-Identifier: LGPL-3.0-only + +package telemetry + +// Noop is a no-op telemetry client implementation. +type Noop struct{} + +// NewNoopMailer returns a no-op telemetry mailer implementation. +func NewNoopMailer() *Noop { + return &Noop{} +} + +// SendMessage does nothing. +func (*Noop) SendMessage(_ Message) {} diff --git a/dot/telemetry/telemetry.go b/dot/telemetry/telemetry.go index 504976ec85..a781275b4c 100644 --- a/dot/telemetry/telemetry.go +++ b/dot/telemetry/telemetry.go @@ -37,4 +37,4 @@ type Message interface { type NoopClient struct{} // SendMessage is an empty implementation used for testing -func (NoopClient) SendMessage(msg Message) {} +func (NoopClient) SendMessage(_ Message) {} diff --git a/dot/types/authority.go b/dot/types/authority.go index ea6966ae62..8962568880 100644 --- a/dot/types/authority.go +++ b/dot/types/authority.go @@ -78,7 +78,7 @@ func (a *Authority) ToRaw() *AuthorityRaw { // DeepCopy creates a deep copy of the Authority func (a *Authority) DeepCopy() *Authority { pk := a.Key.Encode() - pkCopy, _ := sr25519.NewPublicKey(pk[:]) + pkCopy, _ := sr25519.NewPublicKey(pk) return &Authority{ Key: pkCopy, Weight: a.Weight, diff --git a/dot/types/babe.go b/dot/types/babe.go index ab30b0f119..6f4ce09eef 100644 --- a/dot/types/babe.go +++ b/dot/types/babe.go @@ -26,6 +26,10 @@ const ( PrimaryAndSecondaryVRFSlots ) +var ( + ErrChainHeadMissingDigest = errors.New("chain head missing digest") +) + // BabeConfiguration contains the genesis data for BABE //nolint:lll // see: https://github.com/paritytech/substrate/blob/426c26b8bddfcdbaf8d29f45b128e0864b57de1c/core/consensus/babe/primitives/src/lib.rs#L132 @@ -105,7 +109,7 @@ type ConfigData struct { // GetSlotFromHeader returns the BABE slot from the given header func GetSlotFromHeader(header *Header) (uint64, error) { if len(header.Digest.Types) == 0 { - return 0, fmt.Errorf("chain head missing digest") + return 0, ErrChainHeadMissingDigest } preDigest, ok := header.Digest.Types[0].Value().(PreRuntimeDigest) @@ -138,7 +142,7 @@ func IsPrimary(header *Header) (bool, error) { } if len(header.Digest.Types) == 0 { - return false, fmt.Errorf("chain head missing digest") + return false, ErrChainHeadMissingDigest } preDigest, ok := header.Digest.Types[0].Value().(PreRuntimeDigest) diff --git a/dot/types/babe_digest.go b/dot/types/babe_digest.go index 0f547bdc98..f704e41db2 100644 --- a/dot/types/babe_digest.go +++ b/dot/types/babe_digest.go @@ -58,7 +58,7 @@ func (d *BabePrimaryPreDigest) ToPreRuntimeDigest() (*PreRuntimeDigest, error) { } // Index Returns VDT index -func (d BabePrimaryPreDigest) Index() uint { return 1 } +func (BabePrimaryPreDigest) Index() uint { return 1 } // BabeSecondaryPlainPreDigest is included in a block built by a secondary slot authorized producer type BabeSecondaryPlainPreDigest struct { @@ -80,7 +80,7 @@ func (d *BabeSecondaryPlainPreDigest) ToPreRuntimeDigest() (*PreRuntimeDigest, e } // Index Returns VDT index -func (d BabeSecondaryPlainPreDigest) Index() uint { return 2 } +func (BabeSecondaryPlainPreDigest) Index() uint { return 2 } // BabeSecondaryVRFPreDigest is included in a block built by a secondary slot authorized producer type BabeSecondaryVRFPreDigest struct { @@ -108,7 +108,7 @@ func (d *BabeSecondaryVRFPreDigest) ToPreRuntimeDigest() (*PreRuntimeDigest, err } // Index Returns VDT index -func (d BabeSecondaryVRFPreDigest) Index() uint { return 3 } +func (BabeSecondaryVRFPreDigest) Index() uint { return 3 } // toPreRuntimeDigest returns the VaryingDataTypeValue as a PreRuntimeDigest func toPreRuntimeDigest(value scale.VaryingDataTypeValue) (*PreRuntimeDigest, error) { diff --git a/dot/types/consensus_digest.go b/dot/types/consensus_digest.go index b33e884555..6a4dc219e8 100644 --- a/dot/types/consensus_digest.go +++ b/dot/types/consensus_digest.go @@ -27,7 +27,7 @@ type GrandpaScheduledChange struct { } // Index Returns VDT index -func (sc GrandpaScheduledChange) Index() uint { return 1 } +func (GrandpaScheduledChange) Index() uint { return 1 } // GrandpaForcedChange represents a GRANDPA forced authority change type GrandpaForcedChange struct { @@ -40,7 +40,7 @@ type GrandpaForcedChange struct { } // Index Returns VDT index -func (fc GrandpaForcedChange) Index() uint { return 2 } +func (GrandpaForcedChange) Index() uint { return 2 } // GrandpaOnDisabled represents a GRANDPA authority being disabled type GrandpaOnDisabled struct { @@ -48,7 +48,7 @@ type GrandpaOnDisabled struct { } // Index Returns VDT index -func (od GrandpaOnDisabled) Index() uint { return 3 } +func (GrandpaOnDisabled) Index() uint { return 3 } // GrandpaPause represents an authority set pause type GrandpaPause struct { @@ -56,7 +56,7 @@ type GrandpaPause struct { } // Index Returns VDT index -func (p GrandpaPause) Index() uint { return 4 } +func (GrandpaPause) Index() uint { return 4 } // GrandpaResume represents an authority set resume type GrandpaResume struct { @@ -64,7 +64,7 @@ type GrandpaResume struct { } // Index Returns VDT index -func (r GrandpaResume) Index() uint { return 5 } +func (GrandpaResume) Index() uint { return 5 } // NextEpochData is the digest that contains the data for the upcoming BABE epoch. // It is included in the first block of every epoch to describe the next epoch. @@ -74,7 +74,7 @@ type NextEpochData struct { } // Index Returns VDT index -func (d NextEpochData) Index() uint { return 1 } +func (NextEpochData) Index() uint { return 1 } func (d NextEpochData) String() string { return fmt.Sprintf("NextEpochData Authorities=%v Randomness=%v", d.Authorities, d.Randomness) @@ -99,7 +99,7 @@ type BABEOnDisabled struct { } // Index Returns VDT index -func (od BABEOnDisabled) Index() uint { return 2 } +func (BABEOnDisabled) Index() uint { return 2 } // NextConfigData is the digest that contains changes to the BABE configuration. // It is potentially included in the first block of an epoch to describe the next epoch. @@ -110,7 +110,7 @@ type NextConfigData struct { } // Index Returns VDT index -func (d NextConfigData) Index() uint { return 3 } +func (NextConfigData) Index() uint { return 3 } // ToConfigData returns the NextConfigData as ConfigData func (d *NextConfigData) ToConfigData() *ConfigData { diff --git a/dot/types/digest.go b/dot/types/digest.go index 4624d25837..2d156cca45 100644 --- a/dot/types/digest.go +++ b/dot/types/digest.go @@ -45,7 +45,7 @@ type ChangesTrieRootDigest struct { } // Index Returns VDT index -func (d ChangesTrieRootDigest) Index() uint { return 2 } +func (ChangesTrieRootDigest) Index() uint { return 2 } // String returns the digest as a string func (d *ChangesTrieRootDigest) String() string { @@ -59,7 +59,7 @@ type PreRuntimeDigest struct { } // Index Returns VDT index -func (d PreRuntimeDigest) Index() uint { return 6 } +func (PreRuntimeDigest) Index() uint { return 6 } // NewBABEPreRuntimeDigest returns a PreRuntimeDigest with the BABE consensus ID func NewBABEPreRuntimeDigest(data []byte) *PreRuntimeDigest { @@ -81,7 +81,7 @@ type ConsensusDigest struct { } // Index Returns VDT index -func (d ConsensusDigest) Index() uint { return 4 } +func (ConsensusDigest) Index() uint { return 4 } // String returns the digest as a string func (d ConsensusDigest) String() string { @@ -95,7 +95,7 @@ type SealDigest struct { } // Index Returns VDT index -func (d SealDigest) Index() uint { return 5 } +func (SealDigest) Index() uint { return 5 } // String returns the digest as a string func (d *SealDigest) String() string { diff --git a/dot/types/inherents.go b/dot/types/inherents.go index d175add39d..9c199a03c7 100644 --- a/dot/types/inherents.go +++ b/dot/types/inherents.go @@ -74,7 +74,7 @@ func (d *InherentsData) Encode() ([]byte, error) { return nil, err } - _, err = buffer.Write(l[:]) + _, err = buffer.Write(l) if err != nil { return nil, err } diff --git a/internal/log/patch.go b/internal/log/patch.go index d28525776a..7d76bb2075 100644 --- a/internal/log/patch.go +++ b/internal/log/patch.go @@ -10,13 +10,13 @@ func (l *Logger) Patch(options ...Option) { l.mutex.Lock() defer l.mutex.Unlock() - l.patch(options...) + l.patchWithoutLocking(options...) for _, child := range l.childs { - child.patch(options...) + child.patchWithoutLocking(options...) } } -func (l *Logger) patch(options ...Option) { +func (l *Logger) patchWithoutLocking(options ...Option) { var updatedSettings settings updatedSettings.mergeWith(l.settings) updatedSettings.mergeWith(newSettings(options)) diff --git a/internal/log/patch_test.go b/internal/log/patch_test.go index baad6435db..5c1c0561c4 100644 --- a/internal/log/patch_test.go +++ b/internal/log/patch_test.go @@ -169,7 +169,7 @@ func Test_Logger_patch(t *testing.T) { logger := testCase.initialLogger - logger.patch(testCase.options...) + logger.patchWithoutLocking(testCase.options...) assert.Equal(t, testCase.expectedLogger, logger) }) diff --git a/internal/trie/node/branch_encode.go b/internal/trie/node/branch_encode.go index 58f8acd16d..b9b9dc61f7 100644 --- a/internal/trie/node/branch_encode.go +++ b/internal/trie/node/branch_encode.go @@ -56,7 +56,7 @@ func encodeChildrenOpportunisticParallel(children []*Node, buffer io.Writer) (er resultsCh := make(chan encodingAsyncResult, ChildrenCapacity) for i, child := range children { - if child == nil || child.Type() == Leaf { + if child == nil || child.Kind() == Leaf { runEncodeChild(child, i, resultsCh, nil) continue } @@ -153,12 +153,12 @@ func scaleEncodeHash(node *Node) (encoding []byte, err error) { err = hashNode(node, buffer) if err != nil { - return nil, fmt.Errorf("cannot hash %s: %w", node.Type(), err) + return nil, fmt.Errorf("cannot hash %s: %w", node.Kind(), err) } encoding, err = scale.Marshal(buffer.Bytes()) if err != nil { - return nil, fmt.Errorf("cannot scale encode hashed %s: %w", node.Type(), err) + return nil, fmt.Errorf("cannot scale encode hashed %s: %w", node.Kind(), err) } return encoding, nil @@ -171,14 +171,14 @@ func hashNode(node *Node, digestWriter io.Writer) (err error) { err = node.Encode(encodingBuffer) if err != nil { - return fmt.Errorf("cannot encode %s: %w", node.Type(), err) + return fmt.Errorf("cannot encode %s: %w", node.Kind(), err) } // if length of encoded leaf is less than 32 bytes, do not hash if encodingBuffer.Len() < 32 { _, err = digestWriter.Write(encodingBuffer.Bytes()) if err != nil { - return fmt.Errorf("cannot write encoded %s to buffer: %w", node.Type(), err) + return fmt.Errorf("cannot write encoded %s to buffer: %w", node.Kind(), err) } return nil } @@ -191,12 +191,12 @@ func hashNode(node *Node, digestWriter io.Writer) (err error) { // Note: using the sync.Pool's buffer is useful here. _, err = hasher.Write(encodingBuffer.Bytes()) if err != nil { - return fmt.Errorf("cannot hash encoding of %s: %w", node.Type(), err) + return fmt.Errorf("cannot hash encoding of %s: %w", node.Kind(), err) } _, err = digestWriter.Write(hasher.Sum(nil)) if err != nil { - return fmt.Errorf("cannot write hash sum of %s to buffer: %w", node.Type(), err) + return fmt.Errorf("cannot write hash sum of %s to buffer: %w", node.Kind(), err) } return nil } diff --git a/internal/trie/node/copy.go b/internal/trie/node/copy.go index ec2e5c0796..42697544af 100644 --- a/internal/trie/node/copy.go +++ b/internal/trie/node/copy.go @@ -62,7 +62,7 @@ func (n *Node) Copy(settings CopySettings) *Node { Descendants: n.Descendants, } - if n.Type() == Branch { + if n.Kind() == Branch { if settings.CopyChildren { // Copy all fields of children if we deep copy children childSettings := settings diff --git a/internal/trie/node/copy_test.go b/internal/trie/node/copy_test.go index 8816ee4604..664eefbfdc 100644 --- a/internal/trie/node/copy_test.go +++ b/internal/trie/node/copy_test.go @@ -160,7 +160,7 @@ func Test_Node_Copy(t *testing.T) { testForSliceModif(t, testCase.node.HashDigest, nodeCopy.HashDigest) testForSliceModif(t, testCase.node.Encoding, nodeCopy.Encoding) - if testCase.node.Type() == Branch { + if testCase.node.Kind() == Branch { testCase.node.Children[15] = &Node{Key: []byte("modified")} assert.NotEqual(t, nodeCopy.Children, testCase.node.Children) } diff --git a/internal/trie/node/dirty.go b/internal/trie/node/dirty.go index 3c703942b5..3acf9866ef 100644 --- a/internal/trie/node/dirty.go +++ b/internal/trie/node/dirty.go @@ -3,14 +3,17 @@ package node -// SetDirty sets the dirty status to the node. -func (n *Node) SetDirty(dirty bool) { - n.Dirty = dirty - if dirty { - // A node is marked dirty if its key or value is modified. - // This means its cached encoding and hash fields are no longer - // valid. To improve memory usage, we clear these fields. - n.Encoding = nil - n.HashDigest = nil - } +// SetDirty sets the dirty status to true for the node. +func (n *Node) SetDirty() { + n.Dirty = true + // A node is marked dirty if its key or value is modified. + // This means its cached encoding and hash fields are no longer + // valid. To improve memory usage, we clear these fields. + n.Encoding = nil + n.HashDigest = nil +} + +// SetClean sets the dirty status to false for the node. +func (n *Node) SetClean() { + n.Dirty = false } diff --git a/internal/trie/node/dirty_test.go b/internal/trie/node/dirty_test.go index 31d9339447..419adac6ca 100644 --- a/internal/trie/node/dirty_test.go +++ b/internal/trie/node/dirty_test.go @@ -14,46 +14,64 @@ func Test_Node_SetDirty(t *testing.T) { testCases := map[string]struct { node Node - dirty bool expected Node }{ - "not dirty to not dirty": { + "not dirty to dirty": { node: Node{ Encoding: []byte{1}, HashDigest: []byte{1}, }, - expected: Node{ - Encoding: []byte{1}, - HashDigest: []byte{1}, - }, + expected: Node{Dirty: true}, }, - "not dirty to dirty": { + "dirty to dirty": { node: Node{ Encoding: []byte{1}, HashDigest: []byte{1}, + Dirty: true, }, - dirty: true, expected: Node{Dirty: true}, }, - "dirty to not dirty": { + } + + for name, testCase := range testCases { + testCase := testCase + t.Run(name, func(t *testing.T) { + t.Parallel() + + testCase.node.SetDirty() + + assert.Equal(t, testCase.expected, testCase.node) + }) + } +} + +func Test_Node_SetClean(t *testing.T) { + t.Parallel() + + testCases := map[string]struct { + node Node + expected Node + }{ + "not dirty to not dirty": { node: Node{ Encoding: []byte{1}, HashDigest: []byte{1}, - Dirty: true, }, expected: Node{ Encoding: []byte{1}, HashDigest: []byte{1}, }, }, - "dirty to dirty": { + "dirty to not dirty": { node: Node{ Encoding: []byte{1}, HashDigest: []byte{1}, Dirty: true, }, - dirty: true, - expected: Node{Dirty: true}, + expected: Node{ + Encoding: []byte{1}, + HashDigest: []byte{1}, + }, }, } @@ -62,7 +80,7 @@ func Test_Node_SetDirty(t *testing.T) { t.Run(name, func(t *testing.T) { t.Parallel() - testCase.node.SetDirty(testCase.dirty) + testCase.node.SetClean() assert.Equal(t, testCase.expected, testCase.node) }) diff --git a/internal/trie/node/encode.go b/internal/trie/node/encode.go index c605da7302..7d2fd86686 100644 --- a/internal/trie/node/encode.go +++ b/internal/trie/node/encode.go @@ -35,7 +35,7 @@ func (n *Node) Encode(buffer Buffer) (err error) { return fmt.Errorf("cannot write LE key to buffer: %w", err) } - if n.Type() == Branch { + if n.Kind() == Branch { childrenBitmap := common.Uint16ToBytes(n.ChildrenBitmap()) _, err = buffer.Write(childrenBitmap) if err != nil { @@ -57,14 +57,14 @@ func (n *Node) Encode(buffer Buffer) (err error) { } } - if n.Type() == Branch { + if n.Kind() == Branch { err = encodeChildrenOpportunisticParallel(n.Children, buffer) if err != nil { return fmt.Errorf("cannot encode children of branch: %w", err) } } - if n.Type() == Leaf { + if n.Kind() == Leaf { // TODO cache this for branches too and update test cases. // TODO remove this copying since it defeats the purpose of `buffer` // and the sync.Pool. diff --git a/internal/trie/node/hash.go b/internal/trie/node/hash.go index 5c9ef06f1b..2230e49d2a 100644 --- a/internal/trie/node/hash.go +++ b/internal/trie/node/hash.go @@ -11,10 +11,8 @@ import ( ) // EncodeAndHash returns the encoding of the node and -// the blake2b hash digest of the encoding of the node. -// If the encoding is less than 32 bytes, the hash returned -// is the encoding and not the hash of the encoding. -func (n *Node) EncodeAndHash(isRoot bool) (encoding, hash []byte, err error) { +// the Merkle value of the node. +func (n *Node) EncodeAndHash() (encoding, hash []byte, err error) { if !n.Dirty && n.Encoding != nil && n.HashDigest != nil { return n.Encoding, n.HashDigest, nil } @@ -36,7 +34,7 @@ func (n *Node) EncodeAndHash(isRoot bool) (encoding, hash []byte, err error) { copy(n.Encoding, bufferBytes) encoding = n.Encoding // no need to copy - if !isRoot && buffer.Len() < 32 { + if buffer.Len() < 32 { n.HashDigest = make([]byte, len(bufferBytes)) copy(n.HashDigest, bufferBytes) hash = n.HashDigest // no need to copy @@ -53,3 +51,38 @@ func (n *Node) EncodeAndHash(isRoot bool) (encoding, hash []byte, err error) { return encoding, hash, nil } + +// EncodeAndHashRoot returns the encoding of the root node and +// the Merkle value of the root node (the hash of its encoding). +func (n *Node) EncodeAndHashRoot() (encoding, hash []byte, err error) { + if !n.Dirty && n.Encoding != nil && n.HashDigest != nil { + return n.Encoding, n.HashDigest, nil + } + + buffer := pools.EncodingBuffers.Get().(*bytes.Buffer) + buffer.Reset() + defer pools.EncodingBuffers.Put(buffer) + + err = n.Encode(buffer) + if err != nil { + return nil, nil, err + } + + bufferBytes := buffer.Bytes() + + // TODO remove this copying since it defeats the purpose of `buffer` + // and the sync.Pool. + n.Encoding = make([]byte, len(bufferBytes)) + copy(n.Encoding, bufferBytes) + encoding = n.Encoding // no need to copy + + // Note: using the sync.Pool's buffer is useful here. + hashArray, err := common.Blake2bHash(buffer.Bytes()) + if err != nil { + return nil, nil, err + } + n.HashDigest = hashArray[:] + hash = n.HashDigest // no need to copy + + return encoding, hash, nil +} diff --git a/internal/trie/node/hash_test.go b/internal/trie/node/hash_test.go index 703845d514..cae06b09de 100644 --- a/internal/trie/node/hash_test.go +++ b/internal/trie/node/hash_test.go @@ -17,7 +17,6 @@ func Test_Node_EncodeAndHash(t *testing.T) { expectedNode Node encoding []byte hash []byte - isRoot bool errWrapped error errMessage string }{ @@ -32,20 +31,6 @@ func Test_Node_EncodeAndHash(t *testing.T) { }, encoding: []byte{0x41, 0x1, 0x4, 0x2}, hash: []byte{0x41, 0x1, 0x4, 0x2}, - isRoot: false, - }, - "small leaf encoding for root node": { - node: Node{ - Key: []byte{1}, - Value: []byte{2}, - }, - expectedNode: Node{ - Encoding: []byte{0x41, 0x1, 0x4, 0x2}, - HashDigest: []byte{0x60, 0x51, 0x6d, 0xb, 0xb6, 0xe1, 0xbb, 0xfb, 0x12, 0x93, 0xf1, 0xb2, 0x76, 0xea, 0x95, 0x5, 0xe9, 0xf4, 0xa4, 0xe7, 0xd9, 0x8f, 0x62, 0xd, 0x5, 0x11, 0x5e, 0xb, 0x85, 0x27, 0x4a, 0xe1}, //nolint: lll - }, - encoding: []byte{0x41, 0x1, 0x4, 0x2}, - hash: []byte{0x60, 0x51, 0x6d, 0xb, 0xb6, 0xe1, 0xbb, 0xfb, 0x12, 0x93, 0xf1, 0xb2, 0x76, 0xea, 0x95, 0x5, 0xe9, 0xf4, 0xa4, 0xe7, 0xd9, 0x8f, 0x62, 0xd, 0x5, 0x11, 0x5e, 0xb, 0x85, 0x27, 0x4a, 0xe1}, // nolint: lll - isRoot: true, }, "leaf dirty with precomputed encoding and hash": { node: Node{ @@ -61,7 +46,6 @@ func Test_Node_EncodeAndHash(t *testing.T) { }, encoding: []byte{0x41, 0x1, 0x4, 0x2}, hash: []byte{0x41, 0x1, 0x4, 0x2}, - isRoot: false, }, "leaf not dirty with precomputed encoding and hash": { node: Node{ @@ -79,7 +63,6 @@ func Test_Node_EncodeAndHash(t *testing.T) { }, encoding: []byte{3}, hash: []byte{4}, - isRoot: false, }, "large leaf encoding": { node: Node{ @@ -92,7 +75,6 @@ func Test_Node_EncodeAndHash(t *testing.T) { }, encoding: []byte{0x7f, 0x2, 0x7, 0x77, 0x77, 0x77, 0x77, 0x77, 0x77, 0x77, 0x77, 0x77, 0x77, 0x77, 0x77, 0x77, 0x77, 0x77, 0x77, 0x77, 0x77, 0x77, 0x77, 0x77, 0x77, 0x77, 0x77, 0x77, 0x77, 0x77, 0x77, 0x77, 0x77, 0x77, 0x77, 0x4, 0x1}, //nolint:lll hash: []byte{0xd2, 0x1d, 0x43, 0x7, 0x18, 0x17, 0x1b, 0xf1, 0x45, 0x9c, 0xe5, 0x8f, 0xd7, 0x79, 0x82, 0xb, 0xc8, 0x5c, 0x8, 0x47, 0xfe, 0x6c, 0x99, 0xc5, 0xe9, 0x57, 0x87, 0x7, 0x1d, 0x2e, 0x24, 0x5d}, //nolint:lll - isRoot: false, }, "empty branch": { node: Node{ @@ -105,7 +87,6 @@ func Test_Node_EncodeAndHash(t *testing.T) { }, encoding: []byte{0x80, 0x0, 0x0}, hash: []byte{0x80, 0x0, 0x0}, - isRoot: false, }, "small branch encoding": { node: Node{ @@ -120,22 +101,6 @@ func Test_Node_EncodeAndHash(t *testing.T) { }, encoding: []byte{0xc1, 0x1, 0x0, 0x0, 0x4, 0x2}, hash: []byte{0xc1, 0x1, 0x0, 0x0, 0x4, 0x2}, - isRoot: false, - }, - "small branch encoding for root node": { - node: Node{ - Children: make([]*Node, ChildrenCapacity), - Key: []byte{1}, - Value: []byte{2}, - }, - expectedNode: Node{ - Children: make([]*Node, ChildrenCapacity), - Encoding: []byte{0xc1, 0x1, 0x0, 0x0, 0x4, 0x2}, - HashDigest: []byte{0x48, 0x3c, 0xf6, 0x87, 0xcc, 0x5a, 0x60, 0x42, 0xd3, 0xcf, 0xa6, 0x91, 0xe6, 0x88, 0xfb, 0xdc, 0x1b, 0x38, 0x39, 0x5d, 0x6, 0x0, 0xbf, 0xc3, 0xb, 0x4b, 0x5d, 0x6a, 0x37, 0xd9, 0xc5, 0x1c}, // nolint: lll - }, - encoding: []byte{0xc1, 0x1, 0x0, 0x0, 0x4, 0x2}, - hash: []byte{0x48, 0x3c, 0xf6, 0x87, 0xcc, 0x5a, 0x60, 0x42, 0xd3, 0xcf, 0xa6, 0x91, 0xe6, 0x88, 0xfb, 0xdc, 0x1b, 0x38, 0x39, 0x5d, 0x6, 0x0, 0xbf, 0xc3, 0xb, 0x4b, 0x5d, 0x6a, 0x37, 0xd9, 0xc5, 0x1c}, // nolint: lll - isRoot: true, }, "branch dirty with precomputed encoding and hash": { node: Node{ @@ -153,7 +118,6 @@ func Test_Node_EncodeAndHash(t *testing.T) { }, encoding: []byte{0xc1, 0x1, 0x0, 0x0, 0x4, 0x2}, hash: []byte{0xc1, 0x1, 0x0, 0x0, 0x4, 0x2}, - isRoot: false, }, "branch not dirty with precomputed encoding and hash": { node: Node{ @@ -173,7 +137,6 @@ func Test_Node_EncodeAndHash(t *testing.T) { }, encoding: []byte{3}, hash: []byte{4}, - isRoot: false, }, "large branch encoding": { node: Node{ @@ -187,7 +150,6 @@ func Test_Node_EncodeAndHash(t *testing.T) { }, encoding: []byte{0xbf, 0x2, 0x7, 0x77, 0x77, 0x77, 0x77, 0x77, 0x77, 0x77, 0x77, 0x77, 0x77, 0x77, 0x77, 0x77, 0x77, 0x77, 0x77, 0x77, 0x77, 0x77, 0x77, 0x77, 0x77, 0x77, 0x77, 0x77, 0x77, 0x77, 0x77, 0x77, 0x77, 0x77, 0x77, 0x0, 0x0}, //nolint:lll hash: []byte{0x6b, 0xd8, 0xcc, 0xac, 0x71, 0x77, 0x44, 0x17, 0xfe, 0xe0, 0xde, 0xda, 0xd5, 0x97, 0x6e, 0x69, 0xeb, 0xe9, 0xdd, 0x80, 0x1d, 0x4b, 0x51, 0xf1, 0x5b, 0xf3, 0x4a, 0x93, 0x27, 0x32, 0x2c, 0xb0}, //nolint:lll - isRoot: false, }, } @@ -196,7 +158,80 @@ func Test_Node_EncodeAndHash(t *testing.T) { t.Run(name, func(t *testing.T) { t.Parallel() - encoding, hash, err := testCase.node.EncodeAndHash(testCase.isRoot) + encoding, hash, err := testCase.node.EncodeAndHash() + + assert.ErrorIs(t, err, testCase.errWrapped) + if testCase.errWrapped != nil { + assert.EqualError(t, err, testCase.errMessage) + } + assert.Equal(t, testCase.encoding, encoding) + assert.Equal(t, testCase.hash, hash) + }) + } +} + +func Test_Node_EncodeAndHashRoot(t *testing.T) { + t.Parallel() + + testCases := map[string]struct { + node Node + expectedNode Node + encoding []byte + hash []byte + errWrapped error + errMessage string + }{ + "leaf not dirty with precomputed encoding and hash": { + node: Node{ + Key: []byte{1}, + Value: []byte{2}, + Dirty: false, + Encoding: []byte{3}, + HashDigest: []byte{4}, + }, + expectedNode: Node{ + Key: []byte{1}, + Value: []byte{2}, + Encoding: []byte{3}, + HashDigest: []byte{4}, + }, + encoding: []byte{3}, + hash: []byte{4}, + }, + "small leaf encoding": { + node: Node{ + Key: []byte{1}, + Value: []byte{2}, + }, + expectedNode: Node{ + Encoding: []byte{0x41, 0x1, 0x4, 0x2}, + HashDigest: []byte{0x60, 0x51, 0x6d, 0xb, 0xb6, 0xe1, 0xbb, 0xfb, 0x12, 0x93, 0xf1, 0xb2, 0x76, 0xea, 0x95, 0x5, 0xe9, 0xf4, 0xa4, 0xe7, 0xd9, 0x8f, 0x62, 0xd, 0x5, 0x11, 0x5e, 0xb, 0x85, 0x27, 0x4a, 0xe1}, //nolint: lll + }, + encoding: []byte{0x41, 0x1, 0x4, 0x2}, + hash: []byte{0x60, 0x51, 0x6d, 0xb, 0xb6, 0xe1, 0xbb, 0xfb, 0x12, 0x93, 0xf1, 0xb2, 0x76, 0xea, 0x95, 0x5, 0xe9, 0xf4, 0xa4, 0xe7, 0xd9, 0x8f, 0x62, 0xd, 0x5, 0x11, 0x5e, 0xb, 0x85, 0x27, 0x4a, 0xe1}, // nolint: lll + }, + "small branch encoding": { + node: Node{ + Children: make([]*Node, ChildrenCapacity), + Key: []byte{1}, + Value: []byte{2}, + }, + expectedNode: Node{ + Children: make([]*Node, ChildrenCapacity), + Encoding: []byte{0xc1, 0x1, 0x0, 0x0, 0x4, 0x2}, + HashDigest: []byte{0x48, 0x3c, 0xf6, 0x87, 0xcc, 0x5a, 0x60, 0x42, 0xd3, 0xcf, 0xa6, 0x91, 0xe6, 0x88, 0xfb, 0xdc, 0x1b, 0x38, 0x39, 0x5d, 0x6, 0x0, 0xbf, 0xc3, 0xb, 0x4b, 0x5d, 0x6a, 0x37, 0xd9, 0xc5, 0x1c}, // nolint: lll + }, + encoding: []byte{0xc1, 0x1, 0x0, 0x0, 0x4, 0x2}, + hash: []byte{0x48, 0x3c, 0xf6, 0x87, 0xcc, 0x5a, 0x60, 0x42, 0xd3, 0xcf, 0xa6, 0x91, 0xe6, 0x88, 0xfb, 0xdc, 0x1b, 0x38, 0x39, 0x5d, 0x6, 0x0, 0xbf, 0xc3, 0xb, 0x4b, 0x5d, 0x6a, 0x37, 0xd9, 0xc5, 0x1c}, // nolint: lll + }, + } + + for name, testCase := range testCases { + testCase := testCase + t.Run(name, func(t *testing.T) { + t.Parallel() + + encoding, hash, err := testCase.node.EncodeAndHashRoot() assert.ErrorIs(t, err, testCase.errWrapped) if testCase.errWrapped != nil { diff --git a/internal/trie/node/header.go b/internal/trie/node/header.go index 033c5e84e7..aa069c8fc1 100644 --- a/internal/trie/node/header.go +++ b/internal/trie/node/header.go @@ -18,7 +18,7 @@ func encodeHeader(node *Node, writer io.Writer) (err error) { // Merge variant byte and partial key length together var variant variant - if node.Type() == Leaf { + if node.Kind() == Leaf { variant = leafVariant } else if node.Value == nil { variant = branchVariant diff --git a/internal/trie/node/node.go b/internal/trie/node/node.go index a40cf31fd7..88b099d815 100644 --- a/internal/trie/node/node.go +++ b/internal/trie/node/node.go @@ -41,9 +41,9 @@ type Node struct { Descendants uint32 } -// Type returns Leaf or Branch depending on what type +// Kind returns Leaf or Branch depending on what kind // the node is. -func (n *Node) Type() Type { +func (n *Node) Kind() Kind { if n.Children != nil { return Branch } @@ -57,7 +57,7 @@ func (n *Node) String() string { // StringNode returns a gotree compatible node for String methods. func (n Node) StringNode() (stringNode *gotree.Node) { caser := cases.Title(language.BritishEnglish) - stringNode = gotree.New(caser.String(n.Type().String())) + stringNode = gotree.New(caser.String(n.Kind().String())) stringNode.Appendf("Generation: %d", n.Generation) stringNode.Appendf("Dirty: %t", n.Dirty) stringNode.Appendf("Key: " + bytesToString(n.Key)) diff --git a/internal/trie/node/types.go b/internal/trie/node/types.go index ea4a9be136..702dae2ff2 100644 --- a/internal/trie/node/types.go +++ b/internal/trie/node/types.go @@ -5,23 +5,23 @@ package node import "fmt" -// Type is the type of the node. -type Type byte +// Kind is the type of the node. +type Kind byte const ( - // Leaf type for leaf nodes. - Leaf Type = iota - // Branch type for branches (with or without value). + // Leaf kind for leaf nodes. + Leaf Kind = iota + // Branch kind for branches (with or without value). Branch ) -func (t Type) String() string { - switch t { +func (k Kind) String() string { + switch k { case Leaf: return "leaf" case Branch: return "branch" default: - panic(fmt.Sprintf("invalid node type: %d", t)) + panic(fmt.Sprintf("invalid node type: %d", k)) } } diff --git a/lib/babe/epoch_handler.go b/lib/babe/epoch_handler.go index b2cde8af8c..52e75ba3dd 100644 --- a/lib/babe/epoch_handler.go +++ b/lib/babe/epoch_handler.go @@ -91,7 +91,7 @@ func (h *epochHandler) run(ctx context.Context, errCh chan<- error) { } startTime := getSlotStartTime(authoringSlot, h.constants.slotDuration) - waitTime := startTime.Sub(time.Now()) + waitTime := time.Until(startTime) timer := time.NewTimer(waitTime) slotTimeTimers = append(slotTimeTimers, &slotWithTimer{ diff --git a/lib/babe/errors.go b/lib/babe/errors.go index 3f50917477..670f85dcf1 100644 --- a/lib/babe/errors.go +++ b/lib/babe/errors.go @@ -73,14 +73,7 @@ var ( errChannelClosed = errors.New("block notifier channel was closed") errOverPrimarySlotThreshold = errors.New("cannot claim slot, over primary threshold") errNotOurTurnToPropose = errors.New("cannot claim slot, not our turn to propose a block") - errGetEpochData = errors.New("get epochData error") - errFailedFinalisation = errors.New("failed to check finalisation") - errMissingDigest = errors.New("chain head missing digest") - errSetFirstSlot = errors.New("set first slot error") - errGetEpoch = errors.New("get epoch error") - errSkipVerify = errors.New("skipVerify error") errMissingDigestItems = errors.New("block header is missing digest items") - errDescendant = errors.New("descendant err") errServicePaused = errors.New("service paused") errInvalidSlotTechnique = errors.New("invalid slot claiming technique") errNoBABEAuthorityKeyProvided = errors.New("cannot create BABE service as authority; no keypair provided") @@ -140,19 +133,19 @@ func (e UnmarshalError) Error() string { type Other string // Index Returns VDT index -func (err Other) Index() uint { return 0 } +func (Other) Index() uint { return 0 } // CannotLookup Failed to lookup some data type CannotLookup struct{} // Index Returns VDT index -func (err CannotLookup) Index() uint { return 1 } +func (CannotLookup) Index() uint { return 1 } // BadOrigin A bad origin type BadOrigin struct{} // Index Returns VDT index -func (err BadOrigin) Index() uint { return 2 } +func (BadOrigin) Index() uint { return 2 } // Module A custom error in a module type Module struct { @@ -162,7 +155,7 @@ type Module struct { } // Index Returns VDT index -func (err Module) Index() uint { return 3 } +func (Module) Index() uint { return 3 } func (err Module) string() string { return fmt.Sprintf("index: %d code: %d message: %x", err.Idx, err.Err, *err.Message) @@ -172,79 +165,79 @@ func (err Module) string() string { type ValidityCannotLookup struct{} // Index Returns VDT index -func (err ValidityCannotLookup) Index() uint { return 0 } +func (ValidityCannotLookup) Index() uint { return 0 } // NoUnsignedValidator No validator found for the given unsigned transaction type NoUnsignedValidator struct{} // Index Returns VDT index -func (err NoUnsignedValidator) Index() uint { return 1 } +func (NoUnsignedValidator) Index() uint { return 1 } // UnknownCustom Any other custom unknown validity that is not covered type UnknownCustom uint8 // Index Returns VDT index -func (err UnknownCustom) Index() uint { return 2 } +func (UnknownCustom) Index() uint { return 2 } // Call The call of the transaction is not expected type Call struct{} // Index Returns VDT index -func (err Call) Index() uint { return 0 } +func (Call) Index() uint { return 0 } // Payment General error to do with the inability to pay some fees (e.g. account balance too low) type Payment struct{} // Index Returns VDT index -func (err Payment) Index() uint { return 1 } +func (Payment) Index() uint { return 1 } // Future General error to do with the transaction not yet being valid (e.g. nonce too high) type Future struct{} // Index Returns VDT index -func (err Future) Index() uint { return 2 } +func (Future) Index() uint { return 2 } // Stale General error to do with the transaction being outdated (e.g. nonce too low) type Stale struct{} // Index Returns VDT index -func (err Stale) Index() uint { return 3 } +func (Stale) Index() uint { return 3 } // BadProof General error to do with the transaction’s proofs (e.g. signature) type BadProof struct{} // Index Returns VDT index -func (err BadProof) Index() uint { return 4 } +func (BadProof) Index() uint { return 4 } // AncientBirthBlock The transaction birth block is ancient type AncientBirthBlock struct{} // Index Returns VDT index -func (err AncientBirthBlock) Index() uint { return 5 } +func (AncientBirthBlock) Index() uint { return 5 } // ExhaustsResources The transaction would exhaust the resources of current block type ExhaustsResources struct{} // Index Returns VDT index -func (err ExhaustsResources) Index() uint { return 6 } +func (ExhaustsResources) Index() uint { return 6 } // InvalidCustom Any other custom invalid validity that is not covered type InvalidCustom uint8 // Index Returns VDT index -func (err InvalidCustom) Index() uint { return 7 } +func (InvalidCustom) Index() uint { return 7 } // BadMandatory An extrinsic with a Mandatory dispatch resulted in Error type BadMandatory struct{} // Index Returns VDT index -func (err BadMandatory) Index() uint { return 8 } +func (BadMandatory) Index() uint { return 8 } // MandatoryDispatch A transaction with a mandatory dispatch type MandatoryDispatch struct{} // Index Returns VDT index -func (err MandatoryDispatch) Index() uint { return 9 } +func (MandatoryDispatch) Index() uint { return 9 } func determineErrType(vdt scale.VaryingDataType) error { switch val := vdt.Value().(type) { diff --git a/lib/babe/verify_test.go b/lib/babe/verify_test.go index 9b2adf9e8b..3d9c229f86 100644 --- a/lib/babe/verify_test.go +++ b/lib/babe/verify_test.go @@ -925,27 +925,32 @@ func TestVerificationManager_VerifyBlock(t *testing.T) { mockEpochStateNilBlockStateErr := NewMockEpochState(ctrl) mockEpochStateVerifyAuthorshipErr := NewMockEpochState(ctrl) - mockBlockStateCheckFinErr.EXPECT().NumberIsFinalised(uint(1)).Return(false, errFailedFinalisation) + errTestNumberIsFinalised := errors.New("test number is finalised error") + mockBlockStateCheckFinErr.EXPECT().NumberIsFinalised(uint(1)).Return(false, errTestNumberIsFinalised) mockBlockStateNotFinal.EXPECT().NumberIsFinalised(uint(1)).Return(false, nil) mockBlockStateNotFinal2.EXPECT().NumberIsFinalised(uint(1)).Return(false, nil) - mockEpochStateSetSlotErr.EXPECT().SetFirstSlot(uint64(1)).Return(errSetFirstSlot) + errTestSetFirstSlot := errors.New("test set first slot error") + mockEpochStateSetSlotErr.EXPECT().SetFirstSlot(uint64(1)).Return(errTestSetFirstSlot) + errTestGetEpoch := errors.New("test get epoch error") mockEpochStateGetEpochErr.EXPECT().GetEpochForBlock(testBlockHeaderEmpty). - Return(uint64(0), errGetEpoch) + Return(uint64(0), errTestGetEpoch) mockEpochStateSkipVerifyErr.EXPECT().GetEpochForBlock(testBlockHeaderEmpty).Return(uint64(1), nil) - mockEpochStateSkipVerifyErr.EXPECT().GetEpochData(uint64(1), testBlockHeaderEmpty).Return(nil, errGetEpochData) - mockEpochStateSkipVerifyErr.EXPECT().SkipVerify(testBlockHeaderEmpty).Return(false, errSkipVerify) + errTestGetEpochData := errors.New("test get epoch data error") + mockEpochStateSkipVerifyErr.EXPECT().GetEpochData(uint64(1), testBlockHeaderEmpty).Return(nil, errTestGetEpochData) + errTestSkipVerify := errors.New("test skip verify error") + mockEpochStateSkipVerifyErr.EXPECT().SkipVerify(testBlockHeaderEmpty).Return(false, errTestSkipVerify) mockEpochStateSkipVerifyTrue.EXPECT().GetEpochForBlock(testBlockHeaderEmpty).Return(uint64(1), nil) - mockEpochStateSkipVerifyTrue.EXPECT().GetEpochData(uint64(1), testBlockHeaderEmpty).Return(nil, errGetEpochData) + mockEpochStateSkipVerifyTrue.EXPECT().GetEpochData(uint64(1), testBlockHeaderEmpty).Return(nil, errTestGetEpochData) mockEpochStateSkipVerifyTrue.EXPECT().SkipVerify(testBlockHeaderEmpty).Return(true, nil) mockEpochStateGetVerifierInfoErr.EXPECT().GetEpochForBlock(testBlockHeaderEmpty).Return(uint64(1), nil) mockEpochStateGetVerifierInfoErr.EXPECT().GetEpochData(uint64(1), testBlockHeaderEmpty). - Return(nil, errGetEpochData) + Return(nil, errTestGetEpochData) mockEpochStateGetVerifierInfoErr.EXPECT().SkipVerify(testBlockHeaderEmpty).Return(false, nil) mockEpochStateNilBlockStateErr.EXPECT().GetEpochForBlock(testBlockHeaderEmpty).Return(uint64(1), nil) @@ -1007,31 +1012,31 @@ func TestVerificationManager_VerifyBlock(t *testing.T) { name: "fail to check block 1 finalisation", vm: vm0, header: block1Header, - expErr: fmt.Errorf("failed to check if block 1 is finalised: %w", errFailedFinalisation), + expErr: fmt.Errorf("failed to check if block 1 is finalised: %w", errTestNumberIsFinalised), }, { name: "get slot from header error", vm: vm1, header: block1Header, - expErr: fmt.Errorf("failed to get slot from block 1: %w", errMissingDigest), + expErr: fmt.Errorf("failed to get slot from block 1: %w", types.ErrChainHeadMissingDigest), }, { name: "set first slot error", vm: vm2, header: block1Header2, - expErr: fmt.Errorf("failed to set current epoch after receiving block 1: %w", errSetFirstSlot), + expErr: fmt.Errorf("failed to set current epoch after receiving block 1: %w", errTestSetFirstSlot), }, { name: "get epoch error", vm: vm3, header: testBlockHeaderEmpty, - expErr: fmt.Errorf("failed to get epoch for block header: %w", errGetEpoch), + expErr: fmt.Errorf("failed to get epoch for block header: %w", errTestGetEpoch), }, { name: "skip verify err", vm: vm4, header: testBlockHeaderEmpty, - expErr: fmt.Errorf("failed to check if verification can be skipped: %w", errSkipVerify), + expErr: fmt.Errorf("failed to check if verification can be skipped: %w", errTestSkipVerify), }, { name: "skip verify true", @@ -1043,7 +1048,7 @@ func TestVerificationManager_VerifyBlock(t *testing.T) { vm: vm6, header: testBlockHeaderEmpty, expErr: fmt.Errorf("failed to get verifier info for block 2: "+ - "failed to get epoch data for epoch 1: %w", errGetEpochData), + "failed to get epoch data for epoch 1: %w", errTestGetEpochData), }, { name: "nil blockState error", @@ -1093,17 +1098,20 @@ func TestVerificationManager_SetOnDisabled(t *testing.T) { mockEpochStateOk2 := NewMockEpochState(ctrl) mockEpochStateOk3 := NewMockEpochState(ctrl) - mockEpochStateGetEpochErr.EXPECT().GetEpochForBlock(types.NewEmptyHeader()).Return(uint64(0), errGetEpoch) + errTestGetEpoch := errors.New("test get epoch error") + mockEpochStateGetEpochErr.EXPECT().GetEpochForBlock(types.NewEmptyHeader()).Return(uint64(0), errTestGetEpoch) mockEpochStateGetEpochDataErr.EXPECT().GetEpochForBlock(types.NewEmptyHeader()).Return(uint64(0), nil) - mockEpochStateGetEpochDataErr.EXPECT().GetEpochData(uint64(0), types.NewEmptyHeader()).Return(nil, errGetEpochData) + errTestGetEpochData := errors.New("test get epoch data error") + mockEpochStateGetEpochDataErr.EXPECT().GetEpochData(uint64(0), types.NewEmptyHeader()).Return(nil, errTestGetEpochData) mockEpochStateIndexLenErr.EXPECT().GetEpochForBlock(types.NewEmptyHeader()).Return(uint64(2), nil) mockEpochStateSetDisabledProd.EXPECT().GetEpochForBlock(types.NewEmptyHeader()).Return(uint64(2), nil) mockEpochStateOk.EXPECT().GetEpochForBlock(types.NewEmptyHeader()).Return(uint64(2), nil) - mockBlockStateIsDescendantErr.EXPECT().IsDescendantOf(gomock.Any(), gomock.Any()).Return(false, errDescendant) + errTestDescendant := errors.New("test descendant error") + mockBlockStateIsDescendantErr.EXPECT().IsDescendantOf(gomock.Any(), gomock.Any()).Return(false, errTestDescendant) mockEpochStateOk2.EXPECT().GetEpochForBlock(testHeader).Return(uint64(2), nil) mockBlockStateAuthorityDisabled.EXPECT().IsDescendantOf(gomock.Any(), gomock.Any()).Return(true, nil) @@ -1173,7 +1181,7 @@ func TestVerificationManager_SetOnDisabled(t *testing.T) { args: args{ header: types.NewEmptyHeader(), }, - expErr: errGetEpoch, + expErr: errTestGetEpoch, }, { name: "get epoch data err", @@ -1181,7 +1189,7 @@ func TestVerificationManager_SetOnDisabled(t *testing.T) { args: args{ header: types.NewEmptyHeader(), }, - expErr: fmt.Errorf("failed to get epoch data for epoch %d: %w", 0, errGetEpochData), + expErr: fmt.Errorf("failed to get epoch data for epoch %d: %w", 0, errTestGetEpochData), }, { name: "index length error", @@ -1205,7 +1213,7 @@ func TestVerificationManager_SetOnDisabled(t *testing.T) { args: args{ header: types.NewEmptyHeader(), }, - expErr: errDescendant, + expErr: errTestDescendant, }, { name: "authority already disabled", diff --git a/lib/common/hasher.go b/lib/common/hasher.go index da6fdc7ad9..e9bc0732f2 100644 --- a/lib/common/hasher.go +++ b/lib/common/hasher.go @@ -85,10 +85,10 @@ func Twox64(in []byte) ([]byte, error) { } // Twox128Hash computes xxHash64 twice with seeds 0 and 1 applied on given byte array -func Twox128Hash(msg []byte) ([]byte, error) { +func Twox128Hash(msg []byte) (result []byte, err error) { // compute xxHash64 twice with seeds 0 and 1 applied on given byte array h0 := xxhash.NewS64(0) // create xxHash with 0 seed - _, err := h0.Write(msg) + _, err = h0.Write(msg) if err != nil { return nil, err } @@ -105,9 +105,11 @@ func Twox128Hash(msg []byte) ([]byte, error) { hash1 := make([]byte, 8) binary.LittleEndian.PutUint64(hash1, res1) - //concatenated result - both := append(hash0, hash1...) - return both, nil + result = make([]byte, 16) + copy(result[:8], hash0) + copy(result[8:], hash1) + + return result, nil } // Twox256 returns the twox256 hash of the input data diff --git a/lib/crypto/ed25519/ed25519.go b/lib/crypto/ed25519/ed25519.go index 28ab6c34d2..f24967e2b2 100644 --- a/lib/crypto/ed25519/ed25519.go +++ b/lib/crypto/ed25519/ed25519.go @@ -177,7 +177,7 @@ func Verify(pub *PublicKey, msg, sig []byte) (bool, error) { } // Type returns Ed25519Type -func (kp *Keypair) Type() crypto.KeyType { +func (*Keypair) Type() crypto.KeyType { return crypto.Ed25519Type } diff --git a/lib/crypto/secp256k1/secp256k1.go b/lib/crypto/secp256k1/secp256k1.go index dc4790ae8c..c9d5da73db 100644 --- a/lib/crypto/secp256k1/secp256k1.go +++ b/lib/crypto/secp256k1/secp256k1.go @@ -138,7 +138,7 @@ func GenerateKeypair() (*Keypair, error) { } // Type returns Secp256k1Type -func (kp *Keypair) Type() crypto.KeyType { +func (*Keypair) Type() crypto.KeyType { return crypto.Secp256k1Type } diff --git a/lib/crypto/sr25519/sr25519.go b/lib/crypto/sr25519/sr25519.go index 799f720d54..3bfe52bc7b 100644 --- a/lib/crypto/sr25519/sr25519.go +++ b/lib/crypto/sr25519/sr25519.go @@ -194,7 +194,7 @@ func NewPublicKey(in []byte) (*PublicKey, error) { } // Type returns Sr25519Type -func (kp *Keypair) Type() crypto.KeyType { +func (*Keypair) Type() crypto.KeyType { return crypto.Sr25519Type } diff --git a/lib/genesis/helpers.go b/lib/genesis/helpers.go index 7ff492e84b..08916f328b 100644 --- a/lib/genesis/helpers.go +++ b/lib/genesis/helpers.go @@ -8,7 +8,6 @@ import ( "encoding/json" "errors" "fmt" - "log" "math/big" "os" "path/filepath" @@ -689,7 +688,7 @@ func addAuthoritiesValues(k1, k2 string, kt crypto.KeyType, value []byte, gen *G } b := make([]byte, 8) if _, err = reader.Read(b); err != nil { - log.Fatal(err) + return fmt.Errorf("reading from buffer: %w", err) } var iv uint64 err = scale.Unmarshal(b, &iv) diff --git a/lib/genesis/test_utils.go b/lib/genesis/test_utils.go index a02eab8898..418d743d33 100644 --- a/lib/genesis/test_utils.go +++ b/lib/genesis/test_utils.go @@ -47,16 +47,6 @@ var TestGenesis = &Genesis{ BadBlocks: testBadBlocks, } -// TestFieldsHR instance of human-readable Fields struct for testing, use with TestGenesis -var TestFieldsHR = Fields{ - Raw: map[string]map[string]string{}, - Runtime: map[string]map[string]interface{}{ - "System": { - "code": "mocktestcode", - }, - }, -} - // TestFieldsRaw instance of raw Fields struct for testing use with TestGenesis var TestFieldsRaw = Fields{ Raw: map[string]map[string]string{ @@ -67,33 +57,20 @@ var TestFieldsRaw = Fields{ }, } -// CreateTestGenesisJSONFile utility to create mock test genesis JSON file -func CreateTestGenesisJSONFile(t *testing.T, asRaw bool) (filename string) { - tGen := &Genesis{ +// CreateTestGenesisJSONFile writes a genesis file using the fields given to +// the current test temporary directory. +func CreateTestGenesisJSONFile(t *testing.T, fields Fields) (filename string) { + rawGenesis := &Genesis{ Name: "test", ID: "", Bootnodes: nil, ProtocolID: "", - Genesis: Fields{}, + Genesis: fields, } - - if asRaw { - tGen.Genesis = Fields{ - Raw: map[string]map[string]string{}, - Runtime: map[string]map[string]interface{}{ - "System": { - "code": "mocktestcode", - }, - }, - } - } else { - tGen.Genesis = TestFieldsHR - } - - bz, err := json.Marshal(tGen) + jsonData, err := json.Marshal(rawGenesis) require.NoError(t, err) filename = filepath.Join(t.TempDir(), "genesis-test") - err = os.WriteFile(filename, bz, os.ModePerm) + err = os.WriteFile(filename, jsonData, os.ModePerm) require.NoError(t, err) return filename } diff --git a/lib/grandpa/message.go b/lib/grandpa/message.go index ce24d05df3..3a941f2e77 100644 --- a/lib/grandpa/message.go +++ b/lib/grandpa/message.go @@ -57,7 +57,7 @@ type VoteMessage struct { } // Index Returns VDT index -func (v VoteMessage) Index() uint { return 0 } +func (VoteMessage) Index() uint { return 0 } // ToConsensusMessage converts the VoteMessage into a network-level consensus message func (v *VoteMessage) ToConsensusMessage() (*ConsensusMessage, error) { @@ -86,7 +86,7 @@ type NeighbourMessage struct { } // Index Returns VDT index -func (m NeighbourMessage) Index() uint { return 2 } +func (NeighbourMessage) Index() uint { return 2 } // ToConsensusMessage converts the NeighbourMessage into a network-level consensus message func (m *NeighbourMessage) ToConsensusMessage() (*network.ConsensusMessage, error) { @@ -137,7 +137,7 @@ func (s *Service) newCommitMessage(header *types.Header, round uint64) (*CommitM } // Index Returns VDT index -func (f CommitMessage) Index() uint { return 1 } +func (CommitMessage) Index() uint { return 1 } // ToConsensusMessage converts the CommitMessage into a network-level consensus message func (f *CommitMessage) ToConsensusMessage() (*ConsensusMessage, error) { @@ -203,7 +203,7 @@ func newCatchUpRequest(round, setID uint64) *CatchUpRequest { } // Index Returns VDT index -func (r CatchUpRequest) Index() uint { return 3 } +func (CatchUpRequest) Index() uint { return 3 } // ToConsensusMessage converts the catchUpRequest into a network-level consensus message func (r *CatchUpRequest) ToConsensusMessage() (*ConsensusMessage, error) { @@ -260,7 +260,7 @@ func (s *Service) newCatchUpResponse(round, setID uint64) (*CatchUpResponse, err } // Index Returns VDT index -func (r CatchUpResponse) Index() uint { return 4 } +func (CatchUpResponse) Index() uint { return 4 } // ToConsensusMessage converts the catchUpResponse into a network-level consensus message func (r *CatchUpResponse) ToConsensusMessage() (*ConsensusMessage, error) { diff --git a/lib/keystore/generic_keystore.go b/lib/keystore/generic_keystore.go index 9286f22bd6..76ed219ea2 100644 --- a/lib/keystore/generic_keystore.go +++ b/lib/keystore/generic_keystore.go @@ -35,7 +35,7 @@ func (ks *GenericKeystore) Name() Name { } // Type returns UnknownType since the keystore may contain keys of any type -func (ks *GenericKeystore) Type() crypto.KeyType { +func (*GenericKeystore) Type() crypto.KeyType { return crypto.UnknownType } diff --git a/lib/runtime/version.go b/lib/runtime/version.go index 4db1807eae..8b41dd0b54 100644 --- a/lib/runtime/version.go +++ b/lib/runtime/version.go @@ -82,7 +82,7 @@ func (lvd *LegacyVersionData) APIItems() []APIItem { } // TransactionVersion returns the transaction version -func (lvd *LegacyVersionData) TransactionVersion() uint32 { +func (*LegacyVersionData) TransactionVersion() uint32 { return 0 } diff --git a/lib/trie/child_storage.go b/lib/trie/child_storage.go index 9b49f95a4d..fc80da533f 100644 --- a/lib/trie/child_storage.go +++ b/lib/trie/child_storage.go @@ -23,7 +23,10 @@ func (t *Trie) PutChild(keyToChild []byte, child *Trie) error { if err != nil { return err } - key := append(ChildStorageKeyPrefix, keyToChild...) + + key := make([]byte, len(ChildStorageKeyPrefix)+len(keyToChild)) + copy(key, ChildStorageKeyPrefix) + copy(key[len(ChildStorageKeyPrefix):], keyToChild) t.Put(key, childHash.ToBytes()) t.childTries[childHash] = child @@ -32,7 +35,10 @@ func (t *Trie) PutChild(keyToChild []byte, child *Trie) error { // GetChild returns the child trie at key :child_storage:[keyToChild] func (t *Trie) GetChild(keyToChild []byte) (*Trie, error) { - key := append(ChildStorageKeyPrefix, keyToChild...) + key := make([]byte, len(ChildStorageKeyPrefix)+len(keyToChild)) + copy(key, ChildStorageKeyPrefix) + copy(key[len(ChildStorageKeyPrefix):], keyToChild) + childHash := t.Get(key) if childHash == nil { return nil, fmt.Errorf("%w at key 0x%x%x", ErrChildTrieDoesNotExist, ChildStorageKeyPrefix, keyToChild) @@ -83,7 +89,10 @@ func (t *Trie) GetFromChild(keyToChild, key []byte) ([]byte, error) { // DeleteChild deletes the child storage trie func (t *Trie) DeleteChild(keyToChild []byte) { - key := append(ChildStorageKeyPrefix, keyToChild...) + key := make([]byte, len(ChildStorageKeyPrefix)+len(keyToChild)) + copy(key, ChildStorageKeyPrefix) + copy(key[len(ChildStorageKeyPrefix):], keyToChild) + t.Delete(key) } diff --git a/lib/trie/database.go b/lib/trie/database.go index e4c2faa112..8723d9556b 100644 --- a/lib/trie/database.go +++ b/lib/trie/database.go @@ -32,7 +32,7 @@ func (t *Trie) Store(db chaindb.Database) error { } batch := db.NewBatch() - err := t.store(batch, t.root) + err := t.storeNode(batch, t.root) if err != nil { batch.Reset() return err @@ -41,12 +41,17 @@ func (t *Trie) Store(db chaindb.Database) error { return batch.Flush() } -func (t *Trie) store(db chaindb.Batch, n *Node) error { +func (t *Trie) storeNode(db chaindb.Batch, n *Node) (err error) { if n == nil { return nil } - encoding, hash, err := n.EncodeAndHash(n == t.root) + var encoding, hash []byte + if n == t.root { + encoding, hash, err = n.EncodeAndHashRoot() + } else { + encoding, hash, err = n.EncodeAndHash() + } if err != nil { return err } @@ -56,13 +61,13 @@ func (t *Trie) store(db chaindb.Batch, n *Node) error { return err } - if n.Type() == node.Branch { + if n.Kind() == node.Branch { for _, child := range n.Children { if child == nil { continue } - err = t.store(db, child) + err = t.storeNode(db, child) if err != nil { return err } @@ -70,7 +75,7 @@ func (t *Trie) store(db chaindb.Batch, n *Node) error { } if n.Dirty { - n.SetDirty(false) + n.SetClean() } return nil @@ -97,15 +102,15 @@ func (t *Trie) Load(db Database, rootHash common.Hash) error { } t.root = root - t.root.SetDirty(false) + t.root.SetClean() t.root.Encoding = encodedNode t.root.HashDigest = rootHashBytes - return t.load(db, t.root) + return t.loadNode(db, t.root) } -func (t *Trie) load(db Database, n *Node) error { - if n.Type() != node.Branch { +func (t *Trie) loadNode(db Database, n *Node) error { + if n.Kind() != node.Branch { return nil } @@ -120,11 +125,11 @@ func (t *Trie) load(db Database, n *Node) error { if len(hash) == 0 { // node has already been loaded inline // just set encoding + hash digest - _, _, err := child.EncodeAndHash(false) + _, _, err := child.EncodeAndHash() if err != nil { return err } - child.SetDirty(false) + child.SetClean() continue } @@ -139,17 +144,17 @@ func (t *Trie) load(db Database, n *Node) error { return fmt.Errorf("cannot decode node with hash 0x%x: %w", hash, err) } - decodedNode.SetDirty(false) + decodedNode.SetClean() decodedNode.Encoding = encodedNode decodedNode.HashDigest = hash branch.Children[i] = decodedNode - err = t.load(db, decodedNode) + err = t.loadNode(db, decodedNode) if err != nil { return fmt.Errorf("cannot load child at index %d with hash 0x%x: %w", i, hash, err) } - if decodedNode.Type() == node.Branch { + if decodedNode.Kind() == node.Branch { // Note 1: the node is fully loaded with all its descendants // count only after the database load above. // Note 2: direct child node is already counted as descendant @@ -183,7 +188,7 @@ func (t *Trie) load(db Database, n *Node) error { // PopulateNodeHashes writes hashes of each children of the node given // as keys to the map hashesSet. func (t *Trie) PopulateNodeHashes(n *Node, hashesSet map[common.Hash]struct{}) { - if n.Type() != node.Branch { + if n.Kind() != node.Branch { return } @@ -248,16 +253,16 @@ func GetFromDB(db chaindb.Database, rootHash common.Hash, key []byte) ( return nil, fmt.Errorf("cannot decode root node: %w", err) } - return getFromDB(db, rootNode, k) + return getFromDBAtNode(db, rootNode, k) } -// getFromDB recursively searches through the trie and database +// getFromDBAtNode recursively searches through the trie and database // for the value corresponding to a key. // Note it does not copy the value so modifying the value bytes // slice will modify the value of the node in the trie. -func getFromDB(db chaindb.Database, n *Node, key []byte) ( +func getFromDBAtNode(db chaindb.Database, n *Node, key []byte) ( value []byte, err error) { - if n.Type() == node.Leaf { + if n.Kind() == node.Leaf { if bytes.Equal(n.Key, key) { return n.Value, nil } @@ -287,8 +292,8 @@ func getFromDB(db chaindb.Database, n *Node, key []byte) ( // Child can be either inlined or a hash pointer. childHash := child.HashDigest - if len(childHash) == 0 && child.Type() == node.Leaf { - return getFromDB(db, child, key[commonPrefixLength+1:]) + if len(childHash) == 0 && child.Kind() == node.Leaf { + return getFromDBAtNode(db, child, key[commonPrefixLength+1:]) } encodedChild, err := db.Get(childHash) @@ -306,14 +311,14 @@ func getFromDB(db chaindb.Database, n *Node, key []byte) ( childHash, err) } - return getFromDB(db, decodedChild, key[commonPrefixLength+1:]) + return getFromDBAtNode(db, decodedChild, key[commonPrefixLength+1:]) // Note: do not wrap error since it's called recursively. } // WriteDirty writes all dirty nodes to the database and sets them to clean func (t *Trie) WriteDirty(db chaindb.Database) error { batch := db.NewBatch() - err := t.writeDirty(batch, t.root) + err := t.writeDirtyNode(batch, t.root) if err != nil { batch.Reset() return err @@ -322,12 +327,17 @@ func (t *Trie) WriteDirty(db chaindb.Database) error { return batch.Flush() } -func (t *Trie) writeDirty(db chaindb.Batch, n *Node) error { +func (t *Trie) writeDirtyNode(db chaindb.Batch, n *Node) (err error) { if n == nil || !n.Dirty { return nil } - encoding, hash, err := n.EncodeAndHash(n == t.root) + var encoding, hash []byte + if n == t.root { + encoding, hash, err = n.EncodeAndHashRoot() + } else { + encoding, hash, err = n.EncodeAndHash() + } if err != nil { return fmt.Errorf( "cannot encode and hash node with hash 0x%x: %w", @@ -341,8 +351,8 @@ func (t *Trie) writeDirty(db chaindb.Batch, n *Node) error { hash, err) } - if n.Type() != node.Branch { - n.SetDirty(false) + if n.Kind() != node.Branch { + n.SetClean() return nil } @@ -351,7 +361,7 @@ func (t *Trie) writeDirty(db chaindb.Batch, n *Node) error { continue } - err = t.writeDirty(db, child) + err = t.writeDirtyNode(db, child) if err != nil { // Note: do not wrap error since it's returned recursively. return err @@ -359,12 +369,12 @@ func (t *Trie) writeDirty(db chaindb.Batch, n *Node) error { } for _, childTrie := range t.childTries { - if err := childTrie.writeDirty(db, childTrie.root); err != nil { + if err := childTrie.writeDirtyNode(db, childTrie.root); err != nil { return fmt.Errorf("failed to write dirty node=0x%x to database: %w", childTrie.root.HashDigest, err) } } - n.SetDirty(false) + n.SetClean() return nil } @@ -375,19 +385,24 @@ func (t *Trie) writeDirty(db chaindb.Batch, n *Node) error { // We need to compute the hash values of each newly inserted node. func (t *Trie) GetInsertedNodeHashes() (hashesSet map[common.Hash]struct{}, err error) { hashesSet = make(map[common.Hash]struct{}) - err = t.getInsertedNodeHashes(t.root, hashesSet) + err = t.getInsertedNodeHashesAtNode(t.root, hashesSet) if err != nil { return nil, err } return hashesSet, nil } -func (t *Trie) getInsertedNodeHashes(n *Node, hashes map[common.Hash]struct{}) (err error) { +func (t *Trie) getInsertedNodeHashesAtNode(n *Node, hashes map[common.Hash]struct{}) (err error) { if n == nil || !n.Dirty { return nil } - _, hash, err := n.EncodeAndHash(n == t.root) + var hash []byte + if n == t.root { + _, hash, err = n.EncodeAndHashRoot() + } else { + _, hash, err = n.EncodeAndHash() + } if err != nil { return fmt.Errorf( "cannot encode and hash node with hash 0x%x: %w", @@ -396,7 +411,7 @@ func (t *Trie) getInsertedNodeHashes(n *Node, hashes map[common.Hash]struct{}) ( hashes[common.BytesToHash(hash)] = struct{}{} - if n.Type() != node.Branch { + if n.Kind() != node.Branch { return nil } @@ -405,7 +420,7 @@ func (t *Trie) getInsertedNodeHashes(n *Node, hashes map[common.Hash]struct{}) ( continue } - err := t.getInsertedNodeHashes(child, hashes) + err := t.getInsertedNodeHashesAtNode(child, hashes) if err != nil { // Note: do not wrap error since this is called recursively. return err diff --git a/lib/trie/proof/generate.go b/lib/trie/proof/generate.go index 9ab7c1dc84..f82f0e7f47 100644 --- a/lib/trie/proof/generate.go +++ b/lib/trie/proof/generate.go @@ -39,8 +39,7 @@ func Generate(rootHash []byte, fullKeys [][]byte, database Database) ( hashesSeen := make(map[string]struct{}) for _, fullKey := range fullKeys { fullKeyNibbles := codec.KeyLEToNibbles(fullKey) - const isRoot = true - newEncodedProofNodes, err := walk(rootNode, fullKeyNibbles, isRoot) + newEncodedProofNodes, err := walkRoot(rootNode, fullKeyNibbles) if err != nil { // Note we wrap the full key context here since walk is recursive and // may not be aware of the initial full key. @@ -67,7 +66,52 @@ func Generate(rootHash []byte, fullKeys [][]byte, database Database) ( return encodedProofNodes, nil } -func walk(parent *node.Node, fullKey []byte, isRoot bool) ( +func walkRoot(root *node.Node, fullKey []byte) ( + encodedProofNodes [][]byte, err error) { + if root == nil { + if len(fullKey) == 0 { + return nil, nil + } + return nil, ErrKeyNotFound + } + + // Note we do not use sync.Pool buffers since we would have + // to copy it so it persists in encodedProofNodes. + encodingBuffer := bytes.NewBuffer(nil) + err = root.Encode(encodingBuffer) + if err != nil { + return nil, fmt.Errorf("encode node: %w", err) + } + encodedProofNodes = append(encodedProofNodes, encodingBuffer.Bytes()) + + nodeFound := len(fullKey) == 0 || bytes.Equal(root.Key, fullKey) + if nodeFound { + return encodedProofNodes, nil + } + + if root.Kind() == node.Leaf && !nodeFound { + return nil, ErrKeyNotFound + } + + nodeIsDeeper := len(fullKey) > len(root.Key) + if !nodeIsDeeper { + return nil, ErrKeyNotFound + } + + commonLength := lenCommonPrefix(root.Key, fullKey) + childIndex := fullKey[commonLength] + nextChild := root.Children[childIndex] + nextFullKey := fullKey[commonLength+1:] + deeperEncodedProofNodes, err := walk(nextChild, nextFullKey) + if err != nil { + return nil, err // note: do not wrap since this is recursive + } + + encodedProofNodes = append(encodedProofNodes, deeperEncodedProofNodes...) + return encodedProofNodes, nil +} + +func walk(parent *node.Node, fullKey []byte) ( encodedProofNodes [][]byte, err error) { if parent == nil { if len(fullKey) == 0 { @@ -84,9 +128,8 @@ func walk(parent *node.Node, fullKey []byte, isRoot bool) ( return nil, fmt.Errorf("encode node: %w", err) } - if isRoot || encodingBuffer.Len() >= 32 { - // Only add the root node encoding (whatever its length) - // and child node encodings greater or equal to 32 bytes. + if encodingBuffer.Len() >= 32 { + // Only add (non root) node encodings greater or equal to 32 bytes. // This is because child node encodings of less than 32 bytes // are inlined in the parent node encoding, so there is no need // to duplicate them in the proof generated. @@ -98,7 +141,7 @@ func walk(parent *node.Node, fullKey []byte, isRoot bool) ( return encodedProofNodes, nil } - if parent.Type() == node.Leaf && !nodeFound { + if parent.Kind() == node.Leaf && !nodeFound { return nil, ErrKeyNotFound } @@ -111,8 +154,7 @@ func walk(parent *node.Node, fullKey []byte, isRoot bool) ( childIndex := fullKey[commonLength] nextChild := parent.Children[childIndex] nextFullKey := fullKey[commonLength+1:] - isRoot = false - deeperEncodedProofNodes, err := walk(nextChild, nextFullKey, isRoot) + deeperEncodedProofNodes, err := walk(nextChild, nextFullKey) if err != nil { return nil, err // note: do not wrap since this is recursive } diff --git a/lib/trie/proof/generate_test.go b/lib/trie/proof/generate_test.go index 9eaffde1bf..a354962a01 100644 --- a/lib/trie/proof/generate_test.go +++ b/lib/trie/proof/generate_test.go @@ -266,7 +266,7 @@ func Test_Generate(t *testing.T) { } } -func Test_walk(t *testing.T) { +func Test_walkRoot(t *testing.T) { t.Parallel() largeValue := generateBytes(t, 40) @@ -275,7 +275,6 @@ func Test_walk(t *testing.T) { testCases := map[string]struct { parent *node.Node fullKey []byte // nibbles - isRoot bool encodedProofNodes [][]byte errWrapped error errMessage string @@ -293,7 +292,6 @@ func Test_walk(t *testing.T) { Key: []byte{1, 2}, Value: []byte{1}, }, - isRoot: true, encodedProofNodes: [][]byte{encodeNode(t, node.Node{ Key: []byte{1, 2}, Value: []byte{1}, @@ -337,7 +335,6 @@ func Test_walk(t *testing.T) { }, }), }, - isRoot: true, encodedProofNodes: [][]byte{ encodeNode(t, node.Node{ Key: []byte{1, 2}, @@ -393,7 +390,6 @@ func Test_walk(t *testing.T) { }), }, fullKey: []byte{1, 2}, - isRoot: true, encodedProofNodes: [][]byte{ encodeNode(t, node.Node{ Key: []byte{1, 2}, @@ -419,7 +415,6 @@ func Test_walk(t *testing.T) { }), }, fullKey: []byte{1, 2, 0, 1, 2}, - isRoot: true, encodedProofNodes: [][]byte{ encodeNode(t, node.Node{ Key: []byte{1, 2}, @@ -447,7 +442,6 @@ func Test_walk(t *testing.T) { }), }, fullKey: []byte{1, 2, 0, 1, 2}, - isRoot: true, encodedProofNodes: [][]byte{ encodeNode(t, node.Node{ Key: []byte{1, 2}, @@ -492,7 +486,6 @@ func Test_walk(t *testing.T) { }), }, fullKey: []byte{1, 2, 0x04}, - isRoot: true, encodedProofNodes: [][]byte{ encodeNode(t, node.Node{ Key: []byte{1, 2}, @@ -513,7 +506,258 @@ func Test_walk(t *testing.T) { t.Run(name, func(t *testing.T) { t.Parallel() - encodedProofNodes, err := walk(testCase.parent, testCase.fullKey, testCase.isRoot) + encodedProofNodes, err := walkRoot(testCase.parent, testCase.fullKey) + + assert.ErrorIs(t, err, testCase.errWrapped) + if testCase.errWrapped != nil { + assert.EqualError(t, err, testCase.errMessage) + } + assert.Equal(t, testCase.encodedProofNodes, encodedProofNodes) + }) + } +} + +func Test_walk(t *testing.T) { + t.Parallel() + + largeValue := generateBytes(t, 40) + assertLongEncoding(t, node.Node{Value: largeValue}) + + testCases := map[string]struct { + parent *node.Node + fullKey []byte // nibbles + encodedProofNodes [][]byte + errWrapped error + errMessage string + }{ + "nil parent and empty full key": {}, + "nil parent and non empty full key": { + fullKey: []byte{1}, + errWrapped: ErrKeyNotFound, + errMessage: "key not found", + }, + // The parent encode error cannot be triggered here + // since it can only be caused by a buffer.Write error. + "parent leaf and empty full key": { + parent: &node.Node{ + Key: []byte{1, 2}, + Value: largeValue, + }, + encodedProofNodes: [][]byte{encodeNode(t, node.Node{ + Key: []byte{1, 2}, + Value: largeValue, + })}, + }, + "parent leaf and shorter full key": { + parent: &node.Node{ + Key: []byte{1, 2}, + Value: []byte{1}, + }, + fullKey: []byte{1}, + errWrapped: ErrKeyNotFound, + errMessage: "key not found", + }, + "parent leaf and mismatching full key": { + parent: &node.Node{ + Key: []byte{1, 2}, + Value: []byte{1}, + }, + fullKey: []byte{1, 3}, + errWrapped: ErrKeyNotFound, + errMessage: "key not found", + }, + "parent leaf and longer full key": { + parent: &node.Node{ + Key: []byte{1, 2}, + Value: []byte{1}, + }, + fullKey: []byte{1, 2, 3}, + errWrapped: ErrKeyNotFound, + errMessage: "key not found", + }, + "branch and empty search key": { + parent: &node.Node{ + Key: []byte{1, 2}, + Value: largeValue, + Children: padRightChildren([]*node.Node{ + { + Key: []byte{4}, + Value: []byte{5}, + }, + }), + }, + encodedProofNodes: [][]byte{ + encodeNode(t, node.Node{ + Key: []byte{1, 2}, + Value: largeValue, + Children: padRightChildren([]*node.Node{ + { + Key: []byte{4}, + Value: []byte{5}, + }, + }), + }), + }, + }, + "branch and shorter full key": { + parent: &node.Node{ + Key: []byte{1, 2}, + Value: []byte{3}, + Children: padRightChildren([]*node.Node{ + { + Key: []byte{4}, + Value: []byte{5}, + }, + }), + }, + fullKey: []byte{1}, + errWrapped: ErrKeyNotFound, + errMessage: "key not found", + }, + "branch and mismatching full key": { + parent: &node.Node{ + Key: []byte{1, 2}, + Value: []byte{3}, + Children: padRightChildren([]*node.Node{ + { + Key: []byte{4}, + Value: []byte{5}, + }, + }), + }, + fullKey: []byte{1, 3}, + errWrapped: ErrKeyNotFound, + errMessage: "key not found", + }, + "branch and matching search key": { + parent: &node.Node{ + Key: []byte{1, 2}, + Value: []byte{3}, + Children: padRightChildren([]*node.Node{ + { + Key: []byte{4}, + Value: largeValue, + }, + }), + }, + fullKey: []byte{1, 2}, + encodedProofNodes: [][]byte{ + encodeNode(t, node.Node{ + Key: []byte{1, 2}, + Value: []byte{3}, + Children: padRightChildren([]*node.Node{ + { + Key: []byte{4}, + Value: largeValue, + }, + }), + }), + }, + }, + "branch and matching search key for small leaf encoding": { + parent: &node.Node{ + Key: []byte{1, 2}, + Value: largeValue, + Children: padRightChildren([]*node.Node{ + { // full key 1, 2, 0, 1, 2 + Key: []byte{1, 2}, + Value: []byte{3}, + }, + }), + }, + fullKey: []byte{1, 2, 0, 1, 2}, + encodedProofNodes: [][]byte{ + encodeNode(t, node.Node{ + Key: []byte{1, 2}, + Value: largeValue, + Children: padRightChildren([]*node.Node{ + { // full key 1, 2, 0, 1, 2 + Key: []byte{1, 2}, + Value: []byte{3}, + }, + }), + }), + // Note the leaf encoding is not added since its encoding + // is less than 32 bytes. + }, + }, + "branch and matching search key for large leaf encoding": { + parent: &node.Node{ + Key: []byte{1, 2}, + Value: []byte{3}, + Children: padRightChildren([]*node.Node{ + { // full key 1, 2, 0, 1, 2 + Key: []byte{1, 2}, + Value: largeValue, + }, + }), + }, + fullKey: []byte{1, 2, 0, 1, 2}, + encodedProofNodes: [][]byte{ + encodeNode(t, node.Node{ + Key: []byte{1, 2}, + Value: []byte{3}, + Children: padRightChildren([]*node.Node{ + { // full key 1, 2, 0, 1, 2 + Key: []byte{1, 2}, + Value: largeValue, + }, + }), + }), + encodeNode(t, node.Node{ + Key: []byte{1, 2}, + Value: largeValue, + }), + }, + }, + "key not found at deeper level": { + parent: &node.Node{ + Key: []byte{1, 2}, + Value: []byte{3}, + Children: padRightChildren([]*node.Node{ + { + Key: []byte{4, 5}, + Value: []byte{5}, + }, + }), + }, + fullKey: []byte{1, 2, 0x04, 4}, + errWrapped: ErrKeyNotFound, + errMessage: "key not found", + }, + "found leaf at deeper level": { + parent: &node.Node{ + Key: []byte{1, 2}, + Value: []byte{3}, + Children: padRightChildren([]*node.Node{ + { + Key: []byte{4}, + Value: largeValue, + }, + }), + }, + fullKey: []byte{1, 2, 0x04}, + encodedProofNodes: [][]byte{ + encodeNode(t, node.Node{ + Key: []byte{1, 2}, + Value: []byte{3}, + Children: padRightChildren([]*node.Node{ + { + Key: []byte{4}, + Value: largeValue, + }, + }), + }), + }, + }, + } + + for name, testCase := range testCases { + testCase := testCase + t.Run(name, func(t *testing.T) { + t.Parallel() + + encodedProofNodes, err := walk(testCase.parent, testCase.fullKey) assert.ErrorIs(t, err, testCase.errWrapped) if testCase.errWrapped != nil { @@ -583,7 +827,7 @@ func Test_lenCommonPrefix(t *testing.T) { // In both cases, the performance difference is very small // so the code is kept to this inefficient-looking append, // which is in the end quite performant still. -func Benchmark_walk(b *testing.B) { +func Benchmark_walkRoot(b *testing.B) { trie := trie.NewEmptyTrie() // Build a deep trie. @@ -601,13 +845,12 @@ func Benchmark_walk(b *testing.B) { longestKeyNibbles := codec.KeyLEToNibbles(longestKeyLE) rootNode := trie.RootNode() - const isRoot = true - encodedProofNodes, err := walk(rootNode, longestKeyNibbles, isRoot) + encodedProofNodes, err := walkRoot(rootNode, longestKeyNibbles) require.NoError(b, err) require.Equal(b, len(encodedProofNodes), trieDepth) b.ResetTimer() for i := 0; i < b.N; i++ { - _, _ = walk(rootNode, longestKeyNibbles, isRoot) + _, _ = walkRoot(rootNode, longestKeyNibbles) } } diff --git a/lib/trie/proof/verify.go b/lib/trie/proof/verify.go index a75d96b78a..6c0a0ec833 100644 --- a/lib/trie/proof/verify.go +++ b/lib/trie/proof/verify.go @@ -111,7 +111,7 @@ func buildTrie(encodedProofNodes [][]byte, rootHash []byte) (t *trie.Trie, err e // loadProof is a recursive function that will create all the trie paths based // on the map from node hash digest to node encoding, starting from the node `n`. func loadProof(digestToEncoding map[string][]byte, n *node.Node) (err error) { - if n.Type() != node.Branch { + if n.Kind() != node.Branch { return nil } diff --git a/lib/trie/trie.go b/lib/trie/trie.go index bd26ceddfd..96bf698693 100644 --- a/lib/trie/trie.go +++ b/lib/trie/trie.go @@ -73,7 +73,7 @@ func (t *Trie) prepLeafForMutation(currentLeaf *Node, } else { newLeaf = updateGeneration(currentLeaf, t.generation, t.deletedKeys, copySettings) } - newLeaf.SetDirty(true) + newLeaf.SetDirty() return newLeaf } @@ -86,7 +86,7 @@ func (t *Trie) prepBranchForMutation(currentBranch *Node, } else { newBranch = updateGeneration(currentBranch, t.generation, t.deletedKeys, copySettings) } - newBranch.SetDirty(true) + newBranch.SetDirty() return newBranch } @@ -201,7 +201,7 @@ func entries(parent *Node, prefix []byte, kv map[string][]byte) map[string][]byt return kv } - if parent.Type() == node.Leaf { + if parent.Kind() == node.Leaf { parentKey := parent.Key fullKeyNibbles := concatenateSlices(prefix, parentKey) keyLE := string(codec.NibblesToKeyLE(fullKeyNibbles)) @@ -244,7 +244,7 @@ func findNextKey(parent *Node, prefix, searchKey []byte) (nextKey []byte) { return nil } - if parent.Type() == node.Leaf { + if parent.Kind() == node.Leaf { return findNextKeyLeaf(parent, prefix, searchKey) } return findNextKeyBranch(parent, prefix, searchKey) @@ -319,11 +319,7 @@ func findNextKeyChild(children []*Node, startIndex byte, // key specified in little Endian format. func (t *Trie) Put(keyLE, value []byte) { nibblesKey := codec.KeyLEToNibbles(keyLE) - t.put(nibblesKey, value) -} - -func (t *Trie) put(key, value []byte) { - t.root, _ = t.insert(t.root, key, value) + t.root, _ = t.insert(t.root, nibblesKey, value) } // insert inserts a value in the trie at the key specified. @@ -341,7 +337,7 @@ func (t *Trie) insert(parent *Node, key, value []byte) (newParent *Node, nodesCr // TODO ensure all values have dirty set to true - if parent.Type() == node.Branch { + if parent.Kind() == node.Branch { return t.insertInBranch(parent, key, value) } return t.insertInLeaf(parent, key, value) @@ -529,7 +525,7 @@ func getKeysWithPrefix(parent *Node, prefix, key []byte, return keysLE } - if parent.Type() == node.Leaf { + if parent.Kind() == node.Leaf { return getKeysWithPrefixFromLeaf(parent, prefix, key, keysLE) } @@ -575,7 +571,7 @@ func addAllKeys(parent *Node, prefix []byte, keysLE [][]byte) (newKeysLE [][]byt return keysLE } - if parent.Type() == node.Leaf { + if parent.Kind() == node.Leaf { keyLE := makeFullKeyLE(prefix, parent.Key) keysLE = append(keysLE, keyLE) return keysLE @@ -619,7 +615,7 @@ func retrieve(parent *Node, key []byte) (value []byte) { return nil } - if parent.Type() == node.Leaf { + if parent.Kind() == node.Leaf { return retrieveFromLeaf(parent, key) } return retrieveFromBranch(parent, key) @@ -660,20 +656,20 @@ func (t *Trie) ClearPrefixLimit(prefixLE []byte, limit uint32) (deleted uint32, prefix := codec.KeyLEToNibbles(prefixLE) prefix = bytes.TrimSuffix(prefix, []byte{0}) - t.root, deleted, _, allDeleted = t.clearPrefixLimit(t.root, prefix, limit) + t.root, deleted, _, allDeleted = t.clearPrefixLimitAtNode(t.root, prefix, limit) return deleted, allDeleted } -// clearPrefixLimit deletes the keys having the prefix until the value deletion limit is reached. +// clearPrefixLimitAtNode deletes the keys having the prefix until the value deletion limit is reached. // It returns the updated node newParent, the number of deleted values valuesDeleted and the // allDeleted boolean indicating if there is no key left with the prefix. -func (t *Trie) clearPrefixLimit(parent *Node, prefix []byte, limit uint32) ( +func (t *Trie) clearPrefixLimitAtNode(parent *Node, prefix []byte, limit uint32) ( newParent *Node, valuesDeleted, nodesRemoved uint32, allDeleted bool) { if parent == nil { return nil, 0, 0, true } - if parent.Type() == node.Leaf { + if parent.Kind() == node.Leaf { // if prefix is not found, it's also all deleted. // TODO check this is the same behaviour as in substrate const allDeleted = true @@ -716,7 +712,7 @@ func (t *Trie) clearPrefixLimitBranch(branch *Node, prefix []byte, limit uint32) childPrefix := prefix[len(branch.Key)+1:] child := branch.Children[childIndex] - child, valuesDeleted, nodesRemoved, allDeleted = t.clearPrefixLimit(child, childPrefix, limit) + child, valuesDeleted, nodesRemoved, allDeleted = t.clearPrefixLimitAtNode(child, childPrefix, limit) if valuesDeleted == 0 { return branch, valuesDeleted, nodesRemoved, allDeleted } @@ -780,7 +776,7 @@ func (t *Trie) deleteNodesLimit(parent *Node, prefix []byte, limit uint32) ( return nil, valuesDeleted, nodesRemoved } - if parent.Type() == node.Leaf { + if parent.Kind() == node.Leaf { valuesDeleted, nodesRemoved = 1, 1 return nil, valuesDeleted, nodesRemoved } @@ -809,7 +805,7 @@ func (t *Trie) deleteNodesLimit(parent *Node, prefix []byte, limit uint32) ( nodesRemoved += newNodesRemoved branch.Descendants -= newNodesRemoved - branch.SetDirty(true) + branch.SetDirty() newParent, branchChildMerged = handleDeletion(branch, fullKey) if branchChildMerged { @@ -845,10 +841,10 @@ func (t *Trie) ClearPrefix(prefixLE []byte) { prefix := codec.KeyLEToNibbles(prefixLE) prefix = bytes.TrimSuffix(prefix, []byte{0}) - t.root, _ = t.clearPrefix(t.root, prefix) + t.root, _ = t.clearPrefixAtNode(t.root, prefix) } -func (t *Trie) clearPrefix(parent *Node, prefix []byte) ( +func (t *Trie) clearPrefixAtNode(parent *Node, prefix []byte) ( newParent *Node, nodesRemoved uint32) { if parent == nil { const nodesRemoved = 0 @@ -860,7 +856,7 @@ func (t *Trie) clearPrefix(parent *Node, prefix []byte) ( return nil, nodesRemoved } - if parent.Type() == node.Leaf { + if parent.Kind() == node.Leaf { const nodesRemoved = 0 return parent, nodesRemoved } @@ -901,7 +897,7 @@ func (t *Trie) clearPrefix(parent *Node, prefix []byte) ( childPrefix := prefix[len(branch.Key)+1:] child := branch.Children[childIndex] - child, nodesRemoved = t.clearPrefix(child, childPrefix) + child, nodesRemoved = t.clearPrefixAtNode(child, childPrefix) if nodesRemoved == 0 { return parent, nodesRemoved } @@ -923,17 +919,17 @@ func (t *Trie) clearPrefix(parent *Node, prefix []byte) ( // If no node is found at this key, nothing is deleted. func (t *Trie) Delete(keyLE []byte) { key := codec.KeyLEToNibbles(keyLE) - t.root, _, _ = t.delete(t.root, key) + t.root, _, _ = t.deleteAtNode(t.root, key) } -func (t *Trie) delete(parent *Node, key []byte) ( +func (t *Trie) deleteAtNode(parent *Node, key []byte) ( newParent *Node, deleted bool, nodesRemoved uint32) { if parent == nil { const nodesRemoved = 0 return nil, false, nodesRemoved } - if parent.Type() == node.Leaf { + if parent.Kind() == node.Leaf { if deleteLeaf(parent, key) == nil { const nodesRemoved = 1 return nil, true, nodesRemoved @@ -978,7 +974,7 @@ func (t *Trie) deleteBranch(branch *Node, key []byte) ( childKey := key[commonPrefixLength+1:] child := branch.Children[childIndex] - newChild, deleted, nodesRemoved := t.delete(child, childKey) + newChild, deleted, nodesRemoved := t.deleteAtNode(child, childKey) if !deleted { const nodesRemoved = 0 return branch, false, nodesRemoved @@ -1034,7 +1030,7 @@ func handleDeletion(branch *Node, key []byte) (newNode *Node, branchChildMerged childIndex := firstChildIndex child := branch.Children[firstChildIndex] - if child.Type() == node.Leaf { + if child.Kind() == node.Leaf { newLeafKey := concatenateSlices(branch.Key, intToByteSlice(childIndex), child.Key) return &Node{ Key: newLeafKey, diff --git a/lib/trie/trie_endtoend_test.go b/lib/trie/trie_endtoend_test.go index 7abc7288bb..3bb02256da 100644 --- a/lib/trie/trie_endtoend_test.go +++ b/lib/trie/trie_endtoend_test.go @@ -497,7 +497,7 @@ func TestClearPrefix_Small(t *testing.T) { Value: []byte("other"), Generation: 1, } - expectedRoot.SetDirty(true) + expectedRoot.SetDirty() require.Equal(t, expectedRoot, ssTrie.root) diff --git a/lib/trie/trie_test.go b/lib/trie/trie_test.go index b27b68cdd3..012b2d4f49 100644 --- a/lib/trie/trie_test.go +++ b/lib/trie/trie_test.go @@ -157,7 +157,11 @@ func Test_Trie_updateGeneration(t *testing.T) { // Check for deep copy if newNode != nil && testCase.copied { - newNode.SetDirty(!newNode.Dirty) + if newNode.Dirty { + newNode.SetClean() + } else { + newNode.SetDirty() + } assert.NotEqual(t, testCase.node, newNode) } }) @@ -1055,106 +1059,6 @@ func Test_Trie_Put(t *testing.T) { } } -func Test_Trie_put(t *testing.T) { - t.Parallel() - - testCases := map[string]struct { - trie Trie - key []byte - value []byte - expectedTrie Trie - }{ - "nil everything": { - trie: Trie{ - generation: 1, - }, - expectedTrie: Trie{ - generation: 1, - root: &Node{ - Generation: 1, - Dirty: true, - }, - }, - }, - "empty trie with nil key and value": { - trie: Trie{ - generation: 1, - }, - value: []byte{3, 4}, - expectedTrie: Trie{ - generation: 1, - root: &Node{ - Value: []byte{3, 4}, - Generation: 1, - Dirty: true, - }, - }, - }, - "empty trie with key and value": { - trie: Trie{ - generation: 1, - }, - key: []byte{1, 2}, - value: []byte{3, 4}, - expectedTrie: Trie{ - generation: 1, - root: &Node{ - Key: []byte{1, 2}, - Value: []byte{3, 4}, - Generation: 1, - Dirty: true, - }, - }, - }, - "trie with key and value": { - trie: Trie{ - generation: 1, - root: &Node{ - Key: []byte{1, 0, 5}, - Value: []byte{1}, - }, - }, - key: []byte{1, 1, 6}, - value: []byte{2}, - expectedTrie: Trie{ - generation: 1, - root: &Node{ - Key: []byte{1}, - Generation: 1, - Dirty: true, - Descendants: 2, - Children: padRightChildren([]*Node{ - { - Key: []byte{5}, - Value: []byte{1}, - Generation: 1, - Dirty: true, - }, - { - Key: []byte{6}, - Value: []byte{2}, - Generation: 1, - Dirty: true, - }, - }), - }, - }, - }, - } - - for name, testCase := range testCases { - testCase := testCase - t.Run(name, func(t *testing.T) { - t.Parallel() - - trie := testCase.trie - trie.put(testCase.key, testCase.value) - - assert.Equal(t, testCase.expectedTrie, trie) - }) - } -} - func Test_Trie_insert(t *testing.T) { t.Parallel() @@ -2220,7 +2124,7 @@ func Test_Trie_ClearPrefixLimit(t *testing.T) { } } -func Test_Trie_clearPrefixLimit(t *testing.T) { +func Test_Trie_clearPrefixLimitAtNode(t *testing.T) { t.Parallel() testCases := map[string]struct { @@ -2749,7 +2653,7 @@ func Test_Trie_clearPrefixLimit(t *testing.T) { expectedTrie := *trie.DeepCopy() newParent, valuesDeleted, nodesRemoved, allDeleted := - trie.clearPrefixLimit(testCase.parent, testCase.prefix, testCase.limit) + trie.clearPrefixLimitAtNode(testCase.parent, testCase.prefix, testCase.limit) assert.Equal(t, testCase.newParent, newParent) assert.Equal(t, testCase.valuesDeleted, valuesDeleted) @@ -3009,7 +2913,7 @@ func Test_Trie_ClearPrefix(t *testing.T) { } } -func Test_Trie_clearPrefix(t *testing.T) { +func Test_Trie_clearPrefixAtNode(t *testing.T) { t.Parallel() testCases := map[string]struct { @@ -3295,7 +3199,7 @@ func Test_Trie_clearPrefix(t *testing.T) { expectedTrie := *trie.DeepCopy() newParent, nodesRemoved := - trie.clearPrefix(testCase.parent, testCase.prefix) + trie.clearPrefixAtNode(testCase.parent, testCase.prefix) assert.Equal(t, testCase.newParent, newParent) assert.Equal(t, testCase.nodesRemoved, nodesRemoved) @@ -3395,7 +3299,7 @@ func Test_Trie_Delete(t *testing.T) { } } -func Test_Trie_delete(t *testing.T) { +func Test_Trie_deleteAtNode(t *testing.T) { t.Parallel() testCases := map[string]struct { @@ -3695,7 +3599,7 @@ func Test_Trie_delete(t *testing.T) { } expectedTrie := *testCase.trie.DeepCopy() - newParent, updated, nodesRemoved := testCase.trie.delete(testCase.parent, testCase.key) + newParent, updated, nodesRemoved := testCase.trie.deleteAtNode(testCase.parent, testCase.key) assert.Equal(t, testCase.newParent, newParent) assert.Equal(t, testCase.updated, updated) diff --git a/pkg/scale/uint128.go b/pkg/scale/uint128.go index ad782b2691..0b681ec907 100644 --- a/pkg/scale/uint128.go +++ b/pkg/scale/uint128.go @@ -118,7 +118,7 @@ func (u *Uint128) Compare(other *Uint128) int { return 0 } -func (u *Uint128) trimBytes(b []byte, order binary.ByteOrder) []byte { +func (*Uint128) trimBytes(b []byte, order binary.ByteOrder) []byte { switch order { case binary.LittleEndian: for { diff --git a/pkg/scale/varying_data_type_example_test.go b/pkg/scale/varying_data_type_example_test.go index a82fe81184..effa610c42 100644 --- a/pkg/scale/varying_data_type_example_test.go +++ b/pkg/scale/varying_data_type_example_test.go @@ -17,7 +17,7 @@ type MyStruct struct { Foo []byte } -func (ms MyStruct) Index() uint { +func (MyStruct) Index() uint { return 1 } @@ -27,13 +27,13 @@ type MyOtherStruct struct { Baz uint } -func (mos MyOtherStruct) Index() uint { +func (MyOtherStruct) Index() uint { return 2 } type MyInt16 int16 -func (mi16 MyInt16) Index() uint { +func (MyInt16) Index() uint { return 3 } diff --git a/pkg/scale/varying_data_type_nested_example_test.go b/pkg/scale/varying_data_type_nested_example_test.go index 54e1e67530..fa32a547cb 100644 --- a/pkg/scale/varying_data_type_nested_example_test.go +++ b/pkg/scale/varying_data_type_nested_example_test.go @@ -48,7 +48,7 @@ func NewParentVDT() ParentVDT { type ChildVDT scale.VaryingDataType // Index fulfils the VaryingDataTypeValue interface. T -func (cvdt ChildVDT) Index() uint { +func (ChildVDT) Index() uint { return 1 } @@ -87,12 +87,12 @@ func NewChildVDT() ChildVDT { type OtherChildVDT scale.VaryingDataType // Index fulfils the VaryingDataTypeValue interface. -func (ocvdt OtherChildVDT) Index() uint { +func (OtherChildVDT) Index() uint { return 2 } // Set will set a VaryingDataTypeValue using the underlying VaryingDataType -func (cvdt *OtherChildVDT) Set(val scale.VaryingDataTypeValue) (err error) { //nolint:revive +func (cvdt *OtherChildVDT) Set(val scale.VaryingDataTypeValue) (err error) { // cast to VaryingDataType to use VaryingDataType.Set method vdt := scale.VaryingDataType(*cvdt) err = vdt.Set(val) @@ -121,7 +121,7 @@ type ChildInt16 int16 // Index fulfils the VaryingDataTypeValue interface. The ChildVDT type is used as a // VaryingDataTypeValue for ParentVDT -func (ci ChildInt16) Index() uint { +func (ChildInt16) Index() uint { return 1 } @@ -132,7 +132,7 @@ type ChildStruct struct { } // Index fulfils the VaryingDataTypeValue interface -func (cs ChildStruct) Index() uint { +func (ChildStruct) Index() uint { return 2 } @@ -140,7 +140,7 @@ func (cs ChildStruct) Index() uint { type ChildString string // Index fulfils the VaryingDataTypeValue interface -func (cs ChildString) Index() uint { +func (ChildString) Index() uint { return 3 } diff --git a/pkg/scale/varying_data_type_nested_test.go b/pkg/scale/varying_data_type_nested_test.go index dec5dc2413..6056fc3ab5 100644 --- a/pkg/scale/varying_data_type_nested_test.go +++ b/pkg/scale/varying_data_type_nested_test.go @@ -33,7 +33,7 @@ func mustNewParentVDT() parentVDT { type childVDT VaryingDataType -func (cvdt childVDT) Index() uint { +func (childVDT) Index() uint { return 1 } @@ -59,7 +59,7 @@ func mustNewChildVDTAndSet(vdtv VaryingDataTypeValue) childVDT { type childVDT1 VaryingDataType -func (cvdt childVDT1) Index() uint { +func (childVDT1) Index() uint { return 2 } diff --git a/pkg/scale/varying_data_type_test.go b/pkg/scale/varying_data_type_test.go index 3c7b50d6ba..eb66b294cb 100644 --- a/pkg/scale/varying_data_type_test.go +++ b/pkg/scale/varying_data_type_test.go @@ -48,7 +48,7 @@ type VDTValue struct { N bool } -func (ctrd VDTValue) Index() uint { +func (VDTValue) Index() uint { return 1 } @@ -69,7 +69,7 @@ type VDTValue1 struct { AB *bool } -func (ctrd VDTValue1) Index() uint { +func (VDTValue1) Index() uint { return 2 } @@ -94,13 +94,13 @@ type VDTValue2 struct { P [2][2]byte } -func (ctrd VDTValue2) Index() uint { +func (VDTValue2) Index() uint { return 3 } type VDTValue3 int16 -func (ctrd VDTValue3) Index() uint { +func (VDTValue3) Index() uint { return 4 }