feat: auto abort for updating long/short memory

This commit is contained in:
WJG 2024-02-25 18:05:11 +08:00
parent 4beac32a2a
commit 57a765af1b
No known key found for this signature in database
GPG Key ID: 258474EF8590014A
5 changed files with 92 additions and 42 deletions

View File

@ -5,6 +5,7 @@ import { MemoryCRUD } from "../../db/memory";
import { ShortTermMemoryCRUD } from "../../db/memory-short-term"; import { ShortTermMemoryCRUD } from "../../db/memory-short-term";
import { LongTermMemoryCRUD } from "../../db/memory-long-term"; import { LongTermMemoryCRUD } from "../../db/memory-long-term";
import { ShortTermMemoryAgent } from "./short-term"; import { ShortTermMemoryAgent } from "./short-term";
import { openai } from "../../openai";
export class MemoryManager { export class MemoryManager {
private room: Room; private room: Room;
@ -48,33 +49,57 @@ export class MemoryManager {
return []; return [];
} }
private _currentMemory?: Memory;
async addMessage2Memory(message: Message) { async addMessage2Memory(message: Message) {
// todo create memory embedding // todo create memory embedding
const res = await MemoryCRUD.addOrUpdate({ const currentMemory = await MemoryCRUD.addOrUpdate({
text: message.text, text: message.text,
roomId: this.room.id, roomId: this.room.id,
ownerId: message.senderId, ownerId: message.senderId,
}); });
if (currentMemory) {
this._onMemory(currentMemory);
}
return currentMemory;
}
private _onMemory(currentMemory: Memory) {
if (this._currentMemory) {
// 取消之前的更新记忆任务
openai.abort(`update-short-memory-${this._currentMemory.id}`);
openai.abort(`update-long-memory-${this._currentMemory.id}`);
}
this._currentMemory = currentMemory;
// 异步更新长短期记忆 // 异步更新长短期记忆
this.updateLongShortTermMemory(); this.updateLongShortTermMemory({ currentMemory });
return res;
} }
/** /**
* *
*/ */
async updateLongShortTermMemory(options?: { async updateLongShortTermMemory(options: {
currentMemory: Memory;
shortThreshold?: number; shortThreshold?: number;
longThreshold?: number; longThreshold?: number;
}) { }) {
const { shortThreshold, longThreshold } = options ?? {}; const { currentMemory, shortThreshold, longThreshold } = options ?? {};
const success = await this._updateShortTermMemory(shortThreshold); const success = await this._updateShortTermMemory({
currentMemory,
threshold: shortThreshold,
});
if (success) { if (success) {
await this._updateLongTermMemory(longThreshold); await this._updateLongTermMemory({
currentMemory,
threshold: longThreshold,
});
} }
} }
private async _updateShortTermMemory(threshold = 10) { private async _updateShortTermMemory(options: {
currentMemory: Memory;
threshold?: number;
}) {
const { currentMemory, threshold = 10 } = options;
const lastMemory = firstOf(await this.getShortTermMemories(1)); const lastMemory = firstOf(await this.getShortTermMemories(1));
const newMemories = await MemoryCRUD.gets({ const newMemories = await MemoryCRUD.gets({
cursorId: lastMemory?.cursorId, cursorId: lastMemory?.cursorId,
@ -85,10 +110,11 @@ export class MemoryManager {
if (newMemories.length < 1 || newMemories.length < threshold) { if (newMemories.length < 1 || newMemories.length < threshold) {
return true; return true;
} }
const newMemory = await ShortTermMemoryAgent.generate( const newMemory = await ShortTermMemoryAgent.generate({
currentMemory,
newMemories, newMemories,
lastMemory lastMemory,
); });
if (!newMemory) { if (!newMemory) {
return false; return false;
} }
@ -101,7 +127,11 @@ export class MemoryManager {
return res != null; return res != null;
} }
private async _updateLongTermMemory(threshold = 10) { private async _updateLongTermMemory(options: {
currentMemory: Memory;
threshold?: number;
}) {
const { currentMemory, threshold = 10 } = options;
const lastMemory = firstOf(await this.getLongTermMemories(1)); const lastMemory = firstOf(await this.getLongTermMemories(1));
const newMemories = await ShortTermMemoryCRUD.gets({ const newMemories = await ShortTermMemoryCRUD.gets({
cursorId: lastMemory?.cursorId, cursorId: lastMemory?.cursorId,
@ -112,10 +142,11 @@ export class MemoryManager {
if (newMemories.length < 1 || newMemories.length < threshold) { if (newMemories.length < 1 || newMemories.length < threshold) {
return true; return true;
} }
const newMemory = await LongTermMemoryAgent.generate( const newMemory = await LongTermMemoryAgent.generate({
currentMemory,
newMemories, newMemories,
lastMemory lastMemory,
); });
if (!newMemory) { if (!newMemory) {
return false; return false;
} }

View File

@ -1,13 +1,18 @@
import { Memory, ShortTermMemory } from "@prisma/client"; import { LongTermMemory, Memory, ShortTermMemory } from "@prisma/client";
import { openai } from "../../openai";
export class LongTermMemoryAgent { export class LongTermMemoryAgent {
// todo 使用 LLM 生成新的长期记忆 // todo 使用 LLM 生成新的长期记忆
static async generate( static async generate(options: {
newMemories: Memory[], currentMemory: Memory;
lastLongTermMemory?: ShortTermMemory newMemories: ShortTermMemory[];
): Promise<string | undefined> { lastMemory?: LongTermMemory;
return `count: ${newMemories.length}\n${newMemories }): Promise<string | undefined> {
.map((e, idx) => idx.toString() + ". " + e.text) const { currentMemory, newMemories, lastMemory } = options;
.join("\n")}`; const res = await openai.chat({
user: "todo", // todo prompt
requestId: `update-long-memory-${currentMemory.id}`,
});
return res?.content?.trim();
} }
} }

View File

@ -1,13 +1,18 @@
import { Memory } from "@prisma/client"; import { Memory, ShortTermMemory } from "@prisma/client";
import { openai } from "../../openai";
export class ShortTermMemoryAgent { export class ShortTermMemoryAgent {
// todo 使用 LLM 生成新的短期记忆 // todo 使用 LLM 生成新的短期记忆
static async generate( static async generate(options: {
newMemories: Memory[], currentMemory: Memory;
lastShortTermMemory?: Memory newMemories: Memory[];
): Promise<string | undefined> { lastMemory?: ShortTermMemory;
return `count: ${newMemories.length}\n${newMemories }): Promise<string | undefined> {
.map((e, idx) => idx.toString() + ". " + e.text) const { currentMemory, newMemories, lastMemory } = options;
.join("\n")}`; const res = await openai.chat({
user: "todo", // todo prompt
requestId: `update-short-memory-${currentMemory.id}`,
});
return res?.content?.trim();
} }
} }

View File

@ -12,6 +12,7 @@ export interface ChatOptions {
system?: string; system?: string;
tools?: Array<ChatCompletionTool>; tools?: Array<ChatCompletionTool>;
jsonMode?: boolean; jsonMode?: boolean;
requestId?: string;
} }
class OpenAIClient { class OpenAIClient {
@ -32,17 +33,26 @@ class OpenAIClient {
} }
async chat(options: ChatOptions) { async chat(options: ChatOptions) {
const { user, system, tools, jsonMode } = options; let { user, system, tools, jsonMode, requestId } = options;
const systemMsg: ChatCompletionMessageParam[] = system const systemMsg: ChatCompletionMessageParam[] = system
? [{ role: "system", content: system }] ? [{ role: "system", content: system }]
: []; : [];
let signal: AbortSignal | undefined;
if (requestId) {
const controller = new AbortController();
this._abortCallbacks[requestId] = () => controller.abort();
signal = controller.signal;
}
const chatCompletion = await this._client.chat.completions const chatCompletion = await this._client.chat.completions
.create({ .create(
tools, {
messages: [...systemMsg, { role: "user", content: user }], tools,
model: kEnvs.OPENAI_MODEL ?? "gpt-3.5-turbo-0125", messages: [...systemMsg, { role: "user", content: user }],
response_format: jsonMode ? { type: "json_object" } : undefined, model: kEnvs.OPENAI_MODEL ?? "gpt-3.5-turbo-0125",
}) response_format: jsonMode ? { type: "json_object" } : undefined,
},
{ signal }
)
.catch((e) => { .catch((e) => {
console.error("❌ openai chat failed", e); console.error("❌ openai chat failed", e);
return null; return null;
@ -52,11 +62,10 @@ class OpenAIClient {
async chatStream( async chatStream(
options: ChatOptions & { options: ChatOptions & {
requestId?: string;
onStream?: (text: string) => void; onStream?: (text: string) => void;
} }
) { ) {
const { user, system, tools, jsonMode, onStream, requestId } = options; let { user, system, tools, jsonMode, requestId, onStream } = options;
const systemMsg: ChatCompletionMessageParam[] = system const systemMsg: ChatCompletionMessageParam[] = system
? [{ role: "system", content: system }] ? [{ role: "system", content: system }]
: []; : [];

View File

@ -12,9 +12,9 @@ dotenv.config();
async function main() { async function main() {
println(kBannerASCII); println(kBannerASCII);
// testDB(); // testDB();
testSpeaker(); // testSpeaker();
// testOpenAI(); // testOpenAI();
// testMyBot(); testMyBot();
} }
runWithDB(main); runWithDB(main);