Skip to content
This repository has been archived by the owner on Aug 22, 2023. It is now read-only.

Commit

Permalink
Added validation for selecting the SSH public key (#52)
Browse files Browse the repository at this point in the history
Signed-off-by: Arjun Naik <arjun.rn@gmail.com>
  • Loading branch information
arjunrn authored Mar 25, 2020
1 parent 710faca commit daee100
Show file tree
Hide file tree
Showing 9 changed files with 100 additions and 13 deletions.
42 changes: 34 additions & 8 deletions piu/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,14 @@
import socket
import sys
import time

import sshpubkeys
import yaml
import zign.api
import re

from clickclick import error, AliasedGroup, print_table, OutputFormat, warning

from .error_handling import handle_exceptions

import piu
Expand Down Expand Up @@ -275,11 +278,7 @@ def cli(ctx, config_file):
default=True,
)
@click.option(
"-i",
"--ssh-public-key",
help="The public key to use to SSH",
type=click.Path(),
default=os.path.expanduser("~/.ssh/id_rsa.pub"),
"-i", "--ssh-public-key", help="The public key to use to SSH", type=click.Path(),
)
@region_option
@click.pass_obj
Expand All @@ -303,8 +302,9 @@ def request_access(
config = load_config(config_file)
even_url = even_url or config.get("even_url")
odd_host = odd_host or piu.utils.find_odd_host(region) or config.get("odd_host")
if not check_ssh_key(ssh_public_key) and check_ssh_key(config.get("ssh_public_key")):
ssh_public_key = config.get("ssh_public_key")
ssh_public_key = validate_ssh_key(
ssh_public_key, config.get("ssh_public_key"), os.path.expanduser("~/.ssh/id_rsa.pub"), interactive
)

if interactive:
host, odd_host, reason, ssh_public_key = request_access_interactive(region, odd_host, ssh_public_key)
Expand Down Expand Up @@ -409,9 +409,35 @@ def request_access(
sys.exit(1)


def validate_ssh_key(option_path: str, config_path: str, fallback_path: str, interactive: str) -> str:
if option_path:
if check_ssh_key(option_path):
return option_path
if not interactive:
error("specified ssh public key at {0:s} is not a valid key".format(option_path))
sys.exit(1)
elif check_ssh_key(config_path):
return config_path
elif check_ssh_key(fallback_path):
return fallback_path
if not interactive:
error(
"No valid SSH public key could be determined. "
"Please specify one with the -i flag. Consult help for details"
)
sys.exit(1)
return ""


def check_ssh_key(key_path: str) -> bool:
# TODO: verify that the input key is actually an SSH public key
if key_path and os.path.exists(key_path):
with open(key_path) as key:
contents = key.read()
key = sshpubkeys.SSHKey(contents)
try:
key.parse()
except (sshpubkeys.InvalidKeyError, NotImplementedError) as e:
return False
return True
return False

Expand Down
9 changes: 5 additions & 4 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
clickclick>=0.10
PyYAML
requests
pyperclip
stups-zign>=1.1.26
boto3>=1.12.0
botocore>=1.4.10
click>=1.2.2
clickclick>=0.10
pyperclip
requests
sshpubkeys>=3.1.0
stups-zign>=1.1.26
1 change: 1 addition & 0 deletions tests/resources/id_rsa.pub
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
ssh-rsa AAAAB3NzaC1yc2EAAAADAQABAAABgQDR5i0JujjiY5A1fuThjjxvrNDMv/WrvEFN+dTrq0fHsa2I4p51eUdYMnRrkZaQ2rvr6jTuYSt75rtZguNB7DnFFOklG4YDgsIxb+XlizDImnbs5hXawPw869X6uDtaXYxP/ddK1zcVEeW1J0NWBWbJ6t3o2lr0+YkotZ5L/iRC/PR7btNH92oKyqxz11uuL/QtomSScQJoL9nNVAB3MtAO0t5GNRsGk4M7ln9l+R9Xbqcb6abLm6qiIffcfEcDZmLMOYrfI52O0YaLRk7/ifPz4BDLaU8feP/rsc9wAzCwH3nHW+/ymM3JhG7FAEXePswHQ5CdeUmltqXRl3THG4JYUkevD1xIzIl6/DHe6qvoFmIHGsK2P8EEt6ZX4pt1lBsh+g5WMH8PxR9hCuFIoAj5IpHh1lS9suNhN3vck1U7VRQwDtL/WzesKOJbSopoA0TYsahVVrq4KZOTbInTrbe9wu3bKj4xv9BeWAB+tIVN3WSzsFZ5btZRuqtVzzIfI3s= test@blash
1 change: 1 addition & 0 deletions tests/resources/id_rsa_config.pub
1 change: 1 addition & 0 deletions tests/resources/id_rsa_fallback.pub
1 change: 1 addition & 0 deletions tests/resources/id_rsa_malformed.pub
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
ssh-mal malformed test@blash
1 change: 1 addition & 0 deletions tests/resources/id_rsa_option.pub
2 changes: 1 addition & 1 deletion tests/test_cli.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import tempfile

import boto3
from click.testing import CliRunner
from unittest.mock import MagicMock, Mock
import zign.api
Expand All @@ -24,6 +23,7 @@ def mock_aws(monkeypatch):
monkeypatch.setattr("boto3.client", MagicMock(return_value={}))
monkeypatch.setattr("piu.cli.compatible_ami", MagicMock(return_value=False))
monkeypatch.setattr("piu.cli.check_ssh_key", MagicMock(return_value=True))
monkeypatch.setattr("piu.cli.validate_ssh_key", MagicMock(return_value="nonexistent"))
monkeypatch.setattr("piu.cli.send_odd_ssh_key", MagicMock(return_value=True))
mock_list_running_instances(monkeypatch)
yield
Expand Down
55 changes: 55 additions & 0 deletions tests/test_validate_ssh.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
from os import path
from unittest.mock import MagicMock

from piu.cli import validate_ssh_key, check_ssh_key


def test_validate_ssh_fallback(monkeypatch):
mock_exit = MagicMock()
monkeypatch.setattr("sys.exit", mock_exit)
fallback_path = path.join(path.abspath(path.dirname(__file__)), "resources/id_rsa_fallback.pub")
final_path = validate_ssh_key("", "", fallback_path, False)
assert final_path == fallback_path
mock_exit.assert_not_called()


def test_validate_ssh_valid_input(monkeypatch):
mock_exit = MagicMock()
monkeypatch.setattr("sys.exit", mock_exit)
option_path = path.join(path.abspath(path.dirname(__file__)), "resources/id_rsa_option.pub")
config_path = path.join(path.abspath(path.dirname(__file__)), "resources/id_rsa_config.pub")
fallback_path = path.join(path.abspath(path.dirname(__file__)), "resources/id_rsa_fallback.pub")
final_path = validate_ssh_key(option_path, config_path, fallback_path, False)
assert final_path == option_path
mock_exit.assert_not_called()


def test_validate_ssh_valid_config(monkeypatch):
mock_exit = MagicMock()
monkeypatch.setattr("sys.exit", mock_exit)
config_path = path.join(path.abspath(path.dirname(__file__)), "resources/id_rsa_config.pub")
fallback_path = path.join(path.abspath(path.dirname(__file__)), "resources/id_rsa_fallback.pub")
final_path = validate_ssh_key("", config_path, fallback_path, False)
assert final_path == config_path
mock_exit.assert_not_called()


def test_validate_ssh_valid_error(monkeypatch):
mock_exit = MagicMock()
monkeypatch.setattr("sys.exit", mock_exit)
final_path = validate_ssh_key("", "", "", False)
assert final_path == ""
mock_exit.assert_called_with(1)


def test_validate_ssh_valid_no_error_interactive(monkeypatch):
mock_exit = MagicMock()
monkeypatch.setattr("sys.exit", mock_exit)
final_path = validate_ssh_key("", "", "", True)
assert final_path == ""
mock_exit.assert_not_called()


def test_malformed_ssh_key(monkeypatch):
malformed_path = path.join(path.abspath(path.dirname(__file__)), "resources/id_rsa_malformed.pub")
assert not check_ssh_key(malformed_path)

0 comments on commit daee100

Please sign in to comment.