Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Add missing type hints to synapse.replication. #11938

Merged
merged 6 commits into from
Feb 8, 2022
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changelog.d/11938.misc
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Add missing type hints to replication code.
3 changes: 3 additions & 0 deletions mypy.ini
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,9 @@ disallow_untyped_defs = True
[mypy-synapse.push.*]
disallow_untyped_defs = True

[mypy-synapse.replication.*]
disallow_untyped_defs = True

[mypy-synapse.rest.*]
disallow_untyped_defs = True

Expand Down
2 changes: 1 addition & 1 deletion synapse/replication/slave/storage/_slaved_id_tracker.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def __init__(
for table, column in extra_tables:
self.advance(None, _load_current_id(db_conn, table, column))

def advance(self, instance_name: Optional[str], new_id: int):
def advance(self, instance_name: Optional[str], new_id: int) -> None:
self._current = (max if self.step > 0 else min)(self._current, new_id)

def get_current_token(self) -> int:
Expand Down
4 changes: 3 additions & 1 deletion synapse/replication/slave/storage/client_ips.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,9 @@ def __init__(
cache_name="client_ip_last_seen", max_size=50000
)

async def insert_client_ip(self, user_id, access_token, ip, user_agent, device_id):
async def insert_client_ip(
self, user_id: str, access_token: str, ip: str, user_agent: str, device_id: str
) -> None:
now = int(self._clock.time_msec())
key = (user_id, access_token, ip)

Expand Down
10 changes: 7 additions & 3 deletions synapse/replication/slave/storage/devices.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Any, Iterable

from synapse.replication.slave.storage._base import BaseSlavedStore
from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker
Expand Down Expand Up @@ -60,7 +60,9 @@ def __init__(
def get_device_stream_token(self) -> int:
return self._device_list_id_gen.get_current_token()

def process_replication_rows(self, stream_name, instance_name, token, rows):
def process_replication_rows(
self, stream_name: str, instance_name: str, token: int, rows: Iterable[Any]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should rows be Iterable[DeviceListsStream.DeviceListsStreamRow] as on line +76?

I was about to add "and similar for other overloads of process_replication_rows". But I suspect mypy doesn't like it when a subclass narrows and argument type in an overload---something to do with co/contravariance etc.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The APIs here are annoying -- it can be any of the various types of replication rows. The super() call at the bottom of this method just passes through all the rows again, so they're not very well typed. 😢

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe there's some way to use Generics, but... not one for now.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it actually has to accept them all? It is a bit annoying though cause the (stream_name, row_type) is essentially paired, but I don't think we can tell mypy that? Essentially DeviceListsStream.DeviceListsStreamRow only ever occurs when we have stream_name = 'device_stream'...

Anyway, yes I think this is not one for now. 👍

) -> None:
if stream_name == DeviceListsStream.NAME:
self._device_list_id_gen.advance(instance_name, token)
self._invalidate_caches_for_devices(token, rows)
Expand All @@ -70,7 +72,9 @@ def process_replication_rows(self, stream_name, instance_name, token, rows):
self._user_signature_stream_cache.entity_has_changed(row.user_id, token)
return super().process_replication_rows(stream_name, instance_name, token, rows)

def _invalidate_caches_for_devices(self, token, rows):
def _invalidate_caches_for_devices(
self, token: int, rows: Iterable[DeviceListsStream.DeviceListsStreamRow]
) -> None:
for row in rows:
# The entities are either user IDs (starting with '@') whose devices
# have changed, or remote servers that we need to tell about
Expand Down
8 changes: 5 additions & 3 deletions synapse/replication/slave/storage/groups.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Any, Iterable

from synapse.replication.slave.storage._base import BaseSlavedStore
from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker
Expand Down Expand Up @@ -44,10 +44,12 @@ def __init__(
self._group_updates_id_gen.get_current_token(),
)

def get_group_stream_token(self):
def get_group_stream_token(self) -> int:
return self._group_updates_id_gen.get_current_token()

def process_replication_rows(self, stream_name, instance_name, token, rows):
def process_replication_rows(
self, stream_name: str, instance_name: str, token: int, rows: Iterable[Any]
) -> None:
if stream_name == GroupServerStream.NAME:
self._group_updates_id_gen.advance(instance_name, token)
for row in rows:
Expand Down
7 changes: 5 additions & 2 deletions synapse/replication/slave/storage/push_rule.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Iterable

from synapse.replication.tcp.streams import PushRulesStream
from synapse.storage.databases.main.push_rule import PushRulesWorkerStore
Expand All @@ -20,10 +21,12 @@


class SlavedPushRuleStore(SlavedEventStore, PushRulesWorkerStore):
def get_max_push_rules_stream_id(self):
def get_max_push_rules_stream_id(self) -> int:
return self._push_rules_stream_id_gen.get_current_token()

def process_replication_rows(self, stream_name, instance_name, token, rows):
def process_replication_rows(
self, stream_name: str, instance_name: str, token: int, rows: Iterable[Any]
) -> None:
if stream_name == PushRulesStream.NAME:
self._push_rules_stream_id_gen.advance(instance_name, token)
for row in rows:
Expand Down
6 changes: 3 additions & 3 deletions synapse/replication/slave/storage/pushers.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Any, Iterable

from synapse.replication.tcp.streams import PushersStream
from synapse.storage.database import DatabasePool, LoggingDatabaseConnection
Expand Down Expand Up @@ -41,8 +41,8 @@ def get_pushers_stream_token(self) -> int:
return self._pushers_id_gen.get_current_token()

def process_replication_rows(
self, stream_name: str, instance_name: str, token, rows
self, stream_name: str, instance_name: str, token: int, rows: Iterable[Any]
) -> None:
if stream_name == PushersStream.NAME:
self._pushers_id_gen.advance(instance_name, token) # type: ignore
self._pushers_id_gen.advance(instance_name, token)
return super().process_replication_rows(stream_name, instance_name, token, rows)
45 changes: 26 additions & 19 deletions synapse/replication/tcp/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,12 @@
"""A replication client for use by synapse workers.
"""
import logging
from typing import TYPE_CHECKING, Dict, List, Optional, Set, Tuple
from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Set, Tuple

from twisted.internet.defer import Deferred
from twisted.internet.interfaces import IAddress, IConnector
from twisted.internet.protocol import ReconnectingClientFactory
from twisted.python.failure import Failure

from synapse.api.constants import EventTypes
from synapse.federation import send_queue
Expand Down Expand Up @@ -79,10 +81,10 @@ def __init__(

hs.get_reactor().addSystemEventTrigger("before", "shutdown", self.stopTrying)

def startedConnecting(self, connector):
def startedConnecting(self, connector: IConnector) -> None:
logger.info("Connecting to replication: %r", connector.getDestination())

def buildProtocol(self, addr):
def buildProtocol(self, addr: IAddress) -> ClientReplicationStreamProtocol:
logger.info("Connected to replication: %r", addr)
return ClientReplicationStreamProtocol(
self.hs,
Expand All @@ -92,11 +94,11 @@ def buildProtocol(self, addr):
self.command_handler,
)

def clientConnectionLost(self, connector, reason):
def clientConnectionLost(self, connector: IConnector, reason: Failure) -> None:
logger.error("Lost replication conn: %r", reason)
ReconnectingClientFactory.clientConnectionLost(self, connector, reason)

def clientConnectionFailed(self, connector, reason):
def clientConnectionFailed(self, connector: IConnector, reason: Failure) -> None:
logger.error("Failed to connect to replication: %r", reason)
ReconnectingClientFactory.clientConnectionFailed(self, connector, reason)

Expand Down Expand Up @@ -131,7 +133,7 @@ def __init__(self, hs: "HomeServer"):

async def on_rdata(
self, stream_name: str, instance_name: str, token: int, rows: list
):
) -> None:
"""Called to handle a batch of replication data with a given stream token.

By default this just pokes the slave store. Can be overridden in subclasses to
Expand Down Expand Up @@ -252,14 +254,16 @@ async def on_rdata(
# loop. (This maintains the order so no need to resort)
waiting_list[:] = waiting_list[index_of_first_deferred_not_called:]

async def on_position(self, stream_name: str, instance_name: str, token: int):
async def on_position(
self, stream_name: str, instance_name: str, token: int
) -> None:
await self.on_rdata(stream_name, instance_name, token, [])

# We poke the generic "replication" notifier to wake anything up that
# may be streaming.
self.notifier.notify_replication()

def on_remote_server_up(self, server: str):
def on_remote_server_up(self, server: str) -> None:
"""Called when get a new REMOTE_SERVER_UP command."""

# Let's wake up the transaction queue for the server in case we have
Expand All @@ -269,7 +273,7 @@ def on_remote_server_up(self, server: str):

async def wait_for_stream_position(
self, instance_name: str, stream_name: str, position: int
):
) -> None:
"""Wait until this instance has received updates up to and including
the given stream position.
"""
Expand Down Expand Up @@ -304,7 +308,7 @@ async def wait_for_stream_position(
"Finished waiting for repl stream %r to reach %s", stream_name, position
)

def stop_pusher(self, user_id, app_id, pushkey):
def stop_pusher(self, user_id: str, app_id: str, pushkey: str) -> None:
if not self._notify_pushers:
return

Expand All @@ -316,13 +320,13 @@ def stop_pusher(self, user_id, app_id, pushkey):
logger.info("Stopping pusher %r / %r", user_id, key)
pusher.on_stop()

async def start_pusher(self, user_id, app_id, pushkey):
async def start_pusher(self, user_id: str, app_id: str, pushkey: str) -> None:
if not self._notify_pushers:
return

key = "%s:%s" % (app_id, pushkey)
logger.info("Starting pusher %r / %r", user_id, key)
return await self._pusher_pool.start_pusher_by_id(app_id, pushkey, user_id)
await self._pusher_pool.start_pusher_by_id(app_id, pushkey, user_id)


class FederationSenderHandler:
Expand Down Expand Up @@ -353,10 +357,12 @@ def __init__(self, hs: "HomeServer"):

self._fed_position_linearizer = Linearizer(name="_fed_position_linearizer")

def wake_destination(self, server: str):
def wake_destination(self, server: str) -> None:
self.federation_sender.wake_destination(server)

async def process_replication_rows(self, stream_name, token, rows):
async def process_replication_rows(
self, stream_name: str, token: int, rows: list
) -> None:
# The federation stream contains things that we want to send out, e.g.
# presence, typing, etc.
if stream_name == "federation":
Expand Down Expand Up @@ -384,11 +390,12 @@ async def process_replication_rows(self, stream_name, token, rows):
for host in hosts:
self.federation_sender.send_device_messages(host)

async def _on_new_receipts(self, rows):
async def _on_new_receipts(
self, rows: Iterable[ReceiptsStream.ReceiptsStreamRow]
) -> None:
"""
Args:
rows (Iterable[synapse.replication.tcp.streams.ReceiptsStream.ReceiptsStreamRow]):
new receipts to be processed
rows: new receipts to be processed
"""
for receipt in rows:
# we only want to send on receipts for our own users
Expand All @@ -408,7 +415,7 @@ async def _on_new_receipts(self, rows):
)
await self.federation_sender.send_read_receipt(receipt_info)

async def update_token(self, token):
async def update_token(self, token: int) -> None:
"""Update the record of where we have processed to in the federation stream.

Called after we have processed a an update received over replication. Sends
Expand All @@ -428,7 +435,7 @@ async def update_token(self, token):

run_as_background_process("_save_and_send_ack", self._save_and_send_ack)

async def _save_and_send_ack(self):
async def _save_and_send_ack(self) -> None:
"""Save the current federation position in the database and send an ACK
to master with where we're up to.
"""
Expand Down
Loading