-
Notifications
You must be signed in to change notification settings - Fork 8
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #38 from commit-0/fix/typer
Add typer and remove hydra
- Loading branch information
Showing
7 changed files
with
278 additions
and
172 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,171 +1,10 @@ | ||
import commit0.harness.run_pytest_ids | ||
import commit0.harness.get_pytest_ids | ||
import commit0.harness.build | ||
import commit0.harness.setup | ||
import commit0.harness.evaluate | ||
import commit0.harness.lint | ||
import commit0.harness.save | ||
import copy | ||
import sys | ||
import os | ||
import hydra | ||
from hydra.core.config_store import ConfigStore | ||
from commit0.configs.config_class import Commit0Config | ||
from commit0.harness.constants import COMMANDS, SPLIT | ||
from omegaconf import OmegaConf | ||
from commit0.cli import app as commit0_app | ||
|
||
|
||
def main() -> None: | ||
command = sys.argv[1] | ||
if command not in COMMANDS: | ||
raise ValueError( | ||
f"command must be from {', '.join(COMMANDS)}, but you provided {command}" | ||
) | ||
# type check config values | ||
cs = ConfigStore.instance() | ||
cs.store(name="user", group="Commit0Config", node=Commit0Config) | ||
# have hydra to ignore all command-line arguments | ||
sys_argv = copy.deepcopy(sys.argv) | ||
cfg_arg = next((arg for arg in sys_argv if arg.startswith("--cfg=")), None) | ||
|
||
hydra.initialize(version_base=None, config_path="configs") | ||
config = hydra.compose(config_name="user") | ||
|
||
if cfg_arg: | ||
sys_argv.remove(cfg_arg) | ||
config_name = cfg_arg.split("=")[1] | ||
user_config = OmegaConf.load(config_name) | ||
config = OmegaConf.merge(config, user_config) | ||
|
||
# after hydra gets all configs, put command-line arguments back | ||
sys.argv = sys_argv | ||
# repo_split: split from command line has a higher priority than split in hydra | ||
if command in [ | ||
"clone", | ||
"build", | ||
"evaluate", | ||
"evaluate-reference", | ||
"save", | ||
]: | ||
if len(sys.argv) >= 3: | ||
if sys.argv[2] not in SPLIT: | ||
raise ValueError( | ||
f"repo split must be from {', '.join(SPLIT.keys())}, but you provided {sys.argv[2]}" | ||
) | ||
config.repo_split = sys.argv[2] | ||
config.base_dir = os.path.abspath(config.base_dir) | ||
|
||
if command == "clone": | ||
if len(sys.argv) != 3: | ||
raise ValueError( | ||
"You provided an incorrect number of arguments.\nUsage: commit0 clone {repo_split}" | ||
) | ||
commit0.harness.setup.main( | ||
config.dataset_name, | ||
config.dataset_split, | ||
config.repo_split, | ||
config.base_dir, | ||
) | ||
elif command == "build": | ||
if len(sys.argv) != 3: | ||
raise ValueError( | ||
"You provided an incorrect number of arguments.\nUsage: commit0 build {repo_split}" | ||
) | ||
commit0.harness.build.main( | ||
config.dataset_name, | ||
config.dataset_split, | ||
config.repo_split, | ||
config.num_workers, | ||
config.backend, | ||
) | ||
elif command == "get-tests": | ||
if len(sys.argv) != 3: | ||
raise ValueError( | ||
"You provided an incorrect number of arguments.\nUsage: commit0 get-tests {repo_name}" | ||
) | ||
repo = sys.argv[2] | ||
commit0.harness.get_pytest_ids.main(repo, stdout=True) | ||
elif command == "test" or command == "test-reference": | ||
# this command assume execution in arbitrary working directory | ||
repo_or_repo_path = sys.argv[2] | ||
if command == "test-reference": | ||
if len(sys.argv) != 4: | ||
raise ValueError( | ||
"You provided an incorrect number of arguments.\nUsage: commit0 test-reference {repo_dir} {test_ids}" | ||
) | ||
branch = "reference" | ||
test_ids = sys.argv[3] | ||
else: | ||
if len(sys.argv) != 5: | ||
raise ValueError( | ||
"You provided an incorrect number of arguments.\nUsage: commit0 test {repo_dir} {branch} {test_ids}" | ||
) | ||
branch = sys.argv[3] | ||
test_ids = sys.argv[4] | ||
if branch.startswith("branch="): | ||
branch = branch[len("branch=") :] | ||
commit0.harness.run_pytest_ids.main( | ||
config.dataset_name, | ||
config.dataset_split, | ||
config.base_dir, | ||
repo_or_repo_path, | ||
branch, | ||
test_ids, | ||
config.backend, | ||
config.timeout, | ||
config.num_cpus, | ||
stdout=True, | ||
) | ||
elif command == "evaluate" or command == "evaluate-reference": | ||
if command == "evaluate-reference": | ||
if len(sys.argv) != 3: | ||
raise ValueError( | ||
"You provided an incorrect number of arguments.\nUsage: commit0 evaluate-reference {repo_split}" | ||
) | ||
branch = "reference" | ||
else: | ||
if len(sys.argv) != 4: | ||
raise ValueError( | ||
"You provided an incorrect number of arguments.\nUsage: commit0 evaluate {repo_split} {branch}" | ||
) | ||
branch = sys.argv[3] | ||
if branch.startswith("branch="): | ||
branch = branch[len("branch=") :] | ||
commit0.harness.evaluate.main( | ||
config.dataset_name, | ||
config.dataset_split, | ||
config.repo_split, | ||
config.base_dir, | ||
branch, | ||
config.backend, | ||
config.timeout, | ||
config.num_cpus, | ||
config.num_workers, | ||
) | ||
elif command == "lint": | ||
files = sys.argv[1:] | ||
commit0.harness.lint.main(config.base_dir, files) | ||
elif command == "save": | ||
if len(sys.argv) != 5: | ||
raise ValueError( | ||
"You provided an incorrect number of arguments.\nUsage: commit0 save {repo_split} {owner} {branch}" | ||
) | ||
owner = sys.argv[3] | ||
branch = sys.argv[4] | ||
if branch.startswith("branch="): | ||
branch = branch[len("branch=") :] | ||
commit0.harness.save.main( | ||
config.dataset_name, | ||
config.dataset_split, | ||
config.repo_split, | ||
config.base_dir, | ||
owner, | ||
branch, | ||
config.github_token, | ||
) | ||
"""Main function to run the CLI""" | ||
commit0_app() | ||
|
||
|
||
if __name__ == "__main__": | ||
main() | ||
|
||
__all__ = [] |
Oops, something went wrong.