diff --git a/databend-jdbc/pom.xml b/databend-jdbc/pom.xml index b59b417d..15b5c18d 100644 --- a/databend-jdbc/pom.xml +++ b/databend-jdbc/pom.xml @@ -97,6 +97,17 @@ gson 2.6.2 + + junit + junit + 4.13.2 + test + + + org.junit.jupiter + junit-jupiter + test + diff --git a/databend-jdbc/src/main/java/com/databend/jdbc/DatabendPreparedStatement.java b/databend-jdbc/src/main/java/com/databend/jdbc/DatabendPreparedStatement.java index 4e2b084a..30cdef88 100644 --- a/databend-jdbc/src/main/java/com/databend/jdbc/DatabendPreparedStatement.java +++ b/databend-jdbc/src/main/java/com/databend/jdbc/DatabendPreparedStatement.java @@ -52,6 +52,7 @@ import static com.databend.jdbc.ObjectCasts.castToInt; import static com.databend.jdbc.ObjectCasts.castToLong; import static com.databend.jdbc.ObjectCasts.castToShort; +import static com.databend.jdbc.StatementUtil.replaceParameterMarksWithValues; import static java.lang.String.format; import static java.time.format.DateTimeFormatter.ISO_LOCAL_DATE; import static java.time.format.DateTimeFormatter.ISO_LOCAL_TIME; @@ -346,7 +347,8 @@ public int[] executeBatch() throws SQLException { @Override public ResultSet executeQuery() throws SQLException { - this.executeBatch(); + String sql = replaceParameterMarksWithValues(batchInsertUtils.get().getProvideParams(), this.originalSql).get(0).getSql(); + internalExecute(sql, null); return getResultSet(); } diff --git a/databend-jdbc/src/main/java/com/databend/jdbc/StatementUtil.java b/databend-jdbc/src/main/java/com/databend/jdbc/StatementUtil.java index d5c41e11..9d0052f2 100644 --- a/databend-jdbc/src/main/java/com/databend/jdbc/StatementUtil.java +++ b/databend-jdbc/src/main/java/com/databend/jdbc/StatementUtil.java @@ -19,8 +19,8 @@ public class StatementUtil { private static final String SET_PREFIX = "set"; private static final Pattern SET_WITH_SPACE_REGEX = Pattern.compile(SET_PREFIX + " ", Pattern.CASE_INSENSITIVE); - private static final String[] SELECT_KEYWORDS = new String[] { "show", "select", "describe", "exists", "explain", - "with", "call" }; + private static final String[] SELECT_KEYWORDS = new String[]{"show", "select", "describe", "exists", "explain", + "with", "call"}; /** * Returns true if the statement is a query (eg: SELECT, SHOW). @@ -41,7 +41,7 @@ public static boolean isQuery(String cleanSql) { * Extracts parameter from statement (eg: SET x=y) * * @param cleanSql the clean version of the sql (sql statement without comments) - * @param sql the sql statement + * @param sql the sql statement * @return an optional parameter represented with a pair of key/value */ public Optional> extractParamFromSetStatement(@NonNull String cleanSql, String sql) { @@ -200,7 +200,7 @@ public Pair, Optional> extractDbNameAndTableNamePairFro * constructed with the sql statement and the parameters provided * * @param params the parameters - * @param sql the sql statement + * @param sql the sql statement * @return a list of sql statements containing the provided parameters */ public static List replaceParameterMarksWithValues(@NonNull Map params, @@ -213,7 +213,7 @@ public static List replaceParameterMarksWithValues(@NonNul * Returns a list of {@link StatementInfoWrapper} containing sql statements * constructed with the {@link RawStatementWrapper} and the parameters provided * - * @param params the parameters + * @param params the parameters * @param rawStatement the rawStatement * @return a list of sql statements containing the provided parameters */ @@ -292,7 +292,7 @@ private Optional> extractPropertyPair(String cleanStatement String[] values = StringUtils.split(setQuery, "="); if (values.length == 2) { String value = StringUtils.removeEnd(values[1], ";").trim(); - if (StringUtils.isNumeric(value)){ + if (StringUtils.isNumeric(value)) { return Optional.of(Pair.of(values[0].trim(), value.trim())); } else { return Optional.of(Pair.of(values[0].trim(), StringUtils.removeEnd(StringUtils.removeStart(value, "'"), "'"))); diff --git a/databend-jdbc/src/main/java/com/databend/jdbc/parser/BatchInsertUtils.java b/databend-jdbc/src/main/java/com/databend/jdbc/parser/BatchInsertUtils.java index fb0012fc..ef393933 100644 --- a/databend-jdbc/src/main/java/com/databend/jdbc/parser/BatchInsertUtils.java +++ b/databend-jdbc/src/main/java/com/databend/jdbc/parser/BatchInsertUtils.java @@ -1,6 +1,5 @@ package com.databend.jdbc.parser; -import com.databend.jdbc.DatabendPreparedStatement; import de.siegmar.fastcsv.writer.CsvWriter; import de.siegmar.fastcsv.writer.LineDelimiter; @@ -50,6 +49,14 @@ public String getSql() { return sql; } + public Map getProvideParams() { + Map m = new TreeMap<>(); + for (Map.Entry elem : placeHolderEntries.entrySet()) { + m.put(elem.getKey() + 1, elem.getValue()); + } + return m; + } + public String getDatabaseTableName() { Pattern pattern = Pattern.compile("^INSERT INTO\\s+((?:[\\w-]+\\.)?([\\w-]+))(?:\\s*\\((?:[^()]|\\([^()]*\\))*\\))?", Pattern.CASE_INSENSITIVE); Matcher matcher = pattern.matcher(sql.replace("`", "")); diff --git a/databend-jdbc/src/test/java/com/databend/jdbc/StatementUtilTest.java b/databend-jdbc/src/test/java/com/databend/jdbc/StatementUtilTest.java new file mode 100644 index 00000000..d374e309 --- /dev/null +++ b/databend-jdbc/src/test/java/com/databend/jdbc/StatementUtilTest.java @@ -0,0 +1,41 @@ +package com.databend.jdbc; + +import com.google.common.collect.ImmutableMap; +import org.apache.commons.lang3.tuple.ImmutablePair; +import org.junit.jupiter.api.Test; + +import java.util.Map; + +import static com.databend.jdbc.StatementUtil.replaceParameterMarksWithValues; +import static org.junit.jupiter.api.Assertions.*; +public class StatementUtilTest { + @Test + void shouldGetAllQueryParamsFromIn() { + String sql = "SElECT * FROM EMPLOYEES WHERE id IN (?,?)"; + assertEquals(ImmutableMap.of(1, 37, 2, 39), StatementUtil.getParamMarketsPositions(sql)); + assertEquals(1, StatementUtil.parseToRawStatementWrapper(sql).getSubStatements().size()); + } + @Test + void shouldGetAllQueryParams() { + String sql = "SElECT * FROM EMPLOYEES WHERE id = ?"; + assertEquals(ImmutableMap.of(1, 35), StatementUtil.getParamMarketsPositions(sql)); + assertEquals(1, StatementUtil.parseToRawStatementWrapper(sql).getSubStatements().size()); + } + + @Test + void shouldReplaceAQueryParam() { + String sql = "SElECT * FROM EMPLOYEES WHERE id is ?"; + String expectedSql = "SElECT * FROM EMPLOYEES WHERE id is 5"; + Map params = ImmutableMap.of(1, "5"); + System.out.println(replaceParameterMarksWithValues(params, sql)); + assertEquals(expectedSql, replaceParameterMarksWithValues(params, sql).get(0).getSql()); + } + + @Test + void shouldReplaceMultipleQueryParams() { + String sql = "SElECT * FROM EMPLOYEES WHERE id = ? AND name LIKE ? AND dob = ? "; + String expectedSql = "SElECT * FROM EMPLOYEES WHERE id = 5 AND name LIKE 'George' AND dob = '1980-05-22' "; + Map params = ImmutableMap.of(1, "5", 2, "'George'", 3, "'1980-05-22'"); + assertEquals(expectedSql, replaceParameterMarksWithValues(params, sql).get(0).getSql()); + } +} diff --git a/databend-jdbc/src/test/java/com/databend/jdbc/TestBasicDriver.java b/databend-jdbc/src/test/java/com/databend/jdbc/TestBasicDriver.java index 8b4b519f..68d94959 100644 --- a/databend-jdbc/src/test/java/com/databend/jdbc/TestBasicDriver.java +++ b/databend-jdbc/src/test/java/com/databend/jdbc/TestBasicDriver.java @@ -77,19 +77,6 @@ public void testQueryUpdateCount() } } -// @Test -// public void testPrepareStatementQuery() throws SQLException { -// String sql = "SELECT number from numbers(100) where number = ?"; -// Connection connection = createConnection("test_basic_driver"); -// try(PreparedStatement statement = connection.prepareStatement(sql)) { -// statement.setInt(1, 1); -// ResultSet r = statement.executeQuery(); -// statement.execute(); -// r.next(); -// System.out.println(r.getLong("number")); -// } -// } - @Test(groups = {"IT"}) public void testBasicWithProperties() throws SQLException { Properties p = new Properties(); @@ -112,6 +99,19 @@ public void testBasicWithProperties() throws SQLException { } } + @Test + public void testPrepareStatementQuery() throws SQLException { + String sql = "SELECT number from numbers(100) where number = ? or number = ?"; + Connection conn = createConnection("test_basic_driver"); + try (PreparedStatement statement = conn.prepareStatement(sql)) { + statement.setInt(1, 1); + statement.setInt(2, 2); + ResultSet r = statement.executeQuery(); + r.next(); + System.out.println(r.getLong("number")); + } + } + @Test(groups = {"IT"}) public void testBasicWithDatabase() throws SQLException {