diff --git a/mautrix_twitter/user.py b/mautrix_twitter/user.py index cd94431..7ed7e9c 100644 --- a/mautrix_twitter/user.py +++ b/mautrix_twitter/user.py @@ -15,7 +15,7 @@ # along with this program. If not, see . from __future__ import annotations -from typing import TYPE_CHECKING, AsyncGenerator, AsyncIterable, Awaitable, cast +from typing import TYPE_CHECKING, AsyncGenerator, AsyncIterable, Awaitable, Callable, TypeVar, cast import asyncio import logging @@ -76,6 +76,8 @@ } ) +T = TypeVar("T") + class User(DBUser, BaseUser): by_mxid: dict[UserID, User] = {} @@ -192,6 +194,23 @@ async def locked_connect(self, auth_token: str, csrf_token: str) -> None: finally: self._connect_task = None + async def _hacky_retry_loop( + self, fn: Callable[[], Awaitable[T]], action: str, max_retry_count: int = 5 + ) -> T: + retry_count = 0 + while True: + try: + return await fn() + except aiohttp.ClientResponseError as e: + if e.status == 403 and retry_count < max_retry_count: + retry_count += 1 + self.log.warning( + f"Unexpected 403 in {action}, retrying in 5 seconds", exc_info=True + ) + await asyncio.sleep(5) + else: + raise + async def _connect(self, auth_token: str | None = None, csrf_token: str | None = None) -> None: client = TwitterAPI( log=logging.getLogger("mau.twitter.api").getChild(self.mxid), @@ -202,20 +221,7 @@ async def _connect(self, auth_token: str | None = None, csrf_token: str | None = client.set_tokens(auth_token or self.auth_token, csrf_token or self.csrf_token) # Initial ping to make sure auth works - initial_ping_retry = 0 - while True: - try: - await client.get_user_identifier() - break - except aiohttp.ClientResponseError as e: - if e.status == 403 and initial_ping_retry < 5: - initial_ping_retry += 1 - self.log.warning( - "Unexpected 403 in initial ping, retrying in 5 seconds", exc_info=True - ) - await asyncio.sleep(5) - else: - raise + await self._hacky_retry_loop(client.get_user_identifier, action="initial ping") self.client = client self.client.add_handler(Conversation, self.handle_conversation_update) @@ -231,7 +237,7 @@ async def _connect(self, auth_token: str | None = None, csrf_token: str | None = self.client.add_handler(PollingErrored, self.on_error) self.client.add_handler(PollingErrorResolved, self.on_error_resolved) - user_info = await self.get_info() + user_info = await self._hacky_retry_loop(self.get_info, action="settings fetch") self.twid = user_info.id self._track_metric(METRIC_LOGGED_IN, True) self.by_twid[self.twid] = self