diff --git a/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/llama2/Llama2Client.java b/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/llama2/Llama2Client.java index e1a389199..99ec3bc02 100644 --- a/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/llama2/Llama2Client.java +++ b/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/llama2/Llama2Client.java @@ -34,13 +34,6 @@ public EdgeChain> createChatCompletion( logger.info("==============REQUEST DATA================"); logger.info(request.toString()); - // Llama2ChatCompletionRequest llamaRequest = new - // Llama2ChatCompletionRequest(); - // - // llamaRequest.setInputs(request.getInputs()); - // - // llamaRequest.setParameters(request.getParameters()); - // Create headers HttpHeaders headers = new HttpHeaders(); headers.setContentType(MediaType.APPLICATION_JSON); @@ -51,7 +44,8 @@ public EdgeChain> createChatCompletion( List chatCompletionResponse = objectMapper.readValue( - response, new TypeReference>() {}); + response, new TypeReference<>() { + }); emitter.onNext(chatCompletionResponse); emitter.onComplete(); diff --git a/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/llama2/request/Llama2ChatCompletionRequest.java b/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/llama2/request/Llama2ChatCompletionRequest.java index 17f190e93..684ce2cca 100644 --- a/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/llama2/request/Llama2ChatCompletionRequest.java +++ b/FlySpring/edgechain-app/src/main/java/com/edgechain/lib/llama2/request/Llama2ChatCompletionRequest.java @@ -18,9 +18,9 @@ public Llama2ChatCompletionRequest(String inputs, JSONObject parameters) { @Override public String toString() { - return new StringJoiner(", ", Llama2ChatCompletionRequest.class.getSimpleName() + "[{", "}]") - .add("\"inputs:\"" + inputs) - .add("\"parameters:\"" + parameters) + return new StringJoiner(", ", Llama2ChatCompletionRequest.class.getSimpleName() + "{", "}") + .add("\"inputs\":" + inputs) + .add("\"parameters\":" + parameters) .toString(); } diff --git a/FlySpring/edgechain-app/src/test/java/com/edgechain/llama/Llama2ClientTest.java b/FlySpring/edgechain-app/src/test/java/com/edgechain/llama/Llama2ClientTest.java new file mode 100644 index 000000000..792279ab2 --- /dev/null +++ b/FlySpring/edgechain-app/src/test/java/com/edgechain/llama/Llama2ClientTest.java @@ -0,0 +1,53 @@ +package com.edgechain.llama; + +import com.edgechain.lib.llama2.request.Llama2ChatCompletionRequest; +import org.json.JSONObject; +import org.junit.jupiter.api.Assertions; +import org.junit.jupiter.api.DisplayName; +import org.junit.jupiter.api.Test; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import org.springframework.boot.test.context.SpringBootTest; + +import java.util.List; + +@SpringBootTest +public class Llama2ClientTest { + + Logger logger = LoggerFactory.getLogger(getClass()); + + @Test + @DisplayName("Test LLamaClient - Request Format") + void TestLLamaClient_LLamaRequest_ShouldMatchFormat() { + + Llama2ChatCompletionRequest llama2ChatCompletionRequest = Llama2ChatCompletionRequest.builder() + .inputs("[INST]<>What is the color of sky?<>") + .parameters(getJsonObject()) + .build(); + + logger.info("llama completion request data {} ", llama2ChatCompletionRequest); + + + String result = String.valueOf(new JSONObject(llama2ChatCompletionRequest)); + String expected = getRequestBody(); + + Assertions.assertEquals(expected, result); + } + private static JSONObject getJsonObject() { + JSONObject parameters = new JSONObject(); + parameters.put("do_sample", true); + parameters.put("top_p", 50); + parameters.put("temperature", 0.7); + parameters.put("top_k", 2); + parameters.put("max_new_tokens", 512); + parameters.put("repetition_penalty", 0.6); + parameters.put("stop", List.of("")); + return parameters; + } + + private static String getRequestBody(){ + return "{\"inputs\":\"[INST]<>What is the color of sky?<<\\/SYS>>\"," + + "\"parameters\":{\"top_p\":50,\"stop\":[\"<\\/s>\"],\"max_new_tokens\":512," + + "\"top_k\":2,\"temperature\":0.7,\"do_sample\":true,\"repetition_penalty\":0.6}}"; + } +} \ No newline at end of file