diff --git a/src/config.rs b/src/config.rs index e8d42df..7566b97 100644 --- a/src/config.rs +++ b/src/config.rs @@ -181,6 +181,10 @@ pub struct ConnectionConf { #[clap(long("ssl-key"), value_name = "PATH")] pub ssl_key_file: Option, + /// Datacenter name + #[clap(long("datacenter"), required = false, default_value = "")] + pub datacenter: String, + /// Default CQL query consistency level #[clap(long("consistency"), required = false, default_value = "LOCAL_QUORUM")] pub consistency: Consistency, diff --git a/src/context.rs b/src/context.rs index e1b6776..1968367 100644 --- a/src/context.rs +++ b/src/context.rs @@ -25,6 +25,7 @@ use rune::{Any, Value}; use rust_embed::RustEmbed; use scylla::_macro_internal::ColumnType; use scylla::frame::response::result::CqlValue; +use scylla::load_balancing::DefaultPolicy; use scylla::prepared_statement::PreparedStatement; use scylla::transport::errors::{DbError, NewSessionError, QueryError}; use scylla::transport::session::PoolSize; @@ -57,8 +58,14 @@ fn ssl_context(conf: &&ConnectionConf) -> Result, CassError> /// Configures connection to Cassandra. pub async fn connect(conf: &ConnectionConf) -> Result { + let mut policy_builder = DefaultPolicy::builder().token_aware(true); + let dc = &conf.datacenter; + if !dc.is_empty() { + policy_builder = policy_builder.prefer_datacenter(dc.to_owned()).permit_dc_failover(true); + } let profile = ExecutionProfile::builder() .consistency(conf.consistency.scylla_consistency()) + .load_balancing_policy(policy_builder.build()) .request_timeout(Some(Duration::from_secs(conf.request_timeout.get() as u64))) .build(); diff --git a/src/report.rs b/src/report.rs index cf6a7fc..e1445ba 100644 --- a/src/report.rs +++ b/src/report.rs @@ -499,6 +499,10 @@ impl<'a> Display for RunConfigCmp<'a> { } let lines: Vec> = vec![ + self.line("Datacenter", "", |conf| {conf.connection.datacenter.clone()}), + self.line("Consistency", "", |conf| { + conf.connection.consistency.scylla_consistency().to_string() + }), self.line("Threads", "", |conf| Quantity::from(conf.threads)), self.line("Connections", "", |conf| { Quantity::from(conf.connection.count)