diff --git a/requirements/main.in b/requirements/main.in index c983692c62e6..020ba2999e97 100644 --- a/requirements/main.in +++ b/requirements/main.in @@ -43,6 +43,7 @@ pyramid_rpc>=0.7 pyramid_services>=2.1 pyramid_tm>=0.12 python-slugify +PyJWT[crypto]>=2.3.0 readme-renderer[md]>=0.7.0 requests requests-aws4auth diff --git a/requirements/main.txt b/requirements/main.txt index b20ea387ded3..ad4309b8e6c5 100644 --- a/requirements/main.txt +++ b/requirements/main.txt @@ -283,7 +283,8 @@ cryptography==36.0.1 \ --hash=sha256:ebc15b1c22e55c4d5566e3ca4db8689470a0ca2babef8e3a9ee057a8b82ce4b1 \ --hash=sha256:ec63da4e7e4a5f924b90af42eddf20b698a70e58d86a72d943857c4c6045b3ee # via - # -r requirements/main.in + # -r main.in + # pyjwt # pyopenssl # webauthn cssselect==1.1.0 \ @@ -883,6 +884,10 @@ pygments==2.10.0 \ --hash=sha256:b8e67fe6af78f492b3c4b3e2970c0624cbf08beb1e493b2c99b9fa1b67a20380 \ --hash=sha256:f398865f7eb6874156579fdf36bc840a03cab64d1cde9e93d68f46a425ec52c6 # via readme-renderer +pyjwt[crypto]==2.3.0 \ + --hash=sha256:b888b4d56f06f6dcd777210c334e69c737be74755d3e5e9ee3fe67dc18a0ee41 \ + --hash=sha256:e0c4bb8d9f0af0c7f5b1ec4c5036309617d03d56932877f2f7a0beeb5318322f + # via -r main.in pymacaroons==0.13.0 \ --hash=sha256:1e6bba42a5f66c245adf38a5a4006a99dcc06a0703786ea636098667d42903b8 \ --hash=sha256:3e14dff6a262fdbf1a15e769ce635a8aea72e6f8f91e408f9a97166c53b91907 diff --git a/tests/unit/oidc/test_services.py b/tests/unit/oidc/test_services.py new file mode 100644 index 000000000000..75363eee9665 --- /dev/null +++ b/tests/unit/oidc/test_services.py @@ -0,0 +1,418 @@ +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import functools + +import fakeredis +import pretend + +from jwt import PyJWK +from zope.interface.verify import verifyClass + +from warehouse.oidc import interfaces, services + + +def test_oidc_provider_service_factory(): + factory = services.OIDCProviderServiceFactory( + provider="example", issuer_url="https://example.com" + ) + + assert factory.provider == "example" + assert factory.issuer_url == "https://example.com" + assert verifyClass(interfaces.IOIDCProviderService, factory.service_class) + + metrics = pretend.stub() + request = pretend.stub( + registry=pretend.stub( + settings={"oidc.jwk_cache_url": "https://another.example.com"} + ), + find_service=lambda *a, **kw: metrics, + ) + service = factory(pretend.stub(), request) + + assert isinstance(service, factory.service_class) + assert service.provider == factory.provider + assert service.issuer_url == factory.issuer_url + assert service.cache_url == "https://another.example.com" + assert service.metrics == metrics + + assert factory != object() + assert factory != services.OIDCProviderServiceFactory( + provider="another", issuer_url="https://foo.example.com" + ) + + +class TestOIDCProviderService: + def test_verify(self): + service = services.OIDCProviderService( + provider=pretend.stub(), + issuer_url=pretend.stub(), + cache_url=pretend.stub(), + metrics=pretend.stub(), + ) + assert service.verify(pretend.stub()) == NotImplemented + + def test_get_keyset_not_cached(self, monkeypatch): + service = services.OIDCProviderService( + provider="example", + issuer_url=pretend.stub(), + cache_url="rediss://fake.example.com", + metrics=pretend.stub(), + ) + + monkeypatch.setattr(services.redis, "StrictRedis", fakeredis.FakeStrictRedis) + keys, timeout = service._get_keyset() + + assert not keys + assert timeout is False + + def test_get_keyset_cached(self, monkeypatch): + service = services.OIDCProviderService( + provider="example", + issuer_url=pretend.stub(), + cache_url="rediss://fake.example.com", + metrics=pretend.stub(), + ) + + # Create a fake server to provide persistent state through each + # StrictRedis.from_url context manager. + server = fakeredis.FakeServer() + from_url = functools.partial(fakeredis.FakeStrictRedis.from_url, server=server) + monkeypatch.setattr(services.redis.StrictRedis, "from_url", from_url) + + keyset = {"fake-key-id": {"foo": "bar"}} + service._store_keyset(keyset) + keys, timeout = service._get_keyset() + + assert keys == keyset + assert timeout is True + + def test_refresh_keyset_timeout(self, monkeypatch): + metrics = pretend.stub(increment=pretend.call_recorder(lambda *a, **kw: None)) + service = services.OIDCProviderService( + provider="example", + issuer_url="https://example.com", + cache_url="rediss://fake.example.com", + metrics=metrics, + ) + + # Create a fake server to provide persistent state through each + # StrictRedis.from_url context manager. + server = fakeredis.FakeServer() + from_url = functools.partial(fakeredis.FakeStrictRedis.from_url, server=server) + monkeypatch.setattr(services.redis.StrictRedis, "from_url", from_url) + + keyset = {"fake-key-id": {"foo": "bar"}} + service._store_keyset(keyset) + + keys = service._refresh_keyset() + assert keys == keyset + assert metrics.increment.calls == [ + pretend.call( + "warehouse.oidc.refresh_keyset.timeout", tags=["provider:example"] + ) + ] + + def test_refresh_keyset_oidc_config_fails(self, monkeypatch): + metrics = pretend.stub(increment=pretend.call_recorder(lambda *a, **kw: None)) + service = services.OIDCProviderService( + provider="example", + issuer_url="https://example.com", + cache_url="rediss://fake.example.com", + metrics=metrics, + ) + + monkeypatch.setattr(services.redis, "StrictRedis", fakeredis.FakeStrictRedis) + + requests = pretend.stub( + get=pretend.call_recorder(lambda url: pretend.stub(ok=False)) + ) + sentry_sdk = pretend.stub( + capture_message=pretend.call_recorder(lambda msg: pretend.stub()) + ) + monkeypatch.setattr(services, "requests", requests) + monkeypatch.setattr(services, "sentry_sdk", sentry_sdk) + + keys = service._refresh_keyset() + + assert keys == {} + assert metrics.increment.calls == [] + assert requests.get.calls == [ + pretend.call("https://example.com/.well-known/openid-configuration") + ] + assert sentry_sdk.capture_message.calls == [ + pretend.call( + "OIDC provider example failed to return configuration: " + "https://example.com/.well-known/openid-configuration" + ) + ] + + def test_refresh_keyset_oidc_config_no_jwks_uri(self, monkeypatch): + metrics = pretend.stub(increment=pretend.call_recorder(lambda *a, **kw: None)) + service = services.OIDCProviderService( + provider="example", + issuer_url="https://example.com", + cache_url="rediss://fake.example.com", + metrics=metrics, + ) + + monkeypatch.setattr(services.redis, "StrictRedis", fakeredis.FakeStrictRedis) + + requests = pretend.stub( + get=pretend.call_recorder( + lambda url: pretend.stub(ok=True, json=lambda: {}) + ) + ) + sentry_sdk = pretend.stub( + capture_message=pretend.call_recorder(lambda msg: pretend.stub()) + ) + monkeypatch.setattr(services, "requests", requests) + monkeypatch.setattr(services, "sentry_sdk", sentry_sdk) + + keys = service._refresh_keyset() + + assert keys == {} + assert metrics.increment.calls == [] + assert requests.get.calls == [ + pretend.call("https://example.com/.well-known/openid-configuration") + ] + assert sentry_sdk.capture_message.calls == [ + pretend.call( + "OIDC provider example is returning malformed configuration " + "(no jwks_uri)" + ) + ] + + def test_refresh_keyset_oidc_config_no_jwks_json(self, monkeypatch): + metrics = pretend.stub(increment=pretend.call_recorder(lambda *a, **kw: None)) + service = services.OIDCProviderService( + provider="example", + issuer_url="https://example.com", + cache_url="rediss://fake.example.com", + metrics=metrics, + ) + + monkeypatch.setattr(services.redis, "StrictRedis", fakeredis.FakeStrictRedis) + + openid_resp = pretend.stub( + ok=True, + json=lambda: { + "jwks_uri": "https://example.com/.well-known/jwks.json", + }, + ) + jwks_resp = pretend.stub(ok=False) + + def get(url): + if url == "https://example.com/.well-known/jwks.json": + return jwks_resp + else: + return openid_resp + + requests = pretend.stub(get=pretend.call_recorder(get)) + sentry_sdk = pretend.stub( + capture_message=pretend.call_recorder(lambda msg: pretend.stub()) + ) + monkeypatch.setattr(services, "requests", requests) + monkeypatch.setattr(services, "sentry_sdk", sentry_sdk) + + keys = service._refresh_keyset() + + assert keys == {} + assert metrics.increment.calls == [] + assert requests.get.calls == [ + pretend.call("https://example.com/.well-known/openid-configuration"), + pretend.call("https://example.com/.well-known/jwks.json"), + ] + assert sentry_sdk.capture_message.calls == [ + pretend.call( + "OIDC provider example failed to return JWKS JSON: " + "https://example.com/.well-known/jwks.json" + ) + ] + + def test_refresh_keyset_oidc_config_no_jwks_keys(self, monkeypatch): + metrics = pretend.stub(increment=pretend.call_recorder(lambda *a, **kw: None)) + service = services.OIDCProviderService( + provider="example", + issuer_url="https://example.com", + cache_url="rediss://fake.example.com", + metrics=metrics, + ) + + monkeypatch.setattr(services.redis, "StrictRedis", fakeredis.FakeStrictRedis) + + openid_resp = pretend.stub( + ok=True, + json=lambda: { + "jwks_uri": "https://example.com/.well-known/jwks.json", + }, + ) + jwks_resp = pretend.stub(ok=True, json=lambda: {}) + + def get(url): + if url == "https://example.com/.well-known/jwks.json": + return jwks_resp + else: + return openid_resp + + requests = pretend.stub(get=pretend.call_recorder(get)) + sentry_sdk = pretend.stub( + capture_message=pretend.call_recorder(lambda msg: pretend.stub()) + ) + monkeypatch.setattr(services, "requests", requests) + monkeypatch.setattr(services, "sentry_sdk", sentry_sdk) + + keys = service._refresh_keyset() + + assert keys == {} + assert metrics.increment.calls == [] + assert requests.get.calls == [ + pretend.call("https://example.com/.well-known/openid-configuration"), + pretend.call("https://example.com/.well-known/jwks.json"), + ] + assert sentry_sdk.capture_message.calls == [ + pretend.call("OIDC provider example returned JWKS JSON but no keys") + ] + + def test_refresh_keyset_successful(self, monkeypatch): + metrics = pretend.stub(increment=pretend.call_recorder(lambda *a, **kw: None)) + service = services.OIDCProviderService( + provider="example", + issuer_url="https://example.com", + cache_url="rediss://fake.example.com", + metrics=metrics, + ) + + # Create a fake server to provide persistent state through each + # StrictRedis.from_url context manager. + server = fakeredis.FakeServer() + from_url = functools.partial(fakeredis.FakeStrictRedis.from_url, server=server) + monkeypatch.setattr(services.redis.StrictRedis, "from_url", from_url) + + openid_resp = pretend.stub( + ok=True, + json=lambda: { + "jwks_uri": "https://example.com/.well-known/jwks.json", + }, + ) + jwks_resp = pretend.stub( + ok=True, json=lambda: {"keys": [{"kid": "fake-key-id", "foo": "bar"}]} + ) + + def get(url): + if url == "https://example.com/.well-known/jwks.json": + return jwks_resp + else: + return openid_resp + + requests = pretend.stub(get=pretend.call_recorder(get)) + sentry_sdk = pretend.stub( + capture_message=pretend.call_recorder(lambda msg: pretend.stub()) + ) + monkeypatch.setattr(services, "requests", requests) + monkeypatch.setattr(services, "sentry_sdk", sentry_sdk) + + keys = service._refresh_keyset() + + assert keys == {"fake-key-id": {"kid": "fake-key-id", "foo": "bar"}} + assert metrics.increment.calls == [] + assert requests.get.calls == [ + pretend.call("https://example.com/.well-known/openid-configuration"), + pretend.call("https://example.com/.well-known/jwks.json"), + ] + assert sentry_sdk.capture_message.calls == [] + + # Ensure that we also cached the updated keyset as part of refreshing. + keys, timeout = service._get_keyset() + assert keys == {"fake-key-id": {"kid": "fake-key-id", "foo": "bar"}} + assert timeout is True + + def test_get_key_cached(self, monkeypatch): + metrics = pretend.stub(increment=pretend.call_recorder(lambda *a, **kw: None)) + service = services.OIDCProviderService( + provider="example", + issuer_url="https://example.com", + cache_url="rediss://fake.example.com", + metrics=metrics, + ) + + keyset = { + "fake-key-id": { + "kid": "fake-key-id", + "n": "ZHVtbXkK", + "kty": "RSA", + "alg": "RS256", + "e": "AQAB", + "use": "sig", + "x5c": ["dummy"], + "x5t": "dummy", + } + } + monkeypatch.setattr(service, "_get_keyset", lambda: (keyset, True)) + + key = service.get_key("fake-key-id") + assert isinstance(key, PyJWK) + assert key.key_id == "fake-key-id" + + assert metrics.increment.calls == [] + + def test_get_key_uncached(self, monkeypatch): + metrics = pretend.stub(increment=pretend.call_recorder(lambda *a, **kw: None)) + service = services.OIDCProviderService( + provider="example", + issuer_url="https://example.com", + cache_url="rediss://fake.example.com", + metrics=metrics, + ) + + keyset = { + "fake-key-id": { + "kid": "fake-key-id", + "n": "ZHVtbXkK", + "kty": "RSA", + "alg": "RS256", + "e": "AQAB", + "use": "sig", + "x5c": ["dummy"], + "x5t": "dummy", + } + } + monkeypatch.setattr(service, "_get_keyset", lambda: ({}, False)) + monkeypatch.setattr(service, "_refresh_keyset", lambda: keyset) + + key = service.get_key("fake-key-id") + assert isinstance(key, PyJWK) + assert key.key_id == "fake-key-id" + + assert metrics.increment.calls == [] + + def test_get_key_refresh_fails(self, monkeypatch): + metrics = pretend.stub(increment=pretend.call_recorder(lambda *a, **kw: None)) + service = services.OIDCProviderService( + provider="example", + issuer_url="https://example.com", + cache_url="rediss://fake.example.com", + metrics=metrics, + ) + + monkeypatch.setattr(service, "_get_keyset", lambda: ({}, False)) + monkeypatch.setattr(service, "_refresh_keyset", lambda: {}) + + key = service.get_key("fake-key-id") + assert key is None + + assert metrics.increment.calls == [ + pretend.call( + "warehouse.oidc.get_key.error", + tags=["provider:example", "key_id:fake-key-id"], + ) + ] diff --git a/tests/unit/test_config.py b/tests/unit/test_config.py index 4a6c360003ef..5ee76cce8c69 100644 --- a/tests/unit/test_config.py +++ b/tests/unit/test_config.py @@ -332,6 +332,7 @@ def __init__(self): pretend.call(".email"), pretend.call(".accounts"), pretend.call(".macaroons"), + pretend.call(".oidc"), pretend.call(".malware"), pretend.call(".manage"), pretend.call(".packaging"), diff --git a/warehouse/config.py b/warehouse/config.py index 879145189874..5af6481bc764 100644 --- a/warehouse/config.py +++ b/warehouse/config.py @@ -178,6 +178,7 @@ def configure(settings=None): maybe_set(settings, "celery.broker_url", "BROKER_URL") maybe_set(settings, "celery.result_url", "REDIS_URL") maybe_set(settings, "celery.scheduler_url", "REDIS_URL") + maybe_set(settings, "oidc.jwk_cache_url", "REDIS_URL") maybe_set(settings, "database.url", "DATABASE_URL") maybe_set(settings, "elasticsearch.url", "ELASTICSEARCH_URL") maybe_set(settings, "elasticsearch.url", "ELASTICSEARCH_SIX_URL") @@ -459,6 +460,9 @@ def configure(settings=None): # Register support for Macaroon based authentication config.include(".macaroons") + # Register support for OIDC provider based authentication + config.include(".oidc") + # Register support for malware checks config.include(".malware") diff --git a/warehouse/oidc/__init__.py b/warehouse/oidc/__init__.py new file mode 100644 index 000000000000..de37ed217d44 --- /dev/null +++ b/warehouse/oidc/__init__.py @@ -0,0 +1,24 @@ +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from warehouse.oidc.interfaces import IOIDCProviderService +from warehouse.oidc.services import OIDCProviderServiceFactory + + +def includeme(config): + config.register_service_factory( + OIDCProviderServiceFactory( + provider="github", issuer_url="https://token.actions.githubusercontent.com" + ), + IOIDCProviderService, + name="github", + ) diff --git a/warehouse/oidc/interfaces.py b/warehouse/oidc/interfaces.py new file mode 100644 index 000000000000..4c1ea13e9297 --- /dev/null +++ b/warehouse/oidc/interfaces.py @@ -0,0 +1,32 @@ +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from zope.interface import Interface + + +class IOIDCProviderService(Interface): + def get_key(key_id): + """ + Return the JWK identified by the given KID, + fetching it if not already cached locally. + + Returns None if the JWK does not exist or the access pattern is + invalid (i.e., exceeds our internal limit on JWK requests to + each provider). + """ + pass + + def verify(token): + """ + Verify the given JWT. + """ diff --git a/warehouse/oidc/services.py b/warehouse/oidc/services.py new file mode 100644 index 000000000000..f84a116a6cf0 --- /dev/null +++ b/warehouse/oidc/services.py @@ -0,0 +1,177 @@ +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json + +import redis +import requests +import sentry_sdk + +from jwt import PyJWK +from zope.interface import implementer + +from warehouse.metrics.interfaces import IMetricsService +from warehouse.oidc.interfaces import IOIDCProviderService + + +@implementer(IOIDCProviderService) +class OIDCProviderService: + def __init__(self, provider, issuer_url, cache_url, metrics): + self.provider = provider + self.issuer_url = issuer_url + self.cache_url = cache_url + self.metrics = metrics + + self._provider_jwk_key = f"/warehouse/oidc/jwks/{self.provider}" + self._provider_timeout_key = f"{self._provider_jwk_key}/timeout" + + def _store_keyset(self, keys): + """ + Store the given keyset for the given provider, setting the timeout key + in the process. + """ + + with redis.StrictRedis.from_url(self.cache_url) as r: + r.set(self._provider_jwk_key, json.dumps(keys)) + r.setex(self._provider_timeout_key, 60, "placeholder") + + def _get_keyset(self): + """ + Return the cached keyset for the given provider, or an empty + keyset if no keys are currently cached. + """ + + with redis.StrictRedis.from_url(self.cache_url) as r: + keys = r.get(self._provider_jwk_key) + timeout = bool(r.exists(self._provider_timeout_key)) + if keys is not None: + return (json.loads(keys), timeout) + else: + return ({}, timeout) + + def _refresh_keyset(self): + """ + Attempt to refresh the keyset from the OIDC provider, assuming no + timeout is in effect. + + Returns the refreshed keyset, or the cached keyset if a timeout is + in effect. + + Returns the cached keyset on any provider access or format errors. + """ + + # Fast path: we're in a cooldown from a previous refresh. + keys, timeout = self._get_keyset() + if timeout: + self.metrics.increment( + "warehouse.oidc.refresh_keyset.timeout", + tags=[f"provider:{self.provider}"], + ) + return keys + + oidc_url = f"{self.issuer_url}/.well-known/openid-configuration" + + resp = requests.get(oidc_url) + + # For whatever reason, an OIDC provider's configuration URL might be + # offline. We don't want to completely explode here, since other + # providers might still be online (and need updating), so we spit + # out an error and return None instead of raising. + if not resp.ok: + sentry_sdk.capture_message( + f"OIDC provider {self.provider} failed to return configuration: " + f"{oidc_url}" + ) + return keys + + oidc_conf = resp.json() + jwks_url = oidc_conf.get("jwks_uri") + + # A valid OIDC configuration MUST have a `jwks_uri`, but we + # defend against its absence anyways. + if jwks_url is None: + sentry_sdk.capture_message( + f"OIDC provider {self.provider} is returning malformed " + "configuration (no jwks_uri)" + ) + return keys + + resp = requests.get(jwks_url) + + # Same reasoning as above. + if not resp.ok: + sentry_sdk.capture_message( + f"OIDC provider {self.provider} failed to return JWKS JSON: " + f"{jwks_url}" + ) + return keys + + jwks_conf = resp.json() + new_keys = jwks_conf.get("keys") + + # Another sanity test: an OIDC provider should never return an empty + # keyset, but there's nothing stopping them from doing so. We don't + # want to cache an empty keyset just in case it's a short-lived error, + # so we check here, error, and return the current cache instead. + if not new_keys: + sentry_sdk.capture_message( + f"OIDC provider {self.provider} returned JWKS JSON but no keys" + ) + return keys + + keys = {key["kid"]: key for key in new_keys} + self._store_keyset(keys) + + return keys + + def get_key(self, key_id): + """ + Return a JWK for the given key ID, or None if the key can't be found + in this provider's keyset. + """ + + keyset, _ = self._get_keyset() + if key_id not in keyset: + keyset = self._refresh_keyset() + if key_id not in keyset: + self.metrics.increment( + "warehouse.oidc.get_key.error", + tags=[f"provider:{self.provider}", f"key_id:{key_id}"], + ) + return None + return PyJWK(keyset[key_id]) + + def verify(self, token): + return NotImplemented + + +class OIDCProviderServiceFactory: + def __init__(self, provider, issuer_url, service_class=OIDCProviderService): + self.provider = provider + self.issuer_url = issuer_url + self.service_class = service_class + + def __call__(self, _context, request): + cache_url = request.registry.settings["oidc.jwk_cache_url"] + metrics = request.find_service(IMetricsService, context=None) + + return self.service_class(self.provider, self.issuer_url, cache_url, metrics) + + def __eq__(self, other): + if not isinstance(other, OIDCProviderServiceFactory): + return NotImplemented + + return (self.provider, self.issuer_url, self.service_class) == ( + other.provider, + other.issuer_url, + other.service_class, + )