diff --git a/sum.go b/sum.go index 53c1805..c18deb1 100644 --- a/sum.go +++ b/sum.go @@ -2,6 +2,7 @@ package multihash import ( "errors" + "fmt" ) // ErrSumNotSupported is returned when the Sum function code is not implemented @@ -27,6 +28,7 @@ func Sum(data []byte, code uint64, length int) (Multihash, error) { sum := hasher.Sum(nil) // Deal with any truncation. + // Unless it's an identity multihash. Those have different rules. if length < 0 { length = hasher.Size() } @@ -34,6 +36,11 @@ func Sum(data []byte, code uint64, length int) (Multihash, error) { return nil, ErrLenTooLarge } if length >= 0 { + if code == IDENTITY { + if length != len(sum) { + return nil, fmt.Errorf("the length of the identity hash (%d) must be equal to the length of the data (%d)", length, len(sum)) + } + } sum = sum[:length] } diff --git a/sum_test.go b/sum_test.go index c2e9c1b..512231c 100644 --- a/sum_test.go +++ b/sum_test.go @@ -123,6 +123,33 @@ func BenchmarkBlake2B(b *testing.B) { } } +func TestSmallerLengthHashID(t *testing.T) { + + data := []byte("Identity hash input data.") + dataLength := len(data) + + // Normal case: `length == len(data)`. + _, err := multihash.Sum(data, multihash.ID, dataLength) + if err != nil { + t.Fatal(err) + } + + // Unconstrained length (-1): also allowed. + _, err = multihash.Sum(data, multihash.ID, -1) + if err != nil { + t.Fatal(err) + } + + // Any other variation of those two scenarios should fail. + for l := (dataLength - 1); l >= 0; l-- { + _, err = multihash.Sum(data, multihash.ID, l) + if err == nil { + t.Fatal(fmt.Sprintf("identity hash of length %d smaller than data length %d didn't fail", + l, dataLength)) + } + } +} + func TestTooLargeLength(t *testing.T) { _, err := multihash.Sum([]byte("test"), multihash.SHA2_256, 33) if err != multihash.ErrLenTooLarge {