Skip to content

Commit

Permalink
feat: allow using an authenticated Client in Gateway (#974)
Browse files Browse the repository at this point in the history
  • Loading branch information
calebzulawski authored Dec 10, 2024
1 parent c74c657 commit 395116a
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 1 deletion.
5 changes: 5 additions & 0 deletions py-rattler/rattler/repo_data/gateway.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

from rattler.channel import Channel
from rattler.match_spec import MatchSpec
from rattler.networking import Client
from rattler.repo_data.record import RepoDataRecord
from rattler.platform import Platform, PlatformLiteral
from rattler.package.package_name import PackageName
Expand Down Expand Up @@ -90,6 +91,7 @@ def __init__(
default_config: Optional[SourceConfig] = None,
per_channel_config: Optional[dict[str, SourceConfig]] = None,
max_concurrent_requests: int = 100,
client: Optional[Client] = None,
) -> None:
"""
Arguments:
Expand All @@ -100,6 +102,8 @@ def __init__(
prefix, so any channel that starts with the URL uses the configuration.
The configuration with the longest matching prefix is used.
max_concurrent_requests: The maximum number of concurrent requests that can be made.
client: An authenticated client to use for acquiring repodata. If not specified a default
client will be used.
Examples
--------
Expand All @@ -119,6 +123,7 @@ def __init__(
for channel, config in (per_channel_config or {}).items()
},
max_concurrent_requests=max_concurrent_requests,
client=client._client if client is not None else None,
)

async def query(
Expand Down
8 changes: 7 additions & 1 deletion py-rattler/src/repo_data/gateway.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
use crate::error::PyRattlerError;
use crate::match_spec::PyMatchSpec;
use crate::networking::client::PyClientWithMiddleware;
use crate::package_name::PyPackageName;
use crate::platform::PyPlatform;
use crate::record::PyRecord;
Expand Down Expand Up @@ -50,13 +51,14 @@ impl<'source> FromPyObject<'source> for Wrap<SubdirSelection> {
#[pymethods]
impl PyGateway {
#[new]
#[pyo3(signature = (max_concurrent_requests, default_config, per_channel_config, cache_dir=None)
#[pyo3(signature = (max_concurrent_requests, default_config, per_channel_config, cache_dir=None, client=None)
)]
pub fn new(
max_concurrent_requests: usize,
default_config: PySourceConfig,
per_channel_config: HashMap<String, PySourceConfig>,
cache_dir: Option<PathBuf>,
client: Option<PyClientWithMiddleware>,
) -> PyResult<Self> {
let channel_config = ChannelConfig {
default: default_config.into(),
Expand All @@ -77,6 +79,10 @@ impl PyGateway {
gateway.set_cache_dir(cache_dir);
}

if let Some(client) = client {
gateway.set_client(client.into());
}

Ok(Self {
inner: gateway.finish(),
})
Expand Down

0 comments on commit 395116a

Please sign in to comment.