feat: add MessageContext

This commit is contained in:
WJG 2024-02-26 12:03:36 +08:00
parent 6c4dd301df
commit 7ad50322fe
No known key found for this signature in database
GPG Key ID: 258474EF8590014A
9 changed files with 111 additions and 124 deletions

View File

@ -1,9 +1,9 @@
import { Room, User } from "@prisma/client"; import { Room, User } from "@prisma/client";
import { readJSON, writeJSON } from "../../utils/io";
import { deepClone, removeEmpty } from "../../utils/base"; import { deepClone, removeEmpty } from "../../utils/base";
import { UserCRUD } from "../db/user"; import { readJSON, writeJSON } from "../../utils/io";
import { RoomCRUD, getRoomID } from "../db/room";
import { DeepPartial } from "../../utils/type"; import { DeepPartial } from "../../utils/type";
import { RoomCRUD, getRoomID } from "../db/room";
import { UserCRUD } from "../db/user";
const kDefaultMaster = { const kDefaultMaster = {
name: "用户", name: "用户",

View File

@ -1,8 +1,17 @@
import { Message, Prisma, User } from "@prisma/client"; import { Memory, Prisma, User } from "@prisma/client";
import { MemoryManager } from "./memory"; import { DeepPartial, MakeOptional } from "../../utils/type";
import { MessageCRUD } from "../db/message"; import { MessageCRUD } from "../db/message";
import { QueryMessage } from "../speaker/speaker";
import { BotConfig, IBotConfig } from "./config"; import { BotConfig, IBotConfig } from "./config";
import { DeepPartial } from "../../utils/type"; import { MemoryManager } from "./memory";
export interface MessageContext extends IBotConfig {
memory?: Memory;
}
export interface MessageWithSender
extends MakeOptional<QueryMessage, "timestamp"> {
sender: User;
}
export class ConversationManager { export class ConversationManager {
private config: DeepPartial<IBotConfig>; private config: DeepPartial<IBotConfig>;
@ -44,14 +53,8 @@ export class ConversationManager {
return MessageCRUD.gets({ room, ...options }); return MessageCRUD.gets({ room, ...options });
} }
async onMessage( async onMessage(ctx: MessageContext, msg: MessageWithSender) {
payload: IBotConfig & { const { sender, text, timestamp = Date.now() } = msg;
sender: User;
text: string;
timestamp?: number;
}
) {
const { sender, text, timestamp = Date.now(), ...botConfig } = payload;
const { room, memory } = await this.get(); const { room, memory } = await this.get();
if (memory) { if (memory) {
const message = await MessageCRUD.addOrUpdate({ const message = await MessageCRUD.addOrUpdate({
@ -62,7 +65,7 @@ export class ConversationManager {
}); });
if (message) { if (message) {
// 异步加入记忆(到 room // 异步加入记忆(到 room
memory?.addMessage2Memory(message,botConfig); memory?.addMessage2Memory(ctx, message);
return message; return message;
} }
} }

View File

@ -1,14 +1,12 @@
import { randomUUID } from "crypto"; import { randomUUID } from "crypto";
import { buildPrompt, formatMsg } from "../../utils/string"; import { buildPrompt, formatMsg } from "../../utils/string";
import { ChatOptions, openai } from "../openai";
import { IBotConfig } from "./config";
import { ConversationManager } from "./conversation";
import { StreamResponse } from "../speaker/stream";
import { QueryMessage, SpeakerAnswer } from "../speaker/speaker";
import { AISpeaker } from "../speaker/ai";
import { DeepPartial } from "../../utils/type"; import { DeepPartial } from "../../utils/type";
import { ChatOptions, openai } from "../openai";
// todo JSON mode 下,无法使用 stream 应答模式在应答完成之前无法构造完整的JSON import { AISpeaker } from "../speaker/ai";
import { QueryMessage, SpeakerAnswer } from "../speaker/speaker";
import { StreamResponse } from "../speaker/stream";
import { IBotConfig } from "./config";
import { ConversationManager, MessageContext } from "./conversation";
const systemTemplate = ` const systemTemplate = `
{{botName}}使 {{botName}}使
@ -91,6 +89,7 @@ export class MyBot {
if (!memory) { if (!memory) {
return {}; return {};
} }
const ctx = { bot, master, room } as MessageContext;
const lastMessages = await this.manager.getMessages({ take: 10 }); const lastMessages = await this.manager.getMessages({ take: 10 });
const shortTermMemories = await memory.getShortTermMemories({ take: 1 }); const shortTermMemories = await memory.getShortTermMemories({ take: 1 });
const shortTermMemory = shortTermMemories[0]?.text ?? "短期记忆为空"; const shortTermMemory = shortTermMemories[0]?.text ?? "短期记忆为空";
@ -126,24 +125,14 @@ export class MyBot {
}), }),
}); });
// 添加请求消息到 DB // 添加请求消息到 DB
await this.manager.onMessage({ await this.manager.onMessage(ctx, { ...msg, sender: master! });
bot: bot!,
master: master!,
room: room!,
sender: master!,
text: msg.text,
timestamp: msg.timestamp,
});
const stream = await MyBot.chatWithStreamResponse({ const stream = await MyBot.chatWithStreamResponse({
system: systemPrompt, system: systemPrompt,
user: userPrompt, user: userPrompt,
onFinished: async (text) => { onFinished: async (text) => {
if (text) { if (text) {
// 添加响应消息到 DB // 添加响应消息到 DB
await this.manager.onMessage({ await this.manager.onMessage(ctx, {
bot: bot!,
master: master!,
room: room!,
text, text,
sender: bot!, sender: bot!,
timestamp: Date.now(), timestamp: Date.now(),

View File

@ -1,12 +1,12 @@
import { Memory, Message, Room, User } from "@prisma/client"; import { Memory, Message, Room, User } from "@prisma/client";
import { firstOf, lastOf } from "../../../utils/base"; import { firstOf, lastOf } from "../../../utils/base";
import { LongTermMemoryAgent } from "./long-term";
import { MemoryCRUD } from "../../db/memory"; import { MemoryCRUD } from "../../db/memory";
import { ShortTermMemoryCRUD } from "../../db/memory-short-term";
import { LongTermMemoryCRUD } from "../../db/memory-long-term"; import { LongTermMemoryCRUD } from "../../db/memory-long-term";
import { ShortTermMemoryAgent } from "./short-term"; import { ShortTermMemoryCRUD } from "../../db/memory-short-term";
import { openai } from "../../openai"; import { openai } from "../../openai";
import { IBotConfig } from "../config"; import { MessageContext } from "../conversation";
import { LongTermMemoryAgent } from "./long-term";
import { ShortTermMemoryAgent } from "./short-term";
export class MemoryManager { export class MemoryManager {
private room: Room; private room: Room;
@ -47,7 +47,7 @@ export class MemoryManager {
} }
private _currentMemory?: Memory; private _currentMemory?: Memory;
async addMessage2Memory(message: Message, botConfig: IBotConfig) { async addMessage2Memory(ctx: MessageContext, message: Message) {
// todo create memory embedding // todo create memory embedding
const currentMemory = await MemoryCRUD.addOrUpdate({ const currentMemory = await MemoryCRUD.addOrUpdate({
msgId: message.id, msgId: message.id,
@ -55,12 +55,12 @@ export class MemoryManager {
ownerId: message.senderId, ownerId: message.senderId,
}); });
if (currentMemory) { if (currentMemory) {
this._onMemory(currentMemory, botConfig); this._onMemory(ctx, currentMemory);
} }
return currentMemory; return currentMemory;
} }
private _onMemory(currentMemory: Memory, botConfig: IBotConfig) { private _onMemory(ctx: MessageContext, currentMemory: Memory) {
if (this._currentMemory) { if (this._currentMemory) {
// 取消之前的更新记忆任务 // 取消之前的更新记忆任务
openai.abort(`update-short-memory-${this._currentMemory.id}`); openai.abort(`update-short-memory-${this._currentMemory.id}`);
@ -68,40 +68,37 @@ export class MemoryManager {
} }
this._currentMemory = currentMemory; this._currentMemory = currentMemory;
// 异步更新长短期记忆 // 异步更新长短期记忆
this.updateLongShortTermMemory({ currentMemory, botConfig }); this.updateLongShortTermMemory(ctx);
} }
/** /**
* *
*/ */
async updateLongShortTermMemory(options: { async updateLongShortTermMemory(
botConfig: IBotConfig; ctx: MessageContext,
currentMemory: Memory; options?: {
shortThreshold?: number; shortThreshold?: number;
longThreshold?: number; longThreshold?: number;
}) { }
const { currentMemory, shortThreshold, longThreshold, botConfig } = ) {
options ?? {}; const { shortThreshold, longThreshold } = options ?? {};
const success = await this._updateShortTermMemory({ const success = await this._updateShortTermMemory(ctx, {
botConfig,
currentMemory,
threshold: shortThreshold, threshold: shortThreshold,
}); });
if (success) { if (success) {
await this._updateLongTermMemory({ await this._updateLongTermMemory(ctx, {
botConfig,
currentMemory,
threshold: longThreshold, threshold: longThreshold,
}); });
} }
} }
private async _updateShortTermMemory(options: { private async _updateShortTermMemory(
botConfig: IBotConfig; ctx: MessageContext,
currentMemory: Memory; options: {
threshold?: number; threshold?: number;
}) { }
const { currentMemory, threshold = 10, botConfig } = options; ) {
const { threshold = 10 } = options;
const lastMemory = firstOf(await this.getShortTermMemories({ take: 1 })); const lastMemory = firstOf(await this.getShortTermMemories({ take: 1 }));
const newMemories: (Memory & { const newMemories: (Memory & {
msg: Message & { msg: Message & {
@ -116,9 +113,7 @@ export class MemoryManager {
if (newMemories.length < 1 || newMemories.length < threshold) { if (newMemories.length < 1 || newMemories.length < threshold) {
return true; return true;
} }
const newMemory = await ShortTermMemoryAgent.generate({ const newMemory = await ShortTermMemoryAgent.generate(ctx, {
botConfig,
currentMemory,
newMemories, newMemories,
lastMemory, lastMemory,
}); });
@ -134,12 +129,13 @@ export class MemoryManager {
return res != null; return res != null;
} }
private async _updateLongTermMemory(options: { private async _updateLongTermMemory(
botConfig: IBotConfig; ctx: MessageContext,
currentMemory: Memory; options: {
threshold?: number; threshold?: number;
}) { }
const { currentMemory, threshold = 10, botConfig } = options; ) {
const { threshold = 10 } = options;
const lastMemory = firstOf(await this.getLongTermMemories({ take: 1 })); const lastMemory = firstOf(await this.getLongTermMemories({ take: 1 }));
const newMemories = await ShortTermMemoryCRUD.gets({ const newMemories = await ShortTermMemoryCRUD.gets({
cursorId: lastMemory?.cursorId, cursorId: lastMemory?.cursorId,
@ -150,9 +146,7 @@ export class MemoryManager {
if (newMemories.length < 1 || newMemories.length < threshold) { if (newMemories.length < 1 || newMemories.length < threshold) {
return true; return true;
} }
const newMemory = await LongTermMemoryAgent.generate({ const newMemory = await LongTermMemoryAgent.generate(ctx, {
botConfig,
currentMemory,
newMemories, newMemories,
lastMemory, lastMemory,
}); });

View File

@ -1,8 +1,8 @@
import { LongTermMemory, Memory, ShortTermMemory } from "@prisma/client"; import { LongTermMemory, ShortTermMemory } from "@prisma/client";
import { openai } from "../../openai";
import { buildPrompt } from "../../../utils/string";
import { jsonDecode, lastOf } from "../../../utils/base"; import { jsonDecode, lastOf } from "../../../utils/base";
import { IBotConfig } from "../config"; import { buildPrompt } from "../../../utils/string";
import { openai } from "../../openai";
import { MessageContext } from "../conversation";
const userTemplate = ` const userTemplate = `
@ -48,19 +48,21 @@ const userTemplate = `
`.trim(); `.trim();
export class LongTermMemoryAgent { export class LongTermMemoryAgent {
static async generate(options: { static async generate(
botConfig: IBotConfig; ctx: MessageContext,
currentMemory: Memory; options: {
newMemories: ShortTermMemory[]; newMemories: ShortTermMemory[];
lastMemory?: LongTermMemory; lastMemory?: LongTermMemory;
}): Promise<string | undefined> { }
const { currentMemory, newMemories, lastMemory, botConfig } = options; ): Promise<string | undefined> {
const { newMemories, lastMemory } = options;
const { bot, master, memory } = ctx;
const res = await openai.chat({ const res = await openai.chat({
jsonMode: true, jsonMode: true,
requestId: `update-long-memory-${currentMemory.id}`, requestId: `update-long-memory-${memory?.id}`,
user: buildPrompt(userTemplate, { user: buildPrompt(userTemplate, {
masterName: botConfig.master.name, masterName: master.name,
botName: botConfig.bot.name, botName: bot.name,
longTermMemory: lastMemory?.text ?? "暂无长期记忆", longTermMemory: lastMemory?.text ?? "暂无长期记忆",
shortTermMemory: lastOf(newMemories)!.text, shortTermMemory: lastOf(newMemories)!.text,
}), }),

View File

@ -1,8 +1,8 @@
import { Memory, Message, ShortTermMemory, User } from "@prisma/client"; import { Memory, Message, ShortTermMemory, User } from "@prisma/client";
import { openai } from "../../openai";
import { buildPrompt, formatMsg } from "../../../utils/string";
import { jsonDecode } from "../../../utils/base"; import { jsonDecode } from "../../../utils/base";
import { IBotConfig } from "../config"; import { buildPrompt, formatMsg } from "../../../utils/string";
import { openai } from "../../openai";
import { MessageContext } from "../conversation";
const userTemplate = ` const userTemplate = `
@ -47,23 +47,25 @@ const userTemplate = `
`.trim(); `.trim();
export class ShortTermMemoryAgent { export class ShortTermMemoryAgent {
static async generate(options: { static async generate(
botConfig: IBotConfig; ctx: MessageContext,
currentMemory: Memory; options: {
newMemories: (Memory & { newMemories: (Memory & {
msg: Message & { msg: Message & {
sender: User; sender: User;
}; };
})[]; })[];
lastMemory?: ShortTermMemory; lastMemory?: ShortTermMemory;
}): Promise<string | undefined> { }
const { currentMemory, newMemories, lastMemory, botConfig } = options; ): Promise<string | undefined> {
const { newMemories, lastMemory } = options;
const { bot, master, memory } = ctx;
const res = await openai.chat({ const res = await openai.chat({
jsonMode: true, jsonMode: true,
requestId: `update-short-memory-${currentMemory.id}`, requestId: `update-short-memory-${memory?.id}`,
user: buildPrompt(userTemplate, { user: buildPrompt(userTemplate, {
masterName: botConfig.master.name, masterName: master.name,
botName: botConfig.bot.name, botName: bot.name,
shortTermMemory: lastMemory?.text ?? "暂无短期记忆", shortTermMemory: lastMemory?.text ?? "暂无短期记忆",
messages: newMemories messages: newMemories
.map((e) => .map((e) =>

View File

@ -4,7 +4,7 @@ import { StreamResponse } from "./stream";
export interface QueryMessage { export interface QueryMessage {
text: string; text: string;
answer: string; answer?: string;
/** /**
* *
*/ */

View File

@ -1,3 +1,8 @@
export type DeepPartial<T> = { export type DeepPartial<T> = {
[P in keyof T]?: T[P] extends object ? DeepPartial<T[P]> : T[P]; [P in keyof T]?: T[P] extends object ? DeepPartial<T[P]> : T[P];
}; };
export type MakeOptional<T, K extends keyof T> = Omit<T, K> &
Partial<Pick<T, K>>;
export type MakeRequired<T, K extends keyof T> = T & Required<Pick<T, K>>;

View File

@ -1,5 +1,8 @@
import { assert } from "console"; import { assert } from "console";
import { ConversationManager } from "../src/services/bot/conversation"; import {
ConversationManager,
MessageContext,
} from "../src/services/bot/conversation";
import { println } from "../src/utils/base"; import { println } from "../src/utils/base";
import { MessageCRUD } from "../src/services/db/message"; import { MessageCRUD } from "../src/services/db/message";
@ -20,32 +23,21 @@ export async function testDB() {
}); });
const { room, bot, master, memory } = await manager.get(); const { room, bot, master, memory } = await manager.get();
assert(room, "❌ 初始化用户失败"); assert(room, "❌ 初始化用户失败");
let message = await manager.onMessage({ const ctx = { bot, master, room } as MessageContext;
bot: bot!, let message = await manager.onMessage(ctx, {
master: master!,
room: room!,
sender: master!, sender: master!,
text: "你好!", text: "你好!",
}); });
assert(message?.text === "你好!", "❌ 插入消息失败"); assert(message?.text === "你好!", "❌ 插入消息失败");
message = await manager.onMessage({ message = await manager.onMessage(ctx, {
bot: bot!,
master: master!,
room: room!,
sender: bot!, sender: bot!,
text: "你好!很高兴认识你", text: "你好!很高兴认识你",
}); });
await manager.onMessage({ await manager.onMessage(ctx, {
bot: bot!,
master: master!,
room: room!,
sender: master!, sender: master!,
text: "你是谁?", text: "你是谁?",
}); });
await manager.onMessage({ await manager.onMessage(ctx, {
bot: bot!,
master: master!,
room: room!,
sender: bot!, sender: bot!,
text: "我是小爱同学,你可以叫我小爱!", text: "我是小爱同学,你可以叫我小爱!",
}); });