Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update mypy and various typing improvements #3393

Merged
merged 5 commits into from
Jun 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -48,9 +48,9 @@ markupsafe==2.1.2
# via jinja2
mccabe==0.7.0
# via flake8
mypy==1.0.1
mypy==1.10.0
# via -r requirements.in
mypy-extensions==0.4.3
mypy-extensions==1.0.0
# via
# black
# mypy
Expand Down Expand Up @@ -111,9 +111,9 @@ sphinxcontrib-serializinghtml==1.1.5
# via sphinx
tox==4.6.0
# via -r requirements.in
types-pycurl==7.45.2.0
types-pycurl==7.45.3.20240421
# via -r requirements.in
typing-extensions==4.4.0
typing-extensions==4.12.1
# via mypy
urllib3==1.26.18
# via requests
Expand Down
5 changes: 4 additions & 1 deletion tornado/concurrent.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,10 @@ def wrapper(self: Any, *args: Any, **kwargs: Any) -> Future:
_NO_RESULT = object()


def chain_future(a: "Future[_T]", b: "Future[_T]") -> None:
def chain_future(
a: Union["Future[_T]", "futures.Future[_T]"],
b: Union["Future[_T]", "futures.Future[_T]"],
) -> None:
"""Chain two futures together so that when one completes, so does the other.

The result (success or failure) of ``a`` will be copied to ``b``, unless
Expand Down
8 changes: 7 additions & 1 deletion tornado/httputil.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,12 @@
from asyncio import Future # noqa: F401
import unittest # noqa: F401

# This can be done unconditionally in the base class of HTTPHeaders
# after we drop support for Python 3.8.
StrMutableMapping = collections.abc.MutableMapping[str, str]
else:
StrMutableMapping = collections.abc.MutableMapping

# To be used with str.strip() and related methods.
HTTP_WHITESPACE = " \t"

Expand All @@ -76,7 +82,7 @@ def _normalize_header(name: str) -> str:
return "-".join([w.capitalize() for w in name.split("-")])


class HTTPHeaders(collections.abc.MutableMapping):
class HTTPHeaders(StrMutableMapping):
"""A dictionary that maintains ``Http-Header-Case`` for all keys.

Supports multiple values per key via a pair of new methods,
Expand Down
15 changes: 13 additions & 2 deletions tornado/platform/asyncio.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,9 @@
Union,
)

if typing.TYPE_CHECKING:
from typing_extensions import TypeVarTuple, Unpack


class _HasFileno(Protocol):
def fileno(self) -> int:
Expand All @@ -59,6 +62,8 @@ def fileno(self) -> int:

_T = TypeVar("_T")

if typing.TYPE_CHECKING:
_Ts = TypeVarTuple("_Ts")

# Collection of selector thread event loops to shut down on exit.
_selector_loops: Set["SelectorThread"] = set()
Expand Down Expand Up @@ -702,12 +707,18 @@ def close(self) -> None:
self._real_loop.close()

def add_reader(
self, fd: "_FileDescriptorLike", callback: Callable[..., None], *args: Any
self,
fd: "_FileDescriptorLike",
callback: Callable[..., None],
*args: "Unpack[_Ts]",
) -> None:
return self._selector.add_reader(fd, callback, *args)

def add_writer(
self, fd: "_FileDescriptorLike", callback: Callable[..., None], *args: Any
self,
fd: "_FileDescriptorLike",
callback: Callable[..., None],
*args: "Unpack[_Ts]",
) -> None:
return self._selector.add_writer(fd, callback, *args)

Expand Down
26 changes: 26 additions & 0 deletions tornado/test/concurrent_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@

from tornado.concurrent import (
Future,
chain_future,
run_on_executor,
future_set_result_unless_cancelled,
)
Expand All @@ -47,6 +48,31 @@ def test_future_set_result_unless_cancelled(self):
self.assertEqual(fut.result(), 42)


class ChainFutureTest(AsyncTestCase):
@gen_test
async def test_asyncio_futures(self):
fut: Future[int] = Future()
fut2: Future[int] = Future()
chain_future(fut, fut2)
fut.set_result(42)
result = await fut2
self.assertEqual(result, 42)

