Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

TSPS-303 validate method config and update if needed #121

Merged
merged 8 commits into from
Sep 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion rawls-client/swagger.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ def artifactGroup = "${group}.rawls"

task downloadRawlsYaml(type: Download) {
// this was the latest git commit hash at development time
src "https://raw.githubusercontent.com/broadinstitute/rawls/6bdd0d3d2371af6afb8df962819cf7ef48567d3c/core/src/main/resources/swagger/api-docs.yaml"
src "https://raw.githubusercontent.com/broadinstitute/rawls/eec65d1f93c80ef312104a86141f2b95f3e919b2/core/src/main/resources/swagger/api-docs.yaml"
dest "${buildDir}/rawls-api-docs.yaml"
}

Expand Down
Original file line number Diff line number Diff line change
@@ -1,10 +1,7 @@
package bio.terra.pipelines.dependencies.rawls;

import bio.terra.pipelines.app.configuration.external.RawlsConfiguration;
import bio.terra.rawls.api.EntitiesApi;
import bio.terra.rawls.api.StatusApi;
import bio.terra.rawls.api.SubmissionsApi;
import bio.terra.rawls.api.WorkspacesApi;
import bio.terra.rawls.api.*;
import bio.terra.rawls.client.ApiClient;
import jakarta.ws.rs.client.Client;
import org.springframework.stereotype.Component;
Expand Down Expand Up @@ -44,6 +41,10 @@ WorkspacesApi getWorkspacesApi(String accessToken) {
return new WorkspacesApi(getApiClient(accessToken));
}

MethodconfigsApi getMethodConfigsApi(String accessToken) {
return new MethodconfigsApi(getApiClient(accessToken));
}

StatusApi getStatusApi() {
return new StatusApi(getUnauthorizedApiClient());
}
Expand Down
Original file line number Diff line number Diff line change
@@ -1,25 +1,37 @@
package bio.terra.pipelines.dependencies.rawls;

import bio.terra.common.exception.InternalServerErrorException;
import bio.terra.pipelines.db.entities.PipelineInputDefinition;
import bio.terra.pipelines.db.entities.PipelineOutputDefinition;
import bio.terra.pipelines.dependencies.common.HealthCheck;
import bio.terra.rawls.client.ApiException;
import bio.terra.rawls.model.*;
import com.fasterxml.jackson.databind.ObjectMapper;
import java.util.HashMap;
import java.util.List;
import java.util.UUID;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.retry.support.RetryTemplate;
import org.springframework.stereotype.Service;

