Skip to content

Commit

Permalink
refactor(catalog/rest): Split http client logic to seperate mod (apac…
Browse files Browse the repository at this point in the history
…he#423)

Signed-off-by: Xuanwo <github@xuanwo.io>
  • Loading branch information
Xuanwo authored and shaeqahmed committed Dec 9, 2024
1 parent 3eef000 commit 6339351
Show file tree
Hide file tree
Showing 3 changed files with 149 additions and 123 deletions.
147 changes: 24 additions & 123 deletions crates/catalog/rest/src/catalog.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,14 +23,14 @@ use std::str::FromStr;
use async_trait::async_trait;
use itertools::Itertools;
use reqwest::header::{self, HeaderMap, HeaderName, HeaderValue};
use reqwest::{Client, Request, Response, StatusCode, Url};
use serde::de::DeserializeOwned;
use reqwest::{Method, StatusCode, Url};
use typed_builder::TypedBuilder;
use urlencoding::encode;

use crate::catalog::_serde::{
CommitTableRequest, CommitTableResponse, CreateTableRequest, LoadTableResponse,
};
use crate::client::HttpClient;
use iceberg::io::FileIO;
use iceberg::table::Table;
use iceberg::Result;
Expand Down Expand Up @@ -162,10 +162,7 @@ impl RestCatalogConfig {
fn try_create_rest_client(&self) -> Result<HttpClient> {
// TODO: We will add ssl config, sigv4 later
let headers = self.http_headers()?;

Ok(HttpClient(
Client::builder().default_headers(headers).build()?,
))
HttpClient::try_create(headers)
}

fn optional_oauth_params(&self) -> HashMap<&str, &str> {
Expand All @@ -185,97 +182,6 @@ impl RestCatalogConfig {
}
}

#[derive(Debug)]
struct HttpClient(Client);

impl HttpClient {
async fn query<
R: DeserializeOwned,
E: DeserializeOwned + Into<Error>,
const SUCCESS_CODE: u16,
>(
&self,
request: Request,
) -> Result<R> {
let resp = self.0.execute(request).await?;

if resp.status().as_u16() == SUCCESS_CODE {
let text = resp.bytes().await?;
Ok(serde_json::from_slice::<R>(&text).map_err(|e| {
Error::new(
ErrorKind::Unexpected,
"Failed to parse response from rest catalog server!",
)
.with_context("json", String::from_utf8_lossy(&text))
.with_source(e)
})?)
} else {
let code = resp.status();
let text = resp.bytes().await?;
let e = serde_json::from_slice::<E>(&text).map_err(|e| {
Error::new(
ErrorKind::Unexpected,
"Failed to parse response from rest catalog server!",
)
.with_context("json", String::from_utf8_lossy(&text))
.with_context("code", code.to_string())
.with_source(e)
})?;
Err(e.into())
}
}

async fn execute<E: DeserializeOwned + Into<Error>, const SUCCESS_CODE: u16>(
&self,
request: Request,
) -> Result<()> {
let resp = self.0.execute(request).await?;

if resp.status().as_u16() == SUCCESS_CODE {
Ok(())
} else {
let code = resp.status();
let text = resp.bytes().await?;
let e = serde_json::from_slice::<E>(&text).map_err(|e| {
Error::new(
ErrorKind::Unexpected,
"Failed to parse response from rest catalog server!",
)
.with_context("json", String::from_utf8_lossy(&text))
.with_context("code", code.to_string())
.with_source(e)
})?;
Err(e.into())
}
}

/// More generic logic handling for special cases like head.
async fn do_execute<R, E: DeserializeOwned + Into<Error>>(
&self,
request: Request,
handler: impl FnOnce(&Response) -> Option<R>,
) -> Result<R> {
let resp = self.0.execute(request).await?;

if let Some(ret) = handler(&resp) {
Ok(ret)
} else {
let code = resp.status();
let text = resp.bytes().await?;
let e = serde_json::from_slice::<E>(&text).map_err(|e| {
Error::new(
ErrorKind::Unexpected,
"Failed to parse response from rest catalog server!",
)
.with_context("code", code.to_string())
.with_context("json", String::from_utf8_lossy(&text))
.with_source(e)
})?;
Err(e.into())
}
}
}

/// Rest catalog implementation.
#[derive(Debug)]
pub struct RestCatalog {
Expand All @@ -290,7 +196,9 @@ impl Catalog for RestCatalog {
&self,
parent: Option<&NamespaceIdent>,
) -> Result<Vec<NamespaceIdent>> {
let mut request = self.client.0.get(self.config.namespaces_endpoint());
let mut request = self
.client
.request(Method::GET, self.config.namespaces_endpoint());
if let Some(ns) = parent {
request = request.query(&[("parent", ns.encode_in_url())]);
}
Expand All @@ -314,8 +222,7 @@ impl Catalog for RestCatalog {
) -> Result<Namespace> {
let request = self
.client
.0
.post(self.config.namespaces_endpoint())
.request(Method::POST, self.config.namespaces_endpoint())
.json(&NamespaceSerde {
namespace: namespace.as_ref().clone(),
properties: Some(properties),
Expand All @@ -334,8 +241,7 @@ impl Catalog for RestCatalog {
async fn get_namespace(&self, namespace: &NamespaceIdent) -> Result<Namespace> {
let request = self
.client
.0
.get(self.config.namespace_endpoint(namespace))
.request(Method::GET, self.config.namespace_endpoint(namespace))
.build()?;

let resp = self
Expand Down Expand Up @@ -364,8 +270,7 @@ impl Catalog for RestCatalog {
async fn namespace_exists(&self, ns: &NamespaceIdent) -> Result<bool> {
let request = self
.client
.0
.head(self.config.namespace_endpoint(ns))
.request(Method::HEAD, self.config.namespace_endpoint(ns))
.build()?;

self.client
Expand All @@ -381,8 +286,7 @@ impl Catalog for RestCatalog {
async fn drop_namespace(&self, namespace: &NamespaceIdent) -> Result<()> {
let request = self
.client
.0
.delete(self.config.namespace_endpoint(namespace))
.request(Method::DELETE, self.config.namespace_endpoint(namespace))
.build()?;

self.client
Expand All @@ -394,8 +298,7 @@ impl Catalog for RestCatalog {
async fn list_tables(&self, namespace: &NamespaceIdent) -> Result<Vec<TableIdent>> {
let request = self
.client
.0
.get(self.config.tables_endpoint(namespace))
.request(Method::GET, self.config.tables_endpoint(namespace))
.build()?;

let resp = self
Expand All @@ -416,8 +319,7 @@ impl Catalog for RestCatalog {

let request = self
.client
.0
.post(self.config.tables_endpoint(namespace))
.request(Method::POST, self.config.tables_endpoint(namespace))
.json(&CreateTableRequest {
name: creation.name,
location: creation.location,
Expand Down Expand Up @@ -460,8 +362,7 @@ impl Catalog for RestCatalog {
async fn load_table(&self, table: &TableIdent) -> Result<Table> {
let request = self
.client
.0
.get(self.config.table_endpoint(table))
.request(Method::GET, self.config.table_endpoint(table))
.build()?;

let resp = self
Expand All @@ -487,8 +388,7 @@ impl Catalog for RestCatalog {
async fn drop_table(&self, table: &TableIdent) -> Result<()> {
let request = self
.client
.0
.delete(self.config.table_endpoint(table))
.request(Method::DELETE, self.config.table_endpoint(table))
.build()?;

self.client
Expand All @@ -500,8 +400,7 @@ impl Catalog for RestCatalog {
async fn table_exists(&self, table: &TableIdent) -> Result<bool> {
let request = self
.client
.0
.head(self.config.table_endpoint(table))
.request(Method::HEAD, self.config.table_endpoint(table))
.build()?;

self.client
Expand All @@ -517,8 +416,7 @@ impl Catalog for RestCatalog {
async fn rename_table(&self, src: &TableIdent, dest: &TableIdent) -> Result<()> {
let request = self
.client
.0
.post(self.config.rename_table_endpoint())
.request(Method::POST, self.config.rename_table_endpoint())
.json(&RenameTableRequest {
source: src.clone(),
destination: dest.clone(),
Expand All @@ -534,8 +432,10 @@ impl Catalog for RestCatalog {
async fn update_table(&self, mut commit: TableCommit) -> Result<Table> {
let request = self
.client
.0
.post(self.config.table_endpoint(commit.identifier()))
.request(
Method::POST,
self.config.table_endpoint(commit.identifier()),
)
.json(&CommitTableRequest {
identifier: commit.identifier().clone(),
requirements: commit.take_requirements(),
Expand Down Expand Up @@ -594,8 +494,7 @@ impl RestCatalog {
params.extend(optional_oauth_params);
let req = self
.client
.0
.post(self.config.get_token_endpoint())
.request(Method::POST, self.config.get_token_endpoint())
.form(&params)
.build()?;
let res = self
Expand All @@ -617,7 +516,9 @@ impl RestCatalog {
}

async fn update_config(&mut self) -> Result<()> {
let mut request = self.client.0.get(self.config.config_endpoint());
let mut request = self
.client
.request(Method::GET, self.config.config_endpoint());

if let Some(warehouse_location) = &self.config.warehouse {
request = request.query(&[("warehouse", warehouse_location)]);
Expand Down
124 changes: 124 additions & 0 deletions crates/catalog/rest/src/client.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

use iceberg::Result;
use iceberg::{Error, ErrorKind};
use reqwest::header::HeaderMap;
use reqwest::{Client, IntoUrl, Method, Request, RequestBuilder, Response};
use serde::de::DeserializeOwned;

#[derive(Debug)]
pub(crate) struct HttpClient(Client);

impl HttpClient {
pub fn try_create(default_headers: HeaderMap) -> Result<Self> {
Ok(HttpClient(
Client::builder().default_headers(default_headers).build()?,
))
}

#[inline]
pub fn request<U: IntoUrl>(&self, method: Method, url: U) -> RequestBuilder {
self.0.request(method, url)
}

pub async fn query<
R: DeserializeOwned,
E: DeserializeOwned + Into<Error>,
const SUCCESS_CODE: u16,
>(
&self,
request: Request,
) -> Result<R> {
let resp = self.0.execute(request).await?;

if resp.status().as_u16() == SUCCESS_CODE {
let text = resp.bytes().await?;
Ok(serde_json::from_slice::<R>(&text).map_err(|e| {
Error::new(
ErrorKind::Unexpected,
"Failed to parse response from rest catalog server!",
)
.with_context("json", String::from_utf8_lossy(&text))
.with_source(e)
})?)
} else {
let code = resp.status();
let text = resp.bytes().await?;
let e = serde_json::from_slice::<E>(&text).map_err(|e| {
Error::new(
ErrorKind::Unexpected,
"Failed to parse response from rest catalog server!",
)
.with_context("json", String::from_utf8_lossy(&text))
.with_context("code", code.to_string())
.with_source(e)
})?;
Err(e.into())
}
}

pub async fn execute<E: DeserializeOwned + Into<Error>, const SUCCESS_CODE: u16>(
&self,
request: Request,
) -> Result<()> {
let resp = self.0.execute(request).await?;

if resp.status().as_u16() == SUCCESS_CODE {
Ok(())
} else {
let code = resp.status();
let text = resp.bytes().await?;
let e = serde_json::from_slice::<E>(&text).map_err(|e| {
Error::new(
ErrorKind::Unexpected,
"Failed to parse response from rest catalog server!",
)
.with_context("json", String::from_utf8_lossy(&text))
.with_context("code", code.to_string())
.with_source(e)
})?;
Err(e.into())
}
}

/// More generic logic handling for special cases like head.
pub async fn do_execute<R, E: DeserializeOwned + Into<Error>>(
&self,
request: Request,
handler: impl FnOnce(&Response) -> Option<R>,
) -> Result<R> {
let resp = self.0.execute(request).await?;

if let Some(ret) = handler(&resp) {
Ok(ret)
} else {
let code = resp.status();
let text = resp.bytes().await?;
let e = serde_json::from_slice::<E>(&text).map_err(|e| {
Error::new(
ErrorKind::Unexpected,
"Failed to parse response from rest catalog server!",
)
.with_context("code", code.to_string())
.with_context("json", String::from_utf8_lossy(&text))
.with_source(e)
})?;
Err(e.into())
}
}
}
Loading

0 comments on commit 6339351

Please sign in to comment.