@gen_test
async def test_concurrent_futures(self):
# A three-step chain: two concurrent futures (showing that both arguments to chain_future
# can be concurrent futures), and then one from a concurrent future to an asyncio future so
# we can use it in await.
fut: futures.Future[int] = futures.Future()
fut2: futures.Future[int] = futures.Future()
fut3: Future[int] = Future()
chain_future(fut, fut2)
chain_future(fut2, fut3)
fut.set_result(42)
result = await fut3
self.assertEqual(result, 42)


# The following series of classes demonstrate and test various styles
# of use, with and without generators and futures.

Expand Down
26 changes: 17 additions & 9 deletions tornado/web.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,7 +192,15 @@ class RequestHandler(object):

"""

SUPPORTED_METHODS = ("GET", "HEAD", "POST", "DELETE", "PATCH", "PUT", "OPTIONS")
SUPPORTED_METHODS: Tuple[str, ...] = (
"GET",
"HEAD",
"POST",
"DELETE",
"PATCH",
"PUT",
"OPTIONS",
)

_template_loaders = {} # type: Dict[str, template.BaseLoader]
_template_loader_lock = threading.Lock()
Expand Down Expand Up @@ -1596,14 +1604,14 @@ def check_xsrf_cookie(self) -> None:
# information please see
# http://www.djangoproject.com/weblog/2011/feb/08/security/
# http://weblog.rubyonrails.org/2011/2/8/csrf-protection-bypass-in-ruby-on-rails
token = (
input_token = (
self.get_argument("_xsrf", None)
or self.request.headers.get("X-Xsrftoken")
or self.request.headers.get("X-Csrftoken")
)
if not token:
if not input_token:
raise HTTPError(403, "'_xsrf' argument missing from POST")
_, token, _ = self._decode_xsrf_token(token)
_, token, _ = self._decode_xsrf_token(input_token)
_, expected_token, _ = self._get_raw_xsrf_token()
if not token:
raise HTTPError(403, "'_xsrf' argument has invalid format")
Expand Down Expand Up @@ -1886,7 +1894,7 @@ def render(*args, **kwargs) -> str: # type: ignore
if name not in self._active_modules:
self._active_modules[name] = module(self)
rendered = self._active_modules[name].render(*args, **kwargs)
return rendered
return _unicode(rendered)

return render

Expand Down Expand Up @@ -3323,7 +3331,7 @@ def __init__(self, handler: RequestHandler) -> None:
def current_user(self) -> Any:
return self.handler.current_user

def render(self, *args: Any, **kwargs: Any) -> str:
def render(self, *args: Any, **kwargs: Any) -> Union[str, bytes]:
"""Override in subclasses to return this module's output."""
raise NotImplementedError()

Expand Down Expand Up @@ -3371,12 +3379,12 @@ def render_string(self, path: str, **kwargs: Any) -> bytes:


class _linkify(UIModule):
def render(self, text: str, **kwargs: Any) -> str: # type: ignore
def render(self, text: str, **kwargs: Any) -> str:
return escape.linkify(text, **kwargs)


class _xsrf_form_html(UIModule):
def render(self) -> str: # type: ignore
def render(self) -> str:
return self.handler.xsrf_form_html()


Expand All @@ -3402,7 +3410,7 @@ def __init__(self, handler: RequestHandler) -> None:
self._resource_list = [] # type: List[Dict[str, Any]]
self._resource_dict = {} # type: Dict[str, Dict[str, Any]]

def render(self, path: str, **kwargs: Any) -> bytes: # type: ignore
def render(self, path: str, **kwargs: Any) -> bytes:
def set_resources(**kwargs) -> str: # type: ignore
if path not in self._resource_dict:
self._resource_list.append(kwargs)
Expand Down
2 changes: 1 addition & 1 deletion tornado/websocket.py
Original file line number Diff line number Diff line change
Expand Up @@ -1380,7 +1380,7 @@ def __init__(
{
"Upgrade": "websocket",
"Connection": "Upgrade",
"Sec-WebSocket-Key": self.key,
"Sec-WebSocket-Key": to_unicode(self.key),
"Sec-WebSocket-Version": "13",
}
)
Expand Down