Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Pass module API to OIDC mapping provider #16974

Merged
merged 1 commit into from
Mar 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changelog.d/16974.misc
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
As done for SAML mapping provider, let's pass the module API to the OIDC one so the mapper can do more logic in its code.
4 changes: 3 additions & 1 deletion docs/sso_mapping_providers.md
Original file line number Diff line number Diff line change
Expand Up @@ -50,11 +50,13 @@ comment these options out and use those specified by the module instead.

A custom mapping provider must specify the following methods:

* `def __init__(self, parsed_config)`
* `def __init__(self, parsed_config, module_api)`
- Arguments:
- `parsed_config` - A configuration object that is the return value of the
`parse_config` method. You should set any configuration options needed by
the module here.
- `module_api` - a `synapse.module_api.ModuleApi` object which provides the
stable API available for extension modules.
* `def parse_config(config)`
- This method should have the `@staticmethod` decoration.
- Arguments:
Expand Down
17 changes: 14 additions & 3 deletions synapse/handlers/oidc.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@
from synapse.http.servlet import parse_string
from synapse.http.site import SynapseRequest
from synapse.logging.context import make_deferred_yieldable
from synapse.module_api import ModuleApi
from synapse.types import JsonDict, UserID, map_username_to_mxid_localpart
from synapse.util import Clock, json_decoder
from synapse.util.caches.cached_call import RetryOnExceptionCachedCall
Expand Down Expand Up @@ -421,9 +422,19 @@ def __init__(
# from the IdP's jwks_uri, if required.
self._jwks = RetryOnExceptionCachedCall(self._load_jwks)

self._user_mapping_provider = provider.user_mapping_provider_class(
provider.user_mapping_provider_config
user_mapping_provider_init_method = (
provider.user_mapping_provider_class.__init__
)
if len(inspect.signature(user_mapping_provider_init_method).parameters) == 3:
self._user_mapping_provider = provider.user_mapping_provider_class(
provider.user_mapping_provider_config,
ModuleApi(hs, hs.get_auth_handler()),
)
else:
self._user_mapping_provider = provider.user_mapping_provider_class(
provider.user_mapping_provider_config,
)

self._skip_verification = provider.skip_verification
self._allow_existing_users = provider.allow_existing_users

Expand Down Expand Up @@ -1583,7 +1594,7 @@ class JinjaOidcMappingProvider(OidcMappingProvider[JinjaOidcMappingConfig]):
This is the default mapping provider.
"""

def __init__(self, config: JinjaOidcMappingConfig):
def __init__(self, config: JinjaOidcMappingConfig, module_api: ModuleApi):
self._config = config

@staticmethod
Expand Down
Loading