From 314f421c9af593d52bbc54f2b355a1c117947437 Mon Sep 17 00:00:00 2001 From: "Mateusz \"Serafin\" Gajewski" Date: Tue, 8 Oct 2024 15:41:44 +0200 Subject: [PATCH] Simplify JDBC result set iterator implementation --- .../io/trino/jdbc/AbstractTrinoResultSet.java | 26 +- .../io/trino/jdbc/AsyncResultIterator.java | 169 ++++++++ .../io/trino/jdbc/CancellableIterator.java | 45 +++ .../jdbc/CancellableLimitingIterator.java | 58 +++ .../io/trino/jdbc/InMemoryTrinoResultSet.java | 4 +- .../main/java/io/trino/jdbc/ResultUtils.java | 35 ++ .../java/io/trino/jdbc/TrinoResultSet.java | 178 +-------- .../java/io/trino/jdbc/TrinoStatement.java | 2 +- .../trino/jdbc/TestAsyncResultIterator.java | 361 ++++++++++++++++++ .../io/trino/jdbc/TestTrinoResultSet.java | 287 -------------- 10 files changed, 689 insertions(+), 476 deletions(-) create mode 100644 client/trino-jdbc/src/main/java/io/trino/jdbc/AsyncResultIterator.java create mode 100644 client/trino-jdbc/src/main/java/io/trino/jdbc/CancellableIterator.java create mode 100644 client/trino-jdbc/src/main/java/io/trino/jdbc/CancellableLimitingIterator.java create mode 100644 client/trino-jdbc/src/main/java/io/trino/jdbc/ResultUtils.java create mode 100644 client/trino-jdbc/src/test/java/io/trino/jdbc/TestAsyncResultIterator.java delete mode 100644 client/trino-jdbc/src/test/java/io/trino/jdbc/TestTrinoResultSet.java diff --git a/client/trino-jdbc/src/main/java/io/trino/jdbc/AbstractTrinoResultSet.java b/client/trino-jdbc/src/main/java/io/trino/jdbc/AbstractTrinoResultSet.java index b4c7c358cc058..0fcffb219942c 100644 --- a/client/trino-jdbc/src/main/java/io/trino/jdbc/AbstractTrinoResultSet.java +++ b/client/trino-jdbc/src/main/java/io/trino/jdbc/AbstractTrinoResultSet.java @@ -24,8 +24,6 @@ import io.trino.client.Column; import io.trino.client.IntervalDayTime; import io.trino.client.IntervalYearMonth; -import io.trino.client.QueryError; -import io.trino.client.QueryStatusInfo; import io.trino.jdbc.ColumnInfo.Nullable; import io.trino.jdbc.TypeConversions.NoConversionRegisteredException; import org.joda.time.DateTimeZone; @@ -62,7 +60,6 @@ import java.time.ZonedDateTime; import java.util.Calendar; import java.util.GregorianCalendar; -import java.util.Iterator; import java.util.List; import java.util.Map; import java.util.Optional; @@ -184,7 +181,7 @@ abstract class AbstractTrinoResultSet return result; }) .build(); - protected final Iterator> results; + protected final CancellableIterator> results; private final Map fieldMap; private final List columnInfoList; private final ResultSetMetaData resultSetMetaData; @@ -193,7 +190,9 @@ abstract class AbstractTrinoResultSet private final AtomicBoolean wasNull = new AtomicBoolean(); private final Optional statement; - AbstractTrinoResultSet(Optional statement, List columns, Iterator> results) + private final AtomicBoolean closed = new AtomicBoolean(); + + AbstractTrinoResultSet(Optional statement, List columns, CancellableIterator> results) { this.statement = requireNonNull(statement, "statement is null"); requireNonNull(columns, "columns is null"); @@ -1827,6 +1826,15 @@ public T getObject(String columnLabel, Class type) return getObject(columnIndex(columnLabel), type); } + @Override + public void close() + throws SQLException + { + if (closed.compareAndSet(false, true)) { + results.cancel(); + } + } + @SuppressWarnings("unchecked") @Override public T unwrap(Class iface) @@ -1929,14 +1937,6 @@ private static Optional toBigDecimal(String value) } } - static SQLException resultsException(QueryStatusInfo results) - { - QueryError error = requireNonNull(results.getError()); - String message = format("Query failed (#%s): %s", results.getId(), error.getMessage()); - Throwable cause = (error.getFailureInfo() == null) ? null : error.getFailureInfo().toException(); - return new SQLException(message, error.getSqlState(), error.getErrorCode(), cause); - } - private static Map getFieldMap(List columns) { Map map = Maps.newHashMapWithExpectedSize(columns.size()); diff --git a/client/trino-jdbc/src/main/java/io/trino/jdbc/AsyncResultIterator.java b/client/trino-jdbc/src/main/java/io/trino/jdbc/AsyncResultIterator.java new file mode 100644 index 0000000000000..b7df1a1aed419 --- /dev/null +++ b/client/trino-jdbc/src/main/java/io/trino/jdbc/AsyncResultIterator.java @@ -0,0 +1,169 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.jdbc; + +import com.google.common.annotations.VisibleForTesting; +import com.google.common.collect.AbstractIterator; +import com.google.common.util.concurrent.ThreadFactoryBuilder; +import io.trino.client.QueryStatusInfo; +import io.trino.client.StatementClient; + +import java.sql.SQLException; +import java.util.List; +import java.util.Optional; +import java.util.concurrent.ArrayBlockingQueue; +import java.util.concurrent.BlockingQueue; +import java.util.concurrent.CancellationException; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Future; +import java.util.concurrent.Semaphore; +import java.util.function.Consumer; + +import static com.google.common.base.Throwables.throwIfUnchecked; +import static com.google.common.base.Verify.verify; +import static io.trino.jdbc.ResultUtils.resultsException; +import static java.util.Objects.requireNonNull; +import static java.util.concurrent.Executors.newCachedThreadPool; + +public class AsyncResultIterator + extends AbstractIterator> + implements CancellableIterator> +{ + private static final int MAX_QUEUED_ROWS = 50_000; + private static final ExecutorService executorService = newCachedThreadPool( + new ThreadFactoryBuilder().setNameFormat("Trino JDBC worker-%s").setDaemon(true).build()); + + private final StatementClient client; + private final BlockingQueue> rowQueue; + // Semaphore to indicate that some data is ready. + // Each permit represents a row of data (or that the underlying iterator is exhausted). + private final Semaphore semaphore = new Semaphore(0); + private final Future future; + + private volatile boolean cancelled; + private volatile boolean finished; + + AsyncResultIterator(StatementClient client, Consumer progressCallback, WarningsManager warningsManager, Optional>> queue) + { + requireNonNull(progressCallback, "progressCallback is null"); + requireNonNull(warningsManager, "warningsManager is null"); + + this.client = client; + this.rowQueue = queue.orElseGet(() -> new ArrayBlockingQueue<>(MAX_QUEUED_ROWS)); + this.cancelled = false; + this.finished = false; + this.future = executorService.submit(() -> { + try { + do { + QueryStatusInfo results = client.currentStatusInfo(); + progressCallback.accept(QueryStats.create(results.getId(), results.getStats())); + warningsManager.addWarnings(results.getWarnings()); + Iterable> data = client.currentData().getData(); + if (data != null) { + for (List row : data) { + rowQueue.put(row); + semaphore.release(); + } + } + } + while (!cancelled && client.advance()); + + verify(client.isFinished()); + QueryStatusInfo results = client.finalStatusInfo(); + progressCallback.accept(QueryStats.create(results.getId(), results.getStats())); + warningsManager.addWarnings(results.getWarnings()); + if (results.getError() != null) { + throw new RuntimeException(resultsException(results)); + } + } + catch (CancellationException | InterruptedException e) { + close(); + throw new RuntimeException(new SQLException("ResultSet thread was interrupted", e)); + } + finally { + finished = true; + semaphore.release(); + } + }); + } + + @Override + public void cancel() + { + synchronized (this) { + if (cancelled) { + return; + } + cancelled = true; + } + future.cancel(true); + close(); + } + + private void close() + { + // When thread interruption is mis-handled by underlying implementation of `client`, the thread which + // is working for `future` may be blocked by `rowQueue.put` (`rowQueue` is full) and will never finish + // its work. It is necessary to close `client` and drain `rowQueue` to avoid such leaks. + client.close(); + rowQueue.clear(); + } + + @VisibleForTesting + Future getFuture() + { + return future; + } + + @VisibleForTesting + boolean isBackgroundThreadFinished() + { + return finished; + } + + @Override + protected List computeNext() + { + try { + semaphore.acquire(); + } + catch (InterruptedException e) { + handleInterrupt(e); + } + if (rowQueue.isEmpty()) { + // If we got here and the queue is empty the thread fetching from the underlying iterator is done. + // Wait for Future to marked done and check status. + try { + future.get(); + } + catch (InterruptedException e) { + handleInterrupt(e); + } + catch (ExecutionException e) { + throwIfUnchecked(e.getCause()); + throw new RuntimeException(e.getCause()); + } + return endOfData(); + } + return rowQueue.poll(); + } + + private void handleInterrupt(InterruptedException e) + { + cancel(); + Thread.currentThread().interrupt(); + throw new RuntimeException(new SQLException("Interrupted", e)); + } +} diff --git a/client/trino-jdbc/src/main/java/io/trino/jdbc/CancellableIterator.java b/client/trino-jdbc/src/main/java/io/trino/jdbc/CancellableIterator.java new file mode 100644 index 0000000000000..e57a464722c9f --- /dev/null +++ b/client/trino-jdbc/src/main/java/io/trino/jdbc/CancellableIterator.java @@ -0,0 +1,45 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.jdbc; + +import java.util.Iterator; + +interface CancellableIterator + extends Iterator +{ + void cancel(); + + static CancellableIterator wrap(Iterator iterator) + { + return new CancellableIterator() { + @Override + public void cancel() + { + // noop + } + + @Override + public boolean hasNext() + { + return iterator.hasNext(); + } + + @Override + public T next() + { + return iterator.next(); + } + }; + } +} diff --git a/client/trino-jdbc/src/main/java/io/trino/jdbc/CancellableLimitingIterator.java b/client/trino-jdbc/src/main/java/io/trino/jdbc/CancellableLimitingIterator.java new file mode 100644 index 0000000000000..30507097395a1 --- /dev/null +++ b/client/trino-jdbc/src/main/java/io/trino/jdbc/CancellableLimitingIterator.java @@ -0,0 +1,58 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.jdbc; + +import com.google.common.collect.AbstractIterator; + +import static java.util.Objects.requireNonNull; + +public class CancellableLimitingIterator + extends AbstractIterator + implements CancellableIterator +{ + private final long maxRows; + private final CancellableIterator delegate; + private long currentRow; + + CancellableLimitingIterator(CancellableIterator delegate, long maxRows) + { + this.delegate = requireNonNull(delegate, "delegate is null"); + this.maxRows = maxRows; + } + + @Override + public void cancel() + { + delegate.cancel(); + } + + @Override + protected T computeNext() + { + if (maxRows > 0 && currentRow >= maxRows) { + cancel(); + return endOfData(); + } + currentRow++; + if (delegate.hasNext()) { + return delegate.next(); + } + return endOfData(); + } + + static CancellableIterator limit(CancellableIterator delegate, long maxRows) + { + return new CancellableLimitingIterator<>(delegate, maxRows); + } +} diff --git a/client/trino-jdbc/src/main/java/io/trino/jdbc/InMemoryTrinoResultSet.java b/client/trino-jdbc/src/main/java/io/trino/jdbc/InMemoryTrinoResultSet.java index a704010ea72f8..225b9d7e75f52 100644 --- a/client/trino-jdbc/src/main/java/io/trino/jdbc/InMemoryTrinoResultSet.java +++ b/client/trino-jdbc/src/main/java/io/trino/jdbc/InMemoryTrinoResultSet.java @@ -20,6 +20,8 @@ import java.util.Optional; import java.util.concurrent.atomic.AtomicBoolean; +import static io.trino.jdbc.CancellableIterator.wrap; + public class InMemoryTrinoResultSet extends AbstractTrinoResultSet { @@ -27,7 +29,7 @@ public class InMemoryTrinoResultSet public InMemoryTrinoResultSet(List columns, List> results) { - super(Optional.empty(), columns, results.iterator()); + super(Optional.empty(), columns, wrap(results.iterator())); } @Override diff --git a/client/trino-jdbc/src/main/java/io/trino/jdbc/ResultUtils.java b/client/trino-jdbc/src/main/java/io/trino/jdbc/ResultUtils.java new file mode 100644 index 0000000000000..975869e8aa7b9 --- /dev/null +++ b/client/trino-jdbc/src/main/java/io/trino/jdbc/ResultUtils.java @@ -0,0 +1,35 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.jdbc; + +import io.trino.client.QueryError; +import io.trino.client.QueryStatusInfo; + +import java.sql.SQLException; + +import static java.lang.String.format; +import static java.util.Objects.requireNonNull; + +class ResultUtils +{ + private ResultUtils() {} + + static SQLException resultsException(QueryStatusInfo results) + { + QueryError error = requireNonNull(results.getError()); + String message = format("Query failed (#%s): %s", results.getId(), error.getMessage()); + Throwable cause = (error.getFailureInfo() == null) ? null : error.getFailureInfo().toException(); + return new SQLException(message, error.getSqlState(), error.getErrorCode(), cause); + } +} diff --git a/client/trino-jdbc/src/main/java/io/trino/jdbc/TrinoResultSet.java b/client/trino-jdbc/src/main/java/io/trino/jdbc/TrinoResultSet.java index 7f193cc55dba7..a28d6bf62a566 100644 --- a/client/trino-jdbc/src/main/java/io/trino/jdbc/TrinoResultSet.java +++ b/client/trino-jdbc/src/main/java/io/trino/jdbc/TrinoResultSet.java @@ -13,10 +13,6 @@ */ package io.trino.jdbc; -import com.google.common.annotations.VisibleForTesting; -import com.google.common.collect.AbstractIterator; -import com.google.common.collect.Streams; -import com.google.common.util.concurrent.ThreadFactoryBuilder; import com.google.errorprone.annotations.concurrent.GuardedBy; import io.trino.client.Column; import io.trino.client.QueryStatusInfo; @@ -24,23 +20,15 @@ import java.sql.SQLException; import java.sql.Statement; -import java.util.Iterator; import java.util.List; import java.util.Optional; -import java.util.concurrent.ArrayBlockingQueue; -import java.util.concurrent.BlockingQueue; -import java.util.concurrent.ExecutionException; -import java.util.concurrent.ExecutorService; -import java.util.concurrent.Future; -import java.util.concurrent.Semaphore; import java.util.function.Consumer; -import java.util.stream.Stream; -import static com.google.common.base.Throwables.throwIfUnchecked; import static com.google.common.base.Verify.verify; +import static io.trino.jdbc.CancellableLimitingIterator.limit; +import static io.trino.jdbc.ResultUtils.resultsException; import static java.lang.String.format; import static java.util.Objects.requireNonNull; -import static java.util.concurrent.Executors.newCachedThreadPool; public class TrinoResultSet extends AbstractTrinoResultSet @@ -68,7 +56,7 @@ private TrinoResultSet(Statement statement, StatementClient client, List super( Optional.of(requireNonNull(statement, "statement is null")), columns, - new AsyncIterator<>(flatten(new ResultsPageIterator(requireNonNull(client, "client is null"), progressCallback, warningsManager), maxRows), client)); + limit(new AsyncResultIterator(requireNonNull(client, "client is null"), progressCallback, warningsManager, Optional.empty()), maxRows)); this.statement = statement; this.client = requireNonNull(client, "client is null"); @@ -115,7 +103,7 @@ public void close() closeStatement = closeStatementOnClose; } - ((AsyncIterator) results).cancel(); + super.close(); client.close(); if (closeStatement) { statement.close(); @@ -134,164 +122,6 @@ void partialCancel() client.cancelLeafStage(); } - private static Iterator flatten(Iterator> iterator, long maxRows) - { - Stream stream = Streams.stream(iterator) - .flatMap(Streams::stream); - if (maxRows > 0) { - stream = stream.limit(maxRows); - } - return stream.iterator(); - } - - @VisibleForTesting - static class AsyncIterator - extends AbstractIterator - { - private static final int MAX_QUEUED_ROWS = 50_000; - private static final ExecutorService executorService = newCachedThreadPool( - new ThreadFactoryBuilder().setNameFormat("Trino JDBC worker-%s").setDaemon(true).build()); - - private final StatementClient client; - private final BlockingQueue rowQueue; - // Semaphore to indicate that some data is ready. - // Each permit represents a row of data (or that the underlying iterator is exhausted). - private final Semaphore semaphore = new Semaphore(0); - private final Future future; - private volatile boolean cancelled; - private volatile boolean finished; - - public AsyncIterator(Iterator dataIterator, StatementClient client) - { - this(dataIterator, client, Optional.empty()); - } - - @VisibleForTesting - AsyncIterator(Iterator dataIterator, StatementClient client, Optional> queue) - { - requireNonNull(dataIterator, "dataIterator is null"); - this.client = client; - this.rowQueue = queue.orElseGet(() -> new ArrayBlockingQueue<>(MAX_QUEUED_ROWS)); - this.cancelled = false; - this.finished = false; - this.future = executorService.submit(() -> { - try { - while (!cancelled && dataIterator.hasNext()) { - rowQueue.put(dataIterator.next()); - semaphore.release(); - } - } - catch (InterruptedException e) { - client.close(); - rowQueue.clear(); - throw new RuntimeException(new SQLException("ResultSet thread was interrupted", e)); - } - finally { - semaphore.release(); - finished = true; - } - }); - } - - public void cancel() - { - cancelled = true; - future.cancel(true); - // When thread interruption is mis-handled by underlying implementation of `client`, the thread which - // is working for `future` may be blocked by `rowQueue.put` (`rowQueue` is full) and will never finish - // its work. It is necessary to close `client` and drain `rowQueue` to avoid such leaks. - client.close(); - rowQueue.clear(); - } - - @VisibleForTesting - Future getFuture() - { - return future; - } - - @VisibleForTesting - boolean isBackgroundThreadFinished() - { - return finished; - } - - @Override - protected T computeNext() - { - try { - semaphore.acquire(); - } - catch (InterruptedException e) { - handleInterrupt(e); - } - if (rowQueue.isEmpty()) { - // If we got here and the queue is empty the thread fetching from the underlying iterator is done. - // Wait for Future to marked done and check status. - try { - future.get(); - } - catch (InterruptedException e) { - handleInterrupt(e); - } - catch (ExecutionException e) { - throwIfUnchecked(e.getCause()); - throw new RuntimeException(e.getCause()); - } - return endOfData(); - } - return rowQueue.poll(); - } - - private void handleInterrupt(InterruptedException e) - { - cancel(); - Thread.currentThread().interrupt(); - throw new RuntimeException(new SQLException("Interrupted", e)); - } - } - - private static class ResultsPageIterator - extends AbstractIterator>> - { - private final StatementClient client; - private final Consumer progressCallback; - private final WarningsManager warningsManager; - - private ResultsPageIterator(StatementClient client, Consumer progressCallback, WarningsManager warningsManager) - { - this.client = requireNonNull(client, "client is null"); - this.progressCallback = requireNonNull(progressCallback, "progressCallback is null"); - this.warningsManager = requireNonNull(warningsManager, "warningsManager is null"); - } - - @Override - protected Iterable> computeNext() - { - while (client.isRunning()) { - QueryStatusInfo results = client.currentStatusInfo(); - progressCallback.accept(QueryStats.create(results.getId(), results.getStats())); - warningsManager.addWarnings(results.getWarnings()); - Iterable> data = client.currentData().getData(); - if (!client.advance() && data == null) { - break; // No more rows, query finished - } - if (data != null) { - return data; - } - } - - verify(client.isFinished()); - QueryStatusInfo results = client.finalStatusInfo(); - progressCallback.accept(QueryStats.create(results.getId(), results.getStats())); - warningsManager.addWarnings(results.getWarnings()); - if (results.getError() != null) { - throw new RuntimeException(resultsException(results)); - } - return endOfData(); - } - } - private static List getColumns(StatementClient client, Consumer progressCallback) throws SQLException { diff --git a/client/trino-jdbc/src/main/java/io/trino/jdbc/TrinoStatement.java b/client/trino-jdbc/src/main/java/io/trino/jdbc/TrinoStatement.java index 62f1470c341ad..736c72dc946f7 100644 --- a/client/trino-jdbc/src/main/java/io/trino/jdbc/TrinoStatement.java +++ b/client/trino-jdbc/src/main/java/io/trino/jdbc/TrinoStatement.java @@ -33,7 +33,7 @@ import java.util.concurrent.atomic.AtomicReference; import java.util.function.Consumer; -import static io.trino.jdbc.AbstractTrinoResultSet.resultsException; +import static io.trino.jdbc.ResultUtils.resultsException; import static java.lang.Math.toIntExact; import static java.util.Objects.requireNonNull; diff --git a/client/trino-jdbc/src/test/java/io/trino/jdbc/TestAsyncResultIterator.java b/client/trino-jdbc/src/test/java/io/trino/jdbc/TestAsyncResultIterator.java new file mode 100644 index 0000000000000..596ef4c88abe2 --- /dev/null +++ b/client/trino-jdbc/src/test/java/io/trino/jdbc/TestAsyncResultIterator.java @@ -0,0 +1,361 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.jdbc; + +import com.google.common.collect.ImmutableList; +import io.trino.client.ClientSelectedRole; +import io.trino.client.Column; +import io.trino.client.QueryData; +import io.trino.client.QueryError; +import io.trino.client.QueryStatusInfo; +import io.trino.client.StageStats; +import io.trino.client.StatementClient; +import io.trino.client.StatementStats; +import io.trino.client.Warning; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.Timeout; + +import java.net.URI; +import java.time.ZoneId; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.OptionalDouble; +import java.util.Set; +import java.util.concurrent.ArrayBlockingQueue; +import java.util.concurrent.BlockingQueue; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicReference; +import java.util.function.Supplier; + +import static java.util.Objects.requireNonNull; +import static org.assertj.core.api.Assertions.assertThat; + +class TestAsyncResultIterator +{ + @Test + @Timeout(10) + public void testIteratorCancelWhenQueueNotFull() + throws Exception + { + AtomicReference thread = new AtomicReference<>(); + CountDownLatch interruptedButSwallowedLatch = new CountDownLatch(1); + + AsyncResultIterator iterator = new AsyncResultIterator( + new MockStatementClient(() -> () -> { + thread.compareAndSet(null, Thread.currentThread()); + try { + TimeUnit.MILLISECONDS.sleep(1000); + } + catch (InterruptedException e) { + interruptedButSwallowedLatch.countDown(); + } + return ImmutableList.of(ImmutableList.of(new Object())); + }), ignored -> {}, + new WarningsManager(), + Optional.of(new ArrayBlockingQueue<>(100))); + + while (thread.get() == null || thread.get().getState() != Thread.State.TIMED_WAITING) { + // wait for thread being waiting + } + iterator.cancel(); + while (!iterator.getFuture().isDone() || !iterator.isBackgroundThreadFinished()) { + TimeUnit.MILLISECONDS.sleep(10); + } + boolean interruptedButSwallowed = interruptedButSwallowedLatch.await(5000, TimeUnit.MILLISECONDS); + assertThat(interruptedButSwallowed).isTrue(); + } + + @Test + @Timeout(10) + public void testIteratorCancelWhenQueueIsFull() + throws Exception + { + BlockingQueue> queue = new ArrayBlockingQueue<>(1); + queue.put(ImmutableList.of()); + // queue is full at the beginning + AtomicReference thread = new AtomicReference<>(); + + AsyncResultIterator iterator = new AsyncResultIterator( + new MockStatementClient(() -> () -> { + thread.compareAndSet(null, Thread.currentThread()); + return ImmutableList.of(ImmutableList.of(new Object())); + }), ignored -> {}, + new WarningsManager(), + Optional.of(queue)); + + while (thread.get() == null || thread.get().getState() != Thread.State.WAITING) { + // wait for thread being waiting (for queue being not full) + TimeUnit.MILLISECONDS.sleep(10); + } + iterator.cancel(); + while (!iterator.isBackgroundThreadFinished()) { + TimeUnit.MILLISECONDS.sleep(10); + } + } + + private static class MockStatementClient + implements StatementClient + { + private final Supplier queryData; + + public MockStatementClient(Supplier queryData) + { + this.queryData = requireNonNull(queryData, "queryData is null"); + } + + @Override + public String getQuery() + { + throw new UnsupportedOperationException(); + } + + @Override + public ZoneId getTimeZone() + { + throw new UnsupportedOperationException(); + } + + @Override + public boolean isRunning() + { + return true; + } + + @Override + public boolean isClientAborted() + { + throw new UnsupportedOperationException(); + } + + @Override + public boolean isClientError() + { + throw new UnsupportedOperationException(); + } + + @Override + public boolean isFinished() + { + return true; + } + + @Override + public StatementStats getStats() + { + throw new UnsupportedOperationException(); + } + + @Override + public QueryStatusInfo currentStatusInfo() + { + return statusInfo("RUNNING"); + } + + @Override + public QueryData currentData() + { + return queryData.get(); + } + + @Override + public QueryStatusInfo finalStatusInfo() + { + return statusInfo("FINISHED"); + } + + @Override + public Optional getSetCatalog() + { + throw new UnsupportedOperationException(); + } + + @Override + public Optional getSetSchema() + { + throw new UnsupportedOperationException(); + } + + @Override + public Optional> getSetPath() + { + throw new UnsupportedOperationException(); + } + + @Override + public Optional getSetAuthorizationUser() + { + throw new UnsupportedOperationException(); + } + + @Override + public boolean isResetAuthorizationUser() + { + throw new UnsupportedOperationException(); + } + + @Override + public Map getSetSessionProperties() + { + throw new UnsupportedOperationException(); + } + + @Override + public Set getResetSessionProperties() + { + throw new UnsupportedOperationException(); + } + + @Override + public Map getSetRoles() + { + throw new UnsupportedOperationException(); + } + + @Override + public Map getAddedPreparedStatements() + { + throw new UnsupportedOperationException(); + } + + @Override + public Set getDeallocatedPreparedStatements() + { + throw new UnsupportedOperationException(); + } + + @Override + public String getStartedTransactionId() + { + throw new UnsupportedOperationException(); + } + + @Override + public boolean isClearTransactionId() + { + throw new UnsupportedOperationException(); + } + + @Override + public boolean advance() + { + return true; + } + + @Override + public void cancelLeafStage() + { + throw new UnsupportedOperationException(); + } + + @Override + public void close() + { + // do nothing + } + } + + private static QueryStatusInfo statusInfo(String status) + { + return new QueryStatusInfo() + { + @Override + public String getId() + { + return ""; + } + + @Override + public URI getInfoUri() + { + return null; + } + + @Override + public URI getPartialCancelUri() + { + return null; + } + + @Override + public URI getNextUri() + { + return null; + } + + @Override + public List getColumns() + { + return ImmutableList.of(); + } + + @Override + public StatementStats getStats() + { + return new StatementStats( + status, + false, + true, + OptionalDouble.of(50), + OptionalDouble.of(50), + 1, + 100, + 50, + 25, + 50, + 100, + 100, + 100, + 100, + 100, + 100, + 100, + 100, + 100, + 100, + StageStats.builder() + .setStageId("id") + .setDone(false) + .setState(status) + .setSubStages(ImmutableList.of()) + .build()); + } + + @Override + public QueryError getError() + { + return null; + } + + @Override + public List getWarnings() + { + return ImmutableList.of(); + } + + @Override + public String getUpdateType() + { + throw new UnsupportedOperationException(); + } + + @Override + public Long getUpdateCount() + { + throw new UnsupportedOperationException(); + } + }; + } +} diff --git a/client/trino-jdbc/src/test/java/io/trino/jdbc/TestTrinoResultSet.java b/client/trino-jdbc/src/test/java/io/trino/jdbc/TestTrinoResultSet.java deleted file mode 100644 index c90ed0e08d397..0000000000000 --- a/client/trino-jdbc/src/test/java/io/trino/jdbc/TestTrinoResultSet.java +++ /dev/null @@ -1,287 +0,0 @@ -/* - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package io.trino.jdbc; - -import com.google.common.collect.ImmutableList; -import io.trino.client.ClientSelectedRole; -import io.trino.client.QueryData; -import io.trino.client.QueryStatusInfo; -import io.trino.client.StatementClient; -import io.trino.client.StatementStats; -import org.junit.jupiter.api.Test; -import org.junit.jupiter.api.Timeout; - -import java.time.ZoneId; -import java.util.Iterator; -import java.util.List; -import java.util.Map; -import java.util.Optional; -import java.util.Set; -import java.util.concurrent.ArrayBlockingQueue; -import java.util.concurrent.BlockingQueue; -import java.util.concurrent.CountDownLatch; -import java.util.concurrent.TimeUnit; -import java.util.concurrent.atomic.AtomicReference; - -import static org.assertj.core.api.Assertions.assertThat; - -/** - * A unit test for {@link TrinoResultSet}. - * - * @see TestJdbcResultSet for an integration test. - */ -public class TestTrinoResultSet -{ - @Test - @Timeout(10) - public void testIteratorCancelWhenQueueNotFull() - throws Exception - { - AtomicReference thread = new AtomicReference<>(); - CountDownLatch interruptedButSwallowedLatch = new CountDownLatch(1); - MockAsyncIterator>> iterator = new MockAsyncIterator<>( - new Iterator>>() - { - @Override - public boolean hasNext() - { - return true; - } - - @Override - public Iterable> next() - { - thread.compareAndSet(null, Thread.currentThread()); - try { - TimeUnit.MILLISECONDS.sleep(1000); - } - catch (InterruptedException e) { - interruptedButSwallowedLatch.countDown(); - } - return ImmutableList.of(ImmutableList.of(new Object())); - } - }, - new ArrayBlockingQueue<>(100)); - - while (thread.get() == null || thread.get().getState() != Thread.State.TIMED_WAITING) { - // wait for thread being waiting - } - iterator.cancel(); - while (!iterator.getFuture().isDone() || !iterator.isBackgroundThreadFinished()) { - TimeUnit.MILLISECONDS.sleep(10); - } - boolean interruptedButSwallowed = interruptedButSwallowedLatch.await(5000, TimeUnit.MILLISECONDS); - assertThat(interruptedButSwallowed).isTrue(); - } - - @Test - @Timeout(10) - public void testIteratorCancelWhenQueueIsFull() - throws Exception - { - BlockingQueue>> queue = new ArrayBlockingQueue<>(1); - queue.put(ImmutableList.of()); - // queue is full at the beginning - AtomicReference thread = new AtomicReference<>(); - MockAsyncIterator>> iterator = new MockAsyncIterator<>( - new Iterator>>() - { - @Override - public boolean hasNext() - { - return true; - } - - @Override - public Iterable> next() - { - thread.compareAndSet(null, Thread.currentThread()); - return ImmutableList.of(ImmutableList.of(new Object())); - } - }, - queue); - - while (thread.get() == null || thread.get().getState() != Thread.State.WAITING) { - // wait for thread being waiting (for queue being not full) - TimeUnit.MILLISECONDS.sleep(10); - } - iterator.cancel(); - while (!iterator.isBackgroundThreadFinished()) { - TimeUnit.MILLISECONDS.sleep(10); - } - } - - private static class MockAsyncIterator - extends TrinoResultSet.AsyncIterator - { - public MockAsyncIterator(Iterator dataIterator, BlockingQueue queue) - { - super( - dataIterator, - new StatementClient() - { - @Override - public String getQuery() - { - throw new UnsupportedOperationException(); - } - - @Override - public ZoneId getTimeZone() - { - throw new UnsupportedOperationException(); - } - - @Override - public boolean isRunning() - { - throw new UnsupportedOperationException(); - } - - @Override - public boolean isClientAborted() - { - throw new UnsupportedOperationException(); - } - - @Override - public boolean isClientError() - { - throw new UnsupportedOperationException(); - } - - @Override - public boolean isFinished() - { - throw new UnsupportedOperationException(); - } - - @Override - public StatementStats getStats() - { - throw new UnsupportedOperationException(); - } - - @Override - public QueryStatusInfo currentStatusInfo() - { - throw new UnsupportedOperationException(); - } - - @Override - public QueryData currentData() - { - throw new UnsupportedOperationException(); - } - - @Override - public QueryStatusInfo finalStatusInfo() - { - throw new UnsupportedOperationException(); - } - - @Override - public Optional getSetCatalog() - { - throw new UnsupportedOperationException(); - } - - @Override - public Optional getSetSchema() - { - throw new UnsupportedOperationException(); - } - - @Override - public Optional> getSetPath() - { - throw new UnsupportedOperationException(); - } - - @Override - public Optional getSetAuthorizationUser() - { - throw new UnsupportedOperationException(); - } - - @Override - public boolean isResetAuthorizationUser() - { - throw new UnsupportedOperationException(); - } - - @Override - public Map getSetSessionProperties() - { - throw new UnsupportedOperationException(); - } - - @Override - public Set getResetSessionProperties() - { - throw new UnsupportedOperationException(); - } - - @Override - public Map getSetRoles() - { - throw new UnsupportedOperationException(); - } - - @Override - public Map getAddedPreparedStatements() - { - throw new UnsupportedOperationException(); - } - - @Override - public Set getDeallocatedPreparedStatements() - { - throw new UnsupportedOperationException(); - } - - @Override - public String getStartedTransactionId() - { - throw new UnsupportedOperationException(); - } - - @Override - public boolean isClearTransactionId() - { - throw new UnsupportedOperationException(); - } - - @Override - public boolean advance() - { - throw new UnsupportedOperationException(); - } - - @Override - public void cancelLeafStage() - { - throw new UnsupportedOperationException(); - } - - @Override - public void close() - { - // do nothing - } - }, - Optional.of(queue)); - } - } -}