feat: auto update long/short term memory

This commit is contained in:
WJG 2024-02-26 01:06:37 +08:00
parent 57a765af1b
commit 0d9bc01dc9
No known key found for this signature in database
GPG Key ID: 258474EF8590014A
18 changed files with 513 additions and 150 deletions

View File

@ -1,4 +1,4 @@
- ✅ Auto mute XiaoAi reply (not perfect yet) - ✅ Auto mute XiaoAi reply (not perfect yet)
- ✅ Stream response - ✅ Stream response
- ✅ Deactivate Xiaoai - ✅ Deactivate Xiaoai
- Update long/short memories - Update long/short memories

View File

@ -31,11 +31,12 @@ CREATE TABLE "Message" (
-- CreateTable -- CreateTable
CREATE TABLE "Memory" ( CREATE TABLE "Memory" (
"id" INTEGER NOT NULL PRIMARY KEY AUTOINCREMENT, "id" INTEGER NOT NULL PRIMARY KEY AUTOINCREMENT,
"text" TEXT NOT NULL, "msgId" INTEGER NOT NULL,
"ownerId" TEXT, "ownerId" TEXT,
"roomId" TEXT NOT NULL, "roomId" TEXT NOT NULL,
"createdAt" DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP, "createdAt" DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP,
"updatedAt" DATETIME NOT NULL, "updatedAt" DATETIME NOT NULL,
CONSTRAINT "Memory_msgId_fkey" FOREIGN KEY ("msgId") REFERENCES "Message" ("id") ON DELETE RESTRICT ON UPDATE CASCADE,
CONSTRAINT "Memory_ownerId_fkey" FOREIGN KEY ("ownerId") REFERENCES "User" ("id") ON DELETE SET NULL ON UPDATE CASCADE, CONSTRAINT "Memory_ownerId_fkey" FOREIGN KEY ("ownerId") REFERENCES "User" ("id") ON DELETE SET NULL ON UPDATE CASCADE,
CONSTRAINT "Memory_roomId_fkey" FOREIGN KEY ("roomId") REFERENCES "Room" ("id") ON DELETE RESTRICT ON UPDATE CASCADE CONSTRAINT "Memory_roomId_fkey" FOREIGN KEY ("roomId") REFERENCES "Room" ("id") ON DELETE RESTRICT ON UPDATE CASCADE
); );

View File

@ -48,6 +48,7 @@ model Message {
senderId String senderId String
room Room @relation(fields: [roomId], references: [id]) room Room @relation(fields: [roomId], references: [id])
roomId String roomId String
memories Memory[]
// 时间日期 // 时间日期
createdAt DateTime @default(now()) createdAt DateTime @default(now())
updatedAt DateTime @updatedAt updatedAt DateTime @updatedAt
@ -55,8 +56,9 @@ model Message {
model Memory { model Memory {
id Int @id @default(autoincrement()) id Int @id @default(autoincrement())
text String
// 关联数据 // 关联数据
msg Message @relation(fields: [msgId], references: [id])
msgId Int
owner User? @relation(fields: [ownerId], references: [id]) // owner 为空时,即房间自己的公共记忆 owner User? @relation(fields: [ownerId], references: [id]) // owner 为空时,即房间自己的公共记忆
ownerId String? ownerId String?
room Room @relation(fields: [roomId], references: [id]) room Room @relation(fields: [roomId], references: [id])

View File

@ -17,6 +17,7 @@ export class ConversationManager {
} }
return { return {
...config, ...config,
// 记忆存储在公共 room 上
memory: new MemoryManager(config.room), memory: new MemoryManager(config.room),
}; };
} }
@ -43,18 +44,25 @@ export class ConversationManager {
return MessageCRUD.gets({ room, ...options }); return MessageCRUD.gets({ room, ...options });
} }
async onMessage(payload: { sender: User; text: string }) { async onMessage(
const { sender, text } = payload; payload: IBotConfig & {
sender: User;
text: string;
timestamp?: number;
}
) {
const { sender, text, timestamp = Date.now(), ...botConfig } = payload;
const { room, memory } = await this.get(); const { room, memory } = await this.get();
if (memory) { if (memory) {
const message = await MessageCRUD.addOrUpdate({ const message = await MessageCRUD.addOrUpdate({
text, text,
roomId: room!.id, roomId: room!.id,
senderId: sender.id, senderId: sender.id,
createdAt: new Date(timestamp),
}); });
if (message) { if (message) {
// 异步加入记忆 // 异步加入记忆(到 room
memory?.addMessage2Memory(message); memory?.addMessage2Memory(message,botConfig);
return message; return message;
} }
} }

View File

@ -1,43 +1,68 @@
import { randomUUID } from "crypto"; import { randomUUID } from "crypto";
import { jsonDecode, jsonEncode } from "../../utils/base"; import { buildPrompt, formatMsg } from "../../utils/string";
import { buildPrompt, toUTC8Time } from "../../utils/string";
import { ChatOptions, openai } from "../openai"; import { ChatOptions, openai } from "../openai";
import { IBotConfig } from "./config"; import { IBotConfig } from "./config";
import { ConversationManager } from "./conversation"; import { ConversationManager } from "./conversation";
import { StreamResponse } from "../speaker/stream"; import { StreamResponse } from "../speaker/stream";
import { QueryMessage, SpeakerAnswer } from "../speaker/speaker";
import { AISpeaker } from "../speaker/ai";
import { DeepPartial } from "../../utils/type";
// todo JSON mode 下,无法使用 stream 应答模式在应答完成之前无法构造完整的JSON // todo JSON mode 下,无法使用 stream 应答模式在应答完成之前无法构造完整的JSON
const systemTemplate = ` const systemTemplate = `
{{name}} {{botName}}使
{{name}} ##
<profile> {{botName}}
{{profile}} <start>
</profile> {{botProfile}}
</end>
{{wechatName}}{{wechatName}}{{wechatAlias}} ##
{{masterName}}{{masterName}}
<start>
{{masterProfile}}
</end>
##
<history> {{masterName}}{{roomName}}
{{history}} <start>
<history> {{roomIntroduction}}
</end>
##
<contexts>
{{contexts}} <start>
</contexts> {{messages}}
</end>
{{name}} ##
便
<start>
{{shortTermMemory}}
</end>
##
- {{name}}
- {{name}}仿 <start>
- {{longTermMemory}}
- </end>
- {{name}}
JSON格式回复 ##
{ message: "回复的消息内容" } {{masterName}}
- {{botName}}
- {{masterName}}
-
-
-
##
{{masterName}}
{{botName}}
##
{{botName}}{{masterName}}
`.trim(); `.trim();
const userTemplate = ` const userTemplate = `
@ -45,47 +70,88 @@ const userTemplate = `
`.trim(); `.trim();
export class MyBot { export class MyBot {
private manager: ConversationManager; speaker: AISpeaker;
constructor(config: IBotConfig) { manager: ConversationManager;
constructor(config: DeepPartial<IBotConfig> & { speaker: AISpeaker }) {
this.speaker = config.speaker;
this.manager = new ConversationManager(config); this.manager = new ConversationManager(config);
} }
async ask(msg: string) { stop() {
return this.speaker.stop();
}
run() {
this.speaker.askAI = (msg) => this.ask(msg);
return this.speaker.run();
}
async ask(msg: QueryMessage): Promise<SpeakerAnswer> {
const { bot, master, room, memory } = await this.manager.get(); const { bot, master, room, memory } = await this.manager.get();
if (!memory) { if (!memory) {
return; return {};
} }
const lastMessages = await this.manager.getMessages({ const lastMessages = await this.manager.getMessages({ take: 10 });
take: 10, const shortTermMemories = await memory.getShortTermMemories({ take: 1 });
const shortTermMemory = shortTermMemories[0]?.text ?? "短期记忆为空";
const longTermMemories = await memory.getLongTermMemories({ take: 1 });
const longTermMemory = longTermMemories[0]?.text ?? "长期记忆为空";
const systemPrompt = buildPrompt(systemTemplate, {
shortTermMemory,
longTermMemory,
botName: bot!.name,
botProfile: bot!.profile,
masterName: master!.name,
masterProfile: master!.profile,
roomName: room!.name,
roomIntroduction: room!.description,
messages:
lastMessages.length < 1
? "暂无历史消息"
: lastMessages
.map((e) =>
formatMsg({
name: e.sender.name,
text: e.text,
timestamp: e.createdAt.getTime(),
})
)
.join("\n"),
}); });
const result = await openai.chat({ const userPrompt = buildPrompt(userTemplate, {
system: buildPrompt(systemTemplate, { message: formatMsg({
bot_name: bot!.name, name: master!.name,
bot_profile: bot!.profile, text: msg.text,
master_name: master!.name, timestamp: msg.timestamp,
master_profile: master!.profile,
history:
lastMessages.length < 1
? "暂无"
: lastMessages
.map((e) =>
jsonEncode({
time: toUTC8Time(e.createdAt),
user: e.sender.name,
message: e.text,
})
)
.join("\n"),
}),
user: buildPrompt(userTemplate, {
message: jsonEncode({
time: toUTC8Time(new Date()),
user: master!.name,
message: msg,
})!,
}), }),
}); });
return jsonDecode(result?.content)?.message; // 添加请求消息到 DB
await this.manager.onMessage({
bot: bot!,
master: master!,
room: room!,
sender: master!,
text: msg.text,
timestamp: msg.timestamp,
});
const stream = await MyBot.chatWithStreamResponse({
system: systemPrompt,
user: userPrompt,
onFinished: async (text) => {
if (text) {
// 添加响应消息到 DB
await this.manager.onMessage({
bot: bot!,
master: master!,
room: room!,
text,
sender: bot!,
timestamp: Date.now(),
});
}
},
});
return { stream };
} }
static async chatWithStreamResponse( static async chatWithStreamResponse(
@ -94,7 +160,7 @@ export class MyBot {
} }
) { ) {
const requestId = randomUUID(); const requestId = randomUUID();
const stream = new StreamResponse(); const stream = new StreamResponse({ firstSubmitTimeout: 5 * 1000 });
openai openai
.chatStream({ .chatStream({
...options, ...options,

View File

@ -6,6 +6,7 @@ import { ShortTermMemoryCRUD } from "../../db/memory-short-term";
import { LongTermMemoryCRUD } from "../../db/memory-long-term"; import { LongTermMemoryCRUD } from "../../db/memory-long-term";
import { ShortTermMemoryAgent } from "./short-term"; import { ShortTermMemoryAgent } from "./short-term";
import { openai } from "../../openai"; import { openai } from "../../openai";
import { IBotConfig } from "../config";
export class MemoryManager { export class MemoryManager {
private room: Room; private room: Room;
@ -20,27 +21,23 @@ export class MemoryManager {
this.owner = owner; this.owner = owner;
} }
async getMemories(take?: number) { async getMemories(options?: { take?: number }) {
return MemoryCRUD.gets({ return MemoryCRUD.gets({ ...options, room: this.room, owner: this.owner });
room: this.room,
owner: this.owner,
take,
});
} }
async getShortTermMemories(take?: number) { async getShortTermMemories(options?: { take?: number }) {
return ShortTermMemoryCRUD.gets({ return ShortTermMemoryCRUD.gets({
...options,
room: this.room, room: this.room,
owner: this.owner, owner: this.owner,
take,
}); });
} }
async getLongTermMemories(take?: number) { async getLongTermMemories(options?: { take?: number }) {
return LongTermMemoryCRUD.gets({ return LongTermMemoryCRUD.gets({
...options,
room: this.room, room: this.room,
owner: this.owner, owner: this.owner,
take,
}); });
} }
@ -50,20 +47,20 @@ export class MemoryManager {
} }
private _currentMemory?: Memory; private _currentMemory?: Memory;
async addMessage2Memory(message: Message) { async addMessage2Memory(message: Message, botConfig: IBotConfig) {
// todo create memory embedding // todo create memory embedding
const currentMemory = await MemoryCRUD.addOrUpdate({ const currentMemory = await MemoryCRUD.addOrUpdate({
text: message.text, msgId: message.id,
roomId: this.room.id, roomId: this.room.id,
ownerId: message.senderId, ownerId: message.senderId,
}); });
if (currentMemory) { if (currentMemory) {
this._onMemory(currentMemory); this._onMemory(currentMemory, botConfig);
} }
return currentMemory; return currentMemory;
} }
private _onMemory(currentMemory: Memory) { private _onMemory(currentMemory: Memory, botConfig: IBotConfig) {
if (this._currentMemory) { if (this._currentMemory) {
// 取消之前的更新记忆任务 // 取消之前的更新记忆任务
openai.abort(`update-short-memory-${this._currentMemory.id}`); openai.abort(`update-short-memory-${this._currentMemory.id}`);
@ -71,24 +68,28 @@ export class MemoryManager {
} }
this._currentMemory = currentMemory; this._currentMemory = currentMemory;
// 异步更新长短期记忆 // 异步更新长短期记忆
this.updateLongShortTermMemory({ currentMemory }); this.updateLongShortTermMemory({ currentMemory, botConfig });
} }
/** /**
* *
*/ */
async updateLongShortTermMemory(options: { async updateLongShortTermMemory(options: {
botConfig: IBotConfig;
currentMemory: Memory; currentMemory: Memory;
shortThreshold?: number; shortThreshold?: number;
longThreshold?: number; longThreshold?: number;
}) { }) {
const { currentMemory, shortThreshold, longThreshold } = options ?? {}; const { currentMemory, shortThreshold, longThreshold, botConfig } =
options ?? {};
const success = await this._updateShortTermMemory({ const success = await this._updateShortTermMemory({
botConfig,
currentMemory, currentMemory,
threshold: shortThreshold, threshold: shortThreshold,
}); });
if (success) { if (success) {
await this._updateLongTermMemory({ await this._updateLongTermMemory({
botConfig,
currentMemory, currentMemory,
threshold: longThreshold, threshold: longThreshold,
}); });
@ -96,21 +97,27 @@ export class MemoryManager {
} }
private async _updateShortTermMemory(options: { private async _updateShortTermMemory(options: {
botConfig: IBotConfig;
currentMemory: Memory; currentMemory: Memory;
threshold?: number; threshold?: number;
}) { }) {
const { currentMemory, threshold = 10 } = options; const { currentMemory, threshold = 10, botConfig } = options;
const lastMemory = firstOf(await this.getShortTermMemories(1)); const lastMemory = firstOf(await this.getShortTermMemories({ take: 1 }));
const newMemories = await MemoryCRUD.gets({ const newMemories: (Memory & {
msg: Message & {
sender: User;
};
})[] = (await MemoryCRUD.gets({
cursorId: lastMemory?.cursorId, cursorId: lastMemory?.cursorId,
room: this.room, room: this.room,
owner: this.owner, owner: this.owner,
order: "asc", // 从旧到新排序 order: "asc", // 从旧到新排序
}); })) as any;
if (newMemories.length < 1 || newMemories.length < threshold) { if (newMemories.length < 1 || newMemories.length < threshold) {
return true; return true;
} }
const newMemory = await ShortTermMemoryAgent.generate({ const newMemory = await ShortTermMemoryAgent.generate({
botConfig,
currentMemory, currentMemory,
newMemories, newMemories,
lastMemory, lastMemory,
@ -128,11 +135,12 @@ export class MemoryManager {
} }
private async _updateLongTermMemory(options: { private async _updateLongTermMemory(options: {
botConfig: IBotConfig;
currentMemory: Memory; currentMemory: Memory;
threshold?: number; threshold?: number;
}) { }) {
const { currentMemory, threshold = 10 } = options; const { currentMemory, threshold = 10, botConfig } = options;
const lastMemory = firstOf(await this.getLongTermMemories(1)); const lastMemory = firstOf(await this.getLongTermMemories({ take: 1 }));
const newMemories = await ShortTermMemoryCRUD.gets({ const newMemories = await ShortTermMemoryCRUD.gets({
cursorId: lastMemory?.cursorId, cursorId: lastMemory?.cursorId,
room: this.room, room: this.room,
@ -143,6 +151,7 @@ export class MemoryManager {
return true; return true;
} }
const newMemory = await LongTermMemoryAgent.generate({ const newMemory = await LongTermMemoryAgent.generate({
botConfig,
currentMemory, currentMemory,
newMemories, newMemories,
lastMemory, lastMemory,

View File

@ -1,18 +1,70 @@
import { LongTermMemory, Memory, ShortTermMemory } from "@prisma/client"; import { LongTermMemory, Memory, ShortTermMemory } from "@prisma/client";
import { openai } from "../../openai"; import { openai } from "../../openai";
import { buildPrompt } from "../../../utils/string";
import { jsonDecode, lastOf } from "../../../utils/base";
import { IBotConfig } from "../config";
const userTemplate = `
{{botName}}{{masterName}}
##
<start>
{{longTermMemory}}
</end>
##
{{masterName}}{{botName}}便
<start>
{{shortTermMemory}}
</end>
##
-
-
-
- 1000
##
<start>
- 2022/02/11{{masterName}}西
- 2022/03/21{{masterName}}{{botName}}
- 2022/03/21{{masterName}}{{botName}}
- 2022/06/01{{masterName}}201.8
- 2022/12/01{{masterName}}
- 2023/09/21{{masterName}}
</end>
##
JSON格式回复
{"longTermMemories": "这里填写更新后的长期记忆内容"}
##
`.trim();
export class LongTermMemoryAgent { export class LongTermMemoryAgent {
// todo 使用 LLM 生成新的长期记忆
static async generate(options: { static async generate(options: {
botConfig: IBotConfig;
currentMemory: Memory; currentMemory: Memory;
newMemories: ShortTermMemory[]; newMemories: ShortTermMemory[];
lastMemory?: LongTermMemory; lastMemory?: LongTermMemory;
}): Promise<string | undefined> { }): Promise<string | undefined> {
const { currentMemory, newMemories, lastMemory } = options; const { currentMemory, newMemories, lastMemory, botConfig } = options;
const res = await openai.chat({ const res = await openai.chat({
user: "todo", // todo prompt jsonMode: true,
requestId: `update-long-memory-${currentMemory.id}`, requestId: `update-long-memory-${currentMemory.id}`,
user: buildPrompt(userTemplate, {
masterName: botConfig.master.name,
botName: botConfig.bot.name,
longTermMemory: lastMemory?.text ?? "暂无长期记忆",
shortTermMemory: lastOf(newMemories)!.text,
}),
}); });
return res?.content?.trim(); return jsonDecode(res?.content)?.longTermMemories;
} }
} }

View File

@ -1,18 +1,81 @@
import { Memory, ShortTermMemory } from "@prisma/client"; import { Memory, Message, ShortTermMemory, User } from "@prisma/client";
import { openai } from "../../openai"; import { openai } from "../../openai";
import { buildPrompt, formatMsg } from "../../../utils/string";
import { jsonDecode } from "../../../utils/base";
import { IBotConfig } from "../config";
const userTemplate = `
{{botName}}{{masterName}}
##
<start>
{{shortTermMemory}}
</end>
##
{{masterName}}{{botName}}
<start>
{{messages}}
</end>
##
-
-
-
- 1000
##
<start>
- 2023/12/01 08:00{{masterName}}{{botName}}
- 2023/12/01 08:10{{masterName}}{{botName}}
- 2023/12/01 09:00{{masterName}}
- 2023/12/01 09:15{{masterName}}{{botName}}
- 2023/12/01 10:00{{masterName}}{{botName}}
</end>
##
使JSON格式回复更新后的短期记忆
{"shortTermMemories": "更新后的短期记忆内容"}
##
`.trim();
export class ShortTermMemoryAgent { export class ShortTermMemoryAgent {
// todo 使用 LLM 生成新的短期记忆
static async generate(options: { static async generate(options: {
botConfig: IBotConfig;
currentMemory: Memory; currentMemory: Memory;
newMemories: Memory[]; newMemories: (Memory & {
msg: Message & {
sender: User;
};
})[];
lastMemory?: ShortTermMemory; lastMemory?: ShortTermMemory;
}): Promise<string | undefined> { }): Promise<string | undefined> {
const { currentMemory, newMemories, lastMemory } = options; const { currentMemory, newMemories, lastMemory, botConfig } = options;
const res = await openai.chat({ const res = await openai.chat({
user: "todo", // todo prompt jsonMode: true,
requestId: `update-short-memory-${currentMemory.id}`, requestId: `update-short-memory-${currentMemory.id}`,
user: buildPrompt(userTemplate, {
masterName: botConfig.master.name,
botName: botConfig.bot.name,
shortTermMemory: lastMemory?.text ?? "暂无短期记忆",
messages: newMemories
.map((e) =>
formatMsg({
name: e.msg.sender.name,
text: e.msg.text,
timestamp: e.createdAt.getTime(),
})
)
.join("\n"),
}),
}); });
return res?.content?.trim(); return jsonDecode(res?.content)?.shortTermMemories;
} }
} }

View File

@ -1,4 +1,4 @@
import { Memory, Room, User } from "@prisma/client"; import { Memory, Prisma, Room, User } from "@prisma/client";
import { getSkipWithCursor, k404, kPrisma } from "./index"; import { getSkipWithCursor, k404, kPrisma } from "./index";
import { removeEmpty } from "../../utils/base"; import { removeEmpty } from "../../utils/base";
@ -19,8 +19,20 @@ class _MemoryCRUD {
}); });
} }
async get(id: number) { async get(
return kPrisma.memory.findFirst({ where: { id } }).catch((e) => { id: number,
options?: {
include?: Prisma.MemoryInclude;
}
) {
const {
include = {
msg: {
include: { sender: true },
},
},
} = options ?? {};
return kPrisma.memory.findFirst({ where: { id }, include }).catch((e) => {
console.error("❌ get memory failed", id, e); console.error("❌ get memory failed", id, e);
return undefined; return undefined;
}); });
@ -32,6 +44,7 @@ class _MemoryCRUD {
take?: number; take?: number;
skip?: number; skip?: number;
cursorId?: number; cursorId?: number;
include?: Prisma.MemoryInclude;
/** /**
* *
*/ */
@ -43,12 +56,18 @@ class _MemoryCRUD {
take = 10, take = 10,
skip = 0, skip = 0,
cursorId, cursorId,
include = {
msg: {
include: { sender: true },
},
},
order = "desc", order = "desc",
} = options ?? {}; } = options ?? {};
const memories = await kPrisma.memory const memories = await kPrisma.memory
.findMany({ .findMany({
where: removeEmpty({ roomId: room?.id, ownerId: owner?.id }), where: removeEmpty({ roomId: room?.id, ownerId: owner?.id }),
take, take,
include,
orderBy: { createdAt: order }, orderBy: { createdAt: order },
...getSkipWithCursor(skip, cursorId), ...getSkipWithCursor(skip, cursorId),
}) })
@ -61,15 +80,14 @@ class _MemoryCRUD {
async addOrUpdate( async addOrUpdate(
memory: Partial<Memory> & { memory: Partial<Memory> & {
text: string; msgId: number;
roomId: string; roomId: string;
ownerId?: string; ownerId?: string;
} }
) { ) {
const { text: _text, roomId, ownerId } = memory; const { msgId, roomId, ownerId } = memory;
const text = _text?.trim();
const data = { const data = {
text, msg: { connect: { id: msgId } },
room: { connect: { id: roomId } }, room: { connect: { id: roomId } },
owner: ownerId ? { connect: { id: ownerId } } : undefined, owner: ownerId ? { connect: { id: ownerId } } : undefined,
}; };

View File

@ -19,8 +19,14 @@ class _MessageCRUD {
}); });
} }
async get(id: number) { async get(
return kPrisma.message.findFirst({ where: { id } }).catch((e) => { id: number,
options?: {
include?: Prisma.MessageInclude;
}
) {
const { include = { sender: true } } = options ?? {};
return kPrisma.message.findFirst({ where: { id }, include }).catch((e) => {
console.error("❌ get message failed", id, e); console.error("❌ get message failed", id, e);
return undefined; return undefined;
}); });

View File

@ -27,7 +27,13 @@ class _RoomCRUD {
}); });
} }
async get(id: string) { async get(
id: string,
options?: {
include?: Prisma.RoomInclude;
}
) {
const { include = { members: true } } = options ?? {};
return kPrisma.room.findFirst({ where: { id } }).catch((e) => { return kPrisma.room.findFirst({ where: { id } }).catch((e) => {
console.error("❌ get room failed", id, e); console.error("❌ get room failed", id, e);
return undefined; return undefined;

View File

@ -9,8 +9,14 @@ class _UserCRUD {
}); });
} }
async get(id: string) { async get(
return kPrisma.user.findFirst({ where: { id } }).catch((e) => { id: string,
options?: {
include?: Prisma.UserInclude;
}
) {
const { include = { rooms: false } } = options ?? {};
return kPrisma.user.findFirst({ where: { id }, include }).catch((e) => {
console.error("❌ get user failed", id, e); console.error("❌ get user failed", id, e);
return undefined; return undefined;
}); });
@ -30,7 +36,7 @@ class _UserCRUD {
take = 10, take = 10,
skip = 0, skip = 0,
cursorId, cursorId,
include = { rooms: true }, include = { rooms: false },
order = "desc", order = "desc",
} = options ?? {}; } = options ?? {};
const users = await kPrisma.user const users = await kPrisma.user

View File

@ -6,10 +6,13 @@ import {
import { kEnvs } from "../utils/env"; import { kEnvs } from "../utils/env";
import { kProxyAgent } from "./http"; import { kProxyAgent } from "./http";
import { withDefault } from "../utils/base";
import { ChatCompletionCreateParamsBase } from "openai/resources/chat/completions";
export interface ChatOptions { export interface ChatOptions {
user: string; user: string;
system?: string; system?: string;
model?: ChatCompletionCreateParamsBase["model"];
tools?: Array<ChatCompletionTool>; tools?: Array<ChatCompletionTool>;
jsonMode?: boolean; jsonMode?: boolean;
requestId?: string; requestId?: string;
@ -33,7 +36,21 @@ class OpenAIClient {
} }
async chat(options: ChatOptions) { async chat(options: ChatOptions) {
let { user, system, tools, jsonMode, requestId } = options; let {
user,
system,
tools,
jsonMode,
requestId,
model = kEnvs.OPENAI_MODEL ?? "gpt-3.5-turbo-0125",
} = options;
console.log(
`
🔥🔥🔥 onAskAI start
🤖 System: ${system ?? "None"}
😊 User: ${user}
`.trim()
);
const systemMsg: ChatCompletionMessageParam[] = system const systemMsg: ChatCompletionMessageParam[] = system
? [{ role: "system", content: system }] ? [{ role: "system", content: system }]
: []; : [];
@ -46,9 +63,9 @@ class OpenAIClient {
const chatCompletion = await this._client.chat.completions const chatCompletion = await this._client.chat.completions
.create( .create(
{ {
model,
tools, tools,
messages: [...systemMsg, { role: "user", content: user }], messages: [...systemMsg, { role: "user", content: user }],
model: kEnvs.OPENAI_MODEL ?? "gpt-3.5-turbo-0125",
response_format: jsonMode ? { type: "json_object" } : undefined, response_format: jsonMode ? { type: "json_object" } : undefined,
}, },
{ signal } { signal }
@ -57,7 +74,14 @@ class OpenAIClient {
console.error("❌ openai chat failed", e); console.error("❌ openai chat failed", e);
return null; return null;
}); });
return chatCompletion?.choices?.[0]?.message; const message = chatCompletion?.choices?.[0]?.message;
console.log(
`
onAskAI end
🤖 Answer: ${message?.content ?? "None"}
`.trim()
);
return message;
} }
async chatStream( async chatStream(
@ -65,16 +89,31 @@ class OpenAIClient {
onStream?: (text: string) => void; onStream?: (text: string) => void;
} }
) { ) {
let { user, system, tools, jsonMode, requestId, onStream } = options; let {
user,
system,
tools,
jsonMode,
requestId,
onStream,
model = kEnvs.OPENAI_MODEL ?? "gpt-3.5-turbo-0125",
} = options;
console.log(
`
🔥🔥🔥 onAskAI start
🤖 System: ${system ?? "None"}
😊 User: ${user}
`.trim()
);
const systemMsg: ChatCompletionMessageParam[] = system const systemMsg: ChatCompletionMessageParam[] = system
? [{ role: "system", content: system }] ? [{ role: "system", content: system }]
: []; : [];
const stream = await this._client.chat.completions const stream = await this._client.chat.completions
.create({ .create({
model,
tools, tools,
stream: true, stream: true,
messages: [...systemMsg, { role: "user", content: user }], messages: [...systemMsg, { role: "user", content: user }],
model: kEnvs.OPENAI_MODEL ?? "gpt-3.5-turbo-0125",
response_format: jsonMode ? { type: "json_object" } : undefined, response_format: jsonMode ? { type: "json_object" } : undefined,
}) })
.catch((e) => { .catch((e) => {
@ -88,23 +127,26 @@ class OpenAIClient {
this._abortCallbacks[requestId] = () => stream.controller.abort(); this._abortCallbacks[requestId] = () => stream.controller.abort();
} }
let content = ""; let content = "";
try { for await (const chunk of stream) {
for await (const chunk of stream) { const text = chunk.choices[0]?.delta?.content || "";
const text = chunk.choices[0]?.delta?.content || ""; const aborted =
const aborted = requestId && !Object.keys(this._abortCallbacks).includes(requestId);
requestId && !Object.keys(this._abortCallbacks).includes(requestId); if (aborted) {
if (aborted) { content = "";
return undefined; break;
} }
if (text) { if (text) {
onStream?.(text); onStream?.(text);
content += text; content += text;
}
} }
} catch {
return undefined;
} }
return content; console.log(
`
onAskAI end
🤖 Answer: ${content ?? "None"}
`.trim()
);
return withDefault(content, undefined);
} }
} }

View File

@ -203,7 +203,7 @@ export class AISpeaker extends Speaker {
const { hasNewMsg } = this.checkIfHasNewMsg(msg); const { hasNewMsg } = this.checkIfHasNewMsg(msg);
for (const action of this._askAIForAnswerSteps) { for (const action of this._askAIForAnswerSteps) {
const res = await action(msg, data); const res = await action(msg, data);
if (hasNewMsg()) { if (hasNewMsg() || this.status !== "running") {
// 收到新的用户请求消息,终止后续操作和响应 // 收到新的用户请求消息,终止后续操作和响应
return; return;
} }

View File

@ -14,7 +14,7 @@ export interface QueryMessage {
export interface SpeakerAnswer { export interface SpeakerAnswer {
text?: string; text?: string;
url?: string; url?: string;
steam?: StreamResponse; stream?: StreamResponse;
} }
export interface SpeakerCommand { export interface SpeakerCommand {
@ -53,10 +53,10 @@ export class Speaker extends BaseSpeaker {
this.exitKeepAliveAfter = exitKeepAliveAfter; this.exitKeepAliveAfter = exitKeepAliveAfter;
} }
private _status: "running" | "stopped" = "running"; status: "running" | "stopped" = "running";
stop() { stop() {
this._status = "stopped"; this.status = "stopped";
} }
async run() { async run() {
@ -66,7 +66,7 @@ export class Speaker extends BaseSpeaker {
} }
console.log("✅ 服务已启动..."); console.log("✅ 服务已启动...");
this.activeKeepAliveMode(); this.activeKeepAliveMode();
while (this._status === "running") { while (this.status === "running") {
const nextMsg = await this.fetchNextMessage(); const nextMsg = await this.fetchNextMessage();
if (nextMsg) { if (nextMsg) {
this.responding = false; this.responding = false;
@ -79,7 +79,7 @@ export class Speaker extends BaseSpeaker {
} }
async activeKeepAliveMode() { async activeKeepAliveMode() {
while (this._status === "running") { while (this.status === "running") {
if (this.keepAlive) { if (this.keepAlive) {
// 唤醒中 // 唤醒中
if (!this.responding) { if (!this.responding) {
@ -110,7 +110,7 @@ export class Speaker extends BaseSpeaker {
const answer = await command.run(msg); const answer = await command.run(msg);
// 回复用户 // 回复用户
if (answer) { if (answer) {
if (noNewMsg()) { if (noNewMsg() && this.status === "running") {
await this.response({ await this.response({
...answer, ...answer,
keepAlive: this.keepAlive, keepAlive: this.keepAlive,
@ -146,7 +146,12 @@ export class Speaker extends BaseSpeaker {
} }
const { noNewMsg } = this.checkIfHasNewMsg(); const { noNewMsg } = this.checkIfHasNewMsg();
this._preTimer = setTimeout(async () => { this._preTimer = setTimeout(async () => {
if (this.keepAlive && !this.responding && noNewMsg()) { if (
this.keepAlive &&
!this.responding &&
noNewMsg() &&
this.status === "running"
) {
await this.exitKeepAlive(); await this.exitKeepAlive();
} }
}, this.exitKeepAliveAfter * 1000); }, this.exitKeepAliveAfter * 1000);

View File

@ -1,6 +1,6 @@
import { readJSONSync } from './io'; import { readJSONSync } from "./io";
export const kVersion = readJSONSync('package.json').version; export const kVersion = readJSONSync("package.json").version;
export const kBannerASCII = ` export const kBannerASCII = `
@ -15,21 +15,21 @@ export const kBannerASCII = `
MiGPT v1.0.0 by: del-wang.eth MiGPT v1.0.0 by: del-wang.eth
`.replace('1.0.0', kVersion); `.replace("1.0.0", kVersion);
/** /**
* 20231212 12:46 * 20231212 12:46
*/ */
export function toUTC8Time(date: Date) { export function toUTC8Time(date: Date) {
return date.toLocaleString('zh-CN', { return date.toLocaleString("zh-CN", {
year: 'numeric', year: "numeric",
month: '2-digit', month: "2-digit",
weekday: 'long', weekday: "long",
day: '2-digit', day: "2-digit",
hour: '2-digit', hour: "2-digit",
minute: '2-digit', minute: "2-digit",
hour12: false, hour12: false,
timeZone: 'Asia/Shanghai', timeZone: "Asia/Shanghai",
}); });
} }
@ -43,3 +43,12 @@ export function buildPrompt(
} }
return template; return template;
} }
export function formatMsg(msg: {
name: string;
text: string;
timestamp: number;
}) {
const { name, text, timestamp } = msg;
return `${toUTC8Time(new Date(timestamp))} ${name}: ${text}`;
}

