diff --git a/Cargo.toml b/Cargo.toml index 2338ca69..77074dc3 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,7 +1,7 @@ [package] authors = ["Nicolas Grislain "] name = "qrlew" -version = "0.4.10" +version = "0.4.11" edition = "2021" description = "Sarus Qrlew Engine" documentation = "https://docs.rs/qrlew" @@ -29,6 +29,8 @@ dot = "0.1" base64 = "0.21" rusqlite = { version = "0.29", features = ["chrono"], optional = true } postgres = { version = "0.19", features = ["with-chrono-0_4"] } +r2d2 = "0.8" +r2d2_postgres = "0.18" rust_decimal = { version = "1.29", features = [ "tokio-pg" ] } statrs = "0.16.0" diff --git a/src/differential_privacy/aggregates.rs b/src/differential_privacy/aggregates.rs index 0ddbc2a8..a66469ae 100644 --- a/src/differential_privacy/aggregates.rs +++ b/src/differential_privacy/aggregates.rs @@ -309,11 +309,7 @@ mod tests { let query: &str = &ast::Query::from(&relation).to_string(); println!("{query}"); - _ = database - .query(query) - .unwrap() - .iter() - .map(ToString::to_string); + _ = database.query(query).unwrap(); } #[test] @@ -365,11 +361,7 @@ mod tests { let query: &str = &ast::Query::from(&relation).to_string(); println!("{query}"); - _ = database - .query(query) - .unwrap() - .iter() - .map(ToString::to_string); + _ = database.query(query).unwrap(); } #[test] @@ -426,11 +418,7 @@ mod tests { let query: &str = &ast::Query::from(&relation).to_string(); println!("{query}"); - _ = database - .query(query) - .unwrap() - .iter() - .map(ToString::to_string); + _ = database.query(query).unwrap(); } #[test] @@ -524,10 +512,6 @@ mod tests { let query: &str = &ast::Query::from(&relation).to_string(); println!("{query}"); - _ = database - .query(query) - .unwrap() - .iter() - .map(ToString::to_string); + _ = database.query(query).unwrap(); } } diff --git a/src/differential_privacy/mod.rs b/src/differential_privacy/mod.rs index 61c23e7c..b5a3b653 100644 --- a/src/differential_privacy/mod.rs +++ b/src/differential_privacy/mod.rs @@ -224,11 +224,7 @@ mod tests { let query: &str = &ast::Query::from(&dp_relation).to_string(); println!("{query}"); - _ = database - .query(query) - .unwrap() - .iter() - .map(ToString::to_string); + _ = database.query(query).unwrap(); } #[test] @@ -298,11 +294,7 @@ mod tests { let query: &str = &ast::Query::from(&dp_relation).to_string(); println!("{query}"); - _ = database - .query(query) - .unwrap() - .iter() - .map(ToString::to_string); + _ = database.query(query).unwrap(); } #[test] @@ -372,11 +364,7 @@ mod tests { let query: &str = &ast::Query::from(&dp_relation).to_string(); println!("{query}"); - _ = database - .query(query) - .unwrap() - .iter() - .map(ToString::to_string); + _ = database.query(query).unwrap(); } #[test] @@ -462,10 +450,6 @@ mod tests { let query: &str = &ast::Query::from(&dp_relation).to_string(); println!("{query}"); - _ = database - .query(query) - .unwrap() - .iter() - .map(ToString::to_string); + _ = database.query(query).unwrap(); } } diff --git a/src/io/mod.rs b/src/io/mod.rs index 7acc8c77..781bc557 100644 --- a/src/io/mod.rs +++ b/src/io/mod.rs @@ -83,6 +83,11 @@ impl From for Error { Error::Other(err.to_string()) } } +impl From for Error { + fn from(err: r2d2::Error) -> Self { + Error::Other(err.to_string()) + } +} pub type Result = result::Result; diff --git a/src/io/postgresql.rs b/src/io/postgresql.rs index 2dc5116b..6645443b 100644 --- a/src/io/postgresql.rs +++ b/src/io/postgresql.rs @@ -1,7 +1,7 @@ //! An object creating a docker container and releasing it after use //! -use super::{Database as DatabaseTrait, Error, Result, DATA_GENERATION_SEED, try_some_times}; +use super::{Database as DatabaseTrait, Error, Result, DATA_GENERATION_SEED}; use crate::{ data_type::{ generator::Generator, @@ -11,14 +11,18 @@ use crate::{ namer, relation::{Table, Variant as _}, }; +use std::{env, fmt, process::Command, str::FromStr, sync::Arc, sync::Mutex, thread, time, ops::Deref}; + use colored::Colorize; +use rand::{rngs::StdRng, SeedableRng}; +use rust_decimal::{prelude::ToPrimitive, Decimal}; use postgres::{ self, types::{FromSql, ToSql, Type}, }; -use rand::{rngs::StdRng, SeedableRng}; -use rust_decimal::{prelude::ToPrimitive, Decimal}; -use std::{env, fmt, process::Command, str::FromStr, sync::Arc, sync::Mutex, thread, time}; +use r2d2_postgres::{postgres::NoTls, PostgresConnectionManager}; +use r2d2::Pool; + const DB: &str = "qrlew-test"; const PORT: usize = 5432; @@ -35,10 +39,12 @@ impl From for Error { pub struct Database { name: String, tables: Vec, - client: postgres::Client, + pool: Pool>, drop: bool, } +/// Only one pool +pub static POSTGRES_POOL: Mutex>>> = Mutex::new(None); /// Only one thread start a container pub static POSTGRES_CONTAINER: Mutex = Mutex::new(false); @@ -62,47 +68,26 @@ impl Database { env::var("POSTGRES_PASSWORD").unwrap_or(PASSWORD.into()) } + /// Try to build a pool from an existing DB /// A postgresql instance must exist /// `docker run --name qrlew-test -p 5432:5432 -e POSTGRES_PASSWORD=qrlew-test -d postgres` - fn try_get_existing(name: String, tables: Vec
) -> Result { - log::info!("Try to get an existing DB"); - let mut client = postgres::Client::connect( - &format!( + fn build_pool_from_existing() -> Result>> { + let manager = PostgresConnectionManager::new( + format!( "host=localhost port={} user={} password={}", Database::port(), Database::user(), Database::password() - ), - postgres::NoTls, - )?; - let table_names: Vec = client - .query( - "SELECT * FROM pg_catalog.pg_tables WHERE schemaname='public'", - &[], - )? - .into_iter() - .map(|row| row.get("tablename")) - .collect(); - if table_names.is_empty() { - Database { - name, - tables: vec![], - client, - drop: false, - } - .with_tables(tables) - } else { - Ok(Database { - name, - tables, - client, - drop: false, - }) - } + ).parse()?, + NoTls, + ); + Ok(r2d2::Pool::builder() + .max_size(10) + .build(manager)?) } - /// Get a Database from a container - fn try_get_container(name: String, tables: Vec
) -> Result { + /// Try to build a pool from a DB in a container + fn build_pool_from_container(name: String) -> Result>> { let mut postgres_container = POSTGRES_CONTAINER.lock().unwrap(); if *postgres_container == false { // A new container will be started @@ -147,19 +132,15 @@ impl Database { } log::info!("{}", "DB ready".red()); } - let client = postgres::Client::connect( - &format!("host=localhost port={port} user={USER} password={PASSWORD}"), - postgres::NoTls, - )?; - Ok(Database { - name, - tables: vec![], - client, - drop: false, - } - .with_tables(tables)?) + let manager = PostgresConnectionManager::new( + format!("host=localhost port={port} user={USER} password={PASSWORD}").parse()?, + NoTls, + ); + Ok(r2d2::Pool::builder() + .max_size(10) + .build(manager)?) } else { - Database::try_get_existing(name, tables) + Database::build_pool_from_existing() } } } @@ -175,8 +156,35 @@ impl fmt::Debug for Database { impl DatabaseTrait for Database { fn new(name: String, tables: Vec
) -> Result { - try_some_times(100, || Database::try_get_existing(name.clone(), tables.clone())) - .or_else(|_| Database::try_get_container(name, tables)) + let mut postgres_pool = POSTGRES_POOL.lock().unwrap(); + if let None = *postgres_pool { + *postgres_pool = Some(Database::build_pool_from_existing().or_else(|_| Database::build_pool_from_container(name.clone()))?); + } + let pool = postgres_pool.as_ref().unwrap().clone(); + let table_names: Vec = pool.get()? + .query( + "SELECT * FROM pg_catalog.pg_tables WHERE schemaname='public'", + &[], + )? + .into_iter() + .map(|row| row.get("tablename")) + .collect(); + if table_names.is_empty() { + Database { + name, + tables: vec![], + pool, + drop: false, + } + .with_tables(tables) + } else { + Ok(Database { + name, + tables, + pool, + drop: false, + }) + } } fn name(&self) -> &str { @@ -192,13 +200,15 @@ impl DatabaseTrait for Database { } fn create_table(&mut self, table: &Table) -> Result { - Ok(self.client.execute(&table.create().to_string(), &[])? as usize) + let mut connection = self.pool.get()?; + Ok(connection.execute(&table.create().to_string(), &[])? as usize) } fn insert_data(&mut self, table: &Table) -> Result<()> { let mut rng = StdRng::seed_from_u64(DATA_GENERATION_SEED); let size = Database::MAX_SIZE.min(table.size().generate(&mut rng) as usize); - let statement = self.client.prepare(&table.insert('$').to_string())?; + let mut connection = self.pool.get()?; + let statement = connection.prepare(&table.insert('$').to_string())?; for _ in 0..size { let structured: value::Struct = table.schema().data_type().generate(&mut rng).try_into()?; @@ -209,14 +219,18 @@ impl DatabaseTrait for Database { let values = values?; let params: Vec<&(dyn ToSql + Sync)> = values.iter().map(|v| v as &(dyn ToSql + Sync)).collect(); - self.client.execute(&statement, ¶ms)?; + connection.execute(&statement, ¶ms)?; } Ok(()) } fn query(&mut self, query: &str) -> Result> { - let statement = self.client.prepare(query)?; - let rows = self.client.query(&statement, &[])?; + let rows: Vec<_>; + { + let mut connection = self.pool.get()?; + let statement = connection.prepare(query)?; + rows = connection.query(&statement, &[])?; + } Ok(rows .into_iter() .map(|r| { @@ -397,6 +411,7 @@ mod tests { #[test] fn database_test() -> Result<()> { let mut database = test_database(); + println!("Pool {}", database.pool.max_size()); assert!(!database.eq("SELECT * FROM table_1", "SELECT * FROM table_2")); assert!(database.eq( "SELECT * FROM table_1",