Skip to content

Commit

Permalink
SCALE Updates (#371)
Browse files Browse the repository at this point in the history
- Fixes string decoding issues
- Adds encode->decode test and fixes uncovered issues
  • Loading branch information
ansermino committed Nov 5, 2019
1 parent d906a02 commit 8fefce8
Show file tree
Hide file tree
Showing 4 changed files with 165 additions and 23 deletions.
80 changes: 63 additions & 17 deletions codec/decode.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ import (
"bytes"
"encoding/binary"
"errors"
"fmt"
"io"
"math/big"
"reflect"
Expand Down Expand Up @@ -151,7 +152,20 @@ func (sd *Decoder) DecodeFixedWidthInt(t interface{}) (o interface{}, err error)
if err == nil {
o = binary.LittleEndian.Uint64(buf)
}
case int:
buf := make([]byte, 8)
_, err = sd.Reader.Read(buf)
if err == nil {
o = int(binary.LittleEndian.Uint64(buf))
}
case uint:
buf := make([]byte, 8)
_, err = sd.Reader.Read(buf)
if err == nil {
o = uint(binary.LittleEndian.Uint64(buf))
}
}

return o, err
}

Expand Down Expand Up @@ -283,8 +297,10 @@ func (sd *Decoder) DecodeInterface(t interface{}) (interface{}, error) {
}
case reflect.Slice, reflect.Array:
return sd.DecodeArray(t)
default:
case reflect.Struct:
return sd.DecodeTuple(t)
default:
return nil, fmt.Errorf("unexpected kind: %s", reflect.ValueOf(t).Kind())
}
}

