diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index bdef0ff..8ffef2d 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -10,15 +10,16 @@ env: CARGO_TERM_COLOR: always jobs: - build_and_test: + build_toolchains: + name: Clippy on ${{ matrix.toolchain }} runs-on: ubuntu-latest strategy: + fail-fast: false matrix: toolchain: - stable - beta - nightly - steps: - uses: actions/checkout@v4 @@ -45,24 +46,85 @@ jobs: sarif_file: rust-clippy-results.sarif wait-for-processing: true - - name: Build - run: cargo build --verbose + build_and_test: + name: Build and test ${{matrix.os}} on ${{ matrix.toolchain }} + runs-on: ${{ matrix.os }} + strategy: + fail-fast: false + matrix: + os: [ubuntu-latest, macos-latest, windows-latest] + toolchain: + - stable + - beta + steps: + - uses: actions/checkout@v4 - - name: install Spice (https://install.spiceai.org) + - name: Install Rust toolchain + uses: actions-rs/toolchain@v1 + with: + profile: minimal + toolchain: ${{ matrix.toolchain }} + components: clippy + override: true + + - name: Cargo clippy + run: cargo clippy --all-features + + - name: Install Spice (https://install.spiceai.org) (Linux) + if: matrix.os == 'ubuntu-latest' + env: + GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} run: | curl https://install.spiceai.org | /bin/bash echo "$HOME/.spice/bin" >> $GITHUB_PATH + $HOME/.spice/bin/spice install + + - name: Install Spice (https://install.spiceai.org) (MacOS) + if: matrix.os == 'macos-latest' + run: | + brew install spiceai/spiceai/spice + brew install spiceai/spiceai/spiced + + - name: install Spice (Windows) + if: matrix.os == 'windows-latest' + run: | + curl -L "https://install.spiceai.org/Install.ps1" -o Install.ps1 && PowerShell -ExecutionPolicy Bypass -File ./Install.ps1 + + - name: add Spice bin to PATH (Windows) + if: matrix.os == 'windows-latest' + run: | + Add-Content $env:GITHUB_PATH (Join-Path $HOME ".spice\bin") + shell: pwsh - name: Init and start spice app + if: matrix.os != 'windows-latest' run: | spice init spice_qs cd spice_qs spice add spiceai/quickstart - spice run &> spice.log & + spiced &> spice.log & # time to initialize added dataset sleep 10 + - name: Init and start spice app (Windows) + if: matrix.os == 'windows-latest' + run: | + spice init spice_qs + cd spice_qs + spice add spiceai/quickstart + Start-Process -FilePath spice run + # time to initialize added dataset + Start-Sleep -Seconds 10 + shell: pwsh + - name: Run tests - run: cargo test --verbose + run: cargo test env: API_KEY: ${{ secrets.TEST_API_KEY }} + + - name: Stop spice and check logs + working-directory: spice_qs + if: matrix.os != 'windows-latest' && always() + run: | + killall spice || true + cat spice.log diff --git a/.github/workflows/publish.yml b/.github/workflows/publish.yml index 63fde69..7eaa77e 100644 --- a/.github/workflows/publish.yml +++ b/.github/workflows/publish.yml @@ -20,8 +20,8 @@ jobs: toolchain: stable override: true - - run: cargo build --verbose + - run: cargo build - - run: cargo publish --verbose + - run: cargo publish env: CARGO_REGISTRY_TOKEN: ${{ secrets.CARGO_REGISTRY_TOKEN }} diff --git a/Cargo.toml b/Cargo.toml index 86cc3c6..2ba9993 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -27,3 +27,9 @@ dotenv = "0.15.0" arrow = "51.0.0" futures = "0.3.30" base64 = "0.22.0" + +[target.'cfg(windows)'.dependencies] +winver = "1.0.0" + +[dev-dependencies] +regex = "1.10.6" \ No newline at end of file diff --git a/README.md b/README.md index 0448660..2e56390 100644 --- a/README.md +++ b/README.md @@ -33,7 +33,7 @@ async fn main() { } ``` -### New client with https://spice.ai cloud +### New client with cloud ```rust use spiceai::ClientBuilder; diff --git a/src/client.rs b/src/client.rs index 034a9f1..35e3989 100644 --- a/src/client.rs +++ b/src/client.rs @@ -1,11 +1,12 @@ use crate::{ - config::{SPICE_CLOUD_FIRECACHE_ADDR, SPICE_CLOUD_FLIGHT_ADDR, SPICE_LOCAL_FLIGHT_ADDR}, + config::{ + GenericError, SPICE_CLOUD_FIRECACHE_ADDR, SPICE_CLOUD_FLIGHT_ADDR, SPICE_LOCAL_FLIGHT_ADDR, + }, flight::SqlFlightClient, tls::new_tls_flight_channel, }; use arrow_flight::decode::FlightRecordBatchStream; use futures::try_join; -use std::error::Error; use tonic::transport::Channel; struct SpiceClientConfig { @@ -21,7 +22,7 @@ impl SpiceClientConfig { } } - pub async fn load_from_default() -> Result> { + pub async fn load_from_default() -> Result { let (flight_chan, firecache_chan) = try_join!( new_tls_flight_channel(SPICE_CLOUD_FLIGHT_ADDR), new_tls_flight_channel(SPICE_CLOUD_FIRECACHE_ADDR) @@ -49,7 +50,11 @@ impl SpiceClient { /// let mut client = Client::new("API_KEY").await.unwrap(); /// } /// ``` - pub async fn new(api_key: &str) -> Result> { + /// + /// ## Errors + /// + /// - `Box` for any query error + pub async fn new(api_key: &str) -> Result { let config = SpiceClientConfig::load_from_default().await?; Ok(Self { @@ -58,6 +63,7 @@ impl SpiceClient { }) } + #[must_use] pub fn builder() -> SpiceClientBuilder { SpiceClientBuilder::new() } @@ -72,7 +78,11 @@ impl SpiceClient { /// let data = client.query("SELECT * FROM eth.recent_blocks LIMIT 10;").await; /// # } /// ```` - pub async fn query(&mut self, query: &str) -> Result> { + /// + /// ## Errors + /// + /// - `Box` for any query error + pub async fn query(&mut self, query: &str) -> Result { self.flight.query(query).await } @@ -86,10 +96,14 @@ impl SpiceClient { /// let data = client.fire_query("SELECT * FROM eth.recent_blocks LIMIT 10;").await; /// # } /// ```` + /// + /// ## Errors + /// + /// - `Box` for any query error pub async fn fire_query( &mut self, query: &str, - ) -> Result> { + ) -> Result { self.firecache.query(query).await } } @@ -138,6 +152,7 @@ impl Default for SpiceClientBuilder { } impl SpiceClientBuilder { + #[must_use] pub fn new() -> Self { Self { api_key: None, @@ -147,18 +162,21 @@ impl SpiceClientBuilder { } /// Configures the `SpiceClient` to use the given API key. + #[must_use] pub fn api_key(mut self, api_key: &str) -> Self { self.api_key = Some(api_key.to_string()); self } /// Configures the `SpiceClient` to use the given Spice Firecache endpoint. + #[must_use] pub fn firecache_url(mut self, firecache_url: &str) -> Self { self.firecache_url = Some(firecache_url.to_string()); self } /// Configures the `SpiceClient` to use the given Spice Flight endpoint. + #[must_use] pub fn flight_url(mut self, flight_url: &str) -> Self { self.flight_url = Some(flight_url.to_string()); self @@ -166,6 +184,7 @@ impl SpiceClientBuilder { /// Configures the `SpiceClient` to use default Spice.ai Cloud endpoints. /// Equivalent to calling `.firecache_url("https://firecache.spiceai.io")` and `.flight_url("https://flight.spiceai.io")`. + #[must_use] pub fn use_spiceai_cloud(mut self) -> Self { self.flight_url = Some(SPICE_CLOUD_FLIGHT_ADDR.to_string()); self.firecache_url = Some(SPICE_CLOUD_FIRECACHE_ADDR.to_string()); @@ -173,7 +192,11 @@ impl SpiceClientBuilder { } /// Builds the `SpiceClient` with the specified configuration. - pub async fn build(self) -> Result> { + /// + /// ## Errors + /// + /// - `Box` if flight or firecache channel creation fails + pub async fn build(self) -> Result { let flight_channel = match self.flight_url { Some(url) => new_tls_flight_channel(&url).await?, None => new_tls_flight_channel(SPICE_LOCAL_FLIGHT_ADDR).await?, diff --git a/src/config.rs b/src/config.rs index 3580e3d..9cafb19 100644 --- a/src/config.rs +++ b/src/config.rs @@ -3,3 +3,80 @@ pub const SPICE_CLOUD_FIRECACHE_ADDR: &str = "https://firecache.spiceai.io"; // default address for local spice runtime pub const SPICE_LOCAL_FLIGHT_ADDR: &str = "http://localhost:50051"; + +pub type GenericError = Box; + +#[cfg(target_family = "unix")] +fn get_os_release() -> Result { + // call uname -r to get release text + use std::process::Command; + let output = Command::new("uname").arg("-r").output()?; + let release = String::from_utf8(output.stdout)?; + + Ok(release) +} + +#[cfg(target_family = "windows")] +fn get_os_release() -> Result { + use winver::WindowsVersion; + if let Some(version) = WindowsVersion::detect() { + Ok(version.to_string()) + } else { + Ok("unknown".to_string()) + } +} + +pub(crate) fn get_user_agent() -> String { + let os_type = std::env::consts::OS; + let os_type = match os_type { + "" => "unknown".to_string(), + "macos" => "Darwin".to_string(), + "linux" => "Linux".to_string(), + "windows" => "Windows".to_string(), + "ios" => "iOS".to_string(), + "android" => "Android".to_string(), + "freebsd" => "FreeBSD".to_string(), + "dragonfly" => "DragonFlyBSD".to_string(), + "netbsd" => "NetBSD".to_string(), + "openbsd" => "OpenBSD".to_string(), + "solaris" => "Solaris".to_string(), + _ => os_type.to_string(), + }; + + let os_arch = std::env::consts::ARCH; + let os_arch = match os_arch { + "" => "unknown".to_string(), + "x86" => "i386".to_string(), + _ => os_arch.to_string(), + }; + + let os_release = get_os_release() + .unwrap_or_else(|_| "unknown".to_string()) + .trim() + .to_string(); + + format!( + "spice-rs {} ({os_type}/{os_release} {os_arch})", + env!("CARGO_PKG_VERSION") + ) +} + +#[cfg(test)] +mod test { + use super::*; + + #[test] + fn test_get_user_agent() { + let matching_regex = regex::Regex::new( + r"spice-rs \d+\.\d+\.\d+ \((Linux|Windows|Darwin)/[\d\w\.\-\_]+ (x86_64|aarch64|i386)\)", + ) + .expect("regex should be constructed"); + + let user_agent = get_user_agent(); + let agent_matches = matching_regex.is_match(&user_agent); + assert!( + agent_matches, + "expected user agent to match regex, but got {user_agent}" + ); + } +} diff --git a/src/flight.rs b/src/flight.rs index 8414bd4..e569a18 100644 --- a/src/flight.rs +++ b/src/flight.rs @@ -1,3 +1,5 @@ +use crate::config::get_user_agent; +use crate::config::GenericError; use arrow::error::ArrowError; use arrow_flight::decode::FlightRecordBatchStream; use arrow_flight::error::FlightError; @@ -11,7 +13,6 @@ use bytes::Bytes; use futures::stream; use futures::TryStreamExt; use std::collections::HashMap; -use std::error::Error; use std::str::FromStr; use tonic::metadata::AsciiMetadataKey; use tonic::transport::Channel; @@ -24,16 +25,20 @@ pub struct SqlFlightClient { api_key: Option, } +#[allow(clippy::needless_pass_by_value)] fn status_to_arrow_error(status: tonic::Status) -> ArrowError { ArrowError::IpcError(format!("{status:?}")) } impl SqlFlightClient { pub fn new(chan: Channel, api_key: Option) -> Self { + let mut headers = HashMap::new(); + headers.insert("x-spice-user-agent".to_string(), get_user_agent()); + SqlFlightClient { api_key, + headers, client: FlightServiceClient::new(chan), - headers: HashMap::default(), token: None, } } @@ -41,7 +46,7 @@ impl SqlFlightClient { async fn handshake(&mut self, username: &str, password: &str) -> Result { let cmd = HandshakeRequest { protocol_version: 0, - payload: Default::default(), + payload: Bytes::default(), }; let mut req = tonic::Request::new(stream::iter(vec![cmd])); let val = BASE64_STANDARD.encode(format!("{username}:{password}")); @@ -81,7 +86,7 @@ impl SqlFlightClient { Ok(resp) } - async fn authenticate(&mut self, api_key: &str) -> std::result::Result<(), Box> { + async fn authenticate(&mut self, api_key: &str) -> std::result::Result<(), GenericError> { if api_key.split('|').collect::().len() < 2 { return Err("Invalid API key format".into()); } @@ -114,7 +119,7 @@ impl SqlFlightClient { pub async fn query( &mut self, query: &str, - ) -> std::result::Result> { + ) -> std::result::Result { let api_key = self.api_key.clone(); if let Some(api_key) = api_key { self.authenticate(&api_key).await?; diff --git a/src/tls.rs b/src/tls.rs index 3e043b8..c72ad5c 100644 --- a/src/tls.rs +++ b/src/tls.rs @@ -1,9 +1,10 @@ -use std::error::Error; use std::str::FromStr; use tonic::transport::channel::{ClientTlsConfig, Endpoint}; use tonic::transport::Channel; -pub fn system_tls_certificate() -> Result> { +use crate::config::GenericError; + +pub fn system_tls_certificate() -> Result { // Load root certificates found in the platform’s native certificate store. let certs = rustls_native_certs::load_native_certs()?; @@ -19,7 +20,7 @@ pub fn system_tls_certificate() -> Result Result> { +pub async fn new_tls_flight_channel(https_url: &str) -> Result { let mut endpoint = Endpoint::from_str(https_url)?; if https_url.starts_with("https://") { diff --git a/tests/client_test.rs b/tests/client_test.rs index b6c7e1e..d1e30bf 100644 --- a/tests/client_test.rs +++ b/tests/client_test.rs @@ -41,13 +41,13 @@ mod tests { assert_eq!(batch.num_rows(), 10); }, Err(e) => { - panic!("Error: {}", e) + panic!("Error: {e}") }, }; } } Err(e) => { - panic!("Error: {}", e); + panic!("Error: {e}"); } }; } @@ -67,13 +67,13 @@ mod tests { assert_eq!(batch.num_rows(), 10); }, Err(e) => { - panic!("Error: {}", e) + panic!("Error: {e}") } }; } } Err(e) => { - panic!("Error: {}", e); + panic!("Error: {e}"); } }; } @@ -95,7 +95,7 @@ mod tests { total_rows += batch.num_rows(); }, Err(e) => { - panic!("Error: {}", e) + panic!("Error: {e}") }, }; } @@ -103,7 +103,7 @@ mod tests { assert_ne!(num_batches, 1); } Err(e) => { - panic!("Error: {}", e); + panic!("Error: {e}"); } }; } diff --git a/tests/readme_test.rs b/tests/readme_test.rs index bb287c3..4340a12 100644 --- a/tests/readme_test.rs +++ b/tests/readme_test.rs @@ -11,13 +11,17 @@ mod tests { dotenv::from_path(Path::new(".env.local")).ok(); let api_key = env::var("API_KEY").expect("API_KEY not found"); - let mut client = Client::new(&api_key).await.unwrap(); + let mut client = Client::new(&api_key) + .await + .expect("SpiceClient should be created"); let data = client .query("SELECT * FROM eth.recent_blocks LIMIT 10;") .await; - if data.is_err() { - panic!("failed to query: {:#?}", data.expect_err("")) - } + assert!( + data.is_ok(), + "failed to query: {:#?}", + data.expect_err("should be an error") + ); } #[tokio::test] @@ -31,24 +35,31 @@ mod tests { .use_spiceai_cloud() .build() .await - .unwrap(); + .expect("SpiceClient should be created"); let data = client .query("SELECT * FROM eth.recent_blocks LIMIT 10;") .await; - if data.is_err() { - panic!("failed to query: {:#?}", data.expect_err("")) - } + assert!( + data.is_ok(), + "failed to query: {:#?}", + data.expect_err("should be an error") + ); } #[tokio::test] async fn test_readme_builder_local() { // NOTE: If you're changing the code below, make sure you update the README.md. - let mut client = ClientBuilder::new().build().await.unwrap(); + let mut client = ClientBuilder::new() + .build() + .await + .expect("SpiceClient should be created"); let data = client.query("select * from taxi_trips limit 3;").await; - if data.is_err() { - panic!("failed to query: {:#?}", data.expect_err("")) - } + assert!( + data.is_ok(), + "failed to query: {:#?}", + data.expect_err("should be an error") + ); } }