Skip to content

Commit

Permalink
fix(dot/subscription): check websocket message from untrusted data (#…
Browse files Browse the repository at this point in the history
…2527)

* fix: websocket message checks from untrusted data
  • Loading branch information
EclesioMeloJunior committed May 26, 2022
1 parent e29f90c commit 1f20d98
Show file tree
Hide file tree
Showing 3 changed files with 95 additions and 44 deletions.
2 changes: 1 addition & 1 deletion dot/rpc/http.go
Original file line number Diff line number Diff line change
Expand Up @@ -233,7 +233,7 @@ func (h *HTTPServer) ServeHTTP(w http.ResponseWriter, r *http.Request) {
wsc := NewWSConn(ws, h.serverConfig)
h.wsConns = append(h.wsConns, wsc)

go wsc.HandleComm()
go wsc.HandleConn()
}

// NewWSConn to create new WebSocket Connection struct
Expand Down
84 changes: 44 additions & 40 deletions dot/rpc/subscription/websocket.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,21 @@ import (
"github.com/gorilla/websocket"
)

type websocketMessage struct {
ID float64 `json:"id"`
Method string `json:"method"`
Params any `json:"params"`
}

type httpclient interface {
Do(*http.Request) (*http.Response, error)
}

var errCannotReadFromWebsocket = errors.New("cannot read message from websocket")
var errCannotUnmarshalMessage = errors.New("cannot unmarshal webasocket message data")
var (
errCannotReadFromWebsocket = errors.New("cannot read message from websocket")
errEmptyMethod = errors.New("empty method")
)

var logger = log.NewFromGlobal(log.AddContext("pkg", "rpc/subscription"))

// WSConn struct to hold WebSocket Connection references
Expand All @@ -46,87 +55,82 @@ type WSConn struct {
}

