Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[v1] Implement Zustand to fix history states #54

Merged
merged 1 commit into from
May 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Binary file added bun.lockb
Binary file not shown.
244 changes: 166 additions & 78 deletions index.js
Original file line number Diff line number Diff line change
@@ -1,16 +1,21 @@
import { Client, GatewayIntentBits } from 'discord.js';
import { joinVoiceChannel, createAudioPlayer, createAudioResource } from '@discordjs/voice';

import fs from 'fs';
import path from 'path';
import { DateTime } from 'luxon';
import fetch from 'node-fetch';
import { Client, GatewayIntentBits } from "discord.js";
import {
joinVoiceChannel,
createAudioPlayer,
createAudioResource,
} from "@discordjs/voice";

import fs from "fs";
import path from "path";
import { DateTime } from "luxon";
import fetch from "node-fetch";

import { MsEdgeTTS } from "msedge-tts";

import * as dotenv from 'dotenv'
import * as dotenv from "dotenv";
import { createStore } from "zustand/vanilla";

dotenv.config()
dotenv.config();

const channels = process.env.CHANNELIDS.split(",");

Expand All @@ -22,10 +27,10 @@ const client = new Client({
GatewayIntentBits.GuildMembers,
GatewayIntentBits.GuildVoiceStates,
],
allowedMentions: { parse: [], repliedUser: false }
allowedMentions: { parse: [], repliedUser: false },
});

if (!fs.existsSync('./temp')) fs.mkdirSync('./temp');
if (!fs.existsSync("./temp")) fs.mkdirSync("./temp");

