Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Move Account Validity callbacks to a dedicated file #15237

Merged
merged 5 commits into from
Mar 16, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changelog.d/15237.misc
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Move various module API callback registration methods to a dedicated class.
99 changes: 14 additions & 85 deletions synapse/handlers/account_validity.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,7 @@
import email.mime.multipart
import email.utils
import logging
from typing import TYPE_CHECKING, Awaitable, Callable, List, Optional, Tuple

from twisted.web.http import Request
from typing import TYPE_CHECKING, List, Optional, Tuple

from synapse.api.errors import AuthError, StoreError, SynapseError
from synapse.metrics.background_process_metrics import wrap_as_background_process
Expand All @@ -30,25 +28,17 @@

logger = logging.getLogger(__name__)

# Types for callbacks to be registered via the module api
IS_USER_EXPIRED_CALLBACK = Callable[[str], Awaitable[Optional[bool]]]
ON_USER_REGISTRATION_CALLBACK = Callable[[str], Awaitable]
# Temporary hooks to allow for a transition from `/_matrix/client` endpoints
# to `/_synapse/client/account_validity`. See `register_account_validity_callbacks`.
ON_LEGACY_SEND_MAIL_CALLBACK = Callable[[str], Awaitable]
ON_LEGACY_RENEW_CALLBACK = Callable[[str], Awaitable[Tuple[bool, bool, int]]]
ON_LEGACY_ADMIN_REQUEST = Callable[[Request], Awaitable]
Comment on lines -33 to -40
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Any idea if people import these in modules somewhere (e.g. for type hints)?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Users should really only import things from the synapse.module_api module. That being said, one can import classes that the synapse.module_api module itself imports, by doing:

from synapse.module_api import ON_LOGGED_OUT_CALLBACK

for example, and this would not be affected by these changes.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Makes sense. I know some of the spam checker and media bits don't do this as cleanly.



class AccountValidityHandler:
def __init__(self, hs: "HomeServer"):
self.hs = hs
self.config = hs.config
self.store = self.hs.get_datastores().main
self.send_email_handler = self.hs.get_send_email_handler()
self.clock = self.hs.get_clock()
self.store = hs.get_datastores().main
self.send_email_handler = hs.get_send_email_handler()
self.clock = hs.get_clock()

self._app_name = self.hs.config.email.email_app_name
self._app_name = hs.config.email.email_app_name
self._module_api_callbacks = hs.get_module_api_callbacks().account_validity

