Skip to content

Commit

Permalink
Dev: utils: Avoid hardcoding the ssh key type as RSA (#1600)
Browse files Browse the repository at this point in the history
## Problem
- The join process will fail if the existing key type is not RSA. See:
#1504 (comment)
- If the init and join nodes are already set up for passwordless access
using an ed25519 key, an RSA key pair is still generated during the
init/join process.

## Changes include:
- Avoid hardcoding the ssh key type as RSA
- Introduced a new function `ssh_key.fetch_public_key_list` to fetch
public keys from local or remote, return as public key path list or
public key content list
- In KeyFileManager, use class variable to store the key type instead of
hardcoding it as RSA
- Improve shell script in generate_ssh_key_pair_on_remote to avoid
hardcoding the key type as RSA
- Remove unused code
  • Loading branch information
liangxin1300 authored Nov 25, 2024
2 parents 9b1d80f + 271fd8a commit ae21432
Show file tree
Hide file tree
Showing 4 changed files with 144 additions and 146 deletions.
121 changes: 47 additions & 74 deletions crmsh/bootstrap.py
Original file line number Diff line number Diff line change
Expand Up @@ -888,7 +888,7 @@ def _init_ssh_on_remote_nodes(
for i, (remote_user, node) in enumerate(user_node_list):
utils.ssh_copy_id(local_user, remote_user, node)
# After this, login to remote_node is passwordless
public_key_list.append(swap_public_ssh_key(node, local_user, remote_user, local_user, remote_user, add=True))
public_key_list.append(swap_public_ssh_key(node, local_user, remote_user, local_user, remote_user))
if len(user_node_list) > 1:
shell = sh.LocalShell()
shell_script = _merge_line_into_file('~/.ssh/authorized_keys', public_key_list).encode('utf-8')
Expand Down Expand Up @@ -956,18 +956,6 @@ def _fetch_core_hosts(shell: sh.ClusterShell, remote_host) -> typing.Tuple[typin
return user_list, host_list


def key_files(user):
"""
Find home directory for user and return key files with abspath
"""
keyfile_dict = {}
home_dir = userdir.gethomedir(user)
keyfile_dict['private'] = "{}/.ssh/id_rsa".format(home_dir)
keyfile_dict['public'] = "{}/.ssh/id_rsa.pub".format(home_dir)
keyfile_dict['authorized'] = "{}/.ssh/authorized_keys".format(home_dir)
return keyfile_dict


def is_nologin(user, remote=None):
"""
Check if user's shell is nologin
Expand Down Expand Up @@ -1003,10 +991,8 @@ def change_user_shell(user, remote=None):

def configure_ssh_key(user):
"""
Configure ssh rsa key on local or remote
If <home_dir>/.ssh/id_rsa not exist, generate a new one
Add <home_dir>/.ssh/id_rsa.pub to <home_dir>/.ssh/authorized_keys anyway, make sure itself authorized
Configure ssh key for user, generate a new key pair if needed,
and add the public key to authorized_keys
"""
change_user_shell(user)
shell = sh.LocalShell()
Expand All @@ -1027,25 +1013,50 @@ def generate_ssh_key_pair_on_remote(
shell = sh.LocalShell()
# pass cmd through stdin rather than as arguments. It seems sudo has its own argument parsing mechanics,
# which breaks shell expansion used in cmd
cmd = '''
[ -f ~/.ssh/id_rsa ] || ssh-keygen -q -t rsa -f ~/.ssh/id_rsa -C "Cluster internal on $(hostname)" -N ''
[ -f ~/.ssh/id_rsa.pub ] || ssh-keygen -y -f ~/.ssh/id_rsa > ~/.ssh/id_rsa.pub
generate_key_script = f'''
key_types=({ ' '.join(ssh_key.KeyFileManager.KNOWN_KEY_TYPES) })
for key_type in "${{key_types[@]}}"; do
priv_key_file=~/.ssh/id_${{key_type}}
if [ -f "$priv_key_file" ]; then
pub_key_file=$priv_key_file.pub
break
fi
done
if [ -z "$pub_key_file" ]; then
key_type={ssh_key.KeyFileManager.DEFAULT_KEY_TYPE}
priv_key_file=~/.ssh/id_${{key_type}}
ssh-keygen -q -t $key_type -f $priv_key_file -C "Cluster internal on $(hostname)" -N ''
pub_key_file=$priv_key_file.pub
fi
[ -f "$pub_key_file" ] || ssh-keygen -y -f $priv_key_file > $pub_key_file
'''
result = shell.su_subprocess_run(
local_sudoer,
'ssh {} {}@{} sudo -H -u {} /bin/sh'.format(constants.SSH_OPTION, remote_sudoer, remote_host, remote_user),
input=cmd.encode('utf-8'),
input=generate_key_script.encode('utf-8'),
stdout=subprocess.PIPE,
stderr=subprocess.STDOUT,
)
if result.returncode != 0:
raise ValueError(codecs.decode(result.stdout, 'utf-8', 'replace'))

cmd = 'cat ~/.ssh/id_rsa.pub'
fetch_key_script = f'''
key_types=({ ' '.join(ssh_key.KeyFileManager.KNOWN_KEY_TYPES) })
for key_type in "${{key_types[@]}}"; do
priv_key_file=~/.ssh/id_${{key_type}}
if [ -f "$priv_key_file" ]; then
pub_key_file=$priv_key_file.pub
cat $pub_key_file
break
fi
done
'''
result = shell.su_subprocess_run(
local_sudoer,
'ssh {} {}@{} sudo -H -u {} /bin/sh'.format(constants.SSH_OPTION, remote_sudoer, remote_host, remote_user),
input=cmd.encode('utf-8'),
input=fetch_key_script.encode('utf-8'),
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
)
Expand All @@ -1058,9 +1069,7 @@ def export_ssh_key_non_interactive(local_user_to_export, remote_user_to_swap, re
"""Copy ssh key from local to remote's authorized_keys. Require a configured non-interactive ssh authentication."""
# ssh-copy-id will prompt for the password of the destination user
# this is unwanted, so we write to the authorised_keys file ourselve
# cmd = "ssh-copy-id -i ~{}/.ssh/id_rsa.pub {}@{}".format(local_user, remote_user_to_access, remote_node)
with open(os.path.expanduser('~{}/.ssh/id_rsa.pub'.format(local_user_to_export)), 'r', encoding='utf-8') as f:
public_key = f.read()
public_key = ssh_key.fetch_public_key_content_list(None, local_user_to_export)[0]
# FIXME: prevent duplicated entries in authorized_keys
cmd = '''mkdir -p ~{user}/.ssh && chown {user} ~{user}/.ssh && chmod 0700 ~{user}/.ssh && cat >> ~{user}/.ssh/authorized_keys << "EOF"
{key}
Expand All @@ -1079,17 +1088,6 @@ def export_ssh_key_non_interactive(local_user_to_export, remote_user_to_swap, re
))


def import_ssh_key(local_user, remote_user, local_sudoer, remote_node, remote_sudoer):
"Copy ssh key from remote to local authorized_keys"
remote_key_content = remote_public_key_from(remote_user, local_sudoer, remote_node, remote_sudoer)
_, _, local_authorized_file = key_files(local_user).values()
if not utils.check_text_included(remote_key_content, local_authorized_file, remote=None):
sh.LocalShell().get_stdout_or_raise_error(
local_user,
"sed -i '$a {}' '{}'".format(remote_key_content, local_authorized_file),
)


def init_csync2():
host_list = _context.node_list_in_cluster

Expand Down Expand Up @@ -1186,7 +1184,7 @@ def init_qnetd_remote():
Triggered by join_cluster, this function adds the joining node's key to the qnetd's authorized_keys
"""
local_user, remote_user, join_node = _select_user_pair_for_ssh_for_secondary_components(_context.cluster_node)
join_node_key_content = remote_public_key_from(remote_user, local_user, join_node, remote_user)
join_node_key_content = ssh_key.fetch_public_key_content_list(join_node, remote_user)[0]
qnetd_host = corosync.get_value("quorum.device.net.host")
_, qnetd_user, qnetd_host = _select_user_pair_for_ssh_for_secondary_components(qnetd_host)
authorized_key_manager = ssh_key.AuthorizedKeyManager(sh.cluster_shell())
Expand Down Expand Up @@ -1531,7 +1529,7 @@ def _setup_passwordless_ssh_for_qnetd(cluster_node_list: typing.List[str]):
if node == utils.this_node():
continue
local_user, remote_user, node = _select_user_pair_for_ssh_for_secondary_components(node)
remote_key_content = remote_public_key_from(remote_user, local_user, node, remote_user)
remote_key_content = ssh_key.fetch_public_key_content_list(node, remote_user)[0]
in_memory_key = ssh_key.InMemoryPublicKey(remote_key_content)
ssh_key.AuthorizedKeyManager(cluster_shell).add(qnetd_addr, qnetd_user, in_memory_key)

Expand Down Expand Up @@ -1612,7 +1610,7 @@ def join_ssh_impl(local_user, seed_host, seed_user, ssh_public_keys: typing.List
msg += '\nOr, run "{}".'.format(' '.join(args))
raise ValueError(msg)
# After this, login to remote_node is passwordless
swap_public_ssh_key(seed_host, local_user, seed_user, local_user, seed_user, add=True)
swap_public_ssh_key(seed_host, local_user, seed_user, local_user, seed_user)
ssh_shell = sh.SSHShell(local_shell, local_user)
if seed_user != 'root' and 0 != ssh_shell.subprocess_run_without_input(
seed_host, seed_user, 'sudo true',
Expand Down Expand Up @@ -1670,8 +1668,7 @@ def swap_public_ssh_key(
local_user_to_swap,
remote_user_to_swap,
local_sudoer,
remote_sudoer,
add=False,
remote_sudoer
):
"""
Swap public ssh key between remote_node and local
Expand All @@ -1680,35 +1677,11 @@ def swap_public_ssh_key(
if utils.check_ssh_passwd_need(local_user_to_swap, remote_user_to_swap, remote_node):
export_ssh_key_non_interactive(local_user_to_swap, remote_user_to_swap, remote_node, local_sudoer, remote_sudoer)

if add:
public_key = generate_ssh_key_pair_on_remote(local_sudoer, remote_node, remote_sudoer, remote_user_to_swap)
ssh_key.AuthorizedKeyManager(sh.SSHShell(sh.LocalShell(), local_user_to_swap)).add(
None, local_user_to_swap, ssh_key.InMemoryPublicKey(public_key),
)
return public_key
else:
try:
import_ssh_key(local_user_to_swap, remote_user_to_swap, local_sudoer, remote_node, remote_sudoer)
except ValueError as e:
logger.warning(e)


def remote_public_key_from(remote_user, local_sudoer, remote_node, remote_sudoer):
"Get the id_rsa.pub from the remote node"
cmd = 'cat ~/.ssh/id_rsa.pub'
result = sh.LocalShell().su_subprocess_run(
local_sudoer,
'ssh {} {}@{} sudo -H -u {} /bin/sh'.format(constants.SSH_OPTION, remote_sudoer, remote_node, remote_user),
input=cmd.encode('utf-8'),
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
)
if result.returncode != 0:
utils.fatal("Can't get the remote id_rsa.pub from {}: {}".format(
remote_node,
codecs.decode(result.stderr, 'utf-8', 'replace'),
))
return result.stdout.decode('utf-8')
public_key = generate_ssh_key_pair_on_remote(local_sudoer, remote_node, remote_sudoer, remote_user_to_swap)
ssh_key.AuthorizedKeyManager(sh.SSHShell(sh.LocalShell(), local_user_to_swap)).add(
None, local_user_to_swap, ssh_key.InMemoryPublicKey(public_key),
)
return public_key


def join_csync2(seed_host, remote_user):
Expand Down Expand Up @@ -1842,7 +1815,7 @@ def setup_passwordless_with_other_nodes(init_node, remote_user):
swap_public_ssh_key(node, local_user, remote_user_to_swap, local_user, remote_privileged_user)
if local_user != 'hacluster':
change_user_shell('hacluster', node)
swap_public_ssh_key(node, 'hacluster', 'hacluster', local_user, remote_privileged_user, add=True)
swap_public_ssh_key(node, 'hacluster', 'hacluster', local_user, remote_privileged_user)
if local_user != 'hacluster':
swap_key_for_hacluster(cluster_nodes_list)
else:
Expand Down Expand Up @@ -2605,7 +2578,7 @@ def bootstrap_join_geo(context):
sh.LocalShell(additional_environ={'SSH_AUTH_SOCK': ''}),
):
raise ValueError(f"Failed to login to {remote_user}@{node}. Please check the credentials.")
swap_public_ssh_key(node, local_user, remote_user, local_user, remote_user, add=True)
swap_public_ssh_key(node, local_user, remote_user, local_user, remote_user)
user_by_host = utils.HostUserConfig()
user_by_host.add(local_user, utils.this_node())
user_by_host.add(remote_user, node)
Expand Down Expand Up @@ -2644,7 +2617,7 @@ def bootstrap_arbitrator(context):
sh.LocalShell(additional_environ={'SSH_AUTH_SOCK': ''}),
):
raise ValueError(f"Failed to login to {remote_user}@{node}. Please check the credentials.")
swap_public_ssh_key(node, local_user, remote_user, local_user, remote_user, add=True)
swap_public_ssh_key(node, local_user, remote_user, local_user, remote_user)
user_by_host.add(local_user, utils.this_node())
user_by_host.add(remote_user, node)
user_by_host.set_no_generating_ssh_key(context.use_ssh_agent)
Expand Down
65 changes: 61 additions & 4 deletions crmsh/ssh_key.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,8 +130,8 @@ def _add_remote(self, host: str, user: str, key: Key):
tmp.flush()
self._add_by_ssh_copy_id(user, host, tmp.name)

@classmethod
def _add_by_editing_file(cls, user: str, key: Key):
@staticmethod
def _add_by_editing_file(user: str, key: Key):
public_key = key.public_key()
dir = f'~{user}/.ssh'
file = f'{dir}/authorized_keys'
Expand Down Expand Up @@ -183,6 +183,7 @@ def list(self) -> typing.List[Key]:


class KeyFileManager:
DEFAULT_KEY_TYPE = 'rsa'
KNOWN_KEY_TYPES = ['rsa', 'ed25519', 'ecdsa'] # dsa is not listed here as it is not so secure
KNOWN_PUBLIC_KEY_FILENAME_PATTERN = re.compile('/id_(?:{})\\.pub$'.format('|'.join(KNOWN_KEY_TYPES)))

Expand All @@ -208,7 +209,7 @@ def load_public_keys_for_user(self, host: typing.Optional[str], user: str) -> ty
filenames = self.list_public_key_for_user(host, user)
if not filenames:
return list()
cmd = f'cat ~{user}/.ssh/{{{",".join(filenames)}}}'
cmd = f'cat {",".join(filenames)}'
result = self.cluster_shell.subprocess_run_without_input(
host, user,
cmd,
Expand All @@ -232,7 +233,7 @@ def ensure_key_pair_exists_for_user(
* list_of_public_keys: all public keys of known types, including the newly generated one
"""
script = '''if [ ! \\( {condition} \\) ]; then
ssh-keygen -t rsa -f ~/.ssh/id_rsa -q -C "Cluster internal on $(hostname)" -N '' <> /dev/null
ssh-keygen -t {key_type} -f ~/.ssh/id_{key_type} -q -C "Cluster internal on $(hostname)" -N '' <> /dev/null
echo 'GENERATED=1'
fi
for file in ~/.ssh/id_{{{pattern}}}; do
Expand All @@ -245,6 +246,7 @@ def ensure_key_pair_exists_for_user(
done
'''.format(
condition=' -o '.join([f'-f ~/.ssh/id_{t}' for t in self.KNOWN_KEY_TYPES]),
key_type=self.DEFAULT_KEY_TYPE,
pattern=','.join(self.KNOWN_KEY_TYPES),
)
result = self.cluster_shell.subprocess_run_without_input(
Expand All @@ -265,3 +267,58 @@ def ensure_key_pair_exists_for_user(
else:
keys.append(InMemoryPublicKey(line))
return generated, keys


def fetch_public_key_file_list(
host: typing.Optional[str],
user: str,
generate_key_pair: bool = False
) -> typing.List[str]:
"""
Fetch the public key file list for the specified user on the specified host.
:param host: the host where the user is located. If None, the local host is assumed.
:param user: the user name
:param generate_key_pair: whether to generate a new key pair if no key pair is found,
default is False
:return: a list of public key file paths
:raise Error: if no public key file is found for the user
"""
key_file_manager = KeyFileManager(sh.cluster_shell())
if generate_key_pair:
key_file_manager.ensure_key_pair_exists_for_user(host, user)
public_keys = key_file_manager.list_public_key_for_user(host, user)
if not public_keys:
host_str = f'@{host}' if host else ' locally'
raise Error(f'No public key file found for {user}{host_str}')
return public_keys


def fetch_public_key_content_list(
host: typing.Optional[str],
user: str,
generate_key_pair: bool = False
) -> typing.List[str]:
"""
Fetch the public key content list for the specified user on the specified host.
:param host: the host where the user is located. If None, the local host is assumed.
:param user: the user name
:param generate_key_pair: whether to generate a new key pair if no key pair is found,
default is False
:return: a list of public key strings
:raise Error: if no public key file is found for the user
"""
key_file_manager = KeyFileManager(sh.cluster_shell())
if generate_key_pair:
key_file_manager.ensure_key_pair_exists_for_user(host, user)
keys_in_memory = key_file_manager.load_public_keys_for_user(host, user)
public_keys = [key.public_key() for key in keys_in_memory]
if not public_keys:
host_str = f'@{host}' if host else ' locally'
raise Error(f'No public key file found for {user}{host_str}')
return public_keys
29 changes: 3 additions & 26 deletions crmsh/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
from . import constants
from . import options
from . import term
from . import ssh_key
from .constants import SSH_OPTION
from . import log
from .prun import prun
Expand Down Expand Up @@ -138,8 +139,9 @@ def ssh_copy_id_no_raise(local_user, remote_user, remote_node, shell: sh.LocalSh
if shell is None:
shell = sh.LocalShell()
if check_ssh_passwd_need(local_user, remote_user, remote_node, shell):
local_public_key = ssh_key.fetch_public_key_file_list(None, local_user)[0]
logger.info("Configuring SSH passwordless with {}@{}".format(remote_user, remote_node))
cmd = "ssh-copy-id -i ~/.ssh/id_rsa.pub '{}@{}' &> /dev/null".format(remote_user, remote_node)
cmd = f"ssh-copy-id -i {local_public_key} '{remote_user}@{remote_node}' &> /dev/null"
result = shell.su_subprocess_run(local_user, cmd, tty=True)
return result.returncode
else:
Expand Down Expand Up @@ -2427,16 +2429,6 @@ def get_default_nic_from_route(self) -> str:
return res.group(1) if res else self.nic_list[0]


def check_text_included(text, target_file, remote=None):
"Check whether target_file includes the text"
if not detect_file(target_file, remote=remote):
return False

cmd = "cat {}".format(target_file)
target_data = sh.cluster_shell().get_stdout_or_raise_error(cmd, remote)
return text in target_data


def package_is_installed(pkg, remote_addr=None):
"""
Check if package is installed
Expand Down Expand Up @@ -2915,21 +2907,6 @@ def diff_and_patch(orig_cib_str, current_cib_str):
return True


def detect_file(_file, remote=None):
"""
Detect if file exists, support both local and remote
"""
rc = False
if not remote:
cmd = "test -f {}".format(_file)
else:
# FIXME
cmd = "ssh {} {}@{} 'test -f {}'".format(SSH_OPTION, user_of(remote), remote, _file)
code, _, _ = ShellUtils().get_stdout_stderr(cmd)
rc = code == 0
return rc


def retry_with_timeout(callable, timeout_sec: float, interval_sec=1):
"""Try callable repeatedly until it returns without raising an exception.
Expand Down
Loading

0 comments on commit ae21432

Please sign in to comment.