diff --git a/snakemake_interface_executor_plugins/utils.py b/snakemake_interface_executor_plugins/utils.py index 7749871..b72ac7c 100644 --- a/snakemake_interface_executor_plugins/utils.py +++ b/snakemake_interface_executor_plugins/utils.py @@ -6,6 +6,7 @@ import asyncio from collections import UserDict from pathlib import Path +import re import shlex import threading from typing import Any, List @@ -39,16 +40,25 @@ def format_cli_pos_arg(value, quote=True): elif not_iterable(value): return format_cli_value(value) else: - return join_cli_args(format_cli_value(v) for v in value) + return join_cli_args( + format_cli_value(v, quote_if_contains_whitespace=True) for v in value + ) -def format_cli_value(value: Any) -> str: +def format_cli_value(value: Any, quote_if_contains_whitespace: bool = False) -> str: if isinstance(value, SettingsEnumBase): return value.item_to_choice() elif isinstance(value, Path): return shlex.quote(str(value)) elif isinstance(value, str): - return shlex.quote(value) + if is_quoted(value): + # the value is already quoted, do not quote again + return value + elif quote_if_contains_whitespace and " " in value: + # may be expression + return repr(value) + else: + return value else: return repr(value) @@ -99,3 +109,10 @@ async def async_lock(_lock: threading.Lock): yield # the lock is held finally: _lock.release() + + +_is_quoted_re = re.compile(r"^['\"].+['\"]") + + +def is_quoted(value: str) -> bool: + return _is_quoted_re.match(value) is not None diff --git a/tests/tests.py b/tests/tests.py index 6cc3f02..3c3453b 100644 --- a/tests/tests.py +++ b/tests/tests.py @@ -3,6 +3,7 @@ from snakemake_interface_common.plugin_registry.tests import TestRegistryBase from snakemake_interface_common.plugin_registry.plugin import PluginBase, SettingsBase from snakemake_interface_common.plugin_registry import PluginRegistryBase +from snakemake_interface_executor_plugins.utils import format_cli_arg class TestRegistry(TestRegistryBase): @@ -26,3 +27,30 @@ def validate_settings(self, settings: SettingsBase, plugin: PluginBase): def get_example_args(self) -> List[str]: return ["--cluster-generic-submit-cmd", "qsub"] + + +def test_format_cli_arg_single_quote(): + fmt = format_cli_arg("--default-resources", {"slurm_extra": "'--gres=gpu:1'"}) + assert fmt == "--default-resources \"slurm_extra='--gres=gpu:1'\"" + + +def test_format_cli_arg_double_quote(): + fmt = format_cli_arg("--default-resources", {"slurm_extra": '"--gres=gpu:1"'}) + assert fmt == "--default-resources 'slurm_extra=\"--gres=gpu:1\"'" + + +def test_format_cli_arg_int(): + fmt = format_cli_arg("--default-resources", {"mem_mb": 200}) + assert fmt == "--default-resources 'mem_mb=200'" + + +def test_format_cli_arg_expr(): + fmt = format_cli_arg( + "--default-resources", {"mem_mb": "min(2 * input.size_mb, 2000)"} + ) + assert fmt == "--default-resources 'mem_mb=min(2 * input.size_mb, 2000)'" + + +def test_format_cli_arg_list(): + fmt = format_cli_arg("--config", ["foo={'bar': 1}"]) + assert fmt == "--config \"foo={'bar': 1}\""