Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Batch DML support #107

Merged
merged 3 commits into from
Jun 11, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 23 additions & 19 deletions src/main/java/com/google/cloud/spanner/r2dbc/SpannerStatement.java
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import com.google.cloud.spanner.r2dbc.statement.TypedNull;
import com.google.cloud.spanner.r2dbc.util.Assert;
import com.google.protobuf.Struct;
import com.google.spanner.v1.ExecuteBatchDmlResponse;
import com.google.spanner.v1.PartialResultSet;
import com.google.spanner.v1.Session;
import io.r2dbc.spi.Result;
Expand Down Expand Up @@ -52,6 +53,8 @@ public class SpannerStatement implements Statement {

private StatementBindings statementBindings;

private StatementType statementType;

/**
* Creates a Spanner statement for a given SQL statement.
*
Expand All @@ -73,6 +76,7 @@ public SpannerStatement(
this.transaction = transaction;
this.sql = Assert.requireNonNull(sql, "SQL string can not be null");
this.statementBindings = new StatementBindings();
this.statementType = StatementParser.getStatementType(this.sql);
}

@Override
Expand Down Expand Up @@ -108,34 +112,34 @@ public Statement bindNull(int i, Class<?> type) {

@Override
public Publisher<? extends Result> execute() {
Flux<Struct> structFlux = Flux.fromIterable(this.statementBindings.getBindings());
StatementType statementType = StatementParser.getStatementType(this.sql);

if (statementType == StatementType.SELECT) {
return structFlux.flatMap(struct -> runSingleStatement(struct, statementType));
switch (this.statementType) {
case DML:
return this.client
.executeBatchDml(this.session, this.transaction, this.sql,
this.statementBindings.getBindings(),
this.statementBindings.getTypes())
.flatMapIterable(ExecuteBatchDmlResponse::getResultSetsList)
.map(resultSet -> new SpannerResult(Flux.empty(),
Mono.just(Math.toIntExact(resultSet.getStats().getRowCountExact()))));
case SELECT:
Flux<Struct> structFlux = Flux.fromIterable(this.statementBindings.getBindings());
return structFlux.flatMap(this::runSelectStatement);
default:
throw new UnsupportedOperationException("Unsupported statement type " + this.statementType);
}
// DML statements have to be executed sequentially because they need seqNo to be in order
return structFlux.concatMapDelayError(struct -> runSingleStatement(struct, statementType));
}

private Mono<? extends Result> runSingleStatement(Struct params, StatementType statementType) {
private Mono<? extends Result> runSelectStatement(Struct params) {
PartialResultRowExtractor partialResultRowExtractor = new PartialResultRowExtractor();

Flux<PartialResultSet> resultSetFlux =
this.client.executeStreamingSql(
this.session, this.transaction, this.sql, params, this.statementBindings.getTypes());

if (statementType == StatementType.SELECT) {
return resultSetFlux
.flatMapIterable(partialResultRowExtractor, getPartialResultSetFetchSize())
.transform(result -> Mono.just(new SpannerResult(result, Mono.just(0))))
.next();
} else {
return resultSetFlux
.last()
.map(partialResultSet -> Math.toIntExact(partialResultSet.getStats().getRowCountExact()))
.map(rowCount -> new SpannerResult(Flux.empty(), Mono.just(rowCount)));
}
return resultSetFlux
.flatMapIterable(partialResultRowExtractor, getPartialResultSetFetchSize())
.transform(result -> Mono.just(new SpannerResult(result, Mono.just(0))))
.next();
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,12 @@
import com.google.cloud.spanner.r2dbc.SpannerTransactionContext;
import com.google.protobuf.Struct;
import com.google.spanner.v1.CommitResponse;
import com.google.spanner.v1.ExecuteBatchDmlResponse;
import com.google.spanner.v1.PartialResultSet;
import com.google.spanner.v1.Session;
import com.google.spanner.v1.Transaction;
import com.google.spanner.v1.Type;
import java.util.List;
import java.util.Map;
import javax.annotation.Nullable;
import reactor.core.publisher.Flux;
Expand Down Expand Up @@ -83,6 +85,13 @@ default Flux<PartialResultSet> executeStreamingSql(
return executeStreamingSql(session, transaction, sql, null, null);
}

/**
* Execute DML batch.
*/
Mono<ExecuteBatchDmlResponse> executeBatchDml(Session session,
@Nullable SpannerTransactionContext transactionContext, String sql,
List<Struct> params, Map<String, Type> types);
meltsufin marked this conversation as resolved.
Show resolved Hide resolved

/**
* Release any resources held by the {@link Client}.
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@
import com.google.spanner.v1.CommitResponse;
import com.google.spanner.v1.CreateSessionRequest;
import com.google.spanner.v1.DeleteSessionRequest;
import com.google.spanner.v1.ExecuteBatchDmlRequest;
import com.google.spanner.v1.ExecuteBatchDmlResponse;
import com.google.spanner.v1.ExecuteSqlRequest;
import com.google.spanner.v1.PartialResultSet;
import com.google.spanner.v1.RollbackRequest;
Expand All @@ -42,6 +44,7 @@
import io.grpc.ManagedChannel;
import io.grpc.ManagedChannelBuilder;
import io.grpc.auth.MoreCallCredentials;
import java.util.List;
import java.util.Map;
import javax.annotation.Nullable;
import reactor.core.publisher.Flux;
Expand Down Expand Up @@ -159,6 +162,31 @@ public Mono<Void> deleteSession(Session session) {
});
}

@Override
public Mono<ExecuteBatchDmlResponse> executeBatchDml(Session session,
@Nullable SpannerTransactionContext transactionContext, String sql,
List<Struct> params, Map<String, Type> types) {

ExecuteBatchDmlRequest.Builder request = ExecuteBatchDmlRequest.newBuilder()
.setSession(session.getName());
if (transactionContext != null && transactionContext.getTransaction() != null) {
request.setTransaction(
TransactionSelector.newBuilder().setId(transactionContext.getTransaction().getId())
.build())
.setSeqno(transactionContext.nextSeqNum());

}
for (Struct paramsStruct : params) {
ExecuteBatchDmlRequest.Statement statement = ExecuteBatchDmlRequest.Statement.newBuilder()
.setSql(sql).setParams(paramsStruct).putAllParamTypes(types)
.build();
request.addStatements(statement);
}

return ObservableReactiveUtil
.unaryCall(obs -> this.spanner.executeBatchDml(request.build(), obs));
}

@Override
public Flux<PartialResultSet> executeStreamingSql(
Session session, @Nullable SpannerTransactionContext transactionContext, String sql,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,9 @@
import com.google.cloud.spanner.r2dbc.client.Client;
import com.google.protobuf.Struct;
import com.google.protobuf.Value;
import com.google.spanner.v1.ExecuteBatchDmlResponse;
import com.google.spanner.v1.PartialResultSet;
import com.google.spanner.v1.ResultSet;
import com.google.spanner.v1.ResultSetMetadata;
import com.google.spanner.v1.ResultSetStats;
import com.google.spanner.v1.Session;
Expand All @@ -37,6 +39,7 @@
import com.google.spanner.v1.Type;
import com.google.spanner.v1.TypeCode;
import io.r2dbc.spi.Result;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.Map;
Expand Down Expand Up @@ -195,23 +198,28 @@ public void readMultiResultSetQueryTest() {

when(this.mockClient.executeStreamingSql(any(), any(), any(), any(), any())).thenReturn(inputs);

StepVerifier.create(Flux.from(new SpannerStatement(this.mockClient, null, null, "").execute())
StepVerifier
.create(Flux.from(new SpannerStatement(this.mockClient, null, null, "SELECT").execute())
.flatMap(r -> Mono.from(r.getRowsUpdated())))
.expectNext(0)
.verifyComplete();
}

@Test
public void readDmlQueryTest() {
PartialResultSet p1 = PartialResultSet.newBuilder().setStats(
ResultSetStats.newBuilder().setRowCountExact(555).build()
).build();
ResultSet resultSet = ResultSet.newBuilder()
.setStats(ResultSetStats.newBuilder().setRowCountExact(555).build())
.build();

Flux<PartialResultSet> inputs = Flux.just(p1);
ExecuteBatchDmlResponse executeBatchDmlResponse = ExecuteBatchDmlResponse.newBuilder()
.addResultSets(resultSet)
.build();

when(this.mockClient.executeStreamingSql(any(), any(), any(), any(), any())).thenReturn(inputs);
when(this.mockClient.executeBatchDml(any(), any(), any(), any(), any()))
.thenReturn(Mono.just(executeBatchDmlResponse));

StepVerifier.create(Flux.from(new SpannerStatement(this.mockClient, null, null, "").execute())
StepVerifier.create(
Flux.from(new SpannerStatement(this.mockClient, null, null, "Insert into books").execute())
.flatMap(r -> Mono.from(r.getRowsUpdated())))
.expectNext(555)
.verifyComplete();
Expand All @@ -221,13 +229,19 @@ public void readDmlQueryTest() {
public void noopMapOnUpdateQueriesWhenNoRowsAffected() {
Client mockClient = mock(Client.class);
String sql = "delete from Books where true";
PartialResultSet partialResultSet = PartialResultSet.newBuilder()

ResultSet resultSet = ResultSet.newBuilder()
.setMetadata(ResultSetMetadata.getDefaultInstance())
.setStats(ResultSetStats.getDefaultInstance())
.build();
when(mockClient.executeStreamingSql(TEST_SESSION, null, sql,
Struct.newBuilder().build(), Collections.EMPTY_MAP))
.thenReturn(Flux.just(partialResultSet));

ExecuteBatchDmlResponse executeBatchDmlResponse = ExecuteBatchDmlResponse.newBuilder()
.addResultSets(resultSet)
.build();

when(mockClient.executeBatchDml(TEST_SESSION, null, sql,
Arrays.asList(Struct.newBuilder().build()), Collections.EMPTY_MAP))
.thenReturn(Mono.just(executeBatchDmlResponse));

SpannerStatement statement
= new SpannerStatement(mockClient, TEST_SESSION, null, sql);
Expand All @@ -244,7 +258,7 @@ public void noopMapOnUpdateQueriesWhenNoRowsAffected() {
.expectNext(0)
.verifyComplete();

verify(mockClient, times(2)).executeStreamingSql(TEST_SESSION, null, sql,
Struct.newBuilder().build(), Collections.EMPTY_MAP);
verify(mockClient, times(1)).executeBatchDml(TEST_SESSION, null, sql,
Arrays.asList(Struct.newBuilder().build()), Collections.EMPTY_MAP);
}
}
67 changes: 46 additions & 21 deletions src/test/java/com/google/cloud/spanner/r2dbc/it/SpannerIT.java
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@
import org.slf4j.LoggerFactory;
import reactor.core.publisher.Flux;
import reactor.core.publisher.Mono;
import reactor.test.StepVerifier;

/**
* Integration test for connecting to a real Spanner instance.
Expand Down Expand Up @@ -171,25 +172,48 @@ public void testQuerying() {

Mono.from(this.connectionFactory.create())
.delayUntil(c -> c.beginTransaction())
.delayUntil(c -> Flux.from(c.createStatement(
"INSERT BOOKS (UUID, TITLE, AUTHOR, CATEGORY, FICTION, PUBLISHED, WORDS_PER_SENTENCE)"
+ " VALUES (@uuid, @title, @author, @category, @fiction, @published, @wps);")
.bind("uuid", "2b2cbd78-ecd8-430e-b685-fa7910f8a4c7")
.bind("author", "Douglas Crockford")
.bind("category", 100L)
.bind("title", "JavaScript: The Good Parts")
.bind("fiction", true)
.bind("published", LocalDate.of(2008, 5, 1))
.bind("wps", 20.8)
.add()
.bind("uuid", "df0e3d06-2743-4691-8e51-6d33d90c5cb9")
.bind("author", "Joshua Bloch")
.bind("category", 100L)
.bind("title", "Effective Java")
.bind("fiction", false)
.bind("published", LocalDate.of(2018, 1, 6))
.bind("wps", 15.1)
.execute()).flatMapSequential(r -> Mono.from(r.getRowsUpdated())))
.delayUntil(c ->
Mono.fromRunnable(() ->
StepVerifier.create(Flux.from(c.createStatement(
"INSERT BOOKS "
+ "(UUID, TITLE, AUTHOR, CATEGORY, FICTION, PUBLISHED, WORDS_PER_SENTENCE)"
+ " VALUES "
+ "(@uuid, @title, @author, @category, @fiction, @published, @wps);")
.bind("uuid", "2b2cbd78-ecd8-430e-b685-fa7910f8a4c7")
.bind("author", "Douglas Crockford")
.bind("category", 100L)
.bind("title", "JavaScript: The Good Parts")
.bind("fiction", true)
.bind("published", LocalDate.of(2008, 5, 1))
.bind("wps", 20.8)
.add()
.bind("uuid", "df0e3d06-2743-4691-8e51-6d33d90c5cb9")
.bind("author", "Joshua Bloch")
.bind("category", 100L)
.bind("title", "Effective Java")
.bind("fiction", false)
.bind("published", LocalDate.of(2018, 1, 6))
.bind("wps", 15.1)
.execute())
.flatMapSequential(r -> Mono.from(r.getRowsUpdated())))
.expectNext(1).expectNext(1).verifyComplete())
)
.delayUntil(c -> c.commitTransaction())
.block();

Mono.from(this.connectionFactory.create())
.delayUntil(c -> c.beginTransaction())
.delayUntil(c ->
Mono.fromRunnable(() ->
StepVerifier
.create(Flux.from(c.createStatement(
"UPDATE BOOKS SET CATEGORY = @new_cat WHERE CATEGORY = @old_cat")
.bind("new_cat", 101L)
.bind("old_cat", 100L)
.execute())
.flatMap(r -> Mono.from(r.getRowsUpdated())))
.expectNext(2).verifyComplete())
)
.delayUntil(c -> c.commitTransaction())
.block();

Expand Down Expand Up @@ -272,12 +296,13 @@ private int executeDmlQuery(String sql) {
Connection connection = Mono.from(connectionFactory.create()).block();

Mono.from(connection.beginTransaction()).block();
int rowsUpdated = Mono.from(connection.createStatement(sql).execute())
List<Integer> rowsUpdatedPerStatement = Flux.from(connection.createStatement(sql).execute())
.flatMap(result -> Mono.from(result.getRowsUpdated()))
.collectList()
.block();
Mono.from(connection.commitTransaction()).block();

return rowsUpdated;
return rowsUpdatedPerStatement.get(0);
}

/**
Expand Down