// Map to store the last message timestamp per person
const cooldowns = new Map();
Expand All @@ -37,10 +42,14 @@ client.on("ready", async () => {

function shouldIReply(message) {
if (!message.content && !message.attachments) return false;
if (Math.random() < process.env.REPLY_CHANCE && !channels.includes(message.channel.id) && !message.mentions.has(client.user.id)) return false;
if (
Math.random() < process.env.REPLY_CHANCE &&
!channels.includes(message.channel.id) &&
!message.mentions.has(client.user.id)
)
return false;
if (message.content.startsWith("!!")) return false;


// Check cooldown for the person who sent the message
const lastMessageTime = cooldowns.get(message.author.id);
if (lastMessageTime && Date.now() - lastMessageTime < 1000) return false; // Ignore the message if the cooldown hasn't expired
Expand All @@ -54,20 +63,21 @@ function shouldIReply(message) {
async function getPronouns(userid) {
// this is spagetti i'm sorry
try {
let response = await fetch(`https://pronoundb.org/api/v2/lookup?platform=discord&ids=${userid}`);
let response = await fetch(
`https://pronoundb.org/api/v2/lookup?platform=discord&ids=${userid}`
);

response = await response.json();


for (let userId in response) {
if (response[userId].sets.hasOwnProperty('en')) {
response[userId] = response[userId].sets['en'].join('/');
if (response[userId].sets.hasOwnProperty("en")) {
response[userId] = response[userId].sets["en"].join("/");
} else {
response[userId] = 'they/them';
response[userId] = "they/them";
}
}
if (!response.hasOwnProperty(userid)) {
response[userid] = 'they/them';
response[userid] = "they/them";
}

return response[userid];
Expand All @@ -76,29 +86,73 @@ async function getPronouns(userid) {
}
}

const character = fs.readFileSync('./character.txt', 'utf8').replace("\n", ' ');
const character = fs.readFileSync("./character.txt", "utf8").replace("\n", " ");

const initialHistory = [
{ role: "system", content: character },
{ role: "user", content: "lily (she/her) on May 14, 2024 at 12:55 AM UTC: hi sponge" },
{
role: "user",
content: "lily (she/her) on May 14, 2024 at 12:55 AM UTC: hi sponge",
},
{ role: "assistant", content: "hi lily! how are you today :3" },
{ role: "user", content: "zynsterr (they/them) on May 14, 2024 at 3:02 PM UTC: generate me an image of cheese puffs" },
{ role: "assistant", content: "sure thing!\n!gen picture of a bowl of cheese puffs on a table" }
{
role: "user",
content:
"zynsterr (they/them) on May 14, 2024 at 3:02 PM UTC: generate me an image of cheese puffs",
},
{
role: "assistant",
content: "sure thing!\n!gen picture of a bowl of cheese puffs on a table",
},
];
let lastMessage = "";
let history = initialHistory;

client.on("messageCreate", async message => {
global.initialHistory = initialHistory;
let lastMessage = "";
// let history = initialHistory;

const historyStore = createStore((set) => ({
history: [...initialHistory],
addMessage: ({ role, content }) =>
set((s) => ({
history: [
...s?.history,
{
role,
content,
},
],
})),
reset: () =>
set((s) => ({
history: [...initialHistory],
})),
}));

client.on("messageCreate", async (message) => {
if (message.author.bot) return;
if (!shouldIReply(message)) return;

try {
message.channel.sendTyping();
const store = historyStore.getState();

try {
// Conversation reset
if (message.content.startsWith("%reset")) {
message.reply(`♻️ Conversation history reset.`);
history = initialHistory
const totalAmountToReset =
store?.history?.length - initialHistory?.length;
const initialResetReply = await message.reply(
`♻️ Conversation history reset (clearing ${totalAmountToReset} entries).`
);
store?.reset();
initialResetReply.edit({
content: `♻️ Conversation history reset (cleared ${totalAmountToReset} entries).`,
});
return;
}

if (message.content.startsWith("%messages")) {
message.reply(
`ℹ️ ${store?.history?.length - initialHistory?.length} entries exist.`
);
return;
}

Expand All @@ -107,10 +161,17 @@ client.on("messageCreate", async message => {
return;
}

const imageDetails = await imageRecognition(message)
message.channel.sendTyping();
const imageDetails = await imageRecognition(message);

// Send message to CharacterAI
let formattedUserMessage = `${message.author.username} (${await getPronouns(message.author.id)}) on ${DateTime.now().setZone('utc').toLocaleString(DateTime.DATETIME_FULL)}: ${message.content}\n${imageDetails}`;
let formattedUserMessage = `${message.author.username} (${await getPronouns(
message.author.id
)}) on ${DateTime.now()
.setZone("utc")
.toLocaleString(DateTime.DATETIME_FULL)}: ${
message.content
}\n${imageDetails}`;
if (message.reference) {
await message.fetchReference().then(async (reply) => {
if (reply.author.id == client.user.id) {
Expand All @@ -119,50 +180,66 @@ client.on("messageCreate", async message => {
formattedUserMessage = `> ${reply.author.username}: ${reply}\n${formattedUserMessage}`;
}
});
};
}
lastMessage = formattedUserMessage;

message.channel.sendTyping();
history.push({ role: "user", content: formattedUserMessage });

store.addMessage({ role: "user", content: formattedUserMessage });
const input = {
messages: history,
messages: store.history,
max_tokens: 512,
};
let response = await fetch(`https://api.cloudflare.com/client/v4/accounts/${process.env.CF_ACCOUNT}/ai/run/@cf/meta/llama-3-8b-instruct`, {
method: 'POST',
headers: {
'Authorization': `Bearer ${process.env.CF_TOKEN}`,
'Content-Type': 'application/json'
},
body: JSON.stringify(input)
});
let response = await fetch(
`https://api.cloudflare.com/client/v4/accounts/${process.env.CF_ACCOUNT}/ai/run/@cf/meta/llama-3-8b-instruct`,
{
method: "POST",
headers: {
Authorization: `Bearer ${process.env.CF_TOKEN}`,
"Content-Type": "application/json",
},
body: JSON.stringify(input),
}
);
response = await response.json();
response = response.result.response
history.push({ role: "assistant", content: response });
if (history.length % 20 == 0) {
history.push({ role: "system", content: `reminder: ${character}` },);
response = response.result.response;
store.addMessage({ role: "assistant", content: response });
if (store.history.length % 20 == 0) {
store.addMessage({ role: "system", content: `reminder: ${character}` });
}

if (response == "") return message.reply(`❌ AI returned an empty response! Yell at someone idk.`);
if (response == "")
return message.reply(
`❌ AI returned an empty response! Yell at someone idk.`
);

let parts = response.split("!gen");

let trimmedResponse = parts[0].trim();

// Handle long responses
if (trimmedResponse.length >= 2000) {
fs.writeFileSync(path.resolve('./temp/how.txt'), trimmedResponse);
message.reply({ content: "", files: ["./temp/how.txt"], failIfNotExists: false });
fs.writeFileSync(path.resolve("./temp/how.txt"), trimmedResponse);
message.reply({
content: "",
files: ["./temp/how.txt"],
failIfNotExists: false,
});
return;
}

// Send AI response
let sentMessage
let sentMessage;
try {
sentMessage = await message.reply({ content: `${trimmedResponse}`, failIfNotExists: true });
sentMessage = await message.reply({
content: `${trimmedResponse}`,
failIfNotExists: true,
});
} catch (e) {
console.log(e);
sentMessage = await message.channel.send({ content: `\`\`\`\n${message.author.username}: ${message.content}\n\`\`\`\n\n${trimmedResponse}` });
sentMessage = await message.channel.send({
content: `\`\`\`\n${message.author.username}: ${message.content}\n\`\`\`\n\n${trimmedResponse}`,
});
}
if (response.includes("!gen")) {
selfImageGen(message, response, sentMessage);
Expand All @@ -180,7 +257,7 @@ client.on("messageCreate", async message => {

async function imageRecognition(message) {
if (message.attachments.size > 0) {
let imageDetails = '';
let imageDetails = "";

const res = await fetch(message.attachments.first().url);
const blob = await res.arrayBuffer();
Expand All @@ -191,56 +268,67 @@ async function imageRecognition(message) {
};

try {
let response = await fetch(`https://api.cloudflare.com/client/v4/accounts/${process.env.CF_ACCOUNT}/ai/run/@cf/llava-hf/llava-1.5-7b-hf`, {
method: 'POST',
headers: {
'Authorization': `Bearer ${process.env.CF_TOKEN}`,
'Content-Type': 'application/json'
},
body: JSON.stringify(input)
});
let response = await fetch(
`https://api.cloudflare.com/client/v4/accounts/${process.env.CF_ACCOUNT}/ai/run/@cf/llava-hf/llava-1.5-7b-hf`,
{
method: "POST",
headers: {
Authorization: `Bearer ${process.env.CF_TOKEN}`,
"Content-Type": "application/json",
},
body: JSON.stringify(input),
}
);
response = await response.json();

imageDetails += `Attached: image of${String(response.result.description).toLowerCase()}\n`;
imageDetails += `Attached: image of${String(
response.result.description
).toLowerCase()}\n`;
} catch (error) {
console.error(error);
return message.reply(`❌ Error in image recognition! Try again later.`);
};
}

return imageDetails;
} else {
return '';
return "";
}
}

async function selfImageGen(message, response, sentMessage) {
let parts = response.split("!gen");

let partBeforeGen = parts[0].trim();
let partAfterGen = parts[1].trim().replace('[', '').replace(']', '');
let partAfterGen = parts[1].trim().replace("[", "").replace("]", "");

try {
let response = await fetch(`https://api.cloudflare.com/client/v4/accounts/${process.env.CF_ACCOUNT}/ai/run/@cf/lykon/dreamshaper-8-lcm`, {
method: 'POST',
headers: {
'Authorization': `Bearer ${process.env.CF_TOKEN}`,
},
body: JSON.stringify({
prompt: partAfterGen
})
});
let response = await fetch(
`https://api.cloudflare.com/client/v4/accounts/${process.env.CF_ACCOUNT}/ai/run/@cf/lykon/dreamshaper-8-lcm`,
{
method: "POST",
headers: {
Authorization: `Bearer ${process.env.CF_TOKEN}`,
},
body: JSON.stringify({
prompt: partAfterGen,
}),
}
);
response = await response.arrayBuffer();
const imageBuffer = Buffer.from(response);
sentMessage.edit({ content: partBeforeGen, files: [imageBuffer] });
} catch (error) {
console.error(error);
return message.reply(`❌ Error in image generation! Try again later.`);
};
}
}

async function tts(message, text) {
const tts = new MsEdgeTTS();
await tts.setMetadata("en-US-AnaNeural", MsEdgeTTS.OUTPUT_FORMAT.AUDIO_24KHZ_96KBITRATE_MONO_MP3);
await tts.setMetadata(
"en-US-AnaNeural",
MsEdgeTTS.OUTPUT_FORMAT.AUDIO_24KHZ_96KBITRATE_MONO_MP3
);
const filePath = await tts.toFile("./temp/audio.mp3", text);

const channel = message.member.voice.channel;
Expand All @@ -257,7 +345,7 @@ async function tts(message, text) {
connection.subscribe(player);
player.play(resource);

player.on('error', error => {
player.on("error", (error) => {
console.error(`Audio Error: ${error.message}`);
});
}
Expand Down
Loading