diff --git a/plugin/utils.py b/plugin/utils.py index 8cc27f5..762f53b 100644 --- a/plugin/utils.py +++ b/plugin/utils.py @@ -1,9 +1,50 @@ from __future__ import annotations +import io import os -import re import subprocess -from typing import Any +import sys +from collections.abc import Generator, Iterable +from typing import Any, TypeVar + +_T = TypeVar("_T") + + +def camel_to_snake(s: str) -> str: + """Converts "CamelCase" to "snake_case".""" + return "".join((f"_{c}" if c.isupper() else c) for c in s).strip("_").lower() + + +def snake_to_camel(s: str, *, upper_first: bool = True) -> str: + """Converts "snake_case" to "CamelCase".""" + first, *others = s.split("_") + return (first.title() if upper_first else first.lower()) + "".join(map(str.title, others)) + + +if sys.version_info >= (3, 9): + remove_prefix = str.removeprefix + remove_suffix = str.removesuffix +else: + + def remove_prefix(s: str, prefix: str) -> str: + """Remove the prefix from the string. I.e., str.removeprefix in Python 3.9.""" + return s[len(prefix) :] if s.startswith(prefix) else s + + def remove_suffix(s: str, suffix: str) -> str: + """Remove the suffix from the string. I.e., str.removesuffix in Python 3.9.""" + # suffix="" should not call s[:-0] + return s[: -len(suffix)] if suffix and s.endswith(suffix) else s + + +def drop_falsy(iterable: Iterable[_T | None]) -> Generator[_T, None, None]: + """Drops falsy values from the iterable.""" + yield from filter(None, iterable) + + +def iterate_by_line(s: str) -> Generator[str, None, None]: + """Iterates over lines of the string.""" + with io.StringIO(s) as f: + yield from f def get_default_startupinfo() -> Any: @@ -14,7 +55,3 @@ def get_default_startupinfo() -> Any: STARTUPINFO.wShowWindow = subprocess.SW_HIDE # type: ignore return STARTUPINFO return None - - -def lowercase_drive_letter(path: str) -> str: - return re.sub(r"^[A-Z]+(?=:\\)", lambda m: m.group(0).lower(), path) diff --git a/plugin/venv_finder.py b/plugin/venv_finder.py index bbf2dbd..a61fdf7 100644 --- a/plugin/venv_finder.py +++ b/plugin/venv_finder.py @@ -2,7 +2,6 @@ import configparser import os -import re import shutil import subprocess from abc import ABC, abstractmethod @@ -16,6 +15,7 @@ from typing_extensions import Self from .log import log_error +from .utils import camel_to_snake, get_default_startupinfo, iterate_by_line, remove_suffix def find_venv_by_finder_names(finder_names: Sequence[str], *, project_dir: Path) -> VenvInfo | None: @@ -152,32 +152,15 @@ def from_pyvenv_cfg_file(cls, pyvenv_cfg_file: str | Path) -> Self | None: return cls.from_venv_dir(venv_dir) @staticmethod - def parse_pyvenv_cfg(pyvenv_cfg: Path) -> dict[str, Any]: + def parse_pyvenv_cfg(pyvenv_cfg: Path) -> dict[str, str]: # value of these keys are expected to be a string - str_attr = {"command", "executable", "home", "implementation", "prompt", "uv", "version", "version_info"} - - def _cast(key: str, val: str) -> Any: - if key in str_attr: - return val - if val.lower() == "true": - return True - if val.lower() == "false": - return False - if val.isdigit(): - return int(val) - try: - return float(val) - except ValueError: - pass - return val - config = configparser.ConfigParser() try: content = pyvenv_cfg.read_text(encoding="utf-8") config.read_string(f"[USER]\n{content}") except Exception: return {} - return {k: _cast(k, v) for k, v in config.items("USER")} + return dict(config.items("USER")) class BaseVenvFinder(ABC): @@ -188,12 +171,7 @@ def __init__(self, project_dir: Path) -> None: @final @classmethod def name(cls) -> str: - name = cls.__name__ - # remove trailing "VenvFinder" - if name.endswith("VenvFinder"): - name = name[: -len("VenvFinder")] - # CamelCase to snake_case - return "".join(f"_{c.lower()}" if c.isupper() else c for c in name).lstrip("_") + return camel_to_snake(remove_suffix(cls.__name__, "VenvFinder")) @final @classmethod @@ -227,12 +205,12 @@ def _find_venv(self) -> VenvInfo | None: """Find the virtual environment. Implement this method by the subclass.""" @staticmethod - def _find_from_venv_dir_candidates(candidates: Iterable[Path]) -> VenvInfo | None: + def _find_from_venv_dirs(venv_dirs: Iterable[Path]) -> VenvInfo | None: def _filtered_candidates() -> Generator[Path, None, None]: - for candidate in candidates: + for venv_dir in venv_dirs: try: - if candidate.is_dir(): - yield candidate + if venv_dir.is_dir(): + yield venv_dir except PermissionError: pass @@ -240,20 +218,12 @@ def _filtered_candidates() -> Generator[Path, None, None]: @staticmethod def _run_shell_command(command: str, *, cwd: Path | None = None) -> tuple[str, str, int] | None: - if os.name == "nt": - # do not create a window for the process - startupinfo = subprocess.STARTUPINFO() # type: ignore - startupinfo.dwFlags |= subprocess.STARTF_USESHOWWINDOW # type: ignore - startupinfo.wShowWindow = subprocess.SW_HIDE # type: ignore - else: - startupinfo = None - try: proc = subprocess.Popen( command, cwd=cwd, shell=True, - startupinfo=startupinfo, + startupinfo=get_default_startupinfo(), stdout=subprocess.PIPE, stderr=subprocess.PIPE, universal_newlines=True, @@ -277,7 +247,7 @@ def _can_support(cls, project_dir: Path) -> bool: return True def _find_venv(self) -> VenvInfo | None: - return self._find_from_venv_dir_candidates(self.project_dir.iterdir()) + return self._find_from_venv_dirs(self.project_dir.iterdir()) class EnvVarCondaPrefixVenvFinder(BaseVenvFinder): @@ -289,12 +259,10 @@ class EnvVarCondaPrefixVenvFinder(BaseVenvFinder): @classmethod def _can_support(cls, project_dir: Path) -> bool: - return True + return "CONDA_PREFIX" in os.environ def _find_venv(self) -> VenvInfo | None: - if conda_prefix := os.environ.get("CONDA_PREFIX", ""): - return VenvInfo.from_venv_dir(conda_prefix) - return None + return VenvInfo.from_venv_dir(os.environ["CONDA_PREFIX"]) class EnvVarVirtualEnvVenvFinder(BaseVenvFinder): @@ -306,12 +274,10 @@ class EnvVarVirtualEnvVenvFinder(BaseVenvFinder): @classmethod def _can_support(cls, project_dir: Path) -> bool: - return True + return "VIRTUAL_ENV" in os.environ def _find_venv(self) -> VenvInfo | None: - if virtual_env := os.environ.get("VIRTUAL_ENV", ""): - return VenvInfo.from_venv_dir(virtual_env) - return None + return VenvInfo.from_venv_dir(os.environ["VIRTUAL_ENV"]) class LocalDotVenvVenvFinder(BaseVenvFinder): @@ -326,7 +292,7 @@ def _can_support(cls, project_dir: Path) -> bool: return True def _find_venv(self) -> VenvInfo | None: - return self._find_from_venv_dir_candidates(( + return self._find_from_venv_dirs(( self.project_dir / ".venv", self.project_dir / "venv", )) @@ -452,7 +418,8 @@ def _find_venv(self) -> VenvInfo | None: return None stdout, _, _ = output - if m := re.search(r"^venv: (.*)$", stdout, re.MULTILINE): - venv_dir = m.group(1) - return VenvInfo.from_venv_dir(venv_dir) + for line in iterate_by_line(stdout): + pre, sep, post = line.partition(":") + if sep and pre == "venv": + return VenvInfo.from_venv_dir(post.strip()) return None