Files
AIclinicalresearch/backend/src/services/conversationService.ts
2025-10-10 20:33:18 +08:00

385 lines
8.7 KiB
TypeScript
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
import { prisma } from '../config/database.js';
import { LLMFactory } from '../adapters/LLMFactory.js';
import { Message, ModelType, StreamChunk } from '../adapters/types.js';
import { agentService } from './agentService.js';
interface CreateConversationData {
userId: string;
projectId: string;
agentId: string;
title?: string;
}
interface SendMessageData {
conversationId: string;
content: string;
modelType: ModelType;
knowledgeBaseIds?: string[];
}
export class ConversationService {
/**
* 创建新对话
*/
async createConversation(data: CreateConversationData) {
const { userId, projectId, agentId, title } = data;
// 验证智能体是否存在
const agent = agentService.getAgentById(agentId);
if (!agent) {
throw new Error('智能体不存在');
}
// 验证项目是否存在
const project = await prisma.project.findFirst({
where: {
id: projectId,
userId: userId,
deletedAt: null,
},
});
if (!project) {
throw new Error('项目不存在或无权访问');
}
// 创建对话
const conversation = await prisma.conversation.create({
data: {
userId,
projectId,
agentId,
title: title || `${agent.name}的对话`,
metadata: {
agentName: agent.name,
agentCategory: agent.category,
},
},
});
return conversation;
}
/**
* 获取对话列表
*/
async getConversations(userId: string, projectId?: string) {
const where: any = {
userId,
deletedAt: null,
};
if (projectId) {
where.projectId = projectId;
}
const conversations = await prisma.conversation.findMany({
where,
include: {
project: {
select: {
id: true,
name: true,
},
},
_count: {
select: {
messages: true,
},
},
},
orderBy: {
updatedAt: 'desc',
},
});
return conversations;
}
/**
* 获取对话详情(包含消息)
*/
async getConversationById(conversationId: string, userId: string) {
const conversation = await prisma.conversation.findFirst({
where: {
id: conversationId,
userId,
deletedAt: null,
},
include: {
project: {
select: {
id: true,
name: true,
background: true,
researchType: true,
},
},
messages: {
orderBy: {
createdAt: 'asc',
},
},
},
});
if (!conversation) {
throw new Error('对话不存在或无权访问');
}
return conversation;
}
/**
* 组装上下文消息
*/
private async assembleContext(
conversationId: string,
agentId: string,
projectBackground: string,
userInput: string,
knowledgeBaseContext?: string
): Promise<Message[]> {
// 获取系统Prompt
const systemPrompt = agentService.getSystemPrompt(agentId);
// 获取历史消息最近10条
const historyMessages = await prisma.message.findMany({
where: {
conversationId,
},
orderBy: {
createdAt: 'desc',
},
take: 10,
});
// 反转顺序(最早的在前)
historyMessages.reverse();
// 渲染用户Prompt模板
const renderedUserPrompt = agentService.renderUserPrompt(agentId, {
projectBackground,
userInput,
knowledgeBaseContext,
});
// 组装消息数组
const messages: Message[] = [
{
role: 'system',
content: systemPrompt,
},
];
// 添加历史消息
for (const msg of historyMessages) {
messages.push({
role: msg.role as 'user' | 'assistant',
content: msg.content,
});
}
// 添加当前用户输入
messages.push({
role: 'user',
content: renderedUserPrompt,
});
return messages;
}
/**
* 发送消息(非流式)
*/
async sendMessage(data: SendMessageData, userId: string) {
const { conversationId, content, modelType, knowledgeBaseIds } = data;
// 获取对话信息
const conversation = await this.getConversationById(conversationId, userId);
// 获取知识库上下文(如果有@知识库)
let knowledgeBaseContext = '';
if (knowledgeBaseIds && knowledgeBaseIds.length > 0) {
// TODO: 调用Dify RAG获取知识库上下文
knowledgeBaseContext = '相关文献内容...';
}
// 组装上下文
const messages = await this.assembleContext(
conversationId,
conversation.agentId,
conversation.project?.background || '',
content,
knowledgeBaseContext
);
// 获取LLM适配器
const adapter = LLMFactory.getAdapter(modelType);
// 获取智能体配置的模型参数
const agent = agentService.getAgentById(conversation.agentId);
const modelConfig = agent?.models?.[modelType];
// 调用LLM
const response = await adapter.chat(messages, {
temperature: modelConfig?.temperature,
maxTokens: modelConfig?.maxTokens,
topP: modelConfig?.topP,
});
// 保存用户消息
const userMessage = await prisma.message.create({
data: {
conversationId,
role: 'user',
content,
metadata: {
knowledgeBaseIds,
},
},
});
// 保存助手回复
const assistantMessage = await prisma.message.create({
data: {
conversationId,
role: 'assistant',
content: response.content,
model: response.model,
tokens: response.usage?.totalTokens,
metadata: {
usage: response.usage,
finishReason: response.finishReason,
},
},
});
// 更新对话的最后更新时间
await prisma.conversation.update({
where: { id: conversationId },
data: { updatedAt: new Date() },
});
return {
userMessage,
assistantMessage,
usage: response.usage,
};
}
/**
* 发送消息(流式)
*/
async *sendMessageStream(
data: SendMessageData,
userId: string
): AsyncGenerator<StreamChunk, void, unknown> {
const { conversationId, content, modelType, knowledgeBaseIds } = data;
// 获取对话信息
const conversation = await this.getConversationById(conversationId, userId);
// 获取知识库上下文(如果有@知识库)
let knowledgeBaseContext = '';
if (knowledgeBaseIds && knowledgeBaseIds.length > 0) {
// TODO: 调用Dify RAG获取知识库上下文
knowledgeBaseContext = '相关文献内容...';
}
// 组装上下文
const messages = await this.assembleContext(
conversationId,
conversation.agentId,
conversation.project?.background || '',
content,
knowledgeBaseContext
);
// 获取LLM适配器
const adapter = LLMFactory.getAdapter(modelType);
// 获取智能体配置的模型参数
const agent = agentService.getAgentById(conversation.agentId);
const modelConfig = agent?.models?.[modelType];
// 保存用户消息
await prisma.message.create({
data: {
conversationId,
role: 'user',
content,
metadata: {
knowledgeBaseIds,
},
},
});
// 用于累积完整的回复内容
let fullContent = '';
let usage: any = null;
// 流式调用LLM
for await (const chunk of adapter.chatStream(messages, {
temperature: modelConfig?.temperature,
maxTokens: modelConfig?.maxTokens,
topP: modelConfig?.topP,
})) {
fullContent += chunk.content;
if (chunk.usage) {
usage = chunk.usage;
}
yield chunk;
}
// 流式输出完成后,保存助手回复
await prisma.message.create({
data: {
conversationId,
role: 'assistant',
content: fullContent,
model: modelType,
tokens: usage?.totalTokens,
metadata: {
usage,
},
},
});
// 更新对话的最后更新时间
await prisma.conversation.update({
where: { id: conversationId },
data: { updatedAt: new Date() },
});
}
/**
* 删除对话(软删除)
*/
async deleteConversation(conversationId: string, userId: string) {
const conversation = await prisma.conversation.findFirst({
where: {
id: conversationId,
userId,
deletedAt: null,
},
});
if (!conversation) {
throw new Error('对话不存在或无权访问');
}
await prisma.conversation.update({
where: { id: conversationId },
data: { deletedAt: new Date() },
});
return { success: true };
}
}
export const conversationService = new ConversationService();