diff --git a/okhttp-testing-support/src/jvmMain/kotlin/okhttp3/OkHttpClientTestRule.kt b/okhttp-testing-support/src/jvmMain/kotlin/okhttp3/OkHttpClientTestRule.kt index d5df9acae272..811b6e19645f 100644 --- a/okhttp-testing-support/src/jvmMain/kotlin/okhttp3/OkHttpClientTestRule.kt +++ b/okhttp-testing-support/src/jvmMain/kotlin/okhttp3/OkHttpClientTestRule.kt @@ -190,7 +190,9 @@ class OkHttpClientTestRule : BeforeEachCallback, AfterEachCallback { println("After delay: " + connectionPool.connectionCount()) } - assertEquals(0, connectionPool.connectionCount()) + assertEquals(0, connectionPool.connectionCount()) { + "Still ${connectionPool.connectionCount()} connections open" + } } } diff --git a/okhttp/src/jvmMain/kotlin/okhttp3/internal/ws/RealWebSocket.kt b/okhttp/src/jvmMain/kotlin/okhttp3/internal/ws/RealWebSocket.kt index 7934222ff008..f048f8488c56 100644 --- a/okhttp/src/jvmMain/kotlin/okhttp3/internal/ws/RealWebSocket.kt +++ b/okhttp/src/jvmMain/kotlin/okhttp3/internal/ws/RealWebSocket.kt @@ -170,9 +170,9 @@ class RealWebSocket( checkUpgradeSuccess(response, exchange) streams = exchange!!.newWebSocketStreams() } catch (e: IOException) { - exchange?.webSocketUpgradeFailed() failWebSocket(e, response) response.closeQuietly() + exchange?.webSocketUpgradeFailed() return } diff --git a/okhttp/src/jvmTest/java/okhttp3/internal/ws/WebSocketHttpTest.java b/okhttp/src/jvmTest/java/okhttp3/internal/ws/WebSocketHttpTest.java index 01f1f3f77b8c..3f7eef947bf3 100644 --- a/okhttp/src/jvmTest/java/okhttp3/internal/ws/WebSocketHttpTest.java +++ b/okhttp/src/jvmTest/java/okhttp3/internal/ws/WebSocketHttpTest.java @@ -22,6 +22,8 @@ import java.net.ProtocolException; import java.net.SocketTimeoutException; import java.time.Duration; +import java.util.ArrayList; +import java.util.List; import java.util.Random; import java.util.concurrent.CountDownLatch; import java.util.concurrent.TimeUnit; @@ -33,6 +35,7 @@ import mockwebserver3.SocketPolicy; import mockwebserver3.SocketPolicy.KeepOpen; import mockwebserver3.SocketPolicy.NoResponse; +import okhttp3.ConnectionPool; import okhttp3.OkHttpClient; import okhttp3.OkHttpClientTestRule; import okhttp3.Protocol; @@ -351,10 +354,16 @@ private OkHttpClientTestRule configureClientTestRule() { .setHeader("Upgrade", "websocket") .setHeader("Sec-WebSocket-Accept", "ujmZX4KXZqjwy6vi1aQFH5p4Ygk=") .build()); - newWebSocket(); + webServer.enqueue(new MockResponse.Builder() + .socketPolicy(SocketPolicy.DisconnectAtStart.INSTANCE) + .build()); + + RealWebSocket webSocket = newWebSocket(); clientListener.assertFailure(101, null, ProtocolException.class, "Expected 'Connection' header value 'Upgrade' but was 'null'"); + + webSocket.cancel(); } @Test public void wrongConnectionHeader() throws IOException { @@ -364,10 +373,16 @@ private OkHttpClientTestRule configureClientTestRule() { .setHeader("Connection", "Downgrade") .setHeader("Sec-WebSocket-Accept", "ujmZX4KXZqjwy6vi1aQFH5p4Ygk=") .build()); - newWebSocket(); + webServer.enqueue(new MockResponse.Builder() + .socketPolicy(SocketPolicy.DisconnectAtStart.INSTANCE) + .build()); + + RealWebSocket webSocket = newWebSocket(); clientListener.assertFailure(101, null, ProtocolException.class, "Expected 'Connection' header value 'Upgrade' but was 'Downgrade'"); + + webSocket.cancel(); } @Test public void missingUpgradeHeader() throws IOException { @@ -376,10 +391,16 @@ private OkHttpClientTestRule configureClientTestRule() { .setHeader("Connection", "Upgrade") .setHeader("Sec-WebSocket-Accept", "ujmZX4KXZqjwy6vi1aQFH5p4Ygk=") .build()); - newWebSocket(); + webServer.enqueue(new MockResponse.Builder() + .socketPolicy(SocketPolicy.DisconnectAtStart.INSTANCE) + .build()); + + RealWebSocket webSocket = newWebSocket(); clientListener.assertFailure(101, null, ProtocolException.class, "Expected 'Upgrade' header value 'websocket' but was 'null'"); + + webSocket.cancel(); } @Test public void wrongUpgradeHeader() throws IOException { @@ -389,10 +410,16 @@ private OkHttpClientTestRule configureClientTestRule() { .setHeader("Upgrade", "Pepsi") .setHeader("Sec-WebSocket-Accept", "ujmZX4KXZqjwy6vi1aQFH5p4Ygk=") .build()); - newWebSocket(); + webServer.enqueue(new MockResponse.Builder() + .socketPolicy(SocketPolicy.DisconnectAtStart.INSTANCE) + .build()); + + RealWebSocket webSocket = newWebSocket(); clientListener.assertFailure(101, null, ProtocolException.class, "Expected 'Upgrade' header value 'websocket' but was 'Pepsi'"); + + webSocket.cancel(); } @Test public void missingMagicHeader() throws IOException { @@ -401,10 +428,16 @@ private OkHttpClientTestRule configureClientTestRule() { .setHeader("Connection", "Upgrade") .setHeader("Upgrade", "websocket") .build()); - newWebSocket(); + webServer.enqueue(new MockResponse.Builder() + .socketPolicy(SocketPolicy.DisconnectAtStart.INSTANCE) + .build()); + + RealWebSocket webSocket = newWebSocket(); clientListener.assertFailure(101, null, ProtocolException.class, "Expected 'Sec-WebSocket-Accept' header value 'ujmZX4KXZqjwy6vi1aQFH5p4Ygk=' but was 'null'"); + + webSocket.cancel(); } @Test public void wrongMagicHeader() throws IOException { @@ -414,10 +447,16 @@ private OkHttpClientTestRule configureClientTestRule() { .setHeader("Upgrade", "websocket") .setHeader("Sec-WebSocket-Accept", "magic") .build()); - newWebSocket(); + webServer.enqueue(new MockResponse.Builder() + .socketPolicy(SocketPolicy.DisconnectAtStart.INSTANCE) + .build()); + + RealWebSocket webSocket = newWebSocket(); clientListener.assertFailure(101, null, ProtocolException.class, "Expected 'Sec-WebSocket-Accept' header value 'ujmZX4KXZqjwy6vi1aQFH5p4Ygk=' but was 'magic'"); + + webSocket.cancel(); } @Test public void clientIncludesForbiddenHeader() throws IOException { @@ -868,6 +907,53 @@ private OkHttpClientTestRule configureClientTestRule() { webSocket.close(1000, null); } + /** https://github.com/square/okhttp/issues/7768 */ + @Test public void reconnectingToNonWebSocket() throws InterruptedException { + // Async test is problematic + client = this.client.newBuilder() + .connectionPool(new ConnectionPool()) + .build(); + + for (int i = 0; i < 30; i++) { + webServer.enqueue(new MockResponse.Builder() + .bodyDelay(100, TimeUnit.MILLISECONDS) + .body("Wrong endpoint") + .code(401) + .build()); + } + + Request request = new Request.Builder() + .url(webServer.url("/")) + .build(); + + CountDownLatch attempts = new CountDownLatch(20); + + List webSockets = new ArrayList<>(); + + WebSocketListener reconnectOnFailure = new WebSocketListener() { + @Override + public void onFailure(WebSocket webSocket, Throwable t, Response response) { + if (attempts.getCount() > 0) { + clientListener.setNextEventDelegate(this); + webSockets.add(client.newWebSocket(request, clientListener)); + attempts.countDown(); + } + } + }; + + clientListener.setNextEventDelegate(reconnectOnFailure); + + webSockets.add(client.newWebSocket(request, clientListener)); + + attempts.await(); + + for (WebSocket webSocket: webSockets) { + webSocket.cancel(); + } + client.dispatcher().cancelAll(); + client.connectionPool().evictAll(); + } + @Test public void compressedMessages() throws Exception { successfulExtensions("permessage-deflate"); }