feat: add MessageContext

This commit is contained in:
WJG 2024-02-26 12:03:36 +08:00
parent 6c4dd301df
commit 7ad50322fe
No known key found for this signature in database
GPG Key ID: 258474EF8590014A
9 changed files with 111 additions and 124 deletions

View File

@ -1,9 +1,9 @@
import { Room, User } from "@prisma/client";
import { readJSON, writeJSON } from "../../utils/io";
import { deepClone, removeEmpty } from "../../utils/base";
import { UserCRUD } from "../db/user";
import { RoomCRUD, getRoomID } from "../db/room";
import { readJSON, writeJSON } from "../../utils/io";
import { DeepPartial } from "../../utils/type";
import { RoomCRUD, getRoomID } from "../db/room";
import { UserCRUD } from "../db/user";
const kDefaultMaster = {
name: "用户",

View File

@ -1,8 +1,17 @@
import { Message, Prisma, User } from "@prisma/client";
import { MemoryManager } from "./memory";
import { Memory, Prisma, User } from "@prisma/client";
import { DeepPartial, MakeOptional } from "../../utils/type";
import { MessageCRUD } from "../db/message";
import { QueryMessage } from "../speaker/speaker";
import { BotConfig, IBotConfig } from "./config";
import { DeepPartial } from "../../utils/type";
import { MemoryManager } from "./memory";
export interface MessageContext extends IBotConfig {
memory?: Memory;
}
export interface MessageWithSender
extends MakeOptional<QueryMessage, "timestamp"> {
sender: User;
}
export class ConversationManager {
private config: DeepPartial<IBotConfig>;
@ -44,14 +53,8 @@ export class ConversationManager {
return MessageCRUD.gets({ room, ...options });
}
async onMessage(
payload: IBotConfig & {
sender: User;
text: string;
timestamp?: number;
}
) {
const { sender, text, timestamp = Date.now(), ...botConfig } = payload;
async onMessage(ctx: MessageContext, msg: MessageWithSender) {
const { sender, text, timestamp = Date.now() } = msg;
const { room, memory } = await this.get();
if (memory) {
const message = await MessageCRUD.addOrUpdate({
@ -62,7 +65,7 @@ export class ConversationManager {
});
if (message) {
// 异步加入记忆(到 room
memory?.addMessage2Memory(message,botConfig);
memory?.addMessage2Memory(ctx, message);
return message;
}
}

View File

@ -1,14 +1,12 @@
import { randomUUID } from "crypto";
import { buildPrompt, formatMsg } from "../../utils/string";
import { ChatOptions, openai } from "../openai";
import { IBotConfig } from "./config";
import { ConversationManager } from "./conversation";
import { StreamResponse } from "../speaker/stream";
import { QueryMessage, SpeakerAnswer } from "../speaker/speaker";
import { AISpeaker } from "../speaker/ai";
import { DeepPartial } from "../../utils/type";
// todo JSON mode 下,无法使用 stream 应答模式在应答完成之前无法构造完整的JSON
import { ChatOptions, openai } from "../openai";
import { AISpeaker } from "../speaker/ai";
import { QueryMessage, SpeakerAnswer } from "../speaker/speaker";
import { StreamResponse } from "../speaker/stream";
import { IBotConfig } from "./config";
import { ConversationManager, MessageContext } from "./conversation";
const systemTemplate = `
{{botName}}使
@ -91,6 +89,7 @@ export class MyBot {
if (!memory) {
return {};
}
const ctx = { bot, master, room } as MessageContext;
const lastMessages = await this.manager.getMessages({ take: 10 });
const shortTermMemories = await memory.getShortTermMemories({ take: 1 });
const shortTermMemory = shortTermMemories[0]?.text ?? "短期记忆为空";
@ -126,24 +125,14 @@ export class MyBot {
}),
});
// 添加请求消息到 DB
await this.manager.onMessage({
bot: bot!,
master: master!,
room: room!,
sender: master!,
text: msg.text,
timestamp: msg.timestamp,
});
await this.manager.onMessage(ctx, { ...msg, sender: master! });
const stream = await MyBot.chatWithStreamResponse({
system: systemPrompt,
user: userPrompt,
onFinished: async (text) => {
if (text) {
// 添加响应消息到 DB
await this.manager.onMessage({
bot: bot!,
master: master!,
room: room!,
await this.manager.onMessage(ctx, {
text,
sender: bot!,
timestamp: Date.now(),

View File

@ -1,12 +1,12 @@
import { Memory, Message, Room, User } from "@prisma/client";
import { firstOf, lastOf } from "../../../utils/base";
import { LongTermMemoryAgent } from "./long-term";
import { MemoryCRUD } from "../../db/memory";
import { ShortTermMemoryCRUD } from "../../db/memory-short-term";
import { LongTermMemoryCRUD } from "../../db/memory-long-term";
import { ShortTermMemoryAgent } from "./short-term";
import { ShortTermMemoryCRUD } from "../../db/memory-short-term";
import { openai } from "../../openai";
import { IBotConfig } from "../config";
import { MessageContext } from "../conversation";
import { LongTermMemoryAgent } from "./long-term";
import { ShortTermMemoryAgent } from "./short-term";
export class MemoryManager {
private room: Room;
@ -47,7 +47,7 @@ export class MemoryManager {
}
private _currentMemory?: Memory;
async addMessage2Memory(message: Message, botConfig: IBotConfig) {
async addMessage2Memory(ctx: MessageContext, message: Message) {
// todo create memory embedding
const currentMemory = await MemoryCRUD.addOrUpdate({
msgId: message.id,
@ -55,12 +55,12 @@ export class MemoryManager {
ownerId: message.senderId,
});
if (currentMemory) {
this._onMemory(currentMemory, botConfig);
this._onMemory(ctx, currentMemory);
}
return currentMemory;
}
private _onMemory(currentMemory: Memory, botConfig: IBotConfig) {
private _onMemory(ctx: MessageContext, currentMemory: Memory) {
if (this._currentMemory) {
// 取消之前的更新记忆任务
openai.abort(`update-short-memory-${this._currentMemory.id}`);
@ -68,40 +68,37 @@ export class MemoryManager {
}
this._currentMemory = currentMemory;
// 异步更新长短期记忆
this.updateLongShortTermMemory({ currentMemory, botConfig });
this.updateLongShortTermMemory(ctx);
}
/**
*
*/
async updateLongShortTermMemory(options: {
botConfig: IBotConfig;
currentMemory: Memory;
shortThreshold?: number;
longThreshold?: number;
}) {
const { currentMemory, shortThreshold, longThreshold, botConfig } =
options ?? {};
const success = await this._updateShortTermMemory({
botConfig,
currentMemory,
async updateLongShortTermMemory(
ctx: MessageContext,
options?: {
shortThreshold?: number;
longThreshold?: number;
}
) {
const { shortThreshold, longThreshold } = options ?? {};
const success = await this._updateShortTermMemory(ctx, {
threshold: shortThreshold,
});
if (success) {
await this._updateLongTermMemory({
botConfig,
currentMemory,
await this._updateLongTermMemory(ctx, {
threshold: longThreshold,
});
}
}
private async _updateShortTermMemory(options: {
botConfig: IBotConfig;
currentMemory: Memory;
threshold?: number;
}) {
const { currentMemory, threshold = 10, botConfig } = options;
private async _updateShortTermMemory(
ctx: MessageContext,
options: {
threshold?: number;
}
) {
const { threshold = 10 } = options;
const lastMemory = firstOf(await this.getShortTermMemories({ take: 1 }));
const newMemories: (Memory & {
msg: Message & {
@ -116,9 +113,7 @@ export class MemoryManager {
if (newMemories.length < 1 || newMemories.length < threshold) {
return true;
}
const newMemory = await ShortTermMemoryAgent.generate({
botConfig,
currentMemory,
const newMemory = await ShortTermMemoryAgent.generate(ctx, {
newMemories,
lastMemory,
});
@ -134,12 +129,13 @@ export class MemoryManager {
return res != null;
}
private async _updateLongTermMemory(options: {
botConfig: IBotConfig;
currentMemory: Memory;
threshold?: number;
}) {
const { currentMemory, threshold = 10, botConfig } = options;
private async _updateLongTermMemory(
ctx: MessageContext,
options: {
threshold?: number;
}
) {
const { threshold = 10 } = options;
const lastMemory = firstOf(await this.getLongTermMemories({ take: 1 }));
const newMemories = await ShortTermMemoryCRUD.gets({
cursorId: lastMemory?.cursorId,
@ -150,9 +146,7 @@ export class MemoryManager {
if (newMemories.length < 1 || newMemories.length < threshold) {
return true;
}
const newMemory = await LongTermMemoryAgent.generate({
botConfig,
currentMemory,
const newMemory = await LongTermMemoryAgent.generate(ctx, {
newMemories,
lastMemory,
});

View File

@ -1,8 +1,8 @@
import { LongTermMemory, Memory, ShortTermMemory } from "@prisma/client";
import { openai } from "../../openai";
import { buildPrompt } from "../../../utils/string";
import { LongTermMemory, ShortTermMemory } from "@prisma/client";
import { jsonDecode, lastOf } from "../../../utils/base";
import { IBotConfig } from "../config";
import { buildPrompt } from "../../../utils/string";
import { openai } from "../../openai";
import { MessageContext } from "../conversation";
const userTemplate = `
@ -48,19 +48,21 @@ const userTemplate = `
`.trim();
export class LongTermMemoryAgent {
static async generate(options: {
botConfig: IBotConfig;
currentMemory: Memory;
newMemories: ShortTermMemory[];
lastMemory?: LongTermMemory;
}): Promise<string | undefined> {
const { currentMemory, newMemories, lastMemory, botConfig } = options;
static async generate(
ctx: MessageContext,
options: {
newMemories: ShortTermMemory[];
lastMemory?: LongTermMemory;
}
): Promise<string | undefined> {
const { newMemories, lastMemory } = options;
const { bot, master, memory } = ctx;
const res = await openai.chat({
jsonMode: true,
requestId: `update-long-memory-${currentMemory.id}`,
requestId: `update-long-memory-${memory?.id}`,
user: buildPrompt(userTemplate, {
masterName: botConfig.master.name,
botName: botConfig.bot.name,
masterName: master.name,
botName: bot.name,
longTermMemory: lastMemory?.text ?? "暂无长期记忆",
shortTermMemory: lastOf(newMemories)!.text,
}),

View File

@ -1,8 +1,8 @@
import { Memory, Message, ShortTermMemory, User } from "@prisma/client";
import { openai } from "../../openai";
import { buildPrompt, formatMsg } from "../../../utils/string";
import { jsonDecode } from "../../../utils/base";
import { IBotConfig } from "../config";
import { buildPrompt, formatMsg } from "../../../utils/string";
import { openai } from "../../openai";
import { MessageContext } from "../conversation";
const userTemplate = `
@ -47,23 +47,25 @@ const userTemplate = `
`.trim();
export class ShortTermMemoryAgent {
static async generate(options: {
botConfig: IBotConfig;
currentMemory: Memory;
newMemories: (Memory & {
msg: Message & {
sender: User;
};
})[];
lastMemory?: ShortTermMemory;
}): Promise<string | undefined> {
const { currentMemory, newMemories, lastMemory, botConfig } = options;
static async generate(
ctx: MessageContext,
options: {
newMemories: (Memory & {
msg: Message & {
sender: User;
};
})[];
lastMemory?: ShortTermMemory;
}
): Promise<string | undefined> {
const { newMemories, lastMemory } = options;
const { bot, master, memory } = ctx;
const res = await openai.chat({
jsonMode: true,
requestId: `update-short-memory-${currentMemory.id}`,
requestId: `update-short-memory-${memory?.id}`,
user: buildPrompt(userTemplate, {
masterName: botConfig.master.name,
botName: botConfig.bot.name,
masterName: master.name,
botName: bot.name,
shortTermMemory: lastMemory?.text ?? "暂无短期记忆",
messages: newMemories
.map((e) =>

View File

@ -4,7 +4,7 @@ import { StreamResponse } from "./stream";
export interface QueryMessage {
text: string;
answer: string;
answer?: string;
/**
*
*/

View File

@ -1,3 +1,8 @@
export type DeepPartial<T> = {
[P in keyof T]?: T[P] extends object ? DeepPartial<T[P]> : T[P];
};
export type MakeOptional<T, K extends keyof T> = Omit<T, K> &
Partial<Pick<T, K>>;
export type MakeRequired<T, K extends keyof T> = T & Required<Pick<T, K>>;

View File

@ -1,5 +1,8 @@
import { assert } from "console";
import { ConversationManager } from "../src/services/bot/conversation";
import {
ConversationManager,
MessageContext,
} from "../src/services/bot/conversation";
import { println } from "../src/utils/base";
import { MessageCRUD } from "../src/services/db/message";
@ -20,32 +23,21 @@ export async function testDB() {
});
const { room, bot, master, memory } = await manager.get();
assert(room, "❌ 初始化用户失败");
let message = await manager.onMessage({
bot: bot!,
master: master!,
room: room!,
const ctx = { bot, master, room } as MessageContext;
let message = await manager.onMessage(ctx, {
sender: master!,
text: "你好!",
});
assert(message?.text === "你好!", "❌ 插入消息失败");
message = await manager.onMessage({
bot: bot!,
master: master!,
room: room!,
message = await manager.onMessage(ctx, {
sender: bot!,
text: "你好!很高兴认识你",
});
await manager.onMessage({
bot: bot!,
master: master!,
room: room!,
await manager.onMessage(ctx, {
sender: master!,
text: "你是谁?",
});
await manager.onMessage({
bot: bot!,
master: master!,
room: room!,
await manager.onMessage(ctx, {
sender: bot!,
text: "我是小爱同学,你可以叫我小爱!",
});