feat(aia): Implement Protocol Agent MVP with reusable Agent framework

Sprint 1-3 Completed (Backend + Frontend):

Backend (Sprint 1-2):
- Implement 5-layer Agent framework (Query->Planner->Executor->Tools->Reflection)
- Create agent_schema with 6 tables (agent_definitions, stages, prompts, sessions, traces, reflexion_rules)
- Create protocol_schema with 2 tables (protocol_contexts, protocol_generations)
- Implement Protocol Agent core services (Orchestrator, ContextService, PromptBuilder)
- Integrate LLM service adapter (DeepSeek/Qwen/GPT-5/Claude)
- 6 API endpoints with full authentication
- 10/10 API tests passed

Frontend (Sprint 3):
- Add Protocol Agent entry in AgentHub (indigo theme card)
- Implement ProtocolAgentPage with 3-column layout
- Collapsible sidebar (Gemini style, 48px <-> 280px)
- StatePanel with 5 stage cards (scientific_question, pico, study_design, sample_size, endpoints)
- ChatArea with sync button and action cards integration
- 100% prototype design restoration (608 lines CSS)
- Detailed endpoints structure: baseline, exposure, outcomes, confounders

Features:
- 5-stage dialogue flow for research protocol design
- Conversation-driven interaction with sync-to-protocol button
- Real-time context state management
- One-click protocol generation button (UI ready, backend pending)

Database:
- agent_schema: 6 tables for reusable Agent framework
- protocol_schema: 2 tables for Protocol Agent
- Seed data: 1 agent + 5 stages + 9 prompts + 4 reflexion rules

Code Stats:
- Backend: 13 files, 4338 lines
- Frontend: 14 files, 2071 lines
- Total: 27 files, 6409 lines

Status: MVP core functionality completed, pending frontend-backend integration testing

Next: Sprint 4 - One-click protocol generation + Word export
This commit is contained in:
2026-01-24 17:29:24 +08:00
parent 61cdc97eeb
commit 96290d2f76
345 changed files with 13945 additions and 47 deletions

View File

@@ -201,5 +201,7 @@ export const jwtService = new JWTService();

View File

