Skip to content

Commit

Permalink
Merge pull request #188 from nhooyr/fix-negotiations
Browse files Browse the repository at this point in the history
Fix deflate extension parameter negotiation
  • Loading branch information
nhooyr committed Feb 16, 2020
2 parents 94f9b71 + 95bfb8f commit fbd323c
Show file tree
Hide file tree
Showing 6 changed files with 23 additions and 69 deletions.
29 changes: 2 additions & 27 deletions accept.go
Original file line number Diff line number Diff line change
Expand Up @@ -209,7 +209,6 @@ func acceptCompression(r *http.Request, w http.ResponseWriter, mode CompressionM

func acceptDeflate(w http.ResponseWriter, ext websocketExtension, mode CompressionMode) (*compressionOptions, error) {
copts := mode.opts()
copts.serverMaxWindowBits = 8

for _, p := range ext.params {
switch p {
Expand All @@ -222,26 +221,7 @@ func acceptDeflate(w http.ResponseWriter, ext websocketExtension, mode Compressi
}

if strings.HasPrefix(p, "client_max_window_bits") {
continue

// bits, ok := parseExtensionParameter(p, 15)
// if !ok || bits < 8 || bits > 16 {
// err := fmt.Errorf("invalid client_max_window_bits: %q", p)
// http.Error(w, err.Error(), http.StatusBadRequest)
// return nil, err
// }
// copts.clientMaxWindowBits = bits
// continue
}

if false && strings.HasPrefix(p, "server_max_window_bits") {
// We always send back 8 but make sure to validate.
bits, ok := parseExtensionParameter(p, 0)
if !ok || bits < 8 || bits > 16 {
err := fmt.Errorf("invalid server_max_window_bits: %q", p)
http.Error(w, err.Error(), http.StatusBadRequest)
return nil, err
}
// We cannot adjust the read sliding window so cannot make use of this.
continue
}

Expand All @@ -256,14 +236,9 @@ func acceptDeflate(w http.ResponseWriter, ext websocketExtension, mode Compressi
}

// parseExtensionParameter parses the value in the extension parameter p.
// It falls back to defaultVal if there is no value.
// If defaultVal == 0, then ok == false if there is no value.
func parseExtensionParameter(p string, defaultVal int) (int, bool) {
func parseExtensionParameter(p string) (int, bool) {
ps := strings.Split(p, "=")
if len(ps) == 1 {
if defaultVal > 0 {
return defaultVal, true
}
return 0, false
}
i, e := strconv.Atoi(strings.Trim(ps[1], `"`))
Expand Down
1 change: 0 additions & 1 deletion accept_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -327,7 +327,6 @@ func Test_acceptCompression(t *testing.T) {
expCopts: &compressionOptions{
clientNoContextTakeover: true,
serverNoContextTakeover: true,
serverMaxWindowBits: 8,
},
},
{
Expand Down
1 change: 1 addition & 0 deletions autobahn_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ var excludedAutobahnCases = []string{

// We skip the tests related to requestMaxWindowBits as that is unimplemented due
// to limitations in compress/flate. See https://github.com/golang/go/issues/3155
// Same with klauspost/compress which doesn't allow adjusting the sliding window size.
"13.3.*", "13.4.*", "13.5.*", "13.6.*",
}

Expand Down
14 changes: 4 additions & 10 deletions compress_notjs.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
package websocket

import (
"fmt"
"io"
"net/http"
"sync"
Expand All @@ -20,10 +19,7 @@ func (m CompressionMode) opts() *compressionOptions {

type compressionOptions struct {
clientNoContextTakeover bool
clientMaxWindowBits int

serverNoContextTakeover bool
serverMaxWindowBits int
}

func (copts *compressionOptions) setHeader(h http.Header) {
Expand All @@ -34,12 +30,6 @@ func (copts *compressionOptions) setHeader(h http.Header) {
if copts.serverNoContextTakeover {
s += "; server_no_context_takeover"
}
if false && copts.serverMaxWindowBits > 0 {
s += fmt.Sprintf("; server_max_window_bits=%v", copts.serverMaxWindowBits)
}
if false && copts.clientMaxWindowBits > 0 {
s += fmt.Sprintf("; client_max_window_bits=%v", copts.clientMaxWindowBits)
}
h.Set("Sec-WebSocket-Extensions", s)
}

Expand Down Expand Up @@ -147,6 +137,10 @@ func (sw *slidingWindow) init(n int) {
return
}

if n == 0 {
n = 32768
}

p := slidingWindowPool(n)
buf, ok := p.Get().([]byte)
if ok {
Expand Down
45 changes: 15 additions & 30 deletions dial.go
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,12 @@ func dial(ctx context.Context, urls string, opts *DialOptions, rand io.Reader) (
return nil, nil, fmt.Errorf("failed to generate Sec-WebSocket-Key: %w", err)
}

resp, err := handshakeRequest(ctx, urls, opts, secWebSocketKey)
var copts *compressionOptions
if opts.CompressionMode != CompressionDisabled {
copts = opts.CompressionMode.opts()
}

resp, err := handshakeRequest(ctx, urls, opts, copts, secWebSocketKey)
if err != nil {
return nil, resp, err
}
Expand All @@ -104,7 +109,7 @@ func dial(ctx context.Context, urls string, opts *DialOptions, rand io.Reader) (
}
}()

copts, err := verifyServerResponse(opts, secWebSocketKey, resp)
copts, err = verifyServerResponse(opts, copts, secWebSocketKey, resp)
if err != nil {
return nil, resp, err
}
Expand All @@ -125,7 +130,7 @@ func dial(ctx context.Context, urls string, opts *DialOptions, rand io.Reader) (
}), resp, nil
}

func handshakeRequest(ctx context.Context, urls string, opts *DialOptions, secWebSocketKey string) (*http.Response, error) {
func handshakeRequest(ctx context.Context, urls string, opts *DialOptions, copts *compressionOptions, secWebSocketKey string) (*http.Response, error) {
if opts.HTTPClient.Timeout > 0 {
return nil, errors.New("use context for cancellation instead of http.Client.Timeout; see https://github.com/nhooyr/websocket/issues/67")
}
Expand Down Expand Up @@ -153,9 +158,7 @@ func handshakeRequest(ctx context.Context, urls string, opts *DialOptions, secWe
if len(opts.Subprotocols) > 0 {
req.Header.Set("Sec-WebSocket-Protocol", strings.Join(opts.Subprotocols, ","))
}
if opts.CompressionMode != CompressionDisabled {
copts := opts.CompressionMode.opts()
copts.clientMaxWindowBits = 8
if copts != nil {
copts.setHeader(req.Header)
}

Expand All @@ -178,7 +181,7 @@ func secWebSocketKey(rr io.Reader) (string, error) {
return base64.StdEncoding.EncodeToString(b), nil
}

func verifyServerResponse(opts *DialOptions, secWebSocketKey string, resp *http.Response) (*compressionOptions, error) {
func verifyServerResponse(opts *DialOptions, copts *compressionOptions, secWebSocketKey string, resp *http.Response) (*compressionOptions, error) {
if resp.StatusCode != http.StatusSwitchingProtocols {
return nil, fmt.Errorf("expected handshake response status code %v but got %v", http.StatusSwitchingProtocols, resp.StatusCode)
}
Expand All @@ -203,7 +206,7 @@ func verifyServerResponse(opts *DialOptions, secWebSocketKey string, resp *http.
return nil, err
}

return verifyServerExtensions(resp.Header)
return verifyServerExtensions(copts, resp.Header)
}

func verifySubprotocol(subprotos []string, resp *http.Response) error {
Expand All @@ -221,19 +224,19 @@ func verifySubprotocol(subprotos []string, resp *http.Response) error {
return fmt.Errorf("WebSocket protocol violation: unexpected Sec-WebSocket-Protocol from server: %q", proto)
}

func verifyServerExtensions(h http.Header) (*compressionOptions, error) {
func verifyServerExtensions(copts *compressionOptions, h http.Header) (*compressionOptions, error) {
exts := websocketExtensions(h)
if len(exts) == 0 {
return nil, nil
}

ext := exts[0]
if ext.name != "permessage-deflate" || len(exts) > 1 {
if ext.name != "permessage-deflate" || len(exts) > 1 || copts == nil {
return nil, fmt.Errorf("WebSocket protcol violation: unsupported extensions from server: %+v", exts[1:])
}

copts := &compressionOptions{}
copts.clientMaxWindowBits = 8
copts = &*copts

for _, p := range ext.params {
switch p {
case "client_no_context_takeover":
Expand All @@ -244,24 +247,6 @@ func verifyServerExtensions(h http.Header) (*compressionOptions, error) {
continue
}

if false && strings.HasPrefix(p, "server_max_window_bits") {
bits, ok := parseExtensionParameter(p, 0)
if !ok || bits < 8 || bits > 16 {
return nil, fmt.Errorf("invalid server_max_window_bits: %q", p)
}
copts.serverMaxWindowBits = bits
continue
}

if false && strings.HasPrefix(p, "client_max_window_bits") {
bits, ok := parseExtensionParameter(p, 0)
if !ok || bits < 8 || bits > 16 {
return nil, fmt.Errorf("invalid client_max_window_bits: %q", p)
}
copts.clientMaxWindowBits = 8
continue
}

return nil, fmt.Errorf("unsupported permessage-deflate parameter: %q", p)
}

Expand Down
2 changes: 1 addition & 1 deletion dial_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -221,7 +221,7 @@ func Test_verifyServerHandshake(t *testing.T) {
opts := &DialOptions{
Subprotocols: strings.Split(r.Header.Get("Sec-WebSocket-Protocol"), ","),
}
_, err = verifyServerResponse(opts, key, resp)
_, err = verifyServerResponse(opts, opts.CompressionMode.opts(), key, resp)
if tc.success {
assert.Success(t, err)
} else {
Expand Down

0 comments on commit fbd323c

Please sign in to comment.