Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

implement lazy root computation #18

Merged
merged 4 commits into from
Mar 26, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 32 additions & 11 deletions datasquare.go
Original file line number Diff line number Diff line change
Expand Up @@ -141,17 +141,8 @@ func (ds *dataSquare) computeRoots() {
rowRoots := make([][]byte, ds.width)
columnRoots := make([][]byte, ds.width)
for i := uint(0); i < ds.width; i++ {
rowTree := ds.createTreeFn()
columnTree := ds.createTreeFn()
rowData := ds.Row(i)
columnData := ds.Column(i)
for j := uint(0); j < ds.width; j++ {
rowTree.Push(rowData[j])
columnTree.Push(columnData[j])
}

rowRoots[i] = rowTree.Root()
columnRoots[i] = columnTree.Root()
rowRoots[i] = ds.RowRoot(i)
columnRoots[i] = ds.ColRoot(i)
}

ds.rowRoots = rowRoots
Expand All @@ -167,6 +158,21 @@ func (ds *dataSquare) RowRoots() [][]byte {
return ds.rowRoots
}

// RowRoot calculates and returns the root of the selected row. Note: unlike the
// RowRoots method, RowRoot uses the built-in cache when available.
func (ds *dataSquare) RowRoot(x uint) []byte {
if ds.rowRoots != nil {
return ds.rowRoots[x]
}

tree := ds.createTreeFn()
for _, d := range ds.Row(x) {
tree.Push(d)
}

return tree.Root()
}

// ColumnRoots returns the Merkle roots of all the columns in the square.
func (ds *dataSquare) ColumnRoots() [][]byte {
if ds.columnRoots == nil {
Expand All @@ -176,6 +182,21 @@ func (ds *dataSquare) ColumnRoots() [][]byte {
return ds.columnRoots
}

// ColRoot calculates and returns the root of the selected row. Note: unlike the
// ColRoots method, ColRoot does not use the built in cache
func (ds *dataSquare) ColRoot(y uint) []byte {
if ds.columnRoots != nil {
return ds.columnRoots[y]
}

tree := ds.createTreeFn()
for _, d := range ds.Column(y) {
tree.Push(d)
}

return tree.Root()
}

func (ds *dataSquare) computeRowProof(x uint, y uint) ([]byte, [][]byte, uint, uint, error) {
tree := ds.createTreeFn()
data := ds.Row(x)
Expand Down
45 changes: 45 additions & 0 deletions datasquare_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,51 @@ func TestRoots(t *testing.T) {
}
}

func TestLazyRootGeneration(t *testing.T) {
square, err := newDataSquare([][]byte{{1}, {2}, {3}, {4}}, NewDefaultTree)
if err != nil {
panic(err)
}

var rowRoots [][]byte
var colRoots [][]byte

for i := uint(0); i < square.width; i++ {
rowRoots = append(rowRoots, square.RowRoot(i))
colRoots = append(rowRoots, square.ColRoot(i))
}

square.computeRoots()
liamsi marked this conversation as resolved.
Show resolved Hide resolved

if !reflect.DeepEqual(square.rowRoots, rowRoots) && !reflect.DeepEqual(square.columnRoots, colRoots) {
t.Error("RowRoot or ColumnRoot did not produce identical roots to computeRoots")
}
}

func TestRootAPI(t *testing.T) {
square, err := newDataSquare([][]byte{{1}, {2}, {3}, {4}}, NewDefaultTree)
if err != nil {
panic(err)
}

for i := uint(0); i < square.width; i++ {
if !reflect.DeepEqual(square.RowRoots()[i], square.RowRoot(i)) {
t.Errorf(
"Row root API results in different roots, expected %v go %v",
square.RowRoots()[i],
square.RowRoot(i),
)
}
if !reflect.DeepEqual(square.ColumnRoots()[i], square.ColRoot(i)) {
t.Errorf(
"Column root API results in different roots, expected %v go %v",
square.ColumnRoots()[i],
square.ColRoot(i),
)
}
}
}

func TestProofs(t *testing.T) {
result, err := newDataSquare([][]byte{{1, 2}, {3, 4}, {5, 6}, {7, 8}}, NewDefaultTree)
if err != nil {
Expand Down
41 changes: 32 additions & 9 deletions extendeddatacrossword.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ package rsmt2d

import (
"bytes"
"errors"
"fmt"

"gonum.org/v1/gonum/mat"
Expand Down Expand Up @@ -170,11 +169,11 @@ func (eds *ExtendedDataSquare) solveCrossword(rowRoots [][]byte, columnRoots [][

// Check that rebuilt vector matches given merkle root
if mode == row {
if !bytes.Equal(eds.RowRoots()[i], rowRoots[i]) {
if !bytes.Equal(eds.RowRoot(i), rowRoots[i]) {
return &ByzantineRowError{i, edsBackup}
}
} else if mode == column {
if !bytes.Equal(eds.ColumnRoots()[i], columnRoots[i]) {
if !bytes.Equal(eds.ColRoot(i), columnRoots[i]) {
return &ByzantineColumnError{i, edsBackup}
}
}
Expand All @@ -184,12 +183,12 @@ func (eds *ExtendedDataSquare) solveCrossword(rowRoots [][]byte, columnRoots [][
if vectorMask.AtVec(int(j)) == 0 {
if mode == row {
adjMask := mask.ColView(int(j))
if vecNumTrue(adjMask) == adjMask.Len()-1 && !bytes.Equal(eds.ColumnRoots()[j], columnRoots[j]) {
if vecNumTrue(adjMask) == adjMask.Len()-1 && !bytes.Equal(eds.ColRoot(j), columnRoots[j]) {
return &ByzantineColumnError{j, edsBackup}
}
} else if mode == column {
adjMask := mask.RowView(int(j))
if vecNumTrue(adjMask) == adjMask.Len()-1 && !bytes.Equal(eds.RowRoots()[j], rowRoots[j]) {
if vecNumTrue(adjMask) == adjMask.Len()-1 && !bytes.Equal(eds.RowRoot(j), rowRoots[j]) {
Copy link
Member

@liamsi liamsi Mar 23, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Orthogonal to this PR but we should really refactor this. The massive indentation / nesting alone cries for splitting this up into multiple methods...

return &ByzantineRowError{j, edsBackup}
}
}
Expand Down Expand Up @@ -229,11 +228,26 @@ func (eds *ExtendedDataSquare) prerepairSanityCheck(rowRoots [][]byte, columnRoo
for i := uint(0); i < eds.width; i++ {
rowMask := mask.RowView(int(i))
columnMask := mask.ColView(int(i))
if (vecIsTrue(rowMask) && !bytes.Equal(rowRoots[i], eds.RowRoots()[i])) || (vecIsTrue(columnMask) && !bytes.Equal(columnRoots[i], eds.ColumnRoots()[i])) {
return errors.New("bad roots input")
rowMaskIsVec := vecIsTrue(rowMask)
columnMaskIsVec := vecIsTrue(columnMask)

// if there's no missing data in the this row
if noMissingData(eds.Row(i)) {
// ensure that the roots are equal and that rowMask is a vector
if rowMaskIsVec && !bytes.Equal(rowRoots[i], eds.RowRoot(i)) {
return fmt.Errorf("bad root input: row %d expected %v got %v", i, rowRoots[i], eds.RowRoot(i))
}
}

if vecIsTrue(rowMask) {
// if there's no missing data in the this col
if noMissingData(eds.Column(i)) {
// ensure that the roots are equal and that rowMask is a vector
if columnMaskIsVec && !bytes.Equal(columnRoots[i], eds.ColRoot(i)) {
return fmt.Errorf("bad root input: col %d expected %v got %v", i, columnRoots[i], eds.ColRoot(i))
}
}

if rowMaskIsVec {
shares, err = Encode(eds.rowSlice(i, 0, eds.originalDataWidth), eds.codec)
if err != nil {
return err
Expand All @@ -243,7 +257,7 @@ func (eds *ExtendedDataSquare) prerepairSanityCheck(rowRoots [][]byte, columnRoo
}
}

if vecIsTrue(columnMask) {
if columnMaskIsVec {
shares, err = Encode(eds.columnSlice(0, i, eds.originalDataWidth), eds.codec)
if err != nil {
return err
Expand All @@ -257,6 +271,15 @@ func (eds *ExtendedDataSquare) prerepairSanityCheck(rowRoots [][]byte, columnRoo
return nil
}

func noMissingData(input [][]byte) bool {
for _, d := range input {
if d == nil {
return false
}
}
return true
}

func vecIsTrue(vec mat.Vector) bool {
for i := 0; i < vec.Len(); i++ {
if vec.AtVec(i) == 0 {
Expand Down