diff --git a/Cargo.lock b/Cargo.lock index 95b98d170add..48b2e4b8324d 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -793,24 +793,6 @@ dependencies = [ "syn 2.0.43", ] -[[package]] -name = "axum-test-helper" -version = "0.3.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "298f62fa902c2515c169ab0bfb56c593229f33faa01131215d58e3d4898e3aa9" -dependencies = [ - "axum", - "bytes", - "http", - "http-body", - "hyper", - "reqwest", - "serde", - "tokio", - "tower", - "tower-service", -] - [[package]] name = "backon" version = "0.4.1" @@ -3987,9 +3969,9 @@ dependencies = [ [[package]] name = "http" -version = "0.2.11" +version = "0.2.12" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8947b1a6fad4393052c7ba1f4cd97bed3e953a95c79c92ad9b051a04611d9fbb" +checksum = "601cbb57e577e2f5ef5be8e7b83f0f63994f25aa94d673e54a92d5c516d101f1" dependencies = [ "bytes", "fnv", @@ -9029,7 +9011,6 @@ dependencies = [ "auth", "axum", "axum-macros", - "axum-test-helper", "base64 0.21.5", "bytes", "catalog", @@ -9060,6 +9041,7 @@ dependencies = [ "hashbrown 0.14.3", "headers", "hostname", + "http", "http-body", "humantime-serde", "hyper", @@ -10104,7 +10086,6 @@ dependencies = [ "async-trait", "auth", "axum", - "axum-test-helper", "catalog", "chrono", "client", diff --git a/Cargo.toml b/Cargo.toml index b6efa5731918..9f0a7d3c78cb 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -133,6 +133,7 @@ reqwest = { version = "0.11", default-features = false, features = [ "json", "rustls-tls-native-roots", "stream", + "multipart", ] } rskafka = "0.5" rust_decimal = "1.33" diff --git a/src/servers/Cargo.toml b/src/servers/Cargo.toml index c86770246df5..595c2ca61121 100644 --- a/src/servers/Cargo.toml +++ b/src/servers/Cargo.toml @@ -51,6 +51,7 @@ futures = "0.3" hashbrown = "0.14" headers = "0.3" hostname = "0.3.1" +http = "0.2.12" http-body = "0.4" humantime-serde.workspace = true hyper = { version = "0.14", features = ["full"] } @@ -109,7 +110,6 @@ tikv-jemalloc-ctl = { version = "0.5", features = ["use_std"] } [dev-dependencies] auth = { workspace = true, features = ["testing"] } -axum-test-helper = "0.3" catalog = { workspace = true, features = ["testing"] } client.workspace = true common-base.workspace = true diff --git a/src/servers/src/http.rs b/src/servers/src/http.rs index 67debfd376dc..2d57d675baa2 100644 --- a/src/servers/src/http.rs +++ b/src/servers/src/http.rs @@ -94,6 +94,9 @@ pub mod greptime_result_v1; pub mod influxdb_result_v1; pub mod table_result; +#[cfg(any(test, feature = "testing"))] +pub mod test_helpers; + pub const HTTP_API_VERSION: &str = "v1"; pub const HTTP_API_PREFIX: &str = "/v1/"; /// Default http body limit (64M). @@ -824,7 +827,6 @@ mod test { use axum::handler::Handler; use axum::http::StatusCode; use axum::routing::get; - use axum_test_helper::TestClient; use common_query::Output; use common_recordbatch::RecordBatches; use datatypes::prelude::*; @@ -838,6 +840,7 @@ mod test { use super::*; use crate::error::Error; + use crate::http::test_helpers::TestClient; use crate::query_handler::grpc::GrpcQueryHandler; use crate::query_handler::sql::{ServerSqlQueryHandlerAdapter, SqlQueryHandler}; diff --git a/src/servers/src/http/test_helpers.rs b/src/servers/src/http/test_helpers.rs new file mode 100644 index 000000000000..bdfc348e4772 --- /dev/null +++ b/src/servers/src/http/test_helpers.rs @@ -0,0 +1,204 @@ +// Copyright 2023 Greptime Team +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use std::convert::TryFrom; +use std::net::{SocketAddr, TcpListener}; + +use axum::body::HttpBody; +use axum::BoxError; +use bytes::Bytes; +use common_telemetry::info; +use http::header::{HeaderName, HeaderValue}; +use http::{Request, StatusCode}; +use hyper::service::Service; +use hyper::{Body, Server}; +use tower::make::Shared; + +pub struct TestClient { + client: reqwest::Client, + addr: SocketAddr, +} + +impl TestClient { + pub fn new(svc: S) -> Self + where + S: Service, Response = http::Response> + Clone + Send + 'static, + ResBody: HttpBody + Send + 'static, + ResBody::Data: Send, + ResBody::Error: Into, + S::Future: Send, + S::Error: Into, + { + let listener = TcpListener::bind("127.0.0.1:0").expect("Could not bind ephemeral socket"); + let addr = listener.local_addr().unwrap(); + info!("Listening on {}", addr); + + tokio::spawn(async move { + let server = Server::from_tcp(listener).unwrap().serve(Shared::new(svc)); + server.await.expect("server error"); + }); + + let client = reqwest::Client::builder() + .redirect(reqwest::redirect::Policy::none()) + .build() + .unwrap(); + + TestClient { client, addr } + } + + /// returns the base URL (http://ip:port) for this TestClient + /// + /// this is useful when trying to check if Location headers in responses + /// are generated correctly as Location contains an absolute URL + pub fn base_url(&self) -> String { + format!("http://{}", self.addr) + } + + pub fn get(&self, url: &str) -> RequestBuilder { + RequestBuilder { + builder: self.client.get(format!("http://{}{}", self.addr, url)), + } + } + + pub fn head(&self, url: &str) -> RequestBuilder { + RequestBuilder { + builder: self.client.head(format!("http://{}{}", self.addr, url)), + } + } + + pub fn post(&self, url: &str) -> RequestBuilder { + RequestBuilder { + builder: self.client.post(format!("http://{}{}", self.addr, url)), + } + } + + pub fn put(&self, url: &str) -> RequestBuilder { + RequestBuilder { + builder: self.client.put(format!("http://{}{}", self.addr, url)), + } + } + + pub fn patch(&self, url: &str) -> RequestBuilder { + RequestBuilder { + builder: self.client.patch(format!("http://{}{}", self.addr, url)), + } + } + + pub fn delete(&self, url: &str) -> RequestBuilder { + RequestBuilder { + builder: self.client.delete(format!("http://{}{}", self.addr, url)), + } + } +} + +pub struct RequestBuilder { + builder: reqwest::RequestBuilder, +} + +impl RequestBuilder { + pub async fn send(self) -> TestResponse { + TestResponse { + response: self.builder.send().await.unwrap(), + } + } + + pub fn body(mut self, body: impl Into) -> Self { + self.builder = self.builder.body(body); + self + } + + pub fn form(mut self, form: &T) -> Self { + self.builder = self.builder.form(&form); + self + } + + pub fn json(mut self, json: &T) -> Self + where + T: serde::Serialize, + { + self.builder = self.builder.json(json); + self + } + + pub fn header(mut self, key: K, value: V) -> Self + where + HeaderName: TryFrom, + >::Error: Into, + HeaderValue: TryFrom, + >::Error: Into, + { + self.builder = self.builder.header(key, value); + self + } + + pub fn multipart(mut self, form: reqwest::multipart::Form) -> Self { + self.builder = self.builder.multipart(form); + self + } +} + +/// A wrapper around [`reqwest::Response`] that provides common methods with internal `unwrap()`s. +/// +/// This is conventient for tests where panics are what you want. For access to +/// non-panicking versions or the complete `Response` API use `into_inner()` or +/// `as_ref()`. +pub struct TestResponse { + response: reqwest::Response, +} + +impl TestResponse { + pub async fn text(self) -> String { + self.response.text().await.unwrap() + } + + #[allow(dead_code)] + pub async fn bytes(self) -> Bytes { + self.response.bytes().await.unwrap() + } + + pub async fn json(self) -> T + where + T: serde::de::DeserializeOwned, + { + self.response.json().await.unwrap() + } + + pub fn status(&self) -> StatusCode { + self.response.status() + } + + pub fn headers(&self) -> &http::HeaderMap { + self.response.headers() + } + + pub async fn chunk(&mut self) -> Option { + self.response.chunk().await.unwrap() + } + + pub async fn chunk_text(&mut self) -> Option { + let chunk = self.chunk().await?; + Some(String::from_utf8(chunk.to_vec()).unwrap()) + } + + /// Get the inner [`reqwest::Response`] for less convenient but more complete access. + pub fn into_inner(self) -> reqwest::Response { + self.response + } +} + +impl AsRef for TestResponse { + fn as_ref(&self) -> &reqwest::Response { + &self.response + } +} diff --git a/src/servers/tests/http/http_test.rs b/src/servers/tests/http/http_test.rs index b9f9dc5fa25d..071d0f6548cb 100644 --- a/src/servers/tests/http/http_test.rs +++ b/src/servers/tests/http/http_test.rs @@ -13,8 +13,8 @@ // limitations under the License. use axum::Router; -use axum_test_helper::TestClient; use common_test_util::ports; +use servers::http::test_helpers::TestClient; use servers::http::{HttpOptions, HttpServerBuilder}; use table::test_util::MemTable; diff --git a/src/servers/tests/http/influxdb_test.rs b/src/servers/tests/http/influxdb_test.rs index 8c0a6031a594..0282264f85b8 100644 --- a/src/servers/tests/http/influxdb_test.rs +++ b/src/servers/tests/http/influxdb_test.rs @@ -19,7 +19,6 @@ use api::v1::RowInsertRequests; use async_trait::async_trait; use auth::tests::{DatabaseAuthInfo, MockUserProvider}; use axum::{http, Router}; -use axum_test_helper::TestClient; use common_query::Output; use common_test_util::ports; use query::parser::PromQuery; @@ -27,6 +26,7 @@ use query::plan::LogicalPlan; use query::query_engine::DescribeResult; use servers::error::{Error, Result}; use servers::http::header::constants::GREPTIME_DB_HEADER_NAME; +use servers::http::test_helpers::TestClient; use servers::http::{HttpOptions, HttpServerBuilder}; use servers::influxdb::InfluxdbRequest; use servers::query_handler::grpc::GrpcQueryHandler; diff --git a/src/servers/tests/http/opentsdb_test.rs b/src/servers/tests/http/opentsdb_test.rs index 9563f254224e..1a719aa93d2c 100644 --- a/src/servers/tests/http/opentsdb_test.rs +++ b/src/servers/tests/http/opentsdb_test.rs @@ -17,13 +17,13 @@ use std::sync::Arc; use api::v1::greptime_request::Request; use async_trait::async_trait; use axum::Router; -use axum_test_helper::TestClient; use common_query::Output; use common_test_util::ports; use query::parser::PromQuery; use query::plan::LogicalPlan; use query::query_engine::DescribeResult; use servers::error::{self, Result}; +use servers::http::test_helpers::TestClient; use servers::http::{HttpOptions, HttpServerBuilder}; use servers::opentsdb::codec::DataPoint; use servers::query_handler::grpc::GrpcQueryHandler; diff --git a/src/servers/tests/http/prom_store_test.rs b/src/servers/tests/http/prom_store_test.rs index 1a005fc0bd77..350dcff65189 100644 --- a/src/servers/tests/http/prom_store_test.rs +++ b/src/servers/tests/http/prom_store_test.rs @@ -21,7 +21,6 @@ use api::v1::greptime_request::Request; use api::v1::RowInsertRequests; use async_trait::async_trait; use axum::Router; -use axum_test_helper::TestClient; use common_query::Output; use common_test_util::ports; use prost::Message; @@ -30,6 +29,7 @@ use query::plan::LogicalPlan; use query::query_engine::DescribeResult; use servers::error::{Error, Result}; use servers::http::header::{CONTENT_ENCODING_SNAPPY, CONTENT_TYPE_PROTOBUF}; +use servers::http::test_helpers::TestClient; use servers::http::{HttpOptions, HttpServerBuilder}; use servers::prom_store; use servers::prom_store::{snappy_compress, Metrics}; diff --git a/tests-integration/Cargo.toml b/tests-integration/Cargo.toml index 99438540e209..f843e6615c1b 100644 --- a/tests-integration/Cargo.toml +++ b/tests-integration/Cargo.toml @@ -16,7 +16,6 @@ arrow-flight.workspace = true async-trait = "0.1" auth.workspace = true axum.workspace = true -axum-test-helper = "0.3.0" catalog.workspace = true chrono.workspace = true client = { workspace = true, features = ["testing"] } diff --git a/tests-integration/tests/http.rs b/tests-integration/tests/http.rs index bc8b458ee5d1..ea8ec073ceed 100644 --- a/tests-integration/tests/http.rs +++ b/tests-integration/tests/http.rs @@ -17,7 +17,6 @@ use std::collections::BTreeMap; use api::prom_store::remote::WriteRequest; use auth::user_provider_from_option; use axum::http::{HeaderName, StatusCode}; -use axum_test_helper::TestClient; use common_error::status_code::StatusCode as ErrorCode; use prost::Message; use serde_json::json; @@ -27,6 +26,7 @@ use servers::http::handler::HealthResponse; use servers::http::header::GREPTIME_TIMEZONE_HEADER_NAME; use servers::http::influxdb_result_v1::{InfluxdbOutput, InfluxdbV1Response}; use servers::http::prometheus::{PrometheusJsonResponse, PrometheusResponse}; +use servers::http::test_helpers::TestClient; use servers::http::GreptimeQueryOutput; use servers::prom_store; use tests_integration::test_util::{