Skip to content

Commit

Permalink
Code revision and refactoring
Browse files Browse the repository at this point in the history
  • Loading branch information
flashguerdon committed Oct 11, 2024
1 parent 2984f05 commit b125321
Show file tree
Hide file tree
Showing 11 changed files with 416 additions and 224 deletions.
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -108,3 +108,6 @@ tests/resources/keys/*.pem
.DS_Store
.vscode
.idea

# snyk
.dccache
86 changes: 64 additions & 22 deletions fence/blueprints/login/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from fence.config import config
from fence.errors import UserError
from fence.metrics import metrics

logger = get_logger(__name__)


Expand Down Expand Up @@ -96,8 +97,11 @@ def __init__(
"OPENID_CONNECT"
].get(self.idp_name, {})
self.app = app
self.check_groups = config.get("CHECK_GROUPS", False)
self.app = app if app is not None else flask.current_app
# this attribute is only applicable to some OAuth clients
# (e.g., not all clients need read_authz_groups_from_tokens)
self.is_read_authz_groups_from_tokens_enabled = getattr(
self.client, "read_authz_groups_from_tokens", False
)

def get(self):
# Check if user granted access
Expand Down Expand Up @@ -145,9 +149,15 @@ def get(self):
if expires is None:
expires = int(time.time()) + config["REFRESH_TOKEN_EXPIRES_IN"]

# # Store refresh token in db
if self.check_groups:
self.client.store_refresh_token(flask.g.user,refresh_token,expires)
# Store refresh token in db
if self.is_read_authz_groups_from_tokens_enabled:
# Ensure flask.g.user exists to avoid a potential AttributeError
if getattr(flask.g, "user", None):
self.client.store_refresh_token(flask.g.user, refresh_token, expires)
else:
self.logger.error(
"User information is missing from flask.g; cannot store refresh token."
)

self.post_login(
user=flask.g.user,
Expand All @@ -157,22 +167,50 @@ def get(self):

return resp

# see if the refresh token is a JWT. if it is decode to get the exp. we do not care about signatures, the
# reason is that the refresh token is checked by the IDP, not us, thus we don't have the key in most circumstances
# Also check exp from introspect results
def extract_exp(self, refresh_token):
"""
Extract the expiration time (exp) from a refresh token.
This function attempts to extract the `exp` (expiration time) from a given refresh token using
three methods:
1. Using PyJWT to decode the token (without signature verification).
2. Introspecting the token (if supported by the identity provider).
3. Manually base64 decoding the token's payload (if it's a JWT).
Disclaimer:
------------
This function assumes that the refresh token is valid and does not perform any JWT validation.
For any JWT coming from an OpenID Connect (OIDC) provider, validation should be done using the
public keys provided by the IdP (from the JWKS endpoint) before using this function to extract
the expiration time (`exp`). Without validation, the token's integrity and authenticity cannot
be guaranteed, which may expose your system to security risks.
Ensure validation is handled prior to calling this function, especially in any public or
production-facing contexts.
Parameters:
------------
refresh_token: str
The JWT refresh token to extract the expiration from.
Returns:
---------
int or None:
The expiration time (exp) in seconds since the epoch, or None if extraction fails.
"""

# Method 1: PyJWT
try:
# Skipping keys since we're not verifying the signature
decoded_refresh_token = jwt.decode(
refresh_token,
options=
{
options={
"verify_aud": False,
"verify_at_hash": False,
"verify_signature": False
"verify_signature": False,
},
algorithms=["RS256", "HS512"]
algorithms=["RS256", "HS512"],
)
exp = decoded_refresh_token.get("exp")

Expand All @@ -194,9 +232,9 @@ def extract_exp(self, refresh_token):
# Method 3: Manual base64 decoding
try:
# Assuming the token is a JWT (header.payload.signature)
payload_encoded = refresh_token.split('.')[1]
payload_encoded = refresh_token.split(".")[1]
# Add necessary padding for base64 decoding
payload_encoded += '=' * (4 - len(payload_encoded) % 4)
payload_encoded += "=" * (4 - len(payload_encoded) % 4)
payload_decoded = base64.urlsafe_b64decode(payload_encoded)
payload_json = json.loads(payload_decoded)
exp = payload_json.get("exp")
Expand All @@ -212,16 +250,16 @@ def extract_exp(self, refresh_token):
def introspect_token(self, token):

try:
introspect_endpoint = self.client.get_value_from_discovery_doc("introspection_endpoint", "")
introspect_endpoint = self.client.get_value_from_discovery_doc(
"introspection_endpoint", ""
)

# Headers and payload for the introspection request
headers = {
"Content-Type": "application/x-www-form-urlencoded"
}
headers = {"Content-Type": "application/x-www-form-urlencoded"}
data = {
"token": token,
"client_id": self.client.settings.get("client_id"),
"client_secret": self.client.settings.get("client_secret")
"client_id": self.client.client_id,
"client_secret": self.client.client_secret,
}

response = requests.post(introspect_endpoint, headers=headers, data=data)
Expand All @@ -247,8 +285,12 @@ def post_login(self, user=None, token_result=None, **kwargs):
client_id=flask.session.get("client_id"),
)

if self.check_groups:
self.client.update_user_authorization(user=user,pkey_cache=None,db_session=None,idp_name=self.idp_name)
# this attribute is only applicable to some OAuth clients
# (e.g., not all clients need read_authz_groups_from_tokens)
if self.is_read_authz_groups_from_tokens_enabled:
self.client.update_user_authorization(
user=user, pkey_cache=None, db_session=None, idp_name=self.idp_name
)

if token_result:
username = token_result.get(self.username_field)
Expand Down
17 changes: 16 additions & 1 deletion fence/config-default.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ DB_MIGRATION_POSTGRES_LOCK_KEY: 100
# - WARNING: Be careful changing the *_ALLOWED_SCOPES as you can break basic
# and optional functionality
# //////////////////////////////////////////////////////////////////////////////////////
CHECK_GROUPS: false

OPENID_CONNECT:
# any OIDC IDP that does not differ from the generic implementation can be
# configured without code changes
Expand All @@ -116,6 +116,21 @@ OPENID_CONNECT:
multifactor_auth_claim_info: # optional, include if you're using arborist to enforce mfa on a per-file level
claim: '' # claims field that indicates mfa, either the acr or acm claim.
values: [ "" ] # possible values that indicate mfa was used. At least one value configured here is required to be in the token
# is_authz_groups_sync_enabled: A configuration flag that determines whether the application should
# verify and synchronize user group memberships between the identity provider (IdP)
# and the local authorization system (Arborist). When enabled, the system retrieves
# the user's group information from their token issued by the IdP and compares it against
# the groups defined in the local system. Based on the comparison, the user is added to
# or removed from relevant groups in the local system to ensure their group memberships
# remain up-to-date. If this flag is disabled, no group synchronization occurs
is_authz_groups_sync_enabled: true
authz_groups_sync:
# This defines the prefix used to identify authorization groups.
group_prefix: "some_prefix"
# This flag indicates whether the audience (aud) claim in the JWT should be verified during token validation.
verify_aud: true
# This specifies the expected audience (aud) value for the JWT, ensuring that the token is intended for use with the 'fence' service.
audience: fence
# These Google values must be obtained from Google's Cloud Console
# Follow: https://developers.google.com/identity/protocols/OpenIDConnect
#
Expand Down
6 changes: 4 additions & 2 deletions fence/error_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,10 @@ def get_error_response(error: Exception):
)
)


#raise error
# TODO: Issue: Error messages are obfuscated, the line below needs be
# uncommented when troubleshooting errors.
# Breaks tests if not commented out / removed. We need a fix for this.
# raise error

# don't include internal details in the public error message
# to do this, only include error messages for known http status codes
Expand Down
37 changes: 17 additions & 20 deletions fence/job/access_token_updater.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def __init__(

self.visa_types = config.get("USERSYNC", {}).get("visa_types", {})

#introduce list on self which contains all clients that need update
# introduce list on self which contains all clients that need update
self.oidc_clients_requiring_token_refresh = []

# keep this as a special case, because RAS will not set group information configuration.
Expand All @@ -54,28 +54,24 @@ def __init__(
if "ras" not in oidc:
self.logger.error("RAS client not configured")
else:
#instead of setting self.ras_client add the RASClient to self.oidc_clients_requiring_token_refresh
ras_client = RASClient(
oidc["ras"],
HTTP_PROXY=config.get("HTTP_PROXY"),
logger=logger,
)
self.oidc_clients_requiring_token_refresh.append(ras_client)

#initialise a client for each OIDC client in oidc, which does has group information set to true and add them
# Initialise a client for each OIDC client in oidc, which does has gis_authz_groups_sync_enabled set to true and add them
# to oidc_clients_requiring_token_refresh
if config["CHECK_GROUPS"]:
for oidc_name in oidc:
if "groups" in oidc.get(oidc_name):
groups = oidc.get(oidc_name).get("groups")
if groups.get("read_group_information", False):
oidc_client = OIDCClient(
settings=oidc[oidc_name],
HTTP_PROXY=config.get("HTTP_PROXY"),
logger=logger,
idp=oidc_name
)
self.oidc_clients_requiring_token_refresh.append(oidc_client)
for oidc_name in oidc:
if oidc.get(oidc_name).get("is_authz_groups_sync_enabled", False):
oidc_client = OIDCClient(
settings=oidc[oidc_name],
HTTP_PROXY=config.get("HTTP_PROXY"),
logger=logger,
idp=oidc_name,
)
self.oidc_clients_requiring_token_refresh.append(oidc_client)

async def update_tokens(self, db_session):
"""
Expand All @@ -89,7 +85,7 @@ async def update_tokens(self, db_session):
"""
start_time = time.time()
#Change this line to reflect we are refreshing tokens, not just visas
# Change this line to reflect we are refreshing tokens, not just visas
self.logger.info("Initializing Visa Update and Token refreshing Cronjob . . .")
self.logger.info("Total concurrency size: {}".format(self.concurrency))
self.logger.info("Total thread pool size: {}".format(self.thread_pool_size))
Expand Down Expand Up @@ -181,33 +177,34 @@ async def updater(self, name, updater_queue, db_session):
pkey_cache=self.pkey_cache,
db_session=db_session,
)

else:
self.logger.debug(
f"Updater {name} NOT updating authorization for "
f"user {user.username} because no client was found for IdP: {user.identity_provider}"
)

# Only mark the task as done if processing succeeded
updater_queue.task_done()

except Exception as exc:
self.logger.error(
f"Updater {name} could not update authorization "
f"for {user.username if user else 'unknown user'}. Error: {exc}. Continuing."
)
# Still mark the task as done even if there was an exception
# Ensure task is marked done if exception occurs
updater_queue.task_done()

def _pick_client(self, user):
"""
Select OIDC client based on identity provider.
"""
# change this logic to return any client which is in self.oidc_clients_requiring_token_refresh (check against "name")
self.logger.info(f"Selecting client for user {user.username}")
client = None
for oidc_client in self.oidc_clients_requiring_token_refresh:
if getattr(user.identity_provider, "name") == oidc_client.idp:
self.logger.info(f"Picked client: {oidc_client.idp} for user {user.username}")
self.logger.info(
f"Picked client: {oidc_client.idp} for user {user.username}"
)
client = oidc_client
break
if not client:
Expand Down
Loading

0 comments on commit b125321

Please sign in to comment.