Skip to content

Commit

Permalink
Small fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
InAnYan committed Aug 13, 2024
1 parent f3ffc54 commit 85382a1
Show file tree
Hide file tree
Showing 9 changed files with 33 additions and 26 deletions.
2 changes: 1 addition & 1 deletion src/main/java/org/jabref/gui/JabRefGUI.java
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,7 @@ public void initialize() {
JabRefGUI.clipBoardManager = new ClipBoardManager();
Injector.setModelOrService(ClipBoardManager.class, clipBoardManager);

JabRefGUI.aiService = new AiService(preferencesService.getAiPreferences(), dialogService, taskExecutor);
JabRefGUI.aiService = new AiService(preferencesService, dialogService, taskExecutor);
Injector.setModelOrService(AiService.class, aiService);
}

Expand Down
4 changes: 3 additions & 1 deletion src/main/java/org/jabref/gui/entryeditor/AiChatTab.java
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ public class AiChatTab extends EntryEditorTab {
private final BibDatabaseContext bibDatabaseContext;
private final TaskExecutor taskExecutor;
private final CitationKeyGenerator citationKeyGenerator;
private final PreferencesService preferencesService;
private final AiService aiService;

private final List<BibEntry> entriesUnderIngestion = new ArrayList<>();
Expand All @@ -64,6 +65,7 @@ public AiChatTab(LibraryTabContainer libraryTabContainer,
this.bibDatabaseContext = bibDatabaseContext;
this.taskExecutor = taskExecutor;
this.citationKeyGenerator = new CitationKeyGenerator(bibDatabaseContext, preferencesService.getCitationKeyPatternPreferences());
this.preferencesService = preferencesService;

setText(Localization.lang("AI chat"));
setTooltip(new Tooltip(Localization.lang("Chat with AI about content of attached file(s)")));
Expand All @@ -88,7 +90,7 @@ protected void handleFocus() {
protected void bindToEntry(BibEntry entry) {
if (!aiService.getPreferences().getEnableAi()) {
showPrivacyNotice(entry);
} else if (aiService.getPreferences().getSelectedApiKey().isEmpty()) {
} else if (aiService.getPreferences().getSelectedApiKey(preferencesService).isEmpty()) {
showApiKeyMissing();
} else if (entry.getFiles().isEmpty()) {
showErrorNoFiles();
Expand Down
4 changes: 3 additions & 1 deletion src/main/java/org/jabref/gui/entryeditor/AiSummaryTab.java
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ public class AiSummaryTab extends EntryEditorTab {
private final TaskExecutor taskExecutor;
private final CitationKeyGenerator citationKeyGenerator;
private final AiService aiService;
private final PreferencesService preferencesService;

private final List<BibEntry> entriesUnderSummarization = new ArrayList<>();

Expand All @@ -59,6 +60,7 @@ public AiSummaryTab(LibraryTabContainer libraryTabContainer,
this.bibDatabaseContext = bibDatabaseContext;
this.taskExecutor = taskExecutor;
this.citationKeyGenerator = new CitationKeyGenerator(bibDatabaseContext, preferencesService.getCitationKeyPatternPreferences());
this.preferencesService = preferencesService;

setText(Localization.lang("AI summary"));
setTooltip(new Tooltip(Localization.lang("AI-generated summary of attached file(s)")));
Expand All @@ -82,7 +84,7 @@ protected void handleFocus() {
protected void bindToEntry(BibEntry entry) {
if (!aiService.getPreferences().getEnableAi()) {
showPrivacyNotice(entry);
} else if (aiService.getPreferences().getSelectedApiKey().isEmpty()) {
} else if (aiService.getPreferences().getSelectedApiKey(preferencesService).isEmpty()) {
showApiKeyMissing();
} else if (bibDatabaseContext.getDatabasePath().isEmpty()) {
showErrorNoDatabasePath();
Expand Down
7 changes: 4 additions & 3 deletions src/main/java/org/jabref/logic/ai/AiService.java
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import org.jabref.logic.ai.models.JabRefEmbeddingModel;
import org.jabref.logic.ai.summarization.SummariesStorage;
import org.jabref.logic.l10n.Localization;
import org.jabref.preferences.PreferencesService;
import org.jabref.preferences.ai.AiPreferences;

import com.google.common.util.concurrent.ThreadFactoryBuilder;
Expand Down Expand Up @@ -56,8 +57,8 @@ public class AiService implements AutoCloseable {

private final SummariesStorage summariesStorage;

public AiService(AiPreferences aiPreferences, DialogService dialogService, TaskExecutor taskExecutor) {
this.aiPreferences = aiPreferences;
public AiService(PreferencesService preferencesService, DialogService dialogService, TaskExecutor taskExecutor) {
this.aiPreferences = preferencesService.getAiPreferences();

MVStore mvStore;
try {
Expand All @@ -74,7 +75,7 @@ public AiService(AiPreferences aiPreferences, DialogService dialogService, TaskE

this.mvStore = mvStore;

this.jabRefChatLanguageModel = new JabRefChatLanguageModel(aiPreferences);
this.jabRefChatLanguageModel = new JabRefChatLanguageModel(preferencesService);
this.bibDatabaseChatHistoryManager = new BibDatabaseChatHistoryManager(mvStore);
this.jabRefEmbeddingModel = new JabRefEmbeddingModel(aiPreferences, dialogService, taskExecutor);
this.fileEmbeddingsManager = new FileEmbeddingsManager(aiPreferences, shutdownSignal, jabRefEmbeddingModel, mvStore);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

import org.jabref.logic.ai.AiChatLogic;
import org.jabref.logic.l10n.Localization;
import org.jabref.preferences.PreferencesService;
import org.jabref.preferences.ai.AiPreferences;

import com.google.common.util.concurrent.ThreadFactoryBuilder;
Expand All @@ -29,6 +30,7 @@
public class JabRefChatLanguageModel implements ChatLanguageModel, AutoCloseable {
private static final Duration CONNECTION_TIMEOUT = Duration.ofSeconds(5);

private final PreferencesService preferencesService;
private final AiPreferences aiPreferences;

private final HttpClient httpClient;
Expand All @@ -38,8 +40,9 @@ public class JabRefChatLanguageModel implements ChatLanguageModel, AutoCloseable

private Optional<ChatLanguageModel> langchainChatModel = Optional.empty();

public JabRefChatLanguageModel(AiPreferences aiPreferences) {
this.aiPreferences = aiPreferences;
public JabRefChatLanguageModel(PreferencesService preferencesService) {
this.preferencesService = preferencesService;
this.aiPreferences = preferencesService.getAiPreferences();
this.httpClient = HttpClient.newBuilder().connectTimeout(CONNECTION_TIMEOUT).executor(executorService).build();

if (aiPreferences.getEnableAi()) {
Expand All @@ -56,20 +59,20 @@ public JabRefChatLanguageModel(AiPreferences aiPreferences) {
* and using {@link org.jabref.logic.ai.chathistory.BibDatabaseChatHistoryManager}, where messages are stored in {@link MVStore}.
*/
private void rebuild() {
if (!aiPreferences.getEnableAi() || aiPreferences.getSelectedApiKey().isEmpty()) {
if (!aiPreferences.getEnableAi() || aiPreferences.getSelectedApiKey(preferencesService).isEmpty()) {
langchainChatModel = Optional.empty();
return;
}

switch (aiPreferences.getAiProvider()) {
case OPEN_AI -> {
langchainChatModel = Optional.of(new JvmOpenAiChatLanguageModel(aiPreferences, httpClient));
langchainChatModel = Optional.of(new JvmOpenAiChatLanguageModel(preferencesService, httpClient));
}

case MISTRAL_AI -> {
langchainChatModel = Optional.of(MistralAiChatModel
.builder()
.apiKey(aiPreferences.getSelectedApiKey())
.apiKey(aiPreferences.getSelectedApiKey(preferencesService))
.modelName(aiPreferences.getSelectedChatModel())
.temperature(aiPreferences.getTemperature())
.baseUrl(aiPreferences.getSelectedApiBaseUrl())
Expand All @@ -83,7 +86,7 @@ private void rebuild() {
// NOTE: {@link HuggingFaceChatModel} doesn't support API base url :(
langchainChatModel = Optional.of(HuggingFaceChatModel
.builder()
.accessToken(aiPreferences.getSelectedApiKey())
.accessToken(aiPreferences.getSelectedApiKey(preferencesService))
.modelId(aiPreferences.getSelectedChatModel())
.temperature(aiPreferences.getTemperature())
.timeout(Duration.ofMinutes(2))
Expand Down Expand Up @@ -116,7 +119,7 @@ public Response<AiMessage> generate(List<ChatMessage> list) {
if (langchainChatModel.isEmpty()) {
if (!aiPreferences.getEnableAi()) {
throw new RuntimeException(Localization.lang("In order to use AI chat, you need to enable chatting with attached PDF files in JabRef preferences (AI tab)."));
} else if (aiPreferences.getSelectedApiKey().isEmpty()) {
} else if (aiPreferences.getSelectedApiKey(preferencesService).isEmpty()) {
throw new RuntimeException(Localization.lang("In order to use AI chat, set OpenAI API key inside JabRef preferences (AI tab)."));
} else {
throw new RuntimeException(Localization.lang("Unable to chat with AI."));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,10 @@ public void startRebuildingTask() {
return;
}

if (predictorProperty.get().isPresent()) {
predictorProperty.get().get().close();
}

predictorProperty.set(Optional.empty());

new UpdateEmbeddingModelTask(aiPreferences, predictorProperty)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import java.net.http.HttpClient;
import java.util.List;

import org.jabref.preferences.PreferencesService;
import org.jabref.preferences.ai.AiPreferences;

import dev.langchain4j.data.message.AiMessage;
Expand All @@ -29,11 +30,11 @@ public class JvmOpenAiChatLanguageModel implements ChatLanguageModel {

private final ChatClient chatClient;

public JvmOpenAiChatLanguageModel(AiPreferences aiPreferences, HttpClient httpClient) {
this.aiPreferences = aiPreferences;
public JvmOpenAiChatLanguageModel(PreferencesService preferencesService, HttpClient httpClient) {
this.aiPreferences = preferencesService.getAiPreferences();

OpenAI openAI = OpenAI
.newBuilder(aiPreferences.getSelectedApiKey())
.newBuilder(aiPreferences.getSelectedApiKey(preferencesService))
.httpClient(httpClient)
.baseUrl(aiPreferences.getSelectedApiBaseUrl())
.build();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2790,7 +2790,6 @@ public AiPreferences getAiPreferences() {
boolean aiEnabled = getBoolean(AI_ENABLED);

aiPreferences = new AiPreferences(
this,
aiEnabled,
AiProvider.valueOf(get(AI_PROVIDER)),
get(AI_OPEN_AI_CHAT_MODEL),
Expand Down
13 changes: 4 additions & 9 deletions src/main/java/org/jabref/preferences/ai/AiPreferences.java
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,6 @@
import org.jabref.preferences.PreferencesService;

public class AiPreferences {
private final PreferencesService preferencesService;

private final BooleanProperty enableAi;

private final ObjectProperty<AiProvider> aiProvider;
Expand Down Expand Up @@ -46,8 +44,7 @@ public class AiPreferences {
private final IntegerProperty ragMaxResultsCount;
private final DoubleProperty ragMinScore;

public AiPreferences(PreferencesService preferencesService,
boolean enableAi,
public AiPreferences(boolean enableAi,
AiProvider aiProvider,
String openAiChatModel,
String mistralAiChatModel,
Expand All @@ -65,8 +62,6 @@ public AiPreferences(PreferencesService preferencesService,
int ragMaxResultsCount,
double ragMinScore
) {
this.preferencesService = preferencesService;

this.enableAi = new SimpleBooleanProperty(enableAi);

this.aiProvider = new SimpleObjectProperty<>(aiProvider);
Expand Down Expand Up @@ -467,16 +462,16 @@ public String getSelectedChatModel() {
};
}

public String getSelectedApiKey() {
public String getSelectedApiKey(PreferencesService preferencesService) {
if (!enableAi.get()) {
return "";
}

retrieveKeys();
retrieveKeys(preferencesService);
return getKeys();
}

private void retrieveKeys() {
private void retrieveKeys(PreferencesService preferencesService) {
switch (aiProvider.get()) {
case OPEN_AI -> {
if (openAiApiKey.get().isEmpty()) {
Expand Down

0 comments on commit 85382a1

Please sign in to comment.