diff --git a/CHANGES/8444.bugfix b/CHANGES/8444.bugfix new file mode 100644 index 00000000000..774e13064a7 --- /dev/null +++ b/CHANGES/8444.bugfix @@ -0,0 +1,2 @@ +Fix ``ws_connect`` not respecting `receive_timeout`` on WS(S) connection. +-- by :user:`arcivanov`. diff --git a/CONTRIBUTORS.txt b/CONTRIBUTORS.txt index 4442664118f..202193375dd 100644 --- a/CONTRIBUTORS.txt +++ b/CONTRIBUTORS.txt @@ -46,6 +46,7 @@ Anes Abismail Antoine Pietri Anton Kasyanov Anton Zhdan-Pushkin +Arcadiy Ivanov Arseny Timoniq Artem Yushkovskiy Arthur Darcet diff --git a/aiohttp/client.py b/aiohttp/client.py index 2541addcd06..c70ad65c59e 100644 --- a/aiohttp/client.py +++ b/aiohttp/client.py @@ -1009,6 +1009,16 @@ async def _ws_connect( assert conn is not None conn_proto = conn.protocol assert conn_proto is not None + + # For WS connection the read_timeout must be either receive_timeout or greater + # None == no timeout, i.e. infinite timeout, so None is the max timeout possible + if receive_timeout is None: + # Reset regardless + conn_proto.read_timeout = receive_timeout + elif conn_proto.read_timeout is not None: + # If read_timeout was set check which wins + conn_proto.read_timeout = max(receive_timeout, conn_proto.read_timeout) + transport = conn.transport assert transport is not None reader: FlowControlDataQueue[WSMessage] = FlowControlDataQueue( diff --git a/aiohttp/client_proto.py b/aiohttp/client_proto.py index 28e9d3cd9e5..f8c83240209 100644 --- a/aiohttp/client_proto.py +++ b/aiohttp/client_proto.py @@ -224,6 +224,14 @@ def _reschedule_timeout(self) -> None: def start_timeout(self) -> None: self._reschedule_timeout() + @property + def read_timeout(self) -> Optional[float]: + return self._read_timeout + + @read_timeout.setter + def read_timeout(self, read_timeout: Optional[float]) -> None: + self._read_timeout = read_timeout + def _on_read_timeout(self) -> None: exc = SocketTimeoutError("Timeout on reading data from socket") self.set_exception(exc) diff --git a/tests/test_client_ws.py b/tests/test_client_ws.py index 4be404f7752..ebc9d910c1a 100644 --- a/tests/test_client_ws.py +++ b/tests/test_client_ws.py @@ -23,6 +23,7 @@ async def test_ws_connect(ws_key: Any, loop: Any, key_data: Any) -> None: hdrs.SEC_WEBSOCKET_ACCEPT: ws_key, hdrs.SEC_WEBSOCKET_PROTOCOL: "chat", } + resp.connection.protocol.read_timeout = None with mock.patch("aiohttp.client.os") as m_os: with mock.patch("aiohttp.client.ClientSession.request") as m_req: m_os.urandom.return_value = key_data @@ -38,6 +39,97 @@ async def test_ws_connect(ws_key: Any, loop: Any, key_data: Any) -> None: assert hdrs.ORIGIN not in m_req.call_args[1]["headers"] +async def test_ws_connect_read_timeout_is_reset_to_inf( + ws_key: Any, loop: Any, key_data: Any +) -> None: + resp = mock.Mock() + resp.status = 101 + resp.headers = { + hdrs.UPGRADE: "websocket", + hdrs.CONNECTION: "upgrade", + hdrs.SEC_WEBSOCKET_ACCEPT: ws_key, + hdrs.SEC_WEBSOCKET_PROTOCOL: "chat", + } + resp.connection.protocol.read_timeout = 0.5 + with mock.patch("aiohttp.client.os") as m_os, mock.patch( + "aiohttp.client.ClientSession.request" + ) as m_req: + m_os.urandom.return_value = key_data + m_req.return_value = loop.create_future() + m_req.return_value.set_result(resp) + + res = await aiohttp.ClientSession().ws_connect( + "http://test.org", protocols=("t1", "t2", "chat") + ) + + assert isinstance(res, client.ClientWebSocketResponse) + assert res.protocol == "chat" + assert hdrs.ORIGIN not in m_req.call_args[1]["headers"] + assert resp.connection.protocol.read_timeout is None + + +async def test_ws_connect_read_timeout_stays_inf( + ws_key: Any, loop: Any, key_data: Any +) -> None: + resp = mock.Mock() + resp.status = 101 + resp.headers = { + hdrs.UPGRADE: "websocket", + hdrs.CONNECTION: "upgrade", + hdrs.SEC_WEBSOCKET_ACCEPT: ws_key, + hdrs.SEC_WEBSOCKET_PROTOCOL: "chat", + } + resp.connection.protocol.read_timeout = None + with mock.patch("aiohttp.client.os") as m_os, mock.patch( + "aiohttp.client.ClientSession.request" + ) as m_req: + m_os.urandom.return_value = key_data + m_req.return_value = loop.create_future() + m_req.return_value.set_result(resp) + + res = await aiohttp.ClientSession().ws_connect( + "http://test.org", + protocols=("t1", "t2", "chat"), + receive_timeout=0.5, + ) + + assert isinstance(res, client.ClientWebSocketResponse) + assert res.protocol == "chat" + assert hdrs.ORIGIN not in m_req.call_args[1]["headers"] + assert resp.connection.protocol.read_timeout is None + + +async def test_ws_connect_read_timeout_reset_to_max( + ws_key: Any, loop: Any, key_data: Any +) -> None: + resp = mock.Mock() + resp.status = 101 + resp.headers = { + hdrs.UPGRADE: "websocket", + hdrs.CONNECTION: "upgrade", + hdrs.SEC_WEBSOCKET_ACCEPT: ws_key, + hdrs.SEC_WEBSOCKET_PROTOCOL: "chat", + } + resp.connection.protocol.read_timeout = 0.5 + with mock.patch("aiohttp.client.os") as m_os, mock.patch( + "aiohttp.client.ClientSession.request" + ) as m_req: + m_os.urandom.return_value = key_data + m_req.return_value = loop.create_future() + m_req.return_value.set_result(resp) + + res = await aiohttp.ClientSession().ws_connect( + "http://test.org", + protocols=("t1", "t2", "chat"), + receive_timeout=1.0, + ) + + assert isinstance(res, client.ClientWebSocketResponse) + assert res.protocol == "chat" + assert hdrs.ORIGIN not in m_req.call_args[1]["headers"] + assert resp.connection.protocol.read_timeout == 1.0 + + async def test_ws_connect_with_origin(key_data, loop) -> None: resp = mock.Mock() resp.status = 403 @@ -68,6 +160,7 @@ async def test_ws_connect_with_params(ws_key, loop, key_data) -> None: hdrs.SEC_WEBSOCKET_ACCEPT: ws_key, hdrs.SEC_WEBSOCKET_PROTOCOL: "chat", } + resp.connection.protocol.read_timeout = None with mock.patch("aiohttp.client.os") as m_os: with mock.patch("aiohttp.client.ClientSession.request") as m_req: m_os.urandom.return_value = key_data @@ -93,6 +186,7 @@ def read(self, decode=False): hdrs.CONNECTION: "upgrade", hdrs.SEC_WEBSOCKET_ACCEPT: ws_key, } + resp.connection.protocol.read_timeout = None with mock.patch("aiohttp.client.os") as m_os: with mock.patch("aiohttp.client.ClientSession.request") as m_req: m_os.urandom.return_value = key_data @@ -215,6 +309,7 @@ async def mock_get(*args, **kwargs): hdrs.SEC_WEBSOCKET_ACCEPT: accept, hdrs.SEC_WEBSOCKET_PROTOCOL: "chat", } + resp.connection.protocol.read_timeout = None return resp with mock.patch("aiohttp.client.os") as m_os: @@ -245,6 +340,7 @@ async def test_close(loop, ws_key, key_data) -> None: hdrs.CONNECTION: "upgrade", hdrs.SEC_WEBSOCKET_ACCEPT: ws_key, } + resp.connection.protocol.read_timeout = None with mock.patch("aiohttp.client.WebSocketWriter") as WebSocketWriter: with mock.patch("aiohttp.client.os") as m_os: with mock.patch("aiohttp.client.ClientSession.request") as m_req: @@ -285,6 +381,7 @@ async def test_close_eofstream(loop, ws_key, key_data) -> None: hdrs.CONNECTION: "upgrade", hdrs.SEC_WEBSOCKET_ACCEPT: ws_key, } + resp.connection.protocol.read_timeout = None with mock.patch("aiohttp.client.WebSocketWriter") as WebSocketWriter: with mock.patch("aiohttp.client.os") as m_os: with mock.patch("aiohttp.client.ClientSession.request") as m_req: @@ -315,6 +412,7 @@ async def test_close_exc(loop, ws_key, key_data) -> None: hdrs.CONNECTION: "upgrade", hdrs.SEC_WEBSOCKET_ACCEPT: ws_key, } + resp.connection.protocol.read_timeout = None with mock.patch("aiohttp.client.WebSocketWriter") as WebSocketWriter: with mock.patch("aiohttp.client.os") as m_os: with mock.patch("aiohttp.client.ClientSession.request") as m_req: @@ -347,6 +445,7 @@ async def test_close_exc2(loop, ws_key, key_data) -> None: hdrs.CONNECTION: "upgrade", hdrs.SEC_WEBSOCKET_ACCEPT: ws_key, } + resp.connection.protocol.read_timeout = None with mock.patch("aiohttp.client.WebSocketWriter") as WebSocketWriter: with mock.patch("aiohttp.client.os") as m_os: with mock.patch("aiohttp.client.ClientSession.request") as m_req: @@ -381,6 +480,7 @@ async def test_send_data_after_close(ws_key, key_data, loop) -> None: hdrs.CONNECTION: "upgrade", hdrs.SEC_WEBSOCKET_ACCEPT: ws_key, } + resp.connection.protocol.read_timeout = None with mock.patch("aiohttp.client.os") as m_os: with mock.patch("aiohttp.client.ClientSession.request") as m_req: m_os.urandom.return_value = key_data @@ -409,6 +509,7 @@ async def test_send_data_type_errors(ws_key, key_data, loop) -> None: hdrs.CONNECTION: "upgrade", hdrs.SEC_WEBSOCKET_ACCEPT: ws_key, } + resp.connection.protocol.read_timeout = None with mock.patch("aiohttp.client.WebSocketWriter") as WebSocketWriter: with mock.patch("aiohttp.client.os") as m_os: with mock.patch("aiohttp.client.ClientSession.request") as m_req: @@ -437,6 +538,7 @@ async def test_reader_read_exception(ws_key, key_data, loop) -> None: hdrs.CONNECTION: "upgrade", hdrs.SEC_WEBSOCKET_ACCEPT: ws_key, } + hresp.connection.protocol.read_timeout = None with mock.patch("aiohttp.client.WebSocketWriter") as WebSocketWriter: with mock.patch("aiohttp.client.os") as m_os: with mock.patch("aiohttp.client.ClientSession.request") as m_req: @@ -501,6 +603,7 @@ async def test_ws_connect_non_overlapped_protocols(ws_key, loop, key_data) -> No hdrs.SEC_WEBSOCKET_ACCEPT: ws_key, hdrs.SEC_WEBSOCKET_PROTOCOL: "other,another", } + resp.connection.protocol.read_timeout = None with mock.patch("aiohttp.client.os") as m_os: with mock.patch("aiohttp.client.ClientSession.request") as m_req: m_os.urandom.return_value = key_data @@ -523,6 +626,7 @@ async def test_ws_connect_non_overlapped_protocols_2(ws_key, loop, key_data) -> hdrs.SEC_WEBSOCKET_ACCEPT: ws_key, hdrs.SEC_WEBSOCKET_PROTOCOL: "other,another", } + resp.connection.protocol.read_timeout = None with mock.patch("aiohttp.client.os") as m_os: with mock.patch("aiohttp.client.ClientSession.request") as m_req: m_os.urandom.return_value = key_data @@ -547,6 +651,7 @@ async def test_ws_connect_deflate(loop, ws_key, key_data) -> None: hdrs.SEC_WEBSOCKET_ACCEPT: ws_key, hdrs.SEC_WEBSOCKET_EXTENSIONS: "permessage-deflate", } + resp.connection.protocol.read_timeout = None with mock.patch("aiohttp.client.os") as m_os: with mock.patch("aiohttp.client.ClientSession.request") as m_req: m_os.urandom.return_value = key_data @@ -570,6 +675,7 @@ async def test_ws_connect_deflate_per_message(loop, ws_key, key_data) -> None: hdrs.SEC_WEBSOCKET_ACCEPT: ws_key, hdrs.SEC_WEBSOCKET_EXTENSIONS: "permessage-deflate", } + resp.connection.protocol.read_timeout = None with mock.patch("aiohttp.client.WebSocketWriter") as WebSocketWriter: with mock.patch("aiohttp.client.os") as m_os: with mock.patch("aiohttp.client.ClientSession.request") as m_req: @@ -602,6 +708,7 @@ async def test_ws_connect_deflate_server_not_support(loop, ws_key, key_data) -> hdrs.CONNECTION: "upgrade", hdrs.SEC_WEBSOCKET_ACCEPT: ws_key, } + resp.connection.protocol.read_timeout = None with mock.patch("aiohttp.client.os") as m_os: with mock.patch("aiohttp.client.ClientSession.request") as m_req: m_os.urandom.return_value = key_data @@ -626,6 +733,7 @@ async def test_ws_connect_deflate_notakeover(loop, ws_key, key_data) -> None: hdrs.SEC_WEBSOCKET_EXTENSIONS: "permessage-deflate; " "client_no_context_takeover", } + resp.connection.protocol.read_timeout = None with mock.patch("aiohttp.client.os") as m_os: with mock.patch("aiohttp.client.ClientSession.request") as m_req: m_os.urandom.return_value = key_data @@ -650,6 +758,7 @@ async def test_ws_connect_deflate_client_wbits(loop, ws_key, key_data) -> None: hdrs.SEC_WEBSOCKET_EXTENSIONS: "permessage-deflate; " "client_max_window_bits=10", } + resp.connection.protocol.read_timeout = None with mock.patch("aiohttp.client.os") as m_os: with mock.patch("aiohttp.client.ClientSession.request") as m_req: m_os.urandom.return_value = key_data