// readWebsocketMessage will read and parse the message data to a string->interface{} data
func (c *WSConn) readWebsocketMessage() ([]byte, map[string]interface{}, error) {
_, mbytes, err := c.Wsconn.ReadMessage()
func (c *WSConn) readWebsocketMessage() (rawBytes []byte, wsMessage *websocketMessage, err error) {
_, rawBytes, err = c.Wsconn.ReadMessage()
if err != nil {
logger.Debugf("websocket failed to read message: %s", err)
return nil, nil, errCannotReadFromWebsocket
return nil, nil, fmt.Errorf("%w: %s", errCannotReadFromWebsocket, err.Error())
}

logger.Tracef("websocket message received: %s", string(mbytes))

// determine if request is for subscribe method type
var msg map[string]interface{}
err = json.Unmarshal(mbytes, &msg)

wsMessage = new(websocketMessage)
err = json.Unmarshal(rawBytes, wsMessage)
if err != nil {
logger.Debugf("websocket failed to unmarshal request message: %s", err)
return nil, nil, errCannotUnmarshalMessage
return nil, nil, err
}

return mbytes, msg, nil
if wsMessage.Method == "" {
return nil, nil, errEmptyMethod
}

return rawBytes, wsMessage, nil
}

//HandleComm handles messages received on websocket connections
func (c *WSConn) HandleComm() {
// HandleConn handles messages received on websocket connections
func (c *WSConn) HandleConn() {
for {
mbytes, msg, err := c.readWebsocketMessage()
if errors.Is(err, errCannotReadFromWebsocket) {
return
}
rawBytes, wsMessage, err := c.readWebsocketMessage()
if err != nil {
logger.Debugf("websocket failed to read message: %s", err)
if errors.Is(err, errCannotReadFromWebsocket) {
return
}

if errors.Is(err, errCannotUnmarshalMessage) {
c.safeSendError(0, big.NewInt(InvalidRequestCode), InvalidRequestMessage)
continue
}

params := msg["params"]
reqid := msg["id"].(float64)
method := msg["method"].(string)

logger.Debugf("ws method %s called with params %v", method, params)
logger.Tracef("websocket message received: %s", string(rawBytes))
logger.Debugf("ws method %s called with params %v", wsMessage.Method, wsMessage.Params)

if !strings.Contains(method, "_unsubscribe") && !strings.Contains(method, "_unwatch") {
setupListener := c.getSetupListener(method)
if !strings.Contains(wsMessage.Method, "_unsubscribe") && !strings.Contains(wsMessage.Method, "_unwatch") {
setupListener := c.getSetupListener(wsMessage.Method)

if setupListener == nil {
c.executeRPCCall(mbytes)
c.executeRPCCall(rawBytes)
continue
}

listener, err := setupListener(reqid, params)
listener, err := setupListener(wsMessage.ID, wsMessage.Params)
if err != nil {
logger.Warnf("failed to create listener (method=%s): %s", method, err)
logger.Warnf("failed to create listener (method=%s): %s", wsMessage.Method, err)
continue
}

listener.Listen()
continue
}

listener, err := c.getUnsubListener(params)

listener, err := c.getUnsubListener(wsMessage.Params)
if err != nil {
logger.Warnf("failed to get unsubscriber (method=%s): %s", method, err)
logger.Warnf("failed to get unsubscriber (method=%s): %s", wsMessage.Method, err)

if errors.Is(err, errUknownParamSubscribeID) || errors.Is(err, errCannotFindUnsubsriber) {
c.safeSendError(reqid, big.NewInt(InvalidRequestCode), InvalidRequestMessage)
c.safeSendError(wsMessage.ID, big.NewInt(InvalidRequestCode), InvalidRequestMessage)
continue
}

if errors.Is(err, errCannotParseID) || errors.Is(err, errCannotFindListener) {
c.safeSend(newBooleanResponseJSON(false, reqid))
c.safeSend(newBooleanResponseJSON(false, wsMessage.ID))
continue
}
}

err = listener.Stop()
if err != nil {
logger.Warnf("failed to stop listener goroutine (method=%s): %s", method, err)
c.safeSend(newBooleanResponseJSON(false, reqid))
logger.Warnf("failed to stop listener goroutine (method=%s): %s", wsMessage.Method, err)
c.safeSend(newBooleanResponseJSON(false, wsMessage.ID))
}

c.safeSend(newBooleanResponseJSON(true, reqid))
c.safeSend(newBooleanResponseJSON(true, wsMessage.ID))
continue
}
}
Expand Down
53 changes: 50 additions & 3 deletions dot/rpc/subscription/websocket_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,12 @@ import (
"github.com/stretchr/testify/require"
)

func TestWSConn_HandleComm(t *testing.T) {
func TestWSConn_HandleConn(t *testing.T) {
wsconn, c, cancel := setupWSConn(t)
wsconn.Subscriptions = make(map[uint32]Listener)
defer cancel()

go wsconn.HandleComm()
go wsconn.HandleConn()
time.Sleep(time.Second * 2)

// test storageChangeListener
Expand Down Expand Up @@ -294,7 +294,7 @@ func TestSubscribeAllHeads(t *testing.T) {
wsconn.Subscriptions = make(map[uint32]Listener)
defer cancel()

go wsconn.HandleComm()
go wsconn.HandleConn()
time.Sleep(time.Second * 2)

_, err := wsconn.initAllBlocksListerner(1, nil)
Expand Down Expand Up @@ -372,3 +372,50 @@ func TestSubscribeAllHeads(t *testing.T) {
require.NoError(t, l.Stop())
mockBlockAPI.On("FreeImportedBlockNotifierChannel", mock.AnythingOfType("chan *types.Block"))
}

func TestWSConn_CheckWebsocketInvalidData(t *testing.T) {
wsconn, c, cancel := setupWSConn(t)
wsconn.Subscriptions = make(map[uint32]Listener)
defer cancel()

go wsconn.HandleConn()

tests := []struct {
sentMessage []byte
expected []byte
}{
{
sentMessage: []byte(`{
"jsonrpc": "2.0",
"method": "",
"id": 0,
"params": []
}`),
expected: []byte(`{"jsonrpc":"2.0","error":{"code":-32600,"message":"Invalid request"},"id":0}` + "\n"),
},
{
sentMessage: []byte(`{
"jsonrpc": "2.0",
"params": []
}`),
expected: []byte(`{"jsonrpc":"2.0","error":{"code":-32600,"message":"Invalid request"},"id":0}` + "\n"),
},
{
sentMessage: []byte(`{
"jsonrpc": "2.0",
"id": "abcdef"
"method": "some_method_name"
"params": []
}`),
expected: []byte(`{"jsonrpc":"2.0","error":{"code":-32600,"message":"Invalid request"},"id":0}` + "\n"),
},
}

for _, tt := range tests {
c.WriteMessage(websocket.TextMessage, tt.sentMessage)

_, msg, err := c.ReadMessage()
require.NoError(t, err)
require.Equal(t, tt.expected, msg)
}
}

0 comments on commit 1f20d98

Please sign in to comment.