feat: add db crud and conversation manager

This commit is contained in:
WJG 2024-01-28 18:41:40 +08:00
parent ef72a14ef6
commit eb9c69334d
No known key found for this signature in database
GPG Key ID: 258474EF8590014A
15 changed files with 815 additions and 472 deletions

View File

@ -59,7 +59,7 @@ model Memory {
// 关联数据 // 关联数据
owner User? @relation(fields: [ownerId], references: [id]) // owner 为空时,即房间自己的公共记忆 owner User? @relation(fields: [ownerId], references: [id]) // owner 为空时,即房间自己的公共记忆
ownerId String? ownerId String?
Room Room @relation(fields: [roomId], references: [id]) room Room @relation(fields: [roomId], references: [id])
roomId String roomId String
shortTermMemories ShortTermMemory[] shortTermMemories ShortTermMemory[]
// 时间日期 // 时间日期
@ -75,7 +75,7 @@ model ShortTermMemory {
cursorId Int cursorId Int
owner User? @relation(fields: [ownerId], references: [id]) // owner 为空时,即房间自己的公共记忆 owner User? @relation(fields: [ownerId], references: [id]) // owner 为空时,即房间自己的公共记忆
ownerId String? ownerId String?
Room Room @relation(fields: [roomId], references: [id]) room Room @relation(fields: [roomId], references: [id])
roomId String roomId String
longTermMemories LongTermMemory[] longTermMemories LongTermMemory[]
// 时间日期 // 时间日期

View File

@ -0,0 +1,139 @@
import { Message, Prisma, Room, User } from "@prisma/client";
import { UserCRUD } from "../db/user";
import { RoomCRUD, getRoomID } from "../db/room";
import { MemoryManager } from "./memory";
import { MessageCRUD } from "../db/message";
export interface IPerson {
/**
*
*/
name: string;
/**
*
*/
profile: string;
}
const kDefaultBot: IPerson = {
name: "用户",
profile: "",
};
const kDefaultMaster: IPerson = {
name: "小爱同学",
profile: "",
};
export type IBotConfig = {
bot?: IPerson;
master?: IPerson;
room?: {
name: string;
description: string;
};
};
export class ConversationManager {
private config: IBotConfig;
constructor(config: IBotConfig) {
this.config = config;
}
async getMemory() {
const isReady = await this.loadConfig();
if (!isReady) {
return undefined;
}
return this.memory;
}
async getRoom() {
const isReady = await this.loadConfig();
if (!isReady) {
return undefined;
}
return this.room;
}
async getUser(key: "bot" | "master") {
const isReady = await this.loadConfig();
if (!isReady) {
return undefined;
}
return this.users[key];
}
async getMessages(options?: {
sender?: User;
take?: number;
skip?: number;
cursorId?: number;
include?: Prisma.MessageInclude;
/**
*
*/
order?: "asc" | "desc";
}) {
const isReady = await this.loadConfig();
if (!isReady) {
return [];
}
return MessageCRUD.gets({
room: this.room,
...options,
});
}
async onMessage(message: Message) {
const memory = await this.getMemory();
return memory?.addMessage2Memory(message);
}
private users: Record<string, User> = {};
private room?: Room;
private memory?: MemoryManager;
get ready() {
const { bot, master } = this.users;
return bot && master && this.room && this.memory;
}
private async loadConfig() {
if (this.ready) {
return true;
}
let { bot, master } = this.users;
if (!bot) {
await this.addOrUpdateUser("bot", this.config.bot ?? kDefaultBot);
}
if (!master) {
await this.addOrUpdateUser(
"master",
this.config.master ?? kDefaultMaster
);
}
if (!this.room && bot && master) {
const defaultRoomName = `${master.name}${bot.name}的私聊`;
this.room = await RoomCRUD.addOrUpdate({
id: getRoomID([bot, master]),
name: this.config.room?.name ?? defaultRoomName,
description: this.config.room?.description ?? defaultRoomName,
});
}
if (bot && master && this.room && !this.memory) {
this.memory = new MemoryManager(this.room!);
}
return this.ready;
}
private async addOrUpdateUser(type: "bot" | "master", user: IPerson) {
const oldUser = this.users[type];
const res = await UserCRUD.addOrUpdate({
id: oldUser?.id,
...user,
});
if (res) {
this.users[type] = res;
}
}
}

View File

@ -1,8 +1,7 @@
import { User } from "@prisma/client";
import { jsonDecode, jsonEncode } from "../../utils/base"; import { jsonDecode, jsonEncode } from "../../utils/base";
import { buildPrompt, toUTC8Time } from "../../utils/string"; import { buildPrompt, toUTC8Time } from "../../utils/string";
import { openai } from "../openai"; import { openai } from "../openai";
import { kPrisma } from "../db"; import { ConversationManager, IBotConfig } from "./conversation";
const systemTemplate = ` const systemTemplate = `
{{name}} {{name}}
@ -41,43 +40,29 @@ const userTemplate = `
{{message}} {{message}}
`.trim(); `.trim();
export interface IPerson {
/**
*
*/
name: string;
/**
*
*/
profile: string;
}
export class MyBot { export class MyBot {
private users: Record<string, User | undefined> = { private manager: ConversationManager;
bot: undefined, constructor(config: IBotConfig) {
// 主人的个人信息 this.manager = new ConversationManager(config);
master: undefined,
};
constructor(config: { bot: IPerson; master: IPerson }) {
this.createOrUpdateUser("bot", config.bot);
this.createOrUpdateUser("master", config.master);
} }
async ask(msg: string) { async ask(msg: string) {
const { bot, master } = this.users; const memory = await this.manager.getMemory();
if (!bot || !master) { const room = await this.manager.getRoom();
console.error("❌ ask bot failed", bot, master); const bot = await this.manager.getUser("bot");
return undefined; const master = await this.manager.getUser("master");
const lastMessages = await this.manager.getMessages({
take: 10,
});
if (!this.manager.ready) {
return;
} }
const botMemory = new UserMemory(bot);
const result = await openai.chat({ const result = await openai.chat({
system: buildPrompt(systemTemplate, { system: buildPrompt(systemTemplate, {
bot_name: this.bot.name, bot_name: bot!.name,
bot_profile: this.bot.profile, bot_profile: bot!.profile,
master_name: this.master.name, master_name: master!.name,
master_profile: this.master.profile, master_profile: master!.profile,
history: history:
lastMessages.length < 1 lastMessages.length < 1
? "暂无" ? "暂无"
@ -85,7 +70,7 @@ export class MyBot {
.map((e) => .map((e) =>
jsonEncode({ jsonEncode({
time: toUTC8Time(e.createdAt), time: toUTC8Time(e.createdAt),
user: e.user.name, user: e.sender.name,
message: e.text, message: e.text,
}) })
) )
@ -94,39 +79,11 @@ export class MyBot {
user: buildPrompt(userTemplate, { user: buildPrompt(userTemplate, {
message: jsonEncode({ message: jsonEncode({
time: toUTC8Time(new Date()), time: toUTC8Time(new Date()),
user: this.master.name, user: master!.name,
message: msg, message: msg,
})!, })!,
}), }),
tools: [
{
type: "function",
function: {
name: "reply",
description: "回复一条消息",
parameters: {
type: "object",
properties: {
message: { type: "string", description: "回复的消息内容" },
},
},
},
},
],
}); });
return jsonDecode(result?.content)?.message; return jsonDecode(result?.content)?.message;
} }
private async createOrUpdateUser(type: "bot" | "master", user: IPerson) {
this.users[type] = await kPrisma.user
.upsert({
where: { id: this.users[type]?.id },
create: user,
update: user,
})
.catch((e) => {
console.error("❌ update user failed", type, user, e);
return undefined;
});
}
} }

View File

@ -1,155 +0,0 @@
import { kPrisma } from "../../db";
import { Memory, Message, User } from "@prisma/client";
import { jsonDecode, jsonEncode, lastOf } from "../../../utils/base";
import { ShortTermMemory } from "./short-term";
import { LongTermMemory } from "./long-term";
// todo 在会话中,向会话的参与者分发消息(记忆),公共记忆,个人记忆
// todo 通知会话参与者
export class UserMemory {
private user: User;
constructor(user: User) {
this.user = user;
}
async getRelatedMemories(limit: number): Promise<Memory[]> {
// todo search memory embeddings
return [];
}
async count(options?: { cursor?: Memory; ownerId?: "all" | string }) {
const { cursor, ownerId = this.user.id } = options ?? {};
return kPrisma.memory
.count({
where: {
ownerId: ownerId.toLowerCase() === "all" ? undefined : ownerId,
id: {
gt: cursor?.id,
},
},
})
.catch((e) => {
console.error("❌ get memory count failed", e);
return -1;
});
}
async gets(options?: {
ownerId?: "all" | string;
limit?: number;
offset?: number;
cursor?: Memory;
order?: "asc" | "desc";
}) {
const {
cursor,
limit = 10,
offset = 0,
order = "desc",
ownerId = this.user.id,
} = options ?? {};
const memories = await kPrisma.memory
.findMany({
cursor,
where: {
ownerId: ownerId.toLowerCase() === "all" ? undefined : ownerId,
},
take: limit,
skip: offset,
orderBy: { createdAt: order },
})
.catch((e) => {
console.error("❌ get memories failed", options, e);
return [];
});
const orderedMemories = order === "desc" ? memories.reverse() : memories;
return orderedMemories.map((e) => {
return { ...e, content: jsonDecode(e.content)!.data };
});
}
async add(message: Message) {
// todo create memory embedding
const data = {
ownerId: this.user.id,
type: "message",
content: jsonEncode({ data: message.text })!,
};
const memory = await kPrisma.memory.create({ data }).catch((e) => {
console.error("❌ add memory to db failed", data, e);
return undefined;
});
if (memory) {
// 异步更新
new ShortTermMemory(this.user).update(message, memory);
}
return memory;
}
}
export class MemoryHelper {
static async updateAndConnectRelations(config: {
user: User;
message?: Message;
memory?: Memory;
shortTermMemory?: ShortTermMemory;
longTermMemory?: LongTermMemory;
}) {
const { user, message, memory, shortTermMemory, longTermMemory } = config;
const connect = (key: any, value: any) => {
if (value) {
return {
[key]: {
connect: [{ id: value.id }],
},
};
}
return {};
};
await kPrisma.user
.update({
where: { id: user.id },
data: {
...connect("messages", message),
...connect("memories", memory),
...connect("shortTermMemories", shortTermMemory),
...connect("longTermMemories", longTermMemory),
},
})
.catch((e) => {
console.error("❌ updateAndConnectRelations failed: user", e);
});
if (memory && shortTermMemory) {
await kPrisma.memory
.update({
where: { id: memory.id },
data: {
shortTermMemories: {
connect: [{ id: shortTermMemory.id }],
},
},
})
.catch((e) => {
console.error("❌ updateAndConnectRelations failed: memory", e);
});
}
if (shortTermMemory && longTermMemory) {
await kPrisma.shortTermMemory
.update({
where: { id: shortTermMemory?.id },
data: {
longTermMemories: {
connect: [{ id: longTermMemory?.id }],
},
},
})
.catch((e) => {
console.error(
"❌ updateAndConnectRelations failed: shortTermMemory",
e
);
});
}
}
}

View File

@ -0,0 +1,130 @@
import { Memory, Message, Room, User } from "@prisma/client";
import { firstOf, lastOf } from "../../../utils/base";
import { LongTermMemoryAgent } from "./long-term";
import { MemoryCRUD } from "../../db/memory";
import { ShortTermMemoryCRUD } from "../../db/memory-short-term";
import { LongTermMemoryCRUD } from "../../db/memory-long-term";
import { ShortTermMemoryAgent } from "./short-term";
export class MemoryManager {
private room: Room;
/**
* owner
*/
private owner?: User;
constructor(room: Room, owner?: User) {
this.room = room;
this.owner = owner;
}
async getMemories(take?: number) {
return MemoryCRUD.gets({
room: this.room,
owner: this.owner,
take,
});
}
async getShortTermMemories(take?: number) {
return ShortTermMemoryCRUD.gets({
room: this.room,
owner: this.owner,
take,
});
}
async getLongTermMemories(take?: number) {
return LongTermMemoryCRUD.gets({
room: this.room,
owner: this.owner,
take,
});
}
async getRelatedMemories(limit: number): Promise<Memory[]> {
// todo search memory embeddings
return [];
}
async addMessage2Memory(message: Message) {
// todo create memory embedding
const res = await MemoryCRUD.addOrUpdate({
text: message.text,
roomId: this.room.id,
ownerId: message.senderId,
});
// 更新长短期记忆
this.updateLongShortTermMemory();
return res;
}
/**
*
*/
async updateLongShortTermMemory(options?: {
shortThreshold?: number;
longThreshold?: number;
}) {
const { shortThreshold, longThreshold } = options ?? {};
const success = await this._updateShortTermMemory(shortThreshold);
if (success) {
await this._updateLongTermMemory(longThreshold);
}
}
private async _updateShortTermMemory(threshold = 10) {
const lastMemory = firstOf(await this.getShortTermMemories(1));
const newMemories = await MemoryCRUD.gets({
cursorId: lastMemory?.cursorId,
room: this.room,
owner: this.owner,
order: "asc", // 从旧到新排序
});
if (newMemories.length < 1 || newMemories.length < threshold) {
return true;
}
const newMemory = await ShortTermMemoryAgent.generate(
newMemories,
lastMemory
);
if (!newMemory) {
return false;
}
const res = await ShortTermMemoryCRUD.addOrUpdate({
text: newMemory,
roomId: this.room.id,
ownerId: this.owner?.id,
cursorId: lastOf(newMemories)!.id,
});
return res != null;
}
private async _updateLongTermMemory(threshold = 10) {
const lastMemory = firstOf(await this.getLongTermMemories(1));
const newMemories = await ShortTermMemoryCRUD.gets({
cursorId: lastMemory?.cursorId,
room: this.room,
owner: this.owner,
order: "asc", // 从旧到新排序
});
if (newMemories.length < 1 || newMemories.length < threshold) {
return true;
}
const newMemory = await LongTermMemoryAgent.generate(
newMemories,
lastMemory
);
if (!newMemory) {
return false;
}
const res = await LongTermMemoryCRUD.addOrUpdate({
text: newMemory,
roomId: this.room.id,
ownerId: this.owner?.id,
cursorId: lastOf(newMemories)!.id,
});
return res != null;
}
}

View File

@ -1,67 +1,11 @@
import { Memory, Message, User } from "@prisma/client"; import { Memory, ShortTermMemory } from "@prisma/client";
import { kPrisma } from "../../db";
import { MemoryHelper, UserMemory } from "./base";
import { lastOf } from "../../../utils/base";
import { ShortTermMemory } from "./long-term";
export class LongTermMemory { export class LongTermMemoryAgent {
static async get(user: User) { // todo 使用 LLM 生成新的长期记忆
return kPrisma.longTermMemory static async generate(
.findFirst({ newMemories: Memory[],
include: { cursor: true }, lastLongTermMemory?: ShortTermMemory
where: { ownerId: user.id }, ): Promise<string | undefined> {
}) return "";
.catch((e) => {
console.error("❌ get long memory failed", user, e);
return undefined;
});
}
static async update(
user: User,
message: Message,
memory: Memory,
shortTermMemory: ShortTermMemory,
threshold = 10 // 每隔 10 条记忆更新一次短期记忆
) {
const current = await LongTermMemory.get(user);
const newMemories = await new ShortTermMemory(user).gets({
ownerId: "all",
cursor: current?.cursor,
order: "asc", // 从旧到新排序
});
if (newMemories.length < threshold) {
return undefined;
}
// todo update memory
const content = "todo";
const data = {
ownerId: user.id,
content,
cursorId: lastOf(newMemories)!.id,
};
// 直接插入新的长期记忆,不更新旧的长期记忆记录
const longTermMemory = await kPrisma.longTermMemory
.create({ data })
.catch((e) => {
console.error(
"❌ add or update longTermMemory failed",
current,
data,
e
);
return undefined;
});
if (longTermMemory) {
// 异步更新
MemoryHelper.updateAndConnectRelations({
user,
message,
memory,
shortTermMemory,
longTermMemory,
});
}
return memory;
} }
} }

View File

@ -1,109 +1,11 @@
import { import { Memory } from "@prisma/client";
Memory,
Message,
User,
ShortTermMemory as _ShortTermMemory,
} from "@prisma/client";
import { kPrisma } from "../../db";
import { UserMemory } from "./base";
import { lastOf } from "../../../utils/base";
import { LongTermMemory } from "./long-term";
export class ShortTermMemory { export class ShortTermMemoryAgent {
private user: User; // todo 使用 LLM 生成新的短期记忆
static async generate(
constructor(user: User) { newMemories: Memory[],
this.user = user; lastShortTermMemory?: Memory
} ): Promise<string | undefined> {
return "";
async getRelatedMemories(limit: number): Promise<_ShortTermMemory[]> {
// todo search memory embeddings
return [];
}
async count(options?: {
cursor?: _ShortTermMemory;
ownerId?: "all" | string;
}) {
const { cursor, ownerId = this.user.id } = options ?? {};
return kPrisma.memory
.count({
where: {
ownerId: ownerId.toLowerCase() === "all" ? undefined : ownerId,
id: {
gt: cursor?.id,
},
},
})
.catch((e) => {
console.error("❌ get memory count failed", e);
return -1;
});
}
async get() {
return kPrisma.shortTermMemory
.findFirst({
include: { cursor: true },
where: { ownerId: this.user.id },
orderBy: { createdAt: "desc" },
})
.catch((e) => {
console.error("❌ get short memory failed", this.user, e);
return undefined;
});
}
async gets() {
return kPrisma.shortTermMemory
.findFirst({
include: { cursor: true },
where: { ownerId: this.user.id },
orderBy: { createdAt: "desc" },
})
.catch((e) => {
console.error("❌ get short memory failed", this.user, e);
return undefined;
});
}
async update(
message: Message,
memory: Memory,
threshold = 10 // 每隔 10 条记忆更新一次短期记忆
) {
const current = await this.get();
const newMemories = await new UserMemory(this.user).gets({
ownerId: "all",
cursor: current?.cursor,
order: "asc", // 从旧到新排序
});
if (newMemories.length < threshold) {
return undefined;
}
// todo update memory
const content = "todo";
const data = {
ownerId: this.user.id,
content,
cursorId: lastOf(newMemories)!.id,
};
// 直接插入新的短期记忆,不更新旧的短期记忆记录
const shortTermMemory = await kPrisma.shortTermMemory
.create({ data })
.catch((e) => {
console.error(
"❌ add or update shortTermMemory failed",
current,
data,
e
);
return undefined;
});
if (shortTermMemory) {
// 异步更新
LongTermMemory.update(this.user, message, memory, shortTermMemory);
}
return memory;
} }
} }

View File

@ -1,83 +0,0 @@
import { Message, User } from "@prisma/client";
import { kPrisma } from "../db";
import { UserMemory } from "./memory";
class _MessageHistory {
async count(options?: { cursor?: Message; sender?: User }) {
const { sender, cursor } = options ?? {};
return kPrisma.message
.count({
where: {
senderId: sender?.id,
id: {
gt: cursor?.id,
},
},
})
.catch((e) => {
console.error("❌ get msg count failed", e);
return -1;
});
}
/**
*
*/
async gets(options?: {
sender?: User;
limit?: number;
offset?: number;
cursor?: Message;
order?: "asc" | "desc";
}) {
const {
limit = 10,
offset = 0,
order = "desc",
sender,
cursor,
} = options ?? {};
const msgs = await kPrisma.message
.findMany({
cursor,
where: { senderId: sender?.id },
take: limit,
skip: offset,
orderBy: { createdAt: order },
})
.catch((e) => {
console.error("❌ get msgs failed", options, e);
return [];
});
return order === "desc" ? msgs.reverse() : msgs;
}
async addOrUpdate(
msg: Partial<Message> & {
text: string;
sender: User;
}
) {
const data = {
text: msg.text,
senderId: msg.sender?.id,
};
const message = await kPrisma.message
.upsert({
where: { id: msg.id },
create: data,
update: data,
})
.catch((e) => {
console.error("❌ add msg to db failed", msg, e);
return undefined;
});
if (message) {
// 异步更新记忆
new UserMemory(msg.sender).add(message);
}
return message;
}
}
export const MessageHistory = new _MessageHistory();

View File

@ -0,0 +1,90 @@
import { LongTermMemory, Room, User } from "@prisma/client";
import { kPrisma } from ".";
class _LongTermMemoryCRUD {
async count(options?: { cursorId?: number; room?: Room; owner?: User }) {
const { cursorId, owner, room } = options ?? {};
return kPrisma.longTermMemory
.count({
where: {
id: { gt: cursorId },
roomId: room?.id,
ownerId: owner?.id,
},
})
.catch((e) => {
console.error("❌ get longTermMemory count failed", e);
return -1;
});
}
async gets(options?: {
room?: Room;
owner?: User;
take?: number;
skip?: number;
cursorId?: number;
/**
*
*/
order?: "asc" | "desc";
}) {
const {
room,
owner,
take = 10,
skip = 0,
cursorId,
order = "desc",
} = options ?? {};
const memories = await kPrisma.longTermMemory
.findMany({
where: { roomId: room?.id, ownerId: owner?.id },
take,
skip,
cursor: { id: cursorId },
orderBy: { createdAt: order },
})
.catch((e) => {
console.error("❌ get long term memories failed", options, e);
return [];
});
return order === "desc" ? memories.reverse() : memories;
}
async addOrUpdate(
longTermMemory: Partial<LongTermMemory> & {
text: string;
cursorId: number;
roomId: string;
ownerId?: string;
}
) {
const { text: _text, cursorId, roomId, ownerId } = longTermMemory;
const text = _text?.trim();
const data = {
text,
cursor: {
connect: { id: cursorId },
},
room: {
connect: { id: roomId },
},
owner: {
connect: { id: ownerId },
},
};
return kPrisma.longTermMemory
.upsert({
where: { id: longTermMemory.id },
create: data,
update: data,
})
.catch((e) => {
console.error("❌ add longTermMemory to db failed", longTermMemory, e);
return undefined;
});
}
}
export const LongTermMemoryCRUD = new _LongTermMemoryCRUD();

View File

@ -0,0 +1,94 @@
import { ShortTermMemory, Room, User } from "@prisma/client";
import { kPrisma } from ".";
class _ShortTermMemoryCRUD {
async count(options?: { cursorId?: number; room?: Room; owner?: User }) {
const { cursorId, owner, room } = options ?? {};
return kPrisma.shortTermMemory
.count({
where: {
id: { gt: cursorId },
roomId: room?.id,
ownerId: owner?.id,
},
})
.catch((e) => {
console.error("❌ get shortTermMemory count failed", e);
return -1;
});
}
async gets(options?: {
room?: Room;
owner?: User;
take?: number;
skip?: number;
cursorId?: number;
/**
*
*/
order?: "asc" | "desc";
}) {
const {
room,
owner,
take = 10,
skip = 0,
cursorId,
order = "desc",
} = options ?? {};
const memories = await kPrisma.shortTermMemory
.findMany({
where: { roomId: room?.id, ownerId: owner?.id },
take,
skip,
cursor: { id: cursorId },
orderBy: { createdAt: order },
})
.catch((e) => {
console.error("❌ get short term memories failed", options, e);
return [];
});
return order === "desc" ? memories.reverse() : memories;
}
async addOrUpdate(
shortTermMemory: Partial<ShortTermMemory> & {
text: string;
cursorId: number;
roomId: string;
ownerId?: string;
}
) {
const { text: _text, cursorId, roomId, ownerId } = shortTermMemory;
const text = _text?.trim();
const data = {
text,
cursor: {
connect: { id: cursorId },
},
room: {
connect: { id: roomId },
},
owner: {
connect: { id: ownerId },
},
};
return kPrisma.shortTermMemory
.upsert({
where: { id: shortTermMemory.id },
create: data,
update: data,
})
.catch((e) => {
console.error(
"❌ add shortTermMemory to db failed",
shortTermMemory,
e
);
return undefined;
});
}
}
export const ShortTermMemoryCRUD = new _ShortTermMemoryCRUD();

86
src/services/db/memory.ts Normal file
View File

@ -0,0 +1,86 @@
import { Memory, Room, User } from "@prisma/client";
import { kPrisma } from ".";
class _MemoryCRUD {
async count(options?: { cursorId?: number; room?: Room; owner?: User }) {
const { cursorId, owner, room } = options ?? {};
return kPrisma.memory
.count({
where: {
id: { gt: cursorId },
roomId: room?.id,
ownerId: owner?.id,
},
})
.catch((e) => {
console.error("❌ get memory count failed", e);
return -1;
});
}
async gets(options?: {
room?: Room;
owner?: User;
take?: number;
skip?: number;
cursorId?: number;
/**
*
*/
order?: "asc" | "desc";
}) {
const {
room,
owner,
take = 10,
skip = 0,
cursorId,
order = "desc",
} = options ?? {};
const memories = await kPrisma.memory
.findMany({
where: { roomId: room?.id, ownerId: owner?.id },
take,
skip,
cursor: { id: cursorId },
orderBy: { createdAt: order },
})
.catch((e) => {
console.error("❌ get memories failed", options, e);
return [];
});
return order === "desc" ? memories.reverse() : memories;
}
async addOrUpdate(
memory: Partial<Memory> & {
text: string;
roomId: string;
ownerId?: string;
}
) {
const { text: _text, roomId, ownerId } = memory;
const text = _text?.trim();
const data = {
text,
room: {
connect: { id: roomId },
},
owner: {
connect: { id: ownerId },
},
};
return kPrisma.memory
.upsert({
where: { id: memory.id },
create: data,
update: data,
})
.catch((e) => {
console.error("❌ add memory to db failed", memory, e);
return undefined;
});
}
}
export const MemoryCRUD = new _MemoryCRUD();

View File

@ -0,0 +1,89 @@
import { Message, Prisma, Room, User } from "@prisma/client";
import { kPrisma } from ".";
class _MessageCRUD {
async count(options?: { cursorId?: number; room?: Room; sender?: User }) {
const { cursorId, sender, room } = options ?? {};
return kPrisma.message
.count({
where: {
id: { gt: cursorId },
roomId: room?.id,
senderId: sender?.id,
},
})
.catch((e) => {
console.error("❌ get message count failed", e);
return -1;
});
}
async gets(options?: {
room?: Room;
sender?: User;
take?: number;
skip?: number;
cursorId?: number;
include?: Prisma.MessageInclude;
/**
*
*/
order?: "asc" | "desc";
}) {
const {
room,
sender,
take = 10,
skip = 0,
cursorId,
include = { sender: true },
order = "desc",
} = options ?? {};
const messages = await kPrisma.message
.findMany({
where: { roomId: room?.id, senderId: sender?.id },
take,
skip,
include,
cursor: { id: cursorId },
orderBy: { createdAt: order },
})
.catch((e) => {
console.error("❌ get messages failed", options, e);
return [];
});
return order === "desc" ? messages.reverse() : messages;
}
async addOrUpdate(
message: Partial<Message> & {
text: string;
roomId: string;
senderId: string;
}
) {
const { text: _text, roomId, senderId } = message;
const text = _text?.trim();
const data = {
text,
room: {
connect: { id: roomId },
},
sender: {
connect: { id: senderId },
},
};
return kPrisma.message
.upsert({
where: { id: message.id },
create: data,
update: data,
})
.catch((e) => {
console.error("❌ add message to db failed", message, e);
return undefined;
});
}
}
export const MessageCRUD = new _MessageCRUD();

86
src/services/db/room.ts Normal file
View File

@ -0,0 +1,86 @@
import { Prisma, Room, User } from "@prisma/client";
import { kPrisma } from ".";
export function getRoomID(users: User[]) {
return users
.map((e) => e.id)
.sort()
.join("_");
}
class _RoomCRUD {
async count(options?: { user?: User }) {
const { user } = options ?? {};
return kPrisma.room
.count({
where: {
members: {
some: {
id: user?.id,
},
},
},
})
.catch((e) => {
console.error("❌ get room count failed", e);
return -1;
});
}
async gets(options?: {
user?: User;
take?: number;
skip?: number;
cursorId?: string;
include?: Prisma.RoomInclude;
/**
*
*/
order?: "asc" | "desc";
}) {
const {
user,
take = 10,
skip = 0,
cursorId,
include = { members: true },
order = "desc",
} = options ?? {};
const rooms = await kPrisma.room
.findMany({
where: { members: { some: { id: user?.id } } },
take,
skip,
cursor: { id: cursorId },
include,
orderBy: { createdAt: order },
})
.catch((e) => {
console.error("❌ get rooms failed", options, e);
return [];
});
return order === "desc" ? rooms.reverse() : rooms;
}
async addOrUpdate(
room: Partial<Room> & {
name: string;
description: string;
}
) {
room.name = room.name.trim();
room.description = room.description.trim();
return kPrisma.room
.upsert({
where: { id: room.id },
create: room,
update: room,
})
.catch((e) => {
console.error("❌ add room to db failed", room, e);
return undefined;
});
}
}
export const RoomCRUD = new _RoomCRUD();

64
src/services/db/user.ts Normal file
View File

@ -0,0 +1,64 @@
import { Prisma, User } from "@prisma/client";
import { kPrisma } from ".";
class _UserCRUD {
async count() {
return kPrisma.user.count().catch((e) => {
console.error("❌ get user count failed", e);
return -1;
});
}
async gets(options?: {
take?: number;
skip?: number;
cursorId?: string;
include?: Prisma.UserInclude;
/**
*
*/
order?: "asc" | "desc";
}) {
const {
take = 10,
skip = 0,
cursorId,
include = { members: true },
order = "desc",
} = options ?? {};
const users = await kPrisma.user
.findMany({
take,
skip,
cursor: { id: cursorId },
orderBy: { createdAt: order },
})
.catch((e) => {
console.error("❌ get users failed", options, e);
return [];
});
return order === "desc" ? users.reverse() : users;
}
async addOrUpdate(
user: Partial<User> & {
name: string;
profile: string;
}
) {
user.name = user.name.trim();
user.profile = user.profile.trim();
return kPrisma.user
.upsert({
where: { id: user.id },
create: user,
update: user,
})
.catch((e) => {
console.error("❌ add user to db failed", user, e);
return undefined;
});
}
}
export const UserCRUD = new _UserCRUD();