diff --git a/p2p/ipld/nmt_wrapper.go b/p2p/ipld/nmt_wrapper.go index 62abbb5958..0d7b9ed18e 100644 --- a/p2p/ipld/nmt_wrapper.go +++ b/p2p/ipld/nmt_wrapper.go @@ -26,14 +26,15 @@ type ErasuredNamespacedMerkleTree struct { // NewErasuredNamespacedMerkleTree issues a new ErasuredNamespacedMerkleTree func NewErasuredNamespacedMerkleTree(squareSize uint64, setters ...nmt.Option) ErasuredNamespacedMerkleTree { - return ErasuredNamespacedMerkleTree{squareSize: squareSize, options: setters} + tree := nmt.New(sha256.New(), setters...) + return ErasuredNamespacedMerkleTree{squareSize: squareSize, options: setters, tree: tree} } // Constructor acts as the rsmt2d.TreeConstructorFn for // ErasuredNamespacedMerkleTree func (w ErasuredNamespacedMerkleTree) Constructor() rsmt2d.Tree { - w.tree = nmt.New(sha256.New(), w.options...) - return &w + newTree := NewErasuredNamespacedMerkleTree(w.squareSize, w.options...) + return &newTree } // Push adds the provided data to the underlying NamespaceMerkleTree, and diff --git a/p2p/ipld/read_test.go b/p2p/ipld/read_test.go index 014642bfd8..6a17bb2776 100644 --- a/p2p/ipld/read_test.go +++ b/p2p/ipld/read_test.go @@ -3,8 +3,9 @@ package ipld import ( "bytes" "context" - "crypto/rand" "crypto/sha256" + "math" + "math/rand" "sort" "strings" "testing" @@ -18,7 +19,9 @@ import ( "github.com/lazyledger/lazyledger-core/p2p/ipld/plugin/nodes" "github.com/lazyledger/lazyledger-core/types" "github.com/lazyledger/nmt" + "github.com/lazyledger/rsmt2d" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) func TestLeafPath(t *testing.T) { @@ -153,6 +156,88 @@ func TestGetLeafData(t *testing.T) { } } +func TestBlockRecovery(t *testing.T) { + // adjustedLeafSize describes the size of a leaf that will not get split + adjustedLeafSize := types.MsgShareSize + + originalSquareWidth := 2 + sharecount := originalSquareWidth * originalSquareWidth + extendedSquareWidth := originalSquareWidth * originalSquareWidth + extendedShareCount := extendedSquareWidth * extendedSquareWidth + + // generate test data + quarterShares := generateRandNamespacedRawData(sharecount, types.NamespaceSize, adjustedLeafSize) + allShares := generateRandNamespacedRawData(sharecount, types.NamespaceSize, adjustedLeafSize) + + testCases := []struct { + name string + // blockData types.Data + shares [][]byte + expectErr bool + errString string + d int // number of shares to delete + }{ + // missing more shares causes RepairExtendedDataSquare to hang see + // https://github.com/lazyledger/rsmt2d/issues/21 + {"missing 1/4 shares", quarterShares, false, "", extendedShareCount / 4}, + {"missing all but one shares", allShares, true, "failed to solve data square", extendedShareCount - 1}, + } + for _, tc := range testCases { + tc := tc + + t.Run(tc.name, func(t *testing.T) { + squareSize := uint64(math.Sqrt(float64(len(tc.shares)))) + + // create trees for creating roots + tree := NewErasuredNamespacedMerkleTree(squareSize) + recoverTree := NewErasuredNamespacedMerkleTree(squareSize) + + eds, err := rsmt2d.ComputeExtendedDataSquare(tc.shares, rsmt2d.RSGF8, tree.Constructor) + if err != nil { + t.Error(err) + } + + // calculate roots using the first complete square + rowRoots := eds.RowRoots() + colRoots := eds.ColumnRoots() + + flat := flatten(eds) + + // recover a partially complete square + reds, err := rsmt2d.RepairExtendedDataSquare( + rowRoots, + colRoots, + removeRandShares(flat, tc.d), + rsmt2d.RSGF8, + recoverTree.Constructor, + ) + + if tc.expectErr { + require.Error(t, err) + require.Contains(t, err.Error(), tc.errString) + return + } + + require.NoError(t, err) + + // check that the squares are equal + assert.Equal(t, flatten(eds), flatten(reds)) + }) + } +} + +func flatten(eds *rsmt2d.ExtendedDataSquare) [][]byte { + out := make([][]byte, eds.Width()*eds.Width()) + count := 0 + for i := uint(0); i < eds.Width(); i++ { + for _, share := range eds.Row(i) { + out[count] = share + count++ + } + } + return out +} + // nmtcommitment generates the nmt root of some namespaced data func createNmtTree( ctx context.Context, @@ -199,3 +284,18 @@ func generateRandNamespacedRawData(total int, nidSize int, leafSize int) [][]byt func sortByteArrays(src [][]byte) { sort.Slice(src, func(i, j int) bool { return bytes.Compare(src[i], src[j]) < 0 }) } + +// removes d shares from data +func removeRandShares(data [][]byte, d int) [][]byte { + count := len(data) + // remove shares randomly + for i := 0; i < d; { + ind := rand.Intn(count) + if len(data[ind]) == 0 { + continue + } + data[ind] = nil + i++ + } + return data +} diff --git a/types/block_test.go b/types/block_test.go index 23b69a7054..a1250af095 100644 --- a/types/block_test.go +++ b/types/block_test.go @@ -5,9 +5,9 @@ import ( // number generator here and we can run the tests a bit faster stdbytes "bytes" "context" - "crypto/rand" "encoding/hex" "math" + "math/rand" "os" "reflect" "sort" @@ -195,8 +195,8 @@ func makeBlockIDRandom() BlockID { blockHash = make([]byte, tmhash.Size) partSetHash = make([]byte, tmhash.Size) ) - rand.Read(blockHash) //nolint: errcheck // ignore errcheck for read - rand.Read(partSetHash) //nolint: errcheck // ignore errcheck for read + rand.Read(blockHash) + rand.Read(partSetHash) return BlockID{blockHash, PartSetHeader{123, partSetHash}} } @@ -1335,10 +1335,10 @@ func TestPutBlock(t *testing.T) { expectErr bool errString string }{ - {"no leaves", generateRandomData(0), false, ""}, - {"single leaf", generateRandomData(1), false, ""}, - {"16 leaves", generateRandomData(16), false, ""}, - {"max square size", generateRandomData(MaxSquareSize), false, ""}, + {"no leaves", generateRandomMsgOnlyData(0), false, ""}, + {"single leaf", generateRandomMsgOnlyData(1), false, ""}, + {"16 leaves", generateRandomMsgOnlyData(16), false, ""}, + {"max square size", generateRandomMsgOnlyData(MaxSquareSize), false, ""}, } ctx := context.Background() for _, tc := range testCases { @@ -1360,7 +1360,6 @@ func TestPutBlock(t *testing.T) { defer cancel() block.fillDataAvailabilityHeader() - tc.blockData.ComputeShares() for _, rowRoot := range block.DataAvailabilityHeader.RowsRoots.Bytes() { // recreate the cids using only the computed roots cid, err := nodes.CidFromNamespacedSha256(rowRoot) @@ -1387,10 +1386,10 @@ func TestPutBlock(t *testing.T) { } } -func generateRandomData(msgCount int) Data { +func generateRandomMsgOnlyData(msgCount int) Data { out := make([]Message, msgCount) - for i, msg := range generateRandNamespacedRawData(msgCount, NamespaceSize, ShareSize) { - out[i] = Message{NamespaceID: msg[:NamespaceSize], Data: msg[:NamespaceSize]} + for i, msg := range generateRandNamespacedRawData(msgCount, NamespaceSize, MsgShareSize-2) { + out[i] = Message{NamespaceID: msg[:NamespaceSize], Data: msg[NamespaceSize:]} } return Data{ Messages: Messages{MessagesList: out}, diff --git a/types/shares.go b/types/shares.go index 2611921edb..649f74b774 100644 --- a/types/shares.go +++ b/types/shares.go @@ -52,7 +52,6 @@ func (m Message) MarshalDelimited() ([]byte, error) { lenBuf := make([]byte, binary.MaxVarintLen64) length := uint64(len(m.Data)) n := binary.PutUvarint(lenBuf, length) - return append(lenBuf[:n], m.Data...), nil } @@ -60,7 +59,11 @@ func (m Message) MarshalDelimited() ([]byte, error) { // Used for messages. func appendToShares(shares []NamespacedShare, nid namespace.ID, rawData []byte) []NamespacedShare { if len(rawData) <= MsgShareSize { - rawShare := []byte(append(nid, rawData...)) + rawShare := append(append( + make([]byte, 0, len(nid)+len(rawData)), + nid...), + rawData..., + ) paddedShare := zeroPadIfNecessary(rawShare, ShareSize) share := NamespacedShare{paddedShare, nid} shares = append(shares, share) @@ -82,7 +85,11 @@ func splitContiguous(nid namespace.ID, rawDatas [][]byte) []NamespacedShare { var rawData []byte startIndex := 0 rawData, outerIndex, innerIndex, startIndex = getNextChunk(rawDatas, outerIndex, innerIndex, TxShareSize) - rawShare := []byte(append(append(nid, byte(startIndex)), rawData...)) + rawShare := append(append(append( + make([]byte, 0, len(nid)+1+len(rawData)), + nid...), + byte(startIndex)), + rawData...) paddedShare := zeroPadIfNecessary(rawShare, ShareSize) share := NamespacedShare{paddedShare, nid} shares = append(shares, share) @@ -94,14 +101,20 @@ func splitContiguous(nid namespace.ID, rawDatas [][]byte) []NamespacedShare { // shares for a particular namespace func split(rawData []byte, nid namespace.ID) []NamespacedShare { shares := make([]NamespacedShare, 0) - firstRawShare := []byte(append(nid, rawData[:MsgShareSize]...)) + firstRawShare := append(append( + make([]byte, 0, len(nid)+len(rawData[:MsgShareSize])), + nid...), + rawData[:MsgShareSize]..., + ) shares = append(shares, NamespacedShare{firstRawShare, nid}) rawData = rawData[MsgShareSize:] for len(rawData) > 0 { shareSizeOrLen := min(MsgShareSize, len(rawData)) - rawShare := make([]byte, NamespaceSize) - copy(rawShare, nid) - rawShare = append(rawShare, rawData[:shareSizeOrLen]...) + rawShare := append(append( + make([]byte, 0, len(nid)+1+len(rawData[:shareSizeOrLen])), + nid...), + rawData[:shareSizeOrLen]..., + ) paddedShare := zeroPadIfNecessary(rawShare, ShareSize) share := NamespacedShare{paddedShare, nid} shares = append(shares, share)