Skip to content

Commit

Permalink
🪨 fix: Minor AWS Bedrock/Misc. Improvements (#3974)
Browse files Browse the repository at this point in the history
* refactor(EditMessage): avoid manipulation of native paste handling, leverage react-hook-form for textarea changes

* style: apply better theming for MinimalIcon

* fix(useVoicesQuery/useCustomConfigSpeechQuery): make sure to only try request once per render

* feat: edit message content parts

* fix(useCopyToClipboard): handle both assistants and agents content blocks

* refactor: remove save & submit and update text content correctly

* chore(.env.example/config): exclude unsupported bedrock models

* feat: artifacts for aws bedrock

* fix: export options for bedrock conversations
  • Loading branch information
danny-avila committed Sep 10, 2024
1 parent 341e086 commit 1a1e685
Show file tree
Hide file tree
Showing 23 changed files with 441 additions and 203 deletions.
7 changes: 5 additions & 2 deletions .env.example
Original file line number Diff line number Diff line change
Expand Up @@ -125,8 +125,11 @@ BINGAI_TOKEN=user_provided
# See all Bedrock model IDs here: https://docs.aws.amazon.com/bedrock/latest/userguide/model-ids.html#model-ids-arns

# Notes on specific models:
# 'ai21.j2-mid-v1', # Not supported, as it doesn't support streaming
# 'ai21.j2-ultra-v1', # Not supported, as it doesn't support conversation history
# The following models are not support due to not supporting streaming:
# ai21.j2-mid-v1

# The following models are not support due to not supporting conversation history:
# ai21.j2-ultra-v1, cohere.command-text-v14, cohere.command-light-text-v14

#============#
# Google #
Expand Down
50 changes: 45 additions & 5 deletions api/server/routes/messages.js
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
const express = require('express');
const { ContentTypes } = require('librechat-data-provider');
const { saveConvo, saveMessage, getMessages, updateMessage, deleteMessages } = require('~/models');
const { requireJwtAuth, validateMessageReq } = require('~/server/middleware');
const { countTokens } = require('~/server/utils');
Expand Down Expand Up @@ -54,11 +55,50 @@ router.get('/:conversationId/:messageId', validateMessageReq, async (req, res) =

router.put('/:conversationId/:messageId', validateMessageReq, async (req, res) => {
try {
const { messageId, model } = req.params;
const { text } = req.body;
const tokenCount = await countTokens(text, model);
const result = await updateMessage(req, { messageId, text, tokenCount });
res.status(200).json(result);
const { conversationId, messageId } = req.params;
const { text, index, model } = req.body;

if (index === undefined) {
const tokenCount = await countTokens(text, model);
const result = await updateMessage(req, { messageId, text, tokenCount });
return res.status(200).json(result);
}

if (typeof index !== 'number' || index < 0) {
return res.status(400).json({ error: 'Invalid index' });
}

const message = (await getMessages({ conversationId, messageId }, 'content tokenCount'))?.[0];
if (!message) {
return res.status(404).json({ error: 'Message not found' });
}

const existingContent = message.content;
if (!Array.isArray(existingContent) || index >= existingContent.length) {
return res.status(400).json({ error: 'Invalid index' });
}

const updatedContent = [...existingContent];
if (!updatedContent[index]) {
return res.status(400).json({ error: 'Content part not found' });
}

if (updatedContent[index].type !== ContentTypes.TEXT) {
return res.status(400).json({ error: 'Cannot update non-text content' });
}

const oldText = updatedContent[index].text;
updatedContent[index] = { type: ContentTypes.TEXT, text };

let tokenCount = message.tokenCount;
if (tokenCount !== undefined) {
const oldTokenCount = await countTokens(oldText, model);
const newTokenCount = await countTokens(text, model);
tokenCount = Math.max(0, tokenCount - oldTokenCount) + newTokenCount;
}

const result = await updateMessage(req, { messageId, content: updatedContent, tokenCount });
return res.status(200).json(result);
} catch (error) {
logger.error('Error updating message:', error);
res.status(500).json({ error: 'Internal server error' });
Expand Down
4 changes: 4 additions & 0 deletions api/server/services/Endpoints/bedrock/initialize.js
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,10 @@ const initializeClient = async ({ req, res, endpointOption }) => {
model_parameters: endpointOption.model_parameters,
};

if (typeof endpointOption.artifactsPrompt === 'string' && endpointOption.artifactsPrompt) {
agent.instructions = `${agent.instructions ?? ''}\n${endpointOption.artifactsPrompt}`.trim();
}

let modelOptions = { model: agent.model };

// TODO: pass-in override settings that are specific to current run
Expand Down
8 changes: 6 additions & 2 deletions client/src/common/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -339,8 +339,12 @@ export type TAdditionalProps = {
export type TMessageContentProps = TInitialProps & TAdditionalProps;

export type TText = Pick<TInitialProps, 'text'> & { className?: string };
export type TEditProps = Pick<TInitialProps, 'text' | 'isSubmitting'> &
Omit<TAdditionalProps, 'isCreatedByUser'>;
export type TEditProps = Pick<TInitialProps, 'isSubmitting'> &
Omit<TAdditionalProps, 'isCreatedByUser' | 'siblingIdx'> & {
text?: string;
index?: number;
siblingIdx: number | null;
};
export type TDisplayProps = TText &
Pick<TAdditionalProps, 'isCreatedByUser' | 'message'> & {
showCursor?: boolean;
Expand Down
45 changes: 44 additions & 1 deletion client/src/components/Chat/Messages/Content/ContentParts.tsx
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import { memo } from 'react';
import { ContentTypes } from 'librechat-data-provider';
import type { TMessageContentParts } from 'librechat-data-provider';
import EditTextPart from './Parts/EditTextPart';
import Part from './Part';

type ContentPartsProps = {
Expand All @@ -8,13 +10,54 @@ type ContentPartsProps = {
isCreatedByUser: boolean;
isLast: boolean;
isSubmitting: boolean;
edit?: boolean;
enterEdit?: (cancel?: boolean) => void | null | undefined;
siblingIdx?: number;
setSiblingIdx?:
| ((value: number) => void | React.Dispatch<React.SetStateAction<number>>)
| null
| undefined;
};

const ContentParts = memo(
({ content, messageId, isCreatedByUser, isLast, isSubmitting }: ContentPartsProps) => {
({
content,
messageId,
isCreatedByUser,
isLast,
isSubmitting,
edit,
enterEdit,
siblingIdx,
setSiblingIdx,
}: ContentPartsProps) => {
if (!content) {
return null;
}
if (edit === true && enterEdit && setSiblingIdx) {
return (
<>
{content.map((part, idx) => {
if (part?.type !== ContentTypes.TEXT || typeof part.text !== 'string') {
return null;
}

return (
<EditTextPart
index={idx}
text={part.text}
messageId={messageId}
isSubmitting={isSubmitting}
enterEdit={enterEdit}
siblingIdx={siblingIdx ?? null}
setSiblingIdx={setSiblingIdx}
key={`edit-${messageId}-${idx}`}
/>
);
})}
</>
);
}
return (
<>
{content
Expand Down
83 changes: 41 additions & 42 deletions client/src/components/Chat/Messages/Content/EditMessage.tsx
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
import { useRecoilState } from 'recoil';
import TextareaAutosize from 'react-textarea-autosize';
import { useRecoilState, useRecoilValue } from 'recoil';
import { EModelEndpoint } from 'librechat-data-provider';
import { useState, useRef, useEffect, useCallback } from 'react';
import { useRef, useEffect, useCallback } from 'react';
import { useForm } from 'react-hook-form';
import { useUpdateMessageMutation } from 'librechat-data-provider/react-query';
import type { TEditProps } from '~/common';
import { useChatContext, useAddedChatContext } from '~/Providers';
import { TextareaAutosize } from '~/components/ui';
import { cn, removeFocusRings } from '~/utils';
import { useLocalize } from '~/hooks';
import Container from './Container';
Expand All @@ -25,7 +26,6 @@ const EditMessage = ({
store.latestMessageFamily(addedIndex),
);

const [editedText, setEditedText] = useState<string>(text ?? '');
const textAreaRef = useRef<HTMLTextAreaElement | null>(null);

const { conversationId, parentMessageId, messageId } = message;
Expand All @@ -34,6 +34,15 @@ const EditMessage = ({
const updateMessageMutation = useUpdateMessageMutation(conversationId ?? '');
const localize = useLocalize();

const chatDirection = useRecoilValue(store.chatDirection).toLowerCase();
const isRTL = chatDirection === 'rtl';

const { register, handleSubmit, setValue } = useForm({
defaultValues: {
text: text ?? '',
},
});

useEffect(() => {
const textArea = textAreaRef.current;
if (textArea) {
Expand All @@ -43,11 +52,11 @@ const EditMessage = ({
}
}, []);

const resubmitMessage = () => {
const resubmitMessage = (data: { text: string }) => {
if (message.isCreatedByUser) {
ask(
{
text: editedText,
text: data.text,
parentMessageId,
conversationId,
},
Expand All @@ -67,7 +76,7 @@ const EditMessage = ({
ask(
{ ...parentMessage },
{
editedText,
editedText: data.text,
editedMessageId: messageId,
isRegenerate: true,
isEdited: true,
Expand All @@ -80,32 +89,32 @@ const EditMessage = ({
enterEdit(true);
};

const updateMessage = () => {
const updateMessage = (data: { text: string }) => {
const messages = getMessages();
if (!messages) {
return;
}
updateMessageMutation.mutate({
conversationId: conversationId ?? '',
model: conversation?.model ?? 'gpt-3.5-turbo',
text: editedText,
text: data.text,
messageId,
});

if (message.messageId === latestMultiMessage?.messageId) {
setLatestMultiMessage({ ...latestMultiMessage, text: editedText });
setLatestMultiMessage({ ...latestMultiMessage, text: data.text });
}

const isInMessages = messages?.some((message) => message?.messageId === messageId);
const isInMessages = messages.some((message) => message.messageId === messageId);
if (!isInMessages) {
message.text = editedText;
message.text = data.text;
} else {
setMessages(
messages.map((msg) =>
msg.messageId === messageId
? {
...msg,
text: editedText,
text: data.text,
isEdited: true,
}
: msg,
Expand All @@ -126,43 +135,33 @@ const EditMessage = ({
[enterEdit],
);

const { ref, ...registerProps } = register('text', {
required: true,
onChange: (e) => {
setValue('text', e.target.value, { shouldValidate: true });
},
});

return (
<Container message={message}>
<div className="bg-token-main-surface-primary relative flex w-full flex-grow flex-col overflow-hidden rounded-2xl border dark:border-gray-600 dark:text-white [&:has(textarea:focus)]:border-gray-300 [&:has(textarea:focus)]:shadow-[0_2px_6px_rgba(0,0,0,.05)] dark:[&:has(textarea:focus)]:border-gray-500">
<div className="bg-token-main-surface-primary relative flex w-full flex-grow flex-col overflow-hidden rounded-2xl border border-border-medium text-text-primary [&:has(textarea:focus)]:border-border-heavy [&:has(textarea:focus)]:shadow-[0_2px_6px_rgba(0,0,0,.05)]">
<TextareaAutosize
ref={textAreaRef}
onChange={(e) => {
setEditedText(e.target.value);
{...registerProps}
ref={(e) => {
ref(e);
textAreaRef.current = e;
}}
onKeyDown={handleKeyDown}
data-testid="message-text-editor"
className={cn(
'markdown prose dark:prose-invert light whitespace-pre-wrap break-words',
'pl-3 md:pl-4',
'markdown prose dark:prose-invert light whitespace-pre-wrap break-words pl-3 md:pl-4',
'm-0 w-full resize-none border-0 bg-transparent py-[10px]',
'placeholder-black/50 focus:ring-0 focus-visible:ring-0 dark:bg-transparent dark:placeholder-white/50 md:py-3.5 ',
'pr-3 md:pr-4',
'max-h-[65vh] md:max-h-[75vh]',
'placeholder-text-secondary focus:ring-0 focus-visible:ring-0 md:py-3.5',
isRTL ? 'text-right' : 'text-left',
'max-h-[65vh] pr-3 md:max-h-[75vh] md:pr-4',
removeFocusRings,
)}
onPaste={(e) => {
e.preventDefault();

const pastedData = e.clipboardData.getData('text/plain');
const textArea = textAreaRef.current;
if (!textArea) {
return;
}
const start = textArea.selectionStart;
const end = textArea.selectionEnd;
const newValue =
textArea.value.substring(0, start) + pastedData + textArea.value.substring(end);
setEditedText(newValue);
}}
contentEditable={true}
value={editedText}
suppressContentEditableWarning={true}
dir="auto"
dir={isRTL ? 'rtl' : 'ltr'}
/>
</div>
<div className="mt-2 flex w-full justify-center text-center">
Expand All @@ -171,14 +170,14 @@ const EditMessage = ({
disabled={
isSubmitting || (endpoint === EModelEndpoint.google && !message.isCreatedByUser)
}
onClick={resubmitMessage}
onClick={handleSubmit(resubmitMessage)}
>
{localize('com_ui_save_submit')}
</button>
<button
className="btn btn-secondary relative mr-2"
disabled={isSubmitting}
onClick={updateMessage}
onClick={handleSubmit(updateMessage)}
>
{localize('com_ui_save')}
</button>
Expand Down
Loading

0 comments on commit 1a1e685

Please sign in to comment.