Skip to content

Commit

Permalink
OpenID Connect: Add hook to be able to customize role creation
Browse files Browse the repository at this point in the history
  • Loading branch information
sbrunner committed Sep 19, 2024
1 parent 758fc00 commit 8ca57e1
Show file tree
Hide file tree
Showing 3 changed files with 101 additions and 70 deletions.
35 changes: 8 additions & 27 deletions geoportal/c2cgeoportal_geoportal/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,23 +49,22 @@
import sqlalchemy.orm
import zope.event.classhandler
from c2cgeoform import translator
from c2cwsgiutils.broadcast import decorator
from c2cwsgiutils.health_check import HealthCheck
from c2cwsgiutils.prometheus import MemoryMapCollector
from deform import Form
from dogpile.cache import register_backend # type: ignore[attr-defined]
from papyrus.renderers import GeoJSON
from prometheus_client.core import REGISTRY
from pyramid.config import Configurator
from pyramid.httpexceptions import HTTPBadRequest, HTTPException
from pyramid.httpexceptions import HTTPException
from pyramid.path import AssetResolver
from pyramid_mako import add_mako_renderer
from sqlalchemy.orm import Session, joinedload
from sqlalchemy.orm import joinedload

import c2cgeoportal_commons.models
import c2cgeoportal_geoportal.views
from c2cgeoportal_commons.models import InvalidateCacheEvent
from c2cgeoportal_geoportal.lib import C2CPregenerator, caching, check_collector, checker
from c2cgeoportal_geoportal.lib import C2CPregenerator, caching, check_collector, checker, oidc
from c2cgeoportal_geoportal.lib.cacheversion import version_cache_buster
from c2cgeoportal_geoportal.lib.common_headers import Cache, set_common_headers
from c2cgeoportal_geoportal.lib.i18n import available_locale_names
Expand Down Expand Up @@ -317,7 +316,6 @@ def get_user_from_request(
"""
from c2cgeoportal_commons.models import DBSession # pylint: disable=import-outside-toplevel
from c2cgeoportal_commons.models.static import User # pylint: disable=import-outside-toplevel
from c2cgeoportal_geoportal.lib import oidc # pylint: disable=import-outside-toplevel

assert DBSession is not None

Expand Down Expand Up @@ -347,28 +345,10 @@ def get_user_from_request(
)
user_info = oidc.OidcRemember(request).remember(token_response)

if openid_connect_config.get("provide_roles", False) is True:
from c2cgeoportal_commons.models.main import ( # pylint: disable=import-outside-toplevel
Role,
)

request.user_ = oidc.DynamicUser(
username=user_info["username"],
email=user_info["email"],
settings_role=(
DBSession.query(Role).filter_by(name=user_info["settings_role"]).first()
if user_info.get("settings_role") is not None
else None
),
roles=[
DBSession.query(Role).filter_by(name=role).one()
for role in user_info.get("roles", [])
],
)
else:
request.user_ = DBSession.query(User).filter_by(email=user_info["email"]).first()
for user in DBSession.query(User).all():
_LOG.error(user.username)
request.user_ = request.get_user_from_reminder(
user_info,
request.registry.settings.get("authentication", {}).get("openid_connect", {}),
)
else:
# We know we will need the role object of the
# user so we use joined loading
Expand Down Expand Up @@ -517,6 +497,7 @@ def includeme(config: pyramid.config.Configurator) -> None:

config.include("pyramid_mako")
config.include("c2cwsgiutils.pyramid.includeme")
config.include(oidc.includeme)
health_check = HealthCheck(config)
config.registry["health_check"] = health_check

Expand Down
113 changes: 90 additions & 23 deletions geoportal/c2cgeoportal_geoportal/lib/oidc.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
import datetime
import json
import logging
from typing import NamedTuple, TypedDict
from typing import TYPE_CHECKING, Any, NamedTuple, Optional, TypedDict, Union

import pyramid.request
import pyramid.response
Expand All @@ -37,9 +37,11 @@
from pyramid.httpexceptions import HTTPBadRequest, HTTPInternalServerError, HTTPUnauthorized
from pyramid.security import remember

from c2cgeoportal_commons.models import main
from c2cgeoportal_geoportal.lib.caching import get_region

if TYPE_CHECKING:
from c2cgeoportal_commons.models import main, static

_LOG = logging.getLogger(__name__)
_CACHE_REGION_OBJ = get_region("obj")

Expand All @@ -52,8 +54,8 @@ class DynamicUser(NamedTuple):

username: str
email: str
settings_role: main.Role | None
roles: list[main.Role]
settings_role: Optional["main.Role"]
roles: list["main.Role"]


@_CACHE_REGION_OBJ.cache_on_arguments()
Expand Down Expand Up @@ -92,6 +94,81 @@ class OidcRememberObject(TypedDict):
roles: list[str]


def get_remember_from_user_info(
user_info: dict[str, Any], remember_object: OidcRememberObject, settings: dict[str, Any]
) -> None:
"""
Fill the remember object from the user info.
The remember object will be stored in a cookie to remember the user.
:param user_info: The user info from the ID token or from the user info view according to the `query_user_info` configuration.
:param remember_object: The object to fill, by default with the `username`, `email`, `settings_role` and `roles`,
the corresponding field from `user_info` can be configured in `user_info_fields`.
:param settings: The OpenID Connect configuration.
"""
settings_fields = settings.get("user_info_fields", {})

for field_, default_field in (
("username", "name"),
("email", "email"),
("settings_role", None),
("roles", None),
):
user_info_field = settings_fields.get(field_, default_field)
if user_info_field is not None:
if user_info_field not in user_info:
_LOG.error(
"Field '%s' not found in user info, available: %s.",
user_info_field,
", ".join(user_info.keys()),
)
raise HTTPInternalServerError(f"Field '{user_info_field}' not found in user info.")
remember_object[field_] = user_info[user_info_field] # type: ignore[literal-required]


def get_user_from_remember(
remember_object: OidcRememberObject, settings: dict[str, Any], create_user: bool = False
) -> Union["static.User", DynamicUser] | None:
"""
Create a user from the remember object filled from `get_remember_from_user_info`.
:param remember_object: The object to fill, by default with the `username`, `email`, `settings_role` and `roles`.
:param settings: The OpenID Connect configuration.
:param create_user: If the user should be created if it does not exist.
"""
from c2cgeoportal_commons import models # pylint: disable=import-outside-toplevel
from c2cgeoportal_commons.models import main, static # pylint: disable=import-outside-toplevel

assert models.DBSession is not None

user: static.User | DynamicUser | None
username = remember_object["username"]
assert username is not None
email = remember_object["email"]
assert email is not None
if settings.get("provide_roles", False) is False:
user = models.DBSession.query(static.User).filter_by(email=email).one_or_none()
if user is None and create_user is True:
user = static.User(username=username, email=email)
models.DBSession.add(user)
else:
user = DynamicUser(
username=username,
email=email,
settings_role=(
models.DBSession.query(main.Role).filter_by(name=remember_object["settings_role"]).first()
if remember_object.get("settings_role") is not None
else None
),
roles=[
models.DBSession.query(main.Role).filter_by(name=role).one()
for role in remember_object.get("roles", [])
],
)
return user


class OidcRemember:
"""
Build the abject that we want to remember in the cookie.
Expand Down Expand Up @@ -142,7 +219,6 @@ def remember(
"settings_role": None,
"roles": [],
}
settings_fields = openid_connect.get("user_info_fields", {})
client = get_oidc_client(self.request)

if openid_connect.get("query_user_info", False) is True:
Expand All @@ -166,24 +242,15 @@ def remember(
),
)