self._account_validity_enabled = (
hs.config.account_validity.account_validity_enabled
Expand Down Expand Up @@ -78,69 +68,6 @@ def __init__(self, hs: "HomeServer"):
if hs.config.worker.run_background_tasks:
self.clock.looping_call(self._send_renewal_emails, 30 * 60 * 1000)

self._is_user_expired_callbacks: List[IS_USER_EXPIRED_CALLBACK] = []
self._on_user_registration_callbacks: List[ON_USER_REGISTRATION_CALLBACK] = []
self._on_legacy_send_mail_callback: Optional[
ON_LEGACY_SEND_MAIL_CALLBACK
] = None
self._on_legacy_renew_callback: Optional[ON_LEGACY_RENEW_CALLBACK] = None

# The legacy admin requests callback isn't a protected attribute because we need
# to access it from the admin servlet, which is outside of this handler.
self.on_legacy_admin_request_callback: Optional[ON_LEGACY_ADMIN_REQUEST] = None

def register_account_validity_callbacks(
self,
is_user_expired: Optional[IS_USER_EXPIRED_CALLBACK] = None,
on_user_registration: Optional[ON_USER_REGISTRATION_CALLBACK] = None,
on_legacy_send_mail: Optional[ON_LEGACY_SEND_MAIL_CALLBACK] = None,
on_legacy_renew: Optional[ON_LEGACY_RENEW_CALLBACK] = None,
on_legacy_admin_request: Optional[ON_LEGACY_ADMIN_REQUEST] = None,
) -> None:
"""Register callbacks from module for each hook."""
if is_user_expired is not None:
self._is_user_expired_callbacks.append(is_user_expired)

if on_user_registration is not None:
self._on_user_registration_callbacks.append(on_user_registration)

# The builtin account validity feature exposes 3 endpoints (send_mail, renew, and
# an admin one). As part of moving the feature into a module, we need to change
# the path from /_matrix/client/unstable/account_validity/... to
# /_synapse/client/account_validity, because:
#
# * the feature isn't part of the Matrix spec thus shouldn't live under /_matrix
# * the way we register servlets means that modules can't register resources
# under /_matrix/client
#
# We need to allow for a transition period between the old and new endpoints
# in order to allow for clients to update (and for emails to be processed).
#
# Once the email-account-validity module is loaded, it will take control of account
# validity by moving the rows from our `account_validity` table into its own table.
#
# Therefore, we need to allow modules (in practice just the one implementing the
# email-based account validity) to temporarily hook into the legacy endpoints so we
# can route the traffic coming into the old endpoints into the module, which is
# why we have the following three temporary hooks.
if on_legacy_send_mail is not None:
if self._on_legacy_send_mail_callback is not None:
raise RuntimeError("Tried to register on_legacy_send_mail twice")

self._on_legacy_send_mail_callback = on_legacy_send_mail

if on_legacy_renew is not None:
if self._on_legacy_renew_callback is not None:
raise RuntimeError("Tried to register on_legacy_renew twice")

self._on_legacy_renew_callback = on_legacy_renew

if on_legacy_admin_request is not None:
if self.on_legacy_admin_request_callback is not None:
raise RuntimeError("Tried to register on_legacy_admin_request twice")

self.on_legacy_admin_request_callback = on_legacy_admin_request

async def is_user_expired(self, user_id: str) -> bool:
"""Checks if a user has expired against third-party modules.

Expand All @@ -150,7 +77,7 @@ async def is_user_expired(self, user_id: str) -> bool:
Returns:
Whether the user has expired.
"""
for callback in self._is_user_expired_callbacks:
for callback in self._module_api_callbacks.is_user_expired_callbacks:
expired = await delay_cancellation(callback(user_id))
if expired is not None:
return expired
Expand All @@ -168,7 +95,7 @@ async def on_user_registration(self, user_id: str) -> None:
Args:
user_id: The ID of the newly registered user.
"""
for callback in self._on_user_registration_callbacks:
for callback in self._module_api_callbacks.on_user_registration_callbacks:
await callback(user_id)

@wrap_as_background_process("send_renewals")
Expand Down Expand Up @@ -198,8 +125,8 @@ async def send_renewal_email_to_user(self, user_id: str) -> None:
"""
# If a module supports sending a renewal email from here, do that, otherwise do
# the legacy dance.
if self._on_legacy_send_mail_callback is not None:
await self._on_legacy_send_mail_callback(user_id)
if self._module_api_callbacks.on_legacy_send_mail_callback is not None:
await self._module_api_callbacks.on_legacy_send_mail_callback(user_id)
return

if not self._account_validity_renew_by_email_enabled:
Expand Down Expand Up @@ -336,8 +263,10 @@ async def renew_account(self, renewal_token: str) -> Tuple[bool, bool, int]:
"""
# If a module supports triggering a renew from here, do that, otherwise do the
# legacy dance.
if self._on_legacy_renew_callback is not None:
return await self._on_legacy_renew_callback(renewal_token)
if self._module_api_callbacks.on_legacy_renew_callback is not None:
return await self._module_api_callbacks.on_legacy_renew_callback(
renewal_token
)

try:
(
Expand Down
18 changes: 9 additions & 9 deletions synapse/module_api/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,13 +73,6 @@
ON_USER_DEACTIVATION_STATUS_CHANGED_CALLBACK,
)
from synapse.handlers.account_data import ON_ACCOUNT_DATA_UPDATED_CALLBACK
from synapse.handlers.account_validity import (
IS_USER_EXPIRED_CALLBACK,
ON_LEGACY_ADMIN_REQUEST,
ON_LEGACY_RENEW_CALLBACK,
ON_LEGACY_SEND_MAIL_CALLBACK,
ON_USER_REGISTRATION_CALLBACK,
)
from synapse.handlers.auth import (
CHECK_3PID_AUTH_CALLBACK,
CHECK_AUTH_CALLBACK,
Expand All @@ -105,6 +98,13 @@
run_in_background,
)
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.module_api.callbacks.account_validity_callbacks import (
IS_USER_EXPIRED_CALLBACK,
ON_LEGACY_ADMIN_REQUEST,
ON_LEGACY_RENEW_CALLBACK,
ON_LEGACY_SEND_MAIL_CALLBACK,
ON_USER_REGISTRATION_CALLBACK,
)
from synapse.rest.client.login import LoginResponse
from synapse.storage import DataStore
from synapse.storage.background_updates import (
Expand Down Expand Up @@ -250,6 +250,7 @@ def __init__(self, hs: "HomeServer", auth_handler: AuthHandler) -> None:
self._push_rules_handler = hs.get_push_rules_handler()
self._device_handler = hs.get_device_handler()
self.custom_template_dir = hs.config.server.custom_template_directory
self._callbacks = hs.get_module_api_callbacks()

try:
app_name = self._hs.config.email.email_app_name
Expand All @@ -271,7 +272,6 @@ def __init__(self, hs: "HomeServer", auth_handler: AuthHandler) -> None:
self._account_data_manager = AccountDataManager(hs)

self._spam_checker = hs.get_spam_checker()
self._account_validity_handler = hs.get_account_validity_handler()
self._third_party_event_rules = hs.get_third_party_event_rules()
self._password_auth_provider = hs.get_password_auth_provider()
self._presence_router = hs.get_presence_router()
Expand Down Expand Up @@ -332,7 +332,7 @@ def register_account_validity_callbacks(

Added in Synapse v1.39.0.
"""
return self._account_validity_handler.register_account_validity_callbacks(
return self._callbacks.account_validity.register_callbacks(
is_user_expired=is_user_expired,
on_user_registration=on_user_registration,
on_legacy_send_mail=on_legacy_send_mail,
Expand Down
22 changes: 22 additions & 0 deletions synapse/module_api/callbacks/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
# Copyright 2023 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from synapse.module_api.callbacks.account_validity_callbacks import (
AccountValidityModuleApiCallbacks,
)


class ModuleApiCallbacks:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm assuming by how you set this up you're planning to add other module APIs to this too?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Correct - see this poorly named branch for the commits that move the rest of the callback classes: https://github.com/matrix-org/synapse/compare/anoa/public_rooms_module_api_backup

def __init__(self) -> None:
self.account_validity = AccountValidityModuleApiCallbacks()
93 changes: 93 additions & 0 deletions synapse/module_api/callbacks/account_validity_callbacks.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
# Copyright 2023 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import logging
from typing import Awaitable, Callable, List, Optional, Tuple

from twisted.web.http import Request

logger = logging.getLogger(__name__)

# Types for callbacks to be registered via the module api
IS_USER_EXPIRED_CALLBACK = Callable[[str], Awaitable[Optional[bool]]]
ON_USER_REGISTRATION_CALLBACK = Callable[[str], Awaitable]
# Temporary hooks to allow for a transition from `/_matrix/client` endpoints
# to `/_synapse/client/account_validity`. See `register_callbacks` below.
ON_LEGACY_SEND_MAIL_CALLBACK = Callable[[str], Awaitable]
ON_LEGACY_RENEW_CALLBACK = Callable[[str], Awaitable[Tuple[bool, bool, int]]]
ON_LEGACY_ADMIN_REQUEST = Callable[[Request], Awaitable]


class AccountValidityModuleApiCallbacks:
def __init__(self) -> None:
self.is_user_expired_callbacks: List[IS_USER_EXPIRED_CALLBACK] = []
self.on_user_registration_callbacks: List[ON_USER_REGISTRATION_CALLBACK] = []
self.on_legacy_send_mail_callback: Optional[ON_LEGACY_SEND_MAIL_CALLBACK] = None
self.on_legacy_renew_callback: Optional[ON_LEGACY_RENEW_CALLBACK] = None

# The legacy admin requests callback isn't a protected attribute because we need
# to access it from the admin servlet, which is outside of this handler.
self.on_legacy_admin_request_callback: Optional[ON_LEGACY_ADMIN_REQUEST] = None

def register_callbacks(
self,
is_user_expired: Optional[IS_USER_EXPIRED_CALLBACK] = None,
on_user_registration: Optional[ON_USER_REGISTRATION_CALLBACK] = None,
on_legacy_send_mail: Optional[ON_LEGACY_SEND_MAIL_CALLBACK] = None,
on_legacy_renew: Optional[ON_LEGACY_RENEW_CALLBACK] = None,
on_legacy_admin_request: Optional[ON_LEGACY_ADMIN_REQUEST] = None,
) -> None:
"""Register callbacks from module for each hook."""
if is_user_expired is not None:
self.is_user_expired_callbacks.append(is_user_expired)

if on_user_registration is not None:
self.on_user_registration_callbacks.append(on_user_registration)

# The builtin account validity feature exposes 3 endpoints (send_mail, renew, and
# an admin one). As part of moving the feature into a module, we need to change
# the path from /_matrix/client/unstable/account_validity/... to
# /_synapse/client/account_validity, because:
#
# * the feature isn't part of the Matrix spec thus shouldn't live under /_matrix
# * the way we register servlets means that modules can't register resources
# under /_matrix/client
#
# We need to allow for a transition period between the old and new endpoints
# in order to allow for clients to update (and for emails to be processed).
#
# Once the email-account-validity module is loaded, it will take control of account
# validity by moving the rows from our `account_validity` table into its own table.
#
# Therefore, we need to allow modules (in practice just the one implementing the
# email-based account validity) to temporarily hook into the legacy endpoints so we
# can route the traffic coming into the old endpoints into the module, which is
# why we have the following three temporary hooks.
if on_legacy_send_mail is not None:
if self.on_legacy_send_mail_callback is not None:
raise RuntimeError("Tried to register on_legacy_send_mail twice")

self.on_legacy_send_mail_callback = on_legacy_send_mail

if on_legacy_renew is not None:
if self.on_legacy_renew_callback is not None:
raise RuntimeError("Tried to register on_legacy_renew twice")

self.on_legacy_renew_callback = on_legacy_renew

if on_legacy_admin_request is not None:
if self.on_legacy_admin_request_callback is not None:
raise RuntimeError("Tried to register on_legacy_admin_request twice")

self.on_legacy_admin_request_callback = on_legacy_admin_request
17 changes: 8 additions & 9 deletions synapse/rest/admin/users.py
Original file line number Diff line number Diff line change
Expand Up @@ -683,19 +683,18 @@ class AccountValidityRenewServlet(RestServlet):
PATTERNS = admin_patterns("/account_validity/validity$")

def __init__(self, hs: "HomeServer"):
self.account_activity_handler = hs.get_account_validity_handler()
self.account_validity_handler = hs.get_account_validity_handler()
self.account_validity_module_callbacks = (
hs.get_module_api_callbacks().account_validity
)
self.auth = hs.get_auth()

async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
await assert_requester_is_admin(self.auth, request)

if self.account_activity_handler.on_legacy_admin_request_callback:
expiration_ts = (
await (
self.account_activity_handler.on_legacy_admin_request_callback(
request
)
)
if self.account_validity_module_callbacks.on_legacy_admin_request_callback:
expiration_ts = await self.account_validity_module_callbacks.on_legacy_admin_request_callback(
request
)
else:
body = parse_json_object_from_request(request)
Expand All @@ -706,7 +705,7 @@ async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
"Missing property 'user_id' in the request body",
)

expiration_ts = await self.account_activity_handler.renew_account_for_user(
expiration_ts = await self.account_validity_handler.renew_account_for_user(
body["user_id"],
body.get("expiration_ts"),
not body.get("enable_renewal_emails", True),
Expand Down
5 changes: 5 additions & 0 deletions synapse/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,7 @@
from synapse.media.media_repository import MediaRepository
from synapse.metrics.common_usage_metrics import CommonUsageMetricsManager
from synapse.module_api import ModuleApi
from synapse.module_api.callbacks import ModuleApiCallbacks
from synapse.notifier import Notifier, ReplicationNotifier
from synapse.push.bulk_push_rule_evaluator import BulkPushRuleEvaluator
from synapse.push.pusherpool import PusherPool
Expand Down Expand Up @@ -777,6 +778,10 @@ def get_federation_ratelimiter(self) -> FederationRateLimiter:
def get_module_api(self) -> ModuleApi:
return ModuleApi(self, self.get_auth_handler())

@cache_in_self
def get_module_api_callbacks(self) -> ModuleApiCallbacks:
return ModuleApiCallbacks()

@cache_in_self
def get_account_data_handler(self) -> AccountDataHandler:
return AccountDataHandler(self)
Expand Down
5 changes: 2 additions & 3 deletions tests/rest/client/test_account.py
Original file line number Diff line number Diff line change
Expand Up @@ -1249,9 +1249,8 @@ async def is_expired(user_id: str) -> bool:
# account status will fail.
return UserID.from_string(user_id).localpart == "someuser"

self.hs.get_account_validity_handler()._is_user_expired_callbacks.append(
is_expired
)
account_validity_callbacks = self.hs.get_module_api_callbacks().account_validity
account_validity_callbacks.is_user_expired_callbacks.append(is_expired)

self._test_status(
users=[user],
Expand Down