Skip to content

Commit

Permalink
Use AppKey (#621)
Browse files Browse the repository at this point in the history
  • Loading branch information
Dreamsorcerer committed Nov 19, 2023
1 parent 23948e3 commit 38c22c7
Show file tree
Hide file tree
Showing 7 changed files with 74 additions and 60 deletions.
49 changes: 32 additions & 17 deletions aiohttp_devtools/runserver/serve.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,10 @@
import json
import mimetypes
import sys
import warnings
from errno import EADDRINUSE
from pathlib import Path
from typing import Any, Iterator, Optional, NoReturn
from typing import Any, Iterator, NoReturn, Optional, Set, Tuple

from aiohttp import WSMsgType, web
from aiohttp.hdrs import LAST_MODIFIED, CONTENT_LENGTH
Expand All @@ -23,19 +24,37 @@
from .log_handlers import AccessLogger
from .utils import MutableValue

try:
from aiohttp_jinja2 import static_root_key
except ImportError:
static_root_key = None # type: ignore[assignment]

LIVE_RELOAD_HOST_SNIPPET = '\n<script src="http://{}:{}/livereload.js"></script>\n'
LIVE_RELOAD_LOCAL_SNIPPET = b'\n<script src="/livereload.js"></script>\n'
HOST = '0.0.0.0'

LIVERELOAD_SCRIPT = web.AppKey("LIVERELOAD_SCRIPT", bytes)
STATIC_PATH = web.AppKey("STATIC_PATH", str)
STATIC_URL = web.AppKey("STATIC_URL", str)
WS = web.AppKey("WS", Set[Tuple[web.WebSocketResponse, str]])


def _set_static_url(app: web.Application, url: str) -> None:
app["static_root_url"] = MutableValue(url)
if static_root_key is None: # TODO: Remove fallback
with warnings.catch_warnings(): # type: ignore[unreachable]
app["static_root_url"] = MutableValue(url)
else:
app[static_root_key] = MutableValue(url) # type: ignore[misc]
for subapp in app._subapps:
_set_static_url(subapp, url)


def _change_static_url(app: web.Application, url: str) -> None:
app["static_root_url"].change(url)
if static_root_key is None: # TODO: Remove fallback
with warnings.catch_warnings(): # type: ignore[unreachable]
app["static_root_url"].change(url)
else:
app[static_root_key].change(url) # type: ignore[attr-defined]
for subapp in app._subapps:
_change_static_url(subapp, url)

Expand Down Expand Up @@ -174,23 +193,20 @@ async def create_main_app(config: Config, app_factory: AppFactory) -> web.AppRun
modify_main_app(app, config)

await check_port_open(config.main_port)
return web.AppRunner(app, access_log_class=AccessLogger)
return web.AppRunner(app, access_log_class=AccessLogger, shutdown_timeout=0.1)


async def start_main_app(runner: web.AppRunner, port: int) -> None:
await runner.setup()
site = web.TCPSite(runner, host=HOST, port=port, shutdown_timeout=0.1)
site = web.TCPSite(runner, host=HOST, port=port)
await site.start()


WS = 'websockets'


async def src_reload(app: web.Application, path: Optional[str] = None) -> int:
"""
prompt each connected browser to reload by sending websocket message.
:param path: if supplied this must be a path relative to app['static_path'],
:param path: if supplied this must be a path relative to `static_path`,
eg. reload of a single file is only supported for static resources.
:return: number of sources reloaded
"""
Expand All @@ -200,7 +216,7 @@ async def src_reload(app: web.Application, path: Optional[str] = None) -> int:

is_html = None
if path:
path = str(Path(app['static_url']) / Path(path).relative_to(app['static_path']))
path = str(Path(app[STATIC_URL]) / Path(path).relative_to(app[STATIC_PATH]))
is_html = mimetypes.guess_type(path)[0] == 'text/html'

reloads = 0
Expand Down Expand Up @@ -239,16 +255,15 @@ def create_auxiliary_app(
*, static_path: Optional[str], static_url: str = "/", livereload: bool = True,
browser_cache: bool = False) -> web.Application:
app = web.Application()
app[WS] = set()
app.update(
static_path=static_path,
static_url=static_url,
)
ws: Set[Tuple[web.WebSocketResponse, str]] = set()
app[STATIC_PATH] = static_path or ""
app[STATIC_URL] = static_url
app[WS] = ws
app.on_shutdown.append(cleanup_aux_app)

if livereload:
lr_path = Path(__file__).resolve().parent / 'livereload.js'
app['livereload_script'] = lr_path.read_bytes()
app[LIVERELOAD_SCRIPT] = lr_path.read_bytes()
app.router.add_route('GET', '/livereload.js', livereload_js)
app.router.add_route('GET', '/livereload', websocket_handler)
aux_logger.debug('enabling livereload on auxiliary app')
Expand All @@ -271,7 +286,7 @@ async def livereload_js(request: web.Request) -> web.Response:
if request.if_modified_since:
raise HTTPNotModified()

lr_script = request.app['livereload_script']
lr_script = request.app[LIVERELOAD_SCRIPT]
return web.Response(body=lr_script, content_type='application/javascript',
headers={LAST_MODIFIED: 'Fri, 01 Jan 2016 00:00:00 GMT'})

Expand Down
4 changes: 2 additions & 2 deletions aiohttp_devtools/runserver/watch.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from ..exceptions import AiohttpDevException
from ..logs import rs_dft_logger as logger
from .config import Config
from .serve import WS, serve_main_app, src_reload
from .serve import STATIC_PATH, WS, serve_main_app, src_reload


class WatchTask:
Expand Down Expand Up @@ -64,7 +64,7 @@ async def _run(self, live_checks: int = 150) -> None:
try:
self._start_dev_server()

static_path = str(self._app['static_path'])
static_path = self._app[STATIC_PATH]

def is_static(changes: Iterable[Tuple[object, str]]) -> bool:
return all(str(c[1]).startswith(static_path) for c in changes)
Expand Down
3 changes: 2 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
aiohttp==3.8.5
aiohttp==3.9.0
aiohttp-jinja2==1.6
click==8.1.7
coverage==7.3.2
devtools==0.12.2
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@
aiohttp-devtools=aiohttp_devtools.cli:cli
""",
install_requires=[
'aiohttp>=3.8.0',
"aiohttp>=3.9",
'click>=6.6',
'devtools>=0.6',
'Pygments>=2.2.0',
Expand Down
8 changes: 4 additions & 4 deletions tests/test_runserver_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@

from aiohttp_devtools.runserver import runserver
from aiohttp_devtools.runserver.config import Config
from aiohttp_devtools.runserver.serve import (create_auxiliary_app, create_main_app, modify_main_app, src_reload,
start_main_app)
from aiohttp_devtools.runserver.serve import (
WS, create_auxiliary_app, create_main_app, modify_main_app, src_reload, start_main_app)

from .conftest import SIMPLE_APP, forked

Expand Down Expand Up @@ -218,12 +218,12 @@ async def test_websocket_hello(aux_cli, smart_caplog):


async def test_websocket_info(aux_cli, event_loop):
assert len(aux_cli.server.app['websockets']) == 0
assert len(aux_cli.server.app[WS]) == 0
ws = await aux_cli.session.ws_connect(aux_cli.make_url('/livereload'))
try:
await ws.send_json({'command': 'info', 'url': 'foobar', 'plugins': 'bang'})
await asyncio.sleep(0.05)
assert len(aux_cli.server.app['websockets']) == 1
assert len(aux_cli.server.app[WS]) == 1
finally:
await ws.close()

Expand Down
55 changes: 25 additions & 30 deletions tests/test_runserver_serve.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,20 @@
import json
import pathlib
import socket
from typing import Dict
from typing import Any, Dict
from unittest.mock import MagicMock

import pytest
from aiohttp.web import Application, Request, Response
from aiohttp.web import Application, AppKey, Request, Response
from aiohttp_jinja2 import static_root_key
from pytest_toolbox import mktree

from aiohttp_devtools.exceptions import AiohttpDevException
from aiohttp_devtools.runserver.config import Config
from aiohttp_devtools.runserver.log_handlers import fmt_size
from aiohttp_devtools.runserver.serve import check_port_open, cleanup_aux_app, modify_main_app, src_reload
from aiohttp_devtools.runserver.serve import (
STATIC_PATH, STATIC_URL, WS, check_port_open, cleanup_aux_app,
modify_main_app, src_reload)

from .conftest import SIMPLE_APP, create_future

Expand All @@ -33,11 +36,9 @@ async def test_aux_reload(smart_caplog):
aux_app = Application()
ws = MagicMock()
ws.send_str = MagicMock(return_value=create_future())
aux_app.update(
websockets=[(ws, '/foo/bar')],
static_url='/static/',
static_path='/path/to/static_files/'
)
aux_app[STATIC_PATH] = "/path/to/static_files/"
aux_app[STATIC_URL] = "/static/"
aux_app[WS] = set(((ws, "/foo/bar"),)) # type: ignore[misc]
assert 1 == await src_reload(aux_app, '/path/to/static_files/the_file.js')
assert ws.send_str.call_count == 1
send_obj = json.loads(ws.send_str.call_args[0][0])
Expand All @@ -55,11 +56,9 @@ async def test_aux_reload_no_path():
aux_app = Application()
ws = MagicMock()
ws.send_str = MagicMock(return_value=create_future())
aux_app.update(
websockets=[(ws, '/foo/bar')],
static_url='/static/',
static_path='/path/to/static_files/'
)
aux_app[STATIC_PATH] = "/path/to/static_files/"
aux_app[STATIC_URL] = "/static/"
aux_app[WS] = set(((ws, "/foo/bar"),)) # type: ignore[misc]
assert 1 == await src_reload(aux_app)
assert ws.send_str.call_count == 1
send_obj = json.loads(ws.send_str.call_args[0][0])
Expand All @@ -75,11 +74,9 @@ async def test_aux_reload_html_different():
aux_app = Application()
ws = MagicMock()
ws.send_str = MagicMock(return_value=create_future())
aux_app.update(
websockets=[(ws, '/foo/bar')],
static_url='/static/',
static_path='/path/to/static_files/'
)
aux_app[STATIC_PATH] = "/path/to/static_files/"
aux_app[STATIC_URL] = "/static/"
aux_app[WS] = set(((ws, "/foo/bar"),)) # type: ignore[misc]
assert 0 == await src_reload(aux_app, '/path/to/static_files/foo/bar.html')
assert ws.send_str.call_count == 0

Expand All @@ -89,11 +86,9 @@ async def test_aux_reload_runtime_error(smart_caplog):
ws = MagicMock()
ws.send_str = MagicMock(return_value=create_future())
ws.send_str = MagicMock(side_effect=RuntimeError('foobar'))
aux_app.update(
websockets=[(ws, '/foo/bar')],
static_url='/static/',
static_path='/path/to/static_files/'
)
aux_app[STATIC_PATH] = "/path/to/static_files/"
aux_app[STATIC_URL] = "/static/"
aux_app[WS] = set(((ws, "/foo/bar"),)) # type: ignore[misc]
assert 0 == await src_reload(aux_app)
assert ws.send_str.call_count == 1
assert 'adev.server.aux ERROR: Error broadcasting change to /foo/bar, RuntimeError: foobar\n' == smart_caplog
Expand All @@ -104,7 +99,7 @@ async def test_aux_cleanup(event_loop):
aux_app.on_cleanup.append(cleanup_aux_app)
ws = MagicMock()
ws.close = MagicMock(return_value=create_future())
aux_app['websockets'] = [(ws, '/foo/bar')]
aux_app[WS] = set(((ws, "/foo/bar"),)) # type: ignore[misc]
aux_app.freeze()
await aux_app.cleanup()
assert ws.close.call_count == 1
Expand All @@ -120,14 +115,14 @@ def test_fmt_size_large(value, result):
assert fmt_size(value) == result


class DummyApplication(Dict[str, object]):
class DummyApplication(Dict[AppKey[Any], object]):
_debug = False

def __init__(self):
self.on_response_prepare = []
self.middlewares = []
self.router = MagicMock()
self['static_root_url'] = '/static/'
self[static_root_key] = '/static/'
self._subapps = []

def add_subapp(self, path, app):
Expand All @@ -144,8 +139,8 @@ def test_modify_main_app_all_off(tmpworkdir):
modify_main_app(app, config) # type: ignore[arg-type]
assert len(app.on_response_prepare) == 0
assert len(app.middlewares) == 0
assert app['static_root_url'] == 'http://foobar.com:8001/static'
assert subapp["static_root_url"] == "http://foobar.com:8001/static"
assert app[static_root_key] == "http://foobar.com:8001/static"
assert subapp[static_root_key] == "http://foobar.com:8001/static"
assert app._debug is True


Expand All @@ -158,8 +153,8 @@ def test_modify_main_app_all_on(tmpworkdir):
modify_main_app(app, config) # type: ignore[arg-type]
assert len(app.on_response_prepare) == 1
assert len(app.middlewares) == 2
assert app['static_root_url'] == 'http://localhost:8001/static'
assert subapp['static_root_url'] == "http://localhost:8001/static"
assert app[static_root_key] == "http://localhost:8001/static"
assert subapp[static_root_key] == "http://localhost:8001/static"
assert app._debug is True


Expand Down
13 changes: 8 additions & 5 deletions tests/test_runserver_watch.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
import asyncio
from functools import partial
from typing import Set, Tuple
from unittest.mock import MagicMock, call

from aiohttp import ClientSession
from aiohttp.web import Application
from aiohttp.web import Application, WebSocketResponse

from aiohttp_devtools.runserver.serve import STATIC_PATH, WS
from aiohttp_devtools.runserver.watch import AppTask, LiveReloadTask

from .conftest import create_future
Expand Down Expand Up @@ -39,7 +41,7 @@ async def test_single_file_change(event_loop, mocker):
stop_mock = mocker.patch.object(app_task, "_stop_dev_server", autospec=True)
app = MagicMock()
await app_task.start(app)
d = {'static_path': '/path/to/'}
d = {STATIC_PATH: "/path/to/"}
app.__getitem__.side_effect = d.__getitem__
assert app_task._task is not None
await app_task._task
Expand Down Expand Up @@ -79,13 +81,13 @@ async def test_python_no_server(event_loop, mocker):
stop_mock = mocker.patch.object(app_task, "_stop_dev_server", autospec=True)
mocker.patch.object(app_task, "_run", partial(app_task._run, live_checks=2))
app = Application()
app['static_path'] = '/path/to/'
app[STATIC_PATH] = "/path/to/"
app.src_reload = MagicMock()
mock_ws = MagicMock()
f: asyncio.Future[int] = asyncio.Future()
f.set_result(1)
mock_ws.send_str = MagicMock(return_value=f)
app['websockets'] = [(mock_ws, '/')]
app[WS] = set(((mock_ws, "/"),)) # type: ignore[misc]
await app_task.start(app)
assert app_task._task is not None
await app_task._task
Expand All @@ -98,7 +100,8 @@ async def test_python_no_server(event_loop, mocker):

async def test_reload_server_running(event_loop, aiohttp_client, mocker):
app = Application()
app['websockets'] = [None]
ws: Set[Tuple[WebSocketResponse, str]] = set(((MagicMock(), "/foo"),))
app[WS] = ws
mock_src_reload = mocker.patch('aiohttp_devtools.runserver.watch.src_reload', return_value=create_future())
cli = await aiohttp_client(app)
config = MagicMock()
Expand Down

0 comments on commit 38c22c7

Please sign in to comment.