Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

conn: copy a batch type in prepare_batch #1038

Merged
merged 1 commit into from
Jul 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions scylla/src/statement/batch.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,17 @@ impl Batch {
}
}

/// Creates an empty batch, with the configuration of existing batch.
pub(crate) fn new_from(batch: &Batch) -> Batch {
let batch_type = batch.get_type();
let config = batch.config.clone();
Batch {
batch_type,
config,
..Default::default()
}
}

/// Creates a new, empty `Batch` of `batch_type` type with the provided statements.
pub fn new_with_statements(batch_type: BatchType, statements: Vec<BatchStatement>) -> Self {
Self {
Expand Down
3 changes: 1 addition & 2 deletions scylla/src/transport/connection.rs
Original file line number Diff line number Diff line change
Expand Up @@ -914,8 +914,7 @@ impl Connection {
prepared_queries.insert(query, prepared);
}

let mut batch: Cow<Batch> = Cow::Owned(Default::default());
batch.to_mut().config = init_batch.config.clone();
let mut batch: Cow<Batch> = Cow::Owned(Batch::new_from(init_batch));
for stmt in &init_batch.statements {
match stmt {
BatchStatement::Query(query) => match prepared_queries.get(query.contents.as_str())
Expand Down
51 changes: 51 additions & 0 deletions scylla/src/transport/session_test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -354,6 +354,57 @@ async fn test_prepared_statement() {
}
}

#[tokio::test]
async fn test_counter_batch() {
use crate::frame::value::Counter;
use scylla_cql::frame::request::batch::BatchType;

setup_tracing();
let session = Arc::new(create_new_session_builder().build().await.unwrap());
let ks = unique_keyspace_name();

session.query(format!("CREATE KEYSPACE IF NOT EXISTS {} WITH REPLICATION = {{'class' : 'NetworkTopologyStrategy', 'replication_factor' : 1}}", ks), &[]).await.unwrap();
session
.query(
format!(
"CREATE TABLE IF NOT EXISTS {}.t_batch (key int PRIMARY KEY, value counter)",
ks
),
&[],
)
.await
.unwrap();

let statement_str = format!("UPDATE {}.t_batch SET value = value + ? WHERE key = ?", ks);
let query = Query::from(statement_str);
let prepared = session.prepare(query.clone()).await.unwrap();

let mut counter_batch = Batch::new(BatchType::Counter);
counter_batch.append_statement(query.clone());
counter_batch.append_statement(prepared.clone());
counter_batch.append_statement(query.clone());
counter_batch.append_statement(prepared.clone());
counter_batch.append_statement(query.clone());
counter_batch.append_statement(prepared.clone());

// Check that we do not get a server error - the driver
// should send a COUNTER batch instead of a LOGGED (default) one.
session
.batch(
&counter_batch,
(
(Counter(1), 1),
(Counter(2), 2),
(Counter(3), 3),
(Counter(4), 4),
(Counter(5), 5),
(Counter(6), 6),
),
)
.await
.unwrap();
}

#[tokio::test]
async fn test_batch() {
setup_tracing();
Expand Down
Loading