From 395116a534a4979eec20509ceee91cad661ca0f1 Mon Sep 17 00:00:00 2001 From: Caleb Zulawski Date: Tue, 10 Dec 2024 10:50:12 -0500 Subject: [PATCH] feat: allow using an authenticated `Client` in `Gateway` (#974) --- py-rattler/rattler/repo_data/gateway.py | 5 +++++ py-rattler/src/repo_data/gateway.rs | 8 +++++++- 2 files changed, 12 insertions(+), 1 deletion(-) diff --git a/py-rattler/rattler/repo_data/gateway.py b/py-rattler/rattler/repo_data/gateway.py index cee4612ec..eeec05885 100644 --- a/py-rattler/rattler/repo_data/gateway.py +++ b/py-rattler/rattler/repo_data/gateway.py @@ -8,6 +8,7 @@ from rattler.channel import Channel from rattler.match_spec import MatchSpec +from rattler.networking import Client from rattler.repo_data.record import RepoDataRecord from rattler.platform import Platform, PlatformLiteral from rattler.package.package_name import PackageName @@ -90,6 +91,7 @@ def __init__( default_config: Optional[SourceConfig] = None, per_channel_config: Optional[dict[str, SourceConfig]] = None, max_concurrent_requests: int = 100, + client: Optional[Client] = None, ) -> None: """ Arguments: @@ -100,6 +102,8 @@ def __init__( prefix, so any channel that starts with the URL uses the configuration. The configuration with the longest matching prefix is used. max_concurrent_requests: The maximum number of concurrent requests that can be made. + client: An authenticated client to use for acquiring repodata. If not specified a default + client will be used. Examples -------- @@ -119,6 +123,7 @@ def __init__( for channel, config in (per_channel_config or {}).items() }, max_concurrent_requests=max_concurrent_requests, + client=client._client if client is not None else None, ) async def query( diff --git a/py-rattler/src/repo_data/gateway.rs b/py-rattler/src/repo_data/gateway.rs index 97aee99c0..36a5a5a2d 100644 --- a/py-rattler/src/repo_data/gateway.rs +++ b/py-rattler/src/repo_data/gateway.rs @@ -1,5 +1,6 @@ use crate::error::PyRattlerError; use crate::match_spec::PyMatchSpec; +use crate::networking::client::PyClientWithMiddleware; use crate::package_name::PyPackageName; use crate::platform::PyPlatform; use crate::record::PyRecord; @@ -50,13 +51,14 @@ impl<'source> FromPyObject<'source> for Wrap { #[pymethods] impl PyGateway { #[new] - #[pyo3(signature = (max_concurrent_requests, default_config, per_channel_config, cache_dir=None) + #[pyo3(signature = (max_concurrent_requests, default_config, per_channel_config, cache_dir=None, client=None) )] pub fn new( max_concurrent_requests: usize, default_config: PySourceConfig, per_channel_config: HashMap, cache_dir: Option, + client: Option, ) -> PyResult { let channel_config = ChannelConfig { default: default_config.into(), @@ -77,6 +79,10 @@ impl PyGateway { gateway.set_cache_dir(cache_dir); } + if let Some(client) = client { + gateway.set_client(client.into()); + } + Ok(Self { inner: gateway.finish(), })