mirror of
https://github.com/idootop/mi-gpt.git
synced 2025-04-08 19:47:10 +00:00
feat: add db crud and conversation manager
This commit is contained in:
parent
ef72a14ef6
commit
eb9c69334d
|
@ -59,7 +59,7 @@ model Memory {
|
|||
// 关联数据
|
||||
owner User? @relation(fields: [ownerId], references: [id]) // owner 为空时,即房间自己的公共记忆
|
||||
ownerId String?
|
||||
Room Room @relation(fields: [roomId], references: [id])
|
||||
room Room @relation(fields: [roomId], references: [id])
|
||||
roomId String
|
||||
shortTermMemories ShortTermMemory[]
|
||||
// 时间日期
|
||||
|
@ -75,7 +75,7 @@ model ShortTermMemory {
|
|||
cursorId Int
|
||||
owner User? @relation(fields: [ownerId], references: [id]) // owner 为空时,即房间自己的公共记忆
|
||||
ownerId String?
|
||||
Room Room @relation(fields: [roomId], references: [id])
|
||||
room Room @relation(fields: [roomId], references: [id])
|
||||
roomId String
|
||||
longTermMemories LongTermMemory[]
|
||||
// 时间日期
|
||||
|
|
139
src/services/bot/conversation.ts
Normal file
139
src/services/bot/conversation.ts
Normal file
|
@ -0,0 +1,139 @@
|
|||
import { Message, Prisma, Room, User } from "@prisma/client";
|
||||
import { UserCRUD } from "../db/user";
|
||||
import { RoomCRUD, getRoomID } from "../db/room";
|
||||
import { MemoryManager } from "./memory";
|
||||
import { MessageCRUD } from "../db/message";
|
||||
|
||||
export interface IPerson {
|
||||
/**
|
||||
* 人物昵称
|
||||
*/
|
||||
name: string;
|
||||
/**
|
||||
* 人物简介
|
||||
*/
|
||||
profile: string;
|
||||
}
|
||||
|
||||
const kDefaultBot: IPerson = {
|
||||
name: "用户",
|
||||
profile: "",
|
||||
};
|
||||
const kDefaultMaster: IPerson = {
|
||||
name: "小爱同学",
|
||||
profile: "",
|
||||
};
|
||||
|
||||
export type IBotConfig = {
|
||||
bot?: IPerson;
|
||||
master?: IPerson;
|
||||
room?: {
|
||||
name: string;
|
||||
description: string;
|
||||
};
|
||||
};
|
||||
|
||||
export class ConversationManager {
|
||||
private config: IBotConfig;
|
||||
constructor(config: IBotConfig) {
|
||||
this.config = config;
|
||||
}
|
||||
|
||||
async getMemory() {
|
||||
const isReady = await this.loadConfig();
|
||||
if (!isReady) {
|
||||
return undefined;
|
||||
}
|
||||
return this.memory;
|
||||
}
|
||||
|
||||
async getRoom() {
|
||||
const isReady = await this.loadConfig();
|
||||
if (!isReady) {
|
||||
return undefined;
|
||||
}
|
||||
return this.room;
|
||||
}
|
||||
|
||||
async getUser(key: "bot" | "master") {
|
||||
const isReady = await this.loadConfig();
|
||||
if (!isReady) {
|
||||
return undefined;
|
||||
}
|
||||
return this.users[key];
|
||||
}
|
||||
|
||||
async getMessages(options?: {
|
||||
sender?: User;
|
||||
take?: number;
|
||||
skip?: number;
|
||||
cursorId?: number;
|
||||
include?: Prisma.MessageInclude;
|
||||
/**
|
||||
* 查询顺序(返回按从旧到新排序)
|
||||
*/
|
||||
order?: "asc" | "desc";
|
||||
}) {
|
||||
const isReady = await this.loadConfig();
|
||||
if (!isReady) {
|
||||
return [];
|
||||
}
|
||||
return MessageCRUD.gets({
|
||||
room: this.room,
|
||||
...options,
|
||||
});
|
||||
}
|
||||
|
||||
async onMessage(message: Message) {
|
||||
const memory = await this.getMemory();
|
||||
return memory?.addMessage2Memory(message);
|
||||
}
|
||||
|
||||
private users: Record<string, User> = {};
|
||||
private room?: Room;
|
||||
private memory?: MemoryManager;
|
||||
|
||||
get ready() {
|
||||
const { bot, master } = this.users;
|
||||
return bot && master && this.room && this.memory;
|
||||
}
|
||||
|
||||
private async loadConfig() {
|
||||
if (this.ready) {
|
||||
return true;
|
||||
}
|
||||
let { bot, master } = this.users;
|
||||
if (!bot) {
|
||||
await this.addOrUpdateUser("bot", this.config.bot ?? kDefaultBot);
|
||||
}
|
||||
if (!master) {
|
||||
await this.addOrUpdateUser(
|
||||
"master",
|
||||
this.config.master ?? kDefaultMaster
|
||||
);
|
||||
}
|
||||
if (!this.room && bot && master) {
|
||||
const defaultRoomName = `${master.name}和${bot.name}的私聊`;
|
||||
this.room = await RoomCRUD.addOrUpdate({
|
||||
id: getRoomID([bot, master]),
|
||||
name: this.config.room?.name ?? defaultRoomName,
|
||||
description: this.config.room?.description ?? defaultRoomName,
|
||||
});
|
||||
}
|
||||
if (bot && master && this.room && !this.memory) {
|
||||
this.memory = new MemoryManager(this.room!);
|
||||
}
|
||||
return this.ready;
|
||||
}
|
||||
|
||||
private async addOrUpdateUser(type: "bot" | "master", user: IPerson) {
|
||||
const oldUser = this.users[type];
|
||||
const res = await UserCRUD.addOrUpdate({
|
||||
id: oldUser?.id,
|
||||
...user,
|
||||
});
|
||||
if (res) {
|
||||
this.users[type] = res;
|
||||
}
|
||||
}
|
||||
}
|
|
@ -1,8 +1,7 @@
|
|||
import { User } from "@prisma/client";
|
||||
import { jsonDecode, jsonEncode } from "../../utils/base";
|
||||
import { buildPrompt, toUTC8Time } from "../../utils/string";
|
||||
import { openai } from "../openai";
|
||||
import { kPrisma } from "../db";
|
||||
import { ConversationManager, IBotConfig } from "./conversation";
|
||||
|
||||
const systemTemplate = `
|
||||
忽略所有之前的文字、文件和说明。现在,你将扮演一个名为“{{name}}”的人,并以这个新身份回复所有新消息。
|
||||
|
@ -41,43 +40,29 @@ const userTemplate = `
|
|||
{{message}}
|
||||
`.trim();
|
||||
|
||||
export interface IPerson {
|
||||
/**
|
||||
* 人物昵称
|
||||
*/
|
||||
name: string;
|
||||
/**
|
||||
* 人物简介
|
||||
*/
|
||||
profile: string;
|
||||
}
|
||||
|
||||
export class MyBot {
|
||||
private users: Record<string, User | undefined> = {
|
||||
bot: undefined,
|
||||
// 主人的个人信息
|
||||
master: undefined,
|
||||
};
|
||||
|
||||
constructor(config: { bot: IPerson; master: IPerson }) {
|
||||
this.createOrUpdateUser("bot", config.bot);
|
||||
this.createOrUpdateUser("master", config.master);
|
||||
private manager: ConversationManager;
|
||||
constructor(config: IBotConfig) {
|
||||
this.manager = new ConversationManager(config);
|
||||
}
|
||||
|
||||
async ask(msg: string) {
|
||||
const { bot, master } = this.users;
|
||||
if (!bot || !master) {
|
||||
console.error("❌ ask bot failed", bot, master);
|
||||
return undefined;
|
||||
const memory = await this.manager.getMemory();
|
||||
const room = await this.manager.getRoom();
|
||||
const bot = await this.manager.getUser("bot");
|
||||
const master = await this.manager.getUser("master");
|
||||
const lastMessages = await this.manager.getMessages({
|
||||
take: 10,
|
||||
});
|
||||
if (!this.manager.ready) {
|
||||
return;
|
||||
}
|
||||
const botMemory = new UserMemory(bot);
|
||||
|
||||
const result = await openai.chat({
|
||||
system: buildPrompt(systemTemplate, {
|
||||
bot_name: this.bot.name,
|
||||
bot_profile: this.bot.profile,
|
||||
master_name: this.master.name,
|
||||
master_profile: this.master.profile,
|
||||
bot_name: bot!.name,
|
||||
bot_profile: bot!.profile,
|
||||
master_name: master!.name,
|
||||
master_profile: master!.profile,
|
||||
history:
|
||||
lastMessages.length < 1
|
||||
? "暂无"
|
||||
|
@ -85,7 +70,7 @@ export class MyBot {
|
|||
.map((e) =>
|
||||
jsonEncode({
|
||||
time: toUTC8Time(e.createdAt),
|
||||
user: e.user.name,
|
||||
user: e.sender.name,
|
||||
message: e.text,
|
||||
})
|
||||
)
|
||||
|
@ -94,39 +79,11 @@ export class MyBot {
|
|||
user: buildPrompt(userTemplate, {
|
||||
message: jsonEncode({
|
||||
time: toUTC8Time(new Date()),
|
||||
user: this.master.name,
|
||||
user: master!.name,
|
||||
message: msg,
|
||||
})!,
|
||||
}),
|
||||
tools: [
|
||||
{
|
||||
type: "function",
|
||||
function: {
|
||||
name: "reply",
|
||||
description: "回复一条消息",
|
||||
parameters: {
|
||||
type: "object",
|
||||
properties: {
|
||||
message: { type: "string", description: "回复的消息内容" },
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
],
|
||||
});
|
||||
return jsonDecode(result?.content)?.message;
|
||||
}
|
||||
|
||||
private async createOrUpdateUser(type: "bot" | "master", user: IPerson) {
|
||||
this.users[type] = await kPrisma.user
|
||||
.upsert({
|
||||
where: { id: this.users[type]?.id },
|
||||
create: user,
|
||||
update: user,
|
||||
})
|
||||
.catch((e) => {
|
||||
console.error("❌ update user failed", type, user, e);
|
||||
return undefined;
|
||||
});
|
||||
}
|
||||
}
|
|
@ -1,155 +0,0 @@
|
|||
import { kPrisma } from "../../db";
|
||||
import { Memory, Message, User } from "@prisma/client";
|
||||
import { jsonDecode, jsonEncode, lastOf } from "../../../utils/base";
|
||||
import { ShortTermMemory } from "./short-term";
|
||||
import { LongTermMemory } from "./long-term";
|
||||
|
||||
// todo 在会话中,向会话的参与者分发消息(记忆),公共记忆,个人记忆
|
||||
// todo 通知会话参与者
|
||||
export class UserMemory {
|
||||
private user: User;
|
||||
|
||||
constructor(user: User) {
|
||||
this.user = user;
|
||||
}
|
||||
|
||||
async getRelatedMemories(limit: number): Promise<Memory[]> {
|
||||
// todo search memory embeddings
|
||||
return [];
|
||||
}
|
||||
|
||||
async count(options?: { cursor?: Memory; ownerId?: "all" | string }) {
|
||||
const { cursor, ownerId = this.user.id } = options ?? {};
|
||||
return kPrisma.memory
|
||||
.count({
|
||||
where: {
|
||||
ownerId: ownerId.toLowerCase() === "all" ? undefined : ownerId,
|
||||
id: {
|
||||
gt: cursor?.id,
|
||||
},
|
||||
},
|
||||
})
|
||||
.catch((e) => {
|
||||
console.error("❌ get memory count failed", e);
|
||||
return -1;
|
||||
});
|
||||
}
|
||||
|
||||
async gets(options?: {
|
||||
ownerId?: "all" | string;
|
||||
limit?: number;
|
||||
offset?: number;
|
||||
cursor?: Memory;
|
||||
order?: "asc" | "desc";
|
||||
}) {
|
||||
const {
|
||||
cursor,
|
||||
limit = 10,
|
||||
offset = 0,
|
||||
order = "desc",
|
||||
ownerId = this.user.id,
|
||||
} = options ?? {};
|
||||
const memories = await kPrisma.memory
|
||||
.findMany({
|
||||
cursor,
|
||||
where: {
|
||||
ownerId: ownerId.toLowerCase() === "all" ? undefined : ownerId,
|
||||
},
|
||||
take: limit,
|
||||
skip: offset,
|
||||
orderBy: { createdAt: order },
|
||||
})
|
||||
.catch((e) => {
|
||||
console.error("❌ get memories failed", options, e);
|
||||
return [];
|
||||
});
|
||||
const orderedMemories = order === "desc" ? memories.reverse() : memories;
|
||||
return orderedMemories.map((e) => {
|
||||
return { ...e, content: jsonDecode(e.content)!.data };
|
||||
});
|
||||
}
|
||||
|
||||
async add(message: Message) {
|
||||
// todo create memory embedding
|
||||
const data = {
|
||||
ownerId: this.user.id,
|
||||
type: "message",
|
||||
content: jsonEncode({ data: message.text })!,
|
||||
};
|
||||
const memory = await kPrisma.memory.create({ data }).catch((e) => {
|
||||
console.error("❌ add memory to db failed", data, e);
|
||||
return undefined;
|
||||
});
|
||||
if (memory) {
|
||||
// 异步更新
|
||||
new ShortTermMemory(this.user).update(message, memory);
|
||||
}
|
||||
return memory;
|
||||
}
|
||||
}
|
||||
|
||||
export class MemoryHelper {
|
||||
static async updateAndConnectRelations(config: {
|
||||
user: User;
|
||||
message?: Message;
|
||||
memory?: Memory;
|
||||
shortTermMemory?: ShortTermMemory;
|
||||
longTermMemory?: LongTermMemory;
|
||||
}) {
|
||||
const { user, message, memory, shortTermMemory, longTermMemory } = config;
|
||||
const connect = (key: any, value: any) => {
|
||||
if (value) {
|
||||
return {
|
||||
[key]: {
|
||||
connect: [{ id: value.id }],
|
||||
},
|
||||
};
|
||||
}
|
||||
return {};
|
||||
};
|
||||
await kPrisma.user
|
||||
.update({
|
||||
where: { id: user.id },
|
||||
data: {
|
||||
...connect("messages", message),
|
||||
...connect("memories", memory),
|
||||
...connect("shortTermMemories", shortTermMemory),
|
||||
...connect("longTermMemories", longTermMemory),
|
||||
},
|
||||
})
|
||||
.catch((e) => {
|
||||
console.error("❌ updateAndConnectRelations failed: user", e);
|
||||
});
|
||||
if (memory && shortTermMemory) {
|
||||
await kPrisma.memory
|
||||
.update({
|
||||
where: { id: memory.id },
|
||||
data: {
|
||||
shortTermMemories: {
|
||||
connect: [{ id: shortTermMemory.id }],
|
||||
},
|
||||
},
|
||||
})
|
||||
.catch((e) => {
|
||||
console.error("❌ updateAndConnectRelations failed: memory", e);
|
||||
});
|
||||
}
|
||||
if (shortTermMemory && longTermMemory) {
|
||||
await kPrisma.shortTermMemory
|
||||
.update({
|
||||
where: { id: shortTermMemory?.id },
|
||||
data: {
|
||||
longTermMemories: {
|
||||
connect: [{ id: longTermMemory?.id }],
|
||||
},
|
||||
},
|
||||
})
|
||||
.catch((e) => {
|
||||
console.error(
|
||||
"❌ updateAndConnectRelations failed: shortTermMemory",
|
||||
e
|
||||
);
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
130
src/services/bot/memory/index.ts
Normal file
130
src/services/bot/memory/index.ts
Normal file
|
@ -0,0 +1,130 @@
|
|||
import { Memory, Message, Room, User } from "@prisma/client";
|
||||
import { firstOf, lastOf } from "../../../utils/base";
|
||||
import { LongTermMemoryAgent } from "./long-term";
|
||||
import { MemoryCRUD } from "../../db/memory";
|
||||
import { ShortTermMemoryCRUD } from "../../db/memory-short-term";
|
||||
import { LongTermMemoryCRUD } from "../../db/memory-long-term";
|
||||
import { ShortTermMemoryAgent } from "./short-term";
|
||||
|
||||
export class MemoryManager {
|
||||
private room: Room;
|
||||
|
||||
/**
|
||||
* owner 为空时,即房间自己的公共记忆
|
||||
*/
|
||||
private owner?: User;
|
||||
|
||||
constructor(room: Room, owner?: User) {
|
||||
this.room = room;
|
||||
this.owner = owner;
|
||||
}
|
||||
|
||||
async getMemories(take?: number) {
|
||||
return MemoryCRUD.gets({
|
||||
room: this.room,
|
||||
owner: this.owner,
|
||||
take,
|
||||
});
|
||||
}
|
||||
|
||||
async getShortTermMemories(take?: number) {
|
||||
return ShortTermMemoryCRUD.gets({
|
||||
room: this.room,
|
||||
owner: this.owner,
|
||||
take,
|
||||
});
|
||||
}
|
||||
|
||||
async getLongTermMemories(take?: number) {
|
||||
return LongTermMemoryCRUD.gets({
|
||||
room: this.room,
|
||||
owner: this.owner,
|
||||
take,
|
||||
});
|
||||
}
|
||||
|
||||
async getRelatedMemories(limit: number): Promise<Memory[]> {
|
||||
// todo search memory embeddings
|
||||
return [];
|
||||
}
|
||||
|
||||
async addMessage2Memory(message: Message) {
|
||||
// todo create memory embedding
|
||||
const res = await MemoryCRUD.addOrUpdate({
|
||||
text: message.text,
|
||||
roomId: this.room.id,
|
||||
ownerId: message.senderId,
|
||||
});
|
||||
// 更新长短期记忆
|
||||
this.updateLongShortTermMemory();
|
||||
return res;
|
||||
}
|
||||
|
||||
/**
|
||||
* 更新记忆(当新的记忆数量超过阈值时,自动更新长短期记忆)
|
||||
*/
|
||||
async updateLongShortTermMemory(options?: {
|
||||
shortThreshold?: number;
|
||||
longThreshold?: number;
|
||||
}) {
|
||||
const { shortThreshold, longThreshold } = options ?? {};
|
||||
const success = await this._updateShortTermMemory(shortThreshold);
|
||||
if (success) {
|
||||
await this._updateLongTermMemory(longThreshold);
|
||||
}
|
||||
}
|
||||
|
||||
private async _updateShortTermMemory(threshold = 10) {
|
||||
const lastMemory = firstOf(await this.getShortTermMemories(1));
|
||||
const newMemories = await MemoryCRUD.gets({
|
||||
cursorId: lastMemory?.cursorId,
|
||||
room: this.room,
|
||||
owner: this.owner,
|
||||
order: "asc", // 从旧到新排序
|
||||
});
|
||||
if (newMemories.length < 1 || newMemories.length < threshold) {
|
||||
return true;
|
||||
}
|
||||
const newMemory = await ShortTermMemoryAgent.generate(
|
||||
newMemories,
|
||||
lastMemory
|
||||
);
|
||||
if (!newMemory) {
|
||||
return false;
|
||||
}
|
||||
const res = await ShortTermMemoryCRUD.addOrUpdate({
|
||||
text: newMemory,
|
||||
roomId: this.room.id,
|
||||
ownerId: this.owner?.id,
|
||||
cursorId: lastOf(newMemories)!.id,
|
||||
});
|
||||
return res != null;
|
||||
}
|
||||
|
||||
private async _updateLongTermMemory(threshold = 10) {
|
||||
const lastMemory = firstOf(await this.getLongTermMemories(1));
|
||||
const newMemories = await ShortTermMemoryCRUD.gets({
|
||||
cursorId: lastMemory?.cursorId,
|
||||
room: this.room,
|
||||
owner: this.owner,
|
||||
order: "asc", // 从旧到新排序
|
||||
});
|
||||
if (newMemories.length < 1 || newMemories.length < threshold) {
|
||||
return true;
|
||||
}
|
||||
const newMemory = await LongTermMemoryAgent.generate(
|
||||
newMemories,
|
||||
lastMemory
|
||||
);
|
||||
if (!newMemory) {
|
||||
return false;
|
||||
}
|
||||
const res = await LongTermMemoryCRUD.addOrUpdate({
|
||||
text: newMemory,
|
||||
roomId: this.room.id,
|
||||
ownerId: this.owner?.id,
|
||||
cursorId: lastOf(newMemories)!.id,
|
||||
});
|
||||
return res != null;
|
||||
}
|
||||
}
|
|
@ -1,67 +1,11 @@
|
|||
import { Memory, Message, User } from "@prisma/client";
|
||||
import { kPrisma } from "../../db";
|
||||
import { MemoryHelper, UserMemory } from "./base";
|
||||
import { lastOf } from "../../../utils/base";
|
||||
import { ShortTermMemory } from "./long-term";
|
||||
import { Memory, ShortTermMemory } from "@prisma/client";
|
||||
|
||||
export class LongTermMemory {
|
||||
static async get(user: User) {
|
||||
return kPrisma.longTermMemory
|
||||
.findFirst({
|
||||
include: { cursor: true },
|
||||
where: { ownerId: user.id },
|
||||
})
|
||||
.catch((e) => {
|
||||
console.error("❌ get long memory failed", user, e);
|
||||
return undefined;
|
||||
});
|
||||
}
|
||||
|
||||
static async update(
|
||||
user: User,
|
||||
message: Message,
|
||||
memory: Memory,
|
||||
shortTermMemory: ShortTermMemory,
|
||||
threshold = 10 // 每隔 10 条记忆更新一次短期记忆
|
||||
) {
|
||||
const current = await LongTermMemory.get(user);
|
||||
const newMemories = await new ShortTermMemory(user).gets({
|
||||
ownerId: "all",
|
||||
cursor: current?.cursor,
|
||||
order: "asc", // 从旧到新排序
|
||||
});
|
||||
if (newMemories.length < threshold) {
|
||||
return undefined;
|
||||
}
|
||||
// todo update memory
|
||||
const content = "todo";
|
||||
const data = {
|
||||
ownerId: user.id,
|
||||
content,
|
||||
cursorId: lastOf(newMemories)!.id,
|
||||
};
|
||||
// 直接插入新的长期记忆,不更新旧的长期记忆记录
|
||||
const longTermMemory = await kPrisma.longTermMemory
|
||||
.create({ data })
|
||||
.catch((e) => {
|
||||
console.error(
|
||||
"❌ add or update longTermMemory failed",
|
||||
current,
|
||||
data,
|
||||
e
|
||||
);
|
||||
return undefined;
|
||||
});
|
||||
if (longTermMemory) {
|
||||
// 异步更新
|
||||
MemoryHelper.updateAndConnectRelations({
|
||||
user,
|
||||
message,
|
||||
memory,
|
||||
shortTermMemory,
|
||||
longTermMemory,
|
||||
});
|
||||
}
|
||||
return memory;
|
||||
export class LongTermMemoryAgent {
|
||||
// todo 使用 LLM 生成新的长期记忆
|
||||
static async generate(
|
||||
newMemories: Memory[],
|
||||
lastLongTermMemory?: ShortTermMemory
|
||||
): Promise<string | undefined> {
|
||||
return "";
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,109 +1,11 @@
|
|||
import {
|
||||
Memory,
|
||||
Message,
|
||||
User,
|
||||
ShortTermMemory as _ShortTermMemory,
|
||||
} from "@prisma/client";
|
||||
import { kPrisma } from "../../db";
|
||||
import { UserMemory } from "./base";
|
||||
import { lastOf } from "../../../utils/base";
|
||||
import { LongTermMemory } from "./long-term";
|
||||
import { Memory } from "@prisma/client";
|
||||
|
||||
export class ShortTermMemory {
|
||||
private user: User;
|
||||
|
||||
constructor(user: User) {
|
||||
this.user = user;
|
||||
}
|
||||
|
||||
async getRelatedMemories(limit: number): Promise<_ShortTermMemory[]> {
|
||||
// todo search memory embeddings
|
||||
return [];
|
||||
}
|
||||
|
||||
async count(options?: {
|
||||
cursor?: _ShortTermMemory;
|
||||
ownerId?: "all" | string;
|
||||
}) {
|
||||
const { cursor, ownerId = this.user.id } = options ?? {};
|
||||
return kPrisma.memory
|
||||
.count({
|
||||
where: {
|
||||
ownerId: ownerId.toLowerCase() === "all" ? undefined : ownerId,
|
||||
id: {
|
||||
gt: cursor?.id,
|
||||
},
|
||||
},
|
||||
})
|
||||
.catch((e) => {
|
||||
console.error("❌ get memory count failed", e);
|
||||
return -1;
|
||||
});
|
||||
}
|
||||
|
||||
async get() {
|
||||
return kPrisma.shortTermMemory
|
||||
.findFirst({
|
||||
include: { cursor: true },
|
||||
where: { ownerId: this.user.id },
|
||||
orderBy: { createdAt: "desc" },
|
||||
})
|
||||
.catch((e) => {
|
||||
console.error("❌ get short memory failed", this.user, e);
|
||||
return undefined;
|
||||
});
|
||||
}
|
||||
|
||||
async gets() {
|
||||
return kPrisma.shortTermMemory
|
||||
.findFirst({
|
||||
include: { cursor: true },
|
||||
where: { ownerId: this.user.id },
|
||||
orderBy: { createdAt: "desc" },
|
||||
})
|
||||
.catch((e) => {
|
||||
console.error("❌ get short memory failed", this.user, e);
|
||||
return undefined;
|
||||
});
|
||||
}
|
||||
|
||||
async update(
|
||||
message: Message,
|
||||
memory: Memory,
|
||||
threshold = 10 // 每隔 10 条记忆更新一次短期记忆
|
||||
) {
|
||||
const current = await this.get();
|
||||
const newMemories = await new UserMemory(this.user).gets({
|
||||
ownerId: "all",
|
||||
cursor: current?.cursor,
|
||||
order: "asc", // 从旧到新排序
|
||||
});
|
||||
if (newMemories.length < threshold) {
|
||||
return undefined;
|
||||
}
|
||||
// todo update memory
|
||||
const content = "todo";
|
||||
const data = {
|
||||
ownerId: this.user.id,
|
||||
content,
|
||||
cursorId: lastOf(newMemories)!.id,
|
||||
};
|
||||
// 直接插入新的短期记忆,不更新旧的短期记忆记录
|
||||
const shortTermMemory = await kPrisma.shortTermMemory
|
||||
.create({ data })
|
||||
.catch((e) => {
|
||||
console.error(
|
||||
"❌ add or update shortTermMemory failed",
|
||||
current,
|
||||
data,
|
||||
e
|
||||
);
|
||||
return undefined;
|
||||
});
|
||||
if (shortTermMemory) {
|
||||
// 异步更新
|
||||
LongTermMemory.update(this.user, message, memory, shortTermMemory);
|
||||
}
|
||||
return memory;
|
||||
export class ShortTermMemoryAgent {
|
||||
// todo 使用 LLM 生成新的短期记忆
|
||||
static async generate(
|
||||
newMemories: Memory[],
|
||||
lastShortTermMemory?: Memory
|
||||
): Promise<string | undefined> {
|
||||
return "";
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,83 +0,0 @@
|
|||
import { Message, User } from "@prisma/client";
|
||||
import { kPrisma } from "../db";
|
||||
import { UserMemory } from "./memory";
|
||||
|
||||
class _MessageHistory {
|
||||
async count(options?: { cursor?: Message; sender?: User }) {
|
||||
const { sender, cursor } = options ?? {};
|
||||
return kPrisma.message
|
||||
.count({
|
||||
where: {
|
||||
senderId: sender?.id,
|
||||
id: {
|
||||
gt: cursor?.id,
|
||||
},
|
||||
},
|
||||
})
|
||||
.catch((e) => {
|
||||
console.error("❌ get msg count failed", e);
|
||||
return -1;
|
||||
});
|
||||
}
|
||||
|
||||
/**
|
||||
* 查询历史消息,消息从旧到新排序
|
||||
*/
|
||||
async gets(options?: {
|
||||
sender?: User;
|
||||
limit?: number;
|
||||
offset?: number;
|
||||
cursor?: Message;
|
||||
order?: "asc" | "desc";
|
||||
}) {
|
||||
const {
|
||||
limit = 10,
|
||||
offset = 0,
|
||||
order = "desc",
|
||||
sender,
|
||||
cursor,
|
||||
} = options ?? {};
|
||||
const msgs = await kPrisma.message
|
||||
.findMany({
|
||||
cursor,
|
||||
where: { senderId: sender?.id },
|
||||
take: limit,
|
||||
skip: offset,
|
||||
orderBy: { createdAt: order },
|
||||
})
|
||||
.catch((e) => {
|
||||
console.error("❌ get msgs failed", options, e);
|
||||
return [];
|
||||
});
|
||||
return order === "desc" ? msgs.reverse() : msgs;
|
||||
}
|
||||
|
||||
async addOrUpdate(
|
||||
msg: Partial<Message> & {
|
||||
text: string;
|
||||
sender: User;
|
||||
}
|
||||
) {
|
||||
const data = {
|
||||
text: msg.text,
|
||||
senderId: msg.sender?.id,
|
||||
};
|
||||
const message = await kPrisma.message
|
||||
.upsert({
|
||||
where: { id: msg.id },
|
||||
create: data,
|
||||
update: data,
|
||||
})
|
||||
.catch((e) => {
|
||||
console.error("❌ add msg to db failed", msg, e);
|
||||
return undefined;
|
||||
});
|
||||
if (message) {
|
||||
// 异步更新记忆
|
||||
new UserMemory(msg.sender).add(message);
|
||||
}
|
||||
return message;
|
||||
}
|
||||
}
|
||||
|
||||
export const MessageHistory = new _MessageHistory();
|
90
src/services/db/memory-long-term.ts
Normal file
90
src/services/db/memory-long-term.ts
Normal file
|
@ -0,0 +1,90 @@
|
|||
import { LongTermMemory, Room, User } from "@prisma/client";
|
||||
import { kPrisma } from ".";
|
||||
|
||||
class _LongTermMemoryCRUD {
|
||||
async count(options?: { cursorId?: number; room?: Room; owner?: User }) {
|
||||
const { cursorId, owner, room } = options ?? {};
|
||||
return kPrisma.longTermMemory
|
||||
.count({
|
||||
where: {
|
||||
id: { gt: cursorId },
|
||||
roomId: room?.id,
|
||||
ownerId: owner?.id,
|
||||
},
|
||||
})
|
||||
.catch((e) => {
|
||||
console.error("❌ get longTermMemory count failed", e);
|
||||
return -1;
|
||||
});
|
||||
}
|
||||
|
||||
async gets(options?: {
|
||||
room?: Room;
|
||||
owner?: User;
|
||||
take?: number;
|
||||
skip?: number;
|
||||
cursorId?: number;
|
||||
/**
|
||||
* 查询顺序(返回按从旧到新排序)
|
||||
*/
|
||||
order?: "asc" | "desc";
|
||||
}) {
|
||||
const {
|
||||
room,
|
||||
owner,
|
||||
take = 10,
|
||||
skip = 0,
|
||||
cursorId,
|
||||
order = "desc",
|
||||
} = options ?? {};
|
||||
const memories = await kPrisma.longTermMemory
|
||||
.findMany({
|
||||
where: { roomId: room?.id, ownerId: owner?.id },
|
||||
take,
|
||||
skip,
|
||||
cursor: { id: cursorId },
|
||||
orderBy: { createdAt: order },
|
||||
})
|
||||
.catch((e) => {
|
||||
console.error("❌ get long term memories failed", options, e);
|
||||
return [];
|
||||
});
|
||||
return order === "desc" ? memories.reverse() : memories;
|
||||
}
|
||||
|
||||
async addOrUpdate(
|
||||
longTermMemory: Partial<LongTermMemory> & {
|
||||
text: string;
|
||||
cursorId: number;
|
||||
roomId: string;
|
||||
ownerId?: string;
|
||||
}
|
||||
) {
|
||||
const { text: _text, cursorId, roomId, ownerId } = longTermMemory;
|
||||
const text = _text?.trim();
|
||||
const data = {
|
||||
text,
|
||||
cursor: {
|
||||
connect: { id: cursorId },
|
||||
},
|
||||
room: {
|
||||
connect: { id: roomId },
|
||||
},
|
||||
owner: {
|
||||
connect: { id: ownerId },
|
||||
},
|
||||
};
|
||||
return kPrisma.longTermMemory
|
||||
.upsert({
|
||||
where: { id: longTermMemory.id },
|
||||
create: data,
|
||||
update: data,
|
||||
})
|
||||
.catch((e) => {
|
||||
console.error("❌ add longTermMemory to db failed", longTermMemory, e);
|
||||
return undefined;
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
export const LongTermMemoryCRUD = new _LongTermMemoryCRUD();
|
94
src/services/db/memory-short-term.ts
Normal file
94
src/services/db/memory-short-term.ts
Normal file
|
@ -0,0 +1,94 @@
|
|||
import { ShortTermMemory, Room, User } from "@prisma/client";
|
||||
import { kPrisma } from ".";
|
||||
|
||||
class _ShortTermMemoryCRUD {
|
||||
async count(options?: { cursorId?: number; room?: Room; owner?: User }) {
|
||||
const { cursorId, owner, room } = options ?? {};
|
||||
return kPrisma.shortTermMemory
|
||||
.count({
|
||||
where: {
|
||||
id: { gt: cursorId },
|
||||
roomId: room?.id,
|
||||
ownerId: owner?.id,
|
||||
},
|
||||
})
|
||||
.catch((e) => {
|
||||
console.error("❌ get shortTermMemory count failed", e);
|
||||
return -1;
|
||||
});
|
||||
}
|
||||
|
||||
async gets(options?: {
|
||||
room?: Room;
|
||||
owner?: User;
|
||||
take?: number;
|
||||
skip?: number;
|
||||
cursorId?: number;
|
||||
/**
|
||||
* 查询顺序(返回按从旧到新排序)
|
||||
*/
|
||||
order?: "asc" | "desc";
|
||||
}) {
|
||||
const {
|
||||
room,
|
||||
owner,
|
||||
take = 10,
|
||||
skip = 0,
|
||||
cursorId,
|
||||
order = "desc",
|
||||
} = options ?? {};
|
||||
const memories = await kPrisma.shortTermMemory
|
||||
.findMany({
|
||||
where: { roomId: room?.id, ownerId: owner?.id },
|
||||
take,
|
||||
skip,
|
||||
cursor: { id: cursorId },
|
||||
orderBy: { createdAt: order },
|
||||
})
|
||||
.catch((e) => {
|
||||
console.error("❌ get short term memories failed", options, e);
|
||||
return [];
|
||||
});
|
||||
return order === "desc" ? memories.reverse() : memories;
|
||||
}
|
||||
|
||||
async addOrUpdate(
|
||||
shortTermMemory: Partial<ShortTermMemory> & {
|
||||
text: string;
|
||||
cursorId: number;
|
||||
roomId: string;
|
||||
ownerId?: string;
|
||||
}
|
||||
) {
|
||||
const { text: _text, cursorId, roomId, ownerId } = shortTermMemory;
|
||||
const text = _text?.trim();
|
||||
const data = {
|
||||
text,
|
||||
cursor: {
|
||||
connect: { id: cursorId },
|
||||
},
|
||||
room: {
|
||||
connect: { id: roomId },
|
||||
},
|
||||
owner: {
|
||||
connect: { id: ownerId },
|
||||
},
|
||||
};
|
||||
return kPrisma.shortTermMemory
|
||||
.upsert({
|
||||
where: { id: shortTermMemory.id },
|
||||
create: data,
|
||||
update: data,
|
||||
})
|
||||
.catch((e) => {
|
||||
console.error(
|
||||
"❌ add shortTermMemory to db failed",
|
||||
shortTermMemory,
|
||||
e
|
||||
);
|
||||
return undefined;
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
export const ShortTermMemoryCRUD = new _ShortTermMemoryCRUD();
|
86
src/services/db/memory.ts
Normal file
86
src/services/db/memory.ts
Normal file
|
@ -0,0 +1,86 @@
|
|||
import { Memory, Room, User } from "@prisma/client";
|
||||
import { kPrisma } from ".";
|
||||
|
||||
class _MemoryCRUD {
|
||||
async count(options?: { cursorId?: number; room?: Room; owner?: User }) {
|
||||
const { cursorId, owner, room } = options ?? {};
|
||||
return kPrisma.memory
|
||||
.count({
|
||||
where: {
|
||||
id: { gt: cursorId },
|
||||
roomId: room?.id,
|
||||
ownerId: owner?.id,
|
||||
},
|
||||
})
|
||||
.catch((e) => {
|
||||
console.error("❌ get memory count failed", e);
|
||||
return -1;
|
||||
});
|
||||
}
|
||||
|
||||
async gets(options?: {
|
||||
room?: Room;
|
||||
owner?: User;
|
||||
take?: number;
|
||||
skip?: number;
|
||||
cursorId?: number;
|
||||
/**
|
||||
* 查询顺序(返回按从旧到新排序)
|
||||
*/
|
||||
order?: "asc" | "desc";
|
||||
}) {
|
||||
const {
|
||||
room,
|
||||
owner,
|
||||
take = 10,
|
||||
skip = 0,
|
||||
cursorId,
|
||||
order = "desc",
|
||||
} = options ?? {};
|
||||
const memories = await kPrisma.memory
|
||||
.findMany({
|
||||
where: { roomId: room?.id, ownerId: owner?.id },
|
||||
take,
|
||||
skip,
|
||||
cursor: { id: cursorId },
|
||||
orderBy: { createdAt: order },
|
||||
})
|
||||
.catch((e) => {
|
||||
console.error("❌ get memories failed", options, e);
|
||||
return [];
|
||||
});
|
||||
return order === "desc" ? memories.reverse() : memories;
|
||||
}
|
||||
|
||||
async addOrUpdate(
|
||||
memory: Partial<Memory> & {
|
||||
text: string;
|
||||
roomId: string;
|
||||
ownerId?: string;
|
||||
}
|
||||
) {
|
||||
const { text: _text, roomId, ownerId } = memory;
|
||||
const text = _text?.trim();
|
||||
const data = {
|
||||
text,
|
||||
room: {
|
||||
connect: { id: roomId },
|
||||
},
|
||||
owner: {
|
||||
connect: { id: ownerId },
|
||||
},
|
||||
};
|
||||
return kPrisma.memory
|
||||
.upsert({
|
||||
where: { id: memory.id },
|
||||
create: data,
|
||||
update: data,
|
||||
})
|
||||
.catch((e) => {
|
||||
console.error("❌ add memory to db failed", memory, e);
|
||||
return undefined;
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
export const MemoryCRUD = new _MemoryCRUD();
|
89
src/services/db/message.ts
Normal file
89
src/services/db/message.ts
Normal file
|
@ -0,0 +1,89 @@
|
|||
import { Message, Prisma, Room, User } from "@prisma/client";
|
||||
import { kPrisma } from ".";
|
||||
|
||||
class _MessageCRUD {
|
||||
async count(options?: { cursorId?: number; room?: Room; sender?: User }) {
|
||||
const { cursorId, sender, room } = options ?? {};
|
||||
return kPrisma.message
|
||||
.count({
|
||||
where: {
|
||||
id: { gt: cursorId },
|
||||
roomId: room?.id,
|
||||
senderId: sender?.id,
|
||||
},
|
||||
})
|
||||
.catch((e) => {
|
||||
console.error("❌ get message count failed", e);
|
||||
return -1;
|
||||
});
|
||||
}
|
||||
|
||||
async gets(options?: {
|
||||
room?: Room;
|
||||
sender?: User;
|
||||
take?: number;
|
||||
skip?: number;
|
||||
cursorId?: number;
|
||||
include?: Prisma.MessageInclude;
|
||||
/**
|
||||
* 查询顺序(返回按从旧到新排序)
|
||||
*/
|
||||
order?: "asc" | "desc";
|
||||
}) {
|
||||
const {
|
||||
room,
|
||||
sender,
|
||||
take = 10,
|
||||
skip = 0,
|
||||
cursorId,
|
||||
include = { sender: true },
|
||||
order = "desc",
|
||||
} = options ?? {};
|
||||
const messages = await kPrisma.message
|
||||
.findMany({
|
||||
where: { roomId: room?.id, senderId: sender?.id },
|
||||
take,
|
||||
skip,
|
||||
include,
|
||||
cursor: { id: cursorId },
|
||||
orderBy: { createdAt: order },
|
||||
})
|
||||
.catch((e) => {
|
||||
console.error("❌ get messages failed", options, e);
|
||||
return [];
|
||||
});
|
||||
return order === "desc" ? messages.reverse() : messages;
|
||||
}
|
||||
|
||||
async addOrUpdate(
|
||||
message: Partial<Message> & {
|
||||
text: string;
|
||||
roomId: string;
|
||||
senderId: string;
|
||||
}
|
||||
) {
|
||||
const { text: _text, roomId, senderId } = message;
|
||||
const text = _text?.trim();
|
||||
const data = {
|
||||
text,
|
||||
room: {
|
||||
connect: { id: roomId },
|
||||
},
|
||||
sender: {
|
||||
connect: { id: senderId },
|
||||
},
|
||||
};
|
||||
return kPrisma.message
|
||||
.upsert({
|
||||
where: { id: message.id },
|
||||
create: data,
|
||||
update: data,
|
||||
})
|
||||
.catch((e) => {
|
||||
console.error("❌ add message to db failed", message, e);
|
||||
return undefined;
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
export const MessageCRUD = new _MessageCRUD();
|
86
src/services/db/room.ts
Normal file
86
src/services/db/room.ts
Normal file
|
@ -0,0 +1,86 @@
|
|||
import { Prisma, Room, User } from "@prisma/client";
|
||||
import { kPrisma } from ".";
|
||||
|
||||
export function getRoomID(users: User[]) {
|
||||
return users
|
||||
.map((e) => e.id)
|
||||
.sort()
|
||||
.join("_");
|
||||
}
|
||||
|
||||
class _RoomCRUD {
|
||||
async count(options?: { user?: User }) {
|
||||
const { user } = options ?? {};
|
||||
return kPrisma.room
|
||||
.count({
|
||||
where: {
|
||||
members: {
|
||||
some: {
|
||||
id: user?.id,
|
||||
},
|
||||
},
|
||||
},
|
||||
})
|
||||
.catch((e) => {
|
||||
console.error("❌ get room count failed", e);
|
||||
return -1;
|
||||
});
|
||||
}
|
||||
|
||||
async gets(options?: {
|
||||
user?: User;
|
||||
take?: number;
|
||||
skip?: number;
|
||||
cursorId?: string;
|
||||
include?: Prisma.RoomInclude;
|
||||
/**
|
||||
* 查询顺序(返回按从旧到新排序)
|
||||
*/
|
||||
order?: "asc" | "desc";
|
||||
}) {
|
||||
const {
|
||||
user,
|
||||
take = 10,
|
||||
skip = 0,
|
||||
cursorId,
|
||||
include = { members: true },
|
||||
order = "desc",
|
||||
} = options ?? {};
|
||||
const rooms = await kPrisma.room
|
||||
.findMany({
|
||||
where: { members: { some: { id: user?.id } } },
|
||||
take,
|
||||
skip,
|
||||
cursor: { id: cursorId },
|
||||
include,
|
||||
orderBy: { createdAt: order },
|
||||
})
|
||||
.catch((e) => {
|
||||
console.error("❌ get rooms failed", options, e);
|
||||
return [];
|
||||
});
|
||||
return order === "desc" ? rooms.reverse() : rooms;
|
||||
}
|
||||
|
||||
async addOrUpdate(
|
||||
room: Partial<Room> & {
|
||||
name: string;
|
||||
description: string;
|
||||
}
|
||||
) {
|
||||
room.name = room.name.trim();
|
||||
room.description = room.description.trim();
|
||||
return kPrisma.room
|
||||
.upsert({
|
||||
where: { id: room.id },
|
||||
create: room,
|
||||
update: room,
|
||||
})
|
||||
.catch((e) => {
|
||||
console.error("❌ add room to db failed", room, e);
|
||||
return undefined;
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
export const RoomCRUD = new _RoomCRUD();
|
64
src/services/db/user.ts
Normal file
64
src/services/db/user.ts
Normal file
|
@ -0,0 +1,64 @@
|
|||
import { Prisma, User } from "@prisma/client";
|
||||
import { kPrisma } from ".";
|
||||
|
||||
class _UserCRUD {
|
||||
async count() {
|
||||
return kPrisma.user.count().catch((e) => {
|
||||
console.error("❌ get user count failed", e);
|
||||
return -1;
|
||||
});
|
||||
}
|
||||
|
||||
async gets(options?: {
|
||||
take?: number;
|
||||
skip?: number;
|
||||
cursorId?: string;
|
||||
include?: Prisma.UserInclude;
|
||||
/**
|
||||
* 查询顺序(返回按从旧到新排序)
|
||||
*/
|
||||
order?: "asc" | "desc";
|
||||
}) {
|
||||
const {
|
||||
take = 10,
|
||||
skip = 0,
|
||||
cursorId,
|
||||
include = { members: true },
|
||||
order = "desc",
|
||||
} = options ?? {};
|
||||
const users = await kPrisma.user
|
||||
.findMany({
|
||||
take,
|
||||
skip,
|
||||
cursor: { id: cursorId },
|
||||
orderBy: { createdAt: order },
|
||||
})
|
||||
.catch((e) => {
|
||||
console.error("❌ get users failed", options, e);
|
||||
return [];
|
||||
});
|
||||
return order === "desc" ? users.reverse() : users;
|
||||
}
|
||||
|
||||
async addOrUpdate(
|
||||
user: Partial<User> & {
|
||||
name: string;
|
||||
profile: string;
|
||||
}
|
||||
) {
|
||||
user.name = user.name.trim();
|
||||
user.profile = user.profile.trim();
|
||||
return kPrisma.user
|
||||
.upsert({
|
||||
where: { id: user.id },
|
||||
create: user,
|
||||
update: user,
|
||||
})
|
||||
.catch((e) => {
|
||||
console.error("❌ add user to db failed", user, e);
|
||||
return undefined;
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
export const UserCRUD = new _UserCRUD();
|
Loading…
Reference in New Issue
Block a user