Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ref: adjust typing so rate_limits can be a callable #74807

Merged
merged 1 commit into from
Jul 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions src/sentry/api/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,8 +221,8 @@ class Endpoint(APIView):

owner: ApiOwner = ApiOwner.UNOWNED
publish_status: dict[HTTP_METHOD_NAME, ApiPublishStatus] = {}
rate_limits: RateLimitConfig | dict[
str, dict[RateLimitCategory, RateLimit]
rate_limits: RateLimitConfig | dict[str, dict[RateLimitCategory, RateLimit]] | Callable[
..., RateLimitConfig | dict[str, dict[RateLimitCategory, RateLimit]]
] = DEFAULT_RATE_LIMIT_CONFIG
enforce_rate_limit: bool = settings.SENTRY_RATELIMITER_ENABLED
snuba_methods: list[HTTP_METHOD_NAME] = []
Expand Down
3 changes: 2 additions & 1 deletion tests/sentry/middleware/test_ratelimit_middleware.py
Original file line number Diff line number Diff line change
Expand Up @@ -288,7 +288,8 @@ class CallableRateLimitConfigEndpoint(Endpoint):
permission_classes = (AllowAny,)
enforce_rate_limit = True

def rate_limits(request):
@staticmethod
def rate_limits(*a, **k):
return {
"GET": {
RateLimitCategory.IP: RateLimit(limit=20, window=1),
Expand Down
28 changes: 0 additions & 28 deletions tests/sentry/ratelimits/utils/test_get_rate_limit_value.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,31 +88,3 @@ class ChildEndpoint(ParentEndpoint):
assert get_rate_limit_value(
"GET", RateLimitCategory.IP, rate_limit_config
) == get_default_rate_limits_for_group("foo", RateLimitCategory.IP)

def test_multiple_inheritance(self):
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this test wasn't actually testing anything beyond "does python do inheritance" so I removed it

class ParentEndpoint(Endpoint):
rate_limits: RateLimitConfig | dict[str, dict[RateLimitCategory, RateLimit]]
rate_limits = {"GET": {RateLimitCategory.IP: RateLimit(limit=100, window=5)}}

class Mixin:
rate_limits: RateLimitConfig | dict[str, dict[RateLimitCategory, RateLimit]]
rate_limits = {"GET": {RateLimitCategory.IP: RateLimit(limit=2, window=4)}}

class ChildEndpoint(ParentEndpoint, Mixin):
pass

_child_endpoint = ChildEndpoint.as_view()
rate_limit_config = get_rate_limit_config(_child_endpoint.view_class)

class ChildEndpointReverse(Mixin, ParentEndpoint):
pass

_child_endpoint_reverse = ChildEndpointReverse.as_view()
rate_limit_config_reverse = get_rate_limit_config(_child_endpoint_reverse.view_class)

assert get_rate_limit_value("GET", RateLimitCategory.IP, rate_limit_config) == RateLimit(
100, 5
)
assert get_rate_limit_value(
"GET", RateLimitCategory.IP, rate_limit_config_reverse
) == RateLimit(limit=2, window=4)
Loading