diff --git a/generate_self_schema.py b/generate_self_schema.py index acecf19d8..ac12062f3 100644 --- a/generate_self_schema.py +++ b/generate_self_schema.py @@ -10,10 +10,11 @@ import decimal import importlib.util import re +import sys from collections.abc import Callable from datetime import date, datetime, time, timedelta from pathlib import Path -from typing import TYPE_CHECKING, Any, Dict, ForwardRef, List, Set, Type, Union +from typing import TYPE_CHECKING, Any, Dict, ForwardRef, List, Pattern, Set, Type, Union from typing_extensions import TypedDict, get_args, get_origin, is_typeddict @@ -46,7 +47,7 @@ schema_ref_validator = {'type': 'definition-ref', 'schema_ref': 'root-schema'} -def get_schema(obj: Any, definitions: dict[str, core_schema.CoreSchema]) -> core_schema.CoreSchema: +def get_schema(obj: Any, definitions: dict[str, core_schema.CoreSchema]) -> core_schema.CoreSchema: # noqa: C901 if isinstance(obj, str): return {'type': obj} elif obj in (datetime, timedelta, date, time, bool, int, float, str, decimal.Decimal): @@ -81,6 +82,9 @@ def get_schema(obj: Any, definitions: dict[str, core_schema.CoreSchema]) -> core elif issubclass(origin, Type): # can't really use 'is-instance' since this is used for the class_ parameter of 'is-instance' validators return {'type': 'any'} + elif origin in (Pattern, re.Pattern): + # can't really use 'is-instance' easily with Pattern, so we use `any` as a placeholder for now + return {'type': 'any'} else: # debug(obj) raise TypeError(f'Unknown type: {obj!r}') @@ -189,16 +193,12 @@ def all_literal_values(type_: type[core_schema.Literal]) -> list[any]: def eval_forward_ref(type_: Any) -> Any: - try: - try: - # Python 3.12+ - return type_._evaluate(core_schema.__dict__, None, type_params=set(), recursive_guard=set()) - except TypeError: - # Python 3.9+ - return type_._evaluate(core_schema.__dict__, None, set()) - except TypeError: - # for Python 3.8 + if sys.version_info < (3, 9): return type_._evaluate(core_schema.__dict__, None) + elif sys.version_info < (3, 12, 4): + return type_._evaluate(core_schema.__dict__, None, recursive_guard=set()) + else: + return type_._evaluate(core_schema.__dict__, None, type_params=set(), recursive_guard=set()) def main() -> None: diff --git a/python/pydantic_core/core_schema.py b/python/pydantic_core/core_schema.py index 13dc26317..6644c7fc8 100644 --- a/python/pydantic_core/core_schema.py +++ b/python/pydantic_core/core_schema.py @@ -10,7 +10,7 @@ from collections.abc import Mapping from datetime import date, datetime, time, timedelta from decimal import Decimal -from typing import TYPE_CHECKING, Any, Callable, Dict, Hashable, List, Set, Tuple, Type, Union +from typing import TYPE_CHECKING, Any, Callable, Dict, Hashable, List, Pattern, Set, Tuple, Type, Union from typing_extensions import deprecated @@ -744,7 +744,7 @@ def decimal_schema( class StringSchema(TypedDict, total=False): type: Required[Literal['str']] - pattern: str + pattern: Union[str, Pattern[str]] max_length: int min_length: int strip_whitespace: bool @@ -760,7 +760,7 @@ class StringSchema(TypedDict, total=False): def str_schema( *, - pattern: str | None = None, + pattern: str | Pattern[str] | None = None, max_length: int | None = None, min_length: int | None = None, strip_whitespace: bool | None = None, diff --git a/src/validators/string.rs b/src/validators/string.rs index a7becf6bb..672bc64f5 100644 --- a/src/validators/string.rs +++ b/src/validators/string.rs @@ -164,7 +164,7 @@ impl StrConstrainedValidator { .map(|s| s.to_str()) .transpose()? .unwrap_or(RegexEngine::RUST_REGEX); - Pattern::compile(py, s, regex_engine) + Pattern::compile(s, regex_engine) }) .transpose()?; let min_length: Option = @@ -230,18 +230,47 @@ impl RegexEngine { } impl Pattern { - fn compile(py: Python<'_>, pattern: String, engine: &str) -> PyResult { - let engine = match engine { - RegexEngine::RUST_REGEX => { - RegexEngine::RustRegex(Regex::new(&pattern).map_err(|e| py_schema_error_type!("{}", e))?) - } - RegexEngine::PYTHON_RE => { - let re_compile = py.import_bound(intern!(py, "re"))?.getattr(intern!(py, "compile"))?; - RegexEngine::PythonRe(re_compile.call1((&pattern,))?.into()) - } - _ => return Err(py_schema_error_type!("Invalid regex engine: {}", engine)), - }; - Ok(Self { pattern, engine }) + fn extract_pattern_str(pattern: &Bound<'_, PyAny>) -> PyResult { + if pattern.is_instance_of::() { + Ok(pattern.to_string()) + } else { + pattern + .getattr("pattern") + .and_then(|attr| attr.extract::()) + .map_err(|_| py_schema_error_type!("Invalid pattern, must be str or re.Pattern: {}", pattern)) + } + } + + fn compile(pattern: Bound<'_, PyAny>, engine: &str) -> PyResult { + let pattern_str = Self::extract_pattern_str(&pattern)?; + + let py = pattern.py(); + + let re_module = py.import_bound(intern!(py, "re"))?; + let re_compile = re_module.getattr(intern!(py, "compile"))?; + let re_pattern = re_module.getattr(intern!(py, "Pattern"))?; + + if pattern.is_instance(&re_pattern)? { + // if the pattern is already a compiled regex object, we default to using the python re engine + // so that any flags, etc. are preserved + Ok(Self { + pattern: pattern_str, + engine: RegexEngine::PythonRe(pattern.to_object(py)), + }) + } else { + let engine = match engine { + RegexEngine::RUST_REGEX => { + RegexEngine::RustRegex(Regex::new(&pattern_str).map_err(|e| py_schema_error_type!("{}", e))?) + } + RegexEngine::PYTHON_RE => RegexEngine::PythonRe(re_compile.call1((pattern,))?.into()), + _ => return Err(py_schema_error_type!("Invalid regex engine: {}", engine)), + }; + + Ok(Self { + pattern: pattern_str, + engine, + }) + } } fn is_match(&self, py: Python<'_>, target: &str) -> PyResult { diff --git a/tests/validators/test_string.py b/tests/validators/test_string.py index f4f5f3888..75edcc1b0 100644 --- a/tests/validators/test_string.py +++ b/tests/validators/test_string.py @@ -398,3 +398,10 @@ def test_coerce_numbers_to_str_schema_with_strict_mode(number: int): v.validate_python(number) with pytest.raises(ValidationError): v.validate_json(str(number)) + + +@pytest.mark.parametrize('engine', [None, 'rust-regex', 'python-re']) +def test_compiled_regex(engine) -> None: + v = SchemaValidator(core_schema.str_schema(pattern=re.compile('abc', re.IGNORECASE), regex_engine=engine)) + assert v.validate_python('abc') == 'abc' + assert v.validate_python('ABC') == 'ABC'