feat(aia): Implement Protocol Agent MVP with reusable Agent framework
Sprint 1-3 Completed (Backend + Frontend): Backend (Sprint 1-2): - Implement 5-layer Agent framework (Query->Planner->Executor->Tools->Reflection) - Create agent_schema with 6 tables (agent_definitions, stages, prompts, sessions, traces, reflexion_rules) - Create protocol_schema with 2 tables (protocol_contexts, protocol_generations) - Implement Protocol Agent core services (Orchestrator, ContextService, PromptBuilder) - Integrate LLM service adapter (DeepSeek/Qwen/GPT-5/Claude) - 6 API endpoints with full authentication - 10/10 API tests passed Frontend (Sprint 3): - Add Protocol Agent entry in AgentHub (indigo theme card) - Implement ProtocolAgentPage with 3-column layout - Collapsible sidebar (Gemini style, 48px <-> 280px) - StatePanel with 5 stage cards (scientific_question, pico, study_design, sample_size, endpoints) - ChatArea with sync button and action cards integration - 100% prototype design restoration (608 lines CSS) - Detailed endpoints structure: baseline, exposure, outcomes, confounders Features: - 5-stage dialogue flow for research protocol design - Conversation-driven interaction with sync-to-protocol button - Real-time context state management - One-click protocol generation button (UI ready, backend pending) Database: - agent_schema: 6 tables for reusable Agent framework - protocol_schema: 2 tables for Protocol Agent - Seed data: 1 agent + 5 stages + 9 prompts + 4 reflexion rules Code Stats: - Backend: 13 files, 4338 lines - Frontend: 14 files, 2071 lines - Total: 27 files, 6409 lines Status: MVP core functionality completed, pending frontend-backend integration testing Next: Sprint 4 - One-click protocol generation + Word export
This commit is contained in:
@@ -0,0 +1,287 @@
|
||||
/**
|
||||
* Protocol Agent Controller
|
||||
* 处理Protocol Agent的HTTP请求
|
||||
*
|
||||
* @module agent/protocol/controllers/ProtocolAgentController
|
||||
*/
|
||||
|
||||
import { FastifyRequest, FastifyReply } from 'fastify';
|
||||
import { PrismaClient } from '@prisma/client';
|
||||
import { ProtocolOrchestrator } from '../services/ProtocolOrchestrator.js';
|
||||
import { LLMServiceInterface } from '../../services/BaseAgentOrchestrator.js';
|
||||
import { ProtocolStageCode } from '../../types/index.js';
|
||||
|
||||
// 请求类型定义
|
||||
interface SendMessageBody {
|
||||
conversationId: string;
|
||||
content: string;
|
||||
messageId?: string;
|
||||
}
|
||||
|
||||
interface SyncDataBody {
|
||||
conversationId: string;
|
||||
stageCode: ProtocolStageCode;
|
||||
data: Record<string, unknown>;
|
||||
}
|
||||
|
||||
interface GenerateProtocolBody {
|
||||
conversationId: string;
|
||||
options?: {
|
||||
sections?: string[];
|
||||
style?: 'academic' | 'concise';
|
||||
};
|
||||
}
|
||||
|
||||
interface GetContextParams {
|
||||
conversationId: string;
|
||||
}
|
||||
|
||||
export class ProtocolAgentController {
|
||||
private orchestrator: ProtocolOrchestrator;
|
||||
|
||||
constructor(prisma: PrismaClient, llmService: LLMServiceInterface) {
|
||||
this.orchestrator = new ProtocolOrchestrator({ prisma, llmService });
|
||||
}
|
||||
|
||||
/**
|
||||
* 发送消息
|
||||
* POST /api/aia/protocol-agent/message
|
||||
*/
|
||||
async sendMessage(
|
||||
request: FastifyRequest<{ Body: SendMessageBody }>,
|
||||
reply: FastifyReply
|
||||
): Promise<void> {
|
||||
try {
|
||||
const { conversationId, content, messageId } = request.body;
|
||||
const userId = (request as any).user?.userId;
|
||||
|
||||
if (!userId) {
|
||||
reply.code(401).send({ error: 'Unauthorized' });
|
||||
return;
|
||||
}
|
||||
|
||||
if (!conversationId || !content) {
|
||||
reply.code(400).send({ error: 'Missing required fields: conversationId, content' });
|
||||
return;
|
||||
}
|
||||
|
||||
const response = await this.orchestrator.handleMessage({
|
||||
conversationId,
|
||||
userId,
|
||||
content,
|
||||
messageId,
|
||||
});
|
||||
|
||||
reply.send({
|
||||
success: true,
|
||||
data: response,
|
||||
});
|
||||
} catch (error) {
|
||||
console.error('[ProtocolAgentController] sendMessage error:', error);
|
||||
reply.code(500).send({
|
||||
success: false,
|
||||
error: error instanceof Error ? error.message : 'Internal server error',
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 同步阶段数据
|
||||
* POST /api/aia/protocol-agent/sync
|
||||
*/
|
||||
async syncData(
|
||||
request: FastifyRequest<{ Body: SyncDataBody }>,
|
||||
reply: FastifyReply
|
||||
): Promise<void> {
|
||||
try {
|
||||
const { conversationId, stageCode, data } = request.body;
|
||||
const userId = (request as any).user?.userId;
|
||||
|
||||
if (!userId) {
|
||||
reply.code(401).send({ error: 'Unauthorized' });
|
||||
return;
|
||||
}
|
||||
|
||||
if (!conversationId || !stageCode) {
|
||||
reply.code(400).send({ error: 'Missing required fields: conversationId, stageCode' });
|
||||
return;
|
||||
}
|
||||
|
||||
const result = await this.orchestrator.handleProtocolSync(
|
||||
conversationId,
|
||||
userId,
|
||||
stageCode,
|
||||
data
|
||||
);
|
||||
|
||||
reply.send({
|
||||
success: true,
|
||||
data: result,
|
||||
});
|
||||
} catch (error) {
|
||||
console.error('[ProtocolAgentController] syncData error:', error);
|
||||
reply.code(500).send({
|
||||
success: false,
|
||||
error: error instanceof Error ? error.message : 'Internal server error',
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 获取上下文状态
|
||||
* GET /api/aia/protocol-agent/context/:conversationId
|
||||
*/
|
||||
async getContext(
|
||||
request: FastifyRequest<{ Params: GetContextParams }>,
|
||||
reply: FastifyReply
|
||||
): Promise<void> {
|
||||
try {
|
||||
const { conversationId } = request.params;
|
||||
|
||||
if (!conversationId) {
|
||||
reply.code(400).send({ error: 'Missing conversationId' });
|
||||
return;
|
||||
}
|
||||
|
||||
const summary = await this.orchestrator.getContextSummary(conversationId);
|
||||
|
||||
if (!summary) {
|
||||
reply.code(404).send({ error: 'Context not found' });
|
||||
return;
|
||||
}
|
||||
|
||||
reply.send({
|
||||
success: true,
|
||||
data: summary,
|
||||
});
|
||||
} catch (error) {
|
||||
console.error('[ProtocolAgentController] getContext error:', error);
|
||||
reply.code(500).send({
|
||||
success: false,
|
||||
error: error instanceof Error ? error.message : 'Internal server error',
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 一键生成研究方案
|
||||
* POST /api/aia/protocol-agent/generate
|
||||
*/
|
||||
async generateProtocol(
|
||||
request: FastifyRequest<{ Body: GenerateProtocolBody }>,
|
||||
reply: FastifyReply
|
||||
): Promise<void> {
|
||||
try {
|
||||
const { conversationId, options } = request.body;
|
||||
const userId = (request as any).user?.userId;
|
||||
|
||||
if (!userId) {
|
||||
reply.code(401).send({ error: 'Unauthorized' });
|
||||
return;
|
||||
}
|
||||
|
||||
if (!conversationId) {
|
||||
reply.code(400).send({ error: 'Missing conversationId' });
|
||||
return;
|
||||
}
|
||||
|
||||
// 获取上下文
|
||||
const contextService = this.orchestrator.getContextService();
|
||||
const context = await contextService.getContext(conversationId);
|
||||
|
||||
if (!context) {
|
||||
reply.code(404).send({ error: 'Context not found' });
|
||||
return;
|
||||
}
|
||||
|
||||
// 检查是否所有阶段都已完成
|
||||
if (!contextService.isAllStagesCompleted(context)) {
|
||||
reply.code(400).send({
|
||||
error: '请先完成所有5个阶段(科学问题、PICO、研究设计、样本量、观察指标)'
|
||||
});
|
||||
return;
|
||||
}
|
||||
|
||||
// TODO: 实现方案生成逻辑
|
||||
// 这里先返回占位响应,实际应该调用LLM生成完整方案
|
||||
reply.send({
|
||||
success: true,
|
||||
data: {
|
||||
generationId: 'placeholder',
|
||||
status: 'generating',
|
||||
message: '研究方案生成中...',
|
||||
estimatedTime: 30,
|
||||
},
|
||||
});
|
||||
} catch (error) {
|
||||
console.error('[ProtocolAgentController] generateProtocol error:', error);
|
||||
reply.code(500).send({
|
||||
success: false,
|
||||
error: error instanceof Error ? error.message : 'Internal server error',
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 获取生成的方案
|
||||
* GET /api/aia/protocol-agent/generation/:generationId
|
||||
*/
|
||||
async getGeneration(
|
||||
request: FastifyRequest<{ Params: { generationId: string } }>,
|
||||
reply: FastifyReply
|
||||
): Promise<void> {
|
||||
try {
|
||||
const { generationId } = request.params;
|
||||
|
||||
// TODO: 实现获取生成结果逻辑
|
||||
reply.send({
|
||||
success: true,
|
||||
data: {
|
||||
id: generationId,
|
||||
status: 'completed',
|
||||
content: '# 研究方案\n\n(生成中...)',
|
||||
contentVersion: 1,
|
||||
},
|
||||
});
|
||||
} catch (error) {
|
||||
console.error('[ProtocolAgentController] getGeneration error:', error);
|
||||
reply.code(500).send({
|
||||
success: false,
|
||||
error: error instanceof Error ? error.message : 'Internal server error',
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 导出Word文档
|
||||
* POST /api/aia/protocol-agent/generation/:generationId/export
|
||||
*/
|
||||
async exportWord(
|
||||
request: FastifyRequest<{
|
||||
Params: { generationId: string };
|
||||
Body: { format: 'docx' | 'pdf' };
|
||||
}>,
|
||||
reply: FastifyReply
|
||||
): Promise<void> {
|
||||
try {
|
||||
const { generationId } = request.params;
|
||||
const { format } = request.body;
|
||||
|
||||
// TODO: 实现导出逻辑
|
||||
reply.send({
|
||||
success: true,
|
||||
data: {
|
||||
downloadUrl: `/api/aia/protocol-agent/download/${generationId}.${format}`,
|
||||
expiresAt: new Date(Date.now() + 3600 * 1000).toISOString(),
|
||||
},
|
||||
});
|
||||
} catch (error) {
|
||||
console.error('[ProtocolAgentController] exportWord error:', error);
|
||||
reply.code(500).send({
|
||||
success: false,
|
||||
error: error instanceof Error ? error.message : 'Internal server error',
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
22
backend/src/modules/agent/protocol/index.ts
Normal file
22
backend/src/modules/agent/protocol/index.ts
Normal file
@@ -0,0 +1,22 @@
|
||||
/**
|
||||
* Protocol Agent Module
|
||||
* 研究方案制定Agent
|
||||
*
|
||||
* @module agent/protocol
|
||||
*/
|
||||
|
||||
// Services
|
||||
export {
|
||||
ProtocolContextService,
|
||||
ProtocolOrchestrator,
|
||||
PromptBuilder,
|
||||
LLMServiceAdapter,
|
||||
createLLMServiceAdapter,
|
||||
} from './services/index.js';
|
||||
|
||||
// Routes
|
||||
export { protocolAgentRoutes } from './routes/index.js';
|
||||
|
||||
// Controllers
|
||||
export { ProtocolAgentController } from './controllers/ProtocolAgentController.js';
|
||||
|
||||
153
backend/src/modules/agent/protocol/routes/index.ts
Normal file
153
backend/src/modules/agent/protocol/routes/index.ts
Normal file
@@ -0,0 +1,153 @@
|
||||
/**
|
||||
* Protocol Agent Routes
|
||||
* Protocol Agent的API路由定义
|
||||
*
|
||||
* @module agent/protocol/routes
|
||||
*/
|
||||
|
||||
import { FastifyInstance } from 'fastify';
|
||||
import { PrismaClient } from '@prisma/client';
|
||||
import { ProtocolAgentController } from '../controllers/ProtocolAgentController.js';
|
||||
import { createLLMServiceAdapter } from '../services/LLMServiceAdapter.js';
|
||||
import { authenticate } from '../../../../common/auth/auth.middleware.js';
|
||||
|
||||
export async function protocolAgentRoutes(
|
||||
fastify: FastifyInstance,
|
||||
options: { prisma: PrismaClient }
|
||||
): Promise<void> {
|
||||
const { prisma } = options;
|
||||
|
||||
// 创建LLM服务(使用真实的LLM适配器)
|
||||
const llmService = createLLMServiceAdapter();
|
||||
|
||||
// 创建控制器
|
||||
const controller = new ProtocolAgentController(prisma, llmService);
|
||||
|
||||
// ============================================
|
||||
// Protocol Agent API Routes (需要认证)
|
||||
// ============================================
|
||||
|
||||
// 发送消息
|
||||
fastify.post<{
|
||||
Body: {
|
||||
conversationId: string;
|
||||
content: string;
|
||||
messageId?: string;
|
||||
};
|
||||
}>('/message', {
|
||||
preHandler: [authenticate],
|
||||
schema: {
|
||||
body: {
|
||||
type: 'object',
|
||||
required: ['conversationId', 'content'],
|
||||
properties: {
|
||||
conversationId: { type: 'string' },
|
||||
content: { type: 'string' },
|
||||
messageId: { type: 'string' },
|
||||
},
|
||||
},
|
||||
},
|
||||
}, (request, reply) => controller.sendMessage(request, reply));
|
||||
|
||||
// 同步阶段数据
|
||||
fastify.post('/sync', {
|
||||
preHandler: [authenticate],
|
||||
schema: {
|
||||
body: {
|
||||
type: 'object',
|
||||
required: ['conversationId', 'stageCode'],
|
||||
properties: {
|
||||
conversationId: { type: 'string' },
|
||||
stageCode: { type: 'string' },
|
||||
data: { type: 'object' },
|
||||
},
|
||||
},
|
||||
},
|
||||
}, (request, reply) => controller.syncData(request as never, reply));
|
||||
|
||||
// 获取上下文状态
|
||||
fastify.get<{
|
||||
Params: { conversationId: string };
|
||||
}>('/context/:conversationId', {
|
||||
preHandler: [authenticate],
|
||||
schema: {
|
||||
params: {
|
||||
type: 'object',
|
||||
required: ['conversationId'],
|
||||
properties: {
|
||||
conversationId: { type: 'string' },
|
||||
},
|
||||
},
|
||||
},
|
||||
}, (request, reply) => controller.getContext(request, reply));
|
||||
|
||||
// 一键生成研究方案
|
||||
fastify.post<{
|
||||
Body: {
|
||||
conversationId: string;
|
||||
options?: {
|
||||
sections?: string[];
|
||||
style?: 'academic' | 'concise';
|
||||
};
|
||||
};
|
||||
}>('/generate', {
|
||||
preHandler: [authenticate],
|
||||
schema: {
|
||||
body: {
|
||||
type: 'object',
|
||||
required: ['conversationId'],
|
||||
properties: {
|
||||
conversationId: { type: 'string' },
|
||||
options: {
|
||||
type: 'object',
|
||||
properties: {
|
||||
sections: { type: 'array', items: { type: 'string' } },
|
||||
style: { type: 'string', enum: ['academic', 'concise'] },
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}, (request, reply) => controller.generateProtocol(request, reply));
|
||||
|
||||
// 获取生成的方案
|
||||
fastify.get<{
|
||||
Params: { generationId: string };
|
||||
}>('/generation/:generationId', {
|
||||
preHandler: [authenticate],
|
||||
schema: {
|
||||
params: {
|
||||
type: 'object',
|
||||
required: ['generationId'],
|
||||
properties: {
|
||||
generationId: { type: 'string' },
|
||||
},
|
||||
},
|
||||
},
|
||||
}, (request, reply) => controller.getGeneration(request, reply));
|
||||
|
||||
// 导出Word文档
|
||||
fastify.post<{
|
||||
Params: { generationId: string };
|
||||
Body: { format: 'docx' | 'pdf' };
|
||||
}>('/generation/:generationId/export', {
|
||||
preHandler: [authenticate],
|
||||
schema: {
|
||||
params: {
|
||||
type: 'object',
|
||||
required: ['generationId'],
|
||||
properties: {
|
||||
generationId: { type: 'string' },
|
||||
},
|
||||
},
|
||||
body: {
|
||||
type: 'object',
|
||||
required: ['format'],
|
||||
properties: {
|
||||
format: { type: 'string', enum: ['docx', 'pdf'] },
|
||||
},
|
||||
},
|
||||
},
|
||||
}, (request, reply) => controller.exportWord(request, reply));
|
||||
}
|
||||
|
||||
184
backend/src/modules/agent/protocol/services/LLMServiceAdapter.ts
Normal file
184
backend/src/modules/agent/protocol/services/LLMServiceAdapter.ts
Normal file
@@ -0,0 +1,184 @@
|
||||
/**
|
||||
* LLM Service Adapter
|
||||
* 将现有的LLM服务适配为Protocol Agent所需的接口
|
||||
*
|
||||
* @module agent/protocol/services/LLMServiceAdapter
|
||||
*/
|
||||
|
||||
import { LLMFactory } from '../../../../common/llm/adapters/LLMFactory.js';
|
||||
import { ModelType, Message as LLMMessage, LLMOptions } from '../../../../common/llm/adapters/types.js';
|
||||
import { LLMServiceInterface } from '../../services/BaseAgentOrchestrator.js';
|
||||
|
||||
export class LLMServiceAdapter implements LLMServiceInterface {
|
||||
private defaultModel: ModelType;
|
||||
private defaultTemperature: number;
|
||||
private defaultMaxTokens: number;
|
||||
|
||||
constructor(options?: {
|
||||
defaultModel?: ModelType;
|
||||
defaultTemperature?: number;
|
||||
defaultMaxTokens?: number;
|
||||
}) {
|
||||
this.defaultModel = options?.defaultModel ?? 'deepseek-v3';
|
||||
this.defaultTemperature = options?.defaultTemperature ?? 0.7;
|
||||
this.defaultMaxTokens = options?.defaultMaxTokens ?? 4096;
|
||||
}
|
||||
|
||||
/**
|
||||
* 调用LLM进行对话
|
||||
*/
|
||||
async chat(params: {
|
||||
messages: Array<{ role: string; content: string }>;
|
||||
model?: string;
|
||||
temperature?: number;
|
||||
maxTokens?: number;
|
||||
}): Promise<{
|
||||
content: string;
|
||||
thinkingContent?: string;
|
||||
tokensUsed: number;
|
||||
model: string;
|
||||
}> {
|
||||
// 获取模型类型
|
||||
const modelType = this.parseModelType(params.model);
|
||||
|
||||
// 获取LLM适配器
|
||||
const adapter = LLMFactory.getAdapter(modelType);
|
||||
|
||||
// 转换消息格式
|
||||
const messages: LLMMessage[] = params.messages.map(m => ({
|
||||
role: m.role as 'system' | 'user' | 'assistant',
|
||||
content: m.content,
|
||||
}));
|
||||
|
||||
// 调用LLM
|
||||
const options: LLMOptions = {
|
||||
temperature: params.temperature ?? this.defaultTemperature,
|
||||
maxTokens: params.maxTokens ?? this.defaultMaxTokens,
|
||||
};
|
||||
|
||||
try {
|
||||
const response = await adapter.chat(messages, options);
|
||||
|
||||
// 提取思考内容(如果有)
|
||||
const { content, thinkingContent } = this.extractThinkingContent(response.content);
|
||||
|
||||
return {
|
||||
content,
|
||||
thinkingContent,
|
||||
tokensUsed: response.usage?.totalTokens ?? 0,
|
||||
model: response.model,
|
||||
};
|
||||
} catch (error) {
|
||||
console.error('[LLMServiceAdapter] chat error:', error);
|
||||
throw error;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 流式调用LLM(返回AsyncGenerator)
|
||||
*/
|
||||
async *chatStream(params: {
|
||||
messages: Array<{ role: string; content: string }>;
|
||||
model?: string;
|
||||
temperature?: number;
|
||||
maxTokens?: number;
|
||||
}): AsyncGenerator<{
|
||||
content: string;
|
||||
done: boolean;
|
||||
tokensUsed?: number;
|
||||
}> {
|
||||
const modelType = this.parseModelType(params.model);
|
||||
const adapter = LLMFactory.getAdapter(modelType);
|
||||
|
||||
const messages: LLMMessage[] = params.messages.map(m => ({
|
||||
role: m.role as 'system' | 'user' | 'assistant',
|
||||
content: m.content,
|
||||
}));
|
||||
|
||||
const options: LLMOptions = {
|
||||
temperature: params.temperature ?? this.defaultTemperature,
|
||||
maxTokens: params.maxTokens ?? this.defaultMaxTokens,
|
||||
stream: true,
|
||||
};
|
||||
|
||||
try {
|
||||
for await (const chunk of adapter.chatStream(messages, options)) {
|
||||
yield {
|
||||
content: chunk.content,
|
||||
done: chunk.done,
|
||||
tokensUsed: chunk.usage?.totalTokens,
|
||||
};
|
||||
}
|
||||
} catch (error) {
|
||||
console.error('[LLMServiceAdapter] chatStream error:', error);
|
||||
throw error;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 解析模型类型
|
||||
*/
|
||||
private parseModelType(model?: string): ModelType {
|
||||
if (!model) return this.defaultModel;
|
||||
|
||||
// 映射模型名称到ModelType
|
||||
const modelMap: Record<string, ModelType> = {
|
||||
'deepseek-v3': 'deepseek-v3',
|
||||
'deepseek-chat': 'deepseek-v3',
|
||||
'qwen-max': 'qwen3-72b',
|
||||
'qwen3-72b': 'qwen3-72b',
|
||||
'qwen-long': 'qwen-long',
|
||||
'gpt-5': 'gpt-5',
|
||||
'gpt-5-pro': 'gpt-5',
|
||||
'claude-4.5': 'claude-4.5',
|
||||
'claude-sonnet': 'claude-4.5',
|
||||
};
|
||||
|
||||
return modelMap[model.toLowerCase()] ?? this.defaultModel;
|
||||
}
|
||||
|
||||
/**
|
||||
* 从响应中提取思考内容(<think>...</think>)
|
||||
*/
|
||||
private extractThinkingContent(content: string): {
|
||||
content: string;
|
||||
thinkingContent?: string;
|
||||
} {
|
||||
// 匹配 <think>...</think> 或 <thinking>...</thinking>
|
||||
const thinkingPattern = /<(?:think|thinking)>([\s\S]*?)<\/(?:think|thinking)>/gi;
|
||||
const matches = content.matchAll(thinkingPattern);
|
||||
|
||||
let thinkingContent = '';
|
||||
let cleanContent = content;
|
||||
|
||||
for (const match of matches) {
|
||||
thinkingContent += match[1].trim() + '\n';
|
||||
cleanContent = cleanContent.replace(match[0], '').trim();
|
||||
}
|
||||
|
||||
return {
|
||||
content: cleanContent,
|
||||
thinkingContent: thinkingContent ? thinkingContent.trim() : undefined,
|
||||
};
|
||||
}
|
||||
|
||||
/**
|
||||
* 获取支持的模型列表
|
||||
*/
|
||||
getSupportedModels(): string[] {
|
||||
return LLMFactory.getSupportedModels();
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 创建默认的LLM服务适配器
|
||||
*/
|
||||
export function createLLMServiceAdapter(): LLMServiceInterface {
|
||||
return new LLMServiceAdapter({
|
||||
defaultModel: 'deepseek-v3',
|
||||
defaultTemperature: 0.7,
|
||||
defaultMaxTokens: 4096,
|
||||
});
|
||||
}
|
||||
|
||||
|
||||
286
backend/src/modules/agent/protocol/services/PromptBuilder.ts
Normal file
286
backend/src/modules/agent/protocol/services/PromptBuilder.ts
Normal file
@@ -0,0 +1,286 @@
|
||||
/**
|
||||
* Prompt Builder
|
||||
* 构建和渲染Protocol Agent的Prompt
|
||||
*
|
||||
* @module agent/protocol/services/PromptBuilder
|
||||
*/
|
||||
|
||||
import { ConfigLoader } from '../../services/ConfigLoader.js';
|
||||
import {
|
||||
AgentPrompt,
|
||||
ProtocolContextData,
|
||||
PromptRenderContext,
|
||||
} from '../../types/index.js';
|
||||
|
||||
export class PromptBuilder {
|
||||
private configLoader: ConfigLoader;
|
||||
|
||||
constructor(configLoader: ConfigLoader) {
|
||||
this.configLoader = configLoader;
|
||||
}
|
||||
|
||||
/**
|
||||
* 构建完整的消息列表
|
||||
*/
|
||||
async buildMessages(
|
||||
context: ProtocolContextData,
|
||||
userMessage: string,
|
||||
conversationHistory?: Array<{ role: string; content: string }>
|
||||
): Promise<Array<{ role: string; content: string }>> {
|
||||
const messages: Array<{ role: string; content: string }> = [];
|
||||
|
||||
// 1. 系统Prompt
|
||||
const systemPrompt = await this.buildSystemPrompt(context);
|
||||
if (systemPrompt) {
|
||||
messages.push({ role: 'system', content: systemPrompt });
|
||||
}
|
||||
|
||||
// 2. 阶段Prompt
|
||||
const stagePrompt = await this.buildStagePrompt(context);
|
||||
if (stagePrompt) {
|
||||
messages.push({ role: 'system', content: stagePrompt });
|
||||
}
|
||||
|
||||
// 3. 对话历史(最近5轮)
|
||||
if (conversationHistory?.length) {
|
||||
const recentHistory = conversationHistory.slice(-10); // 最近10条消息(5轮)
|
||||
messages.push(...recentHistory);
|
||||
}
|
||||
|
||||
// 4. 用户消息
|
||||
messages.push({ role: 'user', content: userMessage });
|
||||
|
||||
return messages;
|
||||
}
|
||||
|
||||
/**
|
||||
* 构建系统Prompt
|
||||
*/
|
||||
async buildSystemPrompt(context: ProtocolContextData): Promise<string | null> {
|
||||
const prompt = await this.configLoader.getSystemPrompt('protocol_agent');
|
||||
if (!prompt) return null;
|
||||
|
||||
return this.renderTemplate(prompt.content, { context });
|
||||
}
|
||||
|
||||
/**
|
||||
* 构建当前阶段Prompt
|
||||
*/
|
||||
async buildStagePrompt(context: ProtocolContextData): Promise<string | null> {
|
||||
const prompt = await this.configLoader.getStagePrompt(
|
||||
'protocol_agent',
|
||||
context.currentStage
|
||||
);
|
||||
if (!prompt) return null;
|
||||
|
||||
return this.renderTemplate(prompt.content, { context });
|
||||
}
|
||||
|
||||
/**
|
||||
* 构建数据提取Prompt
|
||||
*/
|
||||
async buildExtractionPrompt(
|
||||
context: ProtocolContextData,
|
||||
userMessage: string
|
||||
): Promise<string | null> {
|
||||
const prompt = await this.configLoader.getExtractionPrompt(
|
||||
'protocol_agent',
|
||||
context.currentStage
|
||||
);
|
||||
if (!prompt) return null;
|
||||
|
||||
return this.renderTemplate(prompt.content, {
|
||||
userMessage,
|
||||
context,
|
||||
currentPico: context.pico,
|
||||
});
|
||||
}
|
||||
|
||||
/**
|
||||
* 构建方案生成Prompt
|
||||
*/
|
||||
async buildGenerationPrompt(context: ProtocolContextData): Promise<string | null> {
|
||||
const config = await this.configLoader.loadAgentConfig('protocol_agent');
|
||||
const prompt = config.prompts.find(p => p.promptCode === 'generate_protocol');
|
||||
|
||||
if (!prompt) return null;
|
||||
|
||||
return this.renderTemplate(prompt.content, {
|
||||
scientificQuestion: context.scientificQuestion,
|
||||
pico: context.pico,
|
||||
studyDesign: context.studyDesign,
|
||||
sampleSize: context.sampleSize,
|
||||
endpoints: context.endpoints,
|
||||
});
|
||||
}
|
||||
|
||||
/**
|
||||
* 渲染模板
|
||||
* 支持 {{variable}} 和 {{#if variable}}...{{/if}} 语法
|
||||
*/
|
||||
private renderTemplate(
|
||||
template: string,
|
||||
variables: Record<string, unknown>
|
||||
): string {
|
||||
let result = template;
|
||||
|
||||
// 处理 {{#if variable}}...{{/if}} 条件块
|
||||
result = this.processConditionals(result, variables);
|
||||
|
||||
// 处理 {{#each array}}...{{/each}} 循环块
|
||||
result = this.processEachBlocks(result, variables);
|
||||
|
||||
// 处理简单变量替换 {{variable}}
|
||||
result = this.processVariables(result, variables);
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
/**
|
||||
* 处理条件块
|
||||
*/
|
||||
private processConditionals(
|
||||
template: string,
|
||||
variables: Record<string, unknown>
|
||||
): string {
|
||||
const ifPattern = /\{\{#if\s+(\S+)\}\}([\s\S]*?)\{\{\/if\}\}/g;
|
||||
|
||||
return template.replace(ifPattern, (match, condition, content) => {
|
||||
const value = this.getNestedValue(variables, condition);
|
||||
if (value && value !== false && value !== null && value !== undefined) {
|
||||
return content;
|
||||
}
|
||||
return '';
|
||||
});
|
||||
}
|
||||
|
||||
/**
|
||||
* 处理循环块
|
||||
*/
|
||||
private processEachBlocks(
|
||||
template: string,
|
||||
variables: Record<string, unknown>
|
||||
): string {
|
||||
const eachPattern = /\{\{#each\s+(\S+)\}\}([\s\S]*?)\{\{\/each\}\}/g;
|
||||
|
||||
return template.replace(eachPattern, (match, arrayPath, content) => {
|
||||
const array = this.getNestedValue(variables, arrayPath);
|
||||
if (!Array.isArray(array)) return '';
|
||||
|
||||
return array.map((item, index) => {
|
||||
let itemContent = content;
|
||||
|
||||
// 替换 {{this}} 为当前项
|
||||
itemContent = itemContent.replace(/\{\{this\}\}/g, String(item));
|
||||
|
||||
// 替换 {{@index}} 为索引
|
||||
itemContent = itemContent.replace(/\{\{@index\}\}/g, String(index));
|
||||
|
||||
// 替换项属性 {{name}}, {{definition}} 等
|
||||
if (typeof item === 'object' && item !== null) {
|
||||
for (const [key, value] of Object.entries(item)) {
|
||||
const regex = new RegExp(`\\{\\{${key}\\}\\}`, 'g');
|
||||
itemContent = itemContent.replace(regex, String(value ?? ''));
|
||||
}
|
||||
}
|
||||
|
||||
return itemContent;
|
||||
}).join('\n');
|
||||
});
|
||||
}
|
||||
|
||||
/**
|
||||
* 处理变量替换
|
||||
*/
|
||||
private processVariables(
|
||||
template: string,
|
||||
variables: Record<string, unknown>
|
||||
): string {
|
||||
const varPattern = /\{\{([^#/][^}]*)\}\}/g;
|
||||
|
||||
return template.replace(varPattern, (match, varPath) => {
|
||||
const value = this.getNestedValue(variables, varPath.trim());
|
||||
|
||||
if (value === undefined || value === null) {
|
||||
return '';
|
||||
}
|
||||
|
||||
if (typeof value === 'object') {
|
||||
return JSON.stringify(value, null, 2);
|
||||
}
|
||||
|
||||
return String(value);
|
||||
});
|
||||
}
|
||||
|
||||
/**
|
||||
* 获取嵌套属性值
|
||||
*/
|
||||
private getNestedValue(
|
||||
obj: Record<string, unknown>,
|
||||
path: string
|
||||
): unknown {
|
||||
const parts = path.split('.');
|
||||
let current: unknown = obj;
|
||||
|
||||
for (const part of parts) {
|
||||
if (current === null || current === undefined) {
|
||||
return undefined;
|
||||
}
|
||||
if (typeof current !== 'object') {
|
||||
return undefined;
|
||||
}
|
||||
current = (current as Record<string, unknown>)[part];
|
||||
}
|
||||
|
||||
return current;
|
||||
}
|
||||
|
||||
/**
|
||||
* 构建欢迎消息
|
||||
*/
|
||||
buildWelcomeMessage(): string {
|
||||
return `您好!我是研究方案制定助手,将帮助您系统地完成临床研究方案的核心要素设计。
|
||||
|
||||
我们将一起完成以下5个关键步骤:
|
||||
|
||||
1️⃣ **科学问题梳理** - 明确研究要解决的核心问题
|
||||
2️⃣ **PICO要素** - 确定研究人群、干预、对照和结局
|
||||
3️⃣ **研究设计** - 选择合适的研究类型和方法
|
||||
4️⃣ **样本量计算** - 估算所需的样本量
|
||||
5️⃣ **观察指标** - 定义主要和次要结局指标
|
||||
|
||||
完成这5个要素后,您可以**一键生成完整的研究方案**并下载为Word文档。
|
||||
|
||||
让我们开始吧!请先告诉我,您想研究什么问题?或者描述一下您的研究背景和想法。`;
|
||||
}
|
||||
|
||||
/**
|
||||
* 构建阶段完成消息
|
||||
*/
|
||||
buildStageCompleteMessage(
|
||||
stageName: string,
|
||||
nextStageName?: string,
|
||||
isAllCompleted: boolean = false
|
||||
): string {
|
||||
if (isAllCompleted) {
|
||||
return `✅ ${stageName}已同步到方案!
|
||||
|
||||
🎉 **恭喜!您已完成所有5个核心要素的梳理!**
|
||||
|
||||
您现在可以:
|
||||
- 点击「🚀 一键生成研究方案」生成完整方案
|
||||
- 或者回顾修改任何阶段的内容
|
||||
|
||||
需要我帮您生成研究方案吗?`;
|
||||
}
|
||||
|
||||
return `✅ ${stageName}已同步到方案!
|
||||
|
||||
接下来我们进入**${nextStageName}**阶段。准备好了吗?
|
||||
|
||||
说"继续"我们就开始,或者您也可以先问我任何问题。`;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -0,0 +1,289 @@
|
||||
/**
|
||||
* Protocol Context Service
|
||||
* 管理研究方案的上下文数据
|
||||
*
|
||||
* @module agent/protocol/services/ProtocolContextService
|
||||
*/
|
||||
|
||||
import { PrismaClient } from '@prisma/client';
|
||||
import {
|
||||
ProtocolContextData,
|
||||
ProtocolStageCode,
|
||||
ScientificQuestionData,
|
||||
PICOData,
|
||||
StudyDesignData,
|
||||
SampleSizeData,
|
||||
EndpointsData,
|
||||
} from '../../types/index.js';
|
||||
|
||||
export class ProtocolContextService {
|
||||
private prisma: PrismaClient;
|
||||
|
||||
constructor(prisma: PrismaClient) {
|
||||
this.prisma = prisma;
|
||||
}
|
||||
|
||||
/**
|
||||
* 获取或创建上下文
|
||||
*/
|
||||
async getOrCreateContext(
|
||||
conversationId: string,
|
||||
userId: string
|
||||
): Promise<ProtocolContextData> {
|
||||
let context = await this.getContext(conversationId);
|
||||
|
||||
if (!context) {
|
||||
context = await this.createContext(conversationId, userId);
|
||||
}
|
||||
|
||||
return context;
|
||||
}
|
||||
|
||||
/**
|
||||
* 获取上下文
|
||||
*/
|
||||
async getContext(conversationId: string): Promise<ProtocolContextData | null> {
|
||||
const result = await this.prisma.protocolContext.findUnique({
|
||||
where: { conversationId },
|
||||
});
|
||||
|
||||
if (!result) return null;
|
||||
|
||||
return this.mapToContextData(result);
|
||||
}
|
||||
|
||||
/**
|
||||
* 创建新上下文
|
||||
*/
|
||||
async createContext(
|
||||
conversationId: string,
|
||||
userId: string
|
||||
): Promise<ProtocolContextData> {
|
||||
const result = await this.prisma.protocolContext.create({
|
||||
data: {
|
||||
conversationId,
|
||||
userId,
|
||||
currentStage: 'scientific_question',
|
||||
status: 'in_progress',
|
||||
completedStages: [],
|
||||
},
|
||||
});
|
||||
|
||||
return this.mapToContextData(result);
|
||||
}
|
||||
|
||||
/**
|
||||
* 更新阶段数据
|
||||
*/
|
||||
async updateStageData(
|
||||
conversationId: string,
|
||||
stageCode: ProtocolStageCode,
|
||||
data: Record<string, unknown>
|
||||
): Promise<ProtocolContextData> {
|
||||
const updateData: Record<string, unknown> = {};
|
||||
|
||||
switch (stageCode) {
|
||||
case 'scientific_question':
|
||||
updateData.scientificQuestion = data;
|
||||
break;
|
||||
case 'pico':
|
||||
updateData.pico = data;
|
||||
break;
|
||||
case 'study_design':
|
||||
updateData.studyDesign = data;
|
||||
break;
|
||||
case 'sample_size':
|
||||
updateData.sampleSize = data;
|
||||
break;
|
||||
case 'endpoints':
|
||||
updateData.endpoints = data;
|
||||
break;
|
||||
}
|
||||
|
||||
const result = await this.prisma.protocolContext.update({
|
||||
where: { conversationId },
|
||||
data: {
|
||||
...updateData,
|
||||
lastActiveAt: new Date(),
|
||||
},
|
||||
});
|
||||
|
||||
return this.mapToContextData(result);
|
||||
}
|
||||
|
||||
/**
|
||||
* 标记阶段完成并更新当前阶段
|
||||
*/
|
||||
async completeStage(
|
||||
conversationId: string,
|
||||
stageCode: ProtocolStageCode,
|
||||
nextStage?: ProtocolStageCode
|
||||
): Promise<ProtocolContextData> {
|
||||
const context = await this.getContext(conversationId);
|
||||
if (!context) {
|
||||
throw new Error('Context not found');
|
||||
}
|
||||
|
||||
const completedStages = [...context.completedStages];
|
||||
if (!completedStages.includes(stageCode)) {
|
||||
completedStages.push(stageCode);
|
||||
}
|
||||
|
||||
const result = await this.prisma.protocolContext.update({
|
||||
where: { conversationId },
|
||||
data: {
|
||||
completedStages,
|
||||
currentStage: nextStage ?? context.currentStage,
|
||||
lastActiveAt: new Date(),
|
||||
},
|
||||
});
|
||||
|
||||
return this.mapToContextData(result);
|
||||
}
|
||||
|
||||
/**
|
||||
* 更新当前阶段
|
||||
*/
|
||||
async updateCurrentStage(
|
||||
conversationId: string,
|
||||
stageCode: ProtocolStageCode
|
||||
): Promise<ProtocolContextData> {
|
||||
const result = await this.prisma.protocolContext.update({
|
||||
where: { conversationId },
|
||||
data: {
|
||||
currentStage: stageCode,
|
||||
lastActiveAt: new Date(),
|
||||
},
|
||||
});
|
||||
|
||||
return this.mapToContextData(result);
|
||||
}
|
||||
|
||||
/**
|
||||
* 标记方案完成
|
||||
*/
|
||||
async markCompleted(conversationId: string): Promise<ProtocolContextData> {
|
||||
const result = await this.prisma.protocolContext.update({
|
||||
where: { conversationId },
|
||||
data: {
|
||||
status: 'completed',
|
||||
lastActiveAt: new Date(),
|
||||
},
|
||||
});
|
||||
|
||||
return this.mapToContextData(result);
|
||||
}
|
||||
|
||||
/**
|
||||
* 检查是否所有阶段都已完成
|
||||
*/
|
||||
isAllStagesCompleted(context: ProtocolContextData): boolean {
|
||||
const requiredStages: ProtocolStageCode[] = [
|
||||
'scientific_question',
|
||||
'pico',
|
||||
'study_design',
|
||||
'sample_size',
|
||||
'endpoints',
|
||||
];
|
||||
|
||||
return requiredStages.every(stage => context.completedStages.includes(stage));
|
||||
}
|
||||
|
||||
/**
|
||||
* 获取进度百分比
|
||||
*/
|
||||
getProgress(context: ProtocolContextData): number {
|
||||
const totalStages = 5;
|
||||
return Math.round((context.completedStages.length / totalStages) * 100);
|
||||
}
|
||||
|
||||
/**
|
||||
* 获取阶段状态列表
|
||||
*/
|
||||
getStagesStatus(context: ProtocolContextData): Array<{
|
||||
stageCode: ProtocolStageCode;
|
||||
stageName: string;
|
||||
status: 'completed' | 'current' | 'pending';
|
||||
data: Record<string, unknown> | null;
|
||||
}> {
|
||||
const stages: Array<{
|
||||
code: ProtocolStageCode;
|
||||
name: string;
|
||||
dataKey: keyof ProtocolContextData;
|
||||
}> = [
|
||||
{ code: 'scientific_question', name: '科学问题梳理', dataKey: 'scientificQuestion' },
|
||||
{ code: 'pico', name: 'PICO要素', dataKey: 'pico' },
|
||||
{ code: 'study_design', name: '研究设计', dataKey: 'studyDesign' },
|
||||
{ code: 'sample_size', name: '样本量计算', dataKey: 'sampleSize' },
|
||||
{ code: 'endpoints', name: '观察指标', dataKey: 'endpoints' },
|
||||
];
|
||||
|
||||
return stages.map(stage => ({
|
||||
stageCode: stage.code,
|
||||
stageName: stage.name,
|
||||
status: context.completedStages.includes(stage.code)
|
||||
? 'completed' as const
|
||||
: context.currentStage === stage.code
|
||||
? 'current' as const
|
||||
: 'pending' as const,
|
||||
data: context[stage.dataKey] as unknown as Record<string, unknown> | null ?? null,
|
||||
}));
|
||||
}
|
||||
|
||||
/**
|
||||
* 获取用于生成方案的完整数据
|
||||
*/
|
||||
getGenerationData(context: ProtocolContextData): {
|
||||
scientificQuestion: ScientificQuestionData | null;
|
||||
pico: PICOData | null;
|
||||
studyDesign: StudyDesignData | null;
|
||||
sampleSize: SampleSizeData | null;
|
||||
endpoints: EndpointsData | null;
|
||||
} {
|
||||
return {
|
||||
scientificQuestion: context.scientificQuestion ?? null,
|
||||
pico: context.pico ?? null,
|
||||
studyDesign: context.studyDesign ?? null,
|
||||
sampleSize: context.sampleSize ?? null,
|
||||
endpoints: context.endpoints ?? null,
|
||||
};
|
||||
}
|
||||
|
||||
/**
|
||||
* 将数据库结果映射为上下文数据
|
||||
*/
|
||||
private mapToContextData(result: {
|
||||
id: string;
|
||||
conversationId: string;
|
||||
userId: string;
|
||||
currentStage: string;
|
||||
status: string;
|
||||
scientificQuestion: unknown;
|
||||
pico: unknown;
|
||||
studyDesign: unknown;
|
||||
sampleSize: unknown;
|
||||
endpoints: unknown;
|
||||
completedStages: string[];
|
||||
lastActiveAt: Date;
|
||||
createdAt: Date;
|
||||
updatedAt: Date;
|
||||
}): ProtocolContextData {
|
||||
return {
|
||||
id: result.id,
|
||||
conversationId: result.conversationId,
|
||||
userId: result.userId,
|
||||
currentStage: result.currentStage as ProtocolStageCode,
|
||||
status: result.status as ProtocolContextData['status'],
|
||||
scientificQuestion: result.scientificQuestion as ScientificQuestionData | undefined,
|
||||
pico: result.pico as PICOData | undefined,
|
||||
studyDesign: result.studyDesign as StudyDesignData | undefined,
|
||||
sampleSize: result.sampleSize as SampleSizeData | undefined,
|
||||
endpoints: result.endpoints as EndpointsData | undefined,
|
||||
completedStages: result.completedStages as ProtocolStageCode[],
|
||||
lastActiveAt: result.lastActiveAt,
|
||||
createdAt: result.createdAt,
|
||||
updatedAt: result.updatedAt,
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
@@ -0,0 +1,321 @@
|
||||
/**
|
||||
* Protocol Orchestrator
|
||||
* Protocol Agent的具体实现
|
||||
*
|
||||
* @module agent/protocol/services/ProtocolOrchestrator
|
||||
*/
|
||||
|
||||
import { PrismaClient } from '@prisma/client';
|
||||
import {
|
||||
BaseAgentOrchestrator,
|
||||
OrchestratorDependencies,
|
||||
} from '../../services/BaseAgentOrchestrator.js';
|
||||
import {
|
||||
AgentSession,
|
||||
AgentResponse,
|
||||
ProtocolContextData,
|
||||
ProtocolStageCode,
|
||||
SyncButtonData,
|
||||
ActionCard,
|
||||
UserMessageInput,
|
||||
} from '../../types/index.js';
|
||||
import { ProtocolContextService } from './ProtocolContextService.js';
|
||||
import { PromptBuilder } from './PromptBuilder.js';
|
||||
|
||||
/** 阶段名称映射 */
|
||||
const STAGE_NAMES: Record<ProtocolStageCode, string> = {
|
||||
scientific_question: '科学问题梳理',
|
||||
pico: 'PICO要素',
|
||||
study_design: '研究设计',
|
||||
sample_size: '样本量计算',
|
||||
endpoints: '观察指标',
|
||||
};
|
||||
|
||||
/** 阶段顺序 */
|
||||
const STAGE_ORDER: ProtocolStageCode[] = [
|
||||
'scientific_question',
|
||||
'pico',
|
||||
'study_design',
|
||||
'sample_size',
|
||||
'endpoints',
|
||||
];
|
||||
|
||||
export class ProtocolOrchestrator extends BaseAgentOrchestrator {
|
||||
private contextService: ProtocolContextService;
|
||||
private promptBuilder: PromptBuilder;
|
||||
|
||||
constructor(deps: OrchestratorDependencies) {
|
||||
super(deps);
|
||||
this.contextService = new ProtocolContextService(deps.prisma);
|
||||
this.promptBuilder = new PromptBuilder(this.configLoader);
|
||||
}
|
||||
|
||||
/**
|
||||
* 获取Agent代码
|
||||
*/
|
||||
getAgentCode(): string {
|
||||
return 'protocol_agent';
|
||||
}
|
||||
|
||||
/**
|
||||
* 覆盖父类handleMessage,确保上下文在处理消息前创建
|
||||
*/
|
||||
async handleMessage(input: UserMessageInput): Promise<AgentResponse> {
|
||||
// 确保上下文存在
|
||||
await this.contextService.getOrCreateContext(input.conversationId, input.userId);
|
||||
|
||||
// 调用父类方法处理消息
|
||||
return super.handleMessage(input);
|
||||
}
|
||||
|
||||
/**
|
||||
* 获取上下文数据(如果不存在则返回默认结构)
|
||||
*/
|
||||
async getContext(conversationId: string): Promise<Record<string, unknown> | null> {
|
||||
const context = await this.contextService.getContext(conversationId);
|
||||
if (!context) {
|
||||
// 返回默认上下文结构,避免 undefined 错误
|
||||
return {
|
||||
currentStage: 'scientific_question',
|
||||
completedStages: [],
|
||||
status: 'in_progress',
|
||||
};
|
||||
}
|
||||
return context as unknown as Record<string, unknown>;
|
||||
}
|
||||
|
||||
/**
|
||||
* 保存上下文数据
|
||||
*/
|
||||
async saveContext(
|
||||
conversationId: string,
|
||||
userId: string,
|
||||
data: Record<string, unknown>
|
||||
): Promise<void> {
|
||||
const context = await this.contextService.getOrCreateContext(conversationId, userId);
|
||||
const stageCode = context.currentStage;
|
||||
|
||||
await this.contextService.updateStageData(conversationId, stageCode, data);
|
||||
}
|
||||
|
||||
/**
|
||||
* 构建阶段响应
|
||||
*/
|
||||
async buildStageResponse(
|
||||
session: AgentSession,
|
||||
llmResponse: string,
|
||||
contextData: Record<string, unknown>
|
||||
): Promise<AgentResponse> {
|
||||
const context = contextData as unknown as ProtocolContextData;
|
||||
const stageCode = session.currentStage as ProtocolStageCode;
|
||||
const stageName = STAGE_NAMES[stageCode] || session.currentStage;
|
||||
|
||||
// 检测是否应该显示同步按钮
|
||||
const syncButton = this.buildSyncButton(llmResponse, stageCode, context);
|
||||
|
||||
// 构建动作卡片
|
||||
const actionCards = this.buildActionCards(stageCode, context);
|
||||
|
||||
return {
|
||||
content: llmResponse,
|
||||
stage: stageCode,
|
||||
stageName,
|
||||
syncButton,
|
||||
actionCards,
|
||||
};
|
||||
}
|
||||
|
||||
/**
|
||||
* 处理Protocol同步请求
|
||||
*/
|
||||
async handleProtocolSync(
|
||||
conversationId: string,
|
||||
userId: string,
|
||||
stageCode: string,
|
||||
data: Record<string, unknown>
|
||||
): Promise<{
|
||||
success: boolean;
|
||||
context: ProtocolContextData;
|
||||
nextStage?: ProtocolStageCode;
|
||||
message?: string;
|
||||
}> {
|
||||
const stage = stageCode as ProtocolStageCode;
|
||||
|
||||
// 保存阶段数据
|
||||
await this.contextService.updateStageData(conversationId, stage, {
|
||||
...data,
|
||||
confirmed: true,
|
||||
confirmedAt: new Date(),
|
||||
});
|
||||
|
||||
// 获取下一阶段
|
||||
const currentIndex = STAGE_ORDER.indexOf(stage);
|
||||
const nextStage = currentIndex < STAGE_ORDER.length - 1
|
||||
? STAGE_ORDER[currentIndex + 1]
|
||||
: undefined;
|
||||
|
||||
// 标记当前阶段完成,更新到下一阶段
|
||||
const context = await this.contextService.completeStage(
|
||||
conversationId,
|
||||
stage,
|
||||
nextStage
|
||||
);
|
||||
|
||||
// 检查是否所有阶段都已完成
|
||||
const allCompleted = this.contextService.isAllStagesCompleted(context);
|
||||
|
||||
return {
|
||||
success: true,
|
||||
context,
|
||||
nextStage,
|
||||
message: allCompleted
|
||||
? '🎉 所有核心要素已完成!您可以点击「一键生成研究方案」生成完整方案。'
|
||||
: nextStage
|
||||
? `已同步${STAGE_NAMES[stage]},进入${STAGE_NAMES[nextStage]}阶段`
|
||||
: `已同步${STAGE_NAMES[stage]}`,
|
||||
};
|
||||
}
|
||||
|
||||
/**
|
||||
* 获取Protocol上下文服务
|
||||
*/
|
||||
getContextService(): ProtocolContextService {
|
||||
return this.contextService;
|
||||
}
|
||||
|
||||
/**
|
||||
* 构建同步按钮数据
|
||||
*/
|
||||
private buildSyncButton(
|
||||
llmResponse: string,
|
||||
stageCode: ProtocolStageCode,
|
||||
context: ProtocolContextData
|
||||
): SyncButtonData | undefined {
|
||||
// 检测LLM响应中是否有已整理好的数据
|
||||
// 这里用简单的关键词检测,实际可以用更复杂的方式
|
||||
const readyPatterns = [
|
||||
'整理',
|
||||
'总结',
|
||||
'您的科学问题',
|
||||
'您的PICO',
|
||||
'您的研究设计',
|
||||
'样本量',
|
||||
'观察指标',
|
||||
'同步到方案',
|
||||
];
|
||||
|
||||
const hasReadyData = readyPatterns.some(p => llmResponse.includes(p));
|
||||
|
||||
if (!hasReadyData) {
|
||||
return undefined;
|
||||
}
|
||||
|
||||
// 确保 completedStages 存在
|
||||
const completedStages = context.completedStages || [];
|
||||
|
||||
// 检查当前阶段是否已完成
|
||||
if (completedStages.includes(stageCode)) {
|
||||
return {
|
||||
stageCode,
|
||||
extractedData: {},
|
||||
label: '已同步',
|
||||
disabled: true,
|
||||
};
|
||||
}
|
||||
|
||||
return {
|
||||
stageCode,
|
||||
extractedData: this.extractDataFromResponse(llmResponse, stageCode),
|
||||
label: '✅ 同步到方案',
|
||||
disabled: false,
|
||||
};
|
||||
}
|
||||
|
||||
/**
|
||||
* 从LLM响应中提取结构化数据
|
||||
*/
|
||||
private extractDataFromResponse(
|
||||
response: string,
|
||||
stageCode: ProtocolStageCode
|
||||
): Record<string, unknown> {
|
||||
// 尝试从响应中提取JSON格式的数据
|
||||
const jsonMatch = response.match(/<extracted_data>([\s\S]*?)<\/extracted_data>/);
|
||||
if (jsonMatch) {
|
||||
try {
|
||||
return JSON.parse(jsonMatch[1]);
|
||||
} catch {
|
||||
// 解析失败,继续使用默认逻辑
|
||||
}
|
||||
}
|
||||
|
||||
// 简单提取(实际应该用LLM来提取)
|
||||
switch (stageCode) {
|
||||
case 'scientific_question':
|
||||
return { content: response.substring(0, 500), readyToSync: true };
|
||||
default:
|
||||
return { readyToSync: true };
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 构建动作卡片
|
||||
*/
|
||||
private buildActionCards(
|
||||
stageCode: ProtocolStageCode,
|
||||
context: ProtocolContextData
|
||||
): ActionCard[] {
|
||||
const cards: ActionCard[] = [];
|
||||
|
||||
// 样本量阶段:添加样本量计算器卡片
|
||||
if (stageCode === 'sample_size') {
|
||||
cards.push({
|
||||
id: 'sample_size_calculator',
|
||||
type: 'tool',
|
||||
title: '📊 样本量计算器',
|
||||
description: '使用专业计算器进行样本量估算',
|
||||
actionUrl: '/tools/sample-size-calculator',
|
||||
});
|
||||
}
|
||||
|
||||
// 所有阶段完成后:添加一键生成按钮(确保context有效)
|
||||
if (context.completedStages && this.contextService.isAllStagesCompleted(context)) {
|
||||
cards.push({
|
||||
id: 'generate_protocol',
|
||||
type: 'action',
|
||||
title: '🚀 一键生成研究方案',
|
||||
description: '基于5个核心要素生成完整研究方案',
|
||||
actionUrl: '/api/aia/protocol-agent/generate',
|
||||
});
|
||||
}
|
||||
|
||||
return cards;
|
||||
}
|
||||
|
||||
/**
|
||||
* 获取上下文状态摘要(用于前端State Panel)
|
||||
*/
|
||||
async getContextSummary(conversationId: string): Promise<{
|
||||
currentStage: string;
|
||||
stageName: string;
|
||||
progress: number;
|
||||
stages: Array<{
|
||||
stageCode: string;
|
||||
stageName: string;
|
||||
status: 'completed' | 'current' | 'pending';
|
||||
data: Record<string, unknown> | null;
|
||||
}>;
|
||||
canGenerate: boolean;
|
||||
} | null> {
|
||||
const context = await this.contextService.getContext(conversationId);
|
||||
if (!context) return null;
|
||||
|
||||
return {
|
||||
currentStage: context.currentStage,
|
||||
stageName: STAGE_NAMES[context.currentStage] || context.currentStage,
|
||||
progress: this.contextService.getProgress(context),
|
||||
stages: this.contextService.getStagesStatus(context),
|
||||
canGenerate: this.contextService.isAllStagesCompleted(context),
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
11
backend/src/modules/agent/protocol/services/index.ts
Normal file
11
backend/src/modules/agent/protocol/services/index.ts
Normal file
@@ -0,0 +1,11 @@
|
||||
/**
|
||||
* Protocol Agent Services Export
|
||||
*
|
||||
* @module agent/protocol/services
|
||||
*/
|
||||
|
||||
export { ProtocolContextService } from './ProtocolContextService.js';
|
||||
export { ProtocolOrchestrator } from './ProtocolOrchestrator.js';
|
||||
export { PromptBuilder } from './PromptBuilder.js';
|
||||
export { LLMServiceAdapter, createLLMServiceAdapter } from './LLMServiceAdapter.js';
|
||||
|
||||
Reference in New Issue
Block a user