diff --git a/sos/deviceauth.py b/sos/deviceauth.py new file mode 100644 index 0000000000..108f46b837 --- /dev/null +++ b/sos/deviceauth.py @@ -0,0 +1,291 @@ +# Copyright (C) 2023 Red Hat, Inc., Jose Castillo + +# This file is part of the sos project: https://github.com/sosreport/sos +# +# This copyrighted material is made available to anyone wishing to use, +# modify, copy, or redistribute it subject to the terms and conditions of +# version 2 of the GNU General Public License. +# +# See the LICENSE file in the source distribution for further information. + +import logging +import requests +import time +from datetime import datetime, timedelta +import os +from sos.keyring import Key, KeyNotFoundError +import shutil + +DEVICE_AUTH_CLIENT_ID = "sos-tools" +GRANT_TYPE_DEVICE_CODE = "urn:ietf:params:oauth:grant-type:device_code" +DATETIME_FORMAT = "%Y-%m-%dT%H:%M:%SZ" + +logger = logging.getLogger("sos") + + +def try_read_refresh_token(): + """Try to read locally stored refresh token + + Returns: + str: Returns ODIC refresh token if found otherwise None + """ + try: + RHELKey = Key() + key_refresh_token = RHELKey.search('sos-tools_refresh_token') + key_username = RHELKey.search('sos-tools-user') + except KeyNotFoundError: + logger.info("Refresh token does not exist in keyring or is expired.") + return None + except Exception as e: + logger.error("Error encoutered while accessing keyring. {}".format(e)) + return None + if key_username.data.decode() != os.getlogin(): + return None + + return key_refresh_token.data.decode() + + +class AuthClass: + """ + Device Authorization Class + """ + + def __init__(self): + + self.client_identifier_url = "https://sso.redhat.com/auth/"\ + "realms/redhat-external/protocol/openid-connect/auth/device" + self.token_endpoint = "https://sso.redhat.com/auth/realms/"\ + "redhat-external/protocol/openid-connect/token" + self._access_token = None + self._access_expires_at = None + self._refresh_token = None + self._refresh_expires_at = None + self._refresh_expires_in = None + self._user_verification_url = None + self.__device_code = None + self.RHELKey = Key() + + # Lets check first if we have keyctl installed so we can + # store the token in the keyring + if not shutil.which("keyctl"): + raise Exception("keyctl tool is not installed" + " and is required to store auth tokens.") + + self._use_device_code_grant() + + def _use_device_code_grant(self): + """ + Start the device auth flow. First check for the refresh token stored in + the session keyring. If they are not stored + or are invalid, request new device code. If the stored refresh token is + valid, use it to get new access_token + + """ + stored_refresh_token = try_read_refresh_token() + + if not stored_refresh_token: + self._request_device_code() + print( + "Please visit the following URL to authenticate this" + f" device {self._verification_uri_complete}") + self.poll_for_auth_completion() + else: + self._use_refresh_token_grant(stored_refresh_token) + + def _request_device_code(self): + """ + Initialize new Device Authorization Grant attempt by + requesting a new device code. + + """ + data = "client_id={}".format(DEVICE_AUTH_CLIENT_ID) + headers = {'content-type': 'application/x-www-form-urlencoded'} + try: + res = requests.post( + self.client_identifier_url, + data=data, + headers=headers) + res.raise_for_status() + response = res.json() + self._user_code = response.get("user_code") + self._verification_uri = response.get("verification_uri") + self._interval = response.get("interval") + self.__device_code = response.get("device_code") + self._verification_uri_complete = response.get( + "verification_uri_complete") + except requests.HTTPError as e: + raise e + except Exception as e: + raise e + + def poll_for_auth_completion(self): + """ + Continuously poll OIDC token endpoint until the user is successfully + authenticated or an error occurs. + + """ + token_data = {'grant_type': GRANT_TYPE_DEVICE_CODE, + 'client_id': DEVICE_AUTH_CLIENT_ID, + 'device_code': self.__device_code} + + while self._access_token is None: + time.sleep(self._interval) + try: + check_auth_completion = requests.post(self.token_endpoint, + data=token_data) + + status_code = check_auth_completion.status_code + + if status_code == 200: + logger.info("The SSO authentication is successful") + self._set_token_data(check_auth_completion.json()) + if status_code not in [200, 400]: + raise Exception(status_code, check_auth_completion.text) + if status_code == 400 and \ + check_auth_completion.json()['error'] not in \ + ("authorization_pending", "slow_down"): + raise Exception(status_code, check_auth_completion.text) + except Exception as e: + raise e + + def _set_token_data(self, token_data): + """ + Set the class attributes as per the input token_data received and + persist it in the local keyring to avoid + visting the browser frequently. + :param token_data: Token data containing access_token, refresh_token + and their expiry etc. + """ + self._access_token = token_data.get("access_token") + self._access_expires_at = datetime.utcnow( + ) + timedelta(seconds=token_data.get("expires_in")) + self._refresh_token = token_data.get("refresh_token") + self._refresh_expires_in = token_data.get("refresh_expires_in") + if self._refresh_expires_in == 0: + self._refresh_expires_at = datetime.max + else: + self._refresh_expires_at = datetime.utcnow( + ) + timedelta(seconds=self._refresh_expires_in) + self.persist_refresh_token() + + def get_access_token(self): + """ + Get the valid access_token at any given time. + :return: Access_token + :rtype: string + """ + if self.is_access_token_valid(): + return self._access_token + + if self.grant_type == "client_credentials": + self._use_client_credentials_grant() + return self._access_token + + elif self.grant_type == "device_auth": + if self.is_refresh_token_valid(): + self._use_refresh_token_grant() + return self._access_token + else: + return self.request_new_device_code() + + def is_access_token_valid(self): + """ + Check the validity of access_token. We are considering it invalid 180 + sec. prior to it's exact expiry time. + :return: True/False + + """ + return self._access_token and self._access_expires_at and \ + self._access_expires_at - timedelta(seconds=180) > \ + datetime.utcnow() + + def is_refresh_token_valid(self): + """ + Check the validity of refresh_token. We are considering it invalid + 180 sec. prior to it's exact expiry time. + + :return: True/False + + """ + return self._refresh_token and self._refresh_expires_at and \ + self._refresh_expires_at - timedelta(seconds=180) > \ + datetime.utcnow() + + def _use_refresh_token_grant(self, refresh_token=None): + """ + Fetch the new access_token and refresh_token using the existing + refresh_token and persist it. + :param refresh_token: optional param for refresh_token + + """ + refresh_token_data = {'client_id': DEVICE_AUTH_CLIENT_ID, + 'grant_type': 'refresh_token', + 'refresh_token': self._refresh_token if not + refresh_token else refresh_token} + + refresh_token_res = requests.post(self.token_endpoint, + data=refresh_token_data) + + if refresh_token_res.status_code == 200: + self._set_token_data(refresh_token_res.json()) + + elif refresh_token_res.status_code == 400 and 'invalid' in\ + refresh_token_res.json()['error']: + logger.warning("Problem while fetching the new tokens from refresh" + " token grant - {} {}." + " New Device code will be requested !".format + (refresh_token_res.status_code, + refresh_token_res.json()['error'])) + self.request_new_device_code() + else: + raise Exception( + "Something went wrong while using the " + "Refresh token grant for fetching tokens.") + + def request_new_device_code(self): + """Initialize new Device Authorization Grant + attempt by requesting a new device code. + """ + self._use_device_code_grant() + + def persist_refresh_token(self): + """Persist current refresh token in keyring using keyctl + + Returns: + bool: True if refresh token was successfully + persisted, otherwise False + + """ + if self.is_refresh_token_valid(): + try: + key_refresh_token = self.RHELKey.search( + 'sos-tools_refresh_token') + key_refresh_token.update(self._refresh_token) + except KeyNotFoundError: + key_refresh_token = self.RHELKey.add( + 'sos-tools_refresh_token', self._refresh_token) + except Exception as e: + logger.error( + "Keyctl error encountered while reading " + "keyring {}".format(e)) + return False + + try: + key_username = self.RHELKey.search('sos-tools-user') + key_username.update(os.getlogin()) + except KeyNotFoundError: + key_username = self.RHELKey.add( + 'sos-tools-user', os.getlogin()) + except Exception as e: + logger.error( + "Keyctl error encountered while reading " + "keyring {}".format(e)) + return False + + key_refresh_token.set_timeout(self._refresh_expires_in - 300) + key_username.set_timeout(self._refresh_expires_in - 300) + + return True + else: + logger.info("Cannot save invalid refresh token in keyring") + return False diff --git a/sos/keyring.py b/sos/keyring.py new file mode 100644 index 0000000000..5bcebe447b --- /dev/null +++ b/sos/keyring.py @@ -0,0 +1,453 @@ +# Copyright (C) 2023 Red Hat, Inc., Jose Castillo + +# This file is part of the sos project: https://github.com/sosreport/sos +# +# This copyrighted material is made available to anyone wishing to use, +# modify, copy, or redistribute it subject to the terms and conditions of +# version 2 of the GNU General Public License. +# +# See the LICENSE file in the source distribution for further information. + +import base64 +import subprocess +from itertools import cycle +import os + +DEFAULT_KEYRING = "@s" +DEFAULT_KEYTYPE = "user" + + +def __xor(salt, string): + """ + A simple utility function to obfuscate the value of the key when a user + elects to save the token in the keyring (via keyctl). This merely provides + a convenience against an accidental display of the key. + (eg. keyctl pipe keyid) + + :param salt: Salt value for obfuscation + :param string: The string which needs to be obfuscated + + :return: Obfuscated string + """ + str_ary = [] + for x, y in zip(string, cycle(salt)): + str_ary.append(chr(ord(x) ^ ord(y))) + return ''.join(str_ary) + + +def value_encode(value, salt): + """ + This function will obfuscate the value of a key or any other value. + + :param value: Value to be obfuscated + :param salt: Salt value on which obfuscation will happen + + :return: The obfuscated string. + """ + if value and salt: + val = __xor(salt, value) + if isinstance(val, str): + val = val.encode() + return base64.urlsafe_b64encode(val) + else: + return None + + +def value_decode(value, salt): + """ + This function will de-obfuscate a value of a key or any other value + previously obfuscated with value_encode(). + + :param value: Value to be de-obfuscated + :param salt: Salt value on which obfuscation happened + + :return: The de-obfuscated string + """ + if value and salt: + if isinstance(value, str): + value = value.encode() + value = base64.urlsafe_b64decode(value) + if not isinstance(value, str): + value = value.decode() + return __xor(salt, value) + else: + return None + + +class KeyctlWrapperException(Exception): + @staticmethod + def _get_key_description(keyid=None, keyname=None): + if keyid is not None and keyname is not None: + keydesc = "({} / '{}')".format(keyid, keyname) + elif keyid is not None: + keydesc = "({})".format(keyid) + elif keyname is not None: + keydesc = "('{}')".format(keyname) + else: + keydesc = '(undefined)' + + return keydesc + + +class KeyNotFoundError(KeyctlWrapperException): + def __init__(self, message=None, keyid=None, keyname=None): + if message is None: + message = 'Key {} does not exists in the kernel keyring.'.format( + self._get_key_description(keyid, keyname) + ) + super(KeyNotFoundError, self).__init__(message) + + +class KeyAlreadyExistError(KeyctlWrapperException): + def __init__(self, message=None, keyid=None, keyname=None): + if message is None: + message = 'Key {} already exists in the kernel keyring.'.format( + self._get_key_description(keyid, keyname) + ) + super(KeyAlreadyExistError, self).__init__(message) + + +class KeyctlOperationError(KeyctlWrapperException): + def __init(self, message=None, keyid=None, keyname=None, errmsg=None): + if message is None: + message = 'Operation on key {} failed. ErrorMsg:{}'.format( + self._get_key_description(keyid, keyname), errmsg + ) + + +class KeyctlWrapper(object): + def __init__(self, keyring, keytype): + self.keyring = keyring + self.keytype = keytype + + if self.keyring != DEFAULT_KEYRING: + ret, out, err = self._system( + ["keyctl", "show", str(self.keyring)], check=False + ) + # If keyring does not exist, create the keyring and set permission + if ret != 0: + self.create_keyring(new_keyring_name=self.keyring) + keyring_id = self.get_custom_keyring_id() + self.set_perm(keyring_id, "0x3f3f3f3f") + + def _system(self, args, data=None, check=True): + try: + p = subprocess.Popen( + args, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + stdin=subprocess.PIPE, + bufsize=4096, + ) + except OSError as e: + raise OSError( + "Command '{}' execution failed. ErrMsg:{}" + .format(" ".join(args), e) + ) + + if data is None: + (out, err) = p.communicate() + else: + (out, err) = p.communicate(input=bytes(data, "utf-8")) + + ret = p.returncode + + if not check: + return ret, out, err + elif ret == 0: + return out + else: + raise KeyctlOperationError( + errmsg="({}){} {}".format(ret, err, out)) + + def get_all_key_ids(self): + out = self._system(["keyctl", "rlist", self.keyring]) + ids = out.split() + key_ids = [int(x) for x in ids] + return key_ids + + def get_id_from_name(self, name): + ret, out, err = self._system( + ["keyctl", "search", self.keyring, self.keytype, name], check=False + ) + + if ret != 0: + raise KeyNotFoundError(keyname=name) + + keyid = int(out.strip()) + + return keyid + + def get_name_from_id(self, keyid): + ret, out, err = self._system(["keyctl", + "rdescribe", str(keyid)], check=False) + + if ret != 0: + raise KeyNotFoundError(keyid=keyid) + + name = b";".join(out.split(b";")[4:]) + + return name.rstrip(b"\n") + + def get_data_from_id(self, keyid, mode="raw"): + if mode.lower() == "raw": + kmode = "pipe" + elif mode.lower() == "hex": + kmode = "read" + else: + raise AttributeError("mode must be one of ['raw', 'hex'].") + + ret, out, err = self._system(["keyctl", + kmode, str(keyid)], check=False) + + if ret == 1: + raise KeyNotFoundError(keyid=keyid) + + if mode == "raw": + return out + else: + h = b"".join(out.splitlines()[1:]) + return h.replace(b" ", b"") + + def get_custom_keyring_id(self): + ret, out, err = self._system( + ["keyctl", "describe", str(self.keyring)], check=False + ) + + if ret != 0: + raise KeyNotFoundError(keyname=self.keyring) + + return out.decode().split(":")[0] + + def add_key(self, name, data): + try: + keyid = self.get_id_from_name(name) + raise KeyAlreadyExistError(keyid=keyid, keyname=name) + except KeyNotFoundError: + pass + + out = self._system(["keyctl", "padd", + self.keytype, name, self.keyring], data) + keyid = int(out) + + return keyid + + def update_key(self, keyid, data): + ret, out, err = self._system( + ["keyctl", "pupdate", str(keyid)], data, check=False + ) + + if ret == 1: + raise KeyNotFoundError(keyid=keyid) + elif ret != 0: + raise KeyctlOperationError(keyid=keyid, + errmsg="({}){}".format(ret, err)) + + def remove_key(self, keyid: int): + ret, out, err = self._system(["keyctl", + "revoke", str(keyid)], check=False) + if ret == 1: + raise KeyNotFoundError(keyid=keyid) + elif ret != 0: + raise KeyctlOperationError(keyid=keyid, + errmsg="({}){}".format(ret, err)) + + self._system(["keyctl", "unlink", str(keyid), self.keyring]) + + def set_timeout(self, keyid, time): + ret, out, err = self._system( + ["keyctl", "timeout", str(keyid), str(time)], check=False + ) + + if ret == 1: + raise KeyNotFoundError(keyid=keyid) + elif ret != 0: + raise KeyctlOperationError(keyid=keyid, + errmsg="({}){}".format(ret, err)) + + def create_keyring(self, new_keyring_name, keyring=DEFAULT_KEYRING): + new_keyring_name = new_keyring_name.split(":")[-1] + ret, out, err = self._system( + ["keyctl", "newring", str(new_keyring_name), str(keyring)], + check=False + ) + + if ret != 0: + raise KeyctlOperationError(errmsg="({}){}".format(ret, err)) + + def set_perm(self, keyid, mask): + ret, out, err = self._system( + ["keyctl", "setperm", str(keyid), str(mask)], check=False + ) + + if ret == 1: + raise KeyNotFoundError(keyid=keyid) + elif ret != 0: + raise KeyctlOperationError(keyid=keyid, + errmsg="({}){}".format(ret, err)) + + def clear_keyring(self): + self._system(["keyctl", "clear", self.keyring]) + + +class Key(object): + """ + Represents a key in a keyring. + Keys and their related operations can be managed. + """ + def __init__(self, keyid=None, keyring=DEFAULT_KEYRING, + keytype=DEFAULT_KEYTYPE): + """ + Initializes a Key object. If keyid is provided, + it loads the key with the specified ID. + + :param keyid: The ID of the key to be initialized. Defaults to None. + :param keyring: The name of the keyring to be used. + Defaults to DEFAULT_KEYRING i.e. @s (Session keyring) + :param keytype: The type of the key. + Defaults to the DEFAULT_KEYTYPE i.e. user + """ + self.id = None + self.name = None + self.data = None + self.data_hex = None + self._keyctl = KeyctlWrapper(keyring, keytype) + + if keyid is not None: + self._load_key(keyid) + + def __repr__(self): + """ + Returns string representation of the Key object. + """ + return "<{}({}, '{}', '{}')>".format( + self.__class__.__name__, self.id, self.name, self.data + ) + + def _load_key(self, keyid): + """ + Loads key information based on the provided key ID. + """ + self.id = keyid + self.name = self._keyctl.get_name_from_id(keyid) + self.data = self._keyctl.get_data_from_id(keyid) + self.data_hex = self._keyctl.get_data_from_id(keyid, "hex") + + def list(self): + """ + Returns a list of Key objects representing all keys in the keyring. + + :return: list object containing Key objects. + :rtype: list + """ + keyids = self._keyctl.get_all_key_ids() + + keylist = [] + for keyid in keyids: + key = Key(keyid, self._keyctl.keyring, self._keyctl.keytype) + keylist.append(key) + + return keylist + + def search(self, name): + """ + Searches for a key by name and returns a Key object + representing it. + If a key is not found by the given name, + KeyNotFoundError exception is raised. + + :param name: Name of the key to be searched in string format + :return: Key object + """ + key = Key(keyring=self._keyctl.keyring, keytype=self._keyctl.keytype) + key.id = key._keyctl.get_id_from_name(name) + key.name = name + key.data = key._keyctl.get_data_from_id(key.id) + key.data_hex = key._keyctl.get_data_from_id(key.id, "hex") + return key + + def add(self, name, data): + """ + Adds a new key with the specified name and + obfuscated data to the keyring. + If a key already exists with the provided name, + KeyAlreadyExistsError exception is raised. + + :param name: Name of the key to be created in string format + :param data: Key data in string format + :return: Key object + """ + data = value_encode(data, os.getlogin()).decode() + keyid = self._keyctl.add_key(name, data) + key = Key(keyid, keyring=self._keyctl.keyring, + keytype=self._keyctl.keytype) + return key + + def get_value(self): + """ + Returns the de-obfuscated data of the key. + + :return: De-obfuscated data in string format + """ + return value_decode(self.data, os.getlogin()) + + def set_timeout(self, time): + """ + Sets a timeout for the key's expiration time. + + :param time: Expiration time in seconds denoted by an integer value + """ + self._keyctl.set_timeout(self.id, time) + + def update(self, data): + """ + Updates the data associated with the key. + + :param data: Key data in string format + """ + data = value_encode(data, os.getlogin()).decode() + self._keyctl.update_key(self.id, data) + self._load_key(self.id) + + def set_perm(self, mask): + """ + Sets the permission mask for the key. + + :param mask: The permission mask in string format. + """ + self._keyctl.set_perm(self.id, mask) + + def delete(self): + """ + Deletes the key from the keyring. + """ + self._keyctl.remove_key(self.id) + + def update_or_add(self, name, data, time=0): + """ + Searches for a key by name and updates its data. + If the key does not exist, it adds a new key. + + :param name: Name of the key in string format + :param data: Key data in string format + :param time: Expiration time in seconds denoted by + an integer value. + Default value is zero which denotes no expiration. + + :return: Key object after the updation/creation + """ + try: + key = self.search(name) + if key: + key.update(data) + except KeyNotFoundError: + key = self.add(name, data) + except Exception as e: + raise e + + key.set_perm('0x3f3f3f3f') + + if time: + key.set_timeout(time) + + return key diff --git a/sos/policies/distros/redhat.py b/sos/policies/distros/redhat.py index 15241e2782..9adb9b0883 100644 --- a/sos/policies/distros/redhat.py +++ b/sos/policies/distros/redhat.py @@ -12,6 +12,7 @@ import os import sys import re +from sos.deviceauth import AuthClass from sos.report.plugins import RedHatPlugin from sos.presets.redhat import (RHEL_PRESETS, ATOMIC_PRESETS, RHV, RHEL, @@ -223,6 +224,7 @@ class RHELPolicy(RedHatPolicy): """ + disclaimer_text + "%(vendor_text)s\n") _upload_url = RH_SFTP_HOST _upload_method = 'post' + _device_token = None def __init__(self, sysroot=None, init=None, probe_runtime=True, remote_exec=None): @@ -260,25 +262,25 @@ def check(cls, remote=''): return False def prompt_for_upload_user(self): - if self.commons['cmdlineopts'].upload_user: - return - # Not using the default, so don't call this prompt for RHCP - if self.commons['cmdlineopts'].upload_url: - super(RHELPolicy, self).prompt_for_upload_user() + self.ui_log.info( + _("The option --upload-user has been deprecated in favour" + " of device authorization") + ) + if self.case_id: + RHELAuth = AuthClass() + self._device_token = RHELAuth.get_access_token() return - if not self.get_upload_user(): - if self.case_id: - self.upload_user = input(_( - "Enter your Red Hat Customer Portal username for " - "uploading [empty for anonymous SFTP]: ") - ) - else: # no case id provided => failover to SFTP - self.upload_url = RH_SFTP_HOST - self.ui_log.info("No case id provided, uploading to SFTP") - self.upload_user = input(_( - "Enter your Red Hat Customer Portal username for " - "uploading to SFTP [empty for anonymous]: ") - ) + else: # no case id provided => failover to SFTP + self.upload_url = RH_SFTP_HOST + self.ui_log.info("No case id provided, uploading to SFTP") + self.upload_user = input(_( + "Enter your Red Hat Customer Portal username for " + "uploading to SFTP [empty for anonymous]: ") + ) + + def prompt_for_upload_password(self): + # With OIDC we don't ask for user/pass anymore + return def get_upload_url(self): if self.upload_url: @@ -291,6 +293,25 @@ def get_upload_url(self): rh_case_api = "/support/v1/cases/%s/attachments" return RH_API_HOST + rh_case_api % self.case_id + def _get_upload_https_auth(self): + str_auth = "Bearer {}".format(self._device_token) + return {'Authorization': str_auth} + + def _upload_https_post(self, archive, verify=True): + """If upload_https() needs to use requests.post(), use this method. + + Policies should override this method instead of the base upload_https() + + :param archive: The open archive file object + """ + files = { + 'file': (archive.name.split('/')[-1], archive, + self._get_upload_headers()) + } + return requests.post(self.get_upload_url(), files=files, + headers=self._get_upload_https_auth(), + verify=verify) + def _get_upload_headers(self): if self.get_upload_url().startswith(RH_API_HOST): return {'isPrivate': 'false', 'cache-control': 'no-cache'} @@ -375,7 +396,8 @@ def upload_archive(self, archive): """ try: if self.upload_url and self.upload_url.startswith(RH_API_HOST) and\ - (not self.get_upload_user() or not self.get_upload_password()): + (not self.get_upload_user() or + not self.get_upload_password()): self.upload_url = RH_SFTP_HOST uploaded = super(RHELPolicy, self).upload_archive(archive) except Exception: @@ -466,9 +488,8 @@ def check(cls, remote=''): if not os.path.exists(host_release): return False try: - with open(host_release, 'r') as afile: - for line in afile.read().splitlines(): - atomic |= ATOMIC_RELEASE_STR in line + for line in open(host_release, "r").read().splitlines(): + atomic |= ATOMIC_RELEASE_STR in line except IOError: pass return atomic @@ -552,9 +573,8 @@ def check(cls, remote=''): return coreos host_release = os.environ[ENV_HOST_SYSROOT] + OS_RELEASE try: - with open(host_release, 'r') as hfile: - for line in hfile.read().splitlines(): - coreos |= 'Red Hat Enterprise Linux CoreOS' in line + for line in open(host_release, 'r').read().splitlines(): + coreos |= 'Red Hat Enterprise Linux CoreOS' in line except IOError: pass return coreos