mirror of
https://github.com/idootop/mi-gpt.git
synced 2025-04-07 18:43:08 +00:00
feat: auto abort for updating long/short memory
This commit is contained in:
parent
4beac32a2a
commit
57a765af1b
|
@ -5,6 +5,7 @@ import { MemoryCRUD } from "../../db/memory";
|
|||
import { ShortTermMemoryCRUD } from "../../db/memory-short-term";
|
||||
import { LongTermMemoryCRUD } from "../../db/memory-long-term";
|
||||
import { ShortTermMemoryAgent } from "./short-term";
|
||||
import { openai } from "../../openai";
|
||||
|
||||
export class MemoryManager {
|
||||
private room: Room;
|
||||
|
@ -48,33 +49,57 @@ export class MemoryManager {
|
|||
return [];
|
||||
}
|
||||
|
||||
private _currentMemory?: Memory;
|
||||
async addMessage2Memory(message: Message) {
|
||||
// todo create memory embedding
|
||||
const res = await MemoryCRUD.addOrUpdate({
|
||||
const currentMemory = await MemoryCRUD.addOrUpdate({
|
||||
text: message.text,
|
||||
roomId: this.room.id,
|
||||
ownerId: message.senderId,
|
||||
});
|
||||
if (currentMemory) {
|
||||
this._onMemory(currentMemory);
|
||||
}
|
||||
return currentMemory;
|
||||
}
|
||||
|
||||
private _onMemory(currentMemory: Memory) {
|
||||
if (this._currentMemory) {
|
||||
// 取消之前的更新记忆任务
|
||||
openai.abort(`update-short-memory-${this._currentMemory.id}`);
|
||||
openai.abort(`update-long-memory-${this._currentMemory.id}`);
|
||||
}
|
||||
this._currentMemory = currentMemory;
|
||||
// 异步更新长短期记忆
|
||||
this.updateLongShortTermMemory();
|
||||
return res;
|
||||
this.updateLongShortTermMemory({ currentMemory });
|
||||
}
|
||||
|
||||
/**
|
||||
* 更新记忆(当新的记忆数量超过阈值时,自动更新长短期记忆)
|
||||
*/
|
||||
async updateLongShortTermMemory(options?: {
|
||||
async updateLongShortTermMemory(options: {
|
||||
currentMemory: Memory;
|
||||
shortThreshold?: number;
|
||||
longThreshold?: number;
|
||||
}) {
|
||||
const { shortThreshold, longThreshold } = options ?? {};
|
||||
const success = await this._updateShortTermMemory(shortThreshold);
|
||||
const { currentMemory, shortThreshold, longThreshold } = options ?? {};
|
||||
const success = await this._updateShortTermMemory({
|
||||
currentMemory,
|
||||
threshold: shortThreshold,
|
||||
});
|
||||
if (success) {
|
||||
await this._updateLongTermMemory(longThreshold);
|
||||
await this._updateLongTermMemory({
|
||||
currentMemory,
|
||||
threshold: longThreshold,
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
private async _updateShortTermMemory(threshold = 10) {
|
||||
private async _updateShortTermMemory(options: {
|
||||
currentMemory: Memory;
|
||||
threshold?: number;
|
||||
}) {
|
||||
const { currentMemory, threshold = 10 } = options;
|
||||
const lastMemory = firstOf(await this.getShortTermMemories(1));
|
||||
const newMemories = await MemoryCRUD.gets({
|
||||
cursorId: lastMemory?.cursorId,
|
||||
|
@ -85,10 +110,11 @@ export class MemoryManager {
|
|||
if (newMemories.length < 1 || newMemories.length < threshold) {
|
||||
return true;
|
||||
}
|
||||
const newMemory = await ShortTermMemoryAgent.generate(
|
||||
const newMemory = await ShortTermMemoryAgent.generate({
|
||||
currentMemory,
|
||||
newMemories,
|
||||
lastMemory
|
||||
);
|
||||
lastMemory,
|
||||
});
|
||||
if (!newMemory) {
|
||||
return false;
|
||||
}
|
||||
|
@ -101,7 +127,11 @@ export class MemoryManager {
|
|||
return res != null;
|
||||
}
|
||||
|
||||
private async _updateLongTermMemory(threshold = 10) {
|
||||
private async _updateLongTermMemory(options: {
|
||||
currentMemory: Memory;
|
||||
threshold?: number;
|
||||
}) {
|
||||
const { currentMemory, threshold = 10 } = options;
|
||||
const lastMemory = firstOf(await this.getLongTermMemories(1));
|
||||
const newMemories = await ShortTermMemoryCRUD.gets({
|
||||
cursorId: lastMemory?.cursorId,
|
||||
|
@ -112,10 +142,11 @@ export class MemoryManager {
|
|||
if (newMemories.length < 1 || newMemories.length < threshold) {
|
||||
return true;
|
||||
}
|
||||
const newMemory = await LongTermMemoryAgent.generate(
|
||||
const newMemory = await LongTermMemoryAgent.generate({
|
||||
currentMemory,
|
||||
newMemories,
|
||||
lastMemory
|
||||
);
|
||||
lastMemory,
|
||||
});
|
||||
if (!newMemory) {
|
||||
return false;
|
||||
}
|
||||
|
|
|
@ -1,13 +1,18 @@
|
|||
import { Memory, ShortTermMemory } from "@prisma/client";
|
||||
import { LongTermMemory, Memory, ShortTermMemory } from "@prisma/client";
|
||||
import { openai } from "../../openai";
|
||||
|
||||
export class LongTermMemoryAgent {
|
||||
// todo 使用 LLM 生成新的长期记忆
|
||||
static async generate(
|
||||
newMemories: Memory[],
|
||||
lastLongTermMemory?: ShortTermMemory
|
||||
): Promise<string | undefined> {
|
||||
return `count: ${newMemories.length}\n${newMemories
|
||||
.map((e, idx) => idx.toString() + ". " + e.text)
|
||||
.join("\n")}`;
|
||||
static async generate(options: {
|
||||
currentMemory: Memory;
|
||||
newMemories: ShortTermMemory[];
|
||||
lastMemory?: LongTermMemory;
|
||||
}): Promise<string | undefined> {
|
||||
const { currentMemory, newMemories, lastMemory } = options;
|
||||
const res = await openai.chat({
|
||||
user: "todo", // todo prompt
|
||||
requestId: `update-long-memory-${currentMemory.id}`,
|
||||
});
|
||||
return res?.content?.trim();
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,13 +1,18 @@
|
|||
import { Memory } from "@prisma/client";
|
||||
import { Memory, ShortTermMemory } from "@prisma/client";
|
||||
import { openai } from "../../openai";
|
||||
|
||||
export class ShortTermMemoryAgent {
|
||||
// todo 使用 LLM 生成新的短期记忆
|
||||
static async generate(
|
||||
newMemories: Memory[],
|
||||
lastShortTermMemory?: Memory
|
||||
): Promise<string | undefined> {
|
||||
return `count: ${newMemories.length}\n${newMemories
|
||||
.map((e, idx) => idx.toString() + ". " + e.text)
|
||||
.join("\n")}`;
|
||||
static async generate(options: {
|
||||
currentMemory: Memory;
|
||||
newMemories: Memory[];
|
||||
lastMemory?: ShortTermMemory;
|
||||
}): Promise<string | undefined> {
|
||||
const { currentMemory, newMemories, lastMemory } = options;
|
||||
const res = await openai.chat({
|
||||
user: "todo", // todo prompt
|
||||
requestId: `update-short-memory-${currentMemory.id}`,
|
||||
});
|
||||
return res?.content?.trim();
|
||||
}
|
||||
}
|
||||
|
|
|
@ -12,6 +12,7 @@ export interface ChatOptions {
|
|||
system?: string;
|
||||
tools?: Array<ChatCompletionTool>;
|
||||
jsonMode?: boolean;
|
||||
requestId?: string;
|
||||
}
|
||||
|
||||
class OpenAIClient {
|
||||
|
@ -32,17 +33,26 @@ class OpenAIClient {
|
|||
}
|
||||
|
||||
async chat(options: ChatOptions) {
|
||||
const { user, system, tools, jsonMode } = options;
|
||||
let { user, system, tools, jsonMode, requestId } = options;
|
||||
const systemMsg: ChatCompletionMessageParam[] = system
|
||||
? [{ role: "system", content: system }]
|
||||
: [];
|
||||
let signal: AbortSignal | undefined;
|
||||
if (requestId) {
|
||||
const controller = new AbortController();
|
||||
this._abortCallbacks[requestId] = () => controller.abort();
|
||||
signal = controller.signal;
|
||||
}
|
||||
const chatCompletion = await this._client.chat.completions
|
||||
.create({
|
||||
tools,
|
||||
messages: [...systemMsg, { role: "user", content: user }],
|
||||
model: kEnvs.OPENAI_MODEL ?? "gpt-3.5-turbo-0125",
|
||||
response_format: jsonMode ? { type: "json_object" } : undefined,
|
||||
})
|
||||
.create(
|
||||
{
|
||||
tools,
|
||||
messages: [...systemMsg, { role: "user", content: user }],
|
||||
model: kEnvs.OPENAI_MODEL ?? "gpt-3.5-turbo-0125",
|
||||
response_format: jsonMode ? { type: "json_object" } : undefined,
|
||||
},
|
||||
{ signal }
|
||||
)
|
||||
.catch((e) => {
|
||||
console.error("❌ openai chat failed", e);
|
||||
return null;
|
||||
|
@ -52,11 +62,10 @@ class OpenAIClient {
|
|||
|
||||
async chatStream(
|
||||
options: ChatOptions & {
|
||||
requestId?: string;
|
||||
onStream?: (text: string) => void;
|
||||
}
|
||||
) {
|
||||
const { user, system, tools, jsonMode, onStream, requestId } = options;
|
||||
let { user, system, tools, jsonMode, requestId, onStream } = options;
|
||||
const systemMsg: ChatCompletionMessageParam[] = system
|
||||
? [{ role: "system", content: system }]
|
||||
: [];
|
||||
|
|
|
@ -12,9 +12,9 @@ dotenv.config();
|
|||
async function main() {
|
||||
println(kBannerASCII);
|
||||
// testDB();
|
||||
testSpeaker();
|
||||
// testSpeaker();
|
||||
// testOpenAI();
|
||||
// testMyBot();
|
||||
testMyBot();
|
||||
}
|
||||
|
||||
runWithDB(main);
|
||||
|
|
Loading…
Reference in New Issue
Block a user