diff --git a/src/services/bot/memory/index.ts b/src/services/bot/memory/index.ts index d85e678..1d21e6c 100644 --- a/src/services/bot/memory/index.ts +++ b/src/services/bot/memory/index.ts @@ -5,6 +5,7 @@ import { MemoryCRUD } from "../../db/memory"; import { ShortTermMemoryCRUD } from "../../db/memory-short-term"; import { LongTermMemoryCRUD } from "../../db/memory-long-term"; import { ShortTermMemoryAgent } from "./short-term"; +import { openai } from "../../openai"; export class MemoryManager { private room: Room; @@ -48,33 +49,57 @@ export class MemoryManager { return []; } + private _currentMemory?: Memory; async addMessage2Memory(message: Message) { // todo create memory embedding - const res = await MemoryCRUD.addOrUpdate({ + const currentMemory = await MemoryCRUD.addOrUpdate({ text: message.text, roomId: this.room.id, ownerId: message.senderId, }); + if (currentMemory) { + this._onMemory(currentMemory); + } + return currentMemory; + } + + private _onMemory(currentMemory: Memory) { + if (this._currentMemory) { + // 取消之前的更新记忆任务 + openai.abort(`update-short-memory-${this._currentMemory.id}`); + openai.abort(`update-long-memory-${this._currentMemory.id}`); + } + this._currentMemory = currentMemory; // 异步更新长短期记忆 - this.updateLongShortTermMemory(); - return res; + this.updateLongShortTermMemory({ currentMemory }); } /** * 更新记忆(当新的记忆数量超过阈值时,自动更新长短期记忆) */ - async updateLongShortTermMemory(options?: { + async updateLongShortTermMemory(options: { + currentMemory: Memory; shortThreshold?: number; longThreshold?: number; }) { - const { shortThreshold, longThreshold } = options ?? {}; - const success = await this._updateShortTermMemory(shortThreshold); + const { currentMemory, shortThreshold, longThreshold } = options ?? {}; + const success = await this._updateShortTermMemory({ + currentMemory, + threshold: shortThreshold, + }); if (success) { - await this._updateLongTermMemory(longThreshold); + await this._updateLongTermMemory({ + currentMemory, + threshold: longThreshold, + }); } } - private async _updateShortTermMemory(threshold = 10) { + private async _updateShortTermMemory(options: { + currentMemory: Memory; + threshold?: number; + }) { + const { currentMemory, threshold = 10 } = options; const lastMemory = firstOf(await this.getShortTermMemories(1)); const newMemories = await MemoryCRUD.gets({ cursorId: lastMemory?.cursorId, @@ -85,10 +110,11 @@ export class MemoryManager { if (newMemories.length < 1 || newMemories.length < threshold) { return true; } - const newMemory = await ShortTermMemoryAgent.generate( + const newMemory = await ShortTermMemoryAgent.generate({ + currentMemory, newMemories, - lastMemory - ); + lastMemory, + }); if (!newMemory) { return false; } @@ -101,7 +127,11 @@ export class MemoryManager { return res != null; } - private async _updateLongTermMemory(threshold = 10) { + private async _updateLongTermMemory(options: { + currentMemory: Memory; + threshold?: number; + }) { + const { currentMemory, threshold = 10 } = options; const lastMemory = firstOf(await this.getLongTermMemories(1)); const newMemories = await ShortTermMemoryCRUD.gets({ cursorId: lastMemory?.cursorId, @@ -112,10 +142,11 @@ export class MemoryManager { if (newMemories.length < 1 || newMemories.length < threshold) { return true; } - const newMemory = await LongTermMemoryAgent.generate( + const newMemory = await LongTermMemoryAgent.generate({ + currentMemory, newMemories, - lastMemory - ); + lastMemory, + }); if (!newMemory) { return false; } diff --git a/src/services/bot/memory/long-term.ts b/src/services/bot/memory/long-term.ts index 3bafb43..8c9f8ce 100644 --- a/src/services/bot/memory/long-term.ts +++ b/src/services/bot/memory/long-term.ts @@ -1,13 +1,18 @@ -import { Memory, ShortTermMemory } from "@prisma/client"; +import { LongTermMemory, Memory, ShortTermMemory } from "@prisma/client"; +import { openai } from "../../openai"; export class LongTermMemoryAgent { // todo 使用 LLM 生成新的长期记忆 - static async generate( - newMemories: Memory[], - lastLongTermMemory?: ShortTermMemory - ): Promise { - return `count: ${newMemories.length}\n${newMemories - .map((e, idx) => idx.toString() + ". " + e.text) - .join("\n")}`; + static async generate(options: { + currentMemory: Memory; + newMemories: ShortTermMemory[]; + lastMemory?: LongTermMemory; + }): Promise { + const { currentMemory, newMemories, lastMemory } = options; + const res = await openai.chat({ + user: "todo", // todo prompt + requestId: `update-long-memory-${currentMemory.id}`, + }); + return res?.content?.trim(); } } diff --git a/src/services/bot/memory/short-term.ts b/src/services/bot/memory/short-term.ts index d29e99a..0d1b869 100644 --- a/src/services/bot/memory/short-term.ts +++ b/src/services/bot/memory/short-term.ts @@ -1,13 +1,18 @@ -import { Memory } from "@prisma/client"; +import { Memory, ShortTermMemory } from "@prisma/client"; +import { openai } from "../../openai"; export class ShortTermMemoryAgent { // todo 使用 LLM 生成新的短期记忆 - static async generate( - newMemories: Memory[], - lastShortTermMemory?: Memory - ): Promise { - return `count: ${newMemories.length}\n${newMemories - .map((e, idx) => idx.toString() + ". " + e.text) - .join("\n")}`; + static async generate(options: { + currentMemory: Memory; + newMemories: Memory[]; + lastMemory?: ShortTermMemory; + }): Promise { + const { currentMemory, newMemories, lastMemory } = options; + const res = await openai.chat({ + user: "todo", // todo prompt + requestId: `update-short-memory-${currentMemory.id}`, + }); + return res?.content?.trim(); } } diff --git a/src/services/openai.ts b/src/services/openai.ts index efa4c98..5f38b33 100644 --- a/src/services/openai.ts +++ b/src/services/openai.ts @@ -12,6 +12,7 @@ export interface ChatOptions { system?: string; tools?: Array; jsonMode?: boolean; + requestId?: string; } class OpenAIClient { @@ -32,17 +33,26 @@ class OpenAIClient { } async chat(options: ChatOptions) { - const { user, system, tools, jsonMode } = options; + let { user, system, tools, jsonMode, requestId } = options; const systemMsg: ChatCompletionMessageParam[] = system ? [{ role: "system", content: system }] : []; + let signal: AbortSignal | undefined; + if (requestId) { + const controller = new AbortController(); + this._abortCallbacks[requestId] = () => controller.abort(); + signal = controller.signal; + } const chatCompletion = await this._client.chat.completions - .create({ - tools, - messages: [...systemMsg, { role: "user", content: user }], - model: kEnvs.OPENAI_MODEL ?? "gpt-3.5-turbo-0125", - response_format: jsonMode ? { type: "json_object" } : undefined, - }) + .create( + { + tools, + messages: [...systemMsg, { role: "user", content: user }], + model: kEnvs.OPENAI_MODEL ?? "gpt-3.5-turbo-0125", + response_format: jsonMode ? { type: "json_object" } : undefined, + }, + { signal } + ) .catch((e) => { console.error("❌ openai chat failed", e); return null; @@ -52,11 +62,10 @@ class OpenAIClient { async chatStream( options: ChatOptions & { - requestId?: string; onStream?: (text: string) => void; } ) { - const { user, system, tools, jsonMode, onStream, requestId } = options; + let { user, system, tools, jsonMode, requestId, onStream } = options; const systemMsg: ChatCompletionMessageParam[] = system ? [{ role: "system", content: system }] : []; diff --git a/tests/index.ts b/tests/index.ts index 6b933fe..3063642 100644 --- a/tests/index.ts +++ b/tests/index.ts @@ -12,9 +12,9 @@ dotenv.config(); async function main() { println(kBannerASCII); // testDB(); - testSpeaker(); + // testSpeaker(); // testOpenAI(); - // testMyBot(); + testMyBot(); } runWithDB(main);