for field_, default_field in (
("username", "name"),
("email", "email"),
("settings_role", None),
("roles", None),
):
user_info_field = settings_fields.get(field_, default_field)
if user_info_field is not None:
user_info_dict = user_info.dict()
if user_info_field not in user_info_dict:
_LOG.error(
"Field '%s' not found in user info, available: %s.",
user_info_field,
", ".join(user_info_dict.keys()),
)
raise HTTPInternalServerError(f"Field '{user_info_field}' not found in user info.")
remember_object[field_] = user_info_dict[user_info_field] # type: ignore[literal-required]

self.request.get_remember_from_user_info(user_info.dict(), remember_object, openid_connect)
self.request.response.headers.extend(remember(self.request, json.dumps(remember_object)))

return remember_object


def includeme(config: pyramid.config.Configurator) -> None:
"""
Pyramid includeme function.
"""
config.add_request_method(get_remember_from_user_info, name="get_remember_from_user_info")
config.add_request_method(get_user_from_remember, name="get_user_from_remember")
23 changes: 3 additions & 20 deletions geoportal/c2cgeoportal_geoportal/views/login.py
Original file line number Diff line number Diff line change
Expand Up @@ -644,26 +644,9 @@ def oidc_callback(self) -> pyramid.response.Response:

remember_object = oidc.OidcRemember(self.request).remember(token_response)

user: static.User | oidc.DynamicUser | None
if self.authentication_settings.get("openid_connect", {}).get("provide_roles", False) is False:
user = models.DBSession.query(static.User).filter_by(email=remember_object["email"]).one_or_none()
if user is None:
user = static.User(username=remember_object["username"], email=remember_object["email"])
models.DBSession.add(user)
else:
user = oidc.DynamicUser(
username=remember_object["username"],
email=remember_object["email"],
settings_role=(
models.DBSession.query(main.Role).filter_by(name=remember_object["settings_role"]).first()
if remember_object.get("settings_role") is not None
else None
),
roles=[
models.DBSession.query(main.Role).filter_by(name=role).one()
for role in remember_object.get("roles", [])
],
)
user: static.User | oidc.DynamicUser | None = self.request.get_user_from_remember(
remember_object, self.authentication_settings
)
assert user is not None
self.request.user_ = user

Expand Down

0 comments on commit 8ca57e1

Please sign in to comment.