Skip to content

Commit

Permalink
Return wrapped connection from Statement.getConnection (#10554)
Browse files Browse the repository at this point in the history
  • Loading branch information
laurit authored Feb 15, 2024
1 parent 205100e commit d739628
Show file tree
Hide file tree
Showing 5 changed files with 67 additions and 21 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,12 @@ public class OpenTelemetryCallableStatement<S extends CallableStatement>
extends OpenTelemetryPreparedStatement<S> implements CallableStatement {

public OpenTelemetryCallableStatement(
S delegate, DbInfo dbInfo, String query, Instrumenter<DbRequest, Void> instrumenter) {
super(delegate, dbInfo, query, instrumenter);
S delegate,
OpenTelemetryConnection connection,
DbInfo dbInfo,
String query,
Instrumenter<DbRequest, Void> instrumenter) {
super(delegate, connection, dbInfo, query, instrumenter);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,36 +61,38 @@ public OpenTelemetryConnection(
@Override
public Statement createStatement() throws SQLException {
Statement statement = delegate.createStatement();
return new OpenTelemetryStatement<>(statement, dbInfo, statementInstrumenter);
return new OpenTelemetryStatement<>(statement, this, dbInfo, statementInstrumenter);
}

@Override
public Statement createStatement(int resultSetType, int resultSetConcurrency)
throws SQLException {
Statement statement = delegate.createStatement(resultSetType, resultSetConcurrency);
return new OpenTelemetryStatement<>(statement, dbInfo, statementInstrumenter);
return new OpenTelemetryStatement<>(statement, this, dbInfo, statementInstrumenter);
}

@Override
public Statement createStatement(
int resultSetType, int resultSetConcurrency, int resultSetHoldability) throws SQLException {
Statement statement =
delegate.createStatement(resultSetType, resultSetConcurrency, resultSetHoldability);
return new OpenTelemetryStatement<>(statement, dbInfo, statementInstrumenter);
return new OpenTelemetryStatement<>(statement, this, dbInfo, statementInstrumenter);
}

@Override
public PreparedStatement prepareStatement(String sql) throws SQLException {
PreparedStatement statement = delegate.prepareStatement(sql);
return new OpenTelemetryPreparedStatement<>(statement, dbInfo, sql, statementInstrumenter);
return new OpenTelemetryPreparedStatement<>(
statement, this, dbInfo, sql, statementInstrumenter);
}

@Override
public PreparedStatement prepareStatement(String sql, int resultSetType, int resultSetConcurrency)
throws SQLException {
PreparedStatement statement =
delegate.prepareStatement(sql, resultSetType, resultSetConcurrency);
return new OpenTelemetryPreparedStatement<>(statement, dbInfo, sql, statementInstrumenter);
return new OpenTelemetryPreparedStatement<>(
statement, this, dbInfo, sql, statementInstrumenter);
}

@Override
Expand All @@ -99,38 +101,44 @@ public PreparedStatement prepareStatement(
throws SQLException {
PreparedStatement statement =
delegate.prepareStatement(sql, resultSetType, resultSetConcurrency, resultSetHoldability);
return new OpenTelemetryPreparedStatement<>(statement, dbInfo, sql, statementInstrumenter);
return new OpenTelemetryPreparedStatement<>(
statement, this, dbInfo, sql, statementInstrumenter);
}

@Override
public PreparedStatement prepareStatement(String sql, int autoGeneratedKeys) throws SQLException {
PreparedStatement statement = delegate.prepareStatement(sql, autoGeneratedKeys);
return new OpenTelemetryPreparedStatement<>(statement, dbInfo, sql, statementInstrumenter);
return new OpenTelemetryPreparedStatement<>(
statement, this, dbInfo, sql, statementInstrumenter);
}

@Override
public PreparedStatement prepareStatement(String sql, int[] columnIndexes) throws SQLException {
PreparedStatement statement = delegate.prepareStatement(sql, columnIndexes);
return new OpenTelemetryPreparedStatement<>(statement, dbInfo, sql, statementInstrumenter);
return new OpenTelemetryPreparedStatement<>(
statement, this, dbInfo, sql, statementInstrumenter);
}

@Override
public PreparedStatement prepareStatement(String sql, String[] columnNames) throws SQLException {
PreparedStatement statement = delegate.prepareStatement(sql, columnNames);
return new OpenTelemetryPreparedStatement<>(statement, dbInfo, sql, statementInstrumenter);
return new OpenTelemetryPreparedStatement<>(
statement, this, dbInfo, sql, statementInstrumenter);
}

@Override
public CallableStatement prepareCall(String sql) throws SQLException {
CallableStatement statement = delegate.prepareCall(sql);
return new OpenTelemetryCallableStatement<>(statement, dbInfo, sql, statementInstrumenter);
return new OpenTelemetryCallableStatement<>(
statement, this, dbInfo, sql, statementInstrumenter);
}

@Override
public CallableStatement prepareCall(String sql, int resultSetType, int resultSetConcurrency)
throws SQLException {
CallableStatement statement = delegate.prepareCall(sql, resultSetType, resultSetConcurrency);
return new OpenTelemetryCallableStatement<>(statement, dbInfo, sql, statementInstrumenter);
return new OpenTelemetryCallableStatement<>(
statement, this, dbInfo, sql, statementInstrumenter);
}

@Override
Expand All @@ -139,7 +147,8 @@ public CallableStatement prepareCall(
throws SQLException {
CallableStatement statement =
delegate.prepareCall(sql, resultSetType, resultSetConcurrency, resultSetHoldability);
return new OpenTelemetryCallableStatement<>(statement, dbInfo, sql, statementInstrumenter);
return new OpenTelemetryCallableStatement<>(
statement, this, dbInfo, sql, statementInstrumenter);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,8 +51,12 @@ public class OpenTelemetryPreparedStatement<S extends PreparedStatement>
extends OpenTelemetryStatement<S> implements PreparedStatement {

public OpenTelemetryPreparedStatement(
S delegate, DbInfo dbInfo, String query, Instrumenter<DbRequest, Void> instrumenter) {
super(delegate, dbInfo, query, instrumenter);
S delegate,
OpenTelemetryConnection connection,
DbInfo dbInfo,
String query,
Instrumenter<DbRequest, Void> instrumenter) {
super(delegate, connection, dbInfo, query, instrumenter);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,19 +38,29 @@
public class OpenTelemetryStatement<S extends Statement> implements Statement {

protected final S delegate;
protected final OpenTelemetryConnection connection;
protected final DbInfo dbInfo;
protected final String query;
protected final Instrumenter<DbRequest, Void> instrumenter;

private final ArrayList<String> batchCommands = new ArrayList<>();

OpenTelemetryStatement(S delegate, DbInfo dbInfo, Instrumenter<DbRequest, Void> instrumenter) {
this(delegate, dbInfo, null, instrumenter);
OpenTelemetryStatement(
S delegate,
OpenTelemetryConnection connection,
DbInfo dbInfo,
Instrumenter<DbRequest, Void> instrumenter) {
this(delegate, connection, dbInfo, null, instrumenter);
}

OpenTelemetryStatement(
S delegate, DbInfo dbInfo, String query, Instrumenter<DbRequest, Void> instrumenter) {
S delegate,
OpenTelemetryConnection connection,
DbInfo dbInfo,
String query,
Instrumenter<DbRequest, Void> instrumenter) {
this.delegate = delegate;
this.connection = connection;
this.dbInfo = dbInfo;
this.query = query;
this.instrumenter = instrumenter;
Expand Down Expand Up @@ -230,8 +240,8 @@ public void clearBatch() throws SQLException {
}

@Override
public Connection getConnection() throws SQLException {
return delegate.getConnection();
public Connection getConnection() {
return connection;
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,17 @@
package io.opentelemetry.instrumentation.jdbc.datasource;

import static io.opentelemetry.sdk.testing.assertj.OpenTelemetryAssertions.equalTo;
import static org.assertj.core.api.Assertions.assertThat;

import io.opentelemetry.instrumentation.jdbc.internal.OpenTelemetryConnection;
import io.opentelemetry.instrumentation.testing.junit.InstrumentationExtension;
import io.opentelemetry.instrumentation.testing.junit.LibraryInstrumentationExtension;
import io.opentelemetry.semconv.SemanticAttributes;
import java.sql.CallableStatement;
import java.sql.Connection;
import java.sql.PreparedStatement;
import java.sql.SQLException;
import java.sql.Statement;
import javax.sql.DataSource;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.RegisterExtension;
Expand Down Expand Up @@ -113,4 +119,17 @@ void buildWithSanitizationDisabled() throws SQLException {
span.hasName("SELECT dbname")
.hasAttribute(equalTo(SemanticAttributes.DB_STATEMENT, "SELECT 1;"))));
}

@Test
void statementReturnsWrappedConnection() throws SQLException {
JdbcTelemetry telemetry = JdbcTelemetry.builder(testing.getOpenTelemetry()).build();
DataSource dataSource = telemetry.wrap(new TestDataSource());
Connection connection = dataSource.getConnection();
Statement statement = connection.createStatement();
assertThat(statement.getConnection()).isInstanceOf(OpenTelemetryConnection.class);
PreparedStatement preparedStatement = connection.prepareStatement("SELECT 1");
assertThat(preparedStatement.getConnection()).isInstanceOf(OpenTelemetryConnection.class);
CallableStatement callableStatement = connection.prepareCall("SELECT 1");
assertThat(callableStatement.getConnection()).isInstanceOf(OpenTelemetryConnection.class);
}
}

0 comments on commit d739628

Please sign in to comment.