From 249edb209389a1b6fd3b1f79de78417982077284 Mon Sep 17 00:00:00 2001 From: Anmol Sethi Date: Wed, 18 Oct 2023 20:44:27 -0700 Subject: [PATCH] dial_test: Add TestDialViaProxy For #395 Somehow currently reproduces #391... Debugging still. --- conn_test.go | 26 ++++++++++++ dial_test.go | 110 ++++++++++++++++++++++++++++++++++++++++++------- export_test.go | 7 ++++ 3 files changed, 128 insertions(+), 15 deletions(-) diff --git a/conn_test.go b/conn_test.go index 50b844b9..5f78cad5 100644 --- a/conn_test.go +++ b/conn_test.go @@ -526,3 +526,29 @@ func echoServer(w http.ResponseWriter, r *http.Request, opts *websocket.AcceptOp err = wstest.EchoLoop(r.Context(), c) return assertCloseStatus(websocket.StatusNormalClosure, err) } + +func assertEcho(tb testing.TB, ctx context.Context, c *websocket.Conn) { + exp := xrand.String(xrand.Int(131072)) + + werr := xsync.Go(func() error { + return wsjson.Write(ctx, c, exp) + }) + + var act interface{} + err := wsjson.Read(ctx, c, &act) + assert.Success(tb, err) + assert.Equal(tb, "read msg", exp, act) + + select { + case err := <-werr: + assert.Success(tb, err) + case <-ctx.Done(): + tb.Fatal(ctx.Err()) + } +} + +func assertClose(tb testing.TB, c *websocket.Conn) { + tb.Helper() + err := c.Close(websocket.StatusNormalClosure, "") + assert.Success(tb, err) +} diff --git a/dial_test.go b/dial_test.go index 3652f8d4..7a84436d 100644 --- a/dial_test.go +++ b/dial_test.go @@ -1,7 +1,7 @@ //go:build !js // +build !js -package websocket +package websocket_test import ( "bytes" @@ -10,12 +10,15 @@ import ( "io" "net/http" "net/http/httptest" + "net/url" "strings" "testing" "time" + "nhooyr.io/websocket" "nhooyr.io/websocket/internal/test/assert" "nhooyr.io/websocket/internal/util" + "nhooyr.io/websocket/internal/xsync" ) func TestBadDials(t *testing.T) { @@ -27,7 +30,7 @@ func TestBadDials(t *testing.T) { testCases := []struct { name string url string - opts *DialOptions + opts *websocket.DialOptions rand util.ReaderFunc nilCtx bool }{ @@ -72,7 +75,7 @@ func TestBadDials(t *testing.T) { tc.rand = rand.Reader.Read } - _, _, err := dial(ctx, tc.url, tc.opts, tc.rand) + _, _, err := websocket.ExportedDial(ctx, tc.url, tc.opts, tc.rand) assert.Error(t, err) }) } @@ -84,7 +87,7 @@ func TestBadDials(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), time.Second*5) defer cancel() - _, _, err := Dial(ctx, "ws://example.com", &DialOptions{ + _, _, err := websocket.Dial(ctx, "ws://example.com", &websocket.DialOptions{ HTTPClient: mockHTTPClient(func(*http.Request) (*http.Response, error) { return &http.Response{ Body: io.NopCloser(strings.NewReader("hi")), @@ -104,7 +107,7 @@ func TestBadDials(t *testing.T) { h := http.Header{} h.Set("Connection", "Upgrade") h.Set("Upgrade", "websocket") - h.Set("Sec-WebSocket-Accept", secWebSocketAccept(r.Header.Get("Sec-WebSocket-Key"))) + h.Set("Sec-WebSocket-Accept", websocket.SecWebSocketAccept(r.Header.Get("Sec-WebSocket-Key"))) return &http.Response{ StatusCode: http.StatusSwitchingProtocols, @@ -113,7 +116,7 @@ func TestBadDials(t *testing.T) { }, nil } - _, _, err := Dial(ctx, "ws://example.com", &DialOptions{ + _, _, err := websocket.Dial(ctx, "ws://example.com", &websocket.DialOptions{ HTTPClient: mockHTTPClient(rt), }) assert.Contains(t, err, "response body is not a io.ReadWriteCloser") @@ -152,7 +155,7 @@ func Test_verifyHostOverride(t *testing.T) { h := http.Header{} h.Set("Connection", "Upgrade") h.Set("Upgrade", "websocket") - h.Set("Sec-WebSocket-Accept", secWebSocketAccept(r.Header.Get("Sec-WebSocket-Key"))) + h.Set("Sec-WebSocket-Accept", websocket.SecWebSocketAccept(r.Header.Get("Sec-WebSocket-Key"))) return &http.Response{ StatusCode: http.StatusSwitchingProtocols, @@ -161,7 +164,7 @@ func Test_verifyHostOverride(t *testing.T) { }, nil } - _, _, err := Dial(ctx, "ws://example.com", &DialOptions{ + _, _, err := websocket.Dial(ctx, "ws://example.com", &websocket.DialOptions{ HTTPClient: mockHTTPClient(rt), Host: tc.host, }) @@ -272,18 +275,18 @@ func Test_verifyServerHandshake(t *testing.T) { resp := w.Result() r := httptest.NewRequest("GET", "/", nil) - key, err := secWebSocketKey(rand.Reader) + key, err := websocket.SecWebSocketKey(rand.Reader) assert.Success(t, err) r.Header.Set("Sec-WebSocket-Key", key) if resp.Header.Get("Sec-WebSocket-Accept") == "" { - resp.Header.Set("Sec-WebSocket-Accept", secWebSocketAccept(key)) + resp.Header.Set("Sec-WebSocket-Accept", websocket.SecWebSocketAccept(key)) } - opts := &DialOptions{ + opts := &websocket.DialOptions{ Subprotocols: strings.Split(r.Header.Get("Sec-WebSocket-Protocol"), ","), } - _, err = verifyServerResponse(opts, opts.CompressionMode.opts(), key, resp) + _, err = websocket.VerifyServerResponse(opts, websocket.CompressionModeOpts(opts.CompressionMode), key, resp) if tc.success { assert.Success(t, err) } else { @@ -311,7 +314,7 @@ func TestDialRedirect(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), time.Second*5) defer cancel() - _, _, err := Dial(ctx, "ws://example.com", &DialOptions{ + _, _, err := websocket.Dial(ctx, "ws://example.com", &websocket.DialOptions{ HTTPClient: mockHTTPClient(func(r *http.Request) (*http.Response, error) { resp := &http.Response{ Header: http.Header{}, @@ -321,11 +324,88 @@ func TestDialRedirect(t *testing.T) { resp.StatusCode = http.StatusFound return resp, nil } - resp.Header.Set("Connection", "Upgrade") - resp.Header.Set("Upgrade", "meow") + resp.Header.Set("Connection", "Upgrade") + resp.Header.Set("Upgrade", "meow") resp.StatusCode = http.StatusSwitchingProtocols return resp, nil }), }) assert.Contains(t, err, "failed to WebSocket dial: WebSocket protocol violation: Upgrade header \"meow\" does not contain websocket") } + +type forwardProxy struct { + hc *http.Client +} + +func newForwardProxy() *forwardProxy { + return &forwardProxy{ + hc: &http.Client{}, + } +} + +func (fc *forwardProxy) ServeHTTP(w http.ResponseWriter, r *http.Request) { + ctx, cancel := context.WithTimeout(r.Context(), time.Second*10) + defer cancel() + + r = r.WithContext(ctx) + r.RequestURI = "" + resp, err := fc.hc.Do(r) + if err != nil { + http.Error(w, err.Error(), http.StatusBadRequest) + return + } + defer resp.Body.Close() + + for k, v := range resp.Header { + w.Header()[k] = v + } + w.Header().Set("PROXIED", "true") + w.WriteHeader(resp.StatusCode) + errc1 := xsync.Go(func() error { + _, err := io.Copy(w, resp.Body) + return err + }) + var errc2 <-chan error + if bodyw, ok := resp.Body.(io.Writer); ok { + errc2 = xsync.Go(func() error { + _, err := io.Copy(bodyw, r.Body) + return err + }) + } + select { + case <-errc1: + case <-errc2: + case <-r.Context().Done(): + } +} + +func TestDialViaProxy(t *testing.T) { + t.Parallel() + + ps := httptest.NewServer(newForwardProxy()) + defer ps.Close() + + s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + err := echoServer(w, r, nil) + assert.Success(t, err) + })) + defer s.Close() + + psu, err := url.Parse(ps.URL) + assert.Success(t, err) + proxyTransport := http.DefaultTransport.(*http.Transport).Clone() + proxyTransport.Proxy = http.ProxyURL(psu) + + ctx, cancel := context.WithTimeout(context.Background(), time.Second*10) + defer cancel() + c, resp, err := websocket.Dial(ctx, s.URL, &websocket.DialOptions{ + HTTPClient: &http.Client{ + Transport: proxyTransport, + }, + }) + assert.Success(t, err) + assert.Equal(t, "", "true", resp.Header.Get("PROXIED")) + + assertEcho(t, ctx, c) + assertClose(t, c) +} diff --git a/export_test.go b/export_test.go index 114796d0..e322c36f 100644 --- a/export_test.go +++ b/export_test.go @@ -25,3 +25,10 @@ func (c *Conn) RecordBytesRead() *int { } var ErrClosed = errClosed + +var ExportedDial = dial +var SecWebSocketAccept = secWebSocketAccept +var SecWebSocketKey = secWebSocketKey +var VerifyServerResponse = verifyServerResponse + +var CompressionModeOpts = CompressionMode.opts