Skip to content
This repository has been archived by the owner on Nov 16, 2023. It is now read-only.

Commit

Permalink
get user ssh keys from user extension secret instead of rest server (#36
Browse files Browse the repository at this point in the history
)
  • Loading branch information
suiguoxin authored Feb 23, 2021
1 parent c77c562 commit c5c6dbc
Show file tree
Hide file tree
Showing 4 changed files with 40 additions and 46 deletions.
2 changes: 1 addition & 1 deletion src/init
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,7 @@ python ${PAI_INIT_DIR}/framework_parser.py genconf framework.json > ${PAI_RUNTIM
# Init plugins
# priority=12
CHILD_PROCESS="PLUGIN_INITIALIZER"
python ${PAI_INIT_DIR}/initializer.py ${PAI_RUNTIME_DIR}/job_config.yaml ${PAI_SECRET_DIR}/secrets.yaml ${PAI_TOKEN_SECRET_DIR}/token ${PAI_WORK_DIR}/plugins ${PAI_RUNTIME_DIR} ${FC_TASKROLE_NAME}
python ${PAI_INIT_DIR}/initializer.py ${PAI_RUNTIME_DIR}/job_config.yaml ${PAI_SECRET_DIR}/secrets.yaml ${PAI_SECRET_DIR}/userExtensionSecrets.yaml ${PAI_TOKEN_SECRET_DIR}/token ${PAI_WORK_DIR}/plugins ${PAI_RUNTIME_DIR} ${FC_TASKROLE_NAME}

# Init plugins
# check port conflict
Expand Down
23 changes: 17 additions & 6 deletions src/init.d/initializer.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ def collect_plugin_configs(jobconfig, taskrole):
for prerequisite_name in jobconfig['taskRoles'][taskrole]['prerequisites']:
prerequisite_config = copy.deepcopy(prerequisites_name2config[prerequisite_name])
if 'plugin' in prerequisite_config and prerequisite_config['plugin'].startswith(RUNTIME_PLUGIN_PLACE_HOLDER):
# convert prerequisite to runtime plugin config
# convert prerequisite to runtime plugin config
plugin_config = {
# plugin name follows the format "com.microsoft.pai.runtimeplugin.<plugin name>"
'plugin': prerequisite_config.pop('plugin')[len(RUNTIME_PLUGIN_PLACE_HOLDER) + 1:]
Expand All @@ -119,7 +119,7 @@ def collect_plugin_configs(jobconfig, taskrole):
# the remaining keys (other than plugin, failurePolicy, and type) will be treated as parameters
plugin_config['parameters'] = copy.deepcopy(prerequisite_config)
plugin_configs.append(plugin_config)

# collect plugins from jobconfig["extras"]
if "extras" in jobconfig and RUNTIME_PLUGIN_PLACE_HOLDER in jobconfig["extras"]:
for plugin_config in jobconfig["extras"][RUNTIME_PLUGIN_PLACE_HOLDER]:
Expand All @@ -130,13 +130,14 @@ def collect_plugin_configs(jobconfig, taskrole):
return plugin_configs


def init_plugins(jobconfig, secrets, application_token, commands, plugins_path, runtime_path,
def init_plugins(jobconfig, secrets, user_extension, application_token, commands, plugins_path, runtime_path,
taskrole):
"""Init plugins from jobconfig.
Args:
jobconfig: Jobconfig object generated by parser.py from framework.json.
secrets: user secrests passed to runtime.
secrets: config secrests passed to runtime.
user_extension: user extension passed to runtime.
application_token: application token path passed to runtime.
commands: Commands to call in precommands.sh and postcommands.sh.
plugins_path: The base path for all plugins.
Expand All @@ -155,6 +156,8 @@ def init_plugins(jobconfig, secrets, application_token, commands, plugins_path,
taskrole))
plugin_config["parameters"] = parameters

plugin_config["user_extension"] = user_extension

if os.path.exists(application_token):
with open(application_token, "r") as f:
plugin_config["application_token"] = yaml.safe_load(f)
Expand Down Expand Up @@ -219,7 +222,9 @@ def main():
"jobconfig_yaml",
help="jobConfig.yaml generated by parser.py from framework.json")
parser.add_argument("secret_file",
help="secrets.yaml user secrets passed to runtime")
help="secrets.yaml config secrets passed to runtime")
parser.add_argument("user_extension_secrets_file",
help="userExtensionSecrets.yaml user extension secrets passed to runtime")
parser.add_argument("application_token",
help="application token passed to runtime")
parser.add_argument("plugins_path", help="Plugins path")
Expand All @@ -237,8 +242,14 @@ def main():
with open(args.secret_file) as f:
secrets = yaml.safe_load(f.read())

if not os.path.isfile(args.user_extension_secrets_file):
user_extension = None
else:
with open(args.user_extension_secrets_file) as f:
user_extension = yaml.safe_load(f.read())

commands = [[], []]
init_plugins(job_config, secrets, args.application_token, commands, args.plugins_path,
init_plugins(job_config, secrets, user_extension, args.application_token, commands, args.plugins_path,
args.runtime_path, args.task_role)

# pre-commands and post-commands already handled by rest-server.
Expand Down
47 changes: 15 additions & 32 deletions src/plugins/ssh/init.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
import logging
import os
import sys
import requests

sys.path.append(
os.path.join(os.path.dirname(os.path.abspath(__file__)), "../.."))
Expand All @@ -28,38 +27,27 @@
LOGGER = logging.getLogger(__name__)


def get_user_public_keys(application_token, username):
def get_user_public_keys(user_extension):
"""
get user public keys from rest-server
get user public keys from user extension
Format of API `REST_SERVER_URI/api/v2/users/<username>` response:
Format of user extension:
{
"xxx": "xxx",
"extensions": {
"sshKeys": [
{
"title": "title-of-the-public-key",
"value": "ssh-rsa xxxx"
"time": "xxx"
}
]
}
"sshKeys": [
{
"title": "title-of-the-public-key",
"value": "ssh-rsa xxxx"
"time": "xxx"
}
]
}
Returns:
--------
list
a list of public keys
"""
url = "{}/api/v2/users/{}".format(os.environ.get('REST_SERVER_URI'), username)
headers={
'Authorization': "Bearer {}".format(application_token),
}

response = requests.get(url, headers=headers, data={})
response.raise_for_status()

public_keys = [item["value"] for item in response.json()["extension"]["sshKeys"]]
public_keys = [item["value"] for item in user_extension["sshKeys"]]

return public_keys

Expand All @@ -69,6 +57,7 @@ def main():
[plugin_config, pre_script, _] = plugin_init()
plugin_helper = PluginHelper(plugin_config)
parameters = plugin_config.get("parameters")
user_extension = plugin_config.get("user_extension")

if not parameters:
LOGGER.info("Ssh plugin parameters is empty, ignore this")
Expand All @@ -86,16 +75,10 @@ def main():
cmd_params = [jobssh]

if "userssh" in parameters:
# get user public keys from rest server
application_token = plugin_config.get("application_token")
username = os.environ.get("PAI_USER_NAME")
# get user public keys from user extension secret
public_keys = []
if application_token:
try:
public_keys = get_user_public_keys(application_token, username)
except Exception: #pylint: disable=broad-except
LOGGER.error("Failed to get user public keys", exc_info=True)
sys.exit(1)
if user_extension and "sshKeys" in user_extension:
public_keys = get_user_public_keys(user_extension)

if "value" in parameters["userssh"] and parameters["userssh"]["value"] != "":
public_keys.append(parameters["userssh"]["value"])
Expand Down
14 changes: 7 additions & 7 deletions test/test_runtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def test_cmd_plugin(self):
with open(job_path, 'rt') as f:
jobconfig = yaml.safe_load(f)
commands = [[], []]
initializer.init_plugins(jobconfig, {}, "", commands, "../src/plugins",
initializer.init_plugins(jobconfig, {}, {}, "", commands, "../src/plugins",
".", "worker")

def test_cmd_plugin_with_callbacks(self):
Expand All @@ -69,7 +69,7 @@ def test_cmd_plugin_with_callbacks(self):
with open(job_path, 'rt') as f:
jobconfig = yaml.safe_load(f)
commands = [[], []]
initializer.init_plugins(jobconfig, {}, "", commands, "../src/plugins",
initializer.init_plugins(jobconfig, {}, {}, "", commands, "../src/plugins",
".", "worker")

def test_cmd_plugin_with_prerequisites(self):
Expand All @@ -78,7 +78,7 @@ def test_cmd_plugin_with_prerequisites(self):
with open(job_path, 'rt') as f:
jobconfig = yaml.safe_load(f)
commands = [[], []]
initializer.init_plugins(jobconfig, {}, "", commands, "../src/plugins",
initializer.init_plugins(jobconfig, {}, {}, "", commands, "../src/plugins",
".", "worker")

def test_ssh_plugin(self):
Expand All @@ -88,7 +88,7 @@ def test_ssh_plugin(self):
jobconfig = yaml.safe_load(f)
commands = [[], []]
initializer.init_plugins(
jobconfig, {"userssh": "ssh-rsa AAAAB3N/cTbWGQZtN1pai-ssh"}, "",
jobconfig, {"userssh": "ssh-rsa AAAAB3N/cTbWGQZtN1pai-ssh"}, {}, "",
commands, "../src/plugins", ".", "worker")

def test_ssh_plugin_barrier(self):
Expand All @@ -97,10 +97,10 @@ def test_ssh_plugin_barrier(self):
with open(job_path, 'rt') as f:
jobconfig = yaml.safe_load(f)
commands = [[], []]
initializer.init_plugins(jobconfig, {}, "", commands, "../src/plugins",
initializer.init_plugins(jobconfig, {}, {}, "", commands, "../src/plugins",
".", "master")
commands = [[], []]
initializer.init_plugins(jobconfig, {}, "", commands, "../src/plugins",
initializer.init_plugins(jobconfig, {}, {}, "", commands, "../src/plugins",
".", "worker")

def test_git_plugin(self):
Expand All @@ -109,7 +109,7 @@ def test_git_plugin(self):
with open(job_path, 'rt') as f:
jobconfig = yaml.safe_load(f)
commands = [[], []]
initializer.init_plugins(jobconfig, {}, "", commands, "../src/plugins", ".",
initializer.init_plugins(jobconfig, {}, {}, "", commands, "../src/plugins", ".",
"worker")
repo_local_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "../src/code")
self.assertTrue(os.path.exists(repo_local_path))
Expand Down

0 comments on commit c5c6dbc

Please sign in to comment.