From 6c9fdd65a0c692a481355afe7cb9e2b0e01696b7 Mon Sep 17 00:00:00 2001 From: anthony sottile Date: Thu, 13 Jun 2024 11:39:50 -0400 Subject: [PATCH] ref: improve typing of base_query_set --- pyproject.toml | 1 + .../db/models/manager/base_query_set.py | 26 +++++++++++-------- 2 files changed, 16 insertions(+), 11 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index ed8bd7e33fb69f..bdc15f3a16c0da 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -521,6 +521,7 @@ module = [ "sentry.api.helpers.source_map_helper", "sentry.buffer.*", "sentry.build.*", + "sentry.db.models.manager.base_query_set", "sentry.eventstore.reprocessing.redis", "sentry.eventtypes.error", "sentry.grouping.component", diff --git a/src/sentry/db/models/manager/base_query_set.py b/src/sentry/db/models/manager/base_query_set.py index 13e5843e0af25e..e7c23b970dcdac 100644 --- a/src/sentry/db/models/manager/base_query_set.py +++ b/src/sentry/db/models/manager/base_query_set.py @@ -1,20 +1,22 @@ -import abc -from typing import Any +from __future__ import annotations + +from typing import Any, Self from django.core import exceptions from django.core.exceptions import EmptyResultSet from django.db import connections, router, transaction from django.db.models import QuerySet, sql +from sentry.db.models.manager import M from sentry.signals import post_update -class BaseQuerySet(QuerySet, abc.ABC): - def __init__(self, *args, **kwargs): +class BaseQuerySet(QuerySet[M]): + def __init__(self, *args: Any, **kwargs: Any) -> None: super().__init__(*args, **kwargs) self._with_post_update_signal = False - def with_post_update_signal(self, enable: bool) -> "BaseQuerySet": + def with_post_update_signal(self, enable: bool) -> Self: """ Enables sending a `post_update` signal after this queryset runs an update command. Note that this is less efficient than just running the update. To get the list of group ids affected, we first run the query to @@ -24,12 +26,12 @@ def with_post_update_signal(self, enable: bool) -> "BaseQuerySet": qs._with_post_update_signal = enable return qs - def _clone(self) -> "BaseQuerySet": + def _clone(self) -> Self: qs = super()._clone() # type: ignore[misc] qs._with_post_update_signal = self._with_post_update_signal return qs - def update_with_returning(self, returned_fields: list[str], **kwargs): + def update_with_returning(self, returned_fields: list[str], **kwargs: Any) -> list[tuple[int]]: """ Copied and modified from `Queryset.update()` to support `RETURNING ` """ @@ -77,7 +79,9 @@ def update_with_returning(self, returned_fields: list[str], **kwargs): def update(self, **kwargs: Any) -> int: if self._with_post_update_signal: - pk = self.model._meta.pk.name + pk_field = self.model._meta.pk + assert pk_field is not None + pk = pk_field.name ids = [result[0] for result in self.update_with_returning([pk], **kwargs)] if ids: updated_fields = list(kwargs.keys()) @@ -86,17 +90,17 @@ def update(self, **kwargs: Any) -> int: else: return super().update(**kwargs) - def using_replica(self) -> "BaseQuerySet": + def using_replica(self) -> Self: """ Use read replica for this query. Database router is expected to use the `replica=True` hint to make routing decision. """ return self.using(router.db_for_read(self.model, replica=True)) - def defer(self, *args: Any, **kwargs: Any) -> "BaseQuerySet": + def defer(self, *args: Any, **kwargs: Any) -> Self: raise NotImplementedError("Use ``values_list`` instead [performance].") - def only(self, *args: Any, **kwargs: Any) -> "BaseQuerySet": + def only(self, *args: Any, **kwargs: Any) -> Self: # In rare cases Django can use this if a field is unexpectedly deferred. This # mostly can happen if a field is added to a model, and then an old pickle is # passed to a process running the new code. So if you see this error after a