Skip to content

Commit

Permalink
Feature / Multiple packages per model repo (#210)
Browse files Browse the repository at this point in the history
* Update model definition to support packages

* Utility functions for safe logging of URLs

* PyPi repo testing

* Unzip .whl downloads

* Include package and packageGroup in import model job def

* Handle new package and packageGroup attributes in the orchestrator

* Copy across new package and packageGroup attrs in the model runtime when importing models

* Add test for GitHub repos in the runtime

* Use username/password style credentials in PyPI test case

* Reduce build time in basic build workflow

* Change PyPI config props - do not use hyphen in prop keys

* Add runtime dependency for requests package
  • Loading branch information
Martin Traverse authored Nov 27, 2022
1 parent cb39452 commit c9adc95
Show file tree
Hide file tree
Showing 16 changed files with 478 additions and 68 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ jobs:
cache: gradle

- name: Build
run: ./gradlew build -x test
run: ./gradlew classes testClasses

- name: Unit tests
run: ./gradlew test
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/integration.yml
Original file line number Diff line number Diff line change
Expand Up @@ -202,7 +202,7 @@ jobs:
cache: gradle

- name: Build
run: ./gradlew trac-svc-meta:testClasses -x test
run: ./gradlew trac-svc-meta:testClasses

# Auth tool will also create the secrets file if it doesn't exist
- name: Prepare Auth Keys
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -89,9 +89,14 @@ message ImportModelJob {

string language = 1;
string repository = 2;
string path = 3;
string entryPoint = 4;

optional string packageGroup = 7;
string package = 8;
string version = 5;

string entryPoint = 4;

string path = 3;

repeated TagUpdate modelAttrs = 6;
}
Original file line number Diff line number Diff line change
Expand Up @@ -78,14 +78,14 @@ message ModelDefinition {

string language = 1;
string repository = 2;
string path = 3;

// string package = 4;
reserved 4;
reserved "package";
optional string packageGroup = 10;
string package = 11;
string version = 6;

string entryPoint = 5;
string version = 6;

string path = 3;

map<string, ModelParameter> parameters = 7;
map<string, ModelInputSchema> inputs = 8;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,8 @@ public class MetadataConstants {

public static final String TRAC_MODEL_LANGUAGE = "trac_model_language";
public static final String TRAC_MODEL_REPOSITORY = "trac_model_repository";
public static final String TRAC_MODEL_PATH = "trac_model_path";
public static final String TRAC_MODEL_ENTRY_POINT = "trac_model_entry_point";
public static final String TRAC_MODEL_PACKAGE_GROUP = "trac_model_package_group";
public static final String TRAC_MODEL_PACKAGE = "trac_model_package";
public static final String TRAC_MODEL_VERSION = "trac_model_version";
public static final String TRAC_MODEL_ENTRY_POINT = "trac_model_entry_point";
}
5 changes: 5 additions & 0 deletions tracdap-runtime/python/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,11 @@ pyyaml >= 6.0, < 7.0
pyarrow >= 9, < 10


# Requests, used for downloading model packages

requests >= 2.28.1, < 3.0


# Data is presented to model code using Pandas and/or PySpark

# Baseline support for Pandas on series 1.x
Expand Down
1 change: 1 addition & 0 deletions tracdap-runtime/python/setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ python_requires = >=3.7

install_requires =
pyyaml >= 6.0.0, < 7.0.0
requests >= 2.28.1, < 3.0.0
pyarrow >= 9.0.0, < 10.0.0
pandas >= 1.2.0, < 2.0.0

Expand Down
26 changes: 14 additions & 12 deletions tracdap-runtime/python/src/tracdap/rt/_exec/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -480,6 +480,18 @@ def _execute(self, ctx: NodeContext):
storage_options=options, overwrite=False)


def _model_def_for_import(import_details: meta.ImportModelJob):

return meta.ModelDefinition(
language=import_details.language,
repository=import_details.repository,
packageGroup=import_details.packageGroup,
package=import_details.package,
version=import_details.version,
entryPoint=import_details.entryPoint,
path=import_details.path)


class ImportModelFunc(NodeFunction[meta.ObjectDefinition]):

def __init__(self, node: ImportModelNode, models: _models.ModelLoader):
Expand All @@ -490,12 +502,7 @@ def __init__(self, node: ImportModelNode, models: _models.ModelLoader):

def _execute(self, ctx: NodeContext) -> meta.ObjectDefinition:

stub_model_def = meta.ModelDefinition(
language=self.node.import_details.language,
repository=self.node.import_details.repository,
path=self.node.import_details.path,
entryPoint=self.node.import_details.entryPoint,
version=self.node.import_details.version)
stub_model_def = _model_def_for_import(self.node.import_details)

model_class = self._models.load_model_class(self.node.model_scope, stub_model_def)
model_scan = self._models.scan_model(model_class)
Expand All @@ -516,12 +523,7 @@ def __init__(self, node: ImportAttrsNode, models: _models.ModelLoader):

def _execute(self, ctx: NodeContext) -> _config.TagUpdateList:

stub_model_def = meta.ModelDefinition(
language=self.node.import_details.language,
repository=self.node.import_details.repository,
path=self.node.import_details.path,
entryPoint=self.node.import_details.entryPoint,
version=self.node.import_details.version)
stub_model_def = _model_def_for_import(self.node.import_details)

model_class = self._models.load_model_class(self.node.model_scope, stub_model_def)
return self._models.scan_model_attrs(model_class)
Expand Down
173 changes: 165 additions & 8 deletions tracdap-runtime/python/src/tracdap/rt/_impl/repos.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,12 @@

import subprocess
import subprocess as sp
import urllib.parse
import time
import zipfile
import io

import requests

import tracdap.rt.metadata as _meta
import tracdap.rt.config as _cfg
Expand All @@ -27,6 +32,45 @@
from tracdap.rt.ext.repos import *


# Helper functions for handling credentials supplied via HTTP(S) URLs

__REPO_TOKEN_KEY = "token"
__REPO_USER_KEY = "username"
__REPO_PASS_KEY = "password"


def _get_credentials(url: urllib.parse.ParseResult, properties: tp.Dict[str, str]):

if __REPO_TOKEN_KEY in properties:
return properties[__REPO_TOKEN_KEY]

if __REPO_USER_KEY in properties and __REPO_PASS_KEY in properties:
username = properties[__REPO_USER_KEY]
password = properties[__REPO_PASS_KEY]
return f"{username}:{password}"

if url.username:
credentials_sep = url.netloc.index("@")
return url.netloc[:credentials_sep]

return None


def _apply_credentials(url: urllib.parse.ParseResult, credentials: str):

if credentials is None:
return url

if url.username is None:
location = f"{credentials}@{url.netloc}"

else:
location_sep = url.netloc.index("@")
location = f"{credentials}@{url.netloc[location_sep:]}"

return url._replace(netloc=location)


class IntegratedSource(IModelRepository):

def __init__(self, repo_config: _cfg.PluginConfig):
Expand Down Expand Up @@ -87,13 +131,20 @@ class GitRepository(IModelRepository):
GIT_TIMEOUT_SECONDS = 30

def __init__(self, repo_config: _cfg.PluginConfig):

self._repo_config = repo_config
self._log = _util.logger_for_object(self)
self._repo_url = self._repo_config.properties.get(self.REPO_URL_KEY)

if not self._repo_url:
repo_url_prop = self._repo_config.properties.get(self.REPO_URL_KEY)

if not repo_url_prop:
raise _ex.EConfigParse(f"Missing required property [{self.REPO_URL_KEY}] in Git repository config")

repo_url = urllib.parse.urlparse(repo_url_prop)
credentials = _get_credentials(repo_url, repo_config.properties)

self._repo_url = _apply_credentials(repo_url, credentials)

def checkout_key(self, model_def: _meta.ModelDefinition):
return model_def.version

Expand All @@ -105,13 +156,17 @@ def package_path(

def do_checkout(self, model_def: _meta.ModelDefinition, checkout_dir: pathlib.Path) -> pathlib.Path:

self._log.info(f"Git checkout {model_def.repository} {model_def.version} -> {checkout_dir}")
self._log.info(
f"Git checkout: repo = [{model_def.repository}], " +
f"group = [{model_def.packageGroup}], package = [{model_def.package}], version = [{model_def.version}]")

self._log.info(f"Checkout location: [{checkout_dir}]")

git_cli = ["git", "-C", str(checkout_dir)]

git_cmds = [
["init"],
["remote", "add", "origin", self._repo_url],
["remote", "add", "origin", self._repo_url.geturl()],
["fetch", "--depth=1", "origin", model_def.version],
["reset", "--hard", "FETCH_HEAD"]]

Expand All @@ -134,7 +189,8 @@ def do_checkout(self, model_def: _meta.ModelDefinition, checkout_dir: pathlib.Pa

for git_cmd in git_cmds:

self._log.info(f"git {' '.join(git_cmd)}")
safe_cmd = map(_util.log_safe, git_cmd)
self._log.info(f"git {' '.join(safe_cmd)}")

cmd = [*git_cli, *git_cmd]
cmd_result = sp.run(cmd, cwd=checkout_dir, stdout=sp.PIPE, stderr=sp.PIPE, timeout=self.GIT_TIMEOUT_SECONDS)
Expand All @@ -158,11 +214,111 @@ def do_checkout(self, model_def: _meta.ModelDefinition, checkout_dir: pathlib.Pa
for line in cmd_err:
self._log.error(line)

error_msg = f"Git checkout failed for {model_def.repository} {model_def.version}"
error_msg = f"Git checkout failed for {model_def.package} {model_def.version}"
self._log.error(error_msg)
raise _ex.EModelRepo(error_msg)

self._log.info(f"Git checkout succeeded for {model_def.repository} {model_def.version}")
self._log.info(f"Git checkout succeeded for {model_def.package} {model_def.version}")

return self.package_path(model_def, checkout_dir)


class PyPiRepository(IModelRepository):

JSON_PACKAGE_PATH = "{}/{}/{}/json"

PIP_INDEX_KEY = "pipIndex"
PIP_INDEX_URL_KEY = "pipIndexUrl"

def __init__(self, repo_config: _cfg.PluginConfig):

self._log = _util.logger_for_object(self)

self._repo_config = repo_config

if self.PIP_INDEX_KEY not in self._repo_config.properties:
raise _ex.EConfigParse(f"Missing required property [{self.PIP_INDEX_KEY}] in PyPi repository config")

def checkout_key(self, model_def: _meta.ModelDefinition):
return model_def.version

def package_path(
self, model_def: _meta.ModelDefinition,
checkout_dir: pathlib.Path) -> tp.Optional[pathlib.Path]:

return checkout_dir

def do_checkout(self, model_def: _meta.ModelDefinition, checkout_dir: pathlib.Path) -> tp.Optional[pathlib.Path]:

self._log.info(
f"PyPI checkout: repo = [{model_def.repository}], " +
f"package = [{model_def.package}], version = [{model_def.version}]")

self._log.info(f"Checkout location: [{checkout_dir}]")

repo_props = self._repo_config.properties
pip_index = repo_props.get(self.PIP_INDEX_KEY)

if pip_index is None:
raise _ex.EConfigParse(f"Missing required property [{self.PIP_INDEX_KEY}] in PyPi repository config")

json_root_url = urllib.parse.urlparse(pip_index)
json_package_path = self.JSON_PACKAGE_PATH.format(json_root_url.path, model_def.package, model_def.version)
json_package_url = json_root_url._replace(path=json_package_path)

credentials = _get_credentials(json_root_url, self._repo_config.properties)
json_root_url = _apply_credentials(json_package_url, credentials)

json_headers = {"accept": "application/json"}

self._log.info(f"Package query: {_util.log_safe_url(json_root_url)}")

package_req = requests.get(json_package_url.geturl(), headers=json_headers)

if package_req.status_code != requests.codes.OK:
status_code_name = requests.codes.name[package_req.status_code]
message = f"Package lookup failed: [{package_req.status_code}] {status_code_name}"
self._log.error(message)
raise _ex.EModelRepo(message) # todo status code for access, not found etc

package_obj = package_req.json()
package_info = package_obj.get("info") or {}
summary = package_info.get("summary") or "(summary not available)"

self._log.info(f"Package summary: {summary}")

urls = package_obj.get("urls") or []
bdist_urls = list(filter(lambda d: d.get("packagetype") == "bdist_wheel", urls))

if not bdist_urls:
message = "No compatible packages found"
self._log.error(message)
raise _ex.EModelRepo(message)

if len(bdist_urls) > 1:
message = "Multiple compatible packages found (specialized distributions are not supported yet)"
self._log.error(message)
raise _ex.EModelRepo(message)

package_url_info = bdist_urls[0]
package_filename = package_url_info.get("filename")
package_url = urllib.parse.urlparse(package_url_info.get("url"))

package_url = _apply_credentials(package_url, credentials)

self._log.info(f"Downloading [{package_filename}]")

download_req = requests.get(package_url.geturl())
content = download_req.content
elapsed = download_req.elapsed

self._log.info(f"Downloaded [{len(content) / 1024:.1f}] KB in [{elapsed.total_seconds():.1f}] seconds")

download_whl = zipfile.ZipFile(io.BytesIO(download_req.content))
download_whl.extractall(checkout_dir)

self._log.info(f"Unpacked [{len(download_whl.filelist)}] files")
self._log.info(f"PyPI checkout succeeded for {model_def.package} {model_def.version}")

return self.package_path(model_def, checkout_dir)

Expand All @@ -172,7 +328,8 @@ class RepositoryManager:
__repo_types: tp.Dict[str, tp.Callable[[_cfg.PluginConfig], IModelRepository]] = {
"integrated": IntegratedSource,
"local": LocalRepository,
"git": GitRepository
"git": GitRepository,
"pypi": PyPiRepository
}

@classmethod
Expand Down
Loading

0 comments on commit c9adc95

Please sign in to comment.