Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add and use ClientConnectionResetError #9137

Merged
merged 18 commits into from
Sep 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CHANGES/9137.bugfix.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
Added :exc:`aiohttp.ClientConnectionResetError`. Client code that previously threw :exc:`ConnectionResetError`
will now throw this -- by :user:`Dreamsorcerer`.
bdraco marked this conversation as resolved.
Show resolved Hide resolved
2 changes: 2 additions & 0 deletions aiohttp/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from .client import (
BaseConnector,
ClientConnectionError,
ClientConnectionResetError,
ClientConnectorCertificateError,
ClientConnectorError,
ClientConnectorSSLError,
Expand Down Expand Up @@ -117,6 +118,7 @@
# client
"BaseConnector",
"ClientConnectionError",
"ClientConnectionResetError",
"ClientConnectorCertificateError",
"ClientConnectorError",
"ClientConnectorSSLError",
Expand Down
3 changes: 2 additions & 1 deletion aiohttp/base_protocol.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import asyncio
from typing import Optional, cast

from .client_exceptions import ClientConnectionResetError
from .helpers import set_exception
from .tcp_helpers import tcp_nodelay

Expand Down Expand Up @@ -85,7 +86,7 @@ def connection_lost(self, exc: Optional[BaseException]) -> None:

async def _drain_helper(self) -> None:
if not self.connected:
raise ConnectionResetError("Connection lost")
raise ClientConnectionResetError("Connection lost")
if not self._paused:
return
waiter = self._drain_waiter
Expand Down
2 changes: 2 additions & 0 deletions aiohttp/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
from .abc import AbstractCookieJar
from .client_exceptions import (
ClientConnectionError,
ClientConnectionResetError,
ClientConnectorCertificateError,
ClientConnectorError,
ClientConnectorSSLError,
Expand Down Expand Up @@ -107,6 +108,7 @@
__all__ = (
# client_exceptions
"ClientConnectionError",
"ClientConnectionResetError",
"ClientConnectorCertificateError",
"ClientConnectorError",
"ClientConnectorSSLError",
Expand Down
9 changes: 7 additions & 2 deletions aiohttp/client_exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@

from multidict import MultiMapping

from .http_parser import RawResponseMessage
from .typedefs import StrOrURL

try:
Expand All @@ -18,12 +17,14 @@

if TYPE_CHECKING:
from .client_reqrep import ClientResponse, ConnectionKey, Fingerprint, RequestInfo
from .http_parser import RawResponseMessage
else:
RequestInfo = ClientResponse = ConnectionKey = None
RequestInfo = ClientResponse = ConnectionKey = RawResponseMessage = None

__all__ = (
"ClientError",
"ClientConnectionError",
"ClientConnectionResetError",
"ClientOSError",
"ClientConnectorError",
"ClientProxyConnectionError",
Expand Down Expand Up @@ -126,6 +127,10 @@ class ClientConnectionError(ClientError):
"""Base class for client socket errors."""


class ClientConnectionResetError(ClientConnectionError, ConnectionResetError):
"""ConnectionResetError"""


class ClientOSError(ClientConnectionError, OSError):
"""OSError error."""

Expand Down
5 changes: 3 additions & 2 deletions aiohttp/http_websocket.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
)

from .base_protocol import BaseProtocol
from .client_exceptions import ClientConnectionResetError
from .compression_utils import ZLibCompressor, ZLibDecompressor
from .helpers import NO_EXTENSIONS, set_exception
from .streams import DataQueue
Expand Down Expand Up @@ -609,7 +610,7 @@ async def _send_frame(
) -> None:
"""Send a frame over the websocket with message as its payload."""
if self._closing and not (opcode & WSMsgType.CLOSE):
raise ConnectionResetError("Cannot write to closing transport")
raise ClientConnectionResetError("Cannot write to closing transport")

# RSV are the reserved bits in the frame header. They are used to
# indicate that the frame is using an extension.
Expand Down Expand Up @@ -704,7 +705,7 @@ def _make_compress_obj(self, compress: int) -> ZLibCompressor:

def _write(self, data: bytes) -> None:
if self.transport.is_closing():
raise ConnectionResetError("Cannot write to closing transport")
raise ClientConnectionResetError("Cannot write to closing transport")
self.transport.write(data)

async def pong(self, message: Union[bytes, str] = b"") -> None:
Expand Down
3 changes: 2 additions & 1 deletion aiohttp/http_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

from .abc import AbstractStreamWriter
from .base_protocol import BaseProtocol
from .client_exceptions import ClientConnectionResetError
from .compression_utils import ZLibCompressor
from .helpers import NO_EXTENSIONS

Expand Down Expand Up @@ -72,7 +73,7 @@ def _write(self, chunk: bytes) -> None:
self.output_size += size
transport = self.transport
if not self._protocol.connected or transport is None or transport.is_closing():
raise ConnectionResetError("Cannot write to closing transport")
raise ClientConnectionResetError("Cannot write to closing transport")
transport.write(chunk)

async def write(
Expand Down
6 changes: 6 additions & 0 deletions docs/client_reference.rst
Original file line number Diff line number Diff line change
Expand Up @@ -2193,6 +2193,10 @@ Connection errors

Derived from :exc:`ClientError`

.. class:: ClientConnectionResetError

Derived from :exc:`ClientConnectionError` and :exc:`ConnectionResetError`

.. class:: ClientOSError

Subset of connection errors that are initiated by an :exc:`OSError`
Expand Down Expand Up @@ -2279,6 +2283,8 @@ Hierarchy of exceptions

* :exc:`ClientConnectionError`

* :exc:`ClientConnectionResetError`

* :exc:`ClientOSError`

* :exc:`ClientConnectorError`
Expand Down
20 changes: 14 additions & 6 deletions tests/test_client_ws.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,19 @@
import base64
import hashlib
import os
from typing import Mapping
from typing import Mapping, Type
from unittest import mock

import pytest

import aiohttp
from aiohttp import client, hdrs
from aiohttp.client_exceptions import ServerDisconnectedError
from aiohttp.client_ws import ClientWSTimeout
from aiohttp import (
ClientConnectionResetError,
ClientWSTimeout,
ServerDisconnectedError,
client,
hdrs,
)
from aiohttp.http import WS_KEY
from aiohttp.streams import EofStream
from aiohttp.test_utils import make_mocked_coro
Expand Down Expand Up @@ -535,8 +539,12 @@ async def test_close_exc2(
await resp.close()


@pytest.mark.parametrize("exc", (ClientConnectionResetError, ConnectionResetError))
async def test_send_data_after_close(
ws_key: bytes, key_data: bytes, loop: asyncio.AbstractEventLoop
exc: Type[Exception],
ws_key: bytes,
key_data: bytes,
loop: asyncio.AbstractEventLoop,
) -> None:
mresp = mock.Mock()
mresp.status = 101
Expand All @@ -562,7 +570,7 @@ async def test_send_data_after_close(
(resp.send_bytes, (b"b",)),
(resp.send_json, ({},)),
):
with pytest.raises(ConnectionResetError):
with pytest.raises(exc): # Verify exc can be caught with both classes
await meth(*args)


Expand Down
4 changes: 2 additions & 2 deletions tests/test_client_ws_functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import pytest

import aiohttp
from aiohttp import ServerTimeoutError, WSMsgType, hdrs, web
from aiohttp import ClientConnectionResetError, ServerTimeoutError, WSMsgType, hdrs, web
from aiohttp.client_ws import ClientWSTimeout
from aiohttp.http import WSCloseCode
from aiohttp.pytest_plugin import AiohttpClient, AiohttpServer
Expand Down Expand Up @@ -681,7 +681,7 @@ async def handler(request: web.Request) -> NoReturn:
# would cancel the heartbeat task and we wouldn't get a ping
assert resp._conn is not None
with mock.patch.object(
resp._conn.transport, "write", side_effect=ConnectionResetError
resp._conn.transport, "write", side_effect=ClientConnectionResetError
), mock.patch.object(resp._writer, "ping", wraps=resp._writer.ping) as ping:
await resp.receive()
ping_count = ping.call_count
Expand Down
10 changes: 6 additions & 4 deletions tests/test_http_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import pytest
from multidict import CIMultiDict

from aiohttp import http
from aiohttp import ClientConnectionResetError, http
from aiohttp.base_protocol import BaseProtocol
from aiohttp.test_utils import make_mocked_coro

Expand Down Expand Up @@ -301,7 +301,7 @@ async def test_write_to_closing_transport(
await msg.write(b"Before closing")
transport.is_closing.return_value = True # type: ignore[attr-defined]

with pytest.raises(ConnectionResetError):
with pytest.raises(ClientConnectionResetError):
await msg.write(b"After closing")


Expand All @@ -310,7 +310,7 @@ async def test_write_to_closed_transport(
transport: asyncio.Transport,
loop: asyncio.AbstractEventLoop,
) -> None:
"""Test that writing to a closed transport raises ConnectionResetError.
"""Test that writing to a closed transport raises ClientConnectionResetError.

The StreamWriter checks to see if protocol.transport is None before
writing to the transport. If it is None, it raises ConnectionResetError.
Expand All @@ -320,7 +320,9 @@ async def test_write_to_closed_transport(
await msg.write(b"Before transport close")
protocol.transport = None

with pytest.raises(ConnectionResetError, match="Cannot write to closing transport"):
with pytest.raises(
ClientConnectionResetError, match="Cannot write to closing transport"
):
await msg.write(b"After transport closed")


Expand Down
Loading