Skip to content

Commit

Permalink
storage: fix auth compatibility for registry backend
Browse files Browse the repository at this point in the history
Signed-off-by: lihuahua123 <771725652@qq.com>
  • Loading branch information
lihuahua123 authored and imeoer committed Sep 27, 2023
1 parent aa9c95a commit d7b1851
Showing 1 changed file with 96 additions and 38 deletions.
134 changes: 96 additions & 38 deletions storage/src/backend/registry.rs
Original file line number Diff line number Diff line change
Expand Up @@ -93,19 +93,22 @@ impl Cache {
}

#[derive(Default)]
struct HashCache(RwLock<HashMap<String, String>>);
struct HashCache<T>(RwLock<HashMap<String, T>>);

impl HashCache {
impl<T> HashCache<T> {
fn new() -> Self {
HashCache(RwLock::new(HashMap::new()))
}

fn get(&self, key: &str) -> Option<String> {
fn get(&self, key: &str) -> Option<T>
where
T: Clone,
{
let cached_guard = self.0.read().unwrap();
cached_guard.get(key).cloned()
}

fn set(&self, key: String, value: String) {
fn set(&self, key: String, value: T) {
let mut cached_guard = self.0.write().unwrap();
cached_guard.insert(key, value);
}
Expand Down Expand Up @@ -136,6 +139,7 @@ struct BasicAuth {
}

#[derive(Debug, Clone)]
#[allow(dead_code)]
struct BearerAuth {
realm: String,
service: String,
Expand Down Expand Up @@ -189,10 +193,14 @@ struct RegistryState {
// Example: RwLock<"Bearer <token>">
// RwLock<"Basic base64(<username:password>)">
cached_auth: Cache,
// Cache for the HTTP method when getting auth, it is "true" when using "GET" method.
// Due to the different implementations of various image registries, auth requests
// may use the GET or POST methods, we need to cache the method after the
// fallback, so it can be reused next time and reduce an unnecessary request.
cached_auth_using_http_get: HashCache<bool>,
// Cache 30X redirect url
// Example: RwLock<HashMap<"<blob_id>", "<redirected_url>">>
cached_redirect: HashCache,

cached_redirect: HashCache<String>,
// The epoch timestamp of token expiration, which is obtained from the registry server.
token_expired_at: ArcSwapOption<u64>,
// Cache bearer auth for refreshing token.
Expand Down Expand Up @@ -238,12 +246,85 @@ impl RegistryState {
}
}

/// Request registry authentication server to get bearer token
// Request registry authentication server to get bearer token
fn get_token(&self, auth: BearerAuth, connection: &Arc<Connection>) -> Result<TokenResponse> {
// The information needed for getting token needs to be placed both in
// the query and in the body to be compatible with different registry
// implementations, which have been tested on these platforms:
// docker hub, harbor, github ghcr, aliyun acr.
let http_get = self
.cached_auth_using_http_get
.get(&self.host)
.unwrap_or_default();
let resp = if http_get {
self.get_token_with_get(&auth, connection)?
} else {
match self.get_token_with_post(&auth, connection) {
Ok(resp) => resp,
Err(err) => {
warn!("retry http GET method to get auth token: {}", err);
let resp = self.get_token_with_get(&auth, connection)?;
// Cache http method for next use.
self.cached_auth_using_http_get.set(self.host.clone(), true);
resp
}
}
};

let ret: TokenResponse = resp.json().map_err(|e| {
einval!(format!(
"registry auth server response decode failed: {:?}",
e
))
})?;

if let Ok(now_timestamp) = SystemTime::now().duration_since(UNIX_EPOCH) {
self.token_expired_at
.store(Some(Arc::new(now_timestamp.as_secs() + ret.expires_in)));
debug!(
"cached bearer auth, next time: {}",
now_timestamp.as_secs() + ret.expires_in
);
}

// Cache bearer auth for refreshing token.
self.cached_bearer_auth.store(Some(Arc::new(auth)));

Ok(ret)
}

// Get bearer token using a POST request
fn get_token_with_post(
&self,
auth: &BearerAuth,
connection: &Arc<Connection>,
) -> Result<Response> {
let mut form = HashMap::new();
form.insert("service".to_string(), auth.service.clone());
form.insert("scope".to_string(), auth.scope.clone());
form.insert("grant_type".to_string(), "password".to_string());
form.insert("username".to_string(), self.username.clone());
form.insert("passward".to_string(), self.password.clone());
form.insert("client_id".to_string(), REGISTRY_CLIENT_ID.to_string());

let mut headers = HeaderMap::new();

let token_resp = connection
.call::<&[u8]>(
Method::POST,
auth.realm.as_str(),
None,
Some(ReqBody::Form(form)),
&mut headers,
true,
)
.map_err(|e| einval!(format!("registry auth server request failed {:?}", e)))?;

Ok(token_resp)
}

// Get bearer token using a GET request
fn get_token_with_get(
&self,
auth: &BearerAuth,
connection: &Arc<Connection>,
) -> Result<Response> {
let query = [
("service", auth.service.as_str()),
("scope", auth.scope.as_str()),
Expand All @@ -253,45 +334,20 @@ impl RegistryState {
("client_id", REGISTRY_CLIENT_ID),
];

let mut form = HashMap::new();
for (k, v) in &query {
form.insert(k.to_string(), v.to_string());
}

let mut headers = HeaderMap::new();
if let Some(auth_header) = &auth.header {
headers.insert(HEADER_AUTHORIZATION, auth_header.clone());
}

let token_resp = connection
.call::<&[u8]>(
Method::GET,
auth.realm.as_str(),
Some(&query),
Some(ReqBody::Form(form)),
None,
&mut headers,
true,
)
.map_err(|e| einval!(format!("registry auth server request failed {:?}", e)))?;
let ret: TokenResponse = token_resp.json().map_err(|e| {
einval!(format!(
"registry auth server response decode failed: {:?}",
e
))
})?;
if let Ok(now_timestamp) = SystemTime::now().duration_since(UNIX_EPOCH) {
self.token_expired_at
.store(Some(Arc::new(now_timestamp.as_secs() + ret.expires_in)));
debug!(
"cached bearer auth, next time: {}",
now_timestamp.as_secs() + ret.expires_in
);
}

// Cache bearer auth for refreshing token.
self.cached_bearer_auth.store(Some(Arc::new(auth)));

Ok(ret)
Ok(token_resp)
}

fn get_auth_header(&self, auth: Auth, connection: &Arc<Connection>) -> Result<String> {
Expand Down Expand Up @@ -809,6 +865,7 @@ impl Registry {
retry_limit,
blob_url_scheme: config.blob_url_scheme.clone(),
blob_redirected_host: config.blob_redirected_host.clone(),
cached_auth_using_http_get: HashCache::new(),
cached_redirect: HashCache::new(),
token_expired_at: ArcSwapOption::new(None),
cached_bearer_auth: ArcSwapOption::new(None),
Expand Down Expand Up @@ -990,6 +1047,7 @@ mod tests {
retry_limit: 5,
blob_url_scheme: "https".to_string(),
blob_redirected_host: "oss.alibaba-inc.com".to_string(),
cached_auth_using_http_get: Default::default(),
cached_auth: Default::default(),
cached_redirect: Default::default(),
token_expired_at: ArcSwapOption::new(None),
Expand Down

0 comments on commit d7b1851

Please sign in to comment.