refactor: 调用封装后函数

This commit is contained in:
roitium 2024-07-24 10:18:57 +08:00
parent 208b2db60b
commit 11dc25aa60
3 changed files with 10 additions and 17 deletions

View File

@ -1,5 +1,6 @@
import { Memory, Message, Room, User } from "@prisma/client"; import { Memory, Message, Room, User } from "@prisma/client";
import { firstOf, lastOf } from "../../../utils/base"; import { firstOf, lastOf } from "../../../utils/base";
import { Logger } from "../../../utils/log";
import { MemoryCRUD } from "../../db/memory"; import { MemoryCRUD } from "../../db/memory";
import { LongTermMemoryCRUD } from "../../db/memory-long-term"; import { LongTermMemoryCRUD } from "../../db/memory-long-term";
import { ShortTermMemoryCRUD } from "../../db/memory-short-term"; import { ShortTermMemoryCRUD } from "../../db/memory-short-term";
@ -7,9 +8,7 @@ import { openai } from "../../openai";
import { MessageContext } from "../conversation"; import { MessageContext } from "../conversation";
import { LongTermMemoryAgent } from "./long-term"; import { LongTermMemoryAgent } from "./long-term";
import { ShortTermMemoryAgent } from "./short-term"; import { ShortTermMemoryAgent } from "./short-term";
import {Logger} from "../../../utils/log";
export const memoryLogger = Logger.create({ tag: "Memory" });
export class MemoryManager { export class MemoryManager {
private room: Room; private room: Room;
@ -17,6 +16,7 @@ export class MemoryManager {
* owner * owner
*/ */
private owner?: User; private owner?: User;
private _logger = Logger.create({ tag: "Memory" });
constructor(room: Room, owner?: User) { constructor(room: Room, owner?: User) {
this.room = room; this.room = room;
@ -100,7 +100,7 @@ export class MemoryManager {
threshold?: number; threshold?: number;
} }
) { ) {
const { threshold = 10 } = options; const { threshold = 1 } = options;
const lastMemory = firstOf(await this.getShortTermMemories({ take: 1 })); const lastMemory = firstOf(await this.getShortTermMemories({ take: 1 }));
const newMemories: (Memory & { const newMemories: (Memory & {
msg: Message & { msg: Message & {
@ -120,7 +120,7 @@ export class MemoryManager {
lastMemory, lastMemory,
}); });
if (!newMemory) { if (!newMemory) {
memoryLogger.error("💀 生成短期记忆失败"); this._logger.error("💀 生成短期记忆失败");
return false; return false;
} }
const res = await ShortTermMemoryCRUD.addOrUpdate({ const res = await ShortTermMemoryCRUD.addOrUpdate({
@ -154,7 +154,7 @@ export class MemoryManager {
lastMemory, lastMemory,
}); });
if (!newMemory) { if (!newMemory) {
memoryLogger.error("💀 生成长期记忆失败"); this._logger.error("💀 生成长期记忆失败");
return false; return false;
} }
const res = await LongTermMemoryCRUD.addOrUpdate({ const res = await LongTermMemoryCRUD.addOrUpdate({

View File

@ -1,6 +1,7 @@
import { LongTermMemory, ShortTermMemory } from "@prisma/client"; import { LongTermMemory, ShortTermMemory } from "@prisma/client";
import { jsonDecode, lastOf } from "../../../utils/base"; import { lastOf } from "../../../utils/base";
import { buildPrompt } from "../../../utils/string"; import { buildPrompt } from "../../../utils/string";
import { cleanJsonAndDecode } from "../../../utils/parse";
import { openai } from "../../openai"; import { openai } from "../../openai";
import { MessageContext } from "../conversation"; import { MessageContext } from "../conversation";
@ -67,10 +68,6 @@ export class LongTermMemoryAgent {
shortTermMemory: lastOf(newMemories)!.text, shortTermMemory: lastOf(newMemories)!.text,
}), }),
}); });
// 如果返回内容是个markdown代码块,就让他变回普通json return cleanJsonAndDecode(res?.content)?.longTermMemories?.toString();
res?.content?.trim();
if (res?.content?.startsWith("```json")) {res.content = res?.content?.replace("```json", "");}
if (res?.content?.endsWith("```")) {res.content = res?.content?.replace("```", "");}
return jsonDecode(res?.content)?.longTermMemories?.toString();
} }
} }

View File

@ -1,5 +1,5 @@
import { Memory, Message, ShortTermMemory, User } from "@prisma/client"; import { Memory, Message, ShortTermMemory, User } from "@prisma/client";
import { jsonDecode } from "../../../utils/base"; import { cleanJsonAndDecode } from "../../../utils/parse";
import { buildPrompt, formatMsg } from "../../../utils/string"; import { buildPrompt, formatMsg } from "../../../utils/string";
import { openai } from "../../openai"; import { openai } from "../../openai";
import { MessageContext } from "../conversation"; import { MessageContext } from "../conversation";
@ -78,10 +78,6 @@ export class ShortTermMemoryAgent {
.join("\n"), .join("\n"),
}), }),
}); });
// 如果返回内容是个markdown代码块,就让他变回普通json return cleanJsonAndDecode(res?.content)?.shortTermMemories?.toString();
res?.content?.trim();
if (res?.content?.startsWith("```json")) {res.content = res?.content?.replace("```json", "");}
if (res?.content?.endsWith("```")) {res.content = res?.content?.replace("```", "");}
return jsonDecode(res?.content)?.shortTermMemories?.toString();
} }
} }