@Service
public class RawlsService implements HealthCheck {
private final RawlsClient rawlsClient;
private final RetryTemplate listenerResetRetryTemplate;
private final ObjectMapper objectMapper;

private static final List<SubmissionStatus> FINAL_RUN_STATES =
List.of(SubmissionStatus.ABORTED, SubmissionStatus.DONE);
private static final String DATA_TABLE_REFERENCE_PREFIX = "this.";

public RawlsService(RawlsClient rawlsClient, RetryTemplate listenerResetRetryTemplate) {
@Autowired
public RawlsService(
RawlsClient rawlsClient,
RetryTemplate listenerResetRetryTemplate,
ObjectMapper objectMapper) {
this.rawlsClient = rawlsClient;
this.listenerResetRetryTemplate = listenerResetRetryTemplate;
this.objectMapper = objectMapper;
}

@Override
Expand Down Expand Up @@ -108,11 +120,161 @@ public Entity getDataTableEntity(
.getEntity(workspaceNamespace, workspaceName, entityType, entityName, null, null));
}

public MethodConfiguration getCurrentMethodConfigForMethod(
String accessToken, String workspaceNamespace, String workspaceName, String methodName) {
return executionWithRetryTemplate(
listenerResetRetryTemplate,
() ->
rawlsClient
.getMethodConfigsApi(accessToken)
.getMethodConfiguration(
workspaceNamespace, workspaceName, workspaceNamespace, methodName));
}

public ValidatedMethodConfiguration setMethodConfigForMethod(
String accessToken,
MethodConfiguration methodConfiguration,
String workspaceNamespace,
String workspaceName,
String methodName) {
return executionWithRetryTemplate(
listenerResetRetryTemplate,
() ->
rawlsClient
.getMethodConfigsApi(accessToken)
.updateMethodConfiguration(
methodConfiguration,
workspaceNamespace,
workspaceName,
workspaceNamespace,
methodName));
}

// returns true if submission is in a running state
public static boolean submissionIsRunning(Submission submission) {
return !FINAL_RUN_STATES.contains(submission.getStatus());
}

/**
* validates a method config against the expected version, data table entity, inputs, and outputs.
* It does this by checking that the wdl method version matches what we expect and that all
* expected inputs/outputs both exist in the method config's list of inputs/outputs and that the
* data table reference for each input/output are what we expect.
*
* @param methodConfiguration - method config to validate against
* @param dataTableEntityName - data table entity name to use, should be the pipeline name
* @param wdlWorkflowName - name of the wdl workflow, used to construct the full wdl variable name
* @param pipelineInputDefinitions - list of input definitions for a pipeline
* @param pipelineOutputDefinitions - list of output definitions for a pipeline
* @param wdlMethodVersion - version of wdl we should be submitting
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

also validate rootEntity (or whatever it's called) that defines which data table it should read from

* @return - whether the method config matches what we expect with a boolean
*/
public boolean validateMethodConfig(
MethodConfiguration methodConfiguration,
String dataTableEntityName,
String wdlWorkflowName,
List<PipelineInputDefinition> pipelineInputDefinitions,
List<PipelineOutputDefinition> pipelineOutputDefinitions,
String wdlMethodVersion) {
// validate wdl method version
if (!methodConfiguration.getMethodRepoMethod().getMethodVersion().equals(wdlMethodVersion)) {
return false;
}
// validate data table entity name
if (!methodConfiguration.getRootEntityType().equals(dataTableEntityName)) {
return false;
}
// validate inputs
HashMap<?, ?> methodConfigInputs =
objectMapper.convertValue(methodConfiguration.getInputs(), HashMap.class);
for (PipelineInputDefinition pipelineInputDefinition : pipelineInputDefinitions) {
String fullWdlVariableName =
wdlWorkflowName + "." + pipelineInputDefinition.getWdlVariableName();
String fullDataTableReference =
DATA_TABLE_REFERENCE_PREFIX + pipelineInputDefinition.getWdlVariableName();
if (!methodConfigInputs.containsKey(fullWdlVariableName)
|| !methodConfigInputs.get(fullWdlVariableName).equals(fullDataTableReference)) {
return false;
}
}
// validate outputs
HashMap<?, ?> methodConfigOutputs =
objectMapper.convertValue(methodConfiguration.getOutputs(), HashMap.class);
for (PipelineOutputDefinition pipelineOutputDefinition : pipelineOutputDefinitions) {
String fullWdlVariableName =
wdlWorkflowName + "." + pipelineOutputDefinition.getWdlVariableName();
String fullDataTableReference =
DATA_TABLE_REFERENCE_PREFIX + pipelineOutputDefinition.getWdlVariableName();
if (!methodConfigOutputs.containsKey(fullWdlVariableName)
|| !methodConfigOutputs.get(fullWdlVariableName).equals(fullDataTableReference)) {
return false;
}
}

// all validation checks passed, return true
return true;
}

/**
* this method takes a pre-existing method configuration and updates it to match what the service
* expects. It does this by setting the input and outputs fields to match what we expect both from
* what keys we expect and the data table reference we expect.
*
* <p>It also sets the wdl method version to the version we should be running as well as updating
* the methodRepoMethod method uri by swapping the old version string with the expected version
* string.
*
* @param methodConfiguration - method configuration to update
* @param dataTableEntityName - data table entity name to use, should be the pipeline name
* @param wdlWorkflowName - name of the wdl workflow, used to construct the full wdl variable name
* @param pipelineInputDefinitions - list of input definitions for a pipeline
* @param pipelineOutputDefinitions - list of output definitions for a pipeline
* @param wdlMethodVersion - version of wdl we should be submitting
* @return - a method configuration that has its methodRepoMethod's methodVersion update and its
* inputs and outputs
*/
public MethodConfiguration updateMethodConfigToBeValid(
MethodConfiguration methodConfiguration,
String dataTableEntityName,
String wdlWorkflowName,
List<PipelineInputDefinition> pipelineInputDefinitions,
List<PipelineOutputDefinition> pipelineOutputDefinitions,
String wdlMethodVersion) {
// set wdl method version and uri to match the expected version
String oldVersion = methodConfiguration.getMethodRepoMethod().getMethodVersion();
String oldUri = methodConfiguration.getMethodRepoMethod().getMethodUri();
methodConfiguration.getMethodRepoMethod().setMethodVersion(wdlMethodVersion);
methodConfiguration
.getMethodRepoMethod()
.setMethodUri(oldUri.replace(oldVersion, wdlMethodVersion));

// set data table entity
methodConfiguration.setRootEntityType(dataTableEntityName);

// set inputs
HashMap<String, String> expectedInputs = new HashMap<>();
for (PipelineInputDefinition pipelineInputDefinition : pipelineInputDefinitions) {
String fullWdlVariableName =
wdlWorkflowName + "." + pipelineInputDefinition.getWdlVariableName();
String fullDataTableReference =
DATA_TABLE_REFERENCE_PREFIX + pipelineInputDefinition.getWdlVariableName();
expectedInputs.put(fullWdlVariableName, fullDataTableReference);
}
methodConfiguration.setInputs(expectedInputs);

// set outputs
HashMap<String, String> expectedOutputs = new HashMap<>();
for (PipelineOutputDefinition pipelineOutputDefinition : pipelineOutputDefinitions) {
String fullWdlVariableName =
wdlWorkflowName + "." + pipelineOutputDefinition.getWdlVariableName();
String fullDataTableReference =
DATA_TABLE_REFERENCE_PREFIX + pipelineOutputDefinition.getWdlVariableName();
expectedOutputs.put(fullWdlVariableName, fullDataTableReference);
}
methodConfiguration.setOutputs(expectedOutputs);
return methodConfiguration;
}

interface RawlsAction<T> {
T execute() throws ApiException;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -240,6 +240,8 @@ public PipelineRun startPipelineRun(
pipelineRun.getWorkspaceStorageContainerName())
.addParameter(
RunImputationJobFlightMapKeys.WDL_METHOD_NAME, pipeline.getWdlMethodName())
.addParameter(
RunImputationJobFlightMapKeys.WDL_METHOD_VERSION, pipeline.getWdlMethodVersion())
.addParameter(JobMapKeys.RESULT_PATH.getKeyName(), resultPath);

jobBuilder.submit();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ public RunImputationGcpJobFlight(FlightMap inputParameters, Object beanBag) {
RunImputationJobFlightMapKeys.CONTROL_WORKSPACE_NAME,
RunImputationJobFlightMapKeys.CONTROL_WORKSPACE_STORAGE_CONTAINER_NAME,
RunImputationJobFlightMapKeys.WDL_METHOD_NAME,
RunImputationJobFlightMapKeys.WDL_METHOD_VERSION,
JobMapKeys.RESULT_PATH.getKeyName());

PipelinesEnum pipelinesEnum =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ public abstract class RunImputationJobFlightMapKeys {
public static final String CONTROL_WORKSPACE_STORAGE_CONTAINER_NAME =
"control_workspace_storage_container_name";
public static final String WDL_METHOD_NAME = "wdl_method_name";
public static final String WDL_METHOD_VERSION = "wdl_method_version";
public static final String PIPELINE_RUN_OUTPUTS = "pipeline_run_outputs";

// GCP specific keys
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,21 @@
import bio.terra.pipelines.app.configuration.internal.ImputationConfiguration;
import bio.terra.pipelines.common.utils.FlightUtils;
import bio.terra.pipelines.common.utils.PipelinesEnum;
import bio.terra.pipelines.db.entities.PipelineInputDefinition;
import bio.terra.pipelines.db.entities.PipelineOutputDefinition;
import bio.terra.pipelines.dependencies.rawls.RawlsService;
import bio.terra.pipelines.dependencies.rawls.RawlsServiceApiException;
import bio.terra.pipelines.dependencies.sam.SamService;
import bio.terra.pipelines.dependencies.stairway.JobMapKeys;
import bio.terra.pipelines.stairway.imputation.RunImputationJobFlightMapKeys;
import bio.terra.rawls.model.MethodConfiguration;
import bio.terra.rawls.model.SubmissionReport;
import bio.terra.rawls.model.SubmissionRequest;
import bio.terra.stairway.*;
import com.fasterxml.jackson.core.type.TypeReference;
import java.util.List;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/**
* This step submits a submission to cromwell using the rawls submission endpoint. It uses
Expand All @@ -25,6 +32,8 @@ public class SubmitCromwellSubmissionStep implements Step {
private final RawlsService rawlsService;
private final ImputationConfiguration imputationConfiguration;

private final Logger logger = LoggerFactory.getLogger(SubmitCromwellSubmissionStep.class);

public SubmitCromwellSubmissionStep(
RawlsService rawlsService,
SamService samService,
Expand All @@ -46,8 +55,11 @@ public StepResult doStep(FlightContext flightContext) {
JobMapKeys.PIPELINE_NAME.getKeyName(),
JobMapKeys.DESCRIPTION.getKeyName(),
RunImputationJobFlightMapKeys.WDL_METHOD_NAME,
RunImputationJobFlightMapKeys.WDL_METHOD_VERSION,
RunImputationJobFlightMapKeys.CONTROL_WORKSPACE_BILLING_PROJECT,
RunImputationJobFlightMapKeys.CONTROL_WORKSPACE_NAME);
RunImputationJobFlightMapKeys.CONTROL_WORKSPACE_NAME,
RunImputationJobFlightMapKeys.PIPELINE_INPUT_DEFINITIONS,
RunImputationJobFlightMapKeys.PIPELINE_OUTPUT_DEFINITIONS);

String controlWorkspaceName =
inputParameters.get(RunImputationJobFlightMapKeys.CONTROL_WORKSPACE_NAME, String.class);
Expand All @@ -59,10 +71,73 @@ public StepResult doStep(FlightContext flightContext) {
String description = inputParameters.get(JobMapKeys.DESCRIPTION.getKeyName(), String.class);
String wdlMethodName =
inputParameters.get(RunImputationJobFlightMapKeys.WDL_METHOD_NAME, String.class);
String wdlMethodVersion =
inputParameters.get(RunImputationJobFlightMapKeys.WDL_METHOD_VERSION, String.class);

List<PipelineInputDefinition> inputDefinitions =
inputParameters.get(
RunImputationJobFlightMapKeys.PIPELINE_INPUT_DEFINITIONS, new TypeReference<>() {});
List<PipelineOutputDefinition> outputDefinitions =
inputParameters.get(
RunImputationJobFlightMapKeys.PIPELINE_OUTPUT_DEFINITIONS, new TypeReference<>() {});

// validate and extract parameters from working map
FlightMap workingMap = flightContext.getWorkingMap();

MethodConfiguration methodConfiguration;
try {
// grab current method config and validate it
methodConfiguration =
rawlsService.getCurrentMethodConfigForMethod(
samService.getTeaspoonsServiceAccountToken(),
controlWorkspaceProject,
controlWorkspaceName,
wdlMethodName);
} catch (RawlsServiceApiException e) {
// if we fail to grab the method config then retry
return new StepResult(StepStatus.STEP_RESULT_FAILURE_RETRY, e);
}
boolean validMethodConfig =
rawlsService.validateMethodConfig(
methodConfiguration,
pipelineName.getValue(),
wdlMethodName,
inputDefinitions,
outputDefinitions,
wdlMethodVersion);

// if not a valid method config, set the method config to what we think it should be. This
// shouldn't happen
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

shouldn't happen often, but it will happen once every time we change the wdlMethodVersion... unless we update it in Terra as well as in our service, which actually we could do. 🤷 it's nice to have the warning log here. wonder if we should set up a slack channel to capture warnings.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is that something we can do -- is that a sentry thingy?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah probably sentry

if (!validMethodConfig) {
logger.warn(
"found method config that was not valid for billing project: {}, workspace: {}, method name: {}, methodConfigVersion: {}",
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: it's not necessarily "not valid" but it's not matching our expectations. super minor but to me "not valid" suggests there's an actual error in the setup

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

but to us there is an error in the setup no?

controlWorkspaceProject,
controlWorkspaceName,
wdlMethodName,
methodConfiguration.getMethodConfigVersion());

MethodConfiguration updatedMethodConfiguration =
rawlsService.updateMethodConfigToBeValid(
methodConfiguration,
pipelineName.getValue(),
wdlMethodName,
inputDefinitions,
outputDefinitions,
wdlMethodVersion);
try {
// update method config version, inputs, and outputs
rawlsService.setMethodConfigForMethod(
samService.getTeaspoonsServiceAccountToken(),
updatedMethodConfiguration,
controlWorkspaceProject,
controlWorkspaceName,
wdlMethodName);
} catch (RawlsServiceApiException e) {
// if we fail to update the method config then retry
return new StepResult(StepStatus.STEP_RESULT_FAILURE_RETRY, e);
}
}

// create submission request
SubmissionRequest submissionRequest =
new SubmissionRequest()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,7 @@
import static org.junit.jupiter.api.Assertions.assertTrue;

import bio.terra.pipelines.testutils.BaseEmbeddedDbTest;
import bio.terra.rawls.api.EntitiesApi;
import bio.terra.rawls.api.StatusApi;
import bio.terra.rawls.api.SubmissionsApi;
import bio.terra.rawls.api.WorkspacesApi;
import bio.terra.rawls.api.*;
import bio.terra.rawls.client.ApiClient;
import bio.terra.rawls.client.auth.Authentication;
import bio.terra.rawls.client.auth.OAuth;
Expand Down Expand Up @@ -54,5 +51,9 @@ void testWorkspaceClientApis() {
WorkspacesApi workspacesApi = rawlsClient.getWorkspacesApi(authToken);
assertEquals(rawlsBaseUri, workspacesApi.getApiClient().getBasePath());
assertTrue(workspacesApi.getApiClient().isDebugging());

MethodconfigsApi methodconfigsApi = rawlsClient.getMethodConfigsApi(authToken);
assertEquals(rawlsBaseUri, methodconfigsApi.getApiClient().getBasePath());
assertTrue(methodconfigsApi.getApiClient().isDebugging());
}
}
Loading
Loading