feat: auto abort for updating long/short memory

This commit is contained in:
WJG 2024-02-25 18:05:11 +08:00
parent 4beac32a2a
commit 57a765af1b
No known key found for this signature in database
GPG Key ID: 258474EF8590014A
5 changed files with 92 additions and 42 deletions

View File

@ -5,6 +5,7 @@ import { MemoryCRUD } from "../../db/memory";
import { ShortTermMemoryCRUD } from "../../db/memory-short-term";
import { LongTermMemoryCRUD } from "../../db/memory-long-term";
import { ShortTermMemoryAgent } from "./short-term";
import { openai } from "../../openai";
export class MemoryManager {
private room: Room;
@ -48,33 +49,57 @@ export class MemoryManager {
return [];
}
private _currentMemory?: Memory;
async addMessage2Memory(message: Message) {
// todo create memory embedding
const res = await MemoryCRUD.addOrUpdate({
const currentMemory = await MemoryCRUD.addOrUpdate({
text: message.text,
roomId: this.room.id,
ownerId: message.senderId,
});
if (currentMemory) {
this._onMemory(currentMemory);
}
return currentMemory;
}
private _onMemory(currentMemory: Memory) {
if (this._currentMemory) {
// 取消之前的更新记忆任务
openai.abort(`update-short-memory-${this._currentMemory.id}`);
openai.abort(`update-long-memory-${this._currentMemory.id}`);
}
this._currentMemory = currentMemory;
// 异步更新长短期记忆
this.updateLongShortTermMemory();
return res;
this.updateLongShortTermMemory({ currentMemory });
}
/**
*
*/
async updateLongShortTermMemory(options?: {
async updateLongShortTermMemory(options: {
currentMemory: Memory;
shortThreshold?: number;
longThreshold?: number;
}) {
const { shortThreshold, longThreshold } = options ?? {};
const success = await this._updateShortTermMemory(shortThreshold);
const { currentMemory, shortThreshold, longThreshold } = options ?? {};
const success = await this._updateShortTermMemory({
currentMemory,
threshold: shortThreshold,
});
if (success) {
await this._updateLongTermMemory(longThreshold);
await this._updateLongTermMemory({
currentMemory,
threshold: longThreshold,
});
}
}
private async _updateShortTermMemory(threshold = 10) {
private async _updateShortTermMemory(options: {
currentMemory: Memory;
threshold?: number;
}) {
const { currentMemory, threshold = 10 } = options;
const lastMemory = firstOf(await this.getShortTermMemories(1));
const newMemories = await MemoryCRUD.gets({
cursorId: lastMemory?.cursorId,
@ -85,10 +110,11 @@ export class MemoryManager {
if (newMemories.length < 1 || newMemories.length < threshold) {
return true;
}
const newMemory = await ShortTermMemoryAgent.generate(
const newMemory = await ShortTermMemoryAgent.generate({
currentMemory,
newMemories,
lastMemory
);
lastMemory,
});
if (!newMemory) {
return false;
}
@ -101,7 +127,11 @@ export class MemoryManager {
return res != null;
}
private async _updateLongTermMemory(threshold = 10) {
private async _updateLongTermMemory(options: {
currentMemory: Memory;
threshold?: number;
}) {
const { currentMemory, threshold = 10 } = options;
const lastMemory = firstOf(await this.getLongTermMemories(1));
const newMemories = await ShortTermMemoryCRUD.gets({
cursorId: lastMemory?.cursorId,
@ -112,10 +142,11 @@ export class MemoryManager {
if (newMemories.length < 1 || newMemories.length < threshold) {
return true;
}
const newMemory = await LongTermMemoryAgent.generate(
const newMemory = await LongTermMemoryAgent.generate({
currentMemory,
newMemories,
lastMemory
);
lastMemory,
});
if (!newMemory) {
return false;
}

View File

@ -1,13 +1,18 @@
import { Memory, ShortTermMemory } from "@prisma/client";
import { LongTermMemory, Memory, ShortTermMemory } from "@prisma/client";
import { openai } from "../../openai";
export class LongTermMemoryAgent {
// todo 使用 LLM 生成新的长期记忆
static async generate(
newMemories: Memory[],
lastLongTermMemory?: ShortTermMemory
): Promise<string | undefined> {
return `count: ${newMemories.length}\n${newMemories
.map((e, idx) => idx.toString() + ". " + e.text)
.join("\n")}`;
static async generate(options: {
currentMemory: Memory;
newMemories: ShortTermMemory[];
lastMemory?: LongTermMemory;
}): Promise<string | undefined> {
const { currentMemory, newMemories, lastMemory } = options;
const res = await openai.chat({
user: "todo", // todo prompt
requestId: `update-long-memory-${currentMemory.id}`,
});
return res?.content?.trim();
}
}

View File

@ -1,13 +1,18 @@
import { Memory } from "@prisma/client";
import { Memory, ShortTermMemory } from "@prisma/client";
import { openai } from "../../openai";
export class ShortTermMemoryAgent {
// todo 使用 LLM 生成新的短期记忆
static async generate(
newMemories: Memory[],
lastShortTermMemory?: Memory
): Promise<string | undefined> {
return `count: ${newMemories.length}\n${newMemories
.map((e, idx) => idx.toString() + ". " + e.text)
.join("\n")}`;
static async generate(options: {
currentMemory: Memory;
newMemories: Memory[];
lastMemory?: ShortTermMemory;
}): Promise<string | undefined> {
const { currentMemory, newMemories, lastMemory } = options;
const res = await openai.chat({
user: "todo", // todo prompt
requestId: `update-short-memory-${currentMemory.id}`,
});
return res?.content?.trim();
}
}

View File

@ -12,6 +12,7 @@ export interface ChatOptions {
system?: string;
tools?: Array<ChatCompletionTool>;
jsonMode?: boolean;
requestId?: string;
}
class OpenAIClient {
@ -32,17 +33,26 @@ class OpenAIClient {
}
async chat(options: ChatOptions) {
const { user, system, tools, jsonMode } = options;
let { user, system, tools, jsonMode, requestId } = options;
const systemMsg: ChatCompletionMessageParam[] = system
? [{ role: "system", content: system }]
: [];
let signal: AbortSignal | undefined;
if (requestId) {
const controller = new AbortController();
this._abortCallbacks[requestId] = () => controller.abort();
signal = controller.signal;
}
const chatCompletion = await this._client.chat.completions
.create({
tools,
messages: [...systemMsg, { role: "user", content: user }],
model: kEnvs.OPENAI_MODEL ?? "gpt-3.5-turbo-0125",
response_format: jsonMode ? { type: "json_object" } : undefined,
})
.create(
{
tools,
messages: [...systemMsg, { role: "user", content: user }],
model: kEnvs.OPENAI_MODEL ?? "gpt-3.5-turbo-0125",
response_format: jsonMode ? { type: "json_object" } : undefined,
},
{ signal }
)
.catch((e) => {
console.error("❌ openai chat failed", e);
return null;
@ -52,11 +62,10 @@ class OpenAIClient {
async chatStream(
options: ChatOptions & {
requestId?: string;
onStream?: (text: string) => void;
}
) {
const { user, system, tools, jsonMode, onStream, requestId } = options;
let { user, system, tools, jsonMode, requestId, onStream } = options;
const systemMsg: ChatCompletionMessageParam[] = system
? [{ role: "system", content: system }]
: [];

View File

@ -12,9 +12,9 @@ dotenv.config();
async function main() {
println(kBannerASCII);
// testDB();
testSpeaker();
// testSpeaker();
// testOpenAI();
// testMyBot();
testMyBot();
}
runWithDB(main);