Skip to content

Commit

Permalink
TSPS-303 validate method config and update if needed (#121)
Browse files Browse the repository at this point in the history
Co-authored-by: Jose Soto <jsoto@broadinstitute.org>
  • Loading branch information
jsotobroad and Jose Soto authored Sep 11, 2024
1 parent d1285a6 commit b22ca5e
Show file tree
Hide file tree
Showing 13 changed files with 584 additions and 63 deletions.
2 changes: 1 addition & 1 deletion rawls-client/swagger.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ def artifactGroup = "${group}.rawls"

task downloadRawlsYaml(type: Download) {
// this was the latest git commit hash at development time
src "https://raw.githubusercontent.com/broadinstitute/rawls/6bdd0d3d2371af6afb8df962819cf7ef48567d3c/core/src/main/resources/swagger/api-docs.yaml"
src "https://raw.githubusercontent.com/broadinstitute/rawls/eec65d1f93c80ef312104a86141f2b95f3e919b2/core/src/main/resources/swagger/api-docs.yaml"
dest "${buildDir}/rawls-api-docs.yaml"
}

Expand Down
Original file line number Diff line number Diff line change
@@ -1,10 +1,7 @@
package bio.terra.pipelines.dependencies.rawls;

import bio.terra.pipelines.app.configuration.external.RawlsConfiguration;
import bio.terra.rawls.api.EntitiesApi;
import bio.terra.rawls.api.StatusApi;
import bio.terra.rawls.api.SubmissionsApi;
import bio.terra.rawls.api.WorkspacesApi;
import bio.terra.rawls.api.*;
import bio.terra.rawls.client.ApiClient;
import jakarta.ws.rs.client.Client;
import org.springframework.stereotype.Component;
Expand Down Expand Up @@ -44,6 +41,10 @@ WorkspacesApi getWorkspacesApi(String accessToken) {
return new WorkspacesApi(getApiClient(accessToken));
}

MethodconfigsApi getMethodConfigsApi(String accessToken) {
return new MethodconfigsApi(getApiClient(accessToken));
}

StatusApi getStatusApi() {
return new StatusApi(getUnauthorizedApiClient());
}
Expand Down
Original file line number Diff line number Diff line change
@@ -1,25 +1,37 @@
package bio.terra.pipelines.dependencies.rawls;

import bio.terra.common.exception.InternalServerErrorException;
import bio.terra.pipelines.db.entities.PipelineInputDefinition;
import bio.terra.pipelines.db.entities.PipelineOutputDefinition;
import bio.terra.pipelines.dependencies.common.HealthCheck;
import bio.terra.rawls.client.ApiException;
import bio.terra.rawls.model.*;
import com.fasterxml.jackson.databind.ObjectMapper;
import java.util.HashMap;
import java.util.List;
import java.util.UUID;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.retry.support.RetryTemplate;
import org.springframework.stereotype.Service;

@Service
public class RawlsService implements HealthCheck {
private final RawlsClient rawlsClient;
private final RetryTemplate listenerResetRetryTemplate;
private final ObjectMapper objectMapper;

private static final List<SubmissionStatus> FINAL_RUN_STATES =
List.of(SubmissionStatus.ABORTED, SubmissionStatus.DONE);
private static final String DATA_TABLE_REFERENCE_PREFIX = "this.";

public RawlsService(RawlsClient rawlsClient, RetryTemplate listenerResetRetryTemplate) {
@Autowired
public RawlsService(
RawlsClient rawlsClient,
RetryTemplate listenerResetRetryTemplate,
ObjectMapper objectMapper) {
this.rawlsClient = rawlsClient;
this.listenerResetRetryTemplate = listenerResetRetryTemplate;
this.objectMapper = objectMapper;
}

@Override
Expand Down Expand Up @@ -108,11 +120,161 @@ public Entity getDataTableEntity(
.getEntity(workspaceNamespace, workspaceName, entityType, entityName, null, null));
}

public MethodConfiguration getCurrentMethodConfigForMethod(
String accessToken, String workspaceNamespace, String workspaceName, String methodName) {
return executionWithRetryTemplate(
listenerResetRetryTemplate,
() ->
rawlsClient
.getMethodConfigsApi(accessToken)
.getMethodConfiguration(
workspaceNamespace, workspaceName, workspaceNamespace, methodName));
}

public ValidatedMethodConfiguration setMethodConfigForMethod(
String accessToken,
MethodConfiguration methodConfiguration,
String workspaceNamespace,
String workspaceName,
String methodName) {
return executionWithRetryTemplate(
listenerResetRetryTemplate,
() ->
rawlsClient
.getMethodConfigsApi(accessToken)
.updateMethodConfiguration(
methodConfiguration,
workspaceNamespace,
workspaceName,
workspaceNamespace,
methodName));
}

// returns true if submission is in a running state
public static boolean submissionIsRunning(Submission submission) {
return !FINAL_RUN_STATES.contains(submission.getStatus());
}

/**
* validates a method config against the expected version, data table entity, inputs, and outputs.
* It does this by checking that the wdl method version matches what we expect and that all
* expected inputs/outputs both exist in the method config's list of inputs/outputs and that the
* data table reference for each input/output are what we expect.
*
* @param methodConfiguration - method config to validate against
* @param dataTableEntityName - data table entity name to use, should be the pipeline name
* @param wdlWorkflowName - name of the wdl workflow, used to construct the full wdl variable name
* @param pipelineInputDefinitions - list of input definitions for a pipeline
* @param pipelineOutputDefinitions - list of output definitions for a pipeline
* @param wdlMethodVersion - version of wdl we should be submitting
* @return - whether the method config matches what we expect with a boolean
*/
public boolean validateMethodConfig(
MethodConfiguration methodConfiguration,
String dataTableEntityName,
String wdlWorkflowName,
List<PipelineInputDefinition> pipelineInputDefinitions,
List<PipelineOutputDefinition> pipelineOutputDefinitions,
String wdlMethodVersion) {
// validate wdl method version
if (!methodConfiguration.getMethodRepoMethod().getMethodVersion().equals(wdlMethodVersion)) {
return false;
}
// validate data table entity name
if (!methodConfiguration.getRootEntityType().equals(dataTableEntityName)) {
return false;
}
// validate inputs
HashMap<?, ?> methodConfigInputs =
objectMapper.convertValue(methodConfiguration.getInputs(), HashMap.class);
for (PipelineInputDefinition pipelineInputDefinition : pipelineInputDefinitions) {
String fullWdlVariableName =
wdlWorkflowName + "." + pipelineInputDefinition.getWdlVariableName();
String fullDataTableReference =
DATA_TABLE_REFERENCE_PREFIX + pipelineInputDefinition.getWdlVariableName();
if (!methodConfigInputs.containsKey(fullWdlVariableName)
|| !methodConfigInputs.get(fullWdlVariableName).equals(fullDataTableReference)) {
return false;
}
}
// validate outputs
HashMap<?, ?> methodConfigOutputs =
objectMapper.convertValue(methodConfiguration.getOutputs(), HashMap.class);
for (PipelineOutputDefinition pipelineOutputDefinition : pipelineOutputDefinitions) {
String fullWdlVariableName =
wdlWorkflowName + "." + pipelineOutputDefinition.getWdlVariableName();
String fullDataTableReference =
DATA_TABLE_REFERENCE_PREFIX + pipelineOutputDefinition.getWdlVariableName();
if (!methodConfigOutputs.containsKey(fullWdlVariableName)
|| !methodConfigOutputs.get(fullWdlVariableName).equals(fullDataTableReference)) {
return false;
}
}

// all validation checks passed, return true
return true;
}

/**
* this method takes a pre-existing method configuration and updates it to match what the service
* expects. It does this by setting the input and outputs fields to match what we expect both from
* what keys we expect and the data table reference we expect.
*
* <p>It also sets the wdl method version to the version we should be running as well as updating
* the methodRepoMethod method uri by swapping the old version string with the expected version
* string.
*
* @param methodConfiguration - method configuration to update
* @param dataTableEntityName - data table entity name to use, should be the pipeline name
* @param wdlWorkflowName - name of the wdl workflow, used to construct the full wdl variable name
* @param pipelineInputDefinitions - list of input definitions for a pipeline
* @param pipelineOutputDefinitions - list of output definitions for a pipeline
* @param wdlMethodVersion - version of wdl we should be submitting
* @return - a method configuration that has its methodRepoMethod's methodVersion update and its
* inputs and outputs
*/
public MethodConfiguration updateMethodConfigToBeValid(
MethodConfiguration methodConfiguration,
String dataTableEntityName,
String wdlWorkflowName,
List<PipelineInputDefinition> pipelineInputDefinitions,
List<PipelineOutputDefinition> pipelineOutputDefinitions,
String wdlMethodVersion) {
// set wdl method version and uri to match the expected version
String oldVersion = methodConfiguration.getMethodRepoMethod().getMethodVersion();
String oldUri = methodConfiguration.getMethodRepoMethod().getMethodUri();
methodConfiguration.getMethodRepoMethod().setMethodVersion(wdlMethodVersion);
methodConfiguration
.getMethodRepoMethod()
.setMethodUri(oldUri.replace(oldVersion, wdlMethodVersion));

// set data table entity
methodConfiguration.setRootEntityType(dataTableEntityName);

// set inputs
HashMap<String, String> expectedInputs = new HashMap<>();
for (PipelineInputDefinition pipelineInputDefinition : pipelineInputDefinitions) {
String fullWdlVariableName =
wdlWorkflowName + "." + pipelineInputDefinition.getWdlVariableName();
String fullDataTableReference =
DATA_TABLE_REFERENCE_PREFIX + pipelineInputDefinition.getWdlVariableName();
expectedInputs.put(fullWdlVariableName, fullDataTableReference);
}
methodConfiguration.setInputs(expectedInputs);

// set outputs
HashMap<String, String> expectedOutputs = new HashMap<>();
for (PipelineOutputDefinition pipelineOutputDefinition : pipelineOutputDefinitions) {
String fullWdlVariableName =
wdlWorkflowName + "." + pipelineOutputDefinition.getWdlVariableName();
String fullDataTableReference =
DATA_TABLE_REFERENCE_PREFIX + pipelineOutputDefinition.getWdlVariableName();
expectedOutputs.put(fullWdlVariableName, fullDataTableReference);
}
methodConfiguration.setOutputs(expectedOutputs);
return methodConfiguration;
}

interface RawlsAction<T> {
T execute() throws ApiException;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -240,6 +240,8 @@ public PipelineRun startPipelineRun(
pipelineRun.getWorkspaceStorageContainerName())
.addParameter(
RunImputationJobFlightMapKeys.WDL_METHOD_NAME, pipeline.getWdlMethodName())
.addParameter(
RunImputationJobFlightMapKeys.WDL_METHOD_VERSION, pipeline.getWdlMethodVersion())
.addParameter(JobMapKeys.RESULT_PATH.getKeyName(), resultPath);

jobBuilder.submit();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ public RunImputationGcpJobFlight(FlightMap inputParameters, Object beanBag) {
RunImputationJobFlightMapKeys.CONTROL_WORKSPACE_NAME,
RunImputationJobFlightMapKeys.CONTROL_WORKSPACE_STORAGE_CONTAINER_NAME,
RunImputationJobFlightMapKeys.WDL_METHOD_NAME,
RunImputationJobFlightMapKeys.WDL_METHOD_VERSION,
JobMapKeys.RESULT_PATH.getKeyName());

PipelinesEnum pipelinesEnum =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ public abstract class RunImputationJobFlightMapKeys {
public static final String CONTROL_WORKSPACE_STORAGE_CONTAINER_NAME =
"control_workspace_storage_container_name";
public static final String WDL_METHOD_NAME = "wdl_method_name";
public static final String WDL_METHOD_VERSION = "wdl_method_version";
public static final String PIPELINE_RUN_OUTPUTS = "pipeline_run_outputs";

// GCP specific keys
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,21 @@
import bio.terra.pipelines.app.configuration.internal.ImputationConfiguration;
import bio.terra.pipelines.common.utils.FlightUtils;
import bio.terra.pipelines.common.utils.PipelinesEnum;
import bio.terra.pipelines.db.entities.PipelineInputDefinition;
import bio.terra.pipelines.db.entities.PipelineOutputDefinition;
import bio.terra.pipelines.dependencies.rawls.RawlsService;
import bio.terra.pipelines.dependencies.rawls.RawlsServiceApiException;
import bio.terra.pipelines.dependencies.sam.SamService;
import bio.terra.pipelines.dependencies.stairway.JobMapKeys;
import bio.terra.pipelines.stairway.imputation.RunImputationJobFlightMapKeys;
import bio.terra.rawls.model.MethodConfiguration;
import bio.terra.rawls.model.SubmissionReport;
import bio.terra.rawls.model.SubmissionRequest;
import bio.terra.stairway.*;
import com.fasterxml.jackson.core.type.TypeReference;
import java.util.List;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/**
* This step submits a submission to cromwell using the rawls submission endpoint. It uses
Expand All @@ -25,6 +32,8 @@ public class SubmitCromwellSubmissionStep implements Step {
private final RawlsService rawlsService;
private final ImputationConfiguration imputationConfiguration;

private final Logger logger = LoggerFactory.getLogger(SubmitCromwellSubmissionStep.class);

public SubmitCromwellSubmissionStep(
RawlsService rawlsService,
SamService samService,
Expand All @@ -46,8 +55,11 @@ public StepResult doStep(FlightContext flightContext) {
JobMapKeys.PIPELINE_NAME.getKeyName(),
JobMapKeys.DESCRIPTION.getKeyName(),
RunImputationJobFlightMapKeys.WDL_METHOD_NAME,
RunImputationJobFlightMapKeys.WDL_METHOD_VERSION,
RunImputationJobFlightMapKeys.CONTROL_WORKSPACE_BILLING_PROJECT,
RunImputationJobFlightMapKeys.CONTROL_WORKSPACE_NAME);
RunImputationJobFlightMapKeys.CONTROL_WORKSPACE_NAME,
RunImputationJobFlightMapKeys.PIPELINE_INPUT_DEFINITIONS,
RunImputationJobFlightMapKeys.PIPELINE_OUTPUT_DEFINITIONS);

String controlWorkspaceName =
inputParameters.get(RunImputationJobFlightMapKeys.CONTROL_WORKSPACE_NAME, String.class);
Expand All @@ -59,10 +71,73 @@ public StepResult doStep(FlightContext flightContext) {
String description = inputParameters.get(JobMapKeys.DESCRIPTION.getKeyName(), String.class);
String wdlMethodName =
inputParameters.get(RunImputationJobFlightMapKeys.WDL_METHOD_NAME, String.class);
String wdlMethodVersion =
inputParameters.get(RunImputationJobFlightMapKeys.WDL_METHOD_VERSION, String.class);

List<PipelineInputDefinition> inputDefinitions =
inputParameters.get(
RunImputationJobFlightMapKeys.PIPELINE_INPUT_DEFINITIONS, new TypeReference<>() {});
List<PipelineOutputDefinition> outputDefinitions =
inputParameters.get(
RunImputationJobFlightMapKeys.PIPELINE_OUTPUT_DEFINITIONS, new TypeReference<>() {});

// validate and extract parameters from working map
FlightMap workingMap = flightContext.getWorkingMap();

MethodConfiguration methodConfiguration;
try {
// grab current method config and validate it
methodConfiguration =
rawlsService.getCurrentMethodConfigForMethod(
samService.getTeaspoonsServiceAccountToken(),
controlWorkspaceProject,
controlWorkspaceName,
wdlMethodName);
} catch (RawlsServiceApiException e) {
// if we fail to grab the method config then retry
return new StepResult(StepStatus.STEP_RESULT_FAILURE_RETRY, e);
}
boolean validMethodConfig =
rawlsService.validateMethodConfig(
methodConfiguration,
pipelineName.getValue(),
wdlMethodName,
inputDefinitions,
outputDefinitions,
wdlMethodVersion);

// if not a valid method config, set the method config to what we think it should be. This
// shouldn't happen
if (!validMethodConfig) {
logger.warn(
"found method config that was not valid for billing project: {}, workspace: {}, method name: {}, methodConfigVersion: {}",
controlWorkspaceProject,
controlWorkspaceName,
wdlMethodName,
methodConfiguration.getMethodConfigVersion());

MethodConfiguration updatedMethodConfiguration =
rawlsService.updateMethodConfigToBeValid(
methodConfiguration,
pipelineName.getValue(),
wdlMethodName,
inputDefinitions,
outputDefinitions,
wdlMethodVersion);
try {
// update method config version, inputs, and outputs
rawlsService.setMethodConfigForMethod(
samService.getTeaspoonsServiceAccountToken(),
updatedMethodConfiguration,
controlWorkspaceProject,
controlWorkspaceName,
wdlMethodName);
} catch (RawlsServiceApiException e) {
// if we fail to update the method config then retry
return new StepResult(StepStatus.STEP_RESULT_FAILURE_RETRY, e);
}
}

// create submission request
SubmissionRequest submissionRequest =
new SubmissionRequest()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,7 @@
import static org.junit.jupiter.api.Assertions.assertTrue;

import bio.terra.pipelines.testutils.BaseEmbeddedDbTest;
import bio.terra.rawls.api.EntitiesApi;
import bio.terra.rawls.api.StatusApi;
import bio.terra.rawls.api.SubmissionsApi;
import bio.terra.rawls.api.WorkspacesApi;
import bio.terra.rawls.api.*;
import bio.terra.rawls.client.ApiClient;
import bio.terra.rawls.client.auth.Authentication;
import bio.terra.rawls.client.auth.OAuth;
Expand Down Expand Up @@ -54,5 +51,9 @@ void testWorkspaceClientApis() {
WorkspacesApi workspacesApi = rawlsClient.getWorkspacesApi(authToken);
assertEquals(rawlsBaseUri, workspacesApi.getApiClient().getBasePath());
assertTrue(workspacesApi.getApiClient().isDebugging());

MethodconfigsApi methodconfigsApi = rawlsClient.getMethodConfigsApi(authToken);
assertEquals(rawlsBaseUri, methodconfigsApi.getApiClient().getBasePath());
assertTrue(methodconfigsApi.getApiClient().isDebugging());
}
}
Loading

0 comments on commit b22ca5e

Please sign in to comment.