From 59d15c95be24099a18abb3c97c9a307e04e5f336 Mon Sep 17 00:00:00 2001 From: krishy91 Date: Tue, 16 Jan 2024 18:43:21 +0100 Subject: [PATCH] Add unit tests + small fixes Signed-off-by: krishy91 --- .../processor/InferenceProcessor.java | 90 ++++++++++--------- .../TextEmbeddingProcessorTests.java | 33 +++++++ 2 files changed, 80 insertions(+), 43 deletions(-) diff --git a/src/main/java/org/opensearch/neuralsearch/processor/InferenceProcessor.java b/src/main/java/org/opensearch/neuralsearch/processor/InferenceProcessor.java index a8687d8ac..a584b42e3 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/InferenceProcessor.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/InferenceProcessor.java @@ -4,7 +4,11 @@ */ package org.opensearch.neuralsearch.processor; -import java.util.*; +import java.util.ArrayList; +import java.util.LinkedHashMap; +import java.util.List; +import java.util.Map; +import java.util.Objects; import java.util.function.BiConsumer; import java.util.function.Supplier; import java.util.stream.Collectors; @@ -48,14 +52,14 @@ public abstract class InferenceProcessor extends AbstractProcessor { private final Environment environment; public InferenceProcessor( - String tag, - String description, - String type, - String listTypeNestedMapKey, - String modelId, - Map fieldMap, - MLCommonsClientAccessor clientAccessor, - Environment environment + String tag, + String description, + String type, + String listTypeNestedMapKey, + String modelId, + Map fieldMap, + MLCommonsClientAccessor clientAccessor, + Environment environment ) { super(tag, description); this.type = type; @@ -71,21 +75,21 @@ public InferenceProcessor( private void validateEmbeddingConfiguration(Map fieldMap) { if (fieldMap == null - || fieldMap.size() == 0 - || fieldMap.entrySet() + || fieldMap.size() == 0 + || fieldMap.entrySet() .stream() .anyMatch( - x -> StringUtils.isBlank(x.getKey()) || Objects.isNull(x.getValue()) || StringUtils.isBlank(x.getValue().toString()) + x -> StringUtils.isBlank(x.getKey()) || Objects.isNull(x.getValue()) || StringUtils.isBlank(x.getValue().toString()) )) { throw new IllegalArgumentException("Unable to create the processor as field_map has invalid key or value"); } } public abstract void doExecute( - IngestDocument ingestDocument, - Map ProcessMap, - List inferenceList, - BiConsumer handler + IngestDocument ingestDocument, + Map ProcessMap, + List inferenceList, + BiConsumer handler ); @Override @@ -162,10 +166,10 @@ Map buildMapWithProcessorKeyAndOriginalValue(IngestDocument inge } private void buildMapWithProcessorKeyAndOriginalValueForMapType( - String parentKey, - Object processorKey, - Map sourceAndMetadataMap, - Map treeRes + String parentKey, + Object processorKey, + Map sourceAndMetadataMap, + Map treeRes ) { if (processorKey == null || sourceAndMetadataMap == null) return; if (processorKey instanceof Map) { @@ -173,23 +177,23 @@ private void buildMapWithProcessorKeyAndOriginalValueForMapType( if (sourceAndMetadataMap.get(parentKey) instanceof Map) { for (Map.Entry nestedFieldMapEntry : ((Map) processorKey).entrySet()) { buildMapWithProcessorKeyAndOriginalValueForMapType( - nestedFieldMapEntry.getKey(), - nestedFieldMapEntry.getValue(), - (Map) sourceAndMetadataMap.get(parentKey), - next + nestedFieldMapEntry.getKey(), + nestedFieldMapEntry.getValue(), + (Map) sourceAndMetadataMap.get(parentKey), + next ); } } else if (sourceAndMetadataMap.get(parentKey) instanceof List) { for (Map.Entry nestedFieldMapEntry : ((Map) processorKey).entrySet()) { List> list = (List>) sourceAndMetadataMap.get(parentKey); List listOfStrings = list.stream().map(x -> x.get(nestedFieldMapEntry.getKey())).collect(Collectors.toList()); - Map map = new HashMap(); + Map map = new LinkedHashMap<>(); map.put(nestedFieldMapEntry.getKey(), listOfStrings); buildMapWithProcessorKeyAndOriginalValueForMapType( - nestedFieldMapEntry.getKey(), - nestedFieldMapEntry.getValue(), - map, - next + nestedFieldMapEntry.getKey(), + nestedFieldMapEntry.getValue(), + map, + next ); } } @@ -234,9 +238,9 @@ private void validateNestedTypeValue(String sourceKey, Object sourceValue, Suppl validateListTypeValue(sourceKey, sourceValue, maxDepthSupplier); } else if (Map.class.isAssignableFrom(sourceValue.getClass())) { ((Map) sourceValue).values() - .stream() - .filter(Objects::nonNull) - .forEach(x -> validateNestedTypeValue(sourceKey, x, () -> maxDepth + 1)); + .stream() + .filter(Objects::nonNull) + .forEach(x -> validateNestedTypeValue(sourceKey, x, () -> maxDepth + 1)); } else if (!String.class.isAssignableFrom(sourceValue.getClass())) { throw new IllegalArgumentException("map type field [" + sourceKey + "] has non-string type, cannot process it"); } else if (StringUtils.isBlank(sourceValue.toString())) { @@ -287,11 +291,11 @@ Map buildNLPResult(Map processorMap, List res @SuppressWarnings({ "unchecked" }) private void putNLPResultToSourceMapForMapType( - String processorKey, - Object sourceValue, - List results, - IndexWrapper indexWrapper, - Map sourceAndMetadataMap + String processorKey, + Object sourceValue, + List results, + IndexWrapper indexWrapper, + Map sourceAndMetadataMap ) { if (processorKey == null || sourceAndMetadataMap == null || sourceValue == null) return; if (sourceValue instanceof Map) { @@ -303,11 +307,11 @@ private void putNLPResultToSourceMapForMapType( } } else { putNLPResultToSourceMapForMapType( - inputNestedMapEntry.getKey(), - inputNestedMapEntry.getValue(), - results, - indexWrapper, - (Map) sourceAndMetadataMap.get(processorKey) + inputNestedMapEntry.getKey(), + inputNestedMapEntry.getValue(), + results, + indexWrapper, + (Map) sourceAndMetadataMap.get(processorKey) ); } } @@ -321,7 +325,7 @@ private void putNLPResultToSourceMapForMapType( private List> buildNLPResultForListType(List sourceValue, List results, IndexWrapper indexWrapper) { List> keyToResult = new ArrayList<>(); IntStream.range(0, sourceValue.size()) - .forEachOrdered(x -> keyToResult.add(ImmutableMap.of(listTypeNestedMapKey, results.get(indexWrapper.index++)))); + .forEachOrdered(x -> keyToResult.add(ImmutableMap.of(listTypeNestedMapKey, results.get(indexWrapper.index++)))); return keyToResult; } diff --git a/src/test/java/org/opensearch/neuralsearch/processor/TextEmbeddingProcessorTests.java b/src/test/java/org/opensearch/neuralsearch/processor/TextEmbeddingProcessorTests.java index 8c2f1c1be..6c0cd6dcb 100644 --- a/src/test/java/org/opensearch/neuralsearch/processor/TextEmbeddingProcessorTests.java +++ b/src/test/java/org/opensearch/neuralsearch/processor/TextEmbeddingProcessorTests.java @@ -20,6 +20,7 @@ import java.util.LinkedList; import java.util.List; import java.util.Map; +import java.util.Arrays; import java.util.function.BiConsumer; import java.util.function.Supplier; @@ -404,6 +405,20 @@ public void testBuildVectorOutput_withNestedMap_successful() { assertNotNull(actionGamesKnn); } + public void testBuildVectorOutput_withNestedList_successful() { + Map config = createNestedListConfiguration(); + IngestDocument ingestDocument = createNestedListIngestDocument(); + TextEmbeddingProcessor textEmbeddingProcessor = createInstanceWithNestedMapConfiguration(config); + Map knnMap = textEmbeddingProcessor.buildMapWithProcessorKeyAndOriginalValue(ingestDocument); + List> modelTensorList = createMockVectorResult(); + textEmbeddingProcessor.buildNLPResult(knnMap, modelTensorList, ingestDocument.getSourceAndMetadata()); + List> nestedObj = (List>) ingestDocument.getSourceAndMetadata().get("nestedField"); + assertTrue(nestedObj.get(0).containsKey("vectorField")); + assertTrue(nestedObj.get(1).containsKey("vectorField")); + assertNotNull(nestedObj.get(0).get("vectorField")); + assertNotNull(nestedObj.get(1).get("vectorField")); + } + public void test_updateDocument_appendVectorFieldsToDocument_successful() { Map config = createPlainStringConfiguration(); IngestDocument ingestDocument = createPlainIngestDocument(); @@ -520,4 +535,22 @@ private IngestDocument createNestedMapIngestDocument() { result.put("favorites", favorite); return new IngestDocument(result, new HashMap<>()); } + + private Map createNestedListConfiguration() { + Map nestedConfig = new HashMap<>(); + nestedConfig.put("textField", "vectorField"); + Map result = new HashMap<>(); + result.put("nestedField", nestedConfig); + return result; + } + + private IngestDocument createNestedListIngestDocument() { + HashMap nestedObj1 = new HashMap<>(); + nestedObj1.put("textField", "This is a text field"); + HashMap nestedObj2 = new HashMap<>(); + nestedObj2.put("textField", "This is another text field"); + HashMap nestedList = new HashMap<>(); + nestedList.put("nestedField", Arrays.asList(nestedObj1, nestedObj2)); + return new IngestDocument(nestedList, new HashMap<>()); + } }