From 203f6deaf568ab601a73230a197ae0e77bb5459a Mon Sep 17 00:00:00 2001 From: DenovVasil Date: Thu, 19 Dec 2024 09:41:49 +0200 Subject: [PATCH] feat(doc handle): doc handling for bedrock --- .../aws-bedrock-outbound-connector.json | 53 +++++--- ...aws-bedrock-outbound-connector-hybrid.json | 53 +++++--- connectors/aws/aws-bedrock/pom.xml | 6 + .../aws/bedrock/core/BedrockExecutor.java | 3 +- .../bedrock/mapper/BedrockContentMapper.java | 72 ++++++++++ .../aws/bedrock/mapper/DocumentMapper.java | 76 +++++++++++ .../aws/bedrock/mapper/MessageMapper.java | 45 ++++++ .../aws/bedrock/model/BedrockContent.java | 59 ++++++++ .../aws/bedrock/model/BedrockMessage.java | 53 ++++++++ .../aws/bedrock/model/ConverseData.java | 125 +++++++++++------ .../model/ConverseWrapperResponse.java | 12 -- .../aws/bedrock/model/PreviousMessage.java | 9 -- .../aws/bedrock/model/RequestData.java | 2 +- .../connector/aws/bedrock/util/FileUtil.java | 37 +++++ .../connector/aws/bedrock/BaseTest.java | 11 +- .../mapper/BedrockContentMapperTest.java | 128 ++++++++++++++++++ .../bedrock/mapper/DocumentMapperTest.java | 102 ++++++++++++++ .../aws/bedrock/mapper/MessageMapperTest.java | 68 ++++++++++ .../aws/bedrock/model/ConverseDataTest.java | 37 +++-- .../aws/bedrock/util/FileUtilTest.java | 45 ++++++ .../resources/converse/converseExample.json | 10 +- .../resources/converse/image-document.json | 6 + .../resources/converse/text-document.json | 6 + .../converse/unsupported-document.json | 6 + 24 files changed, 898 insertions(+), 126 deletions(-) create mode 100644 connectors/aws/aws-bedrock/src/main/java/io/camunda/connector/aws/bedrock/mapper/BedrockContentMapper.java create mode 100644 connectors/aws/aws-bedrock/src/main/java/io/camunda/connector/aws/bedrock/mapper/DocumentMapper.java create mode 100644 connectors/aws/aws-bedrock/src/main/java/io/camunda/connector/aws/bedrock/mapper/MessageMapper.java create mode 100644 connectors/aws/aws-bedrock/src/main/java/io/camunda/connector/aws/bedrock/model/BedrockContent.java create mode 100644 connectors/aws/aws-bedrock/src/main/java/io/camunda/connector/aws/bedrock/model/BedrockMessage.java delete mode 100644 connectors/aws/aws-bedrock/src/main/java/io/camunda/connector/aws/bedrock/model/ConverseWrapperResponse.java delete mode 100644 connectors/aws/aws-bedrock/src/main/java/io/camunda/connector/aws/bedrock/model/PreviousMessage.java create mode 100644 connectors/aws/aws-bedrock/src/main/java/io/camunda/connector/aws/bedrock/util/FileUtil.java create mode 100644 connectors/aws/aws-bedrock/src/test/java/io/camunda/connector/aws/bedrock/mapper/BedrockContentMapperTest.java create mode 100644 connectors/aws/aws-bedrock/src/test/java/io/camunda/connector/aws/bedrock/mapper/DocumentMapperTest.java create mode 100644 connectors/aws/aws-bedrock/src/test/java/io/camunda/connector/aws/bedrock/mapper/MessageMapperTest.java create mode 100644 connectors/aws/aws-bedrock/src/test/java/io/camunda/connector/aws/bedrock/util/FileUtilTest.java create mode 100644 connectors/aws/aws-bedrock/src/test/resources/converse/image-document.json create mode 100644 connectors/aws/aws-bedrock/src/test/resources/converse/text-document.json create mode 100644 connectors/aws/aws-bedrock/src/test/resources/converse/unsupported-document.json diff --git a/connectors/aws/aws-bedrock/element-templates/aws-bedrock-outbound-connector.json b/connectors/aws/aws-bedrock/element-templates/aws-bedrock-outbound-connector.json index a359fca740..913c04c0ae 100644 --- a/connectors/aws/aws-bedrock/element-templates/aws-bedrock-outbound-connector.json +++ b/connectors/aws/aws-bedrock/element-templates/aws-bedrock-outbound-connector.json @@ -3,6 +3,9 @@ "name" : "AWS BedRock Outbound Connector", "id" : "io.camunda.connectors.aws.bedrock.v1", "description" : "Execute bedrock requests", + "metadata" : { + "keywords" : [ ] + }, "documentationRef" : "https://docs.camunda.io/docs/", "version" : 1, "category" : { @@ -186,6 +189,23 @@ "type" : "simple" }, "type" : "String" + }, { + "id" : "data.messagesHistory", + "label" : "Message History", + "description" : "Specify the message history, when previous context is needed", + "optional" : true, + "feel" : "required", + "group" : "converse", + "binding" : { + "name" : "data.messagesHistory", + "type" : "zeebe:input" + }, + "condition" : { + "property" : "action", + "equals" : "converse", + "type" : "simple" + }, + "type" : "String" }, { "id" : "data.modelId1", "label" : "Model ID", @@ -207,7 +227,7 @@ }, "type" : "String" }, { - "id" : "data.nextMessage", + "id" : "data.newMessage", "label" : "New Message", "description" : "Specify the next message", "optional" : false, @@ -217,7 +237,7 @@ "feel" : "optional", "group" : "converse", "binding" : { - "name" : "data.nextMessage", + "name" : "data.newMessage", "type" : "zeebe:input" }, "condition" : { @@ -227,14 +247,13 @@ }, "type" : "String" }, { - "id" : "data.messages", - "label" : "Message History", - "description" : "Specify the message history, when previous context is needed", + "id" : "data.maxTokens", + "label" : "Max token returned", "optional" : true, - "feel" : "required", + "feel" : "optional", "group" : "converse", "binding" : { - "name" : "data.messages", + "name" : "data.maxTokens", "type" : "zeebe:input" }, "condition" : { @@ -244,13 +263,13 @@ }, "type" : "String" }, { - "id" : "data.maxTokens", - "label" : "Max token returned", + "id" : "data.temperature", + "label" : "Temperature", "optional" : true, "feel" : "optional", "group" : "converse", "binding" : { - "name" : "data.maxTokens", + "name" : "data.temperature", "type" : "zeebe:input" }, "condition" : { @@ -260,13 +279,13 @@ }, "type" : "String" }, { - "id" : "data.temperature", - "label" : "Temperature", + "id" : "data.topP", + "label" : "top P", "optional" : true, "feel" : "optional", "group" : "converse", "binding" : { - "name" : "data.temperature", + "name" : "data.topP", "type" : "zeebe:input" }, "condition" : { @@ -276,13 +295,13 @@ }, "type" : "String" }, { - "id" : "data.topP", - "label" : "top P", + "id" : "data.newDocuments", + "label" : "documents", "optional" : true, - "feel" : "optional", + "feel" : "required", "group" : "converse", "binding" : { - "name" : "data.topP", + "name" : "data.newDocuments", "type" : "zeebe:input" }, "condition" : { diff --git a/connectors/aws/aws-bedrock/element-templates/hybrid/aws-bedrock-outbound-connector-hybrid.json b/connectors/aws/aws-bedrock/element-templates/hybrid/aws-bedrock-outbound-connector-hybrid.json index fa1b2876a6..839ba4d725 100644 --- a/connectors/aws/aws-bedrock/element-templates/hybrid/aws-bedrock-outbound-connector-hybrid.json +++ b/connectors/aws/aws-bedrock/element-templates/hybrid/aws-bedrock-outbound-connector-hybrid.json @@ -3,6 +3,9 @@ "name" : "Hybrid AWS BedRock Outbound Connector", "id" : "io.camunda.connectors.aws.bedrock.v1-hybrid", "description" : "Execute bedrock requests", + "metadata" : { + "keywords" : [ ] + }, "documentationRef" : "https://docs.camunda.io/docs/", "version" : 1, "category" : { @@ -191,6 +194,23 @@ "type" : "simple" }, "type" : "String" + }, { + "id" : "data.messagesHistory", + "label" : "Message History", + "description" : "Specify the message history, when previous context is needed", + "optional" : true, + "feel" : "required", + "group" : "converse", + "binding" : { + "name" : "data.messagesHistory", + "type" : "zeebe:input" + }, + "condition" : { + "property" : "action", + "equals" : "converse", + "type" : "simple" + }, + "type" : "String" }, { "id" : "data.modelId1", "label" : "Model ID", @@ -212,7 +232,7 @@ }, "type" : "String" }, { - "id" : "data.nextMessage", + "id" : "data.newMessage", "label" : "New Message", "description" : "Specify the next message", "optional" : false, @@ -222,7 +242,7 @@ "feel" : "optional", "group" : "converse", "binding" : { - "name" : "data.nextMessage", + "name" : "data.newMessage", "type" : "zeebe:input" }, "condition" : { @@ -232,14 +252,13 @@ }, "type" : "String" }, { - "id" : "data.messages", - "label" : "Message History", - "description" : "Specify the message history, when previous context is needed", + "id" : "data.maxTokens", + "label" : "Max token returned", "optional" : true, - "feel" : "required", + "feel" : "optional", "group" : "converse", "binding" : { - "name" : "data.messages", + "name" : "data.maxTokens", "type" : "zeebe:input" }, "condition" : { @@ -249,13 +268,13 @@ }, "type" : "String" }, { - "id" : "data.maxTokens", - "label" : "Max token returned", + "id" : "data.temperature", + "label" : "Temperature", "optional" : true, "feel" : "optional", "group" : "converse", "binding" : { - "name" : "data.maxTokens", + "name" : "data.temperature", "type" : "zeebe:input" }, "condition" : { @@ -265,13 +284,13 @@ }, "type" : "String" }, { - "id" : "data.temperature", - "label" : "Temperature", + "id" : "data.topP", + "label" : "top P", "optional" : true, "feel" : "optional", "group" : "converse", "binding" : { - "name" : "data.temperature", + "name" : "data.topP", "type" : "zeebe:input" }, "condition" : { @@ -281,13 +300,13 @@ }, "type" : "String" }, { - "id" : "data.topP", - "label" : "top P", + "id" : "data.newDocuments", + "label" : "documents", "optional" : true, - "feel" : "optional", + "feel" : "required", "group" : "converse", "binding" : { - "name" : "data.topP", + "name" : "data.newDocuments", "type" : "zeebe:input" }, "condition" : { diff --git a/connectors/aws/aws-bedrock/pom.xml b/connectors/aws/aws-bedrock/pom.xml index eb40dcc1aa..706e6df02f 100644 --- a/connectors/aws/aws-bedrock/pom.xml +++ b/connectors/aws/aws-bedrock/pom.xml @@ -37,6 +37,12 @@ bedrockruntime ${version.software-aws-java-sdk} + + + org.apache.tika + tika-core + 2.0.0 + diff --git a/connectors/aws/aws-bedrock/src/main/java/io/camunda/connector/aws/bedrock/core/BedrockExecutor.java b/connectors/aws/aws-bedrock/src/main/java/io/camunda/connector/aws/bedrock/core/BedrockExecutor.java index ddf4d4f64c..8edcd71801 100644 --- a/connectors/aws/aws-bedrock/src/main/java/io/camunda/connector/aws/bedrock/core/BedrockExecutor.java +++ b/connectors/aws/aws-bedrock/src/main/java/io/camunda/connector/aws/bedrock/core/BedrockExecutor.java @@ -9,7 +9,6 @@ import io.camunda.connector.aws.CredentialsProviderSupportV2; import io.camunda.connector.aws.ObjectMapperSupplier; import io.camunda.connector.aws.bedrock.model.BedrockRequest; -import io.camunda.connector.aws.bedrock.model.BedrockResponse; import io.camunda.connector.aws.bedrock.model.RequestData; import software.amazon.awssdk.regions.Region; import software.amazon.awssdk.services.bedrockruntime.BedrockRuntimeClient; @@ -33,7 +32,7 @@ public static BedrockExecutor create(BedrockRequest bedrockRequest) { bedrockRequest.getData()); } - public BedrockResponse execute() { + public Object execute() { return this.requestData.execute(bedrockRuntimeClient, ObjectMapperSupplier.getMapperInstance()); } } diff --git a/connectors/aws/aws-bedrock/src/main/java/io/camunda/connector/aws/bedrock/mapper/BedrockContentMapper.java b/connectors/aws/aws-bedrock/src/main/java/io/camunda/connector/aws/bedrock/mapper/BedrockContentMapper.java new file mode 100644 index 0000000000..e77d457294 --- /dev/null +++ b/connectors/aws/aws-bedrock/src/main/java/io/camunda/connector/aws/bedrock/mapper/BedrockContentMapper.java @@ -0,0 +1,72 @@ +/* + * Copyright Camunda Services GmbH and/or licensed to Camunda Services GmbH + * under one or more contributor license agreements. Licensed under a proprietary license. + * See the License.txt file for more information. You may not use this file + * except in compliance with the proprietary license. + */ +package io.camunda.connector.aws.bedrock.mapper; + +import io.camunda.connector.aws.bedrock.model.BedrockContent; +import io.camunda.document.Document; +import java.util.List; +import software.amazon.awssdk.services.bedrockruntime.model.ContentBlock; +import software.amazon.awssdk.services.bedrockruntime.model.DocumentBlock; +import software.amazon.awssdk.services.bedrockruntime.model.ImageBlock; + +public class BedrockContentMapper { + + private final DocumentMapper documentMapper; + + public BedrockContentMapper(DocumentMapper documentMapper) { + this.documentMapper = documentMapper; + } + + public BedrockContent messageToBedrockContent(String message) { + return new BedrockContent(message); + } + + public List documentsToBedrockContent(List documents) { + if (documents == null) { + return List.of(); + } + return documents.stream().map(this::documentToBedrockContent).toList(); + } + + public BedrockContent documentToBedrockContent(Document document) { + return new BedrockContent(document); + } + + /* + * The current implementation supports ContentBlock containing only text. + * */ + public List mapToBedrockContent(List contentBlocks) { + return contentBlocks.stream().map(ContentBlock::text).map(BedrockContent::new).toList(); + } + + public List mapToContentBlocks(List contentBlocks) { + return contentBlocks.stream() + .map( + content -> { + String text = content.getText(); + if (text != null) { + return mapToContentBlock(text); + } + var document = content.getDocument(); + var docBlock = documentMapper.mapToFileBlock(document); + return mapToContentBlock(docBlock); + }) + .toList(); + } + + private ContentBlock mapToContentBlock(Object content) { + return switch (content) { + case String s -> ContentBlock.fromText(s); + case ImageBlock imageBlock -> ContentBlock.fromImage(imageBlock); + default -> ContentBlock.fromDocument((DocumentBlock) content); + }; + } + + public DocumentMapper getDocumentMapper() { + return documentMapper; + } +} diff --git a/connectors/aws/aws-bedrock/src/main/java/io/camunda/connector/aws/bedrock/mapper/DocumentMapper.java b/connectors/aws/aws-bedrock/src/main/java/io/camunda/connector/aws/bedrock/mapper/DocumentMapper.java new file mode 100644 index 0000000000..61541121f7 --- /dev/null +++ b/connectors/aws/aws-bedrock/src/main/java/io/camunda/connector/aws/bedrock/mapper/DocumentMapper.java @@ -0,0 +1,76 @@ +/* + * Copyright Camunda Services GmbH and/or licensed to Camunda Services GmbH + * under one or more contributor license agreements. Licensed under a proprietary license. + * See the License.txt file for more information. You may not use this file + * except in compliance with the proprietary license. + */ +package io.camunda.connector.aws.bedrock.mapper; + +import io.camunda.connector.aws.bedrock.util.FileUtil; +import io.camunda.document.Document; +import org.apache.commons.lang3.tuple.Pair; +import org.apache.tika.mime.MimeTypeException; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import software.amazon.awssdk.core.SdkBytes; +import software.amazon.awssdk.core.SdkPojo; +import software.amazon.awssdk.services.bedrockruntime.model.*; + +public class DocumentMapper { + + public static final String UNSUPPORTED_DOC_TYPE_MSG = "Unsupported document type: "; + public static final String UNSUPPORTED_CONTENT_TYPE_MSG = "Unsupported document content type: "; + private static final Logger LOGGER = LoggerFactory.getLogger(DocumentMapper.class); + + public DocumentMapper() {} + + public SdkPojo mapToFileBlock(Document document) { + Pair nameTypePair = + FileUtil.defineNameAndType(document.metadata().getFileName()); + String fileName = nameTypePair.getLeft(); + String fileType = nameTypePair.getRight(); + String contentType = document.metadata().getContentType(); + + try { + fileType = fileType.isEmpty() ? FileUtil.defineType(contentType) : fileType; + } catch (MimeTypeException e) { + String errorMsg = UNSUPPORTED_CONTENT_TYPE_MSG + contentType; + LOGGER.debug(errorMsg); + throw new RuntimeException(errorMsg, e); + } + + var bytes = document.asByteArray(); + var imageFormat = ImageFormat.fromValue(fileType); + + if (!imageFormat.equals(ImageFormat.UNKNOWN_TO_SDK_VERSION)) { + return mapToImag(bytes, imageFormat); + } + + var documentFormat = DocumentFormat.fromValue(fileType); + + if (!documentFormat.equals(DocumentFormat.UNKNOWN_TO_SDK_VERSION)) { + return mapToDocument(bytes, documentFormat, fileName); + } + + String unsupportedDocumentMsg = UNSUPPORTED_DOC_TYPE_MSG + fileName; + + LOGGER.debug(unsupportedDocumentMsg); + throw new IllegalArgumentException(unsupportedDocumentMsg); + } + + private ImageBlock mapToImag(byte[] bytes, ImageFormat imageFormat) { + return ImageBlock.builder() + .source(ImageSource.builder().bytes(SdkBytes.fromByteArray(bytes)).build()) + .format(imageFormat) + .build(); + } + + private DocumentBlock mapToDocument( + byte[] bytes, DocumentFormat documentFormat, String fileName) { + return DocumentBlock.builder() + .source(DocumentSource.builder().bytes(SdkBytes.fromByteArray(bytes)).build()) + .format(documentFormat) + .name(fileName) + .build(); + } +} diff --git a/connectors/aws/aws-bedrock/src/main/java/io/camunda/connector/aws/bedrock/mapper/MessageMapper.java b/connectors/aws/aws-bedrock/src/main/java/io/camunda/connector/aws/bedrock/mapper/MessageMapper.java new file mode 100644 index 0000000000..3702a64926 --- /dev/null +++ b/connectors/aws/aws-bedrock/src/main/java/io/camunda/connector/aws/bedrock/mapper/MessageMapper.java @@ -0,0 +1,45 @@ +/* + * Copyright Camunda Services GmbH and/or licensed to Camunda Services GmbH + * under one or more contributor license agreements. Licensed under a proprietary license. + * See the License.txt file for more information. You may not use this file + * except in compliance with the proprietary license. + */ +package io.camunda.connector.aws.bedrock.mapper; + +import io.camunda.connector.aws.bedrock.model.BedrockContent; +import io.camunda.connector.aws.bedrock.model.BedrockMessage; +import java.util.List; +import software.amazon.awssdk.services.bedrockruntime.model.Message; + +public class MessageMapper { + + private final BedrockContentMapper bedrockContentMapper; + + public MessageMapper(BedrockContentMapper bedrockContentMapper) { + this.bedrockContentMapper = bedrockContentMapper; + } + + public List mapToMessages(List bedrockMessages) { + if (bedrockMessages != null) { + return bedrockMessages.stream().map(this::mapToMessage).toList(); + } + return List.of(); + } + + public Message mapToMessage(BedrockMessage bedrockMessage) { + return Message.builder() + .content(bedrockContentMapper.mapToContentBlocks(bedrockMessage.getContentList())) + .role(bedrockMessage.getRole()) + .build(); + } + + public BedrockMessage mapToBedrockMessage(Message message) { + List bedrockContents = + bedrockContentMapper.mapToBedrockContent(message.content()); + return new BedrockMessage(message.roleAsString(), bedrockContents); + } + + public BedrockContentMapper getBedrockContentMapper() { + return bedrockContentMapper; + } +} diff --git a/connectors/aws/aws-bedrock/src/main/java/io/camunda/connector/aws/bedrock/model/BedrockContent.java b/connectors/aws/aws-bedrock/src/main/java/io/camunda/connector/aws/bedrock/model/BedrockContent.java new file mode 100644 index 0000000000..973aaa2d67 --- /dev/null +++ b/connectors/aws/aws-bedrock/src/main/java/io/camunda/connector/aws/bedrock/model/BedrockContent.java @@ -0,0 +1,59 @@ +/* + * Copyright Camunda Services GmbH and/or licensed to Camunda Services GmbH + * under one or more contributor license agreements. Licensed under a proprietary license. + * See the License.txt file for more information. You may not use this file + * except in compliance with the proprietary license. + */ +package io.camunda.connector.aws.bedrock.model; + +import com.fasterxml.jackson.annotation.JsonInclude; +import io.camunda.document.Document; +import jakarta.validation.Valid; +import jakarta.validation.constraints.NotBlank; +import java.util.Objects; + +@JsonInclude(JsonInclude.Include.NON_NULL) +public class BedrockContent { + private String text; + + private Document document; + + public BedrockContent(String text) { + this.text = text; + } + + public BedrockContent(Document document) { + this.document = document; + } + + public BedrockContent() {} + + public @Valid @NotBlank String getText() { + return text; + } + + public void setText(@Valid @NotBlank String text) { + this.text = text; + } + + public Document getDocument() { + return document; + } + + public void setDocument(Document document) { + this.document = document; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + BedrockContent that = (BedrockContent) o; + return Objects.equals(text, that.text) && Objects.equals(document, that.document); + } + + @Override + public int hashCode() { + return Objects.hash(text, document); + } +} diff --git a/connectors/aws/aws-bedrock/src/main/java/io/camunda/connector/aws/bedrock/model/BedrockMessage.java b/connectors/aws/aws-bedrock/src/main/java/io/camunda/connector/aws/bedrock/model/BedrockMessage.java new file mode 100644 index 0000000000..66fda29727 --- /dev/null +++ b/connectors/aws/aws-bedrock/src/main/java/io/camunda/connector/aws/bedrock/model/BedrockMessage.java @@ -0,0 +1,53 @@ +/* + * Copyright Camunda Services GmbH and/or licensed to Camunda Services GmbH + * under one or more contributor license agreements. Licensed under a proprietary license. + * See the License.txt file for more information. You may not use this file + * except in compliance with the proprietary license. + */ +package io.camunda.connector.aws.bedrock.model; + +import java.util.List; +import java.util.Objects; + +public class BedrockMessage { + + private String role; + + private List contentList; + + public BedrockMessage(String role, List contentList) { + this.role = role; + this.contentList = contentList; + } + + public BedrockMessage() {} + + public String getRole() { + return role; + } + + public List getContentList() { + return contentList; + } + + public void setRole(String role) { + this.role = role; + } + + public void setContentList(List contentList) { + this.contentList = contentList; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + BedrockMessage that = (BedrockMessage) o; + return Objects.equals(role, that.role) && Objects.equals(contentList, that.contentList); + } + + @Override + public int hashCode() { + return Objects.hash(role, contentList); + } +} diff --git a/connectors/aws/aws-bedrock/src/main/java/io/camunda/connector/aws/bedrock/model/ConverseData.java b/connectors/aws/aws-bedrock/src/main/java/io/camunda/connector/aws/bedrock/model/ConverseData.java index 2727923d19..4acb3d5b4f 100644 --- a/connectors/aws/aws-bedrock/src/main/java/io/camunda/connector/aws/bedrock/model/ConverseData.java +++ b/connectors/aws/aws-bedrock/src/main/java/io/camunda/connector/aws/bedrock/model/ConverseData.java @@ -9,9 +9,13 @@ import com.fasterxml.jackson.annotation.JsonSetter; import com.fasterxml.jackson.annotation.Nulls; import com.fasterxml.jackson.databind.ObjectMapper; +import io.camunda.connector.aws.bedrock.mapper.BedrockContentMapper; +import io.camunda.connector.aws.bedrock.mapper.DocumentMapper; +import io.camunda.connector.aws.bedrock.mapper.MessageMapper; import io.camunda.connector.generator.dsl.Property; import io.camunda.connector.generator.java.annotation.TemplateProperty; import io.camunda.connector.generator.java.annotation.TemplateSubType; +import io.camunda.document.Document; import jakarta.validation.Valid; import jakarta.validation.constraints.NotBlank; import jakarta.validation.constraints.NotNull; @@ -19,14 +23,23 @@ import java.util.List; import java.util.Objects; import software.amazon.awssdk.services.bedrockruntime.BedrockRuntimeClient; -import software.amazon.awssdk.services.bedrockruntime.model.ContentBlock; -import software.amazon.awssdk.services.bedrockruntime.model.ConversationRole; -import software.amazon.awssdk.services.bedrockruntime.model.ConverseResponse; -import software.amazon.awssdk.services.bedrockruntime.model.Message; +import software.amazon.awssdk.services.bedrockruntime.model.*; @TemplateSubType(id = "converse", label = "Converse") public final class ConverseData implements RequestData { + @TemplateProperty( + label = "Message History", + group = "converse", + id = "data.messagesHistory", + description = "Specify the message history, when previous context is needed", + feel = Property.FeelMode.required, + optional = true, + binding = @TemplateProperty.PropertyBinding(name = "data.messagesHistory")) + @Valid + @JsonSetter(nulls = Nulls.SKIP) + private List messagesHistory = new ArrayList<>(); + @TemplateProperty( label = "Model ID", group = "converse", @@ -42,25 +55,13 @@ public final class ConverseData implements RequestData { @TemplateProperty( label = "New Message", group = "converse", - id = "data.nextMessage", + id = "data.newMessage", description = "Specify the next message", feel = Property.FeelMode.optional, - binding = @TemplateProperty.PropertyBinding(name = "data.nextMessage")) + binding = @TemplateProperty.PropertyBinding(name = "data.newMessage")) @Valid @NotBlank - private String nextMessage; - - @TemplateProperty( - label = "Message History", - group = "converse", - id = "data.messages", - description = "Specify the message history, when previous context is needed", - feel = Property.FeelMode.required, - optional = true, - binding = @TemplateProperty.PropertyBinding(name = "data.messages")) - @Valid - @JsonSetter(nulls = Nulls.SKIP) - private List messages = new ArrayList<>(); + private String newMessage; @TemplateProperty( label = "Max token returned", @@ -89,20 +90,24 @@ public final class ConverseData implements RequestData { binding = @TemplateProperty.PropertyBinding(name = "data.topP")) private Float topP = 0.9f; + @TemplateProperty( + label = "documents", + group = "converse", + id = "data.newDocuments", + feel = Property.FeelMode.required, + optional = true, + binding = @TemplateProperty.PropertyBinding(name = "data.newDocuments")) + private List newDocuments; + @Override - public BedrockResponse execute( + public List execute( BedrockRuntimeClient bedrockRuntimeClient, ObjectMapper mapperInstance) { - this.messages.add(new PreviousMessage(this.nextMessage, ConversationRole.USER.name())); - Message.Builder messageBuilder = Message.builder(); - List messages = - this.messages.stream() - .map( - message -> - messageBuilder - .role(ConversationRole.valueOf(message.role())) - .content(ContentBlock.fromText(message.message())) - .build()) - .toList(); + var messageMapper = createMessageMapper(); + var bedrockContentMapper = messageMapper.getBedrockContentMapper(); + this.messagesHistory.add(prepareNewBedrockMessage(bedrockContentMapper)); + + List messages = messageMapper.mapToMessages(this.messagesHistory); + ConverseResponse converseResponse = bedrockRuntimeClient.converse( builder -> @@ -116,9 +121,25 @@ public BedrockResponse execute( .maxTokens(this.maxTokens) .topP(this.topP) .build())); - String newMessage = converseResponse.output().message().content().getFirst().text(); - this.messages.add(new PreviousMessage(newMessage, ConversationRole.ASSISTANT.name())); - return new ConverseWrapperResponse(this.messages, newMessage); + + var responseMessage = converseResponse.output().message(); + this.messagesHistory.add(messageMapper.mapToBedrockMessage(responseMessage)); + + return messagesHistory; + } + + private BedrockMessage prepareNewBedrockMessage(BedrockContentMapper bedrockContentMapper) { + String user = ConversationRole.USER.toString(); + List contentList = + new ArrayList<>(bedrockContentMapper.documentsToBedrockContent(this.newDocuments)); + contentList.add(bedrockContentMapper.messageToBedrockContent(this.newMessage)); + + return new BedrockMessage(user, contentList); + } + + private MessageMapper createMessageMapper() { + var bedrockContentMapper = new BedrockContentMapper(new DocumentMapper()); + return new MessageMapper(bedrockContentMapper); } @Override @@ -126,29 +147,27 @@ public boolean equals(Object o) { if (this == o) return true; if (o == null || getClass() != o.getClass()) return false; ConverseData that = (ConverseData) o; - return Objects.equals(modelId, that.modelId) - && Objects.equals(nextMessage, that.nextMessage) - && Objects.equals(messages, that.messages) + return Objects.equals(messagesHistory, that.messagesHistory) + && Objects.equals(modelId, that.modelId) + && Objects.equals(newMessage, that.newMessage) && Objects.equals(maxTokens, that.maxTokens) && Objects.equals(temperature, that.temperature) - && Objects.equals(topP, that.topP); + && Objects.equals(topP, that.topP) + && Objects.equals(newDocuments, that.newDocuments); } @Override public int hashCode() { - return Objects.hash(modelId, nextMessage, messages, maxTokens, temperature, topP); + return Objects.hash( + messagesHistory, modelId, newMessage, maxTokens, temperature, topP, newDocuments); } public void setModelId(@Valid @NotNull String modelId) { this.modelId = modelId; } - public void setNextMessage(@Valid @NotBlank String nextMessage) { - this.nextMessage = nextMessage; - } - - public void setMessages(@Valid List messages) { - this.messages = messages; + public void setNewMessage(@Valid @NotBlank String newMessage) { + this.newMessage = newMessage; } public void setMaxTokens(Integer maxTokens) { @@ -163,6 +182,22 @@ public void setTopP(Float topP) { this.topP = topP; } + public List getNewDocuments() { + return newDocuments; + } + + public void setNewDocuments(List newDocuments) { + this.newDocuments = newDocuments; + } + + public List getMessagesHistory() { + return this.messagesHistory; + } + + public void setMessagesHistory(List messagesHistory) { + this.messagesHistory = messagesHistory; + } + @Override public String toString() { return "ConverseData{" diff --git a/connectors/aws/aws-bedrock/src/main/java/io/camunda/connector/aws/bedrock/model/ConverseWrapperResponse.java b/connectors/aws/aws-bedrock/src/main/java/io/camunda/connector/aws/bedrock/model/ConverseWrapperResponse.java deleted file mode 100644 index 97eabc2cf2..0000000000 --- a/connectors/aws/aws-bedrock/src/main/java/io/camunda/connector/aws/bedrock/model/ConverseWrapperResponse.java +++ /dev/null @@ -1,12 +0,0 @@ -/* - * Copyright Camunda Services GmbH and/or licensed to Camunda Services GmbH - * under one or more contributor license agreements. Licensed under a proprietary license. - * See the License.txt file for more information. You may not use this file - * except in compliance with the proprietary license. - */ -package io.camunda.connector.aws.bedrock.model; - -import java.util.List; - -public record ConverseWrapperResponse(List messageHistory, String newMessage) - implements BedrockResponse {} diff --git a/connectors/aws/aws-bedrock/src/main/java/io/camunda/connector/aws/bedrock/model/PreviousMessage.java b/connectors/aws/aws-bedrock/src/main/java/io/camunda/connector/aws/bedrock/model/PreviousMessage.java deleted file mode 100644 index 1c6c0aacfb..0000000000 --- a/connectors/aws/aws-bedrock/src/main/java/io/camunda/connector/aws/bedrock/model/PreviousMessage.java +++ /dev/null @@ -1,9 +0,0 @@ -/* - * Copyright Camunda Services GmbH and/or licensed to Camunda Services GmbH - * under one or more contributor license agreements. Licensed under a proprietary license. - * See the License.txt file for more information. You may not use this file - * except in compliance with the proprietary license. - */ -package io.camunda.connector.aws.bedrock.model; - -public record PreviousMessage(String message, String role) {} diff --git a/connectors/aws/aws-bedrock/src/main/java/io/camunda/connector/aws/bedrock/model/RequestData.java b/connectors/aws/aws-bedrock/src/main/java/io/camunda/connector/aws/bedrock/model/RequestData.java index 5b2e1edb79..7d028524ea 100644 --- a/connectors/aws/aws-bedrock/src/main/java/io/camunda/connector/aws/bedrock/model/RequestData.java +++ b/connectors/aws/aws-bedrock/src/main/java/io/camunda/connector/aws/bedrock/model/RequestData.java @@ -18,5 +18,5 @@ defaultValue = "invokeModel") @TemplateSubType(id = "action", label = "Action") public sealed interface RequestData permits InvokeModelData, ConverseData { - BedrockResponse execute(BedrockRuntimeClient bedrockRuntimeClient, ObjectMapper mapperInstance); + Object execute(BedrockRuntimeClient bedrockRuntimeClient, ObjectMapper mapperInstance); } diff --git a/connectors/aws/aws-bedrock/src/main/java/io/camunda/connector/aws/bedrock/util/FileUtil.java b/connectors/aws/aws-bedrock/src/main/java/io/camunda/connector/aws/bedrock/util/FileUtil.java new file mode 100644 index 0000000000..e691a156c6 --- /dev/null +++ b/connectors/aws/aws-bedrock/src/main/java/io/camunda/connector/aws/bedrock/util/FileUtil.java @@ -0,0 +1,37 @@ +/* + * Copyright Camunda Services GmbH and/or licensed to Camunda Services GmbH + * under one or more contributor license agreements. Licensed under a proprietary license. + * See the License.txt file for more information. You may not use this file + * except in compliance with the proprietary license. + */ +package io.camunda.connector.aws.bedrock.util; + +import org.apache.commons.lang3.tuple.Pair; +import org.apache.tika.mime.MimeTypeException; +import org.apache.tika.mime.MimeTypes; + +public final class FileUtil { + + private FileUtil() {} + + public static Pair defineNameAndType(String fileName) { + String separator = "."; + int separatorIndex = fileName.lastIndexOf(separator); + + if (separatorIndex == -1) { + return Pair.of(fileName, ""); + } + + String name = fileName.substring(0, separatorIndex); + String type = fileName.substring(separatorIndex + 1); + return Pair.of(name, type); + } + + public static String defineType(String contentType) throws MimeTypeException { + MimeTypes mimeTypes = MimeTypes.getDefaultMimeTypes(); + String extension = mimeTypes.forName(contentType).getExtension(); + + int dotIndex = extension.indexOf('.'); + return dotIndex == 0 ? extension.substring(1) : extension; + } +} diff --git a/connectors/aws/aws-bedrock/src/test/java/io/camunda/connector/aws/bedrock/BaseTest.java b/connectors/aws/aws-bedrock/src/test/java/io/camunda/connector/aws/bedrock/BaseTest.java index 1a3138a855..858282ec70 100644 --- a/connectors/aws/aws-bedrock/src/test/java/io/camunda/connector/aws/bedrock/BaseTest.java +++ b/connectors/aws/aws-bedrock/src/test/java/io/camunda/connector/aws/bedrock/BaseTest.java @@ -20,6 +20,8 @@ public class BaseTest { + protected static final ObjectMapper mapper = ConnectorsObjectMapperSupplier.getCopy(); + public static Stream loadInvokeModelVariables() { try { return loadTestCasesFromResourceFile( @@ -37,12 +39,15 @@ public static Stream loadConverseVariables() { } } + public static T readData(String path, Class type) throws IOException { + final String cases = readString(new File(path).toPath(), UTF_8); + return mapper.readValue(cases, type); + } + @SuppressWarnings("unchecked") protected static Stream loadTestCasesFromResourceFile(final String fileWithTestCasesUri) throws IOException { - final String cases = readString(new File(fileWithTestCasesUri).toPath(), UTF_8); - final ObjectMapper mapper = ConnectorsObjectMapperSupplier.getCopy(); - var array = mapper.readValue(cases, ArrayList.class); + var array = readData(fileWithTestCasesUri, ArrayList.class); return array.stream() .map( value -> { diff --git a/connectors/aws/aws-bedrock/src/test/java/io/camunda/connector/aws/bedrock/mapper/BedrockContentMapperTest.java b/connectors/aws/aws-bedrock/src/test/java/io/camunda/connector/aws/bedrock/mapper/BedrockContentMapperTest.java new file mode 100644 index 0000000000..0e6d7d3327 --- /dev/null +++ b/connectors/aws/aws-bedrock/src/test/java/io/camunda/connector/aws/bedrock/mapper/BedrockContentMapperTest.java @@ -0,0 +1,128 @@ +/* + * Copyright Camunda Services GmbH and/or licensed to Camunda Services GmbH + * under one or more contributor license agreements. Licensed under a proprietary license. + * See the License.txt file for more information. You may not use this file + * except in compliance with the proprietary license. + */ +package io.camunda.connector.aws.bedrock.mapper; + +import static io.camunda.connector.aws.bedrock.BaseTest.readData; +import static org.assertj.core.api.Assertions.assertThat; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +import io.camunda.connector.aws.bedrock.model.BedrockContent; +import io.camunda.connector.document.annotation.jackson.DocumentReferenceModel; +import io.camunda.document.CamundaDocument; +import io.camunda.document.Document; +import io.camunda.document.reference.DocumentReference; +import io.camunda.document.store.CamundaDocumentStore; +import java.io.ByteArrayInputStream; +import java.io.IOException; +import java.util.List; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.ExtendWith; +import org.mockito.InjectMocks; +import org.mockito.Mock; +import org.mockito.junit.jupiter.MockitoExtension; +import software.amazon.awssdk.services.bedrockruntime.model.ContentBlock; +import software.amazon.awssdk.services.bedrockruntime.model.DocumentBlock; + +@ExtendWith(MockitoExtension.class) +class BedrockContentMapperTest { + + @Mock private DocumentMapper documentMapper; + + @InjectMocks private BedrockContentMapper bedrockContentMapper; + + @Test + void messageToBedrockContent() { + String msg = "Hello World!"; + var bedrockContent = bedrockContentMapper.messageToBedrockContent(msg); + + assertThat(bedrockContent.getText()).isEqualTo(msg); + assertThat(bedrockContent.getDocument()).isNull(); + } + + @Test + void documentsToBedrockContent() throws IOException { + String path1 = "src/test/resources/converse/text-document.json"; + String path2 = "src/test/resources/converse/image-document.json"; + var doc1 = prepareDocument(path1); + var doc2 = prepareDocument(path2); + + List expectedContent = + List.of(new BedrockContent(doc1), new BedrockContent(doc2)); + List contentResult = + bedrockContentMapper.documentsToBedrockContent(List.of(doc1, doc2)); + + assertThat(contentResult).isEqualTo(expectedContent); + } + + @Test + void documentToBedrockContent() throws IOException { + String path = "src/test/resources/converse/text-document.json"; + var document = prepareDocument(path); + + var bedrockContent = bedrockContentMapper.documentToBedrockContent(document); + + assertThat(bedrockContent.getText()).isNull(); + assertThat(bedrockContent.getDocument()).isEqualTo(document); + } + + @Test + void mapToBedrockContent() { + String msg = "Hello World!"; + var contentBlock = ContentBlock.fromText(msg); + + List result = bedrockContentMapper.mapToBedrockContent(List.of(contentBlock)); + + assertThat(result).hasSize(1); + assertThat(result.getFirst().getText()).isEqualTo(msg); + } + + @Test + void mapToContentBlocks() throws IOException { + String msg = "Hello World!"; + var textBedrockContent = new BedrockContent(msg); + + String path = "src/test/resources/converse/text-document.json"; + var documentReference = mock(DocumentReference.CamundaDocumentReference.class); + var documentStore = mock(CamundaDocumentStore.class); + var document = prepareDocument(path, documentReference, documentStore); + + var byteInput = new ByteArrayInputStream(new byte[0]); + when(documentStore.getDocumentContent(any())).thenReturn(byteInput); + + var docContent = new BedrockContent(document); + + when(documentMapper.mapToFileBlock(any(Document.class))).thenCallRealMethod(); + + List result = + bedrockContentMapper.mapToContentBlocks(List.of(textBedrockContent, docContent)); + + DocumentBlock documentBlock = (DocumentBlock) documentMapper.mapToFileBlock(document); + ContentBlock documentContent = ContentBlock.fromDocument(documentBlock); + + ContentBlock textContent = ContentBlock.fromText(msg); + + List expected = List.of(textContent, documentContent); + + assertThat(result).isEqualTo(expected); + } + + private Document prepareDocument(String path) throws IOException { + var documentReference = mock(DocumentReference.CamundaDocumentReference.class); + var documentStore = mock(CamundaDocumentStore.class); + + return prepareDocument(path, documentReference, documentStore); + } + + private Document prepareDocument( + String path, DocumentReference.CamundaDocumentReference docRef, CamundaDocumentStore docStore) + throws IOException { + var docMetadata = readData(path, DocumentReferenceModel.CamundaDocumentMetadataModel.class); + return new CamundaDocument(docMetadata, docRef, docStore); + } +} diff --git a/connectors/aws/aws-bedrock/src/test/java/io/camunda/connector/aws/bedrock/mapper/DocumentMapperTest.java b/connectors/aws/aws-bedrock/src/test/java/io/camunda/connector/aws/bedrock/mapper/DocumentMapperTest.java new file mode 100644 index 0000000000..1dca9b4895 --- /dev/null +++ b/connectors/aws/aws-bedrock/src/test/java/io/camunda/connector/aws/bedrock/mapper/DocumentMapperTest.java @@ -0,0 +1,102 @@ +/* + * Copyright Camunda Services GmbH and/or licensed to Camunda Services GmbH + * under one or more contributor license agreements. Licensed under a proprietary license. + * See the License.txt file for more information. You may not use this file + * except in compliance with the proprietary license. + */ +package io.camunda.connector.aws.bedrock.mapper; + +import static io.camunda.connector.aws.bedrock.BaseTest.readData; +import static io.camunda.connector.aws.bedrock.mapper.DocumentMapper.UNSUPPORTED_DOC_TYPE_MSG; +import static org.assertj.core.api.Assertions.assertThat; +import static org.junit.Assert.assertThrows; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.when; + +import io.camunda.connector.document.annotation.jackson.DocumentReferenceModel; +import io.camunda.document.CamundaDocument; +import io.camunda.document.Document; +import io.camunda.document.reference.DocumentReference; +import io.camunda.document.store.CamundaDocumentStore; +import java.io.ByteArrayInputStream; +import java.io.IOException; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.ExtendWith; +import org.mockito.Mock; +import org.mockito.junit.jupiter.MockitoExtension; +import software.amazon.awssdk.core.SdkBytes; +import software.amazon.awssdk.services.bedrockruntime.model.*; + +@ExtendWith(MockitoExtension.class) +class DocumentMapperTest { + + private final DocumentMapper documentMapper = new DocumentMapper(); + @Mock private DocumentReference.CamundaDocumentReference documentReference; + @Mock private CamundaDocumentStore documentStore; + + @Test + void mapToFileBlockShouldCreateDocumentBlock() throws IOException { + String path = "src/test/resources/converse/text-document.json"; + var document = prepareDocument(path); + + var byteInput = new ByteArrayInputStream(new byte[0]); + when(documentStore.getDocumentContent(any())).thenReturn(byteInput); + + String fileName = + ((DocumentReferenceModel.CamundaDocumentMetadataModel) document.metadata()).fileName(); + var expectedDocBlock = + DocumentBlock.builder() + .source( + DocumentSource.builder() + .bytes(SdkBytes.fromByteArray(byteInput.readAllBytes())) + .build()) + .format(DocumentFormat.TXT) + .name(fileName.split("\\.")[0]) + .build(); + + var documentBlockResult = documentMapper.mapToFileBlock(document); + + assertThat(documentBlockResult).isEqualTo(expectedDocBlock); + } + + @Test + void mapToFileShouldCreateImageBlock() throws IOException { + String path = "src/test/resources/converse/image-document.json"; + var document = prepareDocument(path); + + var byteInput = new ByteArrayInputStream(new byte[0]); + when(documentStore.getDocumentContent(any())).thenReturn(byteInput); + + var imageBlockResult = documentMapper.mapToFileBlock(document); + + var expectedImageBlock = + ImageBlock.builder() + .source( + ImageSource.builder() + .bytes(SdkBytes.fromByteArray(byteInput.readAllBytes())) + .build()) + .format(ImageFormat.PNG) + .build(); + + assertThat(imageBlockResult).isEqualTo(expectedImageBlock); + } + + @Test + void mapToFileWithWithUnknownFileTypeShouldThrowException() throws IOException { + String path = "src/test/resources/converse/unsupported-document.json"; + var document = prepareDocument(path); + + var byteInput = new ByteArrayInputStream(new byte[0]); + when(documentStore.getDocumentContent(any())).thenReturn(byteInput); + + var ex = + assertThrows(IllegalArgumentException.class, () -> documentMapper.mapToFileBlock(document)); + + assertThat(ex).hasMessageContaining(UNSUPPORTED_DOC_TYPE_MSG); + } + + private Document prepareDocument(String path) throws IOException { + var docMetadata = readData(path, DocumentReferenceModel.CamundaDocumentMetadataModel.class); + return new CamundaDocument(docMetadata, documentReference, documentStore); + } +} diff --git a/connectors/aws/aws-bedrock/src/test/java/io/camunda/connector/aws/bedrock/mapper/MessageMapperTest.java b/connectors/aws/aws-bedrock/src/test/java/io/camunda/connector/aws/bedrock/mapper/MessageMapperTest.java new file mode 100644 index 0000000000..1d09c14364 --- /dev/null +++ b/connectors/aws/aws-bedrock/src/test/java/io/camunda/connector/aws/bedrock/mapper/MessageMapperTest.java @@ -0,0 +1,68 @@ +/* + * Copyright Camunda Services GmbH and/or licensed to Camunda Services GmbH + * under one or more contributor license agreements. Licensed under a proprietary license. + * See the License.txt file for more information. You may not use this file + * except in compliance with the proprietary license. + */ +package io.camunda.connector.aws.bedrock.mapper; + +import static org.assertj.core.api.Assertions.assertThat; + +import io.camunda.connector.aws.bedrock.model.BedrockMessage; +import java.util.List; +import org.junit.jupiter.api.Test; +import software.amazon.awssdk.services.bedrockruntime.model.ContentBlock; +import software.amazon.awssdk.services.bedrockruntime.model.Message; + +class MessageMapperTest { + + private final BedrockContentMapper bedrockContentMapper = + new BedrockContentMapper(new DocumentMapper()); + private final MessageMapper messageMapper = new MessageMapper(bedrockContentMapper); + + @Test + void mapToMessages() { + String msg = "Hello World!"; + String role = "user"; + var bedrockContent = bedrockContentMapper.messageToBedrockContent(msg); + + List messagesResult = + messageMapper.mapToMessages(List.of(new BedrockMessage(role, List.of(bedrockContent)))); + + List messagesExpected = + List.of( + Message.builder() + .content(bedrockContentMapper.mapToContentBlocks(List.of(bedrockContent))) + .role(role) + .build()); + + assertThat(messagesResult).isEqualTo(messagesExpected); + } + + @Test + void mapToMessage() { + String msg = "Hello World!"; + String role = "user"; + var bedrockContent = bedrockContentMapper.messageToBedrockContent(msg); + + var message = messageMapper.mapToMessage(new BedrockMessage(role, List.of(bedrockContent))); + + assertThat(message.role().toString()).isEqualTo(role); + assertThat(message.content()).isEqualTo(List.of(ContentBlock.fromText(msg))); + } + + @Test + void mapToBedrockMessage() { + String msg = "Hello World!"; + String role = "user"; + + Message message = Message.builder().content(ContentBlock.fromText(msg)).role(role).build(); + + var bedrockMessageResult = messageMapper.mapToBedrockMessage(message); + + var bedrockMessageExpected = + new BedrockMessage(role, List.of(bedrockContentMapper.messageToBedrockContent(msg))); + + assertThat(bedrockMessageResult).isIn(bedrockMessageExpected); + } +} diff --git a/connectors/aws/aws-bedrock/src/test/java/io/camunda/connector/aws/bedrock/model/ConverseDataTest.java b/connectors/aws/aws-bedrock/src/test/java/io/camunda/connector/aws/bedrock/model/ConverseDataTest.java index f20b0f54a5..5c4aaa44ca 100644 --- a/connectors/aws/aws-bedrock/src/test/java/io/camunda/connector/aws/bedrock/model/ConverseDataTest.java +++ b/connectors/aws/aws-bedrock/src/test/java/io/camunda/connector/aws/bedrock/model/ConverseDataTest.java @@ -6,6 +6,7 @@ */ package io.camunda.connector.aws.bedrock.model; +import static org.assertj.core.api.Assertions.assertThat; import static org.mockito.ArgumentMatchers.any; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.when; @@ -14,12 +15,12 @@ import java.util.ArrayList; import java.util.List; import java.util.function.Consumer; -import org.junit.jupiter.api.Assertions; import org.junit.jupiter.api.Test; import org.mockito.Mockito; import software.amazon.awssdk.services.bedrockruntime.BedrockRuntimeClient; -import software.amazon.awssdk.services.bedrockruntime.model.ConversationRole; +import software.amazon.awssdk.services.bedrockruntime.model.ContentBlock; import software.amazon.awssdk.services.bedrockruntime.model.ConverseResponse; +import software.amazon.awssdk.services.bedrockruntime.model.Message; class ConverseDataTest { @@ -33,20 +34,26 @@ void execute_success() { ConverseData converseData = new ConverseData(); converseData.setModelId("random-model-id"); - List previousMessages = new ArrayList<>(); - previousMessages.add(new PreviousMessage("Hey", ConversationRole.USER.name())); - previousMessages.add(new PreviousMessage("How are you?", ConversationRole.ASSISTANT.name())); - converseData.setMessages(previousMessages); - converseData.setNextMessage("I am good thanks, and you?"); + List previousMessages = new ArrayList<>(); + previousMessages.add( + new BedrockMessage("assistant", List.of(new BedrockContent("Hey, How are you?")))); + + converseData.setMessagesHistory(previousMessages); + converseData.setNewMessage("I am good thanks, and you?"); when(bedrockRuntimeClient.converse(any(Consumer.class))).thenReturn(converseResponse); - when(converseResponse.output().message().content().getFirst().text()) - .thenReturn("I am also good"); - BedrockResponse bedrockResponse = converseData.execute(bedrockRuntimeClient, mapper); - - Assertions.assertInstanceOf(ConverseWrapperResponse.class, bedrockResponse); - Assertions.assertEquals( - "I am also good", - ((ConverseWrapperResponse) bedrockResponse).messageHistory().getLast().message()); + Message response = + Message.builder() + .role("assistant") + .content(ContentBlock.fromText("I am also good")) + .build(); + + when(converseResponse.output().message()).thenReturn(response); + + List result = converseData.execute(bedrockRuntimeClient, mapper); + + assertThat(result.size()).isEqualTo(3); + assertThat(result.get(2).getContentList()) + .isEqualTo(List.of(new BedrockContent("I am also good"))); } } diff --git a/connectors/aws/aws-bedrock/src/test/java/io/camunda/connector/aws/bedrock/util/FileUtilTest.java b/connectors/aws/aws-bedrock/src/test/java/io/camunda/connector/aws/bedrock/util/FileUtilTest.java new file mode 100644 index 0000000000..77eb3ebd87 --- /dev/null +++ b/connectors/aws/aws-bedrock/src/test/java/io/camunda/connector/aws/bedrock/util/FileUtilTest.java @@ -0,0 +1,45 @@ +/* + * Copyright Camunda Services GmbH and/or licensed to Camunda Services GmbH + * under one or more contributor license agreements. Licensed under a proprietary license. + * See the License.txt file for more information. You may not use this file + * except in compliance with the proprietary license. + */ +package io.camunda.connector.aws.bedrock.util; + +import static org.assertj.core.api.AssertionsForClassTypes.assertThat; + +import org.apache.commons.lang3.tuple.Pair; +import org.apache.tika.mime.MimeTypeException; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.CsvSource; + +class FileUtilTest { + + @ParameterizedTest + @CsvSource({"test-file, txt", "test.file, pdf"}) + void defineNameAndType(String fileName, String fileExtension) { + String fullFileName = fileName + "." + fileExtension; + + Pair result = FileUtil.defineNameAndType(fullFileName); + + assertThat(result.getLeft()).isEqualTo(fileName); + assertThat(result.getRight()).isEqualTo(fileExtension); + } + + @Test + void defineNameAndTypeWhenFileWithoutExtensionShouldReturnOnlyName() { + String fileName = "test-file"; + + Pair result = FileUtil.defineNameAndType(fileName); + + assertThat(result.getLeft()).isEqualTo(fileName); + assertThat(result.getRight()).isEmpty(); + } + + @Test + void defineType() throws MimeTypeException { + String result = FileUtil.defineType("text/plain"); + assertThat(result).isEqualTo("txt"); + } +} diff --git a/connectors/aws/aws-bedrock/src/test/resources/converse/converseExample.json b/connectors/aws/aws-bedrock/src/test/resources/converse/converseExample.json index b17a2a2967..c69a58ed77 100644 --- a/connectors/aws/aws-bedrock/src/test/resources/converse/converseExample.json +++ b/connectors/aws/aws-bedrock/src/test/resources/converse/converseExample.json @@ -2,11 +2,11 @@ { "data":{ "modelId":"amazon.titan-text-premier-v1:0", - "nextMessage":"ok", - "messages":[ - {"message":"Hey", "role":"USER"}, {"message":"Hello there! How can I assist you today?", "role":"ASSISTANT"}, - {"message":"You feel good ?", "role":"USER"}, { - "message":"Thank you for asking! As an AI, I don't have personal feelings or experiences, but I'm always here to help with any questions or information you may need. How can I assist you today?", + "newMessage":"ok", + "messagesHistory":[ + {"text":"Hey", "role":"USER"}, {"text":"Hello there! How can I assist you today?", "role":"ASSISTANT"}, + {"text":"You feel good ?", "role":"USER"}, { + "text":"Thank you for asking! As an AI, I don't have personal feelings or experiences, but I'm always here to help with any questions or information you may need. How can I assist you today?", "role":"ASSISTANT" } ] diff --git a/connectors/aws/aws-bedrock/src/test/resources/converse/image-document.json b/connectors/aws/aws-bedrock/src/test/resources/converse/image-document.json new file mode 100644 index 0000000000..ab5815b59c --- /dev/null +++ b/connectors/aws/aws-bedrock/src/test/resources/converse/image-document.json @@ -0,0 +1,6 @@ +{ + "contentType": "image/png", + "fileName": "iphone-image.png", + "size": 49213, + "customProperties": {} +} \ No newline at end of file diff --git a/connectors/aws/aws-bedrock/src/test/resources/converse/text-document.json b/connectors/aws/aws-bedrock/src/test/resources/converse/text-document.json new file mode 100644 index 0000000000..9ba89a96d7 --- /dev/null +++ b/connectors/aws/aws-bedrock/src/test/resources/converse/text-document.json @@ -0,0 +1,6 @@ +{ + "contentType": "text/plain", + "fileName": "test.txt", + "size": 4, + "customProperties": {} +} \ No newline at end of file diff --git a/connectors/aws/aws-bedrock/src/test/resources/converse/unsupported-document.json b/connectors/aws/aws-bedrock/src/test/resources/converse/unsupported-document.json new file mode 100644 index 0000000000..79228ebac4 --- /dev/null +++ b/connectors/aws/aws-bedrock/src/test/resources/converse/unsupported-document.json @@ -0,0 +1,6 @@ +{ + "contentType": "application/trgx", + "fileName": "test", + "size": 4, + "customProperties": {} +} \ No newline at end of file