Skip to content

Commit

Permalink
fix(dot/subscription): unsafe type casting from untrusted input (#2529)
Browse files Browse the repository at this point in the history
* chore: using checks to use `interface{}` types
  • Loading branch information
EclesioMeloJunior committed May 30, 2022
1 parent d2ee47e commit 1015733
Show file tree
Hide file tree
Showing 2 changed files with 64 additions and 31 deletions.
70 changes: 50 additions & 20 deletions dot/rpc/subscription/websocket.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@ type httpclient interface {
}

var (
errUnexpectedType = errors.New("unexpected type")
errUnexpectedParamLen = errors.New("unexpected params length")
errCannotReadFromWebsocket = errors.New("cannot read message from websocket")
errEmptyMethod = errors.New("empty method")
)
Expand Down Expand Up @@ -163,25 +165,35 @@ func (c *WSConn) initStorageChangeListener(reqID float64, params interface{}) (L
wsconn: c,
}

pA, ok := params.([]interface{})
if !ok {
return nil, fmt.Errorf("unknown parameter type")
}
for _, param := range pA {
switch p := param.(type) {
case []interface{}:
for _, pp := range param.([]interface{}) {
data, ok := pp.(string)
if !ok {
return nil, fmt.Errorf("unknown parameter type")
// the following type checking/casting is needed in order to satisfy some
// websocket request field params eg.:
// "params": ["0x..."] or
// "params": [["0x...", "0x..."]]
switch filters := params.(type) {
case []interface{}:
for _, interfaceKey := range filters {
switch key := interfaceKey.(type) {
case string:
stgobs.filter[key] = []byte{}
case []string:
for _, k := range key {
stgobs.filter[k] = []byte{}
}
case []interface{}:
for _, k := range key {
k, ok := k.(string)
if !ok {
return nil, fmt.Errorf("%w: %T, expected type string", errUnexpectedType, k)
}

stgobs.filter[k] = []byte{}
}
stgobs.filter[data] = []byte{}
default:
return nil, fmt.Errorf("%w: %T, expected type string, []string, []interface{}", errUnexpectedType, interfaceKey)
}
case string:
stgobs.filter[p] = []byte{}
default:
return nil, fmt.Errorf("unknown parameter type")
}
default:
return nil, fmt.Errorf("%w: %T, expected type []interface{}", errUnexpectedType, params)
}

c.mu.Lock()
Expand Down Expand Up @@ -269,14 +281,32 @@ func (c *WSConn) initAllBlocksListerner(reqID float64, _ interface{}) (Listener,
}

func (c *WSConn) initExtrinsicWatch(reqID float64, params interface{}) (Listener, error) {
pA := params.([]interface{})
var encodedExtrinsic string

switch encodedHex := params.(type) {
case []string:
if len(encodedHex) != 1 {
return nil, fmt.Errorf("%w: expected 1 param, got: %d", errUnexpectedParamLen, len(encodedHex))
}
encodedExtrinsic = encodedHex[0]
// the bellow case is needed to cover a interface{} slice containing one string
// as `[]interface{"a"}` is not the same as `[]string{"a"}`
case []interface{}:
if len(encodedHex) != 1 {
return nil, fmt.Errorf("%w: expected 1 param, got: %d", errUnexpectedParamLen, len(encodedHex))
}

if len(pA) != 1 {
return nil, errors.New("expecting only one parameter")
var ok bool
encodedExtrinsic, ok = encodedHex[0].(string)
if !ok {
return nil, fmt.Errorf("%w: %T, expected type string", errUnexpectedType, encodedHex[0])
}
default:
return nil, fmt.Errorf("%w: %T, expected type []string or []interface{}", errUnexpectedType, params)
}

// The passed parameter should be a HEX of a SCALE encoded extrinsic
extBytes, err := common.HexToBytes(pA[0].(string))
extBytes, err := common.HexToBytes(encodedExtrinsic)
if err != nil {
return nil, err
}
Expand Down
25 changes: 14 additions & 11 deletions dot/rpc/subscription/websocket_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,8 @@ func TestWSConn_HandleConn(t *testing.T) {
res, err = wsconn.initStorageChangeListener(1, nil)
require.Nil(t, res)
require.Len(t, wsconn.Subscriptions, 0)
require.EqualError(t, err, "unknown parameter type")
require.ErrorIs(t, err, errUnexpectedType)
require.EqualError(t, err, "unexpected type: <nil>, expected type []interface{}")

res, err = wsconn.initStorageChangeListener(2, []interface{}{})
require.NotNil(t, res)
Expand All @@ -55,33 +56,35 @@ func TestWSConn_HandleConn(t *testing.T) {
require.NoError(t, err)
require.Equal(t, []byte(`{"jsonrpc":"2.0","result":1,"id":2}`+"\n"), msg)

res, err = wsconn.initStorageChangeListener(3, []interface{}{"0x26aa"})
var testFilter0 = []interface{}{"0x26aa"}
res, err = wsconn.initStorageChangeListener(3, testFilter0)
require.NotNil(t, res)
require.NoError(t, err)
require.Len(t, wsconn.Subscriptions, 2)
_, msg, err = c.ReadMessage()
require.NoError(t, err)
require.Equal(t, []byte(`{"jsonrpc":"2.0","result":2,"id":3}`+"\n"), msg)

var testFilters = []interface{}{}
var testFilter1 = []interface{}{"0x26aa", "0x26a1"}
res, err = wsconn.initStorageChangeListener(4, append(testFilters, testFilter1))
require.NotNil(t, res)
var testFilter1 = []interface{}{[]interface{}{"0x26aa", "0x26a1"}}
res, err = wsconn.initStorageChangeListener(4, testFilter1)
require.NoError(t, err)
require.NotNil(t, res)
require.Len(t, wsconn.Subscriptions, 3)
_, msg, err = c.ReadMessage()
require.NoError(t, err)
require.Equal(t, []byte(`{"jsonrpc":"2.0","result":3,"id":4}`+"\n"), msg)

var testFilterWrongType = []interface{}{"0x26aa", 1}
res, err = wsconn.initStorageChangeListener(5, append(testFilters, testFilterWrongType))
require.EqualError(t, err, "unknown parameter type")
var testFilterWrongType = []interface{}{[]int{123}}
res, err = wsconn.initStorageChangeListener(5, testFilterWrongType)
require.ErrorIs(t, err, errUnexpectedType)
require.EqualError(t, err, "unexpected type: []int, expected type string, []string, []interface{}")
require.Nil(t, res)
// keep subscriptions len == 3, no additions was made
require.Len(t, wsconn.Subscriptions, 3)

res, err = wsconn.initStorageChangeListener(6, []interface{}{1})
require.EqualError(t, err, "unknown parameter type")
require.ErrorIs(t, err, errUnexpectedType)
require.EqualError(t, err, "unexpected type: int, expected type string, []string, []interface{}")
require.Nil(t, res)
require.Len(t, wsconn.Subscriptions, 3)

Expand Down Expand Up @@ -207,7 +210,7 @@ func TestWSConn_HandleConn(t *testing.T) {
wsconn.CoreAPI = modules.NewMockCoreAPI()
wsconn.BlockAPI = nil
wsconn.TxStateAPI = modules.NewMockTransactionStateAPI()
listner, err := wsconn.initExtrinsicWatch(0, []interface{}{"NotHex"})
listner, err := wsconn.initExtrinsicWatch(0, []string{"NotHex"})
require.EqualError(t, err, "could not byteify non 0x prefixed string: NotHex")
require.Nil(t, listner)

Expand Down

0 comments on commit 1015733

Please sign in to comment.