Skip to content

Commit

Permalink
Simplify JDBC result set iterator implementation
Browse files Browse the repository at this point in the history
  • Loading branch information
wendigo committed Oct 9, 2024
1 parent aabf405 commit 2084281
Show file tree
Hide file tree
Showing 10 changed files with 689 additions and 476 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,6 @@
import io.trino.client.Column;
import io.trino.client.IntervalDayTime;
import io.trino.client.IntervalYearMonth;
import io.trino.client.QueryError;
import io.trino.client.QueryStatusInfo;
import io.trino.jdbc.ColumnInfo.Nullable;
import io.trino.jdbc.TypeConversions.NoConversionRegisteredException;
import org.joda.time.DateTimeZone;
Expand Down Expand Up @@ -62,7 +60,6 @@
import java.time.ZonedDateTime;
import java.util.Calendar;
import java.util.GregorianCalendar;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Optional;
Expand Down Expand Up @@ -184,7 +181,7 @@ abstract class AbstractTrinoResultSet
return result;
})
.build();
protected final Iterator<List<Object>> results;
protected final CancellableIterator<List<Object>> results;
private final Map<String, Integer> fieldMap;
private final List<ColumnInfo> columnInfoList;
private final ResultSetMetaData resultSetMetaData;
Expand All @@ -193,7 +190,9 @@ abstract class AbstractTrinoResultSet
private final AtomicBoolean wasNull = new AtomicBoolean();
private final Optional<Statement> statement;

AbstractTrinoResultSet(Optional<Statement> statement, List<Column> columns, Iterator<List<Object>> results)
private final AtomicBoolean closed = new AtomicBoolean();

AbstractTrinoResultSet(Optional<Statement> statement, List<Column> columns, CancellableIterator<List<Object>> results)
{
this.statement = requireNonNull(statement, "statement is null");
requireNonNull(columns, "columns is null");
Expand Down Expand Up @@ -1827,6 +1826,15 @@ public <T> T getObject(String columnLabel, Class<T> type)
return getObject(columnIndex(columnLabel), type);
}

@Override
public void close()
throws SQLException
{
if (closed.compareAndSet(false, true)) {
results.cancel();
}
}

@SuppressWarnings("unchecked")
@Override
public <T> T unwrap(Class<T> iface)
Expand Down Expand Up @@ -1929,14 +1937,6 @@ private static Optional<BigDecimal> toBigDecimal(String value)
}
}

static SQLException resultsException(QueryStatusInfo results)
{
QueryError error = requireNonNull(results.getError());
String message = format("Query failed (#%s): %s", results.getId(), error.getMessage());
Throwable cause = (error.getFailureInfo() == null) ? null : error.getFailureInfo().toException();
return new SQLException(message, error.getSqlState(), error.getErrorCode(), cause);
}

