Skip to content

Commit

Permalink
Enable Job management for the Prism runner (#32091)
Browse files Browse the repository at this point in the history
  • Loading branch information
damondouglas authored Aug 7, 2024
1 parent 07e692b commit 656a296
Show file tree
Hide file tree
Showing 2 changed files with 371 additions and 0 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,160 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.beam.runners.prism;

import static org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions.checkState;

import com.google.auto.value.AutoValue;
import java.io.Closeable;
import java.util.Optional;
import java.util.concurrent.TimeUnit;
import org.apache.beam.model.jobmanagement.v1.JobApi;
import org.apache.beam.model.jobmanagement.v1.JobServiceGrpc;
import org.apache.beam.model.pipeline.v1.Endpoints;
import org.apache.beam.sdk.PipelineResult;
import org.apache.beam.sdk.fn.channel.ManagedChannelFactory;
import org.apache.beam.sdk.options.PortablePipelineOptions;
import org.apache.beam.vendor.grpc.v1p60p1.io.grpc.ManagedChannel;
import org.joda.time.Duration;

/**
* A wrapper for {@link JobServiceGrpc.JobServiceBlockingStub} that {@link #close}es when {@link
* StateListener#onStateChanged} is invoked with a {@link PipelineResult.State} that is {@link
* PipelineResult.State#isTerminal}.
*/
@AutoValue
abstract class PrismJobManager implements StateListener, Closeable {

/**
* Instantiate a {@link PrismJobManager} with {@param options}, assigning {@link #getEndpoint}
* from {@link PortablePipelineOptions#getJobEndpoint} and {@link #getTimeout} from {@link
* PortablePipelineOptions#getJobServerTimeout}. Defaults the instantiations of {@link
* #getManagedChannel} and {@link #getBlockingStub}. See respective getters for more details.
*/
static PrismJobManager of(PortablePipelineOptions options) {
return builder()
.setEndpoint(options.getJobEndpoint())
.setTimeout(Duration.standardSeconds(options.getJobServerTimeout()))
.build();
}

static Builder builder() {
return new AutoValue_PrismJobManager.Builder();
}

/**
* Executes {@link #getBlockingStub()}'s {@link JobServiceGrpc.JobServiceBlockingStub#prepare}
* method.
*/
JobApi.PrepareJobResponse prepare(JobApi.PrepareJobRequest request) {
return getBlockingStub().prepare(request);
}

/**
* Executes {@link #getBlockingStub()}'s {@link JobServiceGrpc.JobServiceBlockingStub#run} method.
*/
JobApi.RunJobResponse run(JobApi.RunJobRequest request) {
return getBlockingStub().run(request);
}

/** The {@link JobServiceGrpc} endpoint. */
abstract String getEndpoint();

/** The {@link JobServiceGrpc} timeout. */
abstract Duration getTimeout();

/** The {@link #getBlockingStub}'s channel. Defaulted from the {@link #getEndpoint()}. */
abstract ManagedChannel getManagedChannel();

/** The wrapped service defaulted using the {@link #getManagedChannel}. */
abstract JobServiceGrpc.JobServiceBlockingStub getBlockingStub();

/** Shuts down {@link #getManagedChannel}, if not {@link #isShutdown}. */
@Override
public void close() {
if (isShutdown()) {
return;
}
getManagedChannel().shutdown();
try {
getManagedChannel().awaitTermination(3000L, TimeUnit.MILLISECONDS);
} catch (InterruptedException ignored) {
}
}

/** Queries whether {@link #getManagedChannel} {@link ManagedChannel#isShutdown}. */
boolean isShutdown() {
return getManagedChannel().isShutdown();
}

/**
* Override of {@link StateListener#onStateChanged}. Invokes {@link #close} when {@link
* PipelineResult.State} {@link PipelineResult.State#isTerminal}.
*/
@Override
public void onStateChanged(PipelineResult.State state) {
if (state.isTerminal()) {
close();
}
}

@AutoValue.Builder
abstract static class Builder {

abstract Builder setEndpoint(String endpoint);

abstract Optional<String> getEndpoint();

abstract Builder setTimeout(Duration timeout);

abstract Optional<Duration> getTimeout();

abstract Builder setManagedChannel(ManagedChannel managedChannel);

abstract Optional<ManagedChannel> getManagedChannel();

abstract Builder setBlockingStub(JobServiceGrpc.JobServiceBlockingStub blockingStub);

abstract Optional<JobServiceGrpc.JobServiceBlockingStub> getBlockingStub();

abstract PrismJobManager autoBuild();

final PrismJobManager build() {

checkState(getEndpoint().isPresent(), "endpoint is not set");
checkState(getTimeout().isPresent(), "timeout is not set");

if (!getManagedChannel().isPresent()) {
ManagedChannelFactory channelFactory = ManagedChannelFactory.createDefault();

setManagedChannel(
channelFactory.forDescriptor(
Endpoints.ApiServiceDescriptor.newBuilder().setUrl(getEndpoint().get()).build()));
}

if (!getBlockingStub().isPresent()) {
setBlockingStub(
JobServiceGrpc.newBlockingStub(getManagedChannel().get())
.withDeadlineAfter(getTimeout().get().getMillis(), TimeUnit.MILLISECONDS)
.withWaitForReady());
}

return autoBuild();
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,211 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.beam.runners.prism;

import static com.google.common.truth.Truth.assertThat;
import static org.junit.Assert.assertThrows;

import java.io.IOException;
import java.util.Optional;
import org.apache.beam.model.jobmanagement.v1.JobApi;
import org.apache.beam.model.jobmanagement.v1.JobServiceGrpc;
import org.apache.beam.model.pipeline.v1.Endpoints;
import org.apache.beam.model.pipeline.v1.RunnerApi;
import org.apache.beam.sdk.Pipeline;
import org.apache.beam.sdk.PipelineResult;
import org.apache.beam.sdk.transforms.Impulse;
import org.apache.beam.sdk.util.construction.PipelineTranslation;
import org.apache.beam.vendor.grpc.v1p60p1.io.grpc.ManagedChannel;
import org.apache.beam.vendor.grpc.v1p60p1.io.grpc.inprocess.InProcessChannelBuilder;
import org.apache.beam.vendor.grpc.v1p60p1.io.grpc.inprocess.InProcessServerBuilder;
import org.apache.beam.vendor.grpc.v1p60p1.io.grpc.stub.StreamObserver;
import org.apache.beam.vendor.grpc.v1p60p1.io.grpc.testing.GrpcCleanupRule;
import org.joda.time.Duration;
import org.junit.Rule;
import org.junit.Test;
import org.junit.rules.TestName;
import org.junit.runner.RunWith;
import org.junit.runners.JUnit4;

/** Tests for {@link PrismJobManager}. */
@RunWith(JUnit4.class)
public class PrismJobManagerTest {
@Rule public final GrpcCleanupRule grpcCleanup = new GrpcCleanupRule();

@Rule public TestName testName = new TestName();

@Test
public void givenPrepareError_forwardsException_canGracefulShutdown() {
TestJobService service =
new TestJobService().withErrorResponse(new RuntimeException(testName.getMethodName()));
PrismJobManager underTest = prismJobManager(service);
assertThat(underTest.isShutdown()).isFalse();
assertThrows(
RuntimeException.class,
() ->
underTest.prepare(
JobApi.PrepareJobRequest.newBuilder().setPipeline(pipelineOf()).build()));
assertThat(underTest.isShutdown()).isFalse();
underTest.close();
assertThat(underTest.isShutdown()).isTrue();
}

@Test
public void givenPrepareSuccess_forwardsResponse_canGracefulShutdown() {
TestJobService service =
new TestJobService()
.withPrepareJobResponse(
JobApi.PrepareJobResponse.newBuilder()
.setStagingSessionToken("token")
.setPreparationId("preparationId")
.setArtifactStagingEndpoint(
Endpoints.ApiServiceDescriptor.newBuilder()
.setUrl("localhost:1234")
.build())
.build());
PrismJobManager underTest = prismJobManager(service);
assertThat(underTest.isShutdown()).isFalse();
JobApi.PrepareJobResponse response =
underTest.prepare(JobApi.PrepareJobRequest.newBuilder().setPipeline(pipelineOf()).build());
assertThat(underTest.isShutdown()).isFalse();
assertThat(response.getStagingSessionToken()).isEqualTo("token");
assertThat(response.getPreparationId()).isEqualTo("preparationId");
underTest.close();
assertThat(underTest.isShutdown()).isTrue();
}

@Test
public void givenRunError_forwardsException_canGracefulShutdown() {
TestJobService service =
new TestJobService().withErrorResponse(new RuntimeException(testName.getMethodName()));
PrismJobManager underTest = prismJobManager(service);
assertThat(underTest.isShutdown()).isFalse();
assertThrows(
RuntimeException.class,
() ->
underTest.run(JobApi.RunJobRequest.newBuilder().setPreparationId("prepareId").build()));
assertThat(underTest.isShutdown()).isFalse();
underTest.close();
assertThat(underTest.isShutdown()).isTrue();
}

@Test
public void givenRunSuccess_forwardsResponse_canGracefulShutdown() {
TestJobService service =
new TestJobService()
.withRunJobResponse(JobApi.RunJobResponse.newBuilder().setJobId("jobId").build());
PrismJobManager underTest = prismJobManager(service);
assertThat(underTest.isShutdown()).isFalse();
JobApi.RunJobResponse runJobResponse =
underTest.run(JobApi.RunJobRequest.newBuilder().setPreparationId("preparationId").build());
assertThat(underTest.isShutdown()).isFalse();
assertThat(runJobResponse.getJobId()).isEqualTo("jobId");
underTest.close();
assertThat(underTest.isShutdown()).isTrue();
}

@Test
public void givenTerminalState_closes() {
PrismJobManager underTest = prismJobManager(new TestJobService());
assertThat(underTest.isShutdown()).isFalse();
underTest.onStateChanged(PipelineResult.State.RUNNING);
assertThat(underTest.isShutdown()).isFalse();
underTest.onStateChanged(PipelineResult.State.RUNNING);
assertThat(underTest.isShutdown()).isFalse();
underTest.onStateChanged(PipelineResult.State.CANCELLED);
assertThat(underTest.isShutdown()).isTrue();

underTest.close();
}

private PrismJobManager prismJobManager(TestJobService service) {
String serverName = InProcessServerBuilder.generateName();
try {
grpcCleanup.register(
InProcessServerBuilder.forName(serverName)
.directExecutor()
.addService(service)
.build()
.start());
} catch (IOException e) {
throw new RuntimeException(e);
}

ManagedChannel channel =
grpcCleanup.register(InProcessChannelBuilder.forName(serverName).build());

return PrismJobManager.builder()
.setTimeout(Duration.millis(3000L))
.setEndpoint("ignore")
.setManagedChannel(channel)
.build();
}

private static class TestJobService extends JobServiceGrpc.JobServiceImplBase {

private Optional<JobApi.PrepareJobResponse> prepareJobResponse = Optional.empty();
private Optional<JobApi.RunJobResponse> runJobResponse = Optional.empty();
private Optional<RuntimeException> error = Optional.empty();

TestJobService withPrepareJobResponse(JobApi.PrepareJobResponse prepareJobResponse) {
this.prepareJobResponse = Optional.of(prepareJobResponse);
return this;
}

TestJobService withRunJobResponse(JobApi.RunJobResponse runJobResponse) {
this.runJobResponse = Optional.of(runJobResponse);
return this;
}

TestJobService withErrorResponse(RuntimeException error) {
this.error = Optional.of(error);
return this;
}

@Override
public void prepare(
JobApi.PrepareJobRequest request,
StreamObserver<JobApi.PrepareJobResponse> responseObserver) {
if (prepareJobResponse.isPresent()) {
responseObserver.onNext(prepareJobResponse.get());
responseObserver.onCompleted();
}
if (error.isPresent()) {
responseObserver.onError(error.get());
}
}

@Override
public void run(
JobApi.RunJobRequest request, StreamObserver<JobApi.RunJobResponse> responseObserver) {
if (runJobResponse.isPresent()) {
responseObserver.onNext(runJobResponse.get());
responseObserver.onCompleted();
}
if (error.isPresent()) {
responseObserver.onError(error.get());
}
}
}

private static RunnerApi.Pipeline pipelineOf() {
Pipeline pipeline = Pipeline.create();
pipeline.apply(Impulse.create());
return PipelineTranslation.toProto(pipeline);
}
}

0 comments on commit 656a296

Please sign in to comment.