Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Persist user interactive authentication sessions #7302

Merged
merged 36 commits into from
Apr 30, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
36 commits
Select commit Hold shift + click to select a range
b030fdb
Persist the storage of UI Auth sessions into the database.
clokep Apr 10, 2020
4a83da9
Ensure that UI auth stages are idempotent.
clokep Apr 15, 2020
e7a2db6
Fix postgresql issues.
clokep Apr 17, 2020
b9dd110
Add a changelog file.
clokep Apr 20, 2020
831f8a2
Expire old sessions.
clokep Apr 20, 2020
04d3d8b
Keep the last_used time up-to-date.
clokep Apr 20, 2020
0fdb22c
Properly run the looping call in the background.
clokep Apr 20, 2020
3d9b3a8
Properly call get_session as async.
clokep Apr 20, 2020
b5fc1b9
Attempt to avoid clashes in session IDs.
clokep Apr 20, 2020
d9157c4
Properly await runInteraction calls.
clokep Apr 20, 2020
ae45238
Add the UIAuthStore to workers.
clokep Apr 21, 2020
9ac9c54
Remove unnecessary lambda
clokep Apr 21, 2020
1c861b8
Only expire old sessions on the master.
clokep Apr 21, 2020
0895971
Match the looping_call signature in unit tests.
clokep Apr 21, 2020
42c4bca
Fix mypy typing and run mypy on the new file.
clokep Apr 21, 2020
2d1bcad
Add a few return types.
clokep Apr 21, 2020
f2e5151
Prefix a number to the delta file.
clokep Apr 22, 2020
7091341
Clarify comments.
clokep Apr 22, 2020
ff14b66
Rename methods based on feedback.
clokep Apr 22, 2020
2a4a910
Avoid re-doing work.
clokep Apr 22, 2020
5a60f2d
Use JsonDict in some places.
clokep Apr 22, 2020
6b4a6df
Create a return type for UI auth session data.
clokep Apr 22, 2020
1a5101b
Ensure the session exists before marking a stage complete.
clokep Apr 22, 2020
264ef03
Use creation time instead of last updated time.
clokep Apr 22, 2020
8b5ef4a
Rename the identity parameter to result.
clokep Apr 22, 2020
f179c21
Separate the unsafe worker methods to a separate object.
clokep Apr 22, 2020
64c709b
Use _txn in method names.
clokep Apr 29, 2020
568a778
Document possible error states better.
clokep Apr 29, 2020
79c5be5
Review feedback.
clokep Apr 29, 2020
18b3494
Do not directly raise SynapseError.
clokep Apr 29, 2020
5340662
Use foreign keys to simplify logic.
clokep Apr 29, 2020
5f0bf19
Again fix idempotency of the registration API.
clokep Apr 29, 2020
eccb670
Fix lint.
clokep Apr 29, 2020
106dca9
Fix typo in docstring.
clokep Apr 30, 2020
64852bf
Raise a 400 error, not 404.
clokep Apr 30, 2020
ae27afd
Convert StoreErrors to SynapseErrors.
clokep Apr 30, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changelog.d/7302.bugfix
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Persist user interactive authentication sessions across workers and Synapse restarts.
2 changes: 2 additions & 0 deletions synapse/app/generic_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,7 @@
MonthlyActiveUsersWorkerStore,
)
from synapse.storage.data_stores.main.presence import UserPresenceState
from synapse.storage.data_stores.main.ui_auth import UIAuthWorkerStore
from synapse.storage.data_stores.main.user_directory import UserDirectoryStore
from synapse.types import ReadReceipt
from synapse.util.async_helpers import Linearizer
Expand Down Expand Up @@ -425,6 +426,7 @@ class GenericWorkerSlavedStore(
# FIXME(#3714): We need to add UserDirectoryStore as we write directly
# rather than going via the correct worker.
UserDirectoryStore,
UIAuthWorkerStore,
SlavedDeviceInboxStore,
SlavedDeviceStore,
SlavedReceiptsStore,
Expand Down
175 changes: 61 additions & 114 deletions synapse/handlers/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,10 +41,10 @@
from synapse.http.server import finish_request
from synapse.http.site import SynapseRequest
from synapse.logging.context import defer_to_thread
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.module_api import ModuleApi
from synapse.push.mailer import load_jinja2_templates
from synapse.types import Requester, UserID
from synapse.util.caches.expiringcache import ExpiringCache

from ._base import BaseHandler

Expand All @@ -69,15 +69,6 @@ def __init__(self, hs):

self.bcrypt_rounds = hs.config.bcrypt_rounds

# This is not a cache per se, but a store of all current sessions that
# expire after N hours
self.sessions = ExpiringCache(
cache_name="register_sessions",
clock=hs.get_clock(),
expiry_ms=self.SESSION_EXPIRE_MS,
reset_expiry_on_get=True,
)

account_handler = ModuleApi(hs, self)
self.password_providers = [
module(config=config, account_handler=account_handler)
Expand Down Expand Up @@ -119,6 +110,15 @@ def __init__(self, hs):

self._clock = self.hs.get_clock()

# Expire old UI auth sessions after a period of time.
if hs.config.worker_app is None:
self._clock.looping_call(
run_as_background_process,
5 * 60 * 1000,
"expire_old_sessions",
self._expire_old_sessions,
)

# Load the SSO HTML templates.

# The following template is shown to the user during a client login via SSO,
Expand Down Expand Up @@ -301,16 +301,21 @@ async def check_auth(
if "session" in authdict:
sid = authdict["session"]

# Convert the URI and method to strings.
uri = request.uri.decode("utf-8")
method = request.uri.decode("utf-8")

# If there's no session ID, create a new session.
if not sid:
session = self._create_session(
clientdict, (request.uri, request.method, clientdict), description
session = await self.store.create_ui_auth_session(
clientdict, uri, method, description
)
session_id = session["id"]

else:
session = self._get_session_info(sid)
session_id = sid
try:
session = await self.store.get_ui_auth_session(sid)
except StoreError:
raise SynapseError(400, "Unknown session ID: %s" % (sid,))

if not clientdict:
# This was designed to allow the client to omit the parameters
Expand All @@ -322,36 +327,35 @@ async def check_auth(
# on a homeserver.
# Revisit: Assuming the REST APIs do sensible validation, the data
# isn't arbitrary.
clientdict = session["clientdict"]
clientdict = session.clientdict

# Ensure that the queried operation does not vary between stages of
# the UI authentication session. This is done by generating a stable
# comparator based on the URI, method, and body (minus the auth dict)
# and storing it during the initial query. Subsequent queries ensure
# that this comparator has not changed.
comparator = (request.uri, request.method, clientdict)
if session["ui_auth"] != comparator:
comparator = (uri, method, clientdict)
if (session.uri, session.method, session.clientdict) != comparator:
raise SynapseError(
403,
"Requested operation has changed during the UI authentication session.",
)

if not authdict:
raise InteractiveAuthIncompleteError(
self._auth_dict_for_flows(flows, session_id)
self._auth_dict_for_flows(flows, session.session_id)
)

creds = session["creds"]

# check auth type currently being presented
errordict = {} # type: Dict[str, Any]
if "type" in authdict:
login_type = authdict["type"] # type: str
try:
result = await self._check_auth_dict(authdict, clientip)
if result:
creds[login_type] = result
self._save_session(session)
await self.store.mark_ui_auth_stage_complete(
session.session_id, login_type, result
)
except LoginError as e:
if login_type == LoginType.EMAIL_IDENTITY:
# riot used to have a bug where it would request a new
Expand All @@ -367,6 +371,7 @@ async def check_auth(
# so that the client can have another go.
errordict = e.error_dict()

creds = await self.store.get_completed_ui_auth_stages(session.session_id)
for f in flows:
if len(set(f) - set(creds)) == 0:
# it's very useful to know what args are stored, but this can
Expand All @@ -380,9 +385,9 @@ async def check_auth(
list(clientdict),
)

return creds, clientdict, session_id
return creds, clientdict, session.session_id

ret = self._auth_dict_for_flows(flows, session_id)
ret = self._auth_dict_for_flows(flows, session.session_id)
ret["completed"] = list(creds)
ret.update(errordict)
raise InteractiveAuthIncompleteError(ret)
Expand All @@ -399,13 +404,11 @@ async def add_oob_auth(
if "session" not in authdict:
raise LoginError(400, "", Codes.MISSING_PARAM)

sess = self._get_session_info(authdict["session"])
creds = sess["creds"]

result = await self.checkers[stagetype].check_auth(authdict, clientip)
if result:
creds[stagetype] = result
self._save_session(sess)
await self.store.mark_ui_auth_stage_complete(
authdict["session"], stagetype, result
)
return True
return False

Expand All @@ -427,7 +430,7 @@ def get_session_id(self, clientdict: Dict[str, Any]) -> Optional[str]:
sid = authdict["session"]
return sid

def set_session_data(self, session_id: str, key: str, value: Any) -> None:
async def set_session_data(self, session_id: str, key: str, value: Any) -> None:
"""
Store a key-value pair into the sessions data associated with this
request. This data is stored server-side and cannot be modified by
Expand All @@ -438,11 +441,12 @@ def set_session_data(self, session_id: str, key: str, value: Any) -> None:
key: The key to store the data under
value: The data to store
"""
sess = self._get_session_info(session_id)
sess["serverdict"][key] = value
self._save_session(sess)
try:
await self.store.set_ui_auth_session_data(session_id, key, value)
except StoreError:
raise SynapseError(400, "Unknown session ID: %s" % (session_id,))

def get_session_data(
async def get_session_data(
self, session_id: str, key: str, default: Optional[Any] = None
) -> Any:
"""
Expand All @@ -453,8 +457,18 @@ def get_session_data(
key: The key to store the data under
default: Value to return if the key has not been set
"""
sess = self._get_session_info(session_id)
return sess["serverdict"].get(key, default)
try:
return await self.store.get_ui_auth_session_data(session_id, key, default)
except StoreError:
raise SynapseError(400, "Unknown session ID: %s" % (session_id,))

async def _expire_old_sessions(self):
"""
Invalidate any user interactive authentication sessions that have expired.
"""
now = self._clock.time_msec()
expiration_time = now - self.SESSION_EXPIRE_MS
await self.store.delete_old_ui_auth_sessions(expiration_time)

async def _check_auth_dict(
self, authdict: Dict[str, Any], clientip: str
Expand Down Expand Up @@ -534,67 +548,6 @@ def _auth_dict_for_flows(
"params": params,
}

def _create_session(
self,
clientdict: Dict[str, Any],
ui_auth: Tuple[bytes, bytes, Dict[str, Any]],
description: str,
) -> dict:
"""
Creates a new user interactive authentication session.

The session can be used to track data across multiple requests, e.g. for
interactive authentication.

Each session has the following keys:

id:
A unique identifier for this session. Passed back to the client
and returned for each stage.
clientdict:
The dictionary from the client root level, not the 'auth' key.
ui_auth:
A tuple which is checked at each stage of the authentication to
ensure that the asked for operation has not changed.
creds:
A map, which maps each auth-type (str) to the relevant identity
authenticated by that auth-type (mostly str, but for captcha, bool).
serverdict:
A map of data that is stored server-side and cannot be modified
by the client.
description:
A string description of the operation that the current
authentication is authorising.
Returns:
The newly created session.
"""
session_id = None
while session_id is None or session_id in self.sessions:
session_id = stringutils.random_string(24)

self.sessions[session_id] = {
"id": session_id,
"clientdict": clientdict,
"ui_auth": ui_auth,
"creds": {},
"serverdict": {},
"description": description,
}

return self.sessions[session_id]

def _get_session_info(self, session_id: str) -> dict:
"""
Gets a session given a session ID.

The session can be used to track data across multiple requests, e.g. for
interactive authentication.
"""
try:
return self.sessions[session_id]
except KeyError:
raise SynapseError(400, "Unknown session ID: %s" % (session_id,))

async def get_access_token_for_user_id(
self, user_id: str, device_id: Optional[str], valid_until_ms: Optional[int]
):
Expand Down Expand Up @@ -994,13 +947,6 @@ async def delete_threepid(
await self.store.user_delete_threepid(user_id, medium, address)
return result

def _save_session(self, session: Dict[str, Any]) -> None:
"""Update the last used time on the session to now and add it back to the session store."""
# TODO: Persistent storage
logger.debug("Saving session %s", session)
session["last_used"] = self.hs.get_clock().time_msec()
self.sessions[session["id"]] = session

async def hash(self, password: str) -> str:
"""Computes a secure hash of password.

Expand Down Expand Up @@ -1052,7 +998,7 @@ def _do_validate_hash():
else:
return False

def start_sso_ui_auth(self, redirect_url: str, session_id: str) -> str:
async def start_sso_ui_auth(self, redirect_url: str, session_id: str) -> str:
"""
Get the HTML for the SSO redirect confirmation page.

Expand All @@ -1063,12 +1009,15 @@ def start_sso_ui_auth(self, redirect_url: str, session_id: str) -> str:
Returns:
The HTML to render.
"""
session = self._get_session_info(session_id)
try:
session = await self.store.get_ui_auth_session(session_id)
except StoreError:
raise SynapseError(400, "Unknown session ID: %s" % (session_id,))
return self._sso_auth_confirm_template.render(
description=session["description"], redirect_url=redirect_url,
description=session.description, redirect_url=redirect_url,
)

def complete_sso_ui_auth(
async def complete_sso_ui_auth(
self, registered_user_id: str, session_id: str, request: SynapseRequest,
):
"""Having figured out a mxid for this user, complete the HTTP request
Expand All @@ -1080,13 +1029,11 @@ def complete_sso_ui_auth(
process.
"""
# Mark the stage of the authentication as successful.
sess = self._get_session_info(session_id)
creds = sess["creds"]

# Save the user who authenticated with SSO, this will be used to ensure
# that the account be modified is also the person who logged in.
creds[LoginType.SSO] = registered_user_id
self._save_session(sess)
await self.store.mark_ui_auth_stage_complete(
session_id, LoginType.SSO, registered_user_id
)

# Render the HTML and return.
html_bytes = self._sso_auth_success_template.encode("utf-8")
Expand Down
2 changes: 1 addition & 1 deletion synapse/handlers/cas_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,7 +206,7 @@ async def handle_ticket(
registered_user_id = await self._auth_handler.check_user_exists(user_id)

if session:
self._auth_handler.complete_sso_ui_auth(
await self._auth_handler.complete_sso_ui_auth(
registered_user_id, session, request,
)

Expand Down
2 changes: 1 addition & 1 deletion synapse/handlers/saml_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,7 @@ async def handle_saml_response(self, request):

# Complete the interactive auth session or the login.
if current_session and current_session.ui_auth_session_id:
self._auth_handler.complete_sso_ui_auth(
await self._auth_handler.complete_sso_ui_auth(
user_id, current_session.ui_auth_session_id, request
)

Expand Down
4 changes: 2 additions & 2 deletions synapse/rest/client/v2_alpha/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ def __init__(self, hs):
self._cas_server_url = hs.config.cas_server_url
self._cas_service_url = hs.config.cas_service_url

def on_GET(self, request, stagetype):
async def on_GET(self, request, stagetype):
session = parse_string(request, "session")
if not session:
raise SynapseError(400, "No session supplied")
Expand Down Expand Up @@ -180,7 +180,7 @@ def on_GET(self, request, stagetype):
else:
raise SynapseError(400, "Homeserver not configured for SSO.")

html = self.auth_handler.start_sso_ui_auth(sso_redirect_url, session)
html = await self.auth_handler.start_sso_ui_auth(sso_redirect_url, session)

else:
raise SynapseError(404, "Unknown auth stage type")
Expand Down
4 changes: 2 additions & 2 deletions synapse/rest/client/v2_alpha/register.py
Original file line number Diff line number Diff line change
Expand Up @@ -489,7 +489,7 @@ async def on_POST(self, request):
# registered a user for this session, so we could just return the
# user here. We carry on and go through the auth checks though,
# for paranoia.
registered_user_id = self.auth_handler.get_session_data(
registered_user_id = await self.auth_handler.get_session_data(
session_id, "registered_user_id", None
)

Expand Down Expand Up @@ -588,7 +588,7 @@ async def on_POST(self, request):

# remember that we've now registered that user account, and with
# what user ID (since the user may not have specified)
self.auth_handler.set_session_data(
await self.auth_handler.set_session_data(
session_id, "registered_user_id", registered_user_id
)

Expand Down
Loading