diff --git a/nmt_test.go b/nmt_test.go index 10d3d2b..8c157f2 100644 --- a/nmt_test.go +++ b/nmt_test.go @@ -1300,7 +1300,7 @@ func TestComputeSubtreeRoot(t *testing.T) { }, { start: 0, - end: 16, + end: 2, tree: func() *NamespacedMerkleTree { return exampleNMT2(1, true, 0, 1, 2, 3, 4) // tree leaves are not a power of 2 }(), diff --git a/proof.go b/proof.go index f8fae0e..bde6723 100644 --- a/proof.go +++ b/proof.go @@ -512,27 +512,6 @@ func (proof Proof) VerifySubtreeRootInclusion(nth *NmtHasher, subtreeRoots [][]b var computeRoot func(start, end int) ([]byte, error) // computeRoot can return error iff the HashNode function fails while calculating the root computeRoot = func(start, end int) ([]byte, error) { - // reached a leaf - if end-start == 1 { - // if the leaf index falls within the proof range, pop and return a - // leaf - if proof.Start() <= start && start < proof.End() { - // this check should always be the case. - // however, it is added to avoid nil pointer exceptions - if len(ranges) != 0 { - // advance the list of ranges - ranges = ranges[1:] - } - // advance leafHashes - return popIfNonEmpty(&subtreeRoots), nil - } - - // if the leaf index is outside the proof range, pop and return a - // proof node (which in this case is a leaf) if present, else return - // nil because leaf doesn't exist - return popIfNonEmpty(&proof.nodes), nil - } - // if the current range does not overlap with the proof range, pop and // return a proof node if present, else return nil because subtree // doesn't exist @@ -540,7 +519,11 @@ func (proof Proof) VerifySubtreeRootInclusion(nth *NmtHasher, subtreeRoots [][]b return popIfNonEmpty(&proof.nodes), nil } - if len(ranges) != 0 && ranges[0].Start == start && ranges[0].End == end { + if len(ranges) == 0 { + return nil, fmt.Errorf(fmt.Sprintf("expected to have a subtree root for range [%d, %d)", start, end)) + } + + if ranges[0].Start == start && ranges[0].End == end { ranges = ranges[1:] return popIfNonEmpty(&subtreeRoots), nil } diff --git a/proof_test.go b/proof_test.go index f0d22c7..b2ea98f 100644 --- a/proof_test.go +++ b/proof_test.go @@ -1798,6 +1798,22 @@ func TestVerifySubtreeRootInclusion(t *testing.T) { root: root, expectError: true, }, + + { + proof: func() Proof { + p, err := tree.ProveRange(0, 8) + require.NoError(t, err) + return p + }(), + subtreeRoots: func() [][]byte { + subtreeRoot1, err := tree.ComputeSubtreeRoot(0, 4) + require.NoError(t, err) + return [][]byte{subtreeRoot1} // will error because it requires the subtree root of [4,8) too + }(), + subtreeWidth: 4, + root: root, + expectError: true, + }, } for _, tt := range tests {