diff --git a/pkg/scale/encode.go b/pkg/scale/encode.go index 9074d03aa2..614581fedb 100644 --- a/pkg/scale/encode.go +++ b/pkg/scale/encode.go @@ -7,25 +7,48 @@ import ( "bytes" "encoding/binary" "fmt" + "io" "math/big" "reflect" ) +// Encoder scale encodes to a given io.Writer. +type Encoder struct { + encodeState +} + +// NewEncoder creates a new encoder with the given writer. +func NewEncoder(writer io.Writer) (encoder *Encoder) { + return &Encoder{ + encodeState: encodeState{ + Writer: writer, + fieldScaleIndicesCache: cache, + }, + } +} + +// Encode scale encodes value to the encoder writer. +func (e *Encoder) Encode(value interface{}) (err error) { + return e.marshal(value) +} + // Marshal takes in an interface{} and attempts to marshal into []byte func Marshal(v interface{}) (b []byte, err error) { + buffer := bytes.NewBuffer(nil) es := encodeState{ + Writer: buffer, fieldScaleIndicesCache: cache, } err = es.marshal(v) if err != nil { return } - b = es.Bytes() + b = buffer.Bytes() return } type encodeState struct { - bytes.Buffer + io.Writer *fieldScaleIndicesCache } @@ -64,9 +87,9 @@ func (es *encodeState) marshal(in interface{}) (err error) { elem := reflect.ValueOf(in).Elem() switch elem.IsValid() { case false: - err = es.WriteByte(0) + _, err = es.Write([]byte{0}) default: - err = es.WriteByte(1) + _, err = es.Write([]byte{1}) if err != nil { return } @@ -133,13 +156,13 @@ func (es *encodeState) encodeResult(res Result) (err error) { var in interface{} switch res.mode { case OK: - err = es.WriteByte(0) + _, err = es.Write([]byte{0}) if err != nil { return } in = res.ok case Err: - err = es.WriteByte(1) + _, err = es.Write([]byte{1}) if err != nil { return } @@ -159,7 +182,7 @@ func (es *encodeState) encodeCustomVaryingDataType(in interface{}) (err error) { } func (es *encodeState) encodeVaryingDataType(vdt VaryingDataType) (err error) { - err = es.WriteByte(byte(vdt.value.Index())) + _, err = es.Write([]byte{byte(vdt.value.Index())}) if err != nil { return } diff --git a/pkg/scale/encode_test.go b/pkg/scale/encode_test.go index 637bde1c6c..4d56b40e77 100644 --- a/pkg/scale/encode_test.go +++ b/pkg/scale/encode_test.go @@ -4,12 +4,74 @@ package scale import ( + "bytes" "math/big" "reflect" "strings" "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) +func Test_NewEncoder(t *testing.T) { + t.Parallel() + + cache.Lock() + defer cache.Unlock() + + writer := bytes.NewBuffer(nil) + encoder := NewEncoder(writer) + + expectedEncoder := &Encoder{ + encodeState: encodeState{ + Writer: writer, + fieldScaleIndicesCache: cache, + }, + } + + assert.Equal(t, expectedEncoder, encoder) +} + +func Test_Encoder_Encode(t *testing.T) { + t.Parallel() + + buffer := bytes.NewBuffer(nil) + encoder := NewEncoder(buffer) + + err := encoder.Encode(uint16(1)) + require.NoError(t, err) + + err = encoder.Encode(uint8(2)) + require.NoError(t, err) + + array := [2]byte{4, 5} + err = encoder.Encode(array) + require.NoError(t, err) + + type T struct { + Array [2]byte + } + + someStruct := T{Array: [2]byte{6, 7}} + err = encoder.Encode(someStruct) + require.NoError(t, err) + + structSlice := []T{{Array: [2]byte{8, 9}}} + err = encoder.Encode(structSlice) + require.NoError(t, err) + + written := buffer.Bytes() + expectedWritten := []byte{ + 1, 0, + 2, + 4, 5, + 6, 7, + 4, 8, 9, + } + assert.Equal(t, expectedWritten, written) +} + type test struct { name string in interface{} @@ -869,12 +931,15 @@ type MyStructWithPrivate struct { func Test_encodeState_encodeFixedWidthInteger(t *testing.T) { for _, tt := range fixedWidthIntegerTests { t.Run(tt.name, func(t *testing.T) { - es := &encodeState{} + buffer := bytes.NewBuffer(nil) + es := &encodeState{ + Writer: buffer, + } if err := es.marshal(tt.in); (err != nil) != tt.wantErr { t.Errorf("encodeState.encodeFixedWidthInt() error = %v, wantErr %v", err, tt.wantErr) } - if !reflect.DeepEqual(es.Buffer.Bytes(), tt.want) { - t.Errorf("encodeState.encodeFixedWidthInt() = %v, want %v", es.Buffer.Bytes(), tt.want) + if !reflect.DeepEqual(buffer.Bytes(), tt.want) { + t.Errorf("encodeState.encodeFixedWidthInt() = %v, want %v", buffer.Bytes(), tt.want) } }) } @@ -883,12 +948,15 @@ func Test_encodeState_encodeFixedWidthInteger(t *testing.T) { func Test_encodeState_encodeVariableWidthIntegers(t *testing.T) { for _, tt := range variableWidthIntegerTests { t.Run(tt.name, func(t *testing.T) { - es := &encodeState{} + buffer := bytes.NewBuffer(nil) + es := &encodeState{ + Writer: buffer, + } if err := es.marshal(tt.in); (err != nil) != tt.wantErr { t.Errorf("encodeState.encodeFixedWidthInt() error = %v, wantErr %v", err, tt.wantErr) } - if !reflect.DeepEqual(es.Buffer.Bytes(), tt.want) { - t.Errorf("encodeState.encodeFixedWidthInt() = %v, want %v", es.Buffer.Bytes(), tt.want) + if !reflect.DeepEqual(buffer.Bytes(), tt.want) { + t.Errorf("encodeState.encodeFixedWidthInt() = %v, want %v", buffer.Bytes(), tt.want) } }) } @@ -897,12 +965,15 @@ func Test_encodeState_encodeVariableWidthIntegers(t *testing.T) { func Test_encodeState_encodeBigInt(t *testing.T) { for _, tt := range bigIntTests { t.Run(tt.name, func(t *testing.T) { - es := &encodeState{} + buffer := bytes.NewBuffer(nil) + es := &encodeState{ + Writer: buffer, + } if err := es.marshal(tt.in); (err != nil) != tt.wantErr { t.Errorf("encodeState.encodeBigInt() error = %v, wantErr %v", err, tt.wantErr) } - if !reflect.DeepEqual(es.Buffer.Bytes(), tt.want) { - t.Errorf("encodeState.encodeBigInt() = %v, want %v", es.Buffer.Bytes(), tt.want) + if !reflect.DeepEqual(buffer.Bytes(), tt.want) { + t.Errorf("encodeState.encodeBigInt() = %v, want %v", buffer.Bytes(), tt.want) } }) } @@ -911,12 +982,15 @@ func Test_encodeState_encodeBigInt(t *testing.T) { func Test_encodeState_encodeUint128(t *testing.T) { for _, tt := range uint128Tests { t.Run(tt.name, func(t *testing.T) { - es := &encodeState{} + buffer := bytes.NewBuffer(nil) + es := &encodeState{ + Writer: buffer, + } if err := es.marshal(tt.in); (err != nil) != tt.wantErr { t.Errorf("encodeState.encodeUin128() error = %v, wantErr %v", err, tt.wantErr) } - if !reflect.DeepEqual(es.Buffer.Bytes(), tt.want) { - t.Errorf("encodeState.encodeUin128() = %v, want %v", es.Buffer.Bytes(), tt.want) + if !reflect.DeepEqual(buffer.Bytes(), tt.want) { + t.Errorf("encodeState.encodeUin128() = %v, want %v", buffer.Bytes(), tt.want) } }) } @@ -925,12 +999,16 @@ func Test_encodeState_encodeUint128(t *testing.T) { func Test_encodeState_encodeBytes(t *testing.T) { for _, tt := range stringTests { t.Run(tt.name, func(t *testing.T) { - es := &encodeState{} + + buffer := bytes.NewBuffer(nil) + es := &encodeState{ + Writer: buffer, + } if err := es.marshal(tt.in); (err != nil) != tt.wantErr { t.Errorf("encodeState.encodeBytes() error = %v, wantErr %v", err, tt.wantErr) } - if !reflect.DeepEqual(es.Buffer.Bytes(), tt.want) { - t.Errorf("encodeState.encodeBytes() = %v, want %v", es.Buffer.Bytes(), tt.want) + if !reflect.DeepEqual(buffer.Bytes(), tt.want) { + t.Errorf("encodeState.encodeBytes() = %v, want %v", buffer.Bytes(), tt.want) } }) } @@ -939,12 +1017,16 @@ func Test_encodeState_encodeBytes(t *testing.T) { func Test_encodeState_encodeBool(t *testing.T) { for _, tt := range boolTests { t.Run(tt.name, func(t *testing.T) { - es := &encodeState{} + + buffer := bytes.NewBuffer(nil) + es := &encodeState{ + Writer: buffer, + } if err := es.marshal(tt.in); (err != nil) != tt.wantErr { t.Errorf("encodeState.encodeBool() error = %v, wantErr %v", err, tt.wantErr) } - if !reflect.DeepEqual(es.Buffer.Bytes(), tt.want) { - t.Errorf("encodeState.encodeBool() = %v, want %v", es.Buffer.Bytes(), tt.want) + if !reflect.DeepEqual(buffer.Bytes(), tt.want) { + t.Errorf("encodeState.encodeBool() = %v, want %v", buffer.Bytes(), tt.want) } }) } @@ -953,12 +1035,16 @@ func Test_encodeState_encodeBool(t *testing.T) { func Test_encodeState_encodeStruct(t *testing.T) { for _, tt := range structTests { t.Run(tt.name, func(t *testing.T) { - es := &encodeState{fieldScaleIndicesCache: cache} + buffer := bytes.NewBuffer(nil) + es := &encodeState{ + Writer: buffer, + fieldScaleIndicesCache: cache, + } if err := es.marshal(tt.in); (err != nil) != tt.wantErr { t.Errorf("encodeState.encodeStruct() error = %v, wantErr %v", err, tt.wantErr) } - if !reflect.DeepEqual(es.Buffer.Bytes(), tt.want) { - t.Errorf("encodeState.encodeStruct() = %v, want %v", es.Buffer.Bytes(), tt.want) + if !reflect.DeepEqual(buffer.Bytes(), tt.want) { + t.Errorf("encodeState.encodeStruct() = %v, want %v", buffer.Bytes(), tt.want) } }) } @@ -967,12 +1053,16 @@ func Test_encodeState_encodeStruct(t *testing.T) { func Test_encodeState_encodeSlice(t *testing.T) { for _, tt := range sliceTests { t.Run(tt.name, func(t *testing.T) { - es := &encodeState{fieldScaleIndicesCache: cache} + buffer := bytes.NewBuffer(nil) + es := &encodeState{ + Writer: buffer, + fieldScaleIndicesCache: cache, + } if err := es.marshal(tt.in); (err != nil) != tt.wantErr { t.Errorf("encodeState.encodeSlice() error = %v, wantErr %v", err, tt.wantErr) } - if !reflect.DeepEqual(es.Buffer.Bytes(), tt.want) { - t.Errorf("encodeState.encodeSlice() = %v, want %v", es.Buffer.Bytes(), tt.want) + if !reflect.DeepEqual(buffer.Bytes(), tt.want) { + t.Errorf("encodeState.encodeSlice() = %v, want %v", buffer.Bytes(), tt.want) } }) } @@ -981,12 +1071,16 @@ func Test_encodeState_encodeSlice(t *testing.T) { func Test_encodeState_encodeArray(t *testing.T) { for _, tt := range arrayTests { t.Run(tt.name, func(t *testing.T) { - es := &encodeState{fieldScaleIndicesCache: cache} + buffer := bytes.NewBuffer(nil) + es := &encodeState{ + Writer: buffer, + fieldScaleIndicesCache: cache, + } if err := es.marshal(tt.in); (err != nil) != tt.wantErr { t.Errorf("encodeState.encodeArray() error = %v, wantErr %v", err, tt.wantErr) } - if !reflect.DeepEqual(es.Buffer.Bytes(), tt.want) { - t.Errorf("encodeState.encodeArray() = %v, want %v", es.Buffer.Bytes(), tt.want) + if !reflect.DeepEqual(buffer.Bytes(), tt.want) { + t.Errorf("encodeState.encodeArray() = %v, want %v", buffer.Bytes(), tt.want) } }) } @@ -1007,12 +1101,16 @@ func Test_marshal_optionality(t *testing.T) { } for _, tt := range ptrTests { t.Run(tt.name, func(t *testing.T) { - es := &encodeState{fieldScaleIndicesCache: cache} + buffer := bytes.NewBuffer(nil) + es := &encodeState{ + Writer: buffer, + fieldScaleIndicesCache: cache, + } if err := es.marshal(tt.in); (err != nil) != tt.wantErr { t.Errorf("encodeState.encodeFixedWidthInt() error = %v, wantErr %v", err, tt.wantErr) } - if !reflect.DeepEqual(es.Buffer.Bytes(), tt.want) { - t.Errorf("encodeState.encodeFixedWidthInt() = %v, want %v", es.Buffer.Bytes(), tt.want) + if !reflect.DeepEqual(buffer.Bytes(), tt.want) { + t.Errorf("encodeState.encodeFixedWidthInt() = %v, want %v", buffer.Bytes(), tt.want) } }) } @@ -1043,12 +1141,16 @@ func Test_marshal_optionality_nil_cases(t *testing.T) { } for _, tt := range ptrTests { t.Run(tt.name, func(t *testing.T) { - es := &encodeState{fieldScaleIndicesCache: cache} + buffer := bytes.NewBuffer(nil) + es := &encodeState{ + Writer: buffer, + fieldScaleIndicesCache: cache, + } if err := es.marshal(tt.in); (err != nil) != tt.wantErr { t.Errorf("encodeState.encodeFixedWidthInt() error = %v, wantErr %v", err, tt.wantErr) } - if !reflect.DeepEqual(es.Buffer.Bytes(), tt.want) { - t.Errorf("encodeState.encodeFixedWidthInt() = %v, want %v", es.Buffer.Bytes(), tt.want) + if !reflect.DeepEqual(buffer.Bytes(), tt.want) { + t.Errorf("encodeState.encodeFixedWidthInt() = %v, want %v", buffer.Bytes(), tt.want) } }) } diff --git a/pkg/scale/varying_data_type_test.go b/pkg/scale/varying_data_type_test.go index eb66b294cb..34606a34db 100644 --- a/pkg/scale/varying_data_type_test.go +++ b/pkg/scale/varying_data_type_test.go @@ -4,6 +4,7 @@ package scale import ( + "bytes" "math/big" "reflect" "testing" @@ -293,13 +294,17 @@ var varyingDataTypeTests = tests{ func Test_encodeState_encodeVaryingDataType(t *testing.T) { for _, tt := range varyingDataTypeTests { t.Run(tt.name, func(t *testing.T) { - es := &encodeState{fieldScaleIndicesCache: cache} + buffer := bytes.NewBuffer(nil) + es := &encodeState{ + Writer: buffer, + fieldScaleIndicesCache: cache, + } vdt := tt.in.(VaryingDataType) if err := es.marshal(vdt); (err != nil) != tt.wantErr { t.Errorf("encodeState.marshal() error = %v, wantErr %v", err, tt.wantErr) } - if !reflect.DeepEqual(es.Buffer.Bytes(), tt.want) { - t.Errorf("encodeState.marshal() = %v, want %v", es.Buffer.Bytes(), tt.want) + if !reflect.DeepEqual(buffer.Bytes(), tt.want) { + t.Errorf("encodeState.marshal() = %v, want %v", buffer.Bytes(), tt.want) } }) } @@ -329,14 +334,18 @@ func Test_decodeState_decodeVaryingDataType(t *testing.T) { func Test_encodeState_encodeCustomVaryingDataType(t *testing.T) { for _, tt := range varyingDataTypeTests { t.Run(tt.name, func(t *testing.T) { - es := &encodeState{fieldScaleIndicesCache: cache} + buffer := bytes.NewBuffer(nil) + es := &encodeState{ + Writer: buffer, + fieldScaleIndicesCache: cache, + } vdt := tt.in.(VaryingDataType) cvdt := customVDT(vdt) if err := es.marshal(cvdt); (err != nil) != tt.wantErr { t.Errorf("encodeState.encodeStruct() error = %v, wantErr %v", err, tt.wantErr) } - if !reflect.DeepEqual(es.Buffer.Bytes(), tt.want) { - t.Errorf("encodeState.encodeStruct() = %v, want %v", es.Buffer.Bytes(), tt.want) + if !reflect.DeepEqual(buffer.Bytes(), tt.want) { + t.Errorf("encodeState.encodeStruct() = %v, want %v", buffer.Bytes(), tt.want) } }) }