Skip to content

Commit

Permalink
fix ToolIntegrationWithLLMTest model undeploy race condition
Browse files Browse the repository at this point in the history
Previously the test class attempted to delete a model without fully knowing if the model was undeployed in time. This change adds a waiting for 5 retries each 1 second to check the status of the model and only when undeployed will it proceed to delete the model. When the number of retries are exceeded it throws a error indicating a deeper problem. Manual testing was done to check that the model is undeployed by searching for the specific model via the checkForModelUndeployedStatus method.

Signed-off-by: Brian Flores <iflorbri@amazon.com>
  • Loading branch information
brianf-aws committed Dec 5, 2024
1 parent f888c27 commit baa09b1
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 8 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -281,7 +281,7 @@ public void testPredictRemoteModelFailedWithModelGuardrail() throws IOException,
Map responseMap = parseResponseToMap(response);
String guardrailConnectorId = (String) responseMap.get("connector_id");

//Create the model ID
// Create the model ID
response = registerRemoteModel("guardrail model group", "openAI-GPT-3.5 completions", guardrailConnectorId);
responseMap = parseResponseToMap(response);
String guardrailModelId = (String) responseMap.get("model_id");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.Set;
import java.util.UUID;
import java.util.concurrent.TimeUnit;
import java.util.function.Predicate;
Expand All @@ -32,7 +31,7 @@
@Log4j2
public abstract class ToolIntegrationWithLLMTest extends RestBaseAgentToolsIT {

private static final int MAX_TASK_RESULT_QUERY_TIME_IN_SECOND = 30;
private static final int MAX_RETRIES = 5;
private static final int DEFAULT_TASK_RESULT_QUERY_INTERVAL_IN_MILLISECOND = 1000;

protected HttpServer server;
Expand Down Expand Up @@ -72,16 +71,17 @@ public void stopMockLLM() {
@After
public void deleteModel() throws IOException {
undeployModel(modelId);
checkForModelUndeployedStatus(modelId);
deleteModel(client(), modelId, null);
}

@SneakyThrows
private void waitModelUndeployed(String modelId) {
private void checkForModelUndeployedStatus(String modelId) {
Predicate<Response> condition = response -> {
try {
Map<String, Object> responseInMap = parseResponseToMap(response);
MLModelState state = MLModelState.from(responseInMap.get(MLModel.MODEL_STATE_FIELD).toString());
return Set.of(MLModelState.UNDEPLOYED, MLModelState.DEPLOY_FAILED).contains(state);
return MLModelState.UNDEPLOYED.equals(state);
} catch (Exception e) {
return false;
}
Expand All @@ -91,16 +91,25 @@ private void waitModelUndeployed(String modelId) {

@SneakyThrows
protected Response waitResponseMeetingCondition(String method, String endpoint, String jsonEntity, Predicate<Response> condition) {
for (int i = 0; i < MAX_TASK_RESULT_QUERY_TIME_IN_SECOND; i++) {
for (int attempt = 1; attempt <= MAX_RETRIES; attempt++) {
Response response = TestHelper.makeRequest(client(), method, endpoint, null, jsonEntity, null);
assertEquals(RestStatus.OK, RestStatus.fromCode(response.getStatusLine().getStatusCode()));
if (condition.test(response)) {
return response;
}
logger.info("The {}-th response: {}", i, response.toString());
logger.info("The {}-th attempt on {}:{} . response: {}", attempt, method, endpoint, response.toString());
Thread.sleep(DEFAULT_TASK_RESULT_QUERY_INTERVAL_IN_MILLISECOND);
}
fail("The response failed to meet condition after " + MAX_TASK_RESULT_QUERY_TIME_IN_SECOND + " seconds.");
fail(
String
.format(
Locale.ROOT,
"The response failed to meet condition after %d attempts. Attempted to perform %s : %s",
MAX_RETRIES,
method,
endpoint
)
);
return null;
}

Expand Down

0 comments on commit baa09b1

Please sign in to comment.