Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Add missing type hints to replication.http. #11856

Merged
merged 4 commits into from
Feb 8, 2022
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changelog.d/11856.misc
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Add missing type hints to replication code.
2 changes: 1 addition & 1 deletion synapse/replication/http/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def __init__(self, hs: "HomeServer"):
super().__init__(hs, canonical_json=False, extract_context=True)
self.register_servlets(hs)

def register_servlets(self, hs: "HomeServer"):
def register_servlets(self, hs: "HomeServer") -> None:
send_event.register_servlets(hs, self)
federation.register_servlets(hs, self)
presence.register_servlets(hs, self)
Expand Down
31 changes: 20 additions & 11 deletions synapse/replication/http/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,16 +15,20 @@
import abc
import logging
import re
import urllib
import urllib.parse
from inspect import signature
from typing import TYPE_CHECKING, Any, Awaitable, Callable, Dict, List, Tuple

from prometheus_client import Counter, Gauge

from twisted.web.server import Request

from synapse.api.errors import HttpResponseException, SynapseError
from synapse.http import RequestTimedOutError
from synapse.http.server import HttpServer
from synapse.logging import opentracing
from synapse.logging.opentracing import trace
from synapse.types import JsonDict
from synapse.util.caches.response_cache import ResponseCache
from synapse.util.stringutils import random_string

Expand Down Expand Up @@ -113,10 +117,12 @@ def __init__(self, hs: "HomeServer"):
if hs.config.worker.worker_replication_secret:
self._replication_secret = hs.config.worker.worker_replication_secret

def _check_auth(self, request) -> None:
def _check_auth(self, request: Request) -> None:
# Get the authorization header.
auth_headers = request.requestHeaders.getRawHeaders(b"Authorization")

if not auth_headers:
raise RuntimeError("Missing Authorization header.")
if len(auth_headers) > 1:
raise RuntimeError("Too many Authorization headers.")
parts = auth_headers[0].split(b" ")
Expand All @@ -129,7 +135,7 @@ def _check_auth(self, request) -> None:
raise RuntimeError("Invalid Authorization header.")

@abc.abstractmethod
async def _serialize_payload(**kwargs):
async def _serialize_payload(**kwargs) -> JsonDict:
"""Static method that is called when creating a request.

Concrete implementations should have explicit parameters (rather than
Expand All @@ -144,19 +150,20 @@ async def _serialize_payload(**kwargs):
return {}

@abc.abstractmethod
async def _handle_request(self, request, **kwargs):
async def _handle_request(
self, request: Request, **kwargs: Any
) -> Tuple[int, JsonDict]:
"""Handle incoming request.

This is called with the request object and PATH_ARGS.

Returns:
tuple[int, dict]: HTTP status code and a JSON serialisable dict
to be used as response body of request.
HTTP status code and a JSON serialisable dict to be used as response
body of request.
"""
pass

@classmethod
def make_client(cls, hs: "HomeServer"):
def make_client(cls, hs: "HomeServer") -> Callable:
DMRobertson marked this conversation as resolved.
Show resolved Hide resolved
"""Create a client that makes requests.

Returns a callable that accepts the same parameters as
Expand All @@ -182,7 +189,7 @@ def make_client(cls, hs: "HomeServer"):
)

@trace(opname="outgoing_replication_request")
async def send_request(*, instance_name="master", **kwargs):
async def send_request(*, instance_name: str = "master", **kwargs: Any) -> Any:
DMRobertson marked this conversation as resolved.
Show resolved Hide resolved
with outgoing_gauge.track_inprogress():
if instance_name == local_instance_name:
raise Exception("Trying to send HTTP request to self")
Expand Down Expand Up @@ -268,7 +275,7 @@ async def send_request(*, instance_name="master", **kwargs):

return send_request

def register(self, http_server):
def register(self, http_server: HttpServer) -> None:
"""Called by the server to register this as a handler to the
appropriate path.
"""
Expand All @@ -289,7 +296,9 @@ def register(self, http_server):
self.__class__.__name__,
)