private static Map<String, Integer> getFieldMap(List<Column> columns)
{
Map<String, Integer> map = Maps.newHashMapWithExpectedSize(columns.size());
Expand Down
169 changes: 169 additions & 0 deletions client/trino-jdbc/src/main/java/io/trino/jdbc/AsyncResultIterator.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,169 @@
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.trino.jdbc;

import com.google.common.annotations.VisibleForTesting;
import com.google.common.collect.AbstractIterator;
import com.google.common.util.concurrent.ThreadFactoryBuilder;
import io.trino.client.QueryStatusInfo;
import io.trino.client.StatementClient;

import java.sql.SQLException;
import java.util.List;
import java.util.Optional;
import java.util.concurrent.ArrayBlockingQueue;
import java.util.concurrent.BlockingQueue;
import java.util.concurrent.CancellationException;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Future;
import java.util.concurrent.Semaphore;
import java.util.function.Consumer;

import static com.google.common.base.Throwables.throwIfUnchecked;
import static com.google.common.base.Verify.verify;
import static io.trino.jdbc.ResultUtils.resultsException;
import static java.util.Objects.requireNonNull;
import static java.util.concurrent.Executors.newCachedThreadPool;

public class AsyncResultIterator
extends AbstractIterator<List<Object>>
implements CancellableIterator<List<Object>>
{
private static final int MAX_QUEUED_ROWS = 50_000;
private static final ExecutorService executorService = newCachedThreadPool(
new ThreadFactoryBuilder().setNameFormat("Trino JDBC worker-%s").setDaemon(true).build());

private final StatementClient client;
private final BlockingQueue<List<Object>> rowQueue;
// Semaphore to indicate that some data is ready.
// Each permit represents a row of data (or that the underlying iterator is exhausted).
private final Semaphore semaphore = new Semaphore(0);
private final Future<?> future;

private volatile boolean cancelled;
private volatile boolean finished;

AsyncResultIterator(StatementClient client, Consumer<QueryStats> progressCallback, WarningsManager warningsManager, Optional<BlockingQueue<List<Object>>> queue)
{
requireNonNull(progressCallback, "progressCallback is null");
requireNonNull(warningsManager, "warningsManager is null");

this.client = client;
this.rowQueue = queue.orElseGet(() -> new ArrayBlockingQueue<>(MAX_QUEUED_ROWS));
this.cancelled = false;
this.finished = false;
this.future = executorService.submit(() -> {
try {
do {
QueryStatusInfo results = client.currentStatusInfo();
progressCallback.accept(QueryStats.create(results.getId(), results.getStats()));
warningsManager.addWarnings(results.getWarnings());
Iterable<List<Object>> data = client.currentData().getData();
if (data != null) {
for (List<Object> row : data) {
rowQueue.put(row);
semaphore.release();
}
}
}
while (!cancelled && client.advance());

verify(client.isFinished());
QueryStatusInfo results = client.finalStatusInfo();
progressCallback.accept(QueryStats.create(results.getId(), results.getStats()));
warningsManager.addWarnings(results.getWarnings());
if (results.getError() != null) {
throw new RuntimeException(resultsException(results));
}
}
catch (CancellationException | InterruptedException e) {
close();
throw new RuntimeException(new SQLException("ResultSet thread was interrupted", e));
}
finally {
finished = true;
semaphore.release();
}
});
}

@Override
public void cancel()
{
synchronized (this) {
if (cancelled) {
return;
}
cancelled = true;
}
future.cancel(true);
close();
}

private void close()
{
// When thread interruption is mis-handled by underlying implementation of `client`, the thread which
// is working for `future` may be blocked by `rowQueue.put` (`rowQueue` is full) and will never finish
// its work. It is necessary to close `client` and drain `rowQueue` to avoid such leaks.
client.close();
rowQueue.clear();
}

@VisibleForTesting
Future<?> getFuture()
{
return future;
}

@VisibleForTesting
boolean isBackgroundThreadFinished()
{
return finished;
}

@Override
protected List<Object> computeNext()
{
try {
semaphore.acquire();
}
catch (InterruptedException e) {
handleInterrupt(e);
}
if (rowQueue.isEmpty()) {
// If we got here and the queue is empty the thread fetching from the underlying iterator is done.
// Wait for Future to marked done and check status.
try {
future.get();
}
catch (InterruptedException e) {
handleInterrupt(e);
}
catch (ExecutionException e) {
throwIfUnchecked(e.getCause());
throw new RuntimeException(e.getCause());
}
return endOfData();
}
return rowQueue.poll();
}

private void handleInterrupt(InterruptedException e)
{
cancel();
Thread.currentThread().interrupt();
throw new RuntimeException(new SQLException("Interrupted", e));
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.trino.jdbc;

import java.util.Iterator;

interface CancellableIterator<T>
extends Iterator<T>
{
void cancel();

static <T> CancellableIterator<T> wrap(Iterator<T> iterator)
{
return new CancellableIterator<T>() {
@Override
public void cancel()
{
// noop
}

@Override
public boolean hasNext()
{
return iterator.hasNext();
}

@Override
public T next()
{
return iterator.next();
}
};
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.trino.jdbc;

import com.google.common.collect.AbstractIterator;

import static java.util.Objects.requireNonNull;

public class CancellableLimitingIterator<T>
extends AbstractIterator<T>
implements CancellableIterator<T>
{
private final long maxRows;
private final CancellableIterator<T> delegate;
private long currentRow;

CancellableLimitingIterator(CancellableIterator<T> delegate, long maxRows)
{
this.delegate = requireNonNull(delegate, "delegate is null");
this.maxRows = maxRows;
}

@Override
public void cancel()
{
delegate.cancel();
}

@Override
protected T computeNext()
{
if (maxRows > 0 && currentRow >= maxRows) {
cancel();
return endOfData();
}
currentRow++;
if (delegate.hasNext()) {
return delegate.next();
}
return endOfData();
}

static <T> CancellableIterator<T> limit(CancellableIterator<T> delegate, long maxRows)
{
return new CancellableLimitingIterator<>(delegate, maxRows);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -20,14 +20,16 @@
import java.util.Optional;
import java.util.concurrent.atomic.AtomicBoolean;

import static io.trino.jdbc.CancellableIterator.wrap;

public class InMemoryTrinoResultSet
extends AbstractTrinoResultSet
{
private final AtomicBoolean closed = new AtomicBoolean();

public InMemoryTrinoResultSet(List<Column> columns, List<List<Object>> results)
{
super(Optional.empty(), columns, results.iterator());
super(Optional.empty(), columns, wrap(results.iterator()));
}

@Override
Expand Down
35 changes: 35 additions & 0 deletions client/trino-jdbc/src/main/java/io/trino/jdbc/ResultUtils.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.trino.jdbc;

import io.trino.client.QueryError;
import io.trino.client.QueryStatusInfo;

import java.sql.SQLException;

import static java.lang.String.format;
import static java.util.Objects.requireNonNull;

class ResultUtils
{
private ResultUtils() {}

static SQLException resultsException(QueryStatusInfo results)
{
QueryError error = requireNonNull(results.getError());
String message = format("Query failed (#%s): %s", results.getId(), error.getMessage());
Throwable cause = (error.getFailureInfo() == null) ? null : error.getFailureInfo().toException();
return new SQLException(message, error.getSqlState(), error.getErrorCode(), cause);
}
}
Loading

0 comments on commit 2084281

Please sign in to comment.