diff --git a/src/main/java/org/opensearch/search/relevance/transformer/personalizeintelligentranking/PersonalizeRankingResponseProcessor.java b/src/main/java/org/opensearch/search/relevance/transformer/personalizeintelligentranking/PersonalizeRankingResponseProcessor.java index 3bbc9d6..670915c 100644 --- a/src/main/java/org/opensearch/search/relevance/transformer/personalizeintelligentranking/PersonalizeRankingResponseProcessor.java +++ b/src/main/java/org/opensearch/search/relevance/transformer/personalizeintelligentranking/PersonalizeRankingResponseProcessor.java @@ -154,7 +154,7 @@ public Factory(PersonalizeClientSettings settings) { } @Override - public PersonalizeRankingResponseProcessor create(Map> processorFactories, String tag, String description, boolean ignoreFailure, Map config, PipelineContext pipelineContext) throws Exception { + public PersonalizeRankingResponseProcessor create(Map> processorFactories, String tag, String description, boolean ignoreFailure, Map config, PipelineContext pipelineContext) { String personalizeCampaign = ConfigurationUtils.readStringProperty(TYPE, tag, config, CAMPAIGN_ARN_CONFIG_NAME); String iamRoleArn = ConfigurationUtils.readOptionalStringProperty(TYPE, tag, config, IAM_ROLE_ARN_CONFIG_NAME); String recipe = ConfigurationUtils.readStringProperty(TYPE, tag, config, RECIPE_CONFIG_NAME); @@ -165,9 +165,25 @@ public PersonalizeRankingResponseProcessor create(Map configuration = new HashMap<>(); - configuration.put("item_id_field", itemIdField); + configuration.put("item_id_field", ITEM_ID_FIELD); configuration.put("recipe", AMAZON_PERSONALIZED_RANKING_RECIPE_NAME); - configuration.put("weight", String.valueOf(weight)); - configuration.put("iam_role_arn", iamRoleArn); - configuration.put("aws_region", region); + configuration.put("weight", String.valueOf(WEIGHT)); + configuration.put("iam_role_arn", IAM_ROLE_ARN); + configuration.put("aws_region", REGION); expectThrows(OpenSearchParseException.class, () -> factory.create( Collections.emptyMap(), @@ -82,16 +86,16 @@ public void testFactory() { null, false, configuration, - null + VALIDATE_CONTEXT )); configuration.clear(); // Test config without recipe - configuration.put("campaign_arn", personalizeCampaign); - configuration.put("item_id_field", itemIdField); - configuration.put("weight", String.valueOf(weight)); - configuration.put("iam_role_arn", iamRoleArn); - configuration.put("aws_region", region); + configuration.put("campaign_arn", PERSONALIZE_CAMPAIGN); + configuration.put("item_id_field", ITEM_ID_FIELD); + configuration.put("weight", String.valueOf(WEIGHT)); + configuration.put("iam_role_arn", IAM_ROLE_ARN); + configuration.put("aws_region", REGION); expectThrows(OpenSearchParseException.class, () -> factory.create( Collections.emptyMap(), @@ -99,59 +103,79 @@ public void testFactory() { null, false, configuration, - null + VALIDATE_CONTEXT )); configuration.clear(); // Test config without region - configuration.put("campaign_arn", personalizeCampaign); - configuration.put("item_id_field", itemIdField); + configuration.put("campaign_arn", PERSONALIZE_CAMPAIGN); + configuration.put("item_id_field", ITEM_ID_FIELD); configuration.put("recipe", AMAZON_PERSONALIZED_RANKING_RECIPE_NAME); - configuration.put("weight", String.valueOf(weight)); - configuration.put("iam_role_arn", iamRoleArn); + configuration.put("weight", String.valueOf(WEIGHT)); + configuration.put("iam_role_arn", IAM_ROLE_ARN); expectThrows(OpenSearchParseException.class, () -> factory.create( Collections.emptyMap(), null, null, false, configuration, - null + VALIDATE_CONTEXT )); configuration.clear(); // Test config without weight - configuration.put("campaign_arn", personalizeCampaign); - configuration.put("item_id_field", itemIdField); + configuration.put("campaign_arn", PERSONALIZE_CAMPAIGN); + configuration.put("item_id_field", ITEM_ID_FIELD); configuration.put("recipe", AMAZON_PERSONALIZED_RANKING_RECIPE_NAME); - configuration.put("iam_role_arn", iamRoleArn); - configuration.put("aws_region", region); + configuration.put("iam_role_arn", IAM_ROLE_ARN); + configuration.put("aws_region", REGION); expectThrows(OpenSearchParseException.class, () -> factory.create( Collections.emptyMap(), null, null, false, configuration, - null + VALIDATE_CONTEXT )); configuration.clear(); // Test configuration with invalid weight value - configuration.put("campaign_arn", personalizeCampaign); - configuration.put("item_id_field", itemIdField); + configuration.put("campaign_arn", PERSONALIZE_CAMPAIGN); + configuration.put("item_id_field", ITEM_ID_FIELD); configuration.put("recipe", AMAZON_PERSONALIZED_RANKING_RECIPE_NAME); configuration.put("weight", "invalid"); - configuration.put("iam_role_arn", iamRoleArn); - configuration.put("aws_region", region); + configuration.put("iam_role_arn", IAM_ROLE_ARN); + configuration.put("aws_region", REGION); expectThrows(OpenSearchParseException.class, () -> factory.create( Collections.emptyMap(), null, null, false, configuration, - null + VALIDATE_CONTEXT )); configuration.clear(); - IdleConnectionReaper.shutdown(); + + configuration.put("campaign_arn", PERSONALIZE_CAMPAIGN); + configuration.put("item_id_field", ITEM_ID_FIELD); + configuration.put("recipe", AMAZON_PERSONALIZED_RANKING_RECIPE_NAME); + configuration.put("weight", String.valueOf(WEIGHT)); + configuration.put("iam_role_arn", IAM_ROLE_ARN); + configuration.put("aws_region", REGION); + + // Test that we don't create client on validation + configuration.putAll(buildPersonalizeResponseProcessorConfig()); + PersonalizeRankingResponseProcessor processor = factory.create(Collections.emptyMap(), null, null, false, configuration, VALIDATE_CONTEXT); + assertNull(processor.getPersonalizeClient()); + + // Test that we fail on valid configuration in search request context + expectThrows(IllegalStateException.class, () -> factory.create( + Collections.emptyMap(), + null, + null, + false, + buildPersonalizeResponseProcessorConfig(), + new Processor.PipelineContext(Processor.PipelineSource.SEARCH_REQUEST))); } public void testCreateFactoryWithAllPersonalizeConfig() throws Exception { @@ -161,7 +185,7 @@ public void testCreateFactoryWithAllPersonalizeConfig() throws Exception { Map configuration = buildPersonalizeResponseProcessorConfig(); PersonalizeRankingResponseProcessor personalizeResponseProcessor = - factory.create(Collections.emptyMap(), "testTag", "testingAllFields", false, configuration, null); + factory.create(Collections.emptyMap(), "testTag", "testingAllFields", false, configuration, UPDATE_CONTEXT); assertEquals(TYPE, personalizeResponseProcessor.getType()); assertEquals("testTag", personalizeResponseProcessor.getTag()); @@ -177,7 +201,7 @@ public void testProcessorWithNoHits() throws Exception { Map configuration = buildPersonalizeResponseProcessorConfig(); PersonalizeRankingResponseProcessor personalizeResponseProcessor = - factory.create(Collections.emptyMap(), "testTag", "testingAllFields", false, configuration, null); + factory.create(Collections.emptyMap(), "testTag", "testingAllFields", false, configuration, UPDATE_CONTEXT); SearchRequest searchRequest = new SearchRequest(); SearchHits hits = new SearchHits(new SearchHit[0], new TotalHits(0, TotalHits.Relation.EQUAL_TO), 0.0f); SearchResponseSections searchResponseSections = new SearchResponseSections(hits, null, null, false, false, null, 0); @@ -196,22 +220,22 @@ public void testProcessorWithPersonalizeContext() throws Exception { Map configuration = buildPersonalizeResponseProcessorConfig(); PersonalizeRankingResponseProcessor personalizeResponseProcessor = - factory.create(Collections.emptyMap(), "testTag", "testingAllFields", false, configuration, null); + factory.create(Collections.emptyMap(), "testTag", "testingAllFields", false, configuration, UPDATE_CONTEXT); Map personalizeContext = new HashMap<>(); personalizeContext.put("contextKey2", "contextValue2"); SearchResponse personalizedResponse = - getPersonalizedRankingProcessorResponse(personalizeResponseProcessor, personalizeContext, numHits); + createPersonalizedRankingProcessorResponse(personalizeResponseProcessor, personalizeContext, NUM_HITS); List transformedHits = Arrays.asList(personalizedResponse.getHits().getHits()); List rerankedDocumentIds; rerankedDocumentIds = transformedHits.stream() - .filter(h -> h.getSourceAsMap().get(itemIdField) != null) - .map(h -> h.getSourceAsMap().get(itemIdField).toString()) + .filter(h -> h.getSourceAsMap().get(ITEM_ID_FIELD) != null) + .map(h -> h.getSourceAsMap().get(ITEM_ID_FIELD).toString()) .collect(Collectors.toList()); - ArrayList expectedRankedDocumentIds = PersonalizeRuntimeTestUtil.expectedRankedItemIdsForGivenWeight(numHits, 1); + ArrayList expectedRankedDocumentIds = PersonalizeRuntimeTestUtil.expectedRankedItemIdsForGivenWeight(NUM_HITS, 1); assertEquals(expectedRankedDocumentIds, rerankedDocumentIds); IdleConnectionReaper.shutdown(); } @@ -224,13 +248,13 @@ public void testProcessorWithHitsWithInvalidPersonalizeContext() throws Exceptio Map configuration = buildPersonalizeResponseProcessorConfig(); PersonalizeRankingResponseProcessor personalizeResponseProcessor = - factory.create(Collections.emptyMap(), "testTag", "testingAllFields", false, configuration, null); + factory.create(Collections.emptyMap(), "testTag", "testingAllFields", false, configuration, UPDATE_CONTEXT); Map personalizeContext = new HashMap<>(); personalizeContext.put("contextKey2", 5); expectThrows(OpenSearchParseException.class, () -> - getPersonalizedRankingProcessorResponse(personalizeResponseProcessor, personalizeContext, numHits)); + createPersonalizedRankingProcessorResponse(personalizeResponseProcessor, personalizeContext, NUM_HITS)); IdleConnectionReaper.shutdown(); } @@ -244,9 +268,9 @@ public void testPersonalizeRankingResponse() throws Exception { Map configuration = buildPersonalizeResponseProcessorConfig(); PersonalizeRankingResponseProcessor responseProcessor = - factory.create(Collections.emptyMap(), "testTag", "testingAllFields", false, configuration, null); + factory.create(Collections.emptyMap(), "testTag", "testingAllFields", false, configuration, UPDATE_CONTEXT); - SearchResponse personalizedResponse = getPersonalizedRankingProcessorResponse(responseProcessor, null, numHits); + SearchResponse personalizedResponse = createPersonalizedRankingProcessorResponse(responseProcessor, null, NUM_HITS); List transformedHits = Arrays.asList(personalizedResponse.getHits().getHits()); List rerankedDocumentIds; @@ -255,7 +279,7 @@ public void testPersonalizeRankingResponse() throws Exception { .map(h -> h.getSourceAsMap().get(itemField).toString()) .collect(Collectors.toList()); - ArrayList expectedRankedDocumentIds = PersonalizeRuntimeTestUtil.expectedRankedItemIdsForGivenWeight(numHits, 1); + ArrayList expectedRankedDocumentIds = PersonalizeRuntimeTestUtil.expectedRankedItemIdsForGivenWeight(NUM_HITS, 1); assertEquals(expectedRankedDocumentIds, rerankedDocumentIds); IdleConnectionReaper.shutdown(); } @@ -271,10 +295,10 @@ public void testPersonalizeRankingResponseWithInvalidItemIdFieldName() throws Ex configuration.put("item_id_field", itemFieldInvalid); PersonalizeRankingResponseProcessor responseProcessor = - factory.create(Collections.emptyMap(), "testTag", "testingAllFields", false, configuration, null); + factory.create(Collections.emptyMap(), "testTag", "testingAllFields", false, configuration, UPDATE_CONTEXT); expectThrows(OpenSearchParseException.class, () -> - getPersonalizedRankingProcessorResponse(responseProcessor, null, numHits)); + createPersonalizedRankingProcessorResponse(responseProcessor, null, NUM_HITS)); IdleConnectionReaper.shutdown(); } @@ -289,26 +313,26 @@ public void testPersonalizeRankingResponseWithDefaultItemIdField() throws Except configuration.put("item_id_field", itemIdFieldEmpty); PersonalizeRankingResponseProcessor responseProcessor = - factory.create(Collections.emptyMap(), "testTag", "testingAllFields", false, configuration, null); + factory.create(Collections.emptyMap(), "testTag", "testingAllFields", false, configuration, UPDATE_CONTEXT); - SearchResponse personalizedResponse = getPersonalizedRankingProcessorResponse(responseProcessor, null, numHits); + SearchResponse personalizedResponse = createPersonalizedRankingProcessorResponse(responseProcessor, null, NUM_HITS); List transformedHits = Arrays.asList(personalizedResponse.getHits().getHits()); List rerankedDocumentIds; rerankedDocumentIds = transformedHits.stream() - .filter(h -> h.getId() != null) - .map(h -> h.getId()) + .map(SearchHit::getId) + .filter(Objects::nonNull) .collect(Collectors.toList()); - ArrayList expectedRankedDocumentIds = PersonalizeRuntimeTestUtil.expectedRankedItemIdsForGivenWeight(numHits, 1); + ArrayList expectedRankedDocumentIds = PersonalizeRuntimeTestUtil.expectedRankedItemIdsForGivenWeight(NUM_HITS, 1); assertEquals(expectedRankedDocumentIds, rerankedDocumentIds); IdleConnectionReaper.shutdown(); } - private SearchResponse getPersonalizedRankingProcessorResponse(PersonalizeRankingResponseProcessor responseProcessor, - Map personalizeContext, - int numHits) throws Exception { + private SearchResponse createPersonalizedRankingProcessorResponse(PersonalizeRankingResponseProcessor responseProcessor, + Map personalizeContext, + int numHits) throws Exception { PersonalizeRequestParameters personalizeRequestParams = new PersonalizeRequestParameters("user_1", personalizeContext); SearchRequest request = SearchTestUtil.createSearchRequestWithPersonalizeRequest(personalizeRequestParams); @@ -324,12 +348,12 @@ private SearchResponse getPersonalizedRankingProcessorResponse(PersonalizeRankin private Map buildPersonalizeResponseProcessorConfig() { Map configuration = new HashMap<>(); - configuration.put("campaign_arn", personalizeCampaign); - configuration.put("item_id_field", itemIdField); + configuration.put("campaign_arn", PERSONALIZE_CAMPAIGN); + configuration.put("item_id_field", ITEM_ID_FIELD); configuration.put("recipe", AMAZON_PERSONALIZED_RANKING_RECIPE_NAME); - configuration.put("weight", String.valueOf(weight)); - configuration.put("iam_role_arn", iamRoleArn); - configuration.put("aws_region", region); + configuration.put("weight", String.valueOf(WEIGHT)); + configuration.put("iam_role_arn", IAM_ROLE_ARN); + configuration.put("aws_region", REGION); return configuration; } }