Skip to content

Commit

Permalink
Add support for variables in query parameters (#176)
Browse files Browse the repository at this point in the history
* Add support for variables in query parameters

* Fix code checks
  • Loading branch information
viernullvier committed May 12, 2022
1 parent 2e48862 commit 77d1902
Show file tree
Hide file tree
Showing 5 changed files with 107 additions and 7 deletions.
2 changes: 2 additions & 0 deletions pytest.ini
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
[pytest]
asyncio_mode = auto
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ wheel
asgi-lifespan
async-generator; python_version<'3.7'
autoflake
black==21.9b0
black==22.3.0
flake8
flake8-bugbear
flake8-comprehensions
Expand Down
30 changes: 26 additions & 4 deletions src/tartiflette_asgi/_endpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,16 +33,36 @@ async def get(self, request: Request) -> Response:

class GraphQLEndpoint(HTTPEndpoint):
async def get(self, request: Request) -> Response:
return await self._get_response(request, data=request.query_params)
variables = None
if "variables" in request.query_params:
try:
variables = json.loads(request.query_params["variables"])
except json.JSONDecodeError:
return JSONResponse(
{"error": "Unable to decode variables: Invalid JSON."}, 400
)
return await self._get_response(
request, data=request.query_params, variables=variables
)

async def post(self, request: Request) -> Response:
content_type = request.headers.get("Content-Type", "")

variables = None
if "variables" in request.query_params:
try:
variables = json.loads(request.query_params["variables"])
except json.JSONDecodeError:
return JSONResponse(
{"error": "Unable to decode variables: Invalid JSON."}, 400
)

if "application/json" in content_type:
try:
data = await request.json()
except json.JSONDecodeError:
return JSONResponse({"error": "Invalid JSON."}, 400)
variables = data.get("variables", variables)
elif "application/graphql" in content_type:
body = await request.body()
data = {"query": body.decode()}
Expand All @@ -51,9 +71,11 @@ async def post(self, request: Request) -> Response:
else:
return PlainTextResponse("Unsupported Media Type", 415)

return await self._get_response(request, data=data)
return await self._get_response(request, data=data, variables=variables)

async def _get_response(self, request: Request, data: QueryParams) -> Response:
async def _get_response(
self, request: Request, data: QueryParams, variables: typing.Optional[dict]
) -> Response:
try:
query = data["query"]
except KeyError:
Expand All @@ -67,7 +89,7 @@ async def _get_response(self, request: Request, data: QueryParams) -> Response:
result: dict = await engine.execute(
query,
context=context,
variables=data.get("variables"),
variables=variables,
operation_name=data.get("operationName"),
)

Expand Down
76 changes: 76 additions & 0 deletions tests/test_graphql_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,20 @@ async def test_get_querystring(engine: Engine) -> None:
assert response.json() == {"data": {"hello": "Hello stranger"}}


@pytest.mark.asyncio
async def test_get_querystring_variables(engine: Engine) -> None:
app = TartifletteApp(engine=engine)
async with get_client(app) as client:
response = await client.get(
(
"/?query=query($name: String) { hello(name: $name) }"
'&variables={ "name": "world" }'
)
)
assert response.status_code == 200
assert response.json() == {"data": {"hello": "Hello world"}}


@pytest.mark.asyncio
@pytest.mark.parametrize("path", ("/", "/?foo=bar", "/?q={ hello }"))
async def test_get_no_query(engine: Engine, path: str) -> None:
Expand All @@ -25,6 +39,15 @@ async def test_get_no_query(engine: Engine, path: str) -> None:
assert response.text == "No GraphQL query found in the request"


@pytest.mark.asyncio
async def test_get_invalid_json(engine: Engine) -> None:
app = TartifletteApp(engine=engine)
async with get_client(app) as client:
response = await client.get("/?query={ hello }&variables={test")
assert response.status_code == 400
assert response.json() == {"error": "Unable to decode variables: Invalid JSON."}


@pytest.mark.asyncio
async def test_post_querystring(engine: Engine) -> None:
app = TartifletteApp(engine=engine)
Expand All @@ -34,6 +57,31 @@ async def test_post_querystring(engine: Engine) -> None:
assert response.json() == {"data": {"hello": "Hello stranger"}}


@pytest.mark.asyncio
async def test_post_querystring_variables(engine: Engine) -> None:
app = TartifletteApp(engine=engine)
async with get_client(app) as client:
response = await client.post(
(
"/?query=query($name: String) { hello(name: $name) }"
'&variables={ "name": "world" }'
)
)
assert response.status_code == 200
assert response.json() == {"data": {"hello": "Hello world"}}


@pytest.mark.asyncio
async def test_post_querystring_invalid_json(engine: Engine) -> None:
app = TartifletteApp(engine=engine)
async with get_client(app) as client:
response = await client.post(
"/?query=query($name: String) { hello(name: $name) }&variables={test"
)
assert response.status_code == 400
assert response.json() == {"error": "Unable to decode variables: Invalid JSON."}


@pytest.mark.asyncio
async def test_post_json(engine: Engine) -> None:
app = TartifletteApp(engine=engine)
Expand All @@ -43,6 +91,21 @@ async def test_post_json(engine: Engine) -> None:
assert response.json() == {"data": {"hello": "Hello stranger"}}


@pytest.mark.asyncio
async def test_post_json_variables(engine: Engine) -> None:
app = TartifletteApp(engine=engine)
async with get_client(app) as client:
response = await client.post(
"/",
json={
"query": "query($name: String) { hello(name: $name) }",
"variables": {"name": "world"},
},
)
assert response.status_code == 200
assert response.json() == {"data": {"hello": "Hello world"}}


@pytest.mark.asyncio
async def test_post_invalid_json(engine: Engine) -> None:
app = TartifletteApp(engine=engine)
Expand All @@ -65,6 +128,19 @@ async def test_post_graphql(engine: Engine) -> None:
assert response.json() == {"data": {"hello": "Hello stranger"}}


@pytest.mark.asyncio
async def test_post_graphql_variables(engine: Engine) -> None:
app = TartifletteApp(engine=engine)
async with get_client(app) as client:
response = await client.post(
'/?variables={ "name": "world" }',
data="query($name: String) { hello(name: $name) }",
headers={"content-type": "application/graphql"},
)
assert response.status_code == 200
assert response.json() == {"data": {"hello": "Hello world"}}


@pytest.mark.asyncio
async def test_post_invalid_media_type(engine: Engine) -> None:
app = TartifletteApp(engine=engine)
Expand Down
4 changes: 2 additions & 2 deletions tests/test_subscription.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,12 +49,12 @@ def fixture_path(subscriptions: typing.Any) -> str:


def _init(ws: WebSocket) -> None:
ws.send_json({"type": "connection_init"})
ws.send_json({"type": "connection_init"}) # type: ignore
assert ws.receive_json() == {"type": "connection_ack"}


def _terminate(ws: WebSocket) -> None:
ws.send_json({"type": "connection_terminate"})
ws.send_json({"type": "connection_terminate"}) # type: ignore


def test_protocol_connect_disconnect(
Expand Down

0 comments on commit 77d1902

Please sign in to comment.