Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support passing credentials from pipeline options into SpannerIO.readChangeStream #30361

Merged
merged 13 commits into from
Feb 29, 2024
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import com.google.api.gax.rpc.ServerStreamingCallSettings;
import com.google.api.gax.rpc.StatusCode.Code;
import com.google.api.gax.rpc.UnaryCallSettings;
import com.google.auth.Credentials;
import com.google.cloud.NoCredentials;
import com.google.cloud.ServiceFactory;
import com.google.cloud.spanner.BatchClient;
Expand All @@ -41,6 +42,7 @@
import java.util.concurrent.ConcurrentHashMap;
import org.apache.beam.sdk.options.ValueProvider;
import org.apache.beam.sdk.util.ReleaseInfo;
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.annotations.VisibleForTesting;
import org.joda.time.Duration;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
Expand Down Expand Up @@ -100,7 +102,8 @@ public static SpannerAccessor getOrCreate(SpannerConfig spannerConfig) {
}
}

private static SpannerAccessor createAndConnect(SpannerConfig spannerConfig) {
@VisibleForTesting
static SpannerOptions buildSpannerOptions(SpannerConfig spannerConfig) {
SpannerOptions.Builder builder = SpannerOptions.newBuilder();

Set<Code> retryableCodes = new HashSet<>();
Expand Down Expand Up @@ -222,8 +225,16 @@ private static SpannerAccessor createAndConnect(SpannerConfig spannerConfig) {
if (databaseRole != null && databaseRole.get() != null && !databaseRole.get().isEmpty()) {
builder.setDatabaseRole(databaseRole.get());
}
SpannerOptions options = builder.build();
ValueProvider<Credentials> credentials = spannerConfig.getCredentials();
if (credentials != null && credentials.get() != null) {
builder.setCredentials(credentials.get());
}

return builder.build();
}

private static SpannerAccessor createAndConnect(SpannerConfig spannerConfig) {
SpannerOptions options = buildSpannerOptions(spannerConfig);
Spanner spanner = options.getService();
String instanceId = spannerConfig.getInstanceId().get();
String databaseId = spannerConfig.getDatabaseId().get();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@

import com.google.api.gax.retrying.RetrySettings;
import com.google.api.gax.rpc.StatusCode.Code;
import com.google.auth.Credentials;
import com.google.auto.value.AutoValue;
import com.google.cloud.ServiceFactory;
import com.google.cloud.spanner.Options.RpcPriority;
Expand Down Expand Up @@ -84,6 +85,8 @@ public abstract class SpannerConfig implements Serializable {

public abstract @Nullable ValueProvider<Boolean> getDataBoostEnabled();

public abstract @Nullable ValueProvider<Credentials> getCredentials();

abstract Builder toBuilder();

public static SpannerConfig create() {
Expand Down Expand Up @@ -161,6 +164,8 @@ abstract Builder setExecuteStreamingSqlRetrySettings(

abstract Builder setPartitionReadTimeout(ValueProvider<Duration> partitionReadTimeout);

abstract Builder setCredentials(ValueProvider<Credentials> credentials);

public abstract SpannerConfig build();
}

Expand Down Expand Up @@ -302,4 +307,14 @@ public SpannerConfig withPartitionReadTimeout(Duration partitionReadTimeout) {
public SpannerConfig withPartitionReadTimeout(ValueProvider<Duration> partitionReadTimeout) {
return toBuilder().setPartitionReadTimeout(partitionReadTimeout).build();
}

/** Specifies the credentials. */
public SpannerConfig withCredentials(Credentials credentials) {
return withCredentials(ValueProvider.StaticValueProvider.of(credentials));
}

/** Specifies the credentials. */
public SpannerConfig withCredentials(ValueProvider<Credentials> credentials) {
return toBuilder().setCredentials(credentials).build();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@

import com.google.api.gax.retrying.RetrySettings;
import com.google.api.gax.rpc.StatusCode.Code;
import com.google.auth.Credentials;
import com.google.auto.value.AutoValue;
import com.google.cloud.ServiceFactory;
import com.google.cloud.Timestamp;
Expand Down Expand Up @@ -68,6 +69,7 @@
import org.apache.beam.runners.core.metrics.ServiceCallMetric;
import org.apache.beam.sdk.coders.Coder;
import org.apache.beam.sdk.coders.SerializableCoder;
import org.apache.beam.sdk.extensions.gcp.options.GcpOptions;
import org.apache.beam.sdk.io.gcp.spanner.changestreams.ChangeStreamMetrics;
import org.apache.beam.sdk.io.gcp.spanner.changestreams.ChangeStreamsConstants;
import org.apache.beam.sdk.io.gcp.spanner.changestreams.MetadataSpannerConfigFactory;
Expand All @@ -86,6 +88,7 @@
import org.apache.beam.sdk.metrics.Counter;
import org.apache.beam.sdk.metrics.Distribution;
import org.apache.beam.sdk.metrics.Metrics;
import org.apache.beam.sdk.options.PipelineOptions;
import org.apache.beam.sdk.options.StreamingOptions;
import org.apache.beam.sdk.options.ValueProvider;
import org.apache.beam.sdk.schemas.Schema;
Expand Down Expand Up @@ -1667,26 +1670,9 @@ && getInclusiveStartAt().toSqlTimestamp().after(getInclusiveEndAt().toSqlTimesta
getSpannerConfig().getProjectId().get(),
partitionMetadataInstanceId,
partitionMetadataDatabaseId);
SpannerConfig changeStreamSpannerConfig = getSpannerConfig();
// Set default retryable errors for ReadChangeStream
if (changeStreamSpannerConfig.getRetryableCodes() == null) {
ImmutableSet<Code> defaultRetryableCodes = ImmutableSet.of(Code.UNAVAILABLE, Code.ABORTED);
changeStreamSpannerConfig =
changeStreamSpannerConfig.toBuilder().setRetryableCodes(defaultRetryableCodes).build();
}
// Set default retry timeouts for ReadChangeStream
if (changeStreamSpannerConfig.getExecuteStreamingSqlRetrySettings() == null) {
changeStreamSpannerConfig =
changeStreamSpannerConfig
.toBuilder()
.setExecuteStreamingSqlRetrySettings(
RetrySettings.newBuilder()
.setTotalTimeout(org.threeten.bp.Duration.ofMinutes(5))
.setInitialRpcTimeout(org.threeten.bp.Duration.ofMinutes(1))
.setMaxRpcTimeout(org.threeten.bp.Duration.ofMinutes(1))
.build())
.build();
}

final SpannerConfig changeStreamSpannerConfig =
buildChangeStreamSpannerConfig(input.getPipeline().getOptions());
final SpannerConfig partitionMetadataSpannerConfig =
MetadataSpannerConfigFactory.create(
changeStreamSpannerConfig, partitionMetadataInstanceId, partitionMetadataDatabaseId);
Expand Down Expand Up @@ -1773,6 +1759,37 @@ && getInclusiveStartAt().toSqlTimestamp().after(getInclusiveEndAt().toSqlTimesta
.apply(ParDo.of(new CleanUpReadChangeStreamDoFn(daoFactory)));
return dataChangeRecordsOut;
}

public SpannerConfig buildChangeStreamSpannerConfig(PipelineOptions pipelineOptions) {
SpannerConfig changeStreamSpannerConfig = getSpannerConfig();
// Set default retryable errors for ReadChangeStream
if (changeStreamSpannerConfig.getRetryableCodes() == null) {
ImmutableSet<Code> defaultRetryableCodes = ImmutableSet.of(Code.UNAVAILABLE, Code.ABORTED);
changeStreamSpannerConfig =
changeStreamSpannerConfig.toBuilder().setRetryableCodes(defaultRetryableCodes).build();
}
// Set default retry timeouts for ReadChangeStream
if (changeStreamSpannerConfig.getExecuteStreamingSqlRetrySettings() == null) {
changeStreamSpannerConfig =
changeStreamSpannerConfig
.toBuilder()
.setExecuteStreamingSqlRetrySettings(
RetrySettings.newBuilder()
.setTotalTimeout(org.threeten.bp.Duration.ofMinutes(5))
.setInitialRpcTimeout(org.threeten.bp.Duration.ofMinutes(1))
.setMaxRpcTimeout(org.threeten.bp.Duration.ofMinutes(1))
.build())
.build();
}
// If credentials are not set in SpannerConfig, check pipeline options for credentials.
if (changeStreamSpannerConfig.getCredentials() == null) {
final Credentials credentials = pipelineOptions.as(GcpOptions.class).getGcpCredential();
if (credentials != null) {
changeStreamSpannerConfig = changeStreamSpannerConfig.withCredentials(credentials);
}
}
dengwe1 marked this conversation as resolved.
Show resolved Hide resolved
return changeStreamSpannerConfig;
}
}

private static Dialect getDialect(SpannerConfig spannerConfig) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@

import com.google.api.gax.retrying.RetrySettings;
import com.google.api.gax.rpc.StatusCode.Code;
import com.google.auth.Credentials;
import com.google.cloud.spanner.Options.RpcPriority;
import org.apache.beam.sdk.io.gcp.spanner.SpannerConfig;
import org.apache.beam.sdk.options.ValueProvider;
Expand Down Expand Up @@ -113,6 +114,11 @@ public static SpannerConfig create(
config = config.withRpcPriority(StaticValueProvider.of(rpcPriority.get()));
}

ValueProvider<Credentials> credentials = primaryConfig.getCredentials();
if (credentials != null) {
config = config.withCredentials(credentials);
}

return config;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,14 @@
*/
package org.apache.beam.sdk.io.gcp.spanner;

import static org.junit.Assert.assertEquals;
import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;

import com.google.cloud.spanner.DatabaseId;
import com.google.cloud.spanner.SpannerOptions;
import org.apache.beam.sdk.extensions.gcp.auth.TestCredential;
import org.apache.beam.sdk.options.ValueProvider.StaticValueProvider;
import org.junit.Before;
import org.junit.Test;
Expand Down Expand Up @@ -141,4 +144,24 @@ public void testCreateWithEmptyDatabaseRole() {
.getDatabaseClient(DatabaseId.of("project", "test1", "test1"));
verify(serviceFactory.mockSpanner(), times(1)).close();
}

@Test
public void testBuildSpannerOptionsWithCredential() {
TestCredential testCredential = new TestCredential();
SpannerConfig config1 =
SpannerConfig.create()
.toBuilder()
.setServiceFactory(serviceFactory)
.setProjectId(StaticValueProvider.of("project"))
.setInstanceId(StaticValueProvider.of("test-instance"))
.setDatabaseId(StaticValueProvider.of("test-db"))
.setDatabaseRole(StaticValueProvider.of("test-role"))
.setCredentials(StaticValueProvider.of(testCredential))
.build();

SpannerOptions options = SpannerAccessor.buildSpannerOptions(config1);
assertEquals("project", options.getProjectId());
assertEquals("test-role", options.getDatabaseRole());
assertEquals(testCredential, options.getCredentials());
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
dengwe1 marked this conversation as resolved.
Show resolved Hide resolved
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.beam.sdk.io.gcp.spanner;

import static org.junit.Assert.assertEquals;

import com.google.auth.Credentials;
import com.google.cloud.Timestamp;
import com.google.cloud.spanner.Options.RpcPriority;
import org.apache.beam.sdk.extensions.gcp.auth.TestCredential;
import org.apache.beam.sdk.extensions.gcp.options.GcpOptions;
import org.apache.beam.sdk.io.gcp.spanner.changestreams.MetadataSpannerConfigFactory;
import org.apache.beam.sdk.testing.TestPipeline;
import org.junit.Before;
import org.junit.Rule;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.JUnit4;

@RunWith(JUnit4.class)
public class SpannerIOReadChangeStreamTest {

private static final String TEST_PROJECT = "my-project";
private static final String TEST_INSTANCE = "my-instance";
private static final String TEST_DATABASE = "my-database";
private static final String TEST_METADATA_INSTANCE = "my-metadata-instance";
private static final String TEST_METADATA_DATABASE = "my-metadata-database";
private static final String TEST_METADATA_TABLE = "my-metadata-table";
private static final String TEST_CHANGE_STREAM = "my-change-stream";

@Rule public final transient TestPipeline testPipeline = TestPipeline.create();

private SpannerConfig spannerConfig;
private SpannerIO.ReadChangeStream readChangeStream;

@Before
public void setUp() throws Exception {
spannerConfig =
SpannerConfig.create()
.withProjectId(TEST_PROJECT)
.withInstanceId(TEST_INSTANCE)
.withDatabaseId(TEST_DATABASE);

Timestamp startTimestamp = Timestamp.now();
Timestamp endTimestamp =
Timestamp.ofTimeSecondsAndNanos(
startTimestamp.getSeconds() + 10, startTimestamp.getNanos());
readChangeStream =
SpannerIO.readChangeStream()
.withSpannerConfig(spannerConfig)
.withChangeStreamName(TEST_CHANGE_STREAM)
.withMetadataInstance(TEST_METADATA_INSTANCE)
.withMetadataDatabase(TEST_METADATA_DATABASE)
.withMetadataTable(TEST_METADATA_TABLE)
.withRpcPriority(RpcPriority.MEDIUM)
.withInclusiveStartAt(startTimestamp)
.withInclusiveEndAt(endTimestamp);
}

@Test
public void testSetPipelineCredential() {
TestCredential testCredential = new TestCredential();
// Set the credential in the pipeline options.
testPipeline.getOptions().as(GcpOptions.class).setGcpCredential(testCredential);
SpannerConfig changeStreamSpannerConfig =
readChangeStream.buildChangeStreamSpannerConfig(testPipeline.getOptions());
SpannerConfig metadataSpannerConfig =
MetadataSpannerConfigFactory.create(
changeStreamSpannerConfig, TEST_METADATA_INSTANCE, TEST_METADATA_DATABASE);
// Verify that the credential is propagated in ReadChangeStream.
assertEquals(testCredential, changeStreamSpannerConfig.getCredentials().get());
assertEquals(testCredential, metadataSpannerConfig.getCredentials().get());
}

@Test
public void testSetSpannerConfigCredential() {
TestCredential testCredential = new TestCredential();
// Set the credential in the SpannerConfig.
spannerConfig = spannerConfig.withCredentials(testCredential);
readChangeStream = readChangeStream.withSpannerConfig(spannerConfig);
SpannerConfig changeStreamSpannerConfig =
readChangeStream.buildChangeStreamSpannerConfig(testPipeline.getOptions());
SpannerConfig metadataSpannerConfig =
MetadataSpannerConfigFactory.create(
changeStreamSpannerConfig, TEST_METADATA_INSTANCE, TEST_METADATA_DATABASE);
// Verify that the credential is propagated in ReadChangeStream.
assertEquals(testCredential, changeStreamSpannerConfig.getCredentials().get());
assertEquals(testCredential, metadataSpannerConfig.getCredentials().get());
}

@Test
public void testWithDefaultCredential() {
// Get the default credential, without setting any credentials in the pipeline options or
// SpannerConfig.
Credentials defaultCredential =
testPipeline.getOptions().as(GcpOptions.class).getGcpCredential();
SpannerConfig changeStreamSpannerConfig =
readChangeStream.buildChangeStreamSpannerConfig(testPipeline.getOptions());
SpannerConfig metadataSpannerConfig =
MetadataSpannerConfigFactory.create(
changeStreamSpannerConfig, TEST_METADATA_INSTANCE, TEST_METADATA_DATABASE);
// Verify that the default credential will be used in ReadChangeStream.
assertEquals(defaultCredential, changeStreamSpannerConfig.getCredentials().get());
assertEquals(defaultCredential, metadataSpannerConfig.getCredentials().get());
}
}
Loading