@@ -332,6 +332,8 @@ export function getBatchItems<T>(

View File

@@ -84,5 +84,7 @@ export interface VariableValidation {

View File

@@ -356,3 +356,5 @@ export default ChunkService;

View File

@@ -52,3 +52,5 @@ export const DifyClient = DeprecatedDifyClient;

View File

@@ -207,3 +207,5 @@ export function createOpenAIStreamAdapter(

View File

@@ -213,3 +213,5 @@ export async function streamChat(

View File

@@ -31,3 +31,5 @@ export { THINKING_TAGS } from './types';

View File

@@ -106,3 +106,5 @@ export type SSEEventType =

View File

@@ -159,6 +159,15 @@ logger.info('✅ PKB个人知识库路由已注册: /api/v1/pkb');
await fastify.register(aiaRoutes, { prefix: '/api/v1/aia' });
logger.info('✅ AIA智能问答路由已注册: /api/v1/aia');
// ============================================
// 【业务模块】Protocol Agent - 研究方案制定Agent
// ============================================
import { protocolAgentRoutes } from './modules/agent/protocol/index.js';
await fastify.register((instance, opts, done) => {
protocolAgentRoutes(instance, { prisma, ...opts }).then(() => done()).catch(done);
}, { prefix: '/api/v1/aia/protocol-agent' });
logger.info('✅ Protocol Agent路由已注册: /api/v1/aia/protocol-agent');
// ============================================
// 【业务模块】ASL - AI智能文献筛选
// ============================================

View File

@@ -92,3 +92,5 @@ export async function moduleRoutes(fastify: FastifyInstance) {

View File

@@ -122,3 +122,5 @@ export interface PaginatedResponse<T> {

View File

@@ -169,3 +169,5 @@ export const ROLE_DISPLAY_NAMES: Record<UserRole, string> = {

View File

@@ -0,0 +1,25 @@
/**
* Agent Framework Module
*
* 通用Agent框架入口
*
* @module agent
*/
// Types
export * from './types/index.js';
// Services
export {
ConfigLoader,
BaseAgentOrchestrator,
QueryAnalyzer,
StageManager,
TraceLogger,
type OrchestratorDependencies,
type LLMServiceInterface,
type IntentType,
type TransitionCondition,
type TraceLogInput,
} from './services/index.js';

View File

@@ -0,0 +1,287 @@
/**
* Protocol Agent Controller
* 处理Protocol Agent的HTTP请求
*
* @module agent/protocol/controllers/ProtocolAgentController
*/
import { FastifyRequest, FastifyReply } from 'fastify';
import { PrismaClient } from '@prisma/client';
import { ProtocolOrchestrator } from '../services/ProtocolOrchestrator.js';
import { LLMServiceInterface } from '../../services/BaseAgentOrchestrator.js';
import { ProtocolStageCode } from '../../types/index.js';
// 请求类型定义
interface SendMessageBody {
conversationId: string;
content: string;
messageId?: string;
}
interface SyncDataBody {
conversationId: string;
stageCode: ProtocolStageCode;
data: Record<string, unknown>;
}
interface GenerateProtocolBody {
conversationId: string;
options?: {
sections?: string[];
style?: 'academic' | 'concise';
};
}
interface GetContextParams {
conversationId: string;
}
export class ProtocolAgentController {
private orchestrator: ProtocolOrchestrator;
constructor(prisma: PrismaClient, llmService: LLMServiceInterface) {
this.orchestrator = new ProtocolOrchestrator({ prisma, llmService });
}
/**
* 发送消息
* POST /api/aia/protocol-agent/message
*/
async sendMessage(
request: FastifyRequest<{ Body: SendMessageBody }>,
reply: FastifyReply
): Promise<void> {
try {
const { conversationId, content, messageId } = request.body;
const userId = (request as any).user?.userId;
if (!userId) {
reply.code(401).send({ error: 'Unauthorized' });
return;
}
if (!conversationId || !content) {
reply.code(400).send({ error: 'Missing required fields: conversationId, content' });
return;
}
const response = await this.orchestrator.handleMessage({
conversationId,
userId,
content,
messageId,
});
reply.send({
success: true,
data: response,
});
} catch (error) {
console.error('[ProtocolAgentController] sendMessage error:', error);
reply.code(500).send({
success: false,
error: error instanceof Error ? error.message : 'Internal server error',
});
}
}
/**
* 同步阶段数据
* POST /api/aia/protocol-agent/sync
*/
async syncData(
request: FastifyRequest<{ Body: SyncDataBody }>,
reply: FastifyReply
): Promise<void> {
try {
const { conversationId, stageCode, data } = request.body;
const userId = (request as any).user?.userId;
if (!userId) {
reply.code(401).send({ error: 'Unauthorized' });
return;
}
if (!conversationId || !stageCode) {
reply.code(400).send({ error: 'Missing required fields: conversationId, stageCode' });
return;
}
const result = await this.orchestrator.handleProtocolSync(
conversationId,
userId,
stageCode,
data
);
reply.send({
success: true,
data: result,
});
} catch (error) {
console.error('[ProtocolAgentController] syncData error:', error);
reply.code(500).send({
success: false,
error: error instanceof Error ? error.message : 'Internal server error',
});
}
}
/**
* 获取上下文状态
* GET /api/aia/protocol-agent/context/:conversationId
*/
async getContext(
request: FastifyRequest<{ Params: GetContextParams }>,
reply: FastifyReply
): Promise<void> {
try {
const { conversationId } = request.params;
if (!conversationId) {
reply.code(400).send({ error: 'Missing conversationId' });
return;
}
const summary = await this.orchestrator.getContextSummary(conversationId);
if (!summary) {
reply.code(404).send({ error: 'Context not found' });
return;
}
reply.send({
success: true,
data: summary,
});
} catch (error) {
console.error('[ProtocolAgentController] getContext error:', error);
reply.code(500).send({
success: false,
error: error instanceof Error ? error.message : 'Internal server error',
});
}
}
/**
* 一键生成研究方案
* POST /api/aia/protocol-agent/generate
*/
async generateProtocol(
request: FastifyRequest<{ Body: GenerateProtocolBody }>,
reply: FastifyReply
): Promise<void> {
try {
const { conversationId, options } = request.body;
const userId = (request as any).user?.userId;
if (!userId) {
reply.code(401).send({ error: 'Unauthorized' });
return;
}
if (!conversationId) {
reply.code(400).send({ error: 'Missing conversationId' });
return;
}
// 获取上下文
const contextService = this.orchestrator.getContextService();
const context = await contextService.getContext(conversationId);
if (!context) {
reply.code(404).send({ error: 'Context not found' });
return;
}
// 检查是否所有阶段都已完成
if (!contextService.isAllStagesCompleted(context)) {
reply.code(400).send({
error: '请先完成所有5个阶段科学问题、PICO、研究设计、样本量、观察指标'
});
return;
}
// TODO: 实现方案生成逻辑
// 这里先返回占位响应实际应该调用LLM生成完整方案
reply.send({
success: true,
data: {
generationId: 'placeholder',
status: 'generating',
message: '研究方案生成中...',
estimatedTime: 30,
},
});
} catch (error) {
console.error('[ProtocolAgentController] generateProtocol error:', error);
reply.code(500).send({
success: false,
error: error instanceof Error ? error.message : 'Internal server error',
});
}
}
/**
* 获取生成的方案
* GET /api/aia/protocol-agent/generation/:generationId
*/
async getGeneration(
request: FastifyRequest<{ Params: { generationId: string } }>,
reply: FastifyReply
): Promise<void> {
try {
const { generationId } = request.params;
// TODO: 实现获取生成结果逻辑
reply.send({
success: true,
data: {
id: generationId,
status: 'completed',
content: '# 研究方案\n\n生成中...',
contentVersion: 1,
},
});
} catch (error) {
console.error('[ProtocolAgentController] getGeneration error:', error);
reply.code(500).send({
success: false,
error: error instanceof Error ? error.message : 'Internal server error',
});
}
}
/**
* 导出Word文档
* POST /api/aia/protocol-agent/generation/:generationId/export
*/
async exportWord(
request: FastifyRequest<{
Params: { generationId: string };
Body: { format: 'docx' | 'pdf' };
}>,
reply: FastifyReply
): Promise<void> {
try {
const { generationId } = request.params;
const { format } = request.body;
// TODO: 实现导出逻辑
reply.send({
success: true,
data: {
downloadUrl: `/api/aia/protocol-agent/download/${generationId}.${format}`,
expiresAt: new Date(Date.now() + 3600 * 1000).toISOString(),
},
});
} catch (error) {
console.error('[ProtocolAgentController] exportWord error:', error);
reply.code(500).send({
success: false,
error: error instanceof Error ? error.message : 'Internal server error',
});
}
}
}

View File

@@ -0,0 +1,22 @@
/**
* Protocol Agent Module
* 研究方案制定Agent
*
* @module agent/protocol
*/
// Services
export {
ProtocolContextService,
ProtocolOrchestrator,
PromptBuilder,
LLMServiceAdapter,
createLLMServiceAdapter,
} from './services/index.js';
// Routes
export { protocolAgentRoutes } from './routes/index.js';
// Controllers
export { ProtocolAgentController } from './controllers/ProtocolAgentController.js';

View File

@@ -0,0 +1,153 @@
/**
* Protocol Agent Routes
* Protocol Agent的API路由定义
*
* @module agent/protocol/routes
*/
import { FastifyInstance } from 'fastify';
import { PrismaClient } from '@prisma/client';
import { ProtocolAgentController } from '../controllers/ProtocolAgentController.js';
import { createLLMServiceAdapter } from '../services/LLMServiceAdapter.js';
import { authenticate } from '../../../../common/auth/auth.middleware.js';
export async function protocolAgentRoutes(
fastify: FastifyInstance,
options: { prisma: PrismaClient }
): Promise<void> {
const { prisma } = options;
// 创建LLM服务使用真实的LLM适配器
const llmService = createLLMServiceAdapter();
// 创建控制器
const controller = new ProtocolAgentController(prisma, llmService);
// ============================================
// Protocol Agent API Routes (需要认证)
// ============================================
// 发送消息
fastify.post<{
Body: {
conversationId: string;
content: string;
messageId?: string;
};
}>('/message', {
preHandler: [authenticate],
schema: {
body: {
type: 'object',
required: ['conversationId', 'content'],
properties: {
conversationId: { type: 'string' },
content: { type: 'string' },
messageId: { type: 'string' },
},
},
},
}, (request, reply) => controller.sendMessage(request, reply));
// 同步阶段数据
fastify.post('/sync', {
preHandler: [authenticate],
schema: {
body: {
type: 'object',
required: ['conversationId', 'stageCode'],
properties: {
conversationId: { type: 'string' },
stageCode: { type: 'string' },
data: { type: 'object' },
},
},
},
}, (request, reply) => controller.syncData(request as never, reply));
// 获取上下文状态
fastify.get<{
Params: { conversationId: string };
}>('/context/:conversationId', {
preHandler: [authenticate],
schema: {
params: {
type: 'object',
required: ['conversationId'],
properties: {
conversationId: { type: 'string' },
},
},
},
}, (request, reply) => controller.getContext(request, reply));
// 一键生成研究方案
fastify.post<{
Body: {
conversationId: string;
options?: {
sections?: string[];
style?: 'academic' | 'concise';
};
};
}>('/generate', {
preHandler: [authenticate],
schema: {
body: {
type: 'object',
required: ['conversationId'],
properties: {
conversationId: { type: 'string' },
options: {
type: 'object',
properties: {
sections: { type: 'array', items: { type: 'string' } },
style: { type: 'string', enum: ['academic', 'concise'] },
},
},
},
},
},
}, (request, reply) => controller.generateProtocol(request, reply));
// 获取生成的方案
fastify.get<{
Params: { generationId: string };
}>('/generation/:generationId', {
preHandler: [authenticate],
schema: {
params: {
type: 'object',
required: ['generationId'],
properties: {
generationId: { type: 'string' },
},
},
},
}, (request, reply) => controller.getGeneration(request, reply));
// 导出Word文档
fastify.post<{
Params: { generationId: string };
Body: { format: 'docx' | 'pdf' };
}>('/generation/:generationId/export', {
preHandler: [authenticate],
schema: {
params: {
type: 'object',
required: ['generationId'],
properties: {
generationId: { type: 'string' },
},
},
body: {
type: 'object',
required: ['format'],
properties: {
format: { type: 'string', enum: ['docx', 'pdf'] },
},
},
},
}, (request, reply) => controller.exportWord(request, reply));
}

View File

@@ -0,0 +1,184 @@
/**
* LLM Service Adapter
* 将现有的LLM服务适配为Protocol Agent所需的接口
*
* @module agent/protocol/services/LLMServiceAdapter
*/
import { LLMFactory } from '../../../../common/llm/adapters/LLMFactory.js';
import { ModelType, Message as LLMMessage, LLMOptions } from '../../../../common/llm/adapters/types.js';
import { LLMServiceInterface } from '../../services/BaseAgentOrchestrator.js';
export class LLMServiceAdapter implements LLMServiceInterface {
private defaultModel: ModelType;
private defaultTemperature: number;
private defaultMaxTokens: number;
constructor(options?: {
defaultModel?: ModelType;
defaultTemperature?: number;
defaultMaxTokens?: number;
}) {
this.defaultModel = options?.defaultModel ?? 'deepseek-v3';
this.defaultTemperature = options?.defaultTemperature ?? 0.7;
this.defaultMaxTokens = options?.defaultMaxTokens ?? 4096;
}
/**
* 调用LLM进行对话
*/
async chat(params: {
messages: Array<{ role: string; content: string }>;
model?: string;
temperature?: number;
maxTokens?: number;
}): Promise<{
content: string;
thinkingContent?: string;
tokensUsed: number;
model: string;
}> {
// 获取模型类型
const modelType = this.parseModelType(params.model);
// 获取LLM适配器
const adapter = LLMFactory.getAdapter(modelType);
// 转换消息格式
const messages: LLMMessage[] = params.messages.map(m => ({
role: m.role as 'system' | 'user' | 'assistant',
content: m.content,
}));
// 调用LLM
const options: LLMOptions = {
temperature: params.temperature ?? this.defaultTemperature,
maxTokens: params.maxTokens ?? this.defaultMaxTokens,
};
try {
const response = await adapter.chat(messages, options);
// 提取思考内容(如果有)
const { content, thinkingContent } = this.extractThinkingContent(response.content);
return {
content,
thinkingContent,
tokensUsed: response.usage?.totalTokens ?? 0,
model: response.model,
};
} catch (error) {
console.error('[LLMServiceAdapter] chat error:', error);
throw error;
}
}
/**
* 流式调用LLM返回AsyncGenerator
*/
async *chatStream(params: {
messages: Array<{ role: string; content: string }>;
model?: string;
temperature?: number;
maxTokens?: number;
}): AsyncGenerator<{
content: string;
done: boolean;
tokensUsed?: number;
}> {
const modelType = this.parseModelType(params.model);
const adapter = LLMFactory.getAdapter(modelType);
const messages: LLMMessage[] = params.messages.map(m => ({
role: m.role as 'system' | 'user' | 'assistant',
content: m.content,
}));
const options: LLMOptions = {
temperature: params.temperature ?? this.defaultTemperature,
maxTokens: params.maxTokens ?? this.defaultMaxTokens,
stream: true,
};
try {
for await (const chunk of adapter.chatStream(messages, options)) {
yield {
content: chunk.content,
done: chunk.done,
tokensUsed: chunk.usage?.totalTokens,
};
}
} catch (error) {
console.error('[LLMServiceAdapter] chatStream error:', error);
throw error;
}
}
/**
* 解析模型类型
*/
private parseModelType(model?: string): ModelType {
if (!model) return this.defaultModel;
// 映射模型名称到ModelType
const modelMap: Record<string, ModelType> = {
'deepseek-v3': 'deepseek-v3',
'deepseek-chat': 'deepseek-v3',
'qwen-max': 'qwen3-72b',
'qwen3-72b': 'qwen3-72b',
'qwen-long': 'qwen-long',
'gpt-5': 'gpt-5',
'gpt-5-pro': 'gpt-5',
'claude-4.5': 'claude-4.5',
'claude-sonnet': 'claude-4.5',
};
return modelMap[model.toLowerCase()] ?? this.defaultModel;
}
/**
* 从响应中提取思考内容(<think>...</think>
*/
private extractThinkingContent(content: string): {
content: string;
thinkingContent?: string;
} {
// 匹配 <think>...</think> 或 <thinking>...</thinking>
const thinkingPattern = /<(?:think|thinking)>([\s\S]*?)<\/(?:think|thinking)>/gi;
const matches = content.matchAll(thinkingPattern);
let thinkingContent = '';
let cleanContent = content;
for (const match of matches) {
thinkingContent += match[1].trim() + '\n';
cleanContent = cleanContent.replace(match[0], '').trim();
}
return {
content: cleanContent,
thinkingContent: thinkingContent ? thinkingContent.trim() : undefined,
};
}
/**
* 获取支持的模型列表
*/
getSupportedModels(): string[] {
return LLMFactory.getSupportedModels();
}
}
/**
* 创建默认的LLM服务适配器
*/
export function createLLMServiceAdapter(): LLMServiceInterface {
return new LLMServiceAdapter({
defaultModel: 'deepseek-v3',
defaultTemperature: 0.7,
defaultMaxTokens: 4096,
});
}

View File

@@ -0,0 +1,286 @@
/**
* Prompt Builder
* 构建和渲染Protocol Agent的Prompt
*
* @module agent/protocol/services/PromptBuilder
*/
import { ConfigLoader } from '../../services/ConfigLoader.js';
import {
AgentPrompt,
ProtocolContextData,
PromptRenderContext,
} from '../../types/index.js';
export class PromptBuilder {
private configLoader: ConfigLoader;
constructor(configLoader: ConfigLoader) {
this.configLoader = configLoader;
}
/**
* 构建完整的消息列表
*/
async buildMessages(
context: ProtocolContextData,
userMessage: string,
conversationHistory?: Array<{ role: string; content: string }>
): Promise<Array<{ role: string; content: string }>> {
const messages: Array<{ role: string; content: string }> = [];
// 1. 系统Prompt
const systemPrompt = await this.buildSystemPrompt(context);
if (systemPrompt) {
messages.push({ role: 'system', content: systemPrompt });
}
// 2. 阶段Prompt
const stagePrompt = await this.buildStagePrompt(context);
if (stagePrompt) {
messages.push({ role: 'system', content: stagePrompt });
}
// 3. 对话历史最近5轮
if (conversationHistory?.length) {
const recentHistory = conversationHistory.slice(-10); // 最近10条消息5轮
messages.push(...recentHistory);
}
// 4. 用户消息
messages.push({ role: 'user', content: userMessage });
return messages;
}
/**
* 构建系统Prompt
*/
async buildSystemPrompt(context: ProtocolContextData): Promise<string | null> {
const prompt = await this.configLoader.getSystemPrompt('protocol_agent');
if (!prompt) return null;
return this.renderTemplate(prompt.content, { context });
}
/**
* 构建当前阶段Prompt
*/
async buildStagePrompt(context: ProtocolContextData): Promise<string | null> {
const prompt = await this.configLoader.getStagePrompt(
'protocol_agent',
context.currentStage
);
if (!prompt) return null;
return this.renderTemplate(prompt.content, { context });
}
/**
* 构建数据提取Prompt
*/
async buildExtractionPrompt(
context: ProtocolContextData,
userMessage: string
): Promise<string | null> {
const prompt = await this.configLoader.getExtractionPrompt(
'protocol_agent',
context.currentStage
);
if (!prompt) return null;
return this.renderTemplate(prompt.content, {
userMessage,
context,
currentPico: context.pico,
});
}
/**
* 构建方案生成Prompt
*/
async buildGenerationPrompt(context: ProtocolContextData): Promise<string | null> {
const config = await this.configLoader.loadAgentConfig('protocol_agent');
const prompt = config.prompts.find(p => p.promptCode === 'generate_protocol');
if (!prompt) return null;
return this.renderTemplate(prompt.content, {
scientificQuestion: context.scientificQuestion,
pico: context.pico,
studyDesign: context.studyDesign,
sampleSize: context.sampleSize,
endpoints: context.endpoints,
});
}
/**
* 渲染模板
* 支持 {{variable}} 和 {{#if variable}}...{{/if}} 语法
*/
private renderTemplate(
template: string,
variables: Record<string, unknown>
): string {
let result = template;
// 处理 {{#if variable}}...{{/if}} 条件块
result = this.processConditionals(result, variables);
// 处理 {{#each array}}...{{/each}} 循环块
result = this.processEachBlocks(result, variables);
// 处理简单变量替换 {{variable}}
result = this.processVariables(result, variables);
return result;
}
/**
* 处理条件块
*/
private processConditionals(
template: string,
variables: Record<string, unknown>
): string {
const ifPattern = /\{\{#if\s+(\S+)\}\}([\s\S]*?)\{\{\/if\}\}/g;
return template.replace(ifPattern, (match, condition, content) => {
const value = this.getNestedValue(variables, condition);
if (value && value !== false && value !== null && value !== undefined) {
return content;
}
return '';
});
}
/**
* 处理循环块
*/
private processEachBlocks(
template: string,
variables: Record<string, unknown>
): string {
const eachPattern = /\{\{#each\s+(\S+)\}\}([\s\S]*?)\{\{\/each\}\}/g;
return template.replace(eachPattern, (match, arrayPath, content) => {
const array = this.getNestedValue(variables, arrayPath);
if (!Array.isArray(array)) return '';
return array.map((item, index) => {
let itemContent = content;
// 替换 {{this}} 为当前项
itemContent = itemContent.replace(/\{\{this\}\}/g, String(item));
// 替换 {{@index}} 为索引
itemContent = itemContent.replace(/\{\{@index\}\}/g, String(index));
// 替换项属性 {{name}}, {{definition}} 等
if (typeof item === 'object' && item !== null) {
for (const [key, value] of Object.entries(item)) {
const regex = new RegExp(`\\{\\{${key}\\}\\}`, 'g');
itemContent = itemContent.replace(regex, String(value ?? ''));
}
}
return itemContent;
}).join('\n');
});
}
/**
* 处理变量替换
*/
private processVariables(
template: string,
variables: Record<string, unknown>
): string {
const varPattern = /\{\{([^#/][^}]*)\}\}/g;
return template.replace(varPattern, (match, varPath) => {
const value = this.getNestedValue(variables, varPath.trim());
if (value === undefined || value === null) {
return '';
}
if (typeof value === 'object') {
return JSON.stringify(value, null, 2);
}
return String(value);
});
}
/**
* 获取嵌套属性值
*/
private getNestedValue(
obj: Record<string, unknown>,
path: string
): unknown {
const parts = path.split('.');
let current: unknown = obj;
for (const part of parts) {
if (current === null || current === undefined) {
return undefined;
}
if (typeof current !== 'object') {
return undefined;
}
current = (current as Record<string, unknown>)[part];
}
return current;
}
/**
* 构建欢迎消息
*/
buildWelcomeMessage(): string {
return `您好!我是研究方案制定助手,将帮助您系统地完成临床研究方案的核心要素设计。
我们将一起完成以下5个关键步骤
1⃣ **科学问题梳理** - 明确研究要解决的核心问题
2⃣ **PICO要素** - 确定研究人群、干预、对照和结局
3⃣ **研究设计** - 选择合适的研究类型和方法
4⃣ **样本量计算** - 估算所需的样本量
5⃣ **观察指标** - 定义主要和次要结局指标
完成这5个要素后您可以**一键生成完整的研究方案**并下载为Word文档。
让我们开始吧!请先告诉我,您想研究什么问题?或者描述一下您的研究背景和想法。`;
}
/**
* 构建阶段完成消息
*/
buildStageCompleteMessage(
stageName: string,
nextStageName?: string,
isAllCompleted: boolean = false
): string {
if (isAllCompleted) {
return `${stageName}已同步到方案!
🎉 **恭喜您已完成所有5个核心要素的梳理**
您现在可以:
- 点击「🚀 一键生成研究方案」生成完整方案
- 或者回顾修改任何阶段的内容
需要我帮您生成研究方案吗?`;
}
return `${stageName}已同步到方案!
接下来我们进入**${nextStageName}**阶段。准备好了吗?
说"继续"我们就开始,或者您也可以先问我任何问题。`;
}
}

View File

@@ -0,0 +1,289 @@
/**
* Protocol Context Service
* 管理研究方案的上下文数据
*
* @module agent/protocol/services/ProtocolContextService
*/
import { PrismaClient } from '@prisma/client';
import {
ProtocolContextData,
ProtocolStageCode,
ScientificQuestionData,
PICOData,
StudyDesignData,
SampleSizeData,
EndpointsData,
} from '../../types/index.js';
export class ProtocolContextService {
private prisma: PrismaClient;
constructor(prisma: PrismaClient) {
this.prisma = prisma;
}
/**
* 获取或创建上下文
*/
async getOrCreateContext(
conversationId: string,
userId: string
): Promise<ProtocolContextData> {
let context = await this.getContext(conversationId);
if (!context) {
context = await this.createContext(conversationId, userId);
}
return context;
}
/**
* 获取上下文
*/
async getContext(conversationId: string): Promise<ProtocolContextData | null> {
const result = await this.prisma.protocolContext.findUnique({
where: { conversationId },
});
if (!result) return null;
return this.mapToContextData(result);
}
/**
* 创建新上下文
*/
async createContext(
conversationId: string,
userId: string
): Promise<ProtocolContextData> {
const result = await this.prisma.protocolContext.create({
data: {
conversationId,
userId,
currentStage: 'scientific_question',
status: 'in_progress',
completedStages: [],
},
});
return this.mapToContextData(result);
}
/**
* 更新阶段数据
*/
async updateStageData(
conversationId: string,
stageCode: ProtocolStageCode,
data: Record<string, unknown>
): Promise<ProtocolContextData> {
const updateData: Record<string, unknown> = {};
switch (stageCode) {
case 'scientific_question':
updateData.scientificQuestion = data;
break;
case 'pico':
updateData.pico = data;
break;
case 'study_design':
updateData.studyDesign = data;
break;
case 'sample_size':
updateData.sampleSize = data;
break;
case 'endpoints':
updateData.endpoints = data;
break;
}
const result = await this.prisma.protocolContext.update({
where: { conversationId },
data: {
...updateData,
lastActiveAt: new Date(),
},
});
return this.mapToContextData(result);
}
/**
* 标记阶段完成并更新当前阶段
*/
async completeStage(
conversationId: string,
stageCode: ProtocolStageCode,
nextStage?: ProtocolStageCode
): Promise<ProtocolContextData> {
const context = await this.getContext(conversationId);
if (!context) {
throw new Error('Context not found');
}
const completedStages = [...context.completedStages];
if (!completedStages.includes(stageCode)) {
completedStages.push(stageCode);
}
const result = await this.prisma.protocolContext.update({
where: { conversationId },
data: {
completedStages,
currentStage: nextStage ?? context.currentStage,
lastActiveAt: new Date(),
},
});
return this.mapToContextData(result);
}
/**
* 更新当前阶段
*/
async updateCurrentStage(
conversationId: string,
stageCode: ProtocolStageCode
): Promise<ProtocolContextData> {
const result = await this.prisma.protocolContext.update({
where: { conversationId },
data: {
currentStage: stageCode,
lastActiveAt: new Date(),
},
});
return this.mapToContextData(result);
}
/**
* 标记方案完成
*/
async markCompleted(conversationId: string): Promise<ProtocolContextData> {
const result = await this.prisma.protocolContext.update({
where: { conversationId },
data: {
status: 'completed',
lastActiveAt: new Date(),
},
});
return this.mapToContextData(result);
}
/**
* 检查是否所有阶段都已完成
*/
isAllStagesCompleted(context: ProtocolContextData): boolean {
const requiredStages: ProtocolStageCode[] = [
'scientific_question',
'pico',
'study_design',
'sample_size',
'endpoints',
];
return requiredStages.every(stage => context.completedStages.includes(stage));
}
/**
* 获取进度百分比
*/
getProgress(context: ProtocolContextData): number {
const totalStages = 5;
return Math.round((context.completedStages.length / totalStages) * 100);
}
/**
* 获取阶段状态列表
*/
getStagesStatus(context: ProtocolContextData): Array<{
stageCode: ProtocolStageCode;
stageName: string;
status: 'completed' | 'current' | 'pending';
data: Record<string, unknown> | null;
}> {
const stages: Array<{
code: ProtocolStageCode;
name: string;
dataKey: keyof ProtocolContextData;
}> = [
{ code: 'scientific_question', name: '科学问题梳理', dataKey: 'scientificQuestion' },
{ code: 'pico', name: 'PICO要素', dataKey: 'pico' },
{ code: 'study_design', name: '研究设计', dataKey: 'studyDesign' },
{ code: 'sample_size', name: '样本量计算', dataKey: 'sampleSize' },
{ code: 'endpoints', name: '观察指标', dataKey: 'endpoints' },
];
return stages.map(stage => ({
stageCode: stage.code,
stageName: stage.name,
status: context.completedStages.includes(stage.code)
? 'completed' as const
: context.currentStage === stage.code
? 'current' as const
: 'pending' as const,
data: context[stage.dataKey] as unknown as Record<string, unknown> | null ?? null,
}));
}
/**
* 获取用于生成方案的完整数据
*/
getGenerationData(context: ProtocolContextData): {
scientificQuestion: ScientificQuestionData | null;
pico: PICOData | null;
studyDesign: StudyDesignData | null;
sampleSize: SampleSizeData | null;
endpoints: EndpointsData | null;
} {
return {
scientificQuestion: context.scientificQuestion ?? null,
pico: context.pico ?? null,
studyDesign: context.studyDesign ?? null,
sampleSize: context.sampleSize ?? null,
endpoints: context.endpoints ?? null,
};
}
/**
* 将数据库结果映射为上下文数据
*/
private mapToContextData(result: {
id: string;
conversationId: string;
userId: string;
currentStage: string;
status: string;
scientificQuestion: unknown;
pico: unknown;
studyDesign: unknown;
sampleSize: unknown;
endpoints: unknown;
completedStages: string[];
lastActiveAt: Date;
createdAt: Date;
updatedAt: Date;
}): ProtocolContextData {
return {
id: result.id,
conversationId: result.conversationId,
userId: result.userId,
currentStage: result.currentStage as ProtocolStageCode,
status: result.status as ProtocolContextData['status'],
scientificQuestion: result.scientificQuestion as ScientificQuestionData | undefined,
pico: result.pico as PICOData | undefined,
studyDesign: result.studyDesign as StudyDesignData | undefined,
sampleSize: result.sampleSize as SampleSizeData | undefined,
endpoints: result.endpoints as EndpointsData | undefined,
completedStages: result.completedStages as ProtocolStageCode[],
lastActiveAt: result.lastActiveAt,
createdAt: result.createdAt,
updatedAt: result.updatedAt,
};
}
}

View File

@@ -0,0 +1,321 @@
/**
* Protocol Orchestrator
* Protocol Agent的具体实现
*
* @module agent/protocol/services/ProtocolOrchestrator
*/
import { PrismaClient } from '@prisma/client';
import {
BaseAgentOrchestrator,
OrchestratorDependencies,
} from '../../services/BaseAgentOrchestrator.js';
import {
AgentSession,
AgentResponse,
ProtocolContextData,
ProtocolStageCode,
SyncButtonData,
ActionCard,
UserMessageInput,
} from '../../types/index.js';
import { ProtocolContextService } from './ProtocolContextService.js';
import { PromptBuilder } from './PromptBuilder.js';
/** 阶段名称映射 */
const STAGE_NAMES: Record<ProtocolStageCode, string> = {
scientific_question: '科学问题梳理',
pico: 'PICO要素',
study_design: '研究设计',
sample_size: '样本量计算',
endpoints: '观察指标',
};
/** 阶段顺序 */
const STAGE_ORDER: ProtocolStageCode[] = [
'scientific_question',
'pico',
'study_design',
'sample_size',
'endpoints',
];
export class ProtocolOrchestrator extends BaseAgentOrchestrator {
private contextService: ProtocolContextService;
private promptBuilder: PromptBuilder;
constructor(deps: OrchestratorDependencies) {
super(deps);
this.contextService = new ProtocolContextService(deps.prisma);
this.promptBuilder = new PromptBuilder(this.configLoader);
}
/**
* 获取Agent代码
*/
getAgentCode(): string {
return 'protocol_agent';
}
/**
* 覆盖父类handleMessage确保上下文在处理消息前创建
*/
async handleMessage(input: UserMessageInput): Promise<AgentResponse> {
// 确保上下文存在
await this.contextService.getOrCreateContext(input.conversationId, input.userId);
// 调用父类方法处理消息
return super.handleMessage(input);
}
/**
* 获取上下文数据(如果不存在则返回默认结构)
*/
async getContext(conversationId: string): Promise<Record<string, unknown> | null> {
const context = await this.contextService.getContext(conversationId);
if (!context) {
// 返回默认上下文结构,避免 undefined 错误
return {
currentStage: 'scientific_question',
completedStages: [],
status: 'in_progress',
};
}
return context as unknown as Record<string, unknown>;
}
/**
* 保存上下文数据
*/
async saveContext(
conversationId: string,
userId: string,
data: Record<string, unknown>
): Promise<void> {
const context = await this.contextService.getOrCreateContext(conversationId, userId);
const stageCode = context.currentStage;
await this.contextService.updateStageData(conversationId, stageCode, data);
}
/**
* 构建阶段响应
*/
async buildStageResponse(
session: AgentSession,
llmResponse: string,
contextData: Record<string, unknown>
): Promise<AgentResponse> {
const context = contextData as unknown as ProtocolContextData;
const stageCode = session.currentStage as ProtocolStageCode;
const stageName = STAGE_NAMES[stageCode] || session.currentStage;
// 检测是否应该显示同步按钮
const syncButton = this.buildSyncButton(llmResponse, stageCode, context);
// 构建动作卡片
const actionCards = this.buildActionCards(stageCode, context);
return {
content: llmResponse,
stage: stageCode,
stageName,
syncButton,
actionCards,
};
}
/**
* 处理Protocol同步请求
*/
async handleProtocolSync(
conversationId: string,
userId: string,
stageCode: string,
data: Record<string, unknown>
): Promise<{
success: boolean;
context: ProtocolContextData;
nextStage?: ProtocolStageCode;
message?: string;
}> {
const stage = stageCode as ProtocolStageCode;
// 保存阶段数据
await this.contextService.updateStageData(conversationId, stage, {
...data,
confirmed: true,
confirmedAt: new Date(),
});
// 获取下一阶段
const currentIndex = STAGE_ORDER.indexOf(stage);
const nextStage = currentIndex < STAGE_ORDER.length - 1
? STAGE_ORDER[currentIndex + 1]
: undefined;
// 标记当前阶段完成,更新到下一阶段
const context = await this.contextService.completeStage(
conversationId,
stage,
nextStage
);
// 检查是否所有阶段都已完成
const allCompleted = this.contextService.isAllStagesCompleted(context);
return {
success: true,
context,
nextStage,
message: allCompleted
? '🎉 所有核心要素已完成!您可以点击「一键生成研究方案」生成完整方案。'
: nextStage
? `已同步${STAGE_NAMES[stage]},进入${STAGE_NAMES[nextStage]}阶段`
: `已同步${STAGE_NAMES[stage]}`,
};
}
/**
* 获取Protocol上下文服务
*/
getContextService(): ProtocolContextService {
return this.contextService;
}
/**
* 构建同步按钮数据
*/
private buildSyncButton(
llmResponse: string,
stageCode: ProtocolStageCode,
context: ProtocolContextData
): SyncButtonData | undefined {
// 检测LLM响应中是否有已整理好的数据
// 这里用简单的关键词检测,实际可以用更复杂的方式
const readyPatterns = [
'整理',
'总结',
'您的科学问题',
'您的PICO',
'您的研究设计',
'样本量',
'观察指标',
'同步到方案',
];
const hasReadyData = readyPatterns.some(p => llmResponse.includes(p));
if (!hasReadyData) {
return undefined;
}
// 确保 completedStages 存在
const completedStages = context.completedStages || [];
// 检查当前阶段是否已完成
if (completedStages.includes(stageCode)) {
return {
stageCode,
extractedData: {},
label: '已同步',
disabled: true,
};
}
return {
stageCode,
extractedData: this.extractDataFromResponse(llmResponse, stageCode),
label: '✅ 同步到方案',
disabled: false,
};
}
/**
* 从LLM响应中提取结构化数据
*/
private extractDataFromResponse(
response: string,
stageCode: ProtocolStageCode
): Record<string, unknown> {
// 尝试从响应中提取JSON格式的数据
const jsonMatch = response.match(/<extracted_data>([\s\S]*?)<\/extracted_data>/);
if (jsonMatch) {
try {
return JSON.parse(jsonMatch[1]);
} catch {
// 解析失败,继续使用默认逻辑
}
}
// 简单提取实际应该用LLM来提取
switch (stageCode) {
case 'scientific_question':
return { content: response.substring(0, 500), readyToSync: true };
default:
return { readyToSync: true };
}
}
/**
* 构建动作卡片
*/
private buildActionCards(
stageCode: ProtocolStageCode,
context: ProtocolContextData
): ActionCard[] {
const cards: ActionCard[] = [];
// 样本量阶段:添加样本量计算器卡片
if (stageCode === 'sample_size') {
cards.push({
id: 'sample_size_calculator',
type: 'tool',
title: '📊 样本量计算器',
description: '使用专业计算器进行样本量估算',
actionUrl: '/tools/sample-size-calculator',
});
}
// 所有阶段完成后添加一键生成按钮确保context有效
if (context.completedStages && this.contextService.isAllStagesCompleted(context)) {
cards.push({
id: 'generate_protocol',
type: 'action',
title: '🚀 一键生成研究方案',
description: '基于5个核心要素生成完整研究方案',
actionUrl: '/api/aia/protocol-agent/generate',
});
}
return cards;
}
/**
* 获取上下文状态摘要用于前端State Panel
*/
async getContextSummary(conversationId: string): Promise<{
currentStage: string;
stageName: string;
progress: number;
stages: Array<{
stageCode: string;
stageName: string;
status: 'completed' | 'current' | 'pending';
data: Record<string, unknown> | null;
}>;
canGenerate: boolean;
} | null> {
const context = await this.contextService.getContext(conversationId);
if (!context) return null;
return {
currentStage: context.currentStage,
stageName: STAGE_NAMES[context.currentStage] || context.currentStage,
progress: this.contextService.getProgress(context),
stages: this.contextService.getStagesStatus(context),
canGenerate: this.contextService.isAllStagesCompleted(context),
};
}
}

View File

@@ -0,0 +1,11 @@
/**
* Protocol Agent Services Export
*
* @module agent/protocol/services
*/
export { ProtocolContextService } from './ProtocolContextService.js';
export { ProtocolOrchestrator } from './ProtocolOrchestrator.js';
export { PromptBuilder } from './PromptBuilder.js';
export { LLMServiceAdapter, createLLMServiceAdapter } from './LLMServiceAdapter.js';

View File

@@ -0,0 +1,561 @@
/**
* Base Agent Orchestrator
* Agent框架的核心编排器抽象类
*
* @module agent/services/BaseAgentOrchestrator
*/
import { PrismaClient } from '@prisma/client';
import { v4 as uuidv4 } from 'uuid';
import {
AgentSession,
AgentResponse,
UserMessageInput,
StageTransitionResult,
IntentAnalysis,
AgentFullConfig,
AgentStage,
} from '../types/index.js';
import { ConfigLoader } from './ConfigLoader.js';
import { QueryAnalyzer } from './QueryAnalyzer.js';
import { StageManager } from './StageManager.js';
import { TraceLogger } from './TraceLogger.js';
/** Orchestrator依赖注入 */
export interface OrchestratorDependencies {
prisma: PrismaClient;
llmService: LLMServiceInterface;
}
/** LLM服务接口由具体实现提供 */
export interface LLMServiceInterface {
chat(params: {
messages: Array<{ role: string; content: string }>;
model?: string;
temperature?: number;
maxTokens?: number;
}): Promise<{
content: string;
thinkingContent?: string;
tokensUsed: number;
model: string;
}>;
}
/**
* 抽象基类Agent编排器
* 子类需要实现具体Agent的逻辑
*/
export abstract class BaseAgentOrchestrator {
protected prisma: PrismaClient;
protected llmService: LLMServiceInterface;
protected configLoader: ConfigLoader;
protected queryAnalyzer: QueryAnalyzer;
protected stageManager: StageManager;
protected traceLogger: TraceLogger;
protected config: AgentFullConfig | null = null;
constructor(deps: OrchestratorDependencies) {
this.prisma = deps.prisma;
this.llmService = deps.llmService;
this.configLoader = new ConfigLoader(deps.prisma);
this.queryAnalyzer = new QueryAnalyzer(deps.llmService);
this.stageManager = new StageManager(deps.prisma);
this.traceLogger = new TraceLogger(deps.prisma);
}
/**
* 获取Agent唯一标识子类必须实现
*/
abstract getAgentCode(): string;
/**
* 获取上下文数据(子类必须实现)
*/
abstract getContext(conversationId: string): Promise<Record<string, unknown> | null>;
/**
* 保存上下文数据(子类必须实现)
*/
abstract saveContext(
conversationId: string,
userId: string,
data: Record<string, unknown>
): Promise<void>;
/**
* 构建阶段响应(子类必须实现)
*/
abstract buildStageResponse(
session: AgentSession,
llmResponse: string,
context: Record<string, unknown>
): Promise<AgentResponse>;
/**
* 初始化配置
*/
protected async ensureConfig(): Promise<AgentFullConfig> {
if (!this.config) {
this.config = await this.configLoader.loadAgentConfig(this.getAgentCode());
}
return this.config;
}
/**
* 处理用户消息 - 主入口
*/
async handleMessage(input: UserMessageInput): Promise<AgentResponse> {
const traceId = uuidv4();
let stepIndex = 0;
try {
await this.ensureConfig();
// 1. 获取或创建会话
const session = await this.getOrCreateSession(input.conversationId, input.userId);
// 2. 记录输入追踪
await this.traceLogger.log({
sessionId: session.id,
traceId,
stepIndex: stepIndex++,
stepType: 'query',
input: { content: input.content },
stageCode: session.currentStage,
});
// 3. 意图识别
const intent = await this.analyzeIntent(input.content, session);
await this.traceLogger.log({
sessionId: session.id,
traceId,
stepIndex: stepIndex++,
stepType: 'plan',
input: { userMessage: input.content },
output: { intent },
stageCode: session.currentStage,
});
// 4. 根据意图处理
const context = await this.getContext(input.conversationId) || {};
// 5. 执行LLM对话
const llmResponse = await this.executeDialogue(session, input.content, context, intent);
await this.traceLogger.log({
sessionId: session.id,
traceId,
stepIndex: stepIndex++,
stepType: 'execute',
input: { intent, context },
output: { response: llmResponse.content.substring(0, 500) },
stageCode: session.currentStage,
modelUsed: llmResponse.model,
tokensUsed: llmResponse.tokensUsed,
});
// 6. 更新会话统计
await this.updateSessionStats(session.id, llmResponse.tokensUsed);
// 7. 构建响应
const response = await this.buildStageResponse(session, llmResponse.content, context);
response.thinkingContent = llmResponse.thinkingContent;
response.tokensUsed = llmResponse.tokensUsed;
response.modelUsed = llmResponse.model;
return response;
} catch (error) {
// 记录错误
const session = await this.getSession(input.conversationId);
if (session) {
await this.traceLogger.log({
sessionId: session.id,
traceId,
stepIndex,
stepType: 'execute',
stageCode: session.currentStage,
errorType: 'execution_error',
errorMsg: error instanceof Error ? error.message : String(error),
});
}
throw error;
}
}
/**
* 处理同步请求
*/
async handleSync(
conversationId: string,
userId: string,
stageCode: string,
data: Record<string, unknown>
): Promise<StageTransitionResult> {
const traceId = uuidv4();
await this.ensureConfig();
const session = await this.getSession(conversationId);
if (!session) {
throw new Error('Session not found');
}
// 记录同步操作
await this.traceLogger.log({
sessionId: session.id,
traceId,
stepIndex: 0,
stepType: 'sync',
input: { stageCode, data },
stageCode: session.currentStage,
});
// 保存上下文数据
await this.saveContext(conversationId, userId, {
...data,
confirmed: true,
confirmedAt: new Date(),
});
// 执行Reflexion检查如果启用
const reflexionResults = await this.runReflexion(session, stageCode, 'on_sync', data);
// 检查是否有阻塞性错误
const blockingError = reflexionResults.find(r => !r.passed && r.severity === 'error');
if (blockingError) {
return {
success: false,
fromStage: session.currentStage,
toStage: session.currentStage,
message: blockingError.message,
reflexionResults,
};
}
// 获取下一阶段
const nextStage = await this.configLoader.getNextStage(this.getAgentCode(), stageCode);
if (nextStage) {
// 更新会话状态
await this.updateSessionStage(session.id, nextStage.stageCode);
}
return {
success: true,
fromStage: session.currentStage,
toStage: nextStage?.stageCode ?? session.currentStage,
message: nextStage ? `进入${nextStage.stageName}阶段` : '所有阶段已完成',
reflexionResults,
};
}
/**
* 获取会话
*/
protected async getSession(conversationId: string): Promise<AgentSession | null> {
const result = await this.prisma.agentSession.findUnique({
where: { conversationId },
});
if (!result) return null;
return {
id: result.id,
agentId: result.agentId,
conversationId: result.conversationId,
userId: result.userId,
currentStage: result.currentStage,
status: result.status as AgentSession['status'],
contextRef: result.contextRef ?? undefined,
turnCount: result.turnCount,
totalTokens: result.totalTokens,
createdAt: result.createdAt,
updatedAt: result.updatedAt,
};
}
/**
* 获取或创建会话
*/
protected async getOrCreateSession(
conversationId: string,
userId: string
): Promise<AgentSession> {
let session = await this.getSession(conversationId);
if (!session) {
const config = await this.ensureConfig();
const initialStage = config.stages.find((s: AgentStage) => s.isInitial);
const result = await this.prisma.agentSession.create({
data: {
agentId: config.definition.id,
conversationId,
userId,
currentStage: initialStage?.stageCode ?? config.stages[0]?.stageCode ?? 'unknown',
status: 'active',
},
});
session = {
id: result.id,
agentId: result.agentId,
conversationId: result.conversationId,
userId: result.userId,
currentStage: result.currentStage,
status: result.status as AgentSession['status'],
contextRef: result.contextRef ?? undefined,
turnCount: result.turnCount,
totalTokens: result.totalTokens,
createdAt: result.createdAt,
updatedAt: result.updatedAt,
};
}
return session;
}
/**
* 分析用户意图
*/
protected async analyzeIntent(
userMessage: string,
session: AgentSession
): Promise<IntentAnalysis> {
// 简单的意图识别(后续可增强)
const lowerMessage = userMessage.toLowerCase();
// 检测阶段切换意图
if (lowerMessage.includes('继续') ||
lowerMessage.includes('下一步') ||
lowerMessage.includes('next')) {
return {
intent: 'proceed_next_stage',
confidence: 0.9,
entities: {},
suggestedAction: 'stage_transition',
};
}
// 检测同步意图
if (lowerMessage.includes('确认') ||
lowerMessage.includes('同步') ||
lowerMessage.includes('保存')) {
return {
intent: 'confirm_sync',
confidence: 0.85,
entities: {},
suggestedAction: 'sync_data',
};
}
// 默认为对话意图
return {
intent: 'dialogue',
confidence: 0.7,
entities: {},
suggestedAction: 'continue_dialogue',
};
}
/**
* 执行对话
*/
protected async executeDialogue(
session: AgentSession,
userMessage: string,
context: Record<string, unknown>,
intent: IntentAnalysis
): Promise<{
content: string;
thinkingContent?: string;
tokensUsed: number;
model: string;
}> {
const config = await this.ensureConfig();
// 获取系统Prompt
const systemPrompt = await this.configLoader.getSystemPrompt(this.getAgentCode());
// 获取当前阶段Prompt
const stagePrompt = await this.configLoader.getStagePrompt(
this.getAgentCode(),
session.currentStage
);
// 构建消息
const messages: Array<{ role: string; content: string }> = [];
if (systemPrompt) {
messages.push({
role: 'system',
content: this.renderPrompt(systemPrompt.content, { context, intent }),
});
}
if (stagePrompt) {
messages.push({
role: 'system',
content: this.renderPrompt(stagePrompt.content, { context, intent }),
});
}
messages.push({
role: 'user',
content: userMessage,
});
// 调用LLM
return await this.llmService.chat({
messages,
model: config.definition.config?.defaultModel,
temperature: 0.7,
});
}
/**
* 渲染Prompt模板
*/
protected renderPrompt(
template: string,
variables: Record<string, unknown>
): string {
let result = template;
// 简单的变量替换 {{variable}}
for (const [key, value] of Object.entries(variables)) {
const regex = new RegExp(`\\{\\{${key}\\}\\}`, 'g');
result = result.replace(regex, JSON.stringify(value, null, 2));
}
return result;
}
/**
* 运行Reflexion检查
*/
protected async runReflexion(
session: AgentSession,
stageCode: string,
timing: 'on_sync' | 'on_stage_complete' | 'on_generate',
data: Record<string, unknown>
): Promise<Array<{
ruleCode: string;
ruleName: string;
passed: boolean;
severity: 'error' | 'warning' | 'info';
message?: string;
}>> {
const rules = await this.configLoader.getReflexionRules(
this.getAgentCode(),
stageCode,
timing
);
const results = [];
for (const rule of rules) {
try {
let passed = true;
let message: string | undefined;
if (rule.ruleType === 'rule_based' && rule.conditions) {
// 基于规则的检查
passed = this.evaluateRuleConditions(rule.conditions, data);
if (!passed) {
message = `${rule.ruleName}检查未通过`;
}
} else if (rule.ruleType === 'prompt_based' && rule.promptTemplate) {
// 基于Prompt的检查简化版实际应调用LLM
// TODO: 实现LLM-based检查
passed = true;
}
results.push({
ruleCode: rule.ruleCode,
ruleName: rule.ruleName,
passed,
severity: rule.severity,
message,
});
} catch (error) {
results.push({
ruleCode: rule.ruleCode,
ruleName: rule.ruleName,
passed: false,
severity: 'warning' as const,
message: `规则执行出错: ${error instanceof Error ? error.message : String(error)}`,
});
}
}
return results;
}
/**
* 评估规则条件
*/
protected evaluateRuleConditions(
conditions: Record<string, unknown>,
data: Record<string, unknown>
): boolean {
// 简单的规则评估逻辑
for (const [field, requirement] of Object.entries(conditions)) {
const value = data[field];
if (requirement === 'required' && !value) {
return false;
}
if (typeof requirement === 'object' && requirement !== null) {
const req = requirement as Record<string, unknown>;
if (req.minLength && typeof value === 'string' && value.length < (req.minLength as number)) {
return false;
}
if (req.notEmpty && (!value || (Array.isArray(value) && value.length === 0))) {
return false;
}
}
}
return true;
}
/**
* 更新会话统计
*/
protected async updateSessionStats(sessionId: string, tokensUsed: number): Promise<void> {
await this.prisma.agentSession.update({
where: { id: sessionId },
data: {
turnCount: { increment: 1 },
totalTokens: { increment: tokensUsed },
},
});
}
/**
* 更新会话阶段
*/
protected async updateSessionStage(sessionId: string, stageCode: string): Promise<void> {
await this.prisma.agentSession.update({
where: { id: sessionId },
data: { currentStage: stageCode },
});
}
/**
* 获取当前阶段信息
*/
protected async getCurrentStageInfo(stageCode: string): Promise<AgentStage | null> {
const config = await this.ensureConfig();
return config.stages.find((s: AgentStage) => s.stageCode === stageCode) ?? null;
}
}

View File

@@ -0,0 +1,266 @@
/**
* Agent Configuration Loader
* 从数据库加载Agent配置
*
* @module agent/services/ConfigLoader
*/
import { PrismaClient } from '@prisma/client';
import {
AgentDefinition,
AgentStage,
AgentPrompt,
ReflexionRule,
AgentFullConfig,
AgentConfig,
} from '../types/index.js';
export class ConfigLoader {
private prisma: PrismaClient;
// 配置缓存
private configCache: Map<string, { config: AgentFullConfig; loadedAt: Date }> = new Map();
private cacheTTL: number = 5 * 60 * 1000; // 5分钟缓存
constructor(prisma: PrismaClient) {
this.prisma = prisma;
}
/**
* 加载Agent完整配置
* @param agentCode Agent代码如 'protocol_agent'
* @param useCache 是否使用缓存
*/
async loadAgentConfig(agentCode: string, useCache: boolean = true): Promise<AgentFullConfig> {
// 检查缓存
if (useCache) {
const cached = this.configCache.get(agentCode);
if (cached && Date.now() - cached.loadedAt.getTime() < this.cacheTTL) {
return cached.config;
}
}
// 从数据库加载
const definition = await this.loadDefinition(agentCode);
if (!definition) {
throw new Error(`Agent not found: ${agentCode}`);
}
const [stages, prompts, reflexionRules] = await Promise.all([
this.loadStages(definition.id),
this.loadPrompts(definition.id),
this.loadReflexionRules(definition.id),
]);
const config: AgentFullConfig = {
definition,
stages,
prompts,
reflexionRules,
};
// 更新缓存
this.configCache.set(agentCode, { config, loadedAt: new Date() });
return config;
}
/**
* 加载Agent定义
*/
private async loadDefinition(agentCode: string): Promise<AgentDefinition | null> {
const result = await this.prisma.agentDefinition.findUnique({
where: { code: agentCode },
});
if (!result) return null;
return {
id: result.id,
code: result.code,
name: result.name,
description: result.description ?? undefined,
version: result.version,
config: result.config as unknown as AgentConfig | undefined,
isActive: result.isActive,
createdAt: result.createdAt,
updatedAt: result.updatedAt,
};
}
/**
* 加载Agent阶段配置
*/
private async loadStages(agentId: string): Promise<AgentStage[]> {
const results = await this.prisma.agentStage.findMany({
where: { agentId },
orderBy: { sortOrder: 'asc' },
});
return results.map((r) => ({
id: r.id,
agentId: r.agentId,
stageCode: r.stageCode,
stageName: r.stageName,
sortOrder: r.sortOrder,
config: r.config as Record<string, unknown> | undefined,
nextStages: r.nextStages,
isInitial: r.isInitial,
isFinal: r.isFinal,
}));
}
/**
* 加载Agent Prompt模板
*/
private async loadPrompts(agentId: string): Promise<AgentPrompt[]> {
const results = await this.prisma.agentPrompt.findMany({
where: {
agentId,
isActive: true,
},
});
return results.map((r) => ({
id: r.id,
agentId: r.agentId,
stageId: r.stageId ?? undefined,
promptType: r.promptType as AgentPrompt['promptType'],
promptCode: r.promptCode,
content: r.content,
variables: r.variables,
version: r.version,
isActive: r.isActive,
}));
}
/**
* 加载Reflexion规则
*/
private async loadReflexionRules(agentId: string): Promise<ReflexionRule[]> {
const results = await this.prisma.reflexionRule.findMany({
where: {
agentId,
isActive: true,
},
orderBy: { sortOrder: 'asc' },
});
return results.map((r) => ({
id: r.id,
agentId: r.agentId,
ruleCode: r.ruleCode,
ruleName: r.ruleName,
triggerStage: r.triggerStage ?? undefined,
triggerTiming: r.triggerTiming as ReflexionRule['triggerTiming'],
ruleType: r.ruleType as ReflexionRule['ruleType'],
conditions: r.conditions as Record<string, unknown> | undefined,
promptTemplate: r.promptTemplate ?? undefined,
severity: r.severity as ReflexionRule['severity'],
failureAction: r.failureAction as ReflexionRule['failureAction'],
isActive: r.isActive,
sortOrder: r.sortOrder,
}));
}
/**
* 获取指定阶段的Prompt
*/
async getStagePrompt(agentCode: string, stageCode: string): Promise<AgentPrompt | null> {
const config = await this.loadAgentConfig(agentCode);
// 先找阶段特定的Prompt
const stage = config.stages.find((s: AgentStage) => s.stageCode === stageCode);
if (stage) {
const stagePrompt = config.prompts.find(
(p: AgentPrompt) => p.stageId === stage.id && p.promptType === 'stage'
);
if (stagePrompt) return stagePrompt;
}
return null;
}
/**
* 获取系统Prompt
*/
async getSystemPrompt(agentCode: string): Promise<AgentPrompt | null> {
const config = await this.loadAgentConfig(agentCode);
return config.prompts.find(p => p.promptType === 'system' && !p.stageId) ?? null;
}
/**
* 获取提取Prompt
*/
async getExtractionPrompt(agentCode: string, stageCode: string): Promise<AgentPrompt | null> {
const config = await this.loadAgentConfig(agentCode);
const stage = config.stages.find(s => s.stageCode === stageCode);
if (stage) {
return config.prompts.find(
p => p.stageId === stage.id && p.promptType === 'extraction'
) ?? null;
}
return null;
}
/**
* 获取阶段的Reflexion规则
*/
async getReflexionRules(
agentCode: string,
stageCode: string,
timing: ReflexionRule['triggerTiming']
): Promise<ReflexionRule[]> {
const config = await this.loadAgentConfig(agentCode);
return config.reflexionRules.filter(
r => (r.triggerStage === stageCode || r.triggerStage === null) &&
r.triggerTiming === timing
);
}
/**
* 清除缓存
*/
clearCache(agentCode?: string): void {
if (agentCode) {
this.configCache.delete(agentCode);
} else {
this.configCache.clear();
}
}
/**
* 获取阶段流程顺序
*/
async getStageFlow(agentCode: string): Promise<AgentStage[]> {
const config = await this.loadAgentConfig(agentCode);
return config.stages.sort((a, b) => a.sortOrder - b.sortOrder);
}
/**
* 获取下一个阶段
*/
async getNextStage(agentCode: string, currentStageCode: string): Promise<AgentStage | null> {
const stages = await this.getStageFlow(agentCode);
const currentIndex = stages.findIndex(s => s.stageCode === currentStageCode);
if (currentIndex >= 0 && currentIndex < stages.length - 1) {
return stages[currentIndex + 1];
}
return null;
}
/**
* 检查是否为最终阶段
*/
async isFinalStage(agentCode: string, stageCode: string): Promise<boolean> {
const config = await this.loadAgentConfig(agentCode);
const stage = config.stages.find(s => s.stageCode === stageCode);
return stage?.isFinal ?? false;
}
}

View File

@@ -0,0 +1,270 @@
/**
* Query Analyzer
* 用户意图识别和查询分析
*
* @module agent/services/QueryAnalyzer
*/
import { IntentAnalysis } from '../types/index.js';
import { LLMServiceInterface } from './BaseAgentOrchestrator.js';
/** 意图类型定义 */
export type IntentType =
| 'dialogue' // 普通对话
| 'proceed_next_stage' // 进入下一阶段
| 'confirm_sync' // 确认同步
| 'edit_data' // 编辑已有数据
| 'ask_question' // 提问
| 'clarification' // 澄清/补充信息
| 'generate_protocol' // 生成研究方案
| 'export_document'; // 导出文档
/** 意图识别规则 */
interface IntentRule {
type: IntentType;
keywords: string[];
patterns?: RegExp[];
confidence: number;
}
export class QueryAnalyzer {
private llmService: LLMServiceInterface;
// 预定义的意图规则
private intentRules: IntentRule[] = [
{
type: 'proceed_next_stage',
keywords: ['继续', '下一步', '下一个', 'next', '进入下一', '开始下一'],
confidence: 0.9,
},
{
type: 'confirm_sync',
keywords: ['确认', '同步', '保存', '确定', 'confirm', 'save'],
confidence: 0.85,
},
{
type: 'edit_data',
keywords: ['修改', '编辑', '更改', '调整', '更新', 'edit', 'modify'],
confidence: 0.8,
},
{
type: 'generate_protocol',
keywords: ['生成方案', '生成研究方案', '一键生成', '生成全文'],
confidence: 0.95,
},
{
type: 'export_document',
keywords: ['导出', '下载', '保存word', '保存文档', 'export', 'download'],
confidence: 0.9,
},
{
type: 'ask_question',
keywords: ['什么是', '如何', '怎么', '为什么', '是否', '?', ''],
patterns: [/^(什么|如何|怎么|为什么|是否|能不能|可以|请问)/],
confidence: 0.7,
},
{
type: 'clarification',
keywords: ['补充', '另外', '还有', '再说', '顺便'],
confidence: 0.6,
},
];
constructor(llmService: LLMServiceInterface) {
this.llmService = llmService;
}
/**
* 分析用户输入的意图
*/
async analyze(
userMessage: string,
context?: {
currentStage?: string;
completedStages?: string[];
conversationHistory?: Array<{ role: string; content: string }>;
}
): Promise<IntentAnalysis> {
// 1. 基于规则的快速识别
const ruleBasedResult = this.ruleBasedAnalysis(userMessage);
if (ruleBasedResult.confidence >= 0.85) {
return ruleBasedResult;
}
// 2. 上下文增强
const contextEnhanced = this.contextEnhancedAnalysis(
userMessage,
ruleBasedResult,
context
);
// 3. 如果信心度仍然较低可以考虑使用LLM此处简化
// TODO: 实现LLM-based意图识别
return contextEnhanced;
}
/**
* 基于规则的意图分析
*/
private ruleBasedAnalysis(userMessage: string): IntentAnalysis {
const lowerMessage = userMessage.toLowerCase();
const entities: Record<string, unknown> = {};
// 遍历规则匹配
for (const rule of this.intentRules) {
// 关键词匹配
const keywordMatch = rule.keywords.some(keyword =>
lowerMessage.includes(keyword.toLowerCase())
);
// 正则匹配
const patternMatch = rule.patterns?.some(pattern =>
pattern.test(userMessage)
);
if (keywordMatch || patternMatch) {
// 提取实体
this.extractEntities(userMessage, rule.type, entities);
return {
intent: rule.type,
confidence: keywordMatch && patternMatch ?
Math.min(rule.confidence + 0.1, 1) : rule.confidence,
entities,
suggestedAction: this.getSuggestedAction(rule.type),
};
}
}
// 默认为对话意图
return {
intent: 'dialogue',
confidence: 0.5,
entities,
suggestedAction: 'continue_dialogue',
};
}
/**
* 上下文增强分析
*/
private contextEnhancedAnalysis(
userMessage: string,
baseResult: IntentAnalysis,
context?: {
currentStage?: string;
completedStages?: string[];
conversationHistory?: Array<{ role: string; content: string }>;
}
): IntentAnalysis {
if (!context) return baseResult;
const result = { ...baseResult };
// 根据当前阶段调整意图
if (context.currentStage) {
// 如果已完成所有阶段,倾向于生成方案
if (context.completedStages?.length === 5) {
if (userMessage.includes('方案') || userMessage.includes('生成')) {
result.intent = 'generate_protocol';
result.confidence = Math.max(result.confidence, 0.8);
result.suggestedAction = 'generate_protocol';
}
}
}
// 根据对话历史调整
if (context.conversationHistory?.length) {
const lastAssistantMsg = context.conversationHistory
.filter(m => m.role === 'assistant')
.pop();
// 如果上一条AI消息询问是否继续
if (lastAssistantMsg?.content.includes('继续') ||
lastAssistantMsg?.content.includes('下一阶段')) {
if (userMessage.match(/^(好|是|对|可以|行|ok|yes)/i)) {
result.intent = 'proceed_next_stage';
result.confidence = Math.max(result.confidence, 0.85);
result.suggestedAction = 'stage_transition';
}
}
}
return result;
}
/**
* 提取实体
*/
private extractEntities(
message: string,
intentType: IntentType,
entities: Record<string, unknown>
): void {
// 阶段名称实体
const stagePatterns = [
{ pattern: /科学问题/, entity: 'scientific_question' },
{ pattern: /PICO|pico/, entity: 'pico' },
{ pattern: /研究设计/, entity: 'study_design' },
{ pattern: /样本量/, entity: 'sample_size' },
{ pattern: /(终点|结局|观察)指标/, entity: 'endpoints' },
];
for (const { pattern, entity } of stagePatterns) {
if (pattern.test(message)) {
entities['targetStage'] = entity;
break;
}
}
// 数字实体
const numberMatch = message.match(/(\d+)/g);
if (numberMatch) {
entities['numbers'] = numberMatch.map(n => parseInt(n, 10));
}
}
/**
* 获取建议的动作
*/
private getSuggestedAction(intentType: IntentType): string {
const actionMap: Record<IntentType, string> = {
'dialogue': 'continue_dialogue',
'proceed_next_stage': 'stage_transition',
'confirm_sync': 'sync_data',
'edit_data': 'edit_context',
'ask_question': 'answer_question',
'clarification': 'request_clarification',
'generate_protocol': 'generate_protocol',
'export_document': 'export_document',
};
return actionMap[intentType] ?? 'continue_dialogue';
}
/**
* 检测用户是否表达了同意/肯定
*/
isAffirmative(message: string): boolean {
const affirmativePatterns = [
/^(好|是|对|可以|行|ok|yes|yeah|sure|当然|没问题|同意|确认)/i,
/^(嗯|恩|en|um)/i,
];
return affirmativePatterns.some(p => p.test(message.trim()));
}
/**
* 检测用户是否表达了否定
*/
isNegative(message: string): boolean {
const negativePatterns = [
/^(不|否|no|nope|算了|取消|等等|暂时不|还没)/i,
];
return negativePatterns.some(p => p.test(message.trim()));
}
}

View File

@@ -0,0 +1,283 @@
/**
* Stage Manager
* Agent阶段状态管理
*
* @module agent/services/StageManager
*/
import { PrismaClient } from '@prisma/client';
import {
AgentSession,
AgentStage,
StageTransitionResult,
ProtocolStageCode,
} from '../types/index.js';
/** 阶段转换条件 */
export interface TransitionCondition {
fromStage: string;
toStage: string;
conditions: Record<string, unknown>;
}
export class StageManager {
private prisma: PrismaClient;
constructor(prisma: PrismaClient) {
this.prisma = prisma;
}
/**
* 执行阶段转换
*/
async transition(
session: AgentSession,
targetStage: string,
stages: AgentStage[]
): Promise<StageTransitionResult> {
const currentStageConfig = stages.find(s => s.stageCode === session.currentStage);
const targetStageConfig = stages.find(s => s.stageCode === targetStage);
if (!targetStageConfig) {
return {
success: false,
fromStage: session.currentStage,
toStage: targetStage,
message: `目标阶段 ${targetStage} 不存在`,
};
}
// 检查是否允许转换
if (currentStageConfig && !currentStageConfig.nextStages.includes(targetStage)) {
// 检查是否是返回之前的阶段(允许回退)
const targetOrder = targetStageConfig.sortOrder;
const currentOrder = currentStageConfig.sortOrder;
if (targetOrder > currentOrder) {
return {
success: false,
fromStage: session.currentStage,
toStage: targetStage,
message: `不能从 ${currentStageConfig.stageName} 直接跳转到 ${targetStageConfig.stageName}`,
};
}
}
// 更新会话
await this.prisma.agentSession.update({
where: { id: session.id },
data: { currentStage: targetStage },
});
return {
success: true,
fromStage: session.currentStage,
toStage: targetStage,
message: `已进入${targetStageConfig.stageName}阶段`,
};
}
/**
* 获取下一个阶段
*/
getNextStage(
currentStageCode: string,
stages: AgentStage[]
): AgentStage | null {
const sortedStages = [...stages].sort((a, b) => a.sortOrder - b.sortOrder);
const currentIndex = sortedStages.findIndex(s => s.stageCode === currentStageCode);
if (currentIndex >= 0 && currentIndex < sortedStages.length - 1) {
return sortedStages[currentIndex + 1];
}
return null;
}
/**
* 获取上一个阶段
*/
getPreviousStage(
currentStageCode: string,
stages: AgentStage[]
): AgentStage | null {
const sortedStages = [...stages].sort((a, b) => a.sortOrder - b.sortOrder);
const currentIndex = sortedStages.findIndex(s => s.stageCode === currentStageCode);
if (currentIndex > 0) {
return sortedStages[currentIndex - 1];
}
return null;
}
/**
* 检查阶段是否完成
*/
isStageCompleted(
stageCode: ProtocolStageCode,
completedStages: ProtocolStageCode[]
): boolean {
return completedStages.includes(stageCode);
}
/**
* 检查是否所有阶段都已完成
*/
areAllStagesCompleted(
stages: AgentStage[],
completedStages: string[]
): boolean {
const requiredStages = stages
.filter(s => !s.isFinal)
.map(s => s.stageCode);
return requiredStages.every(code => completedStages.includes(code));
}
/**
* 获取当前进度百分比
*/
getProgressPercentage(
stages: AgentStage[],
completedStages: string[]
): number {
const totalStages = stages.filter(s => !s.isFinal).length;
const completed = completedStages.length;
return totalStages > 0 ? Math.round((completed / totalStages) * 100) : 0;
}
/**
* 获取阶段状态摘要
*/
getStagesSummary(
stages: AgentStage[],
currentStageCode: string,
completedStages: string[]
): Array<{
stageCode: string;
stageName: string;
status: 'completed' | 'current' | 'pending';
order: number;
}> {
return stages
.sort((a, b) => a.sortOrder - b.sortOrder)
.map(stage => ({
stageCode: stage.stageCode,
stageName: stage.stageName,
status: completedStages.includes(stage.stageCode)
? 'completed' as const
: stage.stageCode === currentStageCode
? 'current' as const
: 'pending' as const,
order: stage.sortOrder,
}));
}
/**
* 验证阶段转换是否有效
*/
validateTransition(
fromStage: string,
toStage: string,
stages: AgentStage[],
completedStages: string[]
): { valid: boolean; reason?: string } {
const fromStageConfig = stages.find(s => s.stageCode === fromStage);
const toStageConfig = stages.find(s => s.stageCode === toStage);
if (!fromStageConfig || !toStageConfig) {
return { valid: false, reason: '阶段配置不存在' };
}
// 允许回退到已完成的阶段
if (completedStages.includes(toStage)) {
return { valid: true };
}
// 前向转换:检查当前阶段是否已完成
if (toStageConfig.sortOrder > fromStageConfig.sortOrder) {
if (!completedStages.includes(fromStage)) {
return {
valid: false,
reason: `请先完成${fromStageConfig.stageName}阶段`
};
}
}
// 检查是否是允许的转换
if (!fromStageConfig.nextStages.includes(toStage)) {
// 如果是回退,允许
if (toStageConfig.sortOrder < fromStageConfig.sortOrder) {
return { valid: true };
}
return {
valid: false,
reason: `不能从${fromStageConfig.stageName}转换到${toStageConfig.stageName}`
};
}
return { valid: true };
}
/**
* 获取可用的下一步操作
*/
getAvailableActions(
currentStageCode: string,
stages: AgentStage[],
completedStages: string[]
): Array<{
action: string;
label: string;
targetStage?: string;
enabled: boolean;
}> {
const currentStage = stages.find(s => s.stageCode === currentStageCode);
const actions = [];
// 同步当前阶段数据
if (currentStage && !completedStages.includes(currentStageCode)) {
actions.push({
action: 'sync',
label: '同步到方案',
enabled: true,
});
}
// 进入下一阶段
const nextStage = this.getNextStage(currentStageCode, stages);
if (nextStage) {
actions.push({
action: 'next_stage',
label: `进入${nextStage.stageName}`,
targetStage: nextStage.stageCode,
enabled: completedStages.includes(currentStageCode),
});
}
// 生成研究方案(仅当所有阶段完成时)
if (this.areAllStagesCompleted(stages, completedStages)) {
actions.push({
action: 'generate_protocol',
label: '一键生成研究方案',
enabled: true,
});
}
// 返回上一阶段
const prevStage = this.getPreviousStage(currentStageCode, stages);
if (prevStage) {
actions.push({
action: 'prev_stage',
label: `返回${prevStage.stageName}`,
targetStage: prevStage.stageCode,
enabled: true,
});
}
return actions;
}
}

View File

@@ -0,0 +1,349 @@
/**
* Trace Logger
* Agent执行追踪日志记录
*
* @module agent/services/TraceLogger
*/
import { PrismaClient } from '@prisma/client';
import { AgentTraceRecord, StepType } from '../types/index.js';
/** 追踪日志输入 */
export interface TraceLogInput {
sessionId: string;
traceId: string;
stepIndex: number;
stepType: StepType;
input?: Record<string, unknown>;
output?: Record<string, unknown>;
stageCode?: string;
modelUsed?: string;
tokensUsed?: number;
durationMs?: number;
errorType?: string;
errorMsg?: string;
}
export class TraceLogger {
private prisma: PrismaClient;
private enabled: boolean = true;
constructor(prisma: PrismaClient) {
this.prisma = prisma;
}
/**
* 启用/禁用追踪
*/
setEnabled(enabled: boolean): void {
this.enabled = enabled;
}
/**
* 记录追踪日志
*/
async log(input: TraceLogInput): Promise<AgentTraceRecord | null> {
if (!this.enabled) {
return null;
}
try {
const result = await this.prisma.agentTrace.create({
data: {
sessionId: input.sessionId,
traceId: input.traceId,
stepIndex: input.stepIndex,
stepType: input.stepType,
input: input.input ? JSON.parse(JSON.stringify(input.input)) : undefined,
output: input.output ? JSON.parse(JSON.stringify(input.output)) : undefined,
stageCode: input.stageCode ?? undefined,
modelUsed: input.modelUsed ?? undefined,
tokensUsed: input.tokensUsed ?? undefined,
durationMs: input.durationMs ?? undefined,
errorType: input.errorType ?? undefined,
errorMsg: input.errorMsg ?? undefined,
},
});
return {
id: result.id,
sessionId: result.sessionId,
traceId: result.traceId,
stepIndex: result.stepIndex,
stepType: result.stepType as StepType,
input: result.input as Record<string, unknown> | undefined,
output: result.output as Record<string, unknown> | undefined,
stageCode: result.stageCode ?? undefined,
modelUsed: result.modelUsed ?? undefined,
tokensUsed: result.tokensUsed ?? undefined,
durationMs: result.durationMs ?? undefined,
errorType: result.errorType ?? undefined,
errorMsg: result.errorMsg ?? undefined,
createdAt: result.createdAt,
};
} catch (error) {
console.error('[TraceLogger] Failed to log trace:', error);
return null;
}
}
/**
* 批量记录追踪日志
*/
async logBatch(inputs: TraceLogInput[]): Promise<number> {
if (!this.enabled || inputs.length === 0) {
return 0;
}
try {
const result = await this.prisma.agentTrace.createMany({
data: inputs.map(input => ({
sessionId: input.sessionId,
traceId: input.traceId,
stepIndex: input.stepIndex,
stepType: input.stepType,
input: input.input ? JSON.parse(JSON.stringify(input.input)) : undefined,
output: input.output ? JSON.parse(JSON.stringify(input.output)) : undefined,
stageCode: input.stageCode ?? undefined,
modelUsed: input.modelUsed ?? undefined,
tokensUsed: input.tokensUsed ?? undefined,
durationMs: input.durationMs ?? undefined,
errorType: input.errorType ?? undefined,
errorMsg: input.errorMsg ?? undefined,
})),
});
return result.count;
} catch (error) {
console.error('[TraceLogger] Failed to batch log traces:', error);
return 0;
}
}
/**
* 获取会话的所有追踪记录
*/
async getSessionTraces(sessionId: string): Promise<AgentTraceRecord[]> {
const results = await this.prisma.agentTrace.findMany({
where: { sessionId },
orderBy: [
{ traceId: 'asc' },
{ stepIndex: 'asc' },
],
});
return results.map(r => ({
id: r.id,
sessionId: r.sessionId,
traceId: r.traceId,
stepIndex: r.stepIndex,
stepType: r.stepType as StepType,
input: r.input as Record<string, unknown> | undefined,
output: r.output as Record<string, unknown> | undefined,
stageCode: r.stageCode ?? undefined,
modelUsed: r.modelUsed ?? undefined,
tokensUsed: r.tokensUsed ?? undefined,
durationMs: r.durationMs ?? undefined,
errorType: r.errorType ?? undefined,
errorMsg: r.errorMsg ?? undefined,
createdAt: r.createdAt,
}));
}
/**
* 获取单个请求的追踪记录
*/
async getTraceById(traceId: string): Promise<AgentTraceRecord[]> {
const results = await this.prisma.agentTrace.findMany({
where: { traceId },
orderBy: { stepIndex: 'asc' },
});
return results.map(r => ({
id: r.id,
sessionId: r.sessionId,
traceId: r.traceId,
stepIndex: r.stepIndex,
stepType: r.stepType as StepType,
input: r.input as Record<string, unknown> | undefined,
output: r.output as Record<string, unknown> | undefined,
stageCode: r.stageCode ?? undefined,
modelUsed: r.modelUsed ?? undefined,
tokensUsed: r.tokensUsed ?? undefined,
durationMs: r.durationMs ?? undefined,
errorType: r.errorType ?? undefined,
errorMsg: r.errorMsg ?? undefined,
createdAt: r.createdAt,
}));
}
/**
* 获取会话的错误追踪
*/
async getSessionErrors(sessionId: string): Promise<AgentTraceRecord[]> {
const results = await this.prisma.agentTrace.findMany({
where: {
sessionId,
errorType: { not: null },
},
orderBy: { createdAt: 'desc' },
});
return results.map(r => ({
id: r.id,
sessionId: r.sessionId,
traceId: r.traceId,
stepIndex: r.stepIndex,
stepType: r.stepType as StepType,
input: r.input as Record<string, unknown> | undefined,
output: r.output as Record<string, unknown> | undefined,
stageCode: r.stageCode ?? undefined,
modelUsed: r.modelUsed ?? undefined,
tokensUsed: r.tokensUsed ?? undefined,
durationMs: r.durationMs ?? undefined,
errorType: r.errorType ?? undefined,
errorMsg: r.errorMsg ?? undefined,
createdAt: r.createdAt,
}));
}
/**
* 计算会话的Token统计
*/
async getSessionTokenStats(sessionId: string): Promise<{
totalTokens: number;
byModel: Record<string, number>;
byStepType: Record<string, number>;
}> {
const traces = await this.prisma.agentTrace.findMany({
where: { sessionId },
select: {
tokensUsed: true,
modelUsed: true,
stepType: true,
},
});
const stats = {
totalTokens: 0,
byModel: {} as Record<string, number>,
byStepType: {} as Record<string, number>,
};
for (const trace of traces) {
const tokens = trace.tokensUsed ?? 0;
stats.totalTokens += tokens;
if (trace.modelUsed) {
stats.byModel[trace.modelUsed] = (stats.byModel[trace.modelUsed] ?? 0) + tokens;
}
stats.byStepType[trace.stepType] = (stats.byStepType[trace.stepType] ?? 0) + tokens;
}
return stats;
}
/**
* 计算平均响应时间
*/
async getAverageResponseTime(sessionId: string): Promise<number> {
const result = await this.prisma.agentTrace.aggregate({
where: {
sessionId,
durationMs: { not: null },
},
_avg: {
durationMs: true,
},
});
return result._avg.durationMs ?? 0;
}
/**
* 删除旧的追踪记录(清理)
*/
async cleanupOldTraces(daysToKeep: number = 30): Promise<number> {
const cutoffDate = new Date();
cutoffDate.setDate(cutoffDate.getDate() - daysToKeep);
const result = await this.prisma.agentTrace.deleteMany({
where: {
createdAt: { lt: cutoffDate },
},
});
return result.count;
}
/**
* 生成追踪报告
*/
async generateTraceReport(sessionId: string): Promise<{
sessionId: string;
totalSteps: number;
totalTokens: number;
totalDuration: number;
stageBreakdown: Array<{
stage: string;
steps: number;
tokens: number;
duration: number;
}>;
errors: Array<{
traceId: string;
stepType: string;
errorType: string;
errorMsg: string;
}>;
}> {
const traces = await this.getSessionTraces(sessionId);
const stageMap = new Map<string, { steps: number; tokens: number; duration: number }>();
let totalTokens = 0;
let totalDuration = 0;
const errors: Array<{
traceId: string;
stepType: string;
errorType: string;
errorMsg: string;
}> = [];
for (const trace of traces) {
totalTokens += trace.tokensUsed ?? 0;
totalDuration += trace.durationMs ?? 0;
if (trace.stageCode) {
const existing = stageMap.get(trace.stageCode) ?? { steps: 0, tokens: 0, duration: 0 };
stageMap.set(trace.stageCode, {
steps: existing.steps + 1,
tokens: existing.tokens + (trace.tokensUsed ?? 0),
duration: existing.duration + (trace.durationMs ?? 0),
});
}
if (trace.errorType) {
errors.push({
traceId: trace.traceId,
stepType: trace.stepType,
errorType: trace.errorType,
errorMsg: trace.errorMsg ?? '',
});
}
}
return {
sessionId,
totalSteps: traces.length,
totalTokens,
totalDuration,
stageBreakdown: Array.from(stageMap.entries()).map(([stage, stats]) => ({
stage,
...stats,
})),
errors,
};
}
}

View File

@@ -0,0 +1,12 @@
/**
* Agent Services Export
*
* @module agent/services
*/
export { ConfigLoader } from './ConfigLoader.js';
export { BaseAgentOrchestrator, type OrchestratorDependencies, type LLMServiceInterface } from './BaseAgentOrchestrator.js';
export { QueryAnalyzer, type IntentType } from './QueryAnalyzer.js';
export { StageManager, type TransitionCondition } from './StageManager.js';
export { TraceLogger, type TraceLogInput } from './TraceLogger.js';

View File

@@ -0,0 +1,382 @@
/**
* Protocol Agent Framework - Core Types
* 通用Agent框架类型定义
*
* @module agent/types
*/
// ============================================================
// 基础类型
// ============================================================
/** Agent状态枚举 */
export type AgentSessionStatus = 'active' | 'completed' | 'paused' | 'error';
/** 阶段状态枚举 */
export type StageStatus = 'pending' | 'in_progress' | 'completed' | 'skipped';
/** 执行步骤类型 */
export type StepType = 'query' | 'plan' | 'execute' | 'reflect' | 'tool_call' | 'sync';
/** Prompt类型 */
export type PromptType = 'system' | 'stage' | 'extraction' | 'reflexion' | 'generation';
/** Reflexion规则类型 */
export type ReflexionRuleType = 'rule_based' | 'prompt_based';
/** 规则严重性 */
export type ReflexionSeverity = 'error' | 'warning' | 'info';
/** 失败处理动作 */
export type FailureAction = 'block' | 'warn' | 'log';
// ============================================================
// Agent定义相关
// ============================================================
/** Agent全局配置 */
export interface AgentConfig {
defaultModel: string;
maxTurns: number;
timeout: number;
enableTrace: boolean;
enableReflexion: boolean;
}
/** Agent定义 */
export interface AgentDefinition {
id: string;
code: string;
name: string;
description?: string;
version: string;
config?: AgentConfig;
isActive: boolean;
createdAt: Date;
updatedAt: Date;
}
/** Agent阶段定义 */
export interface AgentStage {
id: string;
agentId: string;
stageCode: string;
stageName: string;
sortOrder: number;
config?: Record<string, unknown>;
nextStages: string[];
isInitial: boolean;
isFinal: boolean;
}
/** Agent Prompt定义 */
export interface AgentPrompt {
id: string;
agentId: string;
stageId?: string;
promptType: PromptType;
promptCode: string;
content: string;
variables: string[];
version: number;
isActive: boolean;
}
// ============================================================
// 会话与追踪
// ============================================================
/** Agent会话 */
export interface AgentSession {
id: string;
agentId: string;
conversationId: string;
userId: string;
currentStage: string;
status: AgentSessionStatus;
contextRef?: string;
turnCount: number;
totalTokens: number;
createdAt: Date;
updatedAt: Date;
}
/** 执行追踪记录 */
export interface AgentTraceRecord {
id: string;
sessionId: string;
traceId: string;
stepIndex: number;
stepType: StepType;
input?: Record<string, unknown>;
output?: Record<string, unknown>;
stageCode?: string;
modelUsed?: string;
tokensUsed?: number;
durationMs?: number;
errorType?: string;
errorMsg?: string;
createdAt: Date;
}
// ============================================================
// Reflexion规则
// ============================================================
/** Reflexion规则定义 */
export interface ReflexionRule {
id: string;
agentId: string;
ruleCode: string;
ruleName: string;
triggerStage?: string;
triggerTiming: 'on_sync' | 'on_stage_complete' | 'on_generate';
ruleType: ReflexionRuleType;
conditions?: Record<string, unknown>;
promptTemplate?: string;
severity: ReflexionSeverity;
failureAction: FailureAction;
isActive: boolean;
sortOrder: number;
}
/** Reflexion检查结果 */
export interface ReflexionResult {
ruleCode: string;
ruleName: string;
passed: boolean;
severity: ReflexionSeverity;
message?: string;
details?: Record<string, unknown>;
}
// ============================================================
// Protocol Agent专用类型
// ============================================================
/** Protocol Context状态 */
export type ProtocolContextStatus = 'in_progress' | 'completed' | 'abandoned';
/** Protocol阶段代码 */
export type ProtocolStageCode =
| 'scientific_question'
| 'pico'
| 'study_design'
| 'sample_size'
| 'endpoints';
/** 科学问题数据 */
export interface ScientificQuestionData {
content: string;
background?: string;
significance?: string;
confirmed: boolean;
confirmedAt?: Date;
}
/** PICO要素 */
export interface PICOElement {
value: string;
details?: string;
}
/** PICO数据 */
export interface PICOData {
P: PICOElement; // Population
I: PICOElement; // Intervention
C: PICOElement; // Comparison
O: PICOElement; // Outcome
confirmed: boolean;
confirmedAt?: Date;
}
/** 研究设计数据 */
export interface StudyDesignData {
type: string; // RCT, Cohort, Case-Control, etc.
blinding?: string; // Open, Single-blind, Double-blind
randomization?: string; // Simple, Block, Stratified
duration?: string; // 研究周期
multiCenter?: boolean;
centerCount?: number;
confirmed: boolean;
confirmedAt?: Date;
}
/** 样本量数据 */
export interface SampleSizeData {
total: number;
perGroup?: number;
alpha?: number; // 显著性水平
power?: number; // 统计效力
effectSize?: number; // 效应量
dropoutRate?: number; // 脱落率
justification?: string; // 计算依据
confirmed: boolean;
confirmedAt?: Date;
}
/** 终点指标项 */
export interface EndpointItem {
name: string;
definition?: string;
method?: string;
timePoint?: string;
}
/** 终点指标数据 */
export interface EndpointsData {
primary: EndpointItem[];
secondary: EndpointItem[];
safety: EndpointItem[];
exploratory?: EndpointItem[];
confirmed: boolean;
confirmedAt?: Date;
}
/** Protocol上下文完整数据 */
export interface ProtocolContextData {
id: string;
conversationId: string;
userId: string;
currentStage: ProtocolStageCode;
status: ProtocolContextStatus;
// 5个核心阶段数据
scientificQuestion?: ScientificQuestionData;
pico?: PICOData;
studyDesign?: StudyDesignData;
sampleSize?: SampleSizeData;
endpoints?: EndpointsData;
// 元数据
completedStages: ProtocolStageCode[];
lastActiveAt: Date;
createdAt: Date;
updatedAt: Date;
}
/** Protocol生成记录 */
export interface ProtocolGenerationRecord {
id: string;
contextId: string;
userId: string;
generatedContent: string;
contentVersion: number;
promptUsed: string;
modelUsed: string;
tokensUsed?: number;
durationMs?: number;
wordFileKey?: string;
lastExportedAt?: Date;
status: 'generating' | 'completed' | 'failed';
errorMessage?: string;
createdAt: Date;
updatedAt: Date;
}
// ============================================================
// 服务接口类型
// ============================================================
/** 用户消息输入 */
export interface UserMessageInput {
conversationId: string;
userId: string;
content: string;
messageId?: string;
attachments?: unknown[];
}
/** Agent响应 */
export interface AgentResponse {
content: string;
thinkingContent?: string;
// 元数据
stage: string;
stageName: string;
// 动作卡片
actionCards?: ActionCard[];
// 同步按钮
syncButton?: SyncButtonData;
// Token统计
tokensUsed?: number;
modelUsed?: string;
}
/** 动作卡片 */
export interface ActionCard {
id: string;
type: string;
title: string;
description?: string;
actionUrl?: string;
actionParams?: Record<string, unknown>;
}
/** 同步按钮数据 */
export interface SyncButtonData {
stageCode: string;
extractedData: Record<string, unknown>;
label: string;
disabled?: boolean;
}
/** 同步请求 */
export interface SyncRequest {
conversationId: string;
userId: string;
stageCode: ProtocolStageCode;
data: Record<string, unknown>;
}
/** 同步响应 */
export interface SyncResponse {
success: boolean;
context: ProtocolContextData;
nextStage?: ProtocolStageCode;
message?: string;
}
/** 意图识别结果 */
export interface IntentAnalysis {
intent: string;
confidence: number;
entities: Record<string, unknown>;
suggestedAction?: string;
}
/** 阶段转换结果 */
export interface StageTransitionResult {
success: boolean;
fromStage: string;
toStage: string;
message?: string;
reflexionResults?: ReflexionResult[];
}
// ============================================================
// 配置加载相关
// ============================================================
/** Agent完整配置含所有关联数据 */
export interface AgentFullConfig {
definition: AgentDefinition;
stages: AgentStage[];
prompts: AgentPrompt[];
reflexionRules: ReflexionRule[];
}
/** Prompt渲染上下文 */
export interface PromptRenderContext {
stage: string;
context: ProtocolContextData;
conversationHistory?: Array<{ role: string; content: string }>;
userMessage: string;
additionalVariables?: Record<string, unknown>;
}

View File

@@ -244,3 +244,5 @@ async function matchIntent(query: string): Promise<{

View File

@@ -98,3 +98,5 @@ export async function uploadAttachment(

View File

@@ -27,3 +27,5 @@ export { aiaRoutes };

View File

@@ -368,6 +368,8 @@ runTests().catch((error) => {

View File

@@ -347,6 +347,8 @@ Content-Type: application/json

View File

@@ -283,6 +283,8 @@ export const conflictDetectionService = new ConflictDetectionService();

View File

@@ -233,6 +233,8 @@ curl -X POST http://localhost:3000/api/v1/dc/tool-c/test/execute \

View File

@@ -287,6 +287,8 @@ export const streamAIController = new StreamAIController();

View File

@@ -196,6 +196,8 @@ logger.info('[SessionMemory] 会话记忆管理器已启动', {

View File

@@ -130,6 +130,8 @@ checkTableStructure();

View File

@@ -117,6 +117,8 @@ checkProjectConfig().catch(console.error);

View File

@@ -99,6 +99,8 @@ main();

View File

@@ -556,6 +556,8 @@ URL: https://iit.xunzhengyixue.com/api/v1/iit/patient-wechat/callback

View File

@@ -191,6 +191,8 @@ console.log('');

View File

@@ -508,6 +508,8 @@ export const patientWechatService = new PatientWechatService();

View File

@@ -153,6 +153,8 @@ testDifyIntegration().catch(error => {

View File

@@ -182,6 +182,8 @@ testIitDatabase()

View File

@@ -168,6 +168,8 @@ if (hasError) {

View File

@@ -194,6 +194,8 @@ async function testUrlVerification() {

View File

@@ -275,6 +275,8 @@ main().catch((error) => {

View File

@@ -159,6 +159,8 @@ Write-Host ""

View File

@@ -252,6 +252,8 @@ export interface CachedProtocolRules {

View File

@@ -65,6 +65,8 @@ export default async function healthRoutes(fastify: FastifyInstance) {

View File

@@ -143,6 +143,8 @@ Content-Type: application/json

View File

@@ -128,6 +128,8 @@ Write-Host " - 删除任务: DELETE $BaseUrl/api/v1/rvw/tasks/{taskId}" -Foregr

View File

@@ -42,6 +42,8 @@ export * from './services/utils.js';

View File

@@ -133,6 +133,8 @@ export function validateAgentSelection(agents: string[]): void {

View File

@@ -433,6 +433,8 @@ SET session_replication_role = 'origin';

View File

@@ -114,3 +114,5 @@ testCrossLanguageSearch();

View File

@@ -176,3 +176,5 @@ testQueryRewrite();

View File

@@ -122,3 +122,5 @@ testRerank();

View File

@@ -135,6 +135,8 @@ WHERE key = 'verify_test';

View File

@@ -278,6 +278,8 @@ verifyDatabase()

View File

@@ -68,6 +68,8 @@ export {}