Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Bugfix/fix chat bugs #77

Merged
merged 7 commits into from
Sep 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -275,7 +275,7 @@ LETO_ADMIN_URL=http://localhost:9000/
LIBRARY_HOST_WHITELIST=https://github.com/ditrit/
CSRF_TOKEN_TIMEOUT=3600
USER_SESSION_TIMEOUT=3600
AI_HOST=http://locahost:8585/api/
AI_HOST=http://locahost:8585/
```

See Configuration section for more details.
Expand Down
14 changes: 7 additions & 7 deletions build.gradle
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
plugins {
id 'java'
id 'org.springframework.boot' version '3.3.3'
id 'org.springframework.boot' version '3.3.4'
id 'io.spring.dependency-management' version '1.1.6'
id 'checkstyle'
id 'com.github.ben-manes.versions' version '0.51.0'
Expand Down Expand Up @@ -45,18 +45,18 @@ dependencies {
implementation 'org.springframework.boot:spring-boot-starter-security'
implementation 'org.springframework.boot:spring-boot-starter-oauth2-client'
implementation 'org.springframework.session:spring-session-jdbc:3.3.2'
implementation 'org.flywaydb:flyway-core:10.17.3'
implementation "org.flywaydb:flyway-database-postgresql:10.17.3"
implementation 'org.flywaydb:flyway-core:10.18.1'
implementation "org.flywaydb:flyway-database-postgresql:10.18.1"
implementation 'commons-lang:commons-lang:2.6'
implementation 'commons-beanutils:commons-beanutils:1.9.4'
implementation 'com.github.erosb:json-sKema:0.16.0'
implementation 'com.github.erosb:json-sKema:0.18.0'
compileOnly 'org.projectlombok:lombok'
runtimeOnly 'org.postgresql:postgresql:42.7.4'
annotationProcessor 'org.projectlombok:lombok'
testImplementation 'org.springframework.boot:spring-boot-starter-test'
testImplementation 'io.cucumber:cucumber-java:7.18.1'
testImplementation 'io.cucumber:cucumber-junit:7.18.1'
testImplementation 'org.junit.vintage:junit-vintage-engine:5.11.0'
testImplementation 'io.cucumber:cucumber-java:7.19.0'
testImplementation 'io.cucumber:cucumber-junit:7.19.0'
testImplementation 'org.junit.vintage:junit-vintage-engine:5.11.1'
}

tasks.named('test') {
Expand Down
10 changes: 10 additions & 0 deletions src/main/java/com/ditrit/letomodelizerapi/config/Constants.java
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,16 @@ public final class Constants {
*/
public static final String DEFAULT_UPDATE_DATE_PROPERTY = "updateDate";

/**
* The constant representing the default message property.
*/
public static final String DEFAULT_MESSAGE_PROPERTY = "message";

/**
* The constant representing the default plugin name property.
*/
public static final String DEFAULT_PLUGIN_NAME_PROPERTY = "pluginName";

/**
* Private constructor.
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import com.ditrit.letomodelizerapi.model.ai.AIConversationRecord;
import com.ditrit.letomodelizerapi.model.ai.AICreateFileRecord;
import com.ditrit.letomodelizerapi.model.ai.AIMessageDTO;
import com.ditrit.letomodelizerapi.model.ai.AIMessageRecord;
import com.ditrit.letomodelizerapi.model.mapper.ai.AIMessageToDTOMapper;
import com.ditrit.letomodelizerapi.model.permission.ActionPermission;
import com.ditrit.letomodelizerapi.model.permission.EntityPermission;
Expand Down Expand Up @@ -237,22 +238,22 @@ public Response deleteConversationById(final @Context HttpServletRequest request
*
* @param request the HttpServletRequest used to access the user's session.
* @param id the ID of the AI conversation to which the message is sent.
* @param message the content of the message to send to the AI.
* @param aiMessage the record that contains the message to send to the AI.
* @return a Response object containing the AI's reply in plain text with a status of CREATED (201).
* @throws JsonProcessingException if there is an error processing the request data.
*/
@POST
@Path("/conversations/{id}/messages")
public Response createConversationMessage(final @Context HttpServletRequest request,
final @PathParam("id") @Valid @NotNull UUID id,
final @Valid @NotNull String message) throws IOException {
final @Valid AIMessageRecord aiMessage) throws IOException {
HttpSession session = request.getSession();
User user = userService.getFromSession(session);

log.info("[{}] Received POST request to send message to conversation id {}",
user.getLogin(), id.toString());

AIMessageDTO aiMessageDTO = new AIMessageToDTOMapper().apply(aiService.sendMessage(user, id, message));
AIMessageDTO aiMessageDTO = new AIMessageToDTOMapper().apply(aiService.sendMessage(user, id, aiMessage));

return Response.status(HttpStatus.CREATED.value()).entity(aiMessageDTO).build();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
*/
public record AIConversationRecord(
@NotBlank String project,
@NotBlank String diagram,
@NotNull String diagram,
@NotBlank String plugin,
String checksum,
@NotNull List<FileRecord> files
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
package com.ditrit.letomodelizerapi.model.ai;

import jakarta.validation.constraints.NotBlank;

/**
* A record representing the data required to create a new conversation with AI.
*
* @param message Message to send to AI. Cannot be blank.
* @param plugin The name of the plugin involved in the conversation. Cannot be blank.
*/
public record AIMessageRecord(
@NotBlank String message,
@NotBlank String plugin
) {
}
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import com.ditrit.letomodelizerapi.model.ai.AIConversationRecord;
import com.ditrit.letomodelizerapi.model.ai.AICreateFileRecord;
import com.ditrit.letomodelizerapi.model.ai.AIMessageRecord;
import com.ditrit.letomodelizerapi.persistence.model.AIConversation;
import com.ditrit.letomodelizerapi.persistence.model.AIMessage;
import com.ditrit.letomodelizerapi.persistence.model.User;
Expand Down Expand Up @@ -80,11 +81,11 @@ AIConversation updateConversationById(User user, UUID id, AIConversationRecord a
*
* @param user the user sending the message.
* @param id the ID of the conversation to which the message is sent.
* @param message the content of the message to be sent.
* @param aiMessage the record that contains the message to be sent.
* @return the AI's response to the message.
* @throws JsonProcessingException if there is an error processing the JSON data.
*/
AIMessage sendMessage(User user, UUID id, String message) throws IOException;
AIMessage sendMessage(User user, UUID id, AIMessageRecord aiMessage) throws IOException;

/**
* Finds all AI conversations for a user with optional filtering and pagination.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import com.ditrit.letomodelizerapi.config.Constants;
import com.ditrit.letomodelizerapi.model.ai.AIConversationRecord;
import com.ditrit.letomodelizerapi.model.ai.AICreateFileRecord;
import com.ditrit.letomodelizerapi.model.ai.AIMessageRecord;
import com.ditrit.letomodelizerapi.model.error.ApiException;
import com.ditrit.letomodelizerapi.model.error.ErrorType;
import com.ditrit.letomodelizerapi.model.file.FileRecord;
Expand Down Expand Up @@ -109,7 +110,6 @@ public String sendRequest(final String endpoint, final String body) {
.uri(uri)
.header(HttpHeaders.CONTENT_TYPE, MediaType.APPLICATION_JSON)
.headers(HttpHeaders.ACCEPT, MediaType.APPLICATION_JSON)
.version(HttpClient.Version.HTTP_1_1)
.POST(HttpRequest.BodyPublishers.ofString(body))
.build();

Expand Down Expand Up @@ -158,11 +158,10 @@ public String sendFiles(final AIConversation conversation, final List<FileRecord
});

ObjectNode json = JsonNodeFactory.instance.objectNode();
json.put("pluginName", conversation.getKey().split("/")[2]);
json.put(Constants.DEFAULT_PLUGIN_NAME_PROPERTY, conversation.getKey().split("/")[2]);
json.set("files", arrayNode);

JsonNode response = mapper.readTree(sendRequest("chat", json.toString()));

JsonNode response = mapper.readTree(sendRequest(Constants.DEFAULT_MESSAGE_PROPERTY, json.toString()));
return response.get(Constants.DEFAULT_CONTEXT_PROPERTY).asText();
}

Expand All @@ -185,7 +184,7 @@ public byte[] compress(final String message) throws IOException {
@Override
public String createFile(final AICreateFileRecord createFileRecord) {
ObjectNode json = JsonNodeFactory.instance.objectNode();
json.put("pluginName", createFileRecord.plugin());
json.put(Constants.DEFAULT_PLUGIN_NAME_PROPERTY, createFileRecord.plugin());
json.put("description", createFileRecord.description());

return sendRequest(createFileRecord.type(), json.toString());
Expand Down Expand Up @@ -252,12 +251,12 @@ public void deleteConversationById(final User user, final UUID id) {
}

@Override
public AIMessage sendMessage(final User user, final UUID id, final String message) throws IOException {
public AIMessage sendMessage(final User user, final UUID id, final AIMessageRecord aiMessage) throws IOException {
AIConversation conversation = aiConversationRepository.findByIdAndUserId(id, user.getId())
.orElseThrow(() -> new ApiException(ErrorType.ENTITY_NOT_FOUND, "id", id.toString()));

byte[] compressedMessage = compress(message);
Long size = conversation.getSize() + compressedMessage.length;
byte[] compressedMessage = compress(aiMessage.message());
long size = conversation.getSize() + compressedMessage.length;

AIMessage userMessage = new AIMessage();
userMessage.setAiConversation(conversation.getId());
Expand All @@ -267,24 +266,26 @@ public AIMessage sendMessage(final User user, final UUID id, final String messag

ObjectNode json = JsonNodeFactory.instance.objectNode();
json.put(Constants.DEFAULT_CONTEXT_PROPERTY, conversation.getContext());
json.put("message", message);
json.put(Constants.DEFAULT_MESSAGE_PROPERTY, aiMessage.message());
json.put(Constants.DEFAULT_PLUGIN_NAME_PROPERTY, aiMessage.plugin());

JsonNode response = new ObjectMapper().readTree(sendRequest("chat", json.toString()));
JsonNode response = new ObjectMapper()
.readTree(sendRequest(Constants.DEFAULT_MESSAGE_PROPERTY, json.toString()));

compressedMessage = compress(response.get("message").asText());
compressedMessage = compress(response.get(Constants.DEFAULT_MESSAGE_PROPERTY).asText());
size += compressedMessage.length;

AIMessage aiMessage = new AIMessage();
aiMessage.setAiConversation(conversation.getId());
aiMessage.setIsUser(false);
aiMessage.setMessage(compressedMessage);
aiMessage = aiMessageRepository.save((aiMessage));
AIMessage aiMessageResponse = new AIMessage();
aiMessageResponse.setAiConversation(conversation.getId());
aiMessageResponse.setIsUser(false);
aiMessageResponse.setMessage(compressedMessage);
aiMessageResponse = aiMessageRepository.save((aiMessageResponse));

conversation.setContext(response.get(Constants.DEFAULT_CONTEXT_PROPERTY).asText());
conversation.setSize(size);
aiConversationRepository.save(conversation);

return aiMessage;
return aiMessageResponse;
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -353,15 +353,15 @@ public Page<Library> findAll(final User user, final Map<String, String> immutabl

return userLibraryViewRepository.findAll(
new SpecificationHelper<>(UserLibraryView.class, filters),
PageRequest.of(pageable.getPageNumber(), pageable.getPageSize())
PageRequest.of(pageable.getPageNumber(), pageable.getPageSize(), pageable.getSort())
).map(new UserLibraryViewToLibraryFunction());
}

@Override
public Page<LibraryTemplate> findAllTemplates(final Map<String, String> filters, final Pageable pageable) {
return libraryTemplateRepository.findAll(
new SpecificationHelper<>(LibraryTemplate.class, filters),
PageRequest.of(pageable.getPageNumber(), pageable.getPageSize())
PageRequest.of(pageable.getPageNumber(), pageable.getPageSize(), pageable.getSort())
);
}

Expand All @@ -372,7 +372,7 @@ public Page<LibraryTemplate> findAllTemplates(final User user,
return userLibraryTemplateViewRepository.findAllByUserId(
user.getId(),
new SpecificationHelper<>(UserLibraryTemplateView.class, filters),
PageRequest.of(pageable.getPageNumber(), pageable.getPageSize())
PageRequest.of(pageable.getPageNumber(), pageable.getPageSize(), pageable.getSort())
).map(new UserLibraryTemplateViewToLibraryTemplateFunction());
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import com.ditrit.letomodelizerapi.helper.MockHelper;
import com.ditrit.letomodelizerapi.model.ai.AIConversationRecord;
import com.ditrit.letomodelizerapi.model.ai.AICreateFileRecord;
import com.ditrit.letomodelizerapi.model.ai.AIMessageRecord;
import com.ditrit.letomodelizerapi.persistence.model.AIConversation;
import com.ditrit.letomodelizerapi.persistence.model.AIMessage;
import com.ditrit.letomodelizerapi.persistence.model.User;
Expand Down Expand Up @@ -217,7 +218,7 @@ void testCreateConversationMessage() throws IOException {
Mockito.when(userService.getFromSession(Mockito.any())).thenReturn(user);
Mockito.when(aiService.sendMessage(Mockito.any(), Mockito.any(), Mockito.any())).thenReturn(message);

Response response = this.controller.createConversationMessage(request, UUID.randomUUID(), "ok");
Response response = this.controller.createConversationMessage(request, UUID.randomUUID(), new AIMessageRecord("ok", "plugin"));

assertNotNull(response);
assertEquals(HttpStatus.CREATED.value(), response.getStatus());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import com.ditrit.letomodelizerapi.model.ai.AIConversationRecord;
import com.ditrit.letomodelizerapi.model.ai.AICreateFileRecord;
import com.ditrit.letomodelizerapi.model.ai.AIMessageRecord;
import com.ditrit.letomodelizerapi.model.error.ApiException;
import com.ditrit.letomodelizerapi.model.error.ErrorType;
import com.ditrit.letomodelizerapi.persistence.model.AIConversation;
Expand Down Expand Up @@ -365,7 +366,7 @@ void testSendMessage() throws IOException, InterruptedException {
User user = new User();
user.setId(UUID.randomUUID());

AIMessage message = service.sendMessage(user, UUID.randomUUID(), "ok");
AIMessage message = service.sendMessage(user, UUID.randomUUID(), new AIMessageRecord("ok", "plugin"));

assertEquals(message, expectedMessage);

Expand All @@ -384,7 +385,7 @@ void testSendMessageThrowException() throws IOException {
ApiException exception = null;

try {
service.sendMessage(user, UUID.randomUUID(), "test");
service.sendMessage(user, UUID.randomUUID(), new AIMessageRecord("test", "plugin"));
} catch (ApiException e) {
exception = e;
}
Expand Down
2 changes: 1 addition & 1 deletion src/test/resources/ai/index.php
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
error_log($type);
error_log(file_get_contents('php://input'));

if ($type == "chat") {
if ($type == "message") {
$contextValue = isset($requestBody["context"]) ? (int) $requestBody["context"] : 0;
$data = [
"context" =>1 + $contextValue,
Expand Down
7 changes: 4 additions & 3 deletions src/test/resources/features/AI.feature
Original file line number Diff line number Diff line change
Expand Up @@ -45,9 +45,10 @@ Feature: ai feature
And I extract resources from response
And I expect one resource contains "id" equals to "[conversation_id]"

When I request "/ai/conversations/[conversation_id]/messages" with method "POST" with body
| body | type |
| test | String |
When I request "/ai/conversations/[conversation_id]/messages" with method "POST" with json
| key | value | type |
| message | TEST | string |
| plugin | @ditrit/terrator-plugin | string |
Then I expect "201" as status code
And I expect response fields length is "5"
And I expect response field "id" is "NOT_NULL"
Expand Down
4 changes: 2 additions & 2 deletions src/test/resources/features/Library.feature
Original file line number Diff line number Diff line change
Expand Up @@ -130,8 +130,8 @@ Feature: Library feature

# Check if library templates are created
When I request "/libraries/templates" with method "GET"
Then I expect "200" as status code
And I expect response field "totalElements" is "3" as "number"
Then I expect "206" as status code
And I expect response field "totalElements" is "33" as "number"

# Delete library
When I request "/libraries/[libraryId]" with method "DELETE"
Expand Down
Loading
Loading