Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[CLI] Generate entity Postgres initialize connection with schema search path #1212

Merged
merged 1 commit into from
Nov 9, 2022
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 25 additions & 9 deletions sea-orm-cli/src/commands/generate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ pub async fn run_generate_command(
use sea_schema::mysql::discovery::SchemaDiscovery;
use sqlx::MySql;

let connection = connect::<MySql>(max_connections, url.as_str()).await?;
let connection = connect::<MySql>(max_connections, url.as_str(), None).await?;
let schema_discovery = SchemaDiscovery::new(connection, database_name);
let schema = schema_discovery.discover().await;
let table_stmts = schema
Expand All @@ -132,7 +132,7 @@ pub async fn run_generate_command(
use sea_schema::sqlite::discovery::SchemaDiscovery;
use sqlx::Sqlite;

let connection = connect::<Sqlite>(max_connections, url.as_str()).await?;
let connection = connect::<Sqlite>(max_connections, url.as_str(), None).await?;
let schema_discovery = SchemaDiscovery::new(connection);
let schema = schema_discovery.discover().await?;
let table_stmts = schema
Expand All @@ -150,7 +150,8 @@ pub async fn run_generate_command(
use sqlx::Postgres;

let schema = &database_schema;
let connection = connect::<Postgres>(max_connections, url.as_str()).await?;
let connection =
connect::<Postgres>(max_connections, url.as_str(), Some(schema)).await?;
let schema_discovery = SchemaDiscovery::new(connection, schema);
let schema = schema_discovery.discover().await;
let table_stmts = schema
Expand Down Expand Up @@ -198,15 +199,30 @@ pub async fn run_generate_command(
Ok(())
}

async fn connect<DB>(max_connections: u32, url: &str) -> Result<sqlx::Pool<DB>, Box<dyn Error>>
async fn connect<DB>(
max_connections: u32,
url: &str,
schema: Option<&str>,
) -> Result<sqlx::Pool<DB>, Box<dyn Error>>
where
DB: sqlx::Database,
for<'a> &'a mut <DB as sqlx::Database>::Connection: sqlx::Executor<'a>,
{
sqlx::pool::PoolOptions::<DB>::new()
.max_connections(max_connections)
.connect(url)
.await
.map_err(Into::into)
let mut pool_options = sqlx::pool::PoolOptions::<DB>::new().max_connections(max_connections);
// Set search_path for Postgres, E.g. Some("public") by default
// MySQL & SQLite connection initialize with schema `None`
if let Some(schema) = schema {
let sql = format!("SET search_path = '{}'", schema);
pool_options = pool_options.after_connect(move |conn, _| {
let sql = sql.clone();
Box::pin(async move {
sqlx::Executor::execute(conn, sql.as_str())
.await
.map(|_| ())
})
});
}
pool_options.connect(url).await.map_err(Into::into)
}

impl From<DateTimeCrate> for CodegenDateTimeCrate {
Expand Down