Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: fix quoting of string arguments that are passed to spawned jobs #60

Merged
merged 7 commits into from
Mar 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 20 additions & 3 deletions snakemake_interface_executor_plugins/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import asyncio
from collections import UserDict
from pathlib import Path
import re
import shlex
import threading
from typing import Any, List
Expand Down Expand Up @@ -39,16 +40,25 @@ def format_cli_pos_arg(value, quote=True):
elif not_iterable(value):
return format_cli_value(value)
else:
return join_cli_args(format_cli_value(v) for v in value)
return join_cli_args(
format_cli_value(v, quote_if_contains_whitespace=True) for v in value
)


def format_cli_value(value: Any) -> str:
def format_cli_value(value: Any, quote_if_contains_whitespace: bool = False) -> str:
if isinstance(value, SettingsEnumBase):
return value.item_to_choice()
elif isinstance(value, Path):
return shlex.quote(str(value))
elif isinstance(value, str):
return shlex.quote(value)
if is_quoted(value):
# the value is already quoted, do not quote again
return value
elif quote_if_contains_whitespace and " " in value:
# may be expression
return repr(value)
else:
return value
else:
return repr(value)

Expand Down Expand Up @@ -99,3 +109,10 @@ async def async_lock(_lock: threading.Lock):
yield # the lock is held
finally:
_lock.release()


_is_quoted_re = re.compile(r"^['\"].+['\"]")


def is_quoted(value: str) -> bool:
return _is_quoted_re.match(value) is not None
28 changes: 28 additions & 0 deletions tests/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from snakemake_interface_common.plugin_registry.tests import TestRegistryBase
from snakemake_interface_common.plugin_registry.plugin import PluginBase, SettingsBase
from snakemake_interface_common.plugin_registry import PluginRegistryBase
from snakemake_interface_executor_plugins.utils import format_cli_arg


class TestRegistry(TestRegistryBase):
Expand All @@ -26,3 +27,30 @@ def validate_settings(self, settings: SettingsBase, plugin: PluginBase):

def get_example_args(self) -> List[str]:
return ["--cluster-generic-submit-cmd", "qsub"]


def test_format_cli_arg_single_quote():
fmt = format_cli_arg("--default-resources", {"slurm_extra": "'--gres=gpu:1'"})
assert fmt == "--default-resources \"slurm_extra='--gres=gpu:1'\""


def test_format_cli_arg_double_quote():
fmt = format_cli_arg("--default-resources", {"slurm_extra": '"--gres=gpu:1"'})
assert fmt == "--default-resources 'slurm_extra=\"--gres=gpu:1\"'"


def test_format_cli_arg_int():
fmt = format_cli_arg("--default-resources", {"mem_mb": 200})
assert fmt == "--default-resources 'mem_mb=200'"


def test_format_cli_arg_expr():
fmt = format_cli_arg(
"--default-resources", {"mem_mb": "min(2 * input.size_mb, 2000)"}
)
assert fmt == "--default-resources 'mem_mb=min(2 * input.size_mb, 2000)'"


def test_format_cli_arg_list():
fmt = format_cli_arg("--config", ["foo={'bar': 1}"])
assert fmt == "--config \"foo={'bar': 1}\""
Loading