Skip to content

Commit

Permalink
Rename PreparedStatement to ServerPreparedStatement (#21577)
Browse files Browse the repository at this point in the history
* Rename PreparedStatement to ServerPreparedStatement

* Rename PreparedStatementRegistry to ServerPreparedStatementRegistry

* Revise javadoc of ServerPreparedStatement
  • Loading branch information
TeslaCN authored Oct 15, 2022
1 parent aa246af commit 9e3bb7e
Show file tree
Hide file tree
Showing 35 changed files with 133 additions and 131 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ public final class ConnectionSession {

private final ExecutorStatementManager statementManager;

private final PreparedStatementRegistry preparedStatementRegistry = new PreparedStatementRegistry();
private final ServerPreparedStatementRegistry serverPreparedStatementRegistry = new ServerPreparedStatementRegistry();

private final ConnectionContext connectionContext;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,26 +23,26 @@
import java.util.Optional;

/**
* Logic prepared statement for clients of ShardingSphere-Proxy.
* Server prepared statement for clients of ShardingSphere-Proxy.
*/
public interface PreparedStatement {
public interface ServerPreparedStatement {

/**
* Get SQL of prepared statement.
* Get SQL of server prepared statement.
*
* @return SQL
*/
String getSql();

/**
* Get {@link SQLStatement} of prepared statement.
* Get {@link SQLStatement} of server prepared statement.
*
* @return {@link SQLStatement}
*/
SQLStatement getSqlStatement();

/**
* Get optional {@link SQLStatementContext} of prepared statement.
* Get optional {@link SQLStatementContext} of server prepared statement.
*
* @return optional {@link SQLStatementContext}
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,36 +21,36 @@
import java.util.concurrent.ConcurrentHashMap;

/**
* {@link PreparedStatement} registry for {@link ConnectionSession}.
* {@link ServerPreparedStatement} registry for {@link ConnectionSession}.
*/
public final class PreparedStatementRegistry {
public final class ServerPreparedStatementRegistry {

private final Map<Object, PreparedStatement> preparedStatements = new ConcurrentHashMap<>();
private final Map<Object, ServerPreparedStatement> preparedStatements = new ConcurrentHashMap<>();

/**
* Add {@link PreparedStatement} into registry.
* Add {@link ServerPreparedStatement} into registry.
*
* @param statementId statement ID
* @param preparedStatement prepared statement
* @param serverPreparedStatement server prepared statement
*/
public void addPreparedStatement(final Object statementId, final PreparedStatement preparedStatement) {
preparedStatements.put(statementId, preparedStatement);
public void addPreparedStatement(final Object statementId, final ServerPreparedStatement serverPreparedStatement) {
preparedStatements.put(statementId, serverPreparedStatement);
}

/**
* Get prepared statement by statement ID.
* Get {@link ServerPreparedStatement} by statement ID.
*
* @param <T> implementation of {@link PreparedStatement}
* @param <T> implementation of {@link ServerPreparedStatement}
* @param statementId statement ID
* @return {@link PreparedStatement}
* @return {@link ServerPreparedStatement}
*/
@SuppressWarnings("unchecked")
public <T extends PreparedStatement> T getPreparedStatement(final Object statementId) {
public <T extends ServerPreparedStatement> T getPreparedStatement(final Object statementId) {
return (T) preparedStatements.get(statementId);
}

/**
* Remove {@link PreparedStatement} from registry.
* Remove {@link ServerPreparedStatement} from registry.
*
* @param statementId statement ID
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,19 +28,19 @@
import static org.junit.Assert.assertNull;
import static org.hamcrest.MatcherAssert.assertThat;

public final class PreparedStatementRegistryTest {
public final class ServerServerPreparedStatementRegistryTest {

@Test
public void assertAddAndGetAndClosePreparedStatement() {
PreparedStatement expected = new DummyPreparedStatement();
PreparedStatementRegistry registry = new PreparedStatementRegistry();
ServerPreparedStatement expected = new DummyServerPreparedStatement();
ServerPreparedStatementRegistry registry = new ServerPreparedStatementRegistry();
registry.addPreparedStatement(1, expected);
assertThat(registry.getPreparedStatement(1), is(expected));
registry.removePreparedStatement(1);
assertNull(registry.getPreparedStatement(1));
}

private static class DummyPreparedStatement implements PreparedStatement {
private static class DummyServerPreparedStatement implements ServerPreparedStatement {

@Override
public String getSql() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
import org.apache.shardingsphere.db.protocol.mysql.packet.command.query.text.query.MySQLComQueryPacket;
import org.apache.shardingsphere.db.protocol.mysql.payload.MySQLPacketPayload;
import org.apache.shardingsphere.proxy.backend.session.ConnectionSession;
import org.apache.shardingsphere.proxy.frontend.mysql.command.query.binary.MySQLPreparedStatement;
import org.apache.shardingsphere.proxy.frontend.mysql.command.query.binary.MySQLServerPreparedStatement;

/**
* Command packet factory for MySQL.
Expand Down Expand Up @@ -64,8 +64,9 @@ public static MySQLCommandPacket newInstance(final MySQLCommandPacketType comman
case COM_STMT_PREPARE:
return new MySQLComStmtPreparePacket(payload);
case COM_STMT_EXECUTE:
MySQLPreparedStatement preparedStatement = connectionSession.getPreparedStatementRegistry().getPreparedStatement(payload.getByteBuf().getIntLE(payload.getByteBuf().readerIndex()));
return new MySQLComStmtExecutePacket(payload, preparedStatement.getSqlStatement().getParameterCount());
MySQLServerPreparedStatement serverPreparedStatement =
connectionSession.getServerPreparedStatementRegistry().getPreparedStatement(payload.getByteBuf().getIntLE(payload.getByteBuf().readerIndex()));
return new MySQLComStmtExecutePacket(payload, serverPreparedStatement.getSqlStatement().getParameterCount());
case COM_STMT_SEND_LONG_DATA:
return new MySQLComStmtSendLongDataPacket(payload);
case COM_STMT_RESET:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ public final class MySQLComStmtSendLongDataExecutor implements CommandExecutor {

@Override
public Collection<DatabasePacket<?>> execute() {
MySQLPreparedStatement preparedStatement = connectionSession.getPreparedStatementRegistry().getPreparedStatement(packet.getStatementId());
MySQLServerPreparedStatement preparedStatement = connectionSession.getServerPreparedStatementRegistry().getPreparedStatement(packet.getStatementId());
preparedStatement.getLongData().put(packet.getParamId(), packet.getData());
return Collections.emptyList();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
import lombok.Setter;
import org.apache.shardingsphere.db.protocol.mysql.packet.command.query.binary.MySQLPreparedStatementParameterType;
import org.apache.shardingsphere.infra.binder.statement.SQLStatementContext;
import org.apache.shardingsphere.proxy.backend.session.PreparedStatement;
import org.apache.shardingsphere.proxy.backend.session.ServerPreparedStatement;
import org.apache.shardingsphere.sql.parser.sql.common.statement.SQLStatement;

import java.util.Collections;
Expand All @@ -37,7 +37,7 @@
@RequiredArgsConstructor
@Getter
@Setter
public final class MySQLPreparedStatement implements PreparedStatement {
public final class MySQLServerPreparedStatement implements ServerPreparedStatement {

private final String sql;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ public final class MySQLComStmtCloseExecutor implements CommandExecutor {

@Override
public Collection<DatabasePacket<?>> execute() {
connectionSession.getPreparedStatementRegistry().removePreparedStatement(packet.getStatementId());
connectionSession.getServerPreparedStatementRegistry().removePreparedStatement(packet.getStatementId());
return Collections.emptyList();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@
import org.apache.shardingsphere.proxy.frontend.command.executor.QueryCommandExecutor;
import org.apache.shardingsphere.proxy.frontend.command.executor.ResponseType;
import org.apache.shardingsphere.proxy.frontend.mysql.command.ServerStatusFlagCalculator;
import org.apache.shardingsphere.proxy.frontend.mysql.command.query.binary.MySQLPreparedStatement;
import org.apache.shardingsphere.proxy.frontend.mysql.command.query.binary.MySQLServerPreparedStatement;
import org.apache.shardingsphere.proxy.frontend.mysql.command.query.builder.ResponsePacketBuilder;

import java.sql.SQLException;
Expand All @@ -70,7 +70,7 @@ public final class MySQLComStmtExecuteExecutor implements QueryCommandExecutor {

@Override
public Collection<DatabasePacket<?>> execute() throws SQLException {
MySQLPreparedStatement preparedStatement = updateAndGetPreparedStatement();
MySQLServerPreparedStatement preparedStatement = updateAndGetPreparedStatement();
List<Object> parameters = packet.readParameters(preparedStatement.getParameterTypes(), preparedStatement.getLongData().keySet());
preparedStatement.getLongData().forEach(parameters::set);
SQLStatementContext<?> sqlStatementContext = preparedStatement.getSqlStatementContext().get();
Expand All @@ -84,8 +84,8 @@ public Collection<DatabasePacket<?>> execute() throws SQLException {
return responseHeader instanceof QueryResponseHeader ? processQuery((QueryResponseHeader) responseHeader) : processUpdate((UpdateResponseHeader) responseHeader);
}

private MySQLPreparedStatement updateAndGetPreparedStatement() {
MySQLPreparedStatement result = connectionSession.getPreparedStatementRegistry().getPreparedStatement(packet.getStatementId());
private MySQLServerPreparedStatement updateAndGetPreparedStatement() {
MySQLServerPreparedStatement result = connectionSession.getServerPreparedStatementRegistry().getPreparedStatement(packet.getStatementId());
if (MySQLNewParametersBoundFlag.PARAMETER_TYPE_EXIST == packet.getNewParametersBoundFlag()) {
result.setParameterTypes(packet.getNewParameterTypes());
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@
import org.apache.shardingsphere.proxy.backend.session.ConnectionSession;
import org.apache.shardingsphere.proxy.frontend.command.executor.CommandExecutor;
import org.apache.shardingsphere.proxy.frontend.mysql.command.ServerStatusFlagCalculator;
import org.apache.shardingsphere.proxy.frontend.mysql.command.query.binary.MySQLPreparedStatement;
import org.apache.shardingsphere.proxy.frontend.mysql.command.query.binary.MySQLServerPreparedStatement;
import org.apache.shardingsphere.proxy.frontend.mysql.command.query.binary.MySQLStatementIDGenerator;
import org.apache.shardingsphere.sql.parser.sql.common.statement.SQLStatement;
import org.apache.shardingsphere.sql.parser.sql.common.statement.dml.SelectStatement;
Expand Down Expand Up @@ -73,7 +73,7 @@ public Collection<DatabasePacket<?>> execute() {
int statementId = MySQLStatementIDGenerator.getInstance().nextStatementId(connectionSession.getConnectionId());
SQLStatementContext<?> sqlStatementContext = SQLStatementContextFactory.newInstance(ProxyContext.getInstance().getContextManager().getMetaDataContexts().getMetaData().getDatabases(),
sqlStatement, connectionSession.getDefaultDatabaseName());
connectionSession.getPreparedStatementRegistry().addPreparedStatement(statementId, new MySQLPreparedStatement(packet.getSql(), sqlStatement, sqlStatementContext));
connectionSession.getServerPreparedStatementRegistry().addPreparedStatement(statementId, new MySQLServerPreparedStatement(packet.getSql(), sqlStatement, sqlStatementContext));
return createPackets(statementId, projectionCount, sqlStatement.getParameterCount());
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
import org.apache.shardingsphere.proxy.backend.session.ConnectionSession;
import org.apache.shardingsphere.proxy.frontend.command.executor.CommandExecutor;
import org.apache.shardingsphere.proxy.frontend.mysql.command.ServerStatusFlagCalculator;
import org.apache.shardingsphere.proxy.frontend.mysql.command.query.binary.MySQLPreparedStatement;
import org.apache.shardingsphere.proxy.frontend.mysql.command.query.binary.MySQLServerPreparedStatement;

import java.util.Collection;
import java.util.Collections;
Expand All @@ -41,7 +41,7 @@ public final class MySQLComStmtResetExecutor implements CommandExecutor {

@Override
public Collection<DatabasePacket<?>> execute() {
connectionSession.getPreparedStatementRegistry().<MySQLPreparedStatement>getPreparedStatement(packet.getStatementId()).getLongData().clear();
connectionSession.getServerPreparedStatementRegistry().<MySQLServerPreparedStatement>getPreparedStatement(packet.getStatementId()).getLongData().clear();
return Collections.singleton(new MySQLOKPacket(1, ServerStatusFlagCalculator.calculateFor(connectionSession)));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,8 @@
import org.apache.shardingsphere.db.protocol.mysql.payload.MySQLPacketPayload;
import org.apache.shardingsphere.infra.binder.statement.SQLStatementContext;
import org.apache.shardingsphere.proxy.backend.session.ConnectionSession;
import org.apache.shardingsphere.proxy.backend.session.PreparedStatementRegistry;
import org.apache.shardingsphere.proxy.frontend.mysql.command.query.binary.MySQLPreparedStatement;
import org.apache.shardingsphere.proxy.backend.session.ServerPreparedStatementRegistry;
import org.apache.shardingsphere.proxy.frontend.mysql.command.query.binary.MySQLServerPreparedStatement;
import org.apache.shardingsphere.sql.parser.sql.dialect.statement.mysql.dml.MySQLSelectStatement;
import org.junit.Test;
import org.junit.runner.RunWith;
Expand Down Expand Up @@ -91,9 +91,9 @@ public void assertNewInstanceWithComStmtExecutePacket() throws SQLException {
when(payload.readInt1()).thenReturn(MySQLNewParametersBoundFlag.PARAMETER_TYPE_EXIST.getValue());
when(payload.readInt4()).thenReturn(1);
when(payload.getByteBuf().getIntLE(anyInt())).thenReturn(1);
PreparedStatementRegistry preparedStatementRegistry = new PreparedStatementRegistry();
when(connectionSession.getPreparedStatementRegistry()).thenReturn(preparedStatementRegistry);
preparedStatementRegistry.addPreparedStatement(1, new MySQLPreparedStatement("select 1", new MySQLSelectStatement(), mock(SQLStatementContext.class)));
ServerPreparedStatementRegistry serverPreparedStatementRegistry = new ServerPreparedStatementRegistry();
when(connectionSession.getServerPreparedStatementRegistry()).thenReturn(serverPreparedStatementRegistry);
serverPreparedStatementRegistry.addPreparedStatement(1, new MySQLServerPreparedStatement("select 1", new MySQLSelectStatement(), mock(SQLStatementContext.class)));
assertThat(MySQLCommandPacketFactory.newInstance(MySQLCommandPacketType.COM_STMT_EXECUTE, payload, connectionSession), instanceOf(MySQLComStmtExecutePacket.class));
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
import org.apache.shardingsphere.db.protocol.packet.DatabasePacket;
import org.apache.shardingsphere.infra.binder.statement.SQLStatementContext;
import org.apache.shardingsphere.proxy.backend.session.ConnectionSession;
import org.apache.shardingsphere.proxy.backend.session.PreparedStatementRegistry;
import org.apache.shardingsphere.proxy.backend.session.ServerPreparedStatementRegistry;
import org.junit.Test;

import java.nio.charset.StandardCharsets;
Expand All @@ -43,9 +43,9 @@ public void assertExecute() {
byte[] data = "data".getBytes(StandardCharsets.US_ASCII);
when(packet.getData()).thenReturn(data);
ConnectionSession connectionSession = mock(ConnectionSession.class);
when(connectionSession.getPreparedStatementRegistry()).thenReturn(new PreparedStatementRegistry());
MySQLPreparedStatement preparedStatement = new MySQLPreparedStatement("insert into t (b) values (?)", null, mock(SQLStatementContext.class));
connectionSession.getPreparedStatementRegistry().addPreparedStatement(1, preparedStatement);
when(connectionSession.getServerPreparedStatementRegistry()).thenReturn(new ServerPreparedStatementRegistry());
MySQLServerPreparedStatement preparedStatement = new MySQLServerPreparedStatement("insert into t (b) values (?)", null, mock(SQLStatementContext.class));
connectionSession.getServerPreparedStatementRegistry().addPreparedStatement(1, preparedStatement);
MySQLComStmtSendLongDataExecutor executor = new MySQLComStmtSendLongDataExecutor(packet, connectionSession);
Collection<DatabasePacket<?>> actual = executor.execute();
assertThat(actual, is(Collections.emptyList()));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,6 @@ public void assertExecute() {
MySQLComStmtClosePacket packet = new MySQLComStmtClosePacket(new MySQLPacketPayload(Unpooled.wrappedBuffer(new byte[]{0x01, 0x00, 0x00, 0x00}), StandardCharsets.UTF_8));
ConnectionSession connectionSession = mock(ConnectionSession.class, RETURNS_DEEP_STUBS);
assertThat(new MySQLComStmtCloseExecutor(packet, connectionSession).execute(), is(Collections.emptyList()));
verify(connectionSession.getPreparedStatementRegistry()).removePreparedStatement(1);
verify(connectionSession.getServerPreparedStatementRegistry()).removePreparedStatement(1);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@
import org.apache.shardingsphere.proxy.backend.session.ConnectionSession;
import org.apache.shardingsphere.proxy.frontend.command.executor.ResponseType;
import org.apache.shardingsphere.proxy.frontend.mysql.ProxyContextRestorer;
import org.apache.shardingsphere.proxy.frontend.mysql.command.query.binary.MySQLPreparedStatement;
import org.apache.shardingsphere.proxy.frontend.mysql.command.query.binary.MySQLServerPreparedStatement;
import org.apache.shardingsphere.sql.parser.sql.common.segment.dml.assignment.ColumnAssignmentSegment;
import org.apache.shardingsphere.sql.parser.sql.common.segment.dml.assignment.SetAssignmentSegment;
import org.apache.shardingsphere.sql.parser.sql.common.segment.dml.column.ColumnSegment;
Expand Down Expand Up @@ -119,13 +119,13 @@ public void setUp() {
when(connectionSession.getAttributeMap().attr(MySQLConstants.MYSQL_CHARACTER_SET_ATTRIBUTE_KEY).get()).thenReturn(MySQLCharacterSet.UTF8MB4_GENERAL_CI);
when(connectionSession.getBackendConnection()).thenReturn(backendConnection);
SQLStatementContext<?> selectStatementContext = prepareSelectStatementContext();
when(connectionSession.getPreparedStatementRegistry().getPreparedStatement(1))
.thenReturn(new MySQLPreparedStatement("select * from tbl where id = ?", prepareSelectStatement(), selectStatementContext));
when(connectionSession.getServerPreparedStatementRegistry().getPreparedStatement(1))
.thenReturn(new MySQLServerPreparedStatement("select * from tbl where id = ?", prepareSelectStatement(), selectStatementContext));
UpdateStatementContext updateStatementContext = mock(UpdateStatementContext.class, RETURNS_DEEP_STUBS);
when(connectionSession.getPreparedStatementRegistry().getPreparedStatement(2))
.thenReturn(new MySQLPreparedStatement("update tbl set col=1 where id = ?", prepareUpdateStatement(), updateStatementContext));
when(connectionSession.getPreparedStatementRegistry().getPreparedStatement(3))
.thenReturn(new MySQLPreparedStatement("commit", new MySQLCommitStatement(), new CommonSQLStatementContext<>(new MySQLCommitStatement())));
when(connectionSession.getServerPreparedStatementRegistry().getPreparedStatement(2))
.thenReturn(new MySQLServerPreparedStatement("update tbl set col=1 where id = ?", prepareUpdateStatement(), updateStatementContext));
when(connectionSession.getServerPreparedStatementRegistry().getPreparedStatement(3))
.thenReturn(new MySQLServerPreparedStatement("commit", new MySQLCommitStatement(), new CommonSQLStatementContext<>(new MySQLCommitStatement())));
}

private ShardingSphereDatabase mockDatabase() {
Expand Down
Loading

0 comments on commit 9e3bb7e

Please sign in to comment.