Expand Down Expand Up @@ -352,24 +368,29 @@ func (sd *Decoder) DecodeArray(t interface{}) (interface{}, error) {
// Note that we return the same interface that was passed to this function; this is because we are writing directly to the
// struct that is passed in, using reflect to get each of the fields.
func (sd *Decoder) DecodeTuple(t interface{}) (interface{}, error) {
v := reflect.ValueOf(t).Elem()
var v reflect.Value
switch reflect.ValueOf(t).Kind() {
case reflect.Ptr:
v = reflect.ValueOf(t).Elem()
default:
v = reflect.ValueOf(t)
}

var err error
var o interface{}

val := reflect.Indirect(reflect.ValueOf(t))

// iterate through each field in the struct
for i := 0; i < v.NumField(); i++ {
// get the field value at i
fieldValue := val.Field(i)
field := v.Field(i)
fieldValue := field.Addr().Interface()

switch v.Field(i).Interface().(type) {
case byte:
b := make([]byte, 1)
_, err = sd.Reader.Read(b)

ptr := fieldValue.Addr().Interface().(*byte)
ptr := fieldValue.(*byte)
*ptr = b[0]
case []byte:
o, err = sd.DecodeByteArray()
Expand All @@ -378,86 +399,111 @@ func (sd *Decoder) DecodeTuple(t interface{}) (interface{}, error) {
}

// get the pointer to the value and set the value
ptr := fieldValue.Addr().Interface().(*[]byte)
ptr := fieldValue.(*[]byte)
*ptr = o.([]byte)
case int8:
o, err = sd.DecodeFixedWidthInt(int8(0))
if err != nil {
break
}

ptr := fieldValue.Addr().Interface().(*int8)
ptr := fieldValue.(*int8)
*ptr = o.(int8)
case int16:
o, err = sd.DecodeFixedWidthInt(int16(0))
if err != nil {
break
}

ptr := fieldValue.Addr().Interface().(*int16)
ptr := fieldValue.(*int16)
*ptr = o.(int16)
case int32:
o, err = sd.DecodeFixedWidthInt(int32(0))
if err != nil {
break
}

ptr := fieldValue.Addr().Interface().(*int32)
ptr := fieldValue.(*int32)
*ptr = o.(int32)
case int64:
o, err = sd.DecodeFixedWidthInt(int64(0))
if err != nil {
break
}

ptr := fieldValue.Addr().Interface().(*int64)
ptr := fieldValue.(*int64)
*ptr = o.(int64)
case uint16:
o, err = sd.DecodeFixedWidthInt(uint16(0))
if err != nil {
break
}

ptr := fieldValue.Addr().Interface().(*uint16)
ptr := fieldValue.(*uint16)
*ptr = o.(uint16)
case uint32:
o, err = sd.DecodeFixedWidthInt(uint32(0))
if err != nil {
break
}

ptr := fieldValue.Addr().Interface().(*uint32)
ptr := fieldValue.(*uint32)
*ptr = o.(uint32)
case uint64:
o, err = sd.DecodeFixedWidthInt(uint64(0))
if err != nil {
break
}

ptr := fieldValue.Addr().Interface().(*uint64)
ptr := fieldValue.(*uint64)
*ptr = o.(uint64)
case int:
o, err = sd.DecodeFixedWidthInt(int(0))
if err != nil {
break
}

ptr := fieldValue.(*int)
*ptr = o.(int)
case uint:
o, err = sd.DecodeFixedWidthInt(uint(0))
if err != nil {
break
}

ptr := fieldValue.(*uint)
*ptr = o.(uint)
case bool:
o, err = sd.DecodeBool()
if err != nil {
break
}

ptr := fieldValue.Addr().Interface().(*bool)
ptr := fieldValue.(*bool)
*ptr = o.(bool)
case *big.Int:
o, err = sd.DecodeBigInt()
if err != nil {
break
}

ptr := fieldValue.Addr().Interface().(**big.Int)
ptr := fieldValue.(**big.Int)
*ptr = o.(*big.Int)
case common.Hash:
b := make([]byte, 32)
_, err = sd.Reader.Read(b)

ptr := fieldValue.Addr().Interface().(*common.Hash)
ptr := fieldValue.(*common.Hash)
*ptr = common.NewHash(b)
case string:
o, err = sd.DecodeByteArray()
if err != nil {
break
}

// get the pointer to the value and set the value
ptr := fieldValue.(*string)
*ptr = string(o.([]byte))
default:
_, err = sd.Decode(v.Field(i).Interface())
if err != nil {
Expand Down
67 changes: 67 additions & 0 deletions codec/enc_dec_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
package codec

import (
"reflect"
"testing"
)

func TestEncodeDecodeComplexStruct(t *testing.T) {
type SimpleStruct struct {
A int64
B bool
}

type ComplexStruct struct {
B bool
I int
I8 int8
I16 int16
I32 int32
I64 int64
U uint
U8 uint8
U16 uint16
U32 uint32
U64 uint64
Str string
Bz []byte
Sub *SimpleStruct
}

test := &ComplexStruct{
B: true,
I: 1,
I8: 2,
I16: 3,
I32: 4,
I64: 5,
U: 6,
U8: 7,
U16: 8,
U32: 9,
U64: 10,
Str: "choansafe",
Bz: []byte{0xDE, 0xAD, 0xBE, 0xEF},
Sub: &SimpleStruct{
A: 99,
B: true,
},
}

enc, err := Encode(test)
if err != nil {
t.Fatal(err)
}

res := &ComplexStruct{
Sub: &SimpleStruct{},
}
output, err := Decode(enc, res)
if err != nil {
t.Fatal(err)
}

if !reflect.DeepEqual(output.(*ComplexStruct), test) {
t.Errorf("Fail: got %+v expected %+v", output.(*ComplexStruct), test)
}
}
18 changes: 12 additions & 6 deletions codec/encode.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ package codec
import (
"bytes"
"encoding/binary"
"errors"
"fmt"
"io"
"math/big"
"reflect"
Expand Down Expand Up @@ -50,7 +50,7 @@ func (se *Encoder) Encode(b interface{}) (n int, err error) {
n, err = se.encodeByteArray(v[:])
case *big.Int:
n, err = se.encodeBigInteger(v)
case int8, uint8, int16, uint16, int32, uint32, int64, uint64:
case int, uint, int8, uint8, int16, uint16, int32, uint32, int64, uint64:
n, err = se.encodeFixedWidthInteger(v)
case string:
n, err = se.encodeByteArray([]byte(v))
Expand All @@ -68,10 +68,10 @@ func (se *Encoder) Encode(b interface{}) (n int, err error) {
case reflect.Slice, reflect.Array:
n, err = se.encodeArray(v)
default:
return 0, errors.New("unsupported type")
return 0, fmt.Errorf("unsupported type: %T", b)
}
default:
return 0, errors.New("unsupported type")
return 0, fmt.Errorf("unsupported type: %T", b)
}

return n, err
Expand Down Expand Up @@ -120,8 +120,14 @@ func (se *Encoder) encodeFixedWidthInteger(in interface{}) (bytesEncoded int, er
case uint64:
err = binary.Write(se.Writer, binary.LittleEndian, uint64(i))
bytesEncoded = 8
case int:
err = binary.Write(se.Writer, binary.LittleEndian, int64(i))
bytesEncoded = 8
case uint:
err = binary.Write(se.Writer, binary.LittleEndian, uint64(i))
bytesEncoded = 8
default:
err = errors.New("could not encode fixed width int: invalid type")
err = fmt.Errorf("could not encode fixed width int, invalid type: %T", in)
}

return bytesEncoded, err
Expand Down Expand Up @@ -223,7 +229,7 @@ func (se *Encoder) encodeTuple(t interface{}) (bytesEncoded int, err error) {
switch reflect.ValueOf(t).Kind() {
case reflect.Ptr:
v = reflect.ValueOf(t).Elem()
case reflect.Slice, reflect.Array:
case reflect.Slice, reflect.Array, reflect.Struct:
v = reflect.ValueOf(t)
}

Expand Down
23 changes: 23 additions & 0 deletions codec/encode_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ package codec
import (
"bytes"
"math/big"
"reflect"
"strings"
"testing"
)
Expand Down Expand Up @@ -167,3 +168,25 @@ func TestEncode(t *testing.T) {
}
}
}

func TestEncodeAndDecodeStringInStruct(t *testing.T) {
test := &struct {
A string
}{
A: "noot",
}

enc, err := Encode(test)
if err != nil {
t.Fatal(err)
}

dec, err := Decode(enc, &struct{ A string }{A: ""})
if err != nil {
t.Fatal(err)
}

if !reflect.DeepEqual(test, dec) {
t.Fatalf("Fail: got %v expected %v", dec, test)
}
}

0 comments on commit 8fefce8

Please sign in to comment.