async def _check_auth_and_handle(self, request, **kwargs):
async def _check_auth_and_handle(
self, request: Request, **kwargs: Any
) -> Tuple[int, JsonDict]:
"""Called on new incoming requests when caching is enabled. Checks
if there is a cached response for the request and returns that,
otherwise calls `_handle_request` and caches its response.
Expand Down
38 changes: 28 additions & 10 deletions synapse/replication/http/account_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,14 @@
# limitations under the License.

import logging
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Tuple

from twisted.web.server import Request

from synapse.http.server import HttpServer
from synapse.http.servlet import parse_json_object_from_request
from synapse.replication.http._base import ReplicationEndpoint
from synapse.types import JsonDict

if TYPE_CHECKING:
from synapse.server import HomeServer
Expand Down Expand Up @@ -48,14 +52,18 @@ def __init__(self, hs: "HomeServer"):
self.clock = hs.get_clock()

@staticmethod
async def _serialize_payload(user_id, account_data_type, content):
async def _serialize_payload( # type: ignore[override]
DMRobertson marked this conversation as resolved.
Show resolved Hide resolved
user_id: str, account_data_type: str, content: JsonDict
) -> JsonDict:
payload = {
"content": content,
}

return payload

async def _handle_request(self, request, user_id, account_data_type):
async def _handle_request( # type: ignore[override]
self, request: Request, user_id: str, account_data_type: str
) -> Tuple[int, JsonDict]:
content = parse_json_object_from_request(request)

max_stream_id = await self.handler.add_account_data_for_user(
Expand Down Expand Up @@ -89,14 +97,18 @@ def __init__(self, hs: "HomeServer"):
self.clock = hs.get_clock()

@staticmethod
async def _serialize_payload(user_id, room_id, account_data_type, content):
async def _serialize_payload( # type: ignore[override]
user_id: str, room_id: str, account_data_type: str, content: JsonDict
) -> JsonDict:
payload = {
"content": content,
}

return payload

async def _handle_request(self, request, user_id, room_id, account_data_type):
async def _handle_request( # type: ignore[override]
self, request: Request, user_id: str, room_id: str, account_data_type: str
) -> Tuple[int, JsonDict]:
content = parse_json_object_from_request(request)

max_stream_id = await self.handler.add_account_data_to_room(
Expand Down Expand Up @@ -130,14 +142,18 @@ def __init__(self, hs: "HomeServer"):
self.clock = hs.get_clock()

@staticmethod
async def _serialize_payload(user_id, room_id, tag, content):
async def _serialize_payload( # type: ignore[override]
user_id: str, room_id: str, tag: str, content: JsonDict
) -> JsonDict:
payload = {
"content": content,
}

return payload

async def _handle_request(self, request, user_id, room_id, tag):
async def _handle_request( # type: ignore[override]
self, request: Request, user_id: str, room_id: str, tag: str
) -> Tuple[int, JsonDict]:
content = parse_json_object_from_request(request)

max_stream_id = await self.handler.add_tag_to_room(
Expand Down Expand Up @@ -173,11 +189,13 @@ def __init__(self, hs: "HomeServer"):
self.clock = hs.get_clock()

@staticmethod
async def _serialize_payload(user_id, room_id, tag):
async def _serialize_payload(user_id: str, room_id: str, tag: str) -> JsonDict: # type: ignore[override]

return {}

async def _handle_request(self, request, user_id, room_id, tag):
async def _handle_request( # type: ignore[override]
self, request: Request, user_id: str, room_id: str, tag: str
) -> Tuple[int, JsonDict]:
max_stream_id = await self.handler.remove_tag_from_room(
user_id,
room_id,
Expand All @@ -187,7 +205,7 @@ async def _handle_request(self, request, user_id, room_id, tag):
return 200, {"max_stream_id": max_stream_id}


def register_servlets(hs: "HomeServer", http_server):
def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
ReplicationUserAccountDataRestServlet(hs).register(http_server)
ReplicationRoomAccountDataRestServlet(hs).register(http_server)
ReplicationAddTagRestServlet(hs).register(http_server)
Expand Down
14 changes: 10 additions & 4 deletions synapse/replication/http/devices.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,13 @@
# limitations under the License.

import logging
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Tuple

from twisted.web.server import Request

from synapse.http.server import HttpServer
from synapse.replication.http._base import ReplicationEndpoint
from synapse.types import JsonDict

if TYPE_CHECKING:
from synapse.server import HomeServer
Expand Down Expand Up @@ -63,14 +67,16 @@ def __init__(self, hs: "HomeServer"):
self.clock = hs.get_clock()

@staticmethod
async def _serialize_payload(user_id):
async def _serialize_payload(user_id: str) -> JsonDict: # type: ignore[override]
return {}

async def _handle_request(self, request, user_id):
async def _handle_request( # type: ignore[override]
self, request: Request, user_id: str
) -> Tuple[int, JsonDict]:
user_devices = await self.device_list_updater.user_device_resync(user_id)

return 200, user_devices


def register_servlets(hs: "HomeServer", http_server):
def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
ReplicationUserDevicesResyncRestServlet(hs).register(http_server)
65 changes: 42 additions & 23 deletions synapse/replication/http/federation.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,18 @@
# limitations under the License.

import logging
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, List, Tuple

from synapse.api.room_versions import KNOWN_ROOM_VERSIONS
from synapse.events import make_event_from_dict
from twisted.web.server import Request

from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, RoomVersion
from synapse.events import EventBase, make_event_from_dict
from synapse.events.snapshot import EventContext
from synapse.http.server import HttpServer
from synapse.http.servlet import parse_json_object_from_request
from synapse.replication.http._base import ReplicationEndpoint
from synapse.storage.databases.main import DataStore
from synapse.types import JsonDict
from synapse.util.metrics import Measure

if TYPE_CHECKING:
Expand Down Expand Up @@ -69,14 +74,18 @@ def __init__(self, hs: "HomeServer"):
self.federation_event_handler = hs.get_federation_event_handler()

@staticmethod
async def _serialize_payload(store, room_id, event_and_contexts, backfilled):
async def _serialize_payload( # type: ignore[override]
store: DataStore,
room_id: str,
event_and_contexts: List[Tuple[EventBase, EventContext]],
backfilled: bool,
) -> JsonDict:
"""
Args:
store
room_id (str)
event_and_contexts (list[tuple[FrozenEvent, EventContext]])
backfilled (bool): Whether or not the events are the result of
backfilling
room_id
event_and_contexts
backfilled: Whether or not the events are the result of backfilling
"""
event_payloads = []
for event, context in event_and_contexts:
Expand All @@ -102,7 +111,7 @@ async def _serialize_payload(store, room_id, event_and_contexts, backfilled):

return payload

async def _handle_request(self, request):
async def _handle_request(self, request: Request) -> Tuple[int, JsonDict]: # type: ignore[override]
with Measure(self.clock, "repl_fed_send_events_parse"):
content = parse_json_object_from_request(request)

Expand Down Expand Up @@ -163,10 +172,14 @@ def __init__(self, hs: "HomeServer"):
self.registry = hs.get_federation_registry()

@staticmethod
async def _serialize_payload(edu_type, origin, content):
async def _serialize_payload( # type: ignore[override]
edu_type: str, origin: str, content: JsonDict
) -> JsonDict:
return {"origin": origin, "content": content}

async def _handle_request(self, request, edu_type):
async def _handle_request( # type: ignore[override]
self, request: Request, edu_type: str
) -> Tuple[int, JsonDict]:
with Measure(self.clock, "repl_fed_send_edu_parse"):
content = parse_json_object_from_request(request)

Expand All @@ -175,9 +188,9 @@ async def _handle_request(self, request, edu_type):

logger.info("Got %r edu from %s", edu_type, origin)

result = await self.registry.on_edu(edu_type, origin, edu_content)
await self.registry.on_edu(edu_type, origin, edu_content)

return 200, result
return 200, {}


class ReplicationGetQueryRestServlet(ReplicationEndpoint):
Expand Down Expand Up @@ -206,15 +219,17 @@ def __init__(self, hs: "HomeServer"):
self.registry = hs.get_federation_registry()

@staticmethod
async def _serialize_payload(query_type, args):
async def _serialize_payload(query_type: str, args: JsonDict) -> JsonDict: # type: ignore[override]
"""
Args:
query_type (str)
args (dict): The arguments received for the given query type
query_type
args: The arguments received for the given query type
"""
return {"args": args}

async def _handle_request(self, request, query_type):
async def _handle_request( # type: ignore[override]
self, request: Request, query_type: str
) -> Tuple[int, JsonDict]:
with Measure(self.clock, "repl_fed_query_parse"):
content = parse_json_object_from_request(request)

Expand Down Expand Up @@ -248,14 +263,16 @@ def __init__(self, hs: "HomeServer"):
self.store = hs.get_datastore()

@staticmethod
async def _serialize_payload(room_id, args):
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The only caller of this seems to only pass a room_id, I'm unsure if this is fixing a bug or not:

async def _clean_room_for_join(self, room_id: str) -> None:
"""Called to clean up any data in DB for a given room, ready for the
server to join the room.
Args:
room_id
"""
if self.config.worker.worker_app:
await self._clean_room_for_join_client(room_id)
else:
await self.store.clean_room_for_join(room_id)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Anything in sentry for this? Maybe it's just never called?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems to be called via do_invite_join which is definitely called, it seems to eventually come from either RoomMemberMaster or RoomMemberWorker, maybe we're not routing anything that calls the worker version to the worker?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think there is a clue here:

# TODO: We should be able to call this on workers, but the upgrading of
# room stuff after join currently doesn't work on workers.
assert self.config.worker.worker_app is None
logger.debug("Joining %s to %s", joinee, room_id)
origin, event, room_version_obj = await self._make_and_verify_event(
target_hosts,
room_id,
joinee,
"join",
content,
params={"ver": KNOWN_ROOM_VERSIONS},
)
# This shouldn't happen, because the RoomMemberHandler has a
# linearizer lock which only allows one operation per user per room
# at a time - so this is just paranoia.
assert room_id not in self._federation_event_handler.room_queues
self._federation_event_handler.room_queues[room_id] = []
await self._clean_room_for_join(room_id)

I couldn't see any other call site of _clean_room_for_join. So yeah, I think this is never called on workers as you say. Given that, let's go ahead with the fix as written. (Strictly speaking it probably ought to be a separate change, but I don't mind sneaking it in here.)

async def _serialize_payload(room_id: str) -> JsonDict: # type: ignore[override]
"""
Args:
room_id (str)
room_id
"""
return {}

async def _handle_request(self, request, room_id):
async def _handle_request( # type: ignore[override]
self, request: Request, room_id: str
) -> Tuple[int, JsonDict]:
await self.store.clean_room_for_join(room_id)

return 200, {}
Expand Down Expand Up @@ -283,17 +300,19 @@ def __init__(self, hs: "HomeServer"):
self.store = hs.get_datastore()

@staticmethod
async def _serialize_payload(room_id, room_version):
async def _serialize_payload(room_id: str, room_version: RoomVersion) -> JsonDict: # type: ignore[override]
return {"room_version": room_version.identifier}

async def _handle_request(self, request, room_id):
async def _handle_request( # type: ignore[override]
self, request: Request, room_id: str
) -> Tuple[int, JsonDict]:
content = parse_json_object_from_request(request)
room_version = KNOWN_ROOM_VERSIONS[content["room_version"]]
await self.store.maybe_store_room_on_outlier_membership(room_id, room_version)
return 200, {}


def register_servlets(hs: "HomeServer", http_server):
def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
ReplicationFederationSendEventsRestServlet(hs).register(http_server)
ReplicationFederationSendEduRestServlet(hs).register(http_server)
ReplicationGetQueryRestServlet(hs).register(http_server)
Expand Down
Loading