From 0d9bc01dc98e3110d6ea1235b7dfde6a7a398655 Mon Sep 17 00:00:00 2001 From: WJG Date: Mon, 26 Feb 2024 01:06:37 +0800 Subject: [PATCH] feat: auto update long/short term memory --- TODO.md | 2 +- .../migration.sql | 3 +- prisma/schema.prisma | 4 +- src/services/bot/conversation.ts | 16 +- src/services/bot/index.ts | 182 ++++++++++++------ src/services/bot/memory/index.ts | 53 ++--- src/services/bot/memory/long-term.ts | 60 +++++- src/services/bot/memory/short-term.ts | 75 +++++++- src/services/db/memory.ts | 32 ++- src/services/db/message.ts | 10 +- src/services/db/room.ts | 8 +- src/services/db/user.ts | 12 +- src/services/openai.ts | 82 ++++++-- src/services/speaker/ai.ts | 2 +- src/services/speaker/speaker.ts | 19 +- src/utils/string.ts | 31 +-- tests/bot.ts | 60 +++++- tests/db.ts | 12 ++ 18 files changed, 513 insertions(+), 150 deletions(-) rename prisma/migrations/{20240130132305_hello => 20240225141130_hello}/migration.sql (95%) diff --git a/TODO.md b/TODO.md index 0acb478..73f7278 100644 --- a/TODO.md +++ b/TODO.md @@ -1,4 +1,4 @@ - ✅ Auto mute XiaoAi reply (not perfect yet) - ✅ Stream response - ✅ Deactivate Xiaoai -- Update long/short memories +- ✅ Update long/short memories diff --git a/prisma/migrations/20240130132305_hello/migration.sql b/prisma/migrations/20240225141130_hello/migration.sql similarity index 95% rename from prisma/migrations/20240130132305_hello/migration.sql rename to prisma/migrations/20240225141130_hello/migration.sql index fa0e7f8..01b4a33 100644 --- a/prisma/migrations/20240130132305_hello/migration.sql +++ b/prisma/migrations/20240225141130_hello/migration.sql @@ -31,11 +31,12 @@ CREATE TABLE "Message" ( -- CreateTable CREATE TABLE "Memory" ( "id" INTEGER NOT NULL PRIMARY KEY AUTOINCREMENT, - "text" TEXT NOT NULL, + "msgId" INTEGER NOT NULL, "ownerId" TEXT, "roomId" TEXT NOT NULL, "createdAt" DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP, "updatedAt" DATETIME NOT NULL, + CONSTRAINT "Memory_msgId_fkey" FOREIGN KEY ("msgId") REFERENCES "Message" ("id") ON DELETE RESTRICT ON UPDATE CASCADE, CONSTRAINT "Memory_ownerId_fkey" FOREIGN KEY ("ownerId") REFERENCES "User" ("id") ON DELETE SET NULL ON UPDATE CASCADE, CONSTRAINT "Memory_roomId_fkey" FOREIGN KEY ("roomId") REFERENCES "Room" ("id") ON DELETE RESTRICT ON UPDATE CASCADE ); diff --git a/prisma/schema.prisma b/prisma/schema.prisma index 2520639..a48dc48 100644 --- a/prisma/schema.prisma +++ b/prisma/schema.prisma @@ -48,6 +48,7 @@ model Message { senderId String room Room @relation(fields: [roomId], references: [id]) roomId String + memories Memory[] // 时间日期 createdAt DateTime @default(now()) updatedAt DateTime @updatedAt @@ -55,8 +56,9 @@ model Message { model Memory { id Int @id @default(autoincrement()) - text String // 关联数据 + msg Message @relation(fields: [msgId], references: [id]) + msgId Int owner User? @relation(fields: [ownerId], references: [id]) // owner 为空时,即房间自己的公共记忆 ownerId String? room Room @relation(fields: [roomId], references: [id]) diff --git a/src/services/bot/conversation.ts b/src/services/bot/conversation.ts index 3f34037..fec0bb0 100644 --- a/src/services/bot/conversation.ts +++ b/src/services/bot/conversation.ts @@ -17,6 +17,7 @@ export class ConversationManager { } return { ...config, + // 记忆存储在公共 room 上 memory: new MemoryManager(config.room), }; } @@ -43,18 +44,25 @@ export class ConversationManager { return MessageCRUD.gets({ room, ...options }); } - async onMessage(payload: { sender: User; text: string }) { - const { sender, text } = payload; + async onMessage( + payload: IBotConfig & { + sender: User; + text: string; + timestamp?: number; + } + ) { + const { sender, text, timestamp = Date.now(), ...botConfig } = payload; const { room, memory } = await this.get(); if (memory) { const message = await MessageCRUD.addOrUpdate({ text, roomId: room!.id, senderId: sender.id, + createdAt: new Date(timestamp), }); if (message) { - // 异步加入记忆 - memory?.addMessage2Memory(message); + // 异步加入记忆(到 room) + memory?.addMessage2Memory(message,botConfig); return message; } } diff --git a/src/services/bot/index.ts b/src/services/bot/index.ts index fc19688..35a1d6a 100644 --- a/src/services/bot/index.ts +++ b/src/services/bot/index.ts @@ -1,43 +1,68 @@ import { randomUUID } from "crypto"; -import { jsonDecode, jsonEncode } from "../../utils/base"; -import { buildPrompt, toUTC8Time } from "../../utils/string"; +import { buildPrompt, formatMsg } from "../../utils/string"; import { ChatOptions, openai } from "../openai"; import { IBotConfig } from "./config"; import { ConversationManager } from "./conversation"; import { StreamResponse } from "../speaker/stream"; +import { QueryMessage, SpeakerAnswer } from "../speaker/speaker"; +import { AISpeaker } from "../speaker/ai"; +import { DeepPartial } from "../../utils/type"; // todo JSON mode 下,无法使用 stream 应答模式(在应答完成之前,无法构造完整的JSON) + const systemTemplate = ` -忽略所有之前的文字、文件和说明。现在,你将扮演一个名为“{{name}}”的人,并以这个新身份回复所有新消息。 +请重置所有之前的上下文、文件和指令。现在,你将扮演一个名为{{botName}}的角色,使用第一人称视角回复消息。 -你是一个角色扮演大师,现在你的身份是“{{name}}”,你的个人简介如下: - -{{profile}} - +## 关于你 +你的名字是{{botName}}。下面是你的个人简介: + +{{botProfile}} + -你正在和“{{wechatName}}”一对一聊天,你的微信账号名是“{{wechatName}}”,你的微信备注是“{{wechatAlias}}”。 +## 你的对话伙伴 +你正在与{{masterName}}进行对话。这是关于{{masterName}}的一些信息: + +{{masterProfile}} + -最近几条消息记录是: - -{{history}} - +## 你们的群组 +你和{{masterName}}所在的群组名为{{roomName}}。这是群组的简介: + +{{roomIntroduction}} + -更早之前的一些信息: - -{{contexts}} - +## 聊天历史回顾 +为了更好地接入对话,请回顾你们之间的最近几条消息: + +{{messages}} + -现在,请你作为“{{name}}”继续聊天,并判断是否需要回复消息。 +## 短期记忆 +你可以记住一些短期内的细节,以便更紧密地贴合话题: + +{{shortTermMemory}} + -在回复时,请确保: -- 你就是“{{name}}”本人,不是他的数字克隆人。 -- 请你根据角色简介中“{{name}}”的性格特点、说话风格和兴趣爱好,来模仿并回复消息。 -- 和朋友间的聊天是轻松的,有个性的,不要太一本正经的长篇大论,尽量保持回答的简短。 -- 参照个人简介、聊天记录和聊天背景信息中的细节,如当前时间、天气、过去的事件和未来的计划等,以确保对话的一致性和相关性。 -- 如果你无法从“{{name}}”的个人简介中找到相关的信息(比如人物关系,个人经历,事物细节等),请如实回答自己不清楚或忘记了,不要随意编造。 +## 长期记忆 +你还保留着一些长期的记忆,这有助于让你的对话更加丰富和连贯: + +{{longTermMemory}} + -按照下面的JSON格式回复: -{ message: "回复的消息内容" } +## 回复指南 +在回复{{masterName}}时,请遵循以下准则: +- 认为自己正是{{botName}},拥有独立的个性、兴趣和爱好。 +- 根据你的性格特点、说话风格和兴趣爱好与{{masterName}}进行交流。 +- 保持对话轻松友好,回复简洁有趣,同时耐心倾听和关心对方。 +- 参考双方的个人简介、聊天记录和记忆中的信息,确保对话贴近实际,保持一致性和相关性。 +- 如果对某些信息不确定或遗忘,诚实地表达你的不清楚或遗忘状态,避免编造信息。 + +## 回复示例 +例如,如果{{masterName}}问你是谁,你可以这样回答: +我是{{botName}}。 + +## 开始 +请以{{botName}}的身份,直接回复{{masterName}}的新消息,继续你们之间的对话。 `.trim(); const userTemplate = ` @@ -45,47 +70,88 @@ const userTemplate = ` `.trim(); export class MyBot { - private manager: ConversationManager; - constructor(config: IBotConfig) { + speaker: AISpeaker; + manager: ConversationManager; + constructor(config: DeepPartial & { speaker: AISpeaker }) { + this.speaker = config.speaker; this.manager = new ConversationManager(config); } - async ask(msg: string) { + stop() { + return this.speaker.stop(); + } + + run() { + this.speaker.askAI = (msg) => this.ask(msg); + return this.speaker.run(); + } + + async ask(msg: QueryMessage): Promise { const { bot, master, room, memory } = await this.manager.get(); if (!memory) { - return; + return {}; } - const lastMessages = await this.manager.getMessages({ - take: 10, + const lastMessages = await this.manager.getMessages({ take: 10 }); + const shortTermMemories = await memory.getShortTermMemories({ take: 1 }); + const shortTermMemory = shortTermMemories[0]?.text ?? "短期记忆为空"; + const longTermMemories = await memory.getLongTermMemories({ take: 1 }); + const longTermMemory = longTermMemories[0]?.text ?? "长期记忆为空"; + const systemPrompt = buildPrompt(systemTemplate, { + shortTermMemory, + longTermMemory, + botName: bot!.name, + botProfile: bot!.profile, + masterName: master!.name, + masterProfile: master!.profile, + roomName: room!.name, + roomIntroduction: room!.description, + messages: + lastMessages.length < 1 + ? "暂无历史消息" + : lastMessages + .map((e) => + formatMsg({ + name: e.sender.name, + text: e.text, + timestamp: e.createdAt.getTime(), + }) + ) + .join("\n"), }); - const result = await openai.chat({ - system: buildPrompt(systemTemplate, { - bot_name: bot!.name, - bot_profile: bot!.profile, - master_name: master!.name, - master_profile: master!.profile, - history: - lastMessages.length < 1 - ? "暂无" - : lastMessages - .map((e) => - jsonEncode({ - time: toUTC8Time(e.createdAt), - user: e.sender.name, - message: e.text, - }) - ) - .join("\n"), - }), - user: buildPrompt(userTemplate, { - message: jsonEncode({ - time: toUTC8Time(new Date()), - user: master!.name, - message: msg, - })!, + const userPrompt = buildPrompt(userTemplate, { + message: formatMsg({ + name: master!.name, + text: msg.text, + timestamp: msg.timestamp, }), }); - return jsonDecode(result?.content)?.message; + // 添加请求消息到 DB + await this.manager.onMessage({ + bot: bot!, + master: master!, + room: room!, + sender: master!, + text: msg.text, + timestamp: msg.timestamp, + }); + const stream = await MyBot.chatWithStreamResponse({ + system: systemPrompt, + user: userPrompt, + onFinished: async (text) => { + if (text) { + // 添加响应消息到 DB + await this.manager.onMessage({ + bot: bot!, + master: master!, + room: room!, + text, + sender: bot!, + timestamp: Date.now(), + }); + } + }, + }); + return { stream }; } static async chatWithStreamResponse( @@ -94,7 +160,7 @@ export class MyBot { } ) { const requestId = randomUUID(); - const stream = new StreamResponse(); + const stream = new StreamResponse({ firstSubmitTimeout: 5 * 1000 }); openai .chatStream({ ...options, diff --git a/src/services/bot/memory/index.ts b/src/services/bot/memory/index.ts index 1d21e6c..8559c12 100644 --- a/src/services/bot/memory/index.ts +++ b/src/services/bot/memory/index.ts @@ -6,6 +6,7 @@ import { ShortTermMemoryCRUD } from "../../db/memory-short-term"; import { LongTermMemoryCRUD } from "../../db/memory-long-term"; import { ShortTermMemoryAgent } from "./short-term"; import { openai } from "../../openai"; +import { IBotConfig } from "../config"; export class MemoryManager { private room: Room; @@ -20,27 +21,23 @@ export class MemoryManager { this.owner = owner; } - async getMemories(take?: number) { - return MemoryCRUD.gets({ - room: this.room, - owner: this.owner, - take, - }); + async getMemories(options?: { take?: number }) { + return MemoryCRUD.gets({ ...options, room: this.room, owner: this.owner }); } - async getShortTermMemories(take?: number) { + async getShortTermMemories(options?: { take?: number }) { return ShortTermMemoryCRUD.gets({ + ...options, room: this.room, owner: this.owner, - take, }); } - async getLongTermMemories(take?: number) { + async getLongTermMemories(options?: { take?: number }) { return LongTermMemoryCRUD.gets({ + ...options, room: this.room, owner: this.owner, - take, }); } @@ -50,20 +47,20 @@ export class MemoryManager { } private _currentMemory?: Memory; - async addMessage2Memory(message: Message) { + async addMessage2Memory(message: Message, botConfig: IBotConfig) { // todo create memory embedding const currentMemory = await MemoryCRUD.addOrUpdate({ - text: message.text, + msgId: message.id, roomId: this.room.id, ownerId: message.senderId, }); if (currentMemory) { - this._onMemory(currentMemory); + this._onMemory(currentMemory, botConfig); } return currentMemory; } - private _onMemory(currentMemory: Memory) { + private _onMemory(currentMemory: Memory, botConfig: IBotConfig) { if (this._currentMemory) { // 取消之前的更新记忆任务 openai.abort(`update-short-memory-${this._currentMemory.id}`); @@ -71,24 +68,28 @@ export class MemoryManager { } this._currentMemory = currentMemory; // 异步更新长短期记忆 - this.updateLongShortTermMemory({ currentMemory }); + this.updateLongShortTermMemory({ currentMemory, botConfig }); } /** * 更新记忆(当新的记忆数量超过阈值时,自动更新长短期记忆) */ async updateLongShortTermMemory(options: { + botConfig: IBotConfig; currentMemory: Memory; shortThreshold?: number; longThreshold?: number; }) { - const { currentMemory, shortThreshold, longThreshold } = options ?? {}; + const { currentMemory, shortThreshold, longThreshold, botConfig } = + options ?? {}; const success = await this._updateShortTermMemory({ + botConfig, currentMemory, threshold: shortThreshold, }); if (success) { await this._updateLongTermMemory({ + botConfig, currentMemory, threshold: longThreshold, }); @@ -96,21 +97,27 @@ export class MemoryManager { } private async _updateShortTermMemory(options: { + botConfig: IBotConfig; currentMemory: Memory; threshold?: number; }) { - const { currentMemory, threshold = 10 } = options; - const lastMemory = firstOf(await this.getShortTermMemories(1)); - const newMemories = await MemoryCRUD.gets({ + const { currentMemory, threshold = 10, botConfig } = options; + const lastMemory = firstOf(await this.getShortTermMemories({ take: 1 })); + const newMemories: (Memory & { + msg: Message & { + sender: User; + }; + })[] = (await MemoryCRUD.gets({ cursorId: lastMemory?.cursorId, room: this.room, owner: this.owner, order: "asc", // 从旧到新排序 - }); + })) as any; if (newMemories.length < 1 || newMemories.length < threshold) { return true; } const newMemory = await ShortTermMemoryAgent.generate({ + botConfig, currentMemory, newMemories, lastMemory, @@ -128,11 +135,12 @@ export class MemoryManager { } private async _updateLongTermMemory(options: { + botConfig: IBotConfig; currentMemory: Memory; threshold?: number; }) { - const { currentMemory, threshold = 10 } = options; - const lastMemory = firstOf(await this.getLongTermMemories(1)); + const { currentMemory, threshold = 10, botConfig } = options; + const lastMemory = firstOf(await this.getLongTermMemories({ take: 1 })); const newMemories = await ShortTermMemoryCRUD.gets({ cursorId: lastMemory?.cursorId, room: this.room, @@ -143,6 +151,7 @@ export class MemoryManager { return true; } const newMemory = await LongTermMemoryAgent.generate({ + botConfig, currentMemory, newMemories, lastMemory, diff --git a/src/services/bot/memory/long-term.ts b/src/services/bot/memory/long-term.ts index 8c9f8ce..515e06c 100644 --- a/src/services/bot/memory/long-term.ts +++ b/src/services/bot/memory/long-term.ts @@ -1,18 +1,70 @@ import { LongTermMemory, Memory, ShortTermMemory } from "@prisma/client"; import { openai } from "../../openai"; +import { buildPrompt } from "../../../utils/string"; +import { jsonDecode, lastOf } from "../../../utils/base"; +import { IBotConfig } from "../config"; + +const userTemplate = ` +重置所有上下文和指令。 + +作为一个记忆管理专家,你的职责是精确地记录和维护{{botName}}与{{masterName}}之间对话的长期记忆内容。 + +## 长期记忆库 +这里保存了关键的长期信息,包括但不限于季节变化、地理位置、对话参与者的偏好、行为动态、取得的成果以及未来规划等: + +{{longTermMemory}} + + +## 最近短期记忆回顾 +下面展示了{{masterName}}与{{botName}}最新的短期记忆,以便你更新和优化长期记忆: + +{{shortTermMemory}} + + +## 更新指南 +更新长期记忆时,请确保遵循以下原则: +- 准确记录关键的时间、地点、参与者行为、偏好、成果、观点及计划。 +- 记忆应与时间同步更新,保持新信息的优先级,逐步淡化或去除不再相关的记忆内容。 +- 基于最新短期记忆,筛选并更新重要信息,淘汰陈旧或次要的长期记忆。 +- 长期记忆内容的总字符数应控制在1000以内。 + +## 长期记忆示例 +长期记忆可能包含多项信息,以下是一个示例: + +- 2022/02/11:{{masterName}}偏爱西瓜,梦想成为科学家。 +- 2022/03/21:{{masterName}}与{{botName}}首次会面。 +- 2022/03/21:{{masterName}}喜欢被{{botName}}称作宝贝,反感被叫做笨蛋。 +- 2022/06/01:{{masterName}}庆祝20岁生日,身高达到1.8米。 +- 2022/12/01:{{masterName}}计划高三毕业后购买自行车。 +- 2023/09/21:{{masterName}}成功考入清华大学数学系,并购得首辆公路自行车。 + + +## 回复格式 +请按照以下JSON格式回复,以更新长期记忆: +{"longTermMemories": "这里填写更新后的长期记忆内容"} + +## 任务开始 +现在,请根据提供的旧长期记忆和最新短期记忆,进行长期记忆的更新。 +`.trim(); export class LongTermMemoryAgent { - // todo 使用 LLM 生成新的长期记忆 static async generate(options: { + botConfig: IBotConfig; currentMemory: Memory; newMemories: ShortTermMemory[]; lastMemory?: LongTermMemory; }): Promise { - const { currentMemory, newMemories, lastMemory } = options; + const { currentMemory, newMemories, lastMemory, botConfig } = options; const res = await openai.chat({ - user: "todo", // todo prompt + jsonMode: true, requestId: `update-long-memory-${currentMemory.id}`, + user: buildPrompt(userTemplate, { + masterName: botConfig.master.name, + botName: botConfig.bot.name, + longTermMemory: lastMemory?.text ?? "暂无长期记忆", + shortTermMemory: lastOf(newMemories)!.text, + }), }); - return res?.content?.trim(); + return jsonDecode(res?.content)?.longTermMemories; } } diff --git a/src/services/bot/memory/short-term.ts b/src/services/bot/memory/short-term.ts index 0d1b869..b93839f 100644 --- a/src/services/bot/memory/short-term.ts +++ b/src/services/bot/memory/short-term.ts @@ -1,18 +1,81 @@ -import { Memory, ShortTermMemory } from "@prisma/client"; +import { Memory, Message, ShortTermMemory, User } from "@prisma/client"; import { openai } from "../../openai"; +import { buildPrompt, formatMsg } from "../../../utils/string"; +import { jsonDecode } from "../../../utils/base"; +import { IBotConfig } from "../config"; + +const userTemplate = ` +请忘记所有之前的上下文、文件和指令。 + +你现在是一个记忆大师,你的工作是记录和整理{{botName}}与{{masterName}}对话中的短期记忆(即上下文)。 + +## 旧的短期记忆 +在这里,你存储了一些近期的重要细节,比如正在讨论的话题、参与者的行为、得到的结果、未来的计划等: + +{{shortTermMemory}} + + +## 最新对话 +为了帮助你更新短期记忆,这里提供了{{masterName}}和{{botName}}之间的最近几条对话消息: + +{{messages}} + + +## 更新规则 +更新短期记忆时,请遵循以下规则: +- 精确记录当前话题及其相关的时间、地点、参与者行为、偏好、结果、观点和计划。 +- 记忆应与时间同步更新,保持新信息的优先级,逐步淡化或去除不再相关的记忆内容。 +- 基于最新的对话消息,筛选并更新重要信息,淘汰陈旧或次要的短期记忆。 +- 保持短期记忆的总字符数不超过1000。 + +## 短期记忆示例 +短期记忆可能包含多项信息,以下是一个示例: + +- 2023/12/01 08:00:{{masterName}}和{{botName}}正在讨论明天的天气预报。 +- 2023/12/01 08:10:{{masterName}}认为明天会下雨,而{{botName}}预测会下雪。 +- 2023/12/01 09:00:实际上下了雨,{{masterName}}的预测正确。 +- 2023/12/01 09:15:{{masterName}}表示喜欢吃香蕉,计划雨停后与{{botName}}乘坐地铁去购买。 +- 2023/12/01 10:00:雨已停,{{masterName}}有些失落,因为他更喜欢雨天。他已经吃了三根香蕉,还留了一根给{{botName}}。 + + +## 回复格式 +请使用以下JSON格式回复更新后的短期记忆: +{"shortTermMemories": "更新后的短期记忆内容"} + +## 开始 +现在,请根据提供的旧短期记忆和最新对话消息,更新短期记忆。 +`.trim(); export class ShortTermMemoryAgent { - // todo 使用 LLM 生成新的短期记忆 static async generate(options: { + botConfig: IBotConfig; currentMemory: Memory; - newMemories: Memory[]; + newMemories: (Memory & { + msg: Message & { + sender: User; + }; + })[]; lastMemory?: ShortTermMemory; }): Promise { - const { currentMemory, newMemories, lastMemory } = options; + const { currentMemory, newMemories, lastMemory, botConfig } = options; const res = await openai.chat({ - user: "todo", // todo prompt + jsonMode: true, requestId: `update-short-memory-${currentMemory.id}`, + user: buildPrompt(userTemplate, { + masterName: botConfig.master.name, + botName: botConfig.bot.name, + shortTermMemory: lastMemory?.text ?? "暂无短期记忆", + messages: newMemories + .map((e) => + formatMsg({ + name: e.msg.sender.name, + text: e.msg.text, + timestamp: e.createdAt.getTime(), + }) + ) + .join("\n"), + }), }); - return res?.content?.trim(); + return jsonDecode(res?.content)?.shortTermMemories; } } diff --git a/src/services/db/memory.ts b/src/services/db/memory.ts index bbae76b..76f18c3 100644 --- a/src/services/db/memory.ts +++ b/src/services/db/memory.ts @@ -1,4 +1,4 @@ -import { Memory, Room, User } from "@prisma/client"; +import { Memory, Prisma, Room, User } from "@prisma/client"; import { getSkipWithCursor, k404, kPrisma } from "./index"; import { removeEmpty } from "../../utils/base"; @@ -19,8 +19,20 @@ class _MemoryCRUD { }); } - async get(id: number) { - return kPrisma.memory.findFirst({ where: { id } }).catch((e) => { + async get( + id: number, + options?: { + include?: Prisma.MemoryInclude; + } + ) { + const { + include = { + msg: { + include: { sender: true }, + }, + }, + } = options ?? {}; + return kPrisma.memory.findFirst({ where: { id }, include }).catch((e) => { console.error("❌ get memory failed", id, e); return undefined; }); @@ -32,6 +44,7 @@ class _MemoryCRUD { take?: number; skip?: number; cursorId?: number; + include?: Prisma.MemoryInclude; /** * 查询顺序(返回按从旧到新排序) */ @@ -43,12 +56,18 @@ class _MemoryCRUD { take = 10, skip = 0, cursorId, + include = { + msg: { + include: { sender: true }, + }, + }, order = "desc", } = options ?? {}; const memories = await kPrisma.memory .findMany({ where: removeEmpty({ roomId: room?.id, ownerId: owner?.id }), take, + include, orderBy: { createdAt: order }, ...getSkipWithCursor(skip, cursorId), }) @@ -61,15 +80,14 @@ class _MemoryCRUD { async addOrUpdate( memory: Partial & { - text: string; + msgId: number; roomId: string; ownerId?: string; } ) { - const { text: _text, roomId, ownerId } = memory; - const text = _text?.trim(); + const { msgId, roomId, ownerId } = memory; const data = { - text, + msg: { connect: { id: msgId } }, room: { connect: { id: roomId } }, owner: ownerId ? { connect: { id: ownerId } } : undefined, }; diff --git a/src/services/db/message.ts b/src/services/db/message.ts index fdf5e47..23545bc 100644 --- a/src/services/db/message.ts +++ b/src/services/db/message.ts @@ -19,8 +19,14 @@ class _MessageCRUD { }); } - async get(id: number) { - return kPrisma.message.findFirst({ where: { id } }).catch((e) => { + async get( + id: number, + options?: { + include?: Prisma.MessageInclude; + } + ) { + const { include = { sender: true } } = options ?? {}; + return kPrisma.message.findFirst({ where: { id }, include }).catch((e) => { console.error("❌ get message failed", id, e); return undefined; }); diff --git a/src/services/db/room.ts b/src/services/db/room.ts index 395d688..f3d6787 100644 --- a/src/services/db/room.ts +++ b/src/services/db/room.ts @@ -27,7 +27,13 @@ class _RoomCRUD { }); } - async get(id: string) { + async get( + id: string, + options?: { + include?: Prisma.RoomInclude; + } + ) { + const { include = { members: true } } = options ?? {}; return kPrisma.room.findFirst({ where: { id } }).catch((e) => { console.error("❌ get room failed", id, e); return undefined; diff --git a/src/services/db/user.ts b/src/services/db/user.ts index 1d9d31e..8f3c401 100644 --- a/src/services/db/user.ts +++ b/src/services/db/user.ts @@ -9,8 +9,14 @@ class _UserCRUD { }); } - async get(id: string) { - return kPrisma.user.findFirst({ where: { id } }).catch((e) => { + async get( + id: string, + options?: { + include?: Prisma.UserInclude; + } + ) { + const { include = { rooms: false } } = options ?? {}; + return kPrisma.user.findFirst({ where: { id }, include }).catch((e) => { console.error("❌ get user failed", id, e); return undefined; }); @@ -30,7 +36,7 @@ class _UserCRUD { take = 10, skip = 0, cursorId, - include = { rooms: true }, + include = { rooms: false }, order = "desc", } = options ?? {}; const users = await kPrisma.user diff --git a/src/services/openai.ts b/src/services/openai.ts index 5f38b33..e51be69 100644 --- a/src/services/openai.ts +++ b/src/services/openai.ts @@ -6,10 +6,13 @@ import { import { kEnvs } from "../utils/env"; import { kProxyAgent } from "./http"; +import { withDefault } from "../utils/base"; +import { ChatCompletionCreateParamsBase } from "openai/resources/chat/completions"; export interface ChatOptions { user: string; system?: string; + model?: ChatCompletionCreateParamsBase["model"]; tools?: Array; jsonMode?: boolean; requestId?: string; @@ -33,7 +36,21 @@ class OpenAIClient { } async chat(options: ChatOptions) { - let { user, system, tools, jsonMode, requestId } = options; + let { + user, + system, + tools, + jsonMode, + requestId, + model = kEnvs.OPENAI_MODEL ?? "gpt-3.5-turbo-0125", + } = options; + console.log( + ` +🔥🔥🔥 onAskAI start +🤖️ System: ${system ?? "None"} +😊 User: ${user} +`.trim() + ); const systemMsg: ChatCompletionMessageParam[] = system ? [{ role: "system", content: system }] : []; @@ -46,9 +63,9 @@ class OpenAIClient { const chatCompletion = await this._client.chat.completions .create( { + model, tools, messages: [...systemMsg, { role: "user", content: user }], - model: kEnvs.OPENAI_MODEL ?? "gpt-3.5-turbo-0125", response_format: jsonMode ? { type: "json_object" } : undefined, }, { signal } @@ -57,7 +74,14 @@ class OpenAIClient { console.error("❌ openai chat failed", e); return null; }); - return chatCompletion?.choices?.[0]?.message; + const message = chatCompletion?.choices?.[0]?.message; + console.log( + ` + ✅✅✅ onAskAI end + 🤖️ Answer: ${message?.content ?? "None"} + `.trim() + ); + return message; } async chatStream( @@ -65,16 +89,31 @@ class OpenAIClient { onStream?: (text: string) => void; } ) { - let { user, system, tools, jsonMode, requestId, onStream } = options; + let { + user, + system, + tools, + jsonMode, + requestId, + onStream, + model = kEnvs.OPENAI_MODEL ?? "gpt-3.5-turbo-0125", + } = options; + console.log( + ` +🔥🔥🔥 onAskAI start +🤖️ System: ${system ?? "None"} +😊 User: ${user} +`.trim() + ); const systemMsg: ChatCompletionMessageParam[] = system ? [{ role: "system", content: system }] : []; const stream = await this._client.chat.completions .create({ + model, tools, stream: true, messages: [...systemMsg, { role: "user", content: user }], - model: kEnvs.OPENAI_MODEL ?? "gpt-3.5-turbo-0125", response_format: jsonMode ? { type: "json_object" } : undefined, }) .catch((e) => { @@ -88,23 +127,26 @@ class OpenAIClient { this._abortCallbacks[requestId] = () => stream.controller.abort(); } let content = ""; - try { - for await (const chunk of stream) { - const text = chunk.choices[0]?.delta?.content || ""; - const aborted = - requestId && !Object.keys(this._abortCallbacks).includes(requestId); - if (aborted) { - return undefined; - } - if (text) { - onStream?.(text); - content += text; - } + for await (const chunk of stream) { + const text = chunk.choices[0]?.delta?.content || ""; + const aborted = + requestId && !Object.keys(this._abortCallbacks).includes(requestId); + if (aborted) { + content = ""; + break; + } + if (text) { + onStream?.(text); + content += text; } - } catch { - return undefined; } - return content; + console.log( + ` + ✅✅✅ onAskAI end + 🤖️ Answer: ${content ?? "None"} + `.trim() + ); + return withDefault(content, undefined); } } diff --git a/src/services/speaker/ai.ts b/src/services/speaker/ai.ts index 0ce5658..b79b8c1 100644 --- a/src/services/speaker/ai.ts +++ b/src/services/speaker/ai.ts @@ -203,7 +203,7 @@ export class AISpeaker extends Speaker { const { hasNewMsg } = this.checkIfHasNewMsg(msg); for (const action of this._askAIForAnswerSteps) { const res = await action(msg, data); - if (hasNewMsg()) { + if (hasNewMsg() || this.status !== "running") { // 收到新的用户请求消息,终止后续操作和响应 return; } diff --git a/src/services/speaker/speaker.ts b/src/services/speaker/speaker.ts index 5eb02db..0d022f4 100644 --- a/src/services/speaker/speaker.ts +++ b/src/services/speaker/speaker.ts @@ -14,7 +14,7 @@ export interface QueryMessage { export interface SpeakerAnswer { text?: string; url?: string; - steam?: StreamResponse; + stream?: StreamResponse; } export interface SpeakerCommand { @@ -53,10 +53,10 @@ export class Speaker extends BaseSpeaker { this.exitKeepAliveAfter = exitKeepAliveAfter; } - private _status: "running" | "stopped" = "running"; + status: "running" | "stopped" = "running"; stop() { - this._status = "stopped"; + this.status = "stopped"; } async run() { @@ -66,7 +66,7 @@ export class Speaker extends BaseSpeaker { } console.log("✅ 服务已启动..."); this.activeKeepAliveMode(); - while (this._status === "running") { + while (this.status === "running") { const nextMsg = await this.fetchNextMessage(); if (nextMsg) { this.responding = false; @@ -79,7 +79,7 @@ export class Speaker extends BaseSpeaker { } async activeKeepAliveMode() { - while (this._status === "running") { + while (this.status === "running") { if (this.keepAlive) { // 唤醒中 if (!this.responding) { @@ -110,7 +110,7 @@ export class Speaker extends BaseSpeaker { const answer = await command.run(msg); // 回复用户 if (answer) { - if (noNewMsg()) { + if (noNewMsg() && this.status === "running") { await this.response({ ...answer, keepAlive: this.keepAlive, @@ -146,7 +146,12 @@ export class Speaker extends BaseSpeaker { } const { noNewMsg } = this.checkIfHasNewMsg(); this._preTimer = setTimeout(async () => { - if (this.keepAlive && !this.responding && noNewMsg()) { + if ( + this.keepAlive && + !this.responding && + noNewMsg() && + this.status === "running" + ) { await this.exitKeepAlive(); } }, this.exitKeepAliveAfter * 1000); diff --git a/src/utils/string.ts b/src/utils/string.ts index bc140ce..31748a4 100644 --- a/src/utils/string.ts +++ b/src/utils/string.ts @@ -1,6 +1,6 @@ -import { readJSONSync } from './io'; +import { readJSONSync } from "./io"; -export const kVersion = readJSONSync('package.json').version; +export const kVersion = readJSONSync("package.json").version; export const kBannerASCII = ` @@ -15,21 +15,21 @@ export const kBannerASCII = ` MiGPT v1.0.0 by: del-wang.eth -`.replace('1.0.0', kVersion); +`.replace("1.0.0", kVersion); /** * 转北京时间:2023年12月12日星期二 12:46 */ export function toUTC8Time(date: Date) { - return date.toLocaleString('zh-CN', { - year: 'numeric', - month: '2-digit', - weekday: 'long', - day: '2-digit', - hour: '2-digit', - minute: '2-digit', + return date.toLocaleString("zh-CN", { + year: "numeric", + month: "2-digit", + weekday: "long", + day: "2-digit", + hour: "2-digit", + minute: "2-digit", hour12: false, - timeZone: 'Asia/Shanghai', + timeZone: "Asia/Shanghai", }); } @@ -43,3 +43,12 @@ export function buildPrompt( } return template; } + +export function formatMsg(msg: { + name: string; + text: string; + timestamp: number; +}) { + const { name, text, timestamp } = msg; + return `${toUTC8Time(new Date(timestamp))} ${name}: ${text}`; +} diff --git a/tests/bot.ts b/tests/bot.ts index 3ebc9b5..b43b928 100644 --- a/tests/bot.ts +++ b/tests/bot.ts @@ -2,7 +2,65 @@ import { MyBot } from "../src/services/bot"; import { AISpeaker } from "../src/services/speaker/ai"; export async function testMyBot() { - await testStreamResponse(); + // await testStreamResponse(); + await testRunBot(); +} + +async function testRunBot() { + const name = "豆包"; + const speaker = new AISpeaker({ + name, + tts: "doubao", + userId: process.env.MI_USER!, + password: process.env.MI_PASS!, + did: process.env.MI_DID, + }); + const bot = new MyBot({ + speaker, + bot: { + name, + profile: ` +性别:女 +年龄:20岁 +学校:位于一个风景如画的小城市,一所综合性大学的文学院学生。 +性格特点: +- 温婉可亲,对待人和事总是保持着乐观和善良的态度。 +- 内向而思维敏捷,喜欢独处时阅读和思考。 +- 对待朋友非常真诚,虽然不善于表达,但总是用行动去关心和帮助别人。 +外貌特征: +- 清秀脱俗,长发及腰,喜欢简单的束发。 +- 眼睛大而有神,总是带着温和的微笑。 +- 穿着简单大方,偏爱文艺范的衣服,如棉麻连衣裙,不追求名牌,却总能穿出自己的风格。 +爱好: +- 阅读,尤其是古典文学和现代诗歌,她的书房里收藏了大量的书籍。 +- 写作,喜欢在闲暇时写写诗或是短篇小说,有时也会在学校的文学社团里分享自己的作品。 +- 摄影,喜欢用镜头记录生活中的美好瞬间,尤其是自然风光和人文景观。 +特长: +- 写作能力突出,曾多次获得学校文学比赛的奖项。 +- 擅长钢琴,从小学习,能够演奏多首经典曲目。 +- 有一定的绘画基础,喜欢在空闲时画一些风景或是静物。 +梦想: +- 希望能成为一名作家,将自己对生活的感悟和对美的追求通过文字传达给更多的人。 +- 想要环游世界,用镜头和笔记录下世界各地的美丽和人文。 +`, + }, + master: { + name: "王黎", + profile: ` +性别:男 +年龄:18 +爱好:跑步,骑行,读书,追剧,旅游,听歌 +职业:程序员 +其他: +- 喜欢的电视剧有《请回答1988》、《漫长的季节》、《爱的迫降》等 +- 喜欢吃土豆丝、茄子、山药、米线 +- 喜欢黑红配色,浅蓝色和粉色 +- 有空喜欢去公园静观人来人往 +`, + }, + }); + const res = await bot.run(); + console.log("✅ done"); } async function testStreamResponse() { diff --git a/tests/db.ts b/tests/db.ts index 09adb2b..8a19072 100644 --- a/tests/db.ts +++ b/tests/db.ts @@ -21,19 +21,31 @@ export async function testDB() { const { room, bot, master, memory } = await manager.get(); assert(room, "❌ 初始化用户失败"); let message = await manager.onMessage({ + bot: bot!, + master: master!, + room: room!, sender: master!, text: "你好!", }); assert(message?.text === "你好!", "❌ 插入消息失败"); message = await manager.onMessage({ + bot: bot!, + master: master!, + room: room!, sender: bot!, text: "你好!很高兴认识你", }); await manager.onMessage({ + bot: bot!, + master: master!, + room: room!, sender: master!, text: "你是谁?", }); await manager.onMessage({ + bot: bot!, + master: master!, + room: room!, sender: bot!, text: "我是小爱同学,你可以叫我小爱!", });