Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Add type hints to state database #10823

Merged
merged 7 commits into from
Sep 15, 2021
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changelog.d/10823.misc
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Add type hints to the state database.
1 change: 1 addition & 0 deletions mypy.ini
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ files =
synapse/storage/databases/main/session.py,
synapse/storage/databases/main/stream.py,
synapse/storage/databases/main/ui_auth.py,
synapse/storage/databases/state,
synapse/storage/database.py,
synapse/storage/engines,
synapse/storage/keys.py,
Expand Down
52 changes: 36 additions & 16 deletions synapse/storage/databases/state/bg_updates.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,20 @@
# limitations under the License.

import logging
from typing import Optional
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union

from synapse.storage._base import SQLBaseStore
from synapse.storage.database import DatabasePool
from synapse.storage.database import (
DatabasePool,
LoggingDatabaseConnection,
LoggingTransaction,
)
from synapse.storage.engines import PostgresEngine
from synapse.storage.state import StateFilter

if TYPE_CHECKING:
from synapse.server import HomeServer

logger = logging.getLogger(__name__)


Expand All @@ -31,7 +38,9 @@ class StateGroupBackgroundUpdateStore(SQLBaseStore):
updates.
"""

def _count_state_group_hops_txn(self, txn, state_group):
def _count_state_group_hops_txn(
self, txn: LoggingTransaction, state_group: int
) -> int:
"""Given a state group, count how many hops there are in the tree.

This is used to ensure the delta chains don't get too long.
Expand All @@ -56,7 +65,7 @@ def _count_state_group_hops_txn(self, txn, state_group):
else:
# We don't use WITH RECURSIVE on sqlite3 as there are distributions
# that ship with an sqlite3 version that doesn't support it (e.g. wheezy)
next_group = state_group
next_group: Optional[int] = state_group
count = 0

while next_group:
Expand All @@ -73,11 +82,14 @@ def _count_state_group_hops_txn(self, txn, state_group):
return count

def _get_state_groups_from_groups_txn(
self, txn, groups, state_filter: Optional[StateFilter] = None
self,
txn: LoggingTransaction,
groups: List[int],
state_filter: Optional[StateFilter] = None,
):
clokep marked this conversation as resolved.
Show resolved Hide resolved
state_filter = state_filter or StateFilter.all()

results = {group: {} for group in groups}
results: Dict[int, Dict[Tuple[str, str], str]] = {group: {} for group in groups}

where_clause, where_args = state_filter.make_sql_filter_clause()

Expand Down Expand Up @@ -117,7 +129,7 @@ def _get_state_groups_from_groups_txn(
"""

for group in groups:
args = [group]
args: List[Union[int, str]] = [group]
args.extend(where_args)

txn.execute(sql % (where_clause,), args)
Expand All @@ -131,7 +143,7 @@ def _get_state_groups_from_groups_txn(
# We don't use WITH RECURSIVE on sqlite3 as there are distributions
# that ship with an sqlite3 version that doesn't support it (e.g. wheezy)
for group in groups:
next_group = group
next_group: Optional[int] = group

while next_group:
# We did this before by getting the list of group ids, and
Expand Down Expand Up @@ -182,7 +194,12 @@ class StateBackgroundUpdateStore(StateGroupBackgroundUpdateStore):
STATE_GROUP_INDEX_UPDATE_NAME = "state_group_state_type_index"
STATE_GROUPS_ROOM_INDEX_UPDATE_NAME = "state_groups_room_id_idx"

def __init__(self, database: DatabasePool, db_conn, hs):
def __init__(
self,
database: DatabasePool,
db_conn: LoggingDatabaseConnection,
hs: "HomeServer",
):
super().__init__(database, db_conn, hs)
self.db_pool.updates.register_background_update_handler(
self.STATE_GROUP_DEDUPLICATION_UPDATE_NAME,
Expand All @@ -198,7 +215,9 @@ def __init__(self, database: DatabasePool, db_conn, hs):
columns=["room_id"],
)

async def _background_deduplicate_state(self, progress, batch_size):
async def _background_deduplicate_state(
self, progress: dict, batch_size: int
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure if progress can be more specific? Maybe JsonDict?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not easily. Me and David looked at a similar case. (Though actually it seems reasonable to believe it might be doable as long as you update all background processes at the same time.)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I hoped I could define my own TypedDict, but this ended up clashing with the background process stuff. Specifically:

  • register_background_update_handler's update_handler has type Callable[[JsonDict, int], Awaitable[int]]
  • this means that the progress argument type ends up being compared against JsonDict = Dict[str, Any]
  • TypedDict is not a subtype of JsonDict because you can add new keys or remove existing keys from a Dict
    • so passing t: TypedDict to a function process_dict(d: Dict) might alter t's keys. This would be bad.
  • TypedDict is a subtype of Mapping[str, Any] because the latter is immutable
  • More details here: https://www.python.org/dev/peps/pep-0589/#id20

An update_handler itself is called here:

        progress = db_to_json(progress_json)

        time_start = self._clock.time_msec()
        items_updated = await update_handler(progress, batch_size)

and db_to_json returns a JsonDict.

We didn't find a nice way of joining this all up that didn't involve a sea of casts.

) -> int:
"""This background update will slowly deduplicate state by reencoding
them as deltas.
"""
Expand All @@ -218,7 +237,7 @@ async def _background_deduplicate_state(self, progress, batch_size):
)
max_group = rows[0][0]

def reindex_txn(txn):
def reindex_txn(txn: LoggingTransaction) -> Tuple[bool, int]:
new_last_state_group = last_state_group
for count in range(batch_size):
txn.execute(
Expand Down Expand Up @@ -251,7 +270,8 @@ def reindex_txn(txn):
" WHERE id < ? AND room_id = ?",
(state_group, room_id),
)
(prev_group,) = txn.fetchone()
# There will be a result due to the coalesce.
(prev_group,) = txn.fetchone() # type: ignore
new_last_state_group = state_group

if prev_group:
Expand Down Expand Up @@ -340,14 +360,14 @@ def reindex_txn(txn):

return result * BATCH_SIZE_SCALE_FACTOR

async def _background_index_state(self, progress, batch_size):
def reindex_txn(conn):
async def _background_index_state(self, progress: dict, batch_size: int) -> int:
def reindex_txn(conn: LoggingDatabaseConnection) -> None:
conn.rollback()
if isinstance(self.database_engine, PostgresEngine):
# postgres insists on autocommit for the index
conn.set_session(autocommit=True)
try:
txn = conn.cursor()
txn = conn.LoggingTransaction()
clokep marked this conversation as resolved.
Show resolved Hide resolved
txn.execute(
"CREATE INDEX CONCURRENTLY state_groups_state_type_idx"
" ON state_groups_state(state_group, type, state_key)"
Expand All @@ -356,7 +376,7 @@ def reindex_txn(conn):
finally:
conn.set_session(autocommit=False)
else:
txn = conn.cursor()
txn = conn.LoggingTransaction()
txn.execute(
"CREATE INDEX state_groups_state_type_idx"
" ON state_groups_state(state_group, type, state_key)"
Expand Down
Loading