mirror of
https://github.com/idootop/mi-gpt.git
synced 2025-04-07 21:39:20 +00:00
feat: auto abort for updating long/short memory
This commit is contained in:
parent
4beac32a2a
commit
57a765af1b
|
@ -5,6 +5,7 @@ import { MemoryCRUD } from "../../db/memory";
|
||||||
import { ShortTermMemoryCRUD } from "../../db/memory-short-term";
|
import { ShortTermMemoryCRUD } from "../../db/memory-short-term";
|
||||||
import { LongTermMemoryCRUD } from "../../db/memory-long-term";
|
import { LongTermMemoryCRUD } from "../../db/memory-long-term";
|
||||||
import { ShortTermMemoryAgent } from "./short-term";
|
import { ShortTermMemoryAgent } from "./short-term";
|
||||||
|
import { openai } from "../../openai";
|
||||||
|
|
||||||
export class MemoryManager {
|
export class MemoryManager {
|
||||||
private room: Room;
|
private room: Room;
|
||||||
|
@ -48,33 +49,57 @@ export class MemoryManager {
|
||||||
return [];
|
return [];
|
||||||
}
|
}
|
||||||
|
|
||||||
|
private _currentMemory?: Memory;
|
||||||
async addMessage2Memory(message: Message) {
|
async addMessage2Memory(message: Message) {
|
||||||
// todo create memory embedding
|
// todo create memory embedding
|
||||||
const res = await MemoryCRUD.addOrUpdate({
|
const currentMemory = await MemoryCRUD.addOrUpdate({
|
||||||
text: message.text,
|
text: message.text,
|
||||||
roomId: this.room.id,
|
roomId: this.room.id,
|
||||||
ownerId: message.senderId,
|
ownerId: message.senderId,
|
||||||
});
|
});
|
||||||
|
if (currentMemory) {
|
||||||
|
this._onMemory(currentMemory);
|
||||||
|
}
|
||||||
|
return currentMemory;
|
||||||
|
}
|
||||||
|
|
||||||
|
private _onMemory(currentMemory: Memory) {
|
||||||
|
if (this._currentMemory) {
|
||||||
|
// 取消之前的更新记忆任务
|
||||||
|
openai.abort(`update-short-memory-${this._currentMemory.id}`);
|
||||||
|
openai.abort(`update-long-memory-${this._currentMemory.id}`);
|
||||||
|
}
|
||||||
|
this._currentMemory = currentMemory;
|
||||||
// 异步更新长短期记忆
|
// 异步更新长短期记忆
|
||||||
this.updateLongShortTermMemory();
|
this.updateLongShortTermMemory({ currentMemory });
|
||||||
return res;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* 更新记忆(当新的记忆数量超过阈值时,自动更新长短期记忆)
|
* 更新记忆(当新的记忆数量超过阈值时,自动更新长短期记忆)
|
||||||
*/
|
*/
|
||||||
async updateLongShortTermMemory(options?: {
|
async updateLongShortTermMemory(options: {
|
||||||
|
currentMemory: Memory;
|
||||||
shortThreshold?: number;
|
shortThreshold?: number;
|
||||||
longThreshold?: number;
|
longThreshold?: number;
|
||||||
}) {
|
}) {
|
||||||
const { shortThreshold, longThreshold } = options ?? {};
|
const { currentMemory, shortThreshold, longThreshold } = options ?? {};
|
||||||
const success = await this._updateShortTermMemory(shortThreshold);
|
const success = await this._updateShortTermMemory({
|
||||||
|
currentMemory,
|
||||||
|
threshold: shortThreshold,
|
||||||
|
});
|
||||||
if (success) {
|
if (success) {
|
||||||
await this._updateLongTermMemory(longThreshold);
|
await this._updateLongTermMemory({
|
||||||
|
currentMemory,
|
||||||
|
threshold: longThreshold,
|
||||||
|
});
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
private async _updateShortTermMemory(threshold = 10) {
|
private async _updateShortTermMemory(options: {
|
||||||
|
currentMemory: Memory;
|
||||||
|
threshold?: number;
|
||||||
|
}) {
|
||||||
|
const { currentMemory, threshold = 10 } = options;
|
||||||
const lastMemory = firstOf(await this.getShortTermMemories(1));
|
const lastMemory = firstOf(await this.getShortTermMemories(1));
|
||||||
const newMemories = await MemoryCRUD.gets({
|
const newMemories = await MemoryCRUD.gets({
|
||||||
cursorId: lastMemory?.cursorId,
|
cursorId: lastMemory?.cursorId,
|
||||||
|
@ -85,10 +110,11 @@ export class MemoryManager {
|
||||||
if (newMemories.length < 1 || newMemories.length < threshold) {
|
if (newMemories.length < 1 || newMemories.length < threshold) {
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
const newMemory = await ShortTermMemoryAgent.generate(
|
const newMemory = await ShortTermMemoryAgent.generate({
|
||||||
|
currentMemory,
|
||||||
newMemories,
|
newMemories,
|
||||||
lastMemory
|
lastMemory,
|
||||||
);
|
});
|
||||||
if (!newMemory) {
|
if (!newMemory) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
@ -101,7 +127,11 @@ export class MemoryManager {
|
||||||
return res != null;
|
return res != null;
|
||||||
}
|
}
|
||||||
|
|
||||||
private async _updateLongTermMemory(threshold = 10) {
|
private async _updateLongTermMemory(options: {
|
||||||
|
currentMemory: Memory;
|
||||||
|
threshold?: number;
|
||||||
|
}) {
|
||||||
|
const { currentMemory, threshold = 10 } = options;
|
||||||
const lastMemory = firstOf(await this.getLongTermMemories(1));
|
const lastMemory = firstOf(await this.getLongTermMemories(1));
|
||||||
const newMemories = await ShortTermMemoryCRUD.gets({
|
const newMemories = await ShortTermMemoryCRUD.gets({
|
||||||
cursorId: lastMemory?.cursorId,
|
cursorId: lastMemory?.cursorId,
|
||||||
|
@ -112,10 +142,11 @@ export class MemoryManager {
|
||||||
if (newMemories.length < 1 || newMemories.length < threshold) {
|
if (newMemories.length < 1 || newMemories.length < threshold) {
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
const newMemory = await LongTermMemoryAgent.generate(
|
const newMemory = await LongTermMemoryAgent.generate({
|
||||||
|
currentMemory,
|
||||||
newMemories,
|
newMemories,
|
||||||
lastMemory
|
lastMemory,
|
||||||
);
|
});
|
||||||
if (!newMemory) {
|
if (!newMemory) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,13 +1,18 @@
|
||||||
import { Memory, ShortTermMemory } from "@prisma/client";
|
import { LongTermMemory, Memory, ShortTermMemory } from "@prisma/client";
|
||||||
|
import { openai } from "../../openai";
|
||||||
|
|
||||||
export class LongTermMemoryAgent {
|
export class LongTermMemoryAgent {
|
||||||
// todo 使用 LLM 生成新的长期记忆
|
// todo 使用 LLM 生成新的长期记忆
|
||||||
static async generate(
|
static async generate(options: {
|
||||||
newMemories: Memory[],
|
currentMemory: Memory;
|
||||||
lastLongTermMemory?: ShortTermMemory
|
newMemories: ShortTermMemory[];
|
||||||
): Promise<string | undefined> {
|
lastMemory?: LongTermMemory;
|
||||||
return `count: ${newMemories.length}\n${newMemories
|
}): Promise<string | undefined> {
|
||||||
.map((e, idx) => idx.toString() + ". " + e.text)
|
const { currentMemory, newMemories, lastMemory } = options;
|
||||||
.join("\n")}`;
|
const res = await openai.chat({
|
||||||
|
user: "todo", // todo prompt
|
||||||
|
requestId: `update-long-memory-${currentMemory.id}`,
|
||||||
|
});
|
||||||
|
return res?.content?.trim();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,13 +1,18 @@
|
||||||
import { Memory } from "@prisma/client";
|
import { Memory, ShortTermMemory } from "@prisma/client";
|
||||||
|
import { openai } from "../../openai";
|
||||||
|
|
||||||
export class ShortTermMemoryAgent {
|
export class ShortTermMemoryAgent {
|
||||||
// todo 使用 LLM 生成新的短期记忆
|
// todo 使用 LLM 生成新的短期记忆
|
||||||
static async generate(
|
static async generate(options: {
|
||||||
newMemories: Memory[],
|
currentMemory: Memory;
|
||||||
lastShortTermMemory?: Memory
|
newMemories: Memory[];
|
||||||
): Promise<string | undefined> {
|
lastMemory?: ShortTermMemory;
|
||||||
return `count: ${newMemories.length}\n${newMemories
|
}): Promise<string | undefined> {
|
||||||
.map((e, idx) => idx.toString() + ". " + e.text)
|
const { currentMemory, newMemories, lastMemory } = options;
|
||||||
.join("\n")}`;
|
const res = await openai.chat({
|
||||||
|
user: "todo", // todo prompt
|
||||||
|
requestId: `update-short-memory-${currentMemory.id}`,
|
||||||
|
});
|
||||||
|
return res?.content?.trim();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -12,6 +12,7 @@ export interface ChatOptions {
|
||||||
system?: string;
|
system?: string;
|
||||||
tools?: Array<ChatCompletionTool>;
|
tools?: Array<ChatCompletionTool>;
|
||||||
jsonMode?: boolean;
|
jsonMode?: boolean;
|
||||||
|
requestId?: string;
|
||||||
}
|
}
|
||||||
|
|
||||||
class OpenAIClient {
|
class OpenAIClient {
|
||||||
|
@ -32,17 +33,26 @@ class OpenAIClient {
|
||||||
}
|
}
|
||||||
|
|
||||||
async chat(options: ChatOptions) {
|
async chat(options: ChatOptions) {
|
||||||
const { user, system, tools, jsonMode } = options;
|
let { user, system, tools, jsonMode, requestId } = options;
|
||||||
const systemMsg: ChatCompletionMessageParam[] = system
|
const systemMsg: ChatCompletionMessageParam[] = system
|
||||||
? [{ role: "system", content: system }]
|
? [{ role: "system", content: system }]
|
||||||
: [];
|
: [];
|
||||||
|
let signal: AbortSignal | undefined;
|
||||||
|
if (requestId) {
|
||||||
|
const controller = new AbortController();
|
||||||
|
this._abortCallbacks[requestId] = () => controller.abort();
|
||||||
|
signal = controller.signal;
|
||||||
|
}
|
||||||
const chatCompletion = await this._client.chat.completions
|
const chatCompletion = await this._client.chat.completions
|
||||||
.create({
|
.create(
|
||||||
tools,
|
{
|
||||||
messages: [...systemMsg, { role: "user", content: user }],
|
tools,
|
||||||
model: kEnvs.OPENAI_MODEL ?? "gpt-3.5-turbo-0125",
|
messages: [...systemMsg, { role: "user", content: user }],
|
||||||
response_format: jsonMode ? { type: "json_object" } : undefined,
|
model: kEnvs.OPENAI_MODEL ?? "gpt-3.5-turbo-0125",
|
||||||
})
|
response_format: jsonMode ? { type: "json_object" } : undefined,
|
||||||
|
},
|
||||||
|
{ signal }
|
||||||
|
)
|
||||||
.catch((e) => {
|
.catch((e) => {
|
||||||
console.error("❌ openai chat failed", e);
|
console.error("❌ openai chat failed", e);
|
||||||
return null;
|
return null;
|
||||||
|
@ -52,11 +62,10 @@ class OpenAIClient {
|
||||||
|
|
||||||
async chatStream(
|
async chatStream(
|
||||||
options: ChatOptions & {
|
options: ChatOptions & {
|
||||||
requestId?: string;
|
|
||||||
onStream?: (text: string) => void;
|
onStream?: (text: string) => void;
|
||||||
}
|
}
|
||||||
) {
|
) {
|
||||||
const { user, system, tools, jsonMode, onStream, requestId } = options;
|
let { user, system, tools, jsonMode, requestId, onStream } = options;
|
||||||
const systemMsg: ChatCompletionMessageParam[] = system
|
const systemMsg: ChatCompletionMessageParam[] = system
|
||||||
? [{ role: "system", content: system }]
|
? [{ role: "system", content: system }]
|
||||||
: [];
|
: [];
|
||||||
|
|
|
@ -12,9 +12,9 @@ dotenv.config();
|
||||||
async function main() {
|
async function main() {
|
||||||
println(kBannerASCII);
|
println(kBannerASCII);
|
||||||
// testDB();
|
// testDB();
|
||||||
testSpeaker();
|
// testSpeaker();
|
||||||
// testOpenAI();
|
// testOpenAI();
|
||||||
// testMyBot();
|
testMyBot();
|
||||||
}
|
}
|
||||||
|
|
||||||
runWithDB(main);
|
runWithDB(main);
|
||||||
|
|
Loading…
Reference in New Issue
Block a user