View File

@ -2,7 +2,65 @@ import { MyBot } from "../src/services/bot";
import { AISpeaker } from "../src/services/speaker/ai"; import { AISpeaker } from "../src/services/speaker/ai";
export async function testMyBot() { export async function testMyBot() {
await testStreamResponse(); // await testStreamResponse();
await testRunBot();
}
async function testRunBot() {
const name = "豆包";
const speaker = new AISpeaker({
name,
tts: "doubao",
userId: process.env.MI_USER!,
password: process.env.MI_PASS!,
did: process.env.MI_DID,
});
const bot = new MyBot({
speaker,
bot: {
name,
profile: `
20
-
-
-
-
-
- 穿穿
-
-
-
-
-
-
-
-
`,
},
master: {
name: "王黎",
profile: `
18
- 1988
- 线
-
-
`,
},
});
const res = await bot.run();
console.log("✅ done");
} }
async function testStreamResponse() { async function testStreamResponse() {

View File

@ -21,19 +21,31 @@ export async function testDB() {
const { room, bot, master, memory } = await manager.get(); const { room, bot, master, memory } = await manager.get();
assert(room, "❌ 初始化用户失败"); assert(room, "❌ 初始化用户失败");
let message = await manager.onMessage({ let message = await manager.onMessage({
bot: bot!,
master: master!,
room: room!,
sender: master!, sender: master!,
text: "你好!", text: "你好!",
}); });
assert(message?.text === "你好!", "❌ 插入消息失败"); assert(message?.text === "你好!", "❌ 插入消息失败");
message = await manager.onMessage({ message = await manager.onMessage({
bot: bot!,
master: master!,
room: room!,
sender: bot!, sender: bot!,
text: "你好!很高兴认识你", text: "你好!很高兴认识你",
}); });
await manager.onMessage({ await manager.onMessage({
bot: bot!,
master: master!,
room: room!,
sender: master!, sender: master!,
text: "你是谁?", text: "你是谁?",
}); });
await manager.onMessage({ await manager.onMessage({
bot: bot!,
master: master!,
room: room!,
sender: bot!, sender: bot!,
text: "我是小爱同学,你可以叫我小爱!", text: "我是小爱同学,你可以叫我小爱!",
}); });