diff --git a/tests/main_test.py b/tests/main_test.py index bf6f7a79..c4050f7d 100644 --- a/tests/main_test.py +++ b/tests/main_test.py @@ -8,6 +8,7 @@ from collections.abc import Sequence from pathlib import Path from typing import Any +from unittest.mock import Mock import pytest import pytest_mock @@ -449,6 +450,34 @@ def test_cache_timeouts(tmp_path: Path) -> None: tldextract.suffix_list.find_first_response(cache, [server], 5) +@responses.activate +def test_find_first_response_without_session(tmp_path: Path) -> None: + """Test it is able to find first response without session passed in.""" + server = "http://some-server.com" + response_text = "server response" + responses.add(responses.GET, server, status=200, body=response_text) + cache = DiskCache(str(tmp_path)) + + result = tldextract.suffix_list.find_first_response(cache, [server], 5) + assert result == response_text + + +def test_find_first_response_with_session(tmp_path: Path) -> None: + """Test it is able to find first response with passed in session.""" + server = "http://some-server.com" + response_text = "server response" + cache = DiskCache(str(tmp_path)) + mock_session = Mock() + mock_session.get.return_value.text = response_text + + result = tldextract.suffix_list.find_first_response( + cache, [server], 5, mock_session + ) + assert result == response_text + mock_session.get.assert_called_once_with(server, timeout=5) + mock_session.close.assert_not_called() + + def test_include_psl_private_domain_attr() -> None: """Test private domains, which default to not being treated differently.""" extract_private = tldextract.TLDExtract(include_psl_private_domains=True) diff --git a/tldextract/suffix_list.py b/tldextract/suffix_list.py index 62427367..192f6333 100644 --- a/tldextract/suffix_list.py +++ b/tldextract/suffix_list.py @@ -31,11 +31,16 @@ def find_first_response( cache: DiskCache, urls: Sequence[str], cache_fetch_timeout: float | int | None = None, + session: requests.Session | None = None, ) -> str: """Decode the first successfully fetched URL, from UTF-8 encoding to Python unicode.""" - with requests.Session() as session: + session_created = False + if session is None: + session = requests.Session() session.mount("file://", FileAdapter()) + session_created = True + try: for url in urls: try: return cache.cached_fetch_url( @@ -43,6 +48,11 @@ def find_first_response( ) except requests.exceptions.RequestException: LOG.exception("Exception reading Public Suffix List url %s", url) + finally: + # Ensure the session is always closed if it's constructed in the method + if session_created: + session.close() + raise SuffixListNotFound( "No remote Public Suffix List found. Consider using a mirror, or avoid this" " fetch by constructing your TLDExtract with `suffix_list_urls=()`." @@ -65,6 +75,7 @@ def get_suffix_lists( urls: Sequence[str], cache_fetch_timeout: float | int | None, fallback_to_snapshot: bool, + session: requests.Session | None = None, ) -> tuple[list[str], list[str]]: """Fetch, parse, and cache the suffix lists.""" return cache.run_and_cache( @@ -75,6 +86,7 @@ def get_suffix_lists( "urls": urls, "cache_fetch_timeout": cache_fetch_timeout, "fallback_to_snapshot": fallback_to_snapshot, + "session": session, }, hashed_argnames=["urls", "fallback_to_snapshot"], ) @@ -85,10 +97,13 @@ def _get_suffix_lists( urls: Sequence[str], cache_fetch_timeout: float | int | None, fallback_to_snapshot: bool, + session: requests.Session | None = None, ) -> tuple[list[str], list[str]]: """Fetch, parse, and cache the suffix lists.""" try: - text = find_first_response(cache, urls, cache_fetch_timeout=cache_fetch_timeout) + text = find_first_response( + cache, urls, cache_fetch_timeout=cache_fetch_timeout, session=session + ) except SuffixListNotFound as exc: if fallback_to_snapshot: maybe_pkg_data = pkgutil.get_data("tldextract", ".tld_set_snapshot") diff --git a/tldextract/tldextract.py b/tldextract/tldextract.py index 95a7acd0..902cae69 100644 --- a/tldextract/tldextract.py +++ b/tldextract/tldextract.py @@ -44,6 +44,7 @@ from functools import wraps import idna +import requests from .cache import DiskCache, get_cache_dir from .remote import lenient_netloc, looks_like_ip, looks_like_ipv6 @@ -221,13 +222,19 @@ def __init__( self._cache = DiskCache(cache_dir) def __call__( - self, url: str, include_psl_private_domains: bool | None = None + self, + url: str, + include_psl_private_domains: bool | None = None, + session: requests.Session | None = None, ) -> ExtractResult: """Alias for `extract_str`.""" - return self.extract_str(url, include_psl_private_domains) + return self.extract_str(url, include_psl_private_domains, session=session) def extract_str( - self, url: str, include_psl_private_domains: bool | None = None + self, + url: str, + include_psl_private_domains: bool | None = None, + session: requests.Session | None = None, ) -> ExtractResult: """Take a string URL and splits it into its subdomain, domain, and suffix components. @@ -238,13 +245,27 @@ def extract_str( ExtractResult(subdomain='forums.news', domain='cnn', suffix='com', is_private=False) >>> extractor.extract_str('http://forums.bbc.co.uk/') ExtractResult(subdomain='forums', domain='bbc', suffix='co.uk', is_private=False) + + Allows configuring the HTTP request via the optional `session` + parameter. For example, if you need to use a HTTP proxy. See also + `requests.Session`. + + >>> import requests + >>> session = requests.Session() + >>> # customize your session here + >>> with session: + ... extractor.extract_str("http://forums.news.cnn.com/", session=session) + ExtractResult(subdomain='forums.news', domain='cnn', suffix='com', is_private=False) """ - return self._extract_netloc(lenient_netloc(url), include_psl_private_domains) + return self._extract_netloc( + lenient_netloc(url), include_psl_private_domains, session=session + ) def extract_urllib( self, url: urllib.parse.ParseResult | urllib.parse.SplitResult, include_psl_private_domains: bool | None = None, + session: requests.Session | None = None, ) -> ExtractResult: """Take the output of urllib.parse URL parsing methods and further splits the parsed URL. @@ -260,10 +281,15 @@ def extract_urllib( >>> extractor.extract_urllib(urllib.parse.urlsplit('http://forums.bbc.co.uk/')) ExtractResult(subdomain='forums', domain='bbc', suffix='co.uk', is_private=False) """ - return self._extract_netloc(url.netloc, include_psl_private_domains) + return self._extract_netloc( + url.netloc, include_psl_private_domains, session=session + ) def _extract_netloc( - self, netloc: str, include_psl_private_domains: bool | None + self, + netloc: str, + include_psl_private_domains: bool | None, + session: requests.Session | None = None, ) -> ExtractResult: netloc_with_ascii_dots = ( netloc.replace("\u3002", "\u002e") @@ -282,9 +308,9 @@ def _extract_netloc( labels = netloc_with_ascii_dots.split(".") - suffix_index, is_private = self._get_tld_extractor().suffix_index( - labels, include_psl_private_domains=include_psl_private_domains - ) + suffix_index, is_private = self._get_tld_extractor( + session=session + ).suffix_index(labels, include_psl_private_domains=include_psl_private_domains) num_ipv4_labels = 4 if suffix_index == len(labels) == num_ipv4_labels and looks_like_ip( @@ -297,23 +323,27 @@ def _extract_netloc( domain = labels[suffix_index - 1] if suffix_index else "" return ExtractResult(subdomain, domain, suffix, is_private) - def update(self, fetch_now: bool = False) -> None: + def update( + self, fetch_now: bool = False, session: requests.Session | None = None + ) -> None: """Force fetch the latest suffix list definitions.""" self._extractor = None self._cache.clear() if fetch_now: - self._get_tld_extractor() + self._get_tld_extractor(session=session) @property - def tlds(self) -> list[str]: + def tlds(self, session: requests.Session | None = None) -> list[str]: """ Returns the list of tld's used by default. This will vary based on `include_psl_private_domains` and `extra_suffixes` """ - return list(self._get_tld_extractor().tlds()) + return list(self._get_tld_extractor(session=session).tlds()) - def _get_tld_extractor(self) -> _PublicSuffixListTLDExtractor: + def _get_tld_extractor( + self, session: requests.Session | None = None + ) -> _PublicSuffixListTLDExtractor: """Get or compute this object's TLDExtractor. Looks up the TLDExtractor in roughly the following order, based on the @@ -332,6 +362,7 @@ def _get_tld_extractor(self) -> _PublicSuffixListTLDExtractor: urls=self.suffix_list_urls, cache_fetch_timeout=self.cache_fetch_timeout, fallback_to_snapshot=self.fallback_to_snapshot, + session=session, ) if not any([public_tlds, private_tlds, self.extra_suffixes]): @@ -400,9 +431,13 @@ def add_suffix(self, suffix: str, is_private: bool = False) -> None: @wraps(TLD_EXTRACTOR.__call__) def extract( # noqa: D103 - url: str, include_psl_private_domains: bool | None = False + url: str, + include_psl_private_domains: bool | None = False, + session: requests.Session | None = None, ) -> ExtractResult: - return TLD_EXTRACTOR(url, include_psl_private_domains=include_psl_private_domains) + return TLD_EXTRACTOR( + url, include_psl_private_domains=include_psl_private_domains, session=session + ) @wraps(TLD_EXTRACTOR.update)