feat(aia): Implement Protocol Agent MVP with reusable Agent framework

Sprint 1-3 Completed (Backend + Frontend):

Backend (Sprint 1-2):
- Implement 5-layer Agent framework (Query->Planner->Executor->Tools->Reflection)
- Create agent_schema with 6 tables (agent_definitions, stages, prompts, sessions, traces, reflexion_rules)
- Create protocol_schema with 2 tables (protocol_contexts, protocol_generations)
- Implement Protocol Agent core services (Orchestrator, ContextService, PromptBuilder)
- Integrate LLM service adapter (DeepSeek/Qwen/GPT-5/Claude)
- 6 API endpoints with full authentication
- 10/10 API tests passed

Frontend (Sprint 3):
- Add Protocol Agent entry in AgentHub (indigo theme card)
- Implement ProtocolAgentPage with 3-column layout
- Collapsible sidebar (Gemini style, 48px <-> 280px)
- StatePanel with 5 stage cards (scientific_question, pico, study_design, sample_size, endpoints)
- ChatArea with sync button and action cards integration
- 100% prototype design restoration (608 lines CSS)
- Detailed endpoints structure: baseline, exposure, outcomes, confounders

Features:
- 5-stage dialogue flow for research protocol design
- Conversation-driven interaction with sync-to-protocol button
- Real-time context state management
- One-click protocol generation button (UI ready, backend pending)

Database:
- agent_schema: 6 tables for reusable Agent framework
- protocol_schema: 2 tables for Protocol Agent
- Seed data: 1 agent + 5 stages + 9 prompts + 4 reflexion rules

Code Stats:
- Backend: 13 files, 4338 lines
- Frontend: 14 files, 2071 lines
- Total: 27 files, 6409 lines

Status: MVP core functionality completed, pending frontend-backend integration testing

Next: Sprint 4 - One-click protocol generation + Word export
This commit is contained in:
2026-01-24 17:29:24 +08:00
parent 61cdc97eeb
commit 96290d2f76
345 changed files with 13945 additions and 47 deletions

View File

@@ -58,6 +58,8 @@ Status: Day 1 complete (11/11 tasks), ready for Day 2

View File

@@ -288,6 +288,8 @@

View File

@@ -52,3 +52,5 @@ COPY docker-init-extensions.sql /docker-entrypoint-initdb.d/
# 暴露端口
EXPOSE 5432

View File

@@ -234,6 +234,8 @@ https://iit.xunzhengyixue.com/api/v1/iit/health

View File

@@ -163,6 +163,8 @@ https://iit.xunzhengyixue.com/api/v1/iit/health

View File

@@ -64,6 +64,8 @@

View File

@@ -324,6 +324,8 @@ npx tsx src/modules/iit-manager/test-patient-wechat-url-verify.ts

View File

@@ -186,6 +186,8 @@ npm run dev

View File

@@ -64,5 +64,7 @@ main()

View File

@@ -58,5 +58,7 @@ main()

View File

@@ -53,5 +53,7 @@ main()

View File

@@ -85,5 +85,7 @@ main()

View File

@@ -48,5 +48,7 @@ main()

View File

@@ -89,5 +89,7 @@ main()

View File

@@ -36,5 +36,7 @@ main()

View File

@@ -124,5 +124,7 @@ main()

View File

@@ -95,5 +95,7 @@ main()

View File

@@ -81,5 +81,7 @@ main()

View File

@@ -123,5 +123,7 @@ main()

View File

@@ -34,5 +34,7 @@ ON CONFLICT (id) DO NOTHING;

View File

@@ -66,5 +66,7 @@ ON CONFLICT (id) DO NOTHING;

View File

@@ -78,3 +78,5 @@ OSS_SIGNED_URL_EXPIRES=3600
```

View File

@@ -83,6 +83,8 @@ WHERE table_schema = 'dc_schema'

View File

@@ -121,6 +121,8 @@ ORDER BY ordinal_position;

View File

@@ -134,6 +134,8 @@ runMigration()

View File

@@ -68,6 +68,8 @@ COMMENT ON COLUMN "dc_schema"."dc_tool_c_sessions"."column_mapping" IS '列名

View File

@@ -95,6 +95,8 @@ COMMENT ON COLUMN dc_schema.dc_tool_c_sessions.expires_at IS '过期时间(创

View File

@@ -66,3 +66,5 @@ USING gin (metadata jsonb_path_ops);

View File

@@ -33,3 +33,5 @@ USING gin (tags);

View File

@@ -6,7 +6,7 @@ generator client {
datasource db {
provider = "postgresql"
url = env("DATABASE_URL")
schemas = ["admin_schema", "aia_schema", "asl_schema", "capability_schema", "common_schema", "dc_schema", "ekb_schema", "iit_schema", "pkb_schema", "platform_schema", "public", "rvw_schema", "ssa_schema", "st_schema"]
schemas = ["admin_schema", "agent_schema", "aia_schema", "asl_schema", "capability_schema", "common_schema", "dc_schema", "ekb_schema", "iit_schema", "pkb_schema", "platform_schema", "protocol_schema", "public", "rvw_schema", "ssa_schema", "st_schema"]
}
/// 应用缓存表 - Postgres-Only架构
@@ -1393,3 +1393,303 @@ model EkbChunk {
@@map("ekb_chunk")
@@schema("ekb_schema")
}
// ============================================================
// Agent Framework Schema (agent_schema)
// 通用Agent框架 - 可复用于多种Agent类型
// ============================================================
/// Agent定义表 - 存储Agent的基本配置
model AgentDefinition {
id String @id @default(uuid())
code String @unique /// 唯一标识: protocol_agent, stat_agent
name String /// 显示名称
description String? /// 描述
version String @default("1.0.0") /// 版本号
/// Agent配置
config Json? @db.JsonB /// 全局配置 { defaultModel, maxTurns, timeout }
/// 状态
isActive Boolean @default(true) @map("is_active")
/// 关联
stages AgentStage[]
prompts AgentPrompt[]
sessions AgentSession[]
reflexionRules ReflexionRule[]
createdAt DateTime @default(now()) @map("created_at")
updatedAt DateTime @updatedAt @map("updated_at")
@@index([code], map: "idx_agent_def_code")
@@index([isActive], map: "idx_agent_def_active")
@@map("agent_definitions")
@@schema("agent_schema")
}
/// Agent阶段配置表 - 定义Agent的工作流阶段
model AgentStage {
id String @id @default(uuid())
agentId String @map("agent_id") /// 关联的Agent
/// 阶段标识
stageCode String @map("stage_code") /// 阶段代码: scientific_question, pico
stageName String @map("stage_name") /// 阶段名称: 科学问题梳理
sortOrder Int @map("sort_order") /// 排序顺序
/// 阶段配置
config Json? @db.JsonB /// 阶段特定配置
/// 状态机配置
nextStages String[] @map("next_stages") /// 可转换的下一阶段列表
isInitial Boolean @default(false) @map("is_initial") /// 是否为起始阶段
isFinal Boolean @default(false) @map("is_final") /// 是否为结束阶段
/// 关联
agent AgentDefinition @relation(fields: [agentId], references: [id], onDelete: Cascade)
prompts AgentPrompt[]
createdAt DateTime @default(now()) @map("created_at")
updatedAt DateTime @updatedAt @map("updated_at")
@@unique([agentId, stageCode], map: "unique_agent_stage")
@@index([agentId], map: "idx_agent_stage_agent")
@@index([sortOrder], map: "idx_agent_stage_order")
@@map("agent_stages")
@@schema("agent_schema")
}
/// Agent Prompt模板表 - 存储各阶段的Prompt
model AgentPrompt {
id String @id @default(uuid())
agentId String @map("agent_id") /// 关联的Agent
stageId String? @map("stage_id") /// 关联的阶段可选null表示通用Prompt
/// Prompt标识
promptType String @map("prompt_type") /// system, stage, extraction, reflexion
promptCode String @map("prompt_code") /// 唯一代码
/// Prompt内容
content String @db.Text /// Prompt模板内容支持变量
variables String[] /// 预期变量列表
/// 版本控制
version Int @default(1) /// 版本号
isActive Boolean @default(true) @map("is_active")
/// 关联
agent AgentDefinition @relation(fields: [agentId], references: [id], onDelete: Cascade)
stage AgentStage? @relation(fields: [stageId], references: [id], onDelete: SetNull)
createdAt DateTime @default(now()) @map("created_at")
updatedAt DateTime @updatedAt @map("updated_at")
@@unique([agentId, promptCode, version], map: "unique_agent_prompt_version")
@@index([agentId, promptType], map: "idx_agent_prompt_type")
@@index([stageId], map: "idx_agent_prompt_stage")
@@map("agent_prompts")
@@schema("agent_schema")
}
/// Agent会话表 - 存储Agent的运行时会话状态
model AgentSession {
id String @id @default(uuid())
agentId String @map("agent_id") /// 关联的Agent定义
conversationId String @map("conversation_id") /// 关联的对话IDaia_schema.conversations
userId String @map("user_id") /// 用户ID
/// 当前状态
currentStage String @map("current_stage") /// 当前阶段代码
status String @default("active") /// active, completed, paused, error
/// 上下文数据具体Agent的上下文存储在对应schema
contextRef String? @map("context_ref") /// 上下文引用ID如protocol_schema.protocol_contexts.id
/// 统计信息
turnCount Int @default(0) @map("turn_count") /// 对话轮数
totalTokens Int @default(0) @map("total_tokens") /// 总Token数
/// 关联
agent AgentDefinition @relation(fields: [agentId], references: [id])
traces AgentTrace[]
createdAt DateTime @default(now()) @map("created_at")
updatedAt DateTime @updatedAt @map("updated_at")
@@unique([conversationId], map: "unique_agent_session_conv")
@@index([agentId], map: "idx_agent_session_agent")
@@index([userId], map: "idx_agent_session_user")
@@index([status], map: "idx_agent_session_status")
@@map("agent_sessions")
@@schema("agent_schema")
}
/// Agent执行追踪表 - 记录每一步的执行详情
model AgentTrace {
id String @id @default(uuid())
sessionId String @map("session_id") /// 关联的会话
/// 追踪信息
traceId String @map("trace_id") /// 请求追踪ID用于关联日志
stepIndex Int @map("step_index") /// 步骤序号
stepType String @map("step_type") /// query, plan, execute, reflect, tool_call
/// 输入输出
input Json? @db.JsonB /// 步骤输入
output Json? @db.JsonB /// 步骤输出
/// 执行信息
stageCode String? @map("stage_code") /// 执行时的阶段
modelUsed String? @map("model_used") /// 使用的模型
tokensUsed Int? @map("tokens_used") /// 消耗的Token
durationMs Int? @map("duration_ms") /// 执行时长(毫秒)
/// 错误信息
errorType String? @map("error_type") /// 错误类型
errorMsg String? @map("error_msg") /// 错误信息
/// 关联
session AgentSession @relation(fields: [sessionId], references: [id], onDelete: Cascade)
createdAt DateTime @default(now()) @map("created_at")
@@index([sessionId, stepIndex], map: "idx_agent_trace_session_step")
@@index([traceId], map: "idx_agent_trace_trace_id")
@@index([stepType], map: "idx_agent_trace_step_type")
@@map("agent_traces")
@@schema("agent_schema")
}
/// Reflexion规则表 - 定义质量检查规则
model ReflexionRule {
id String @id @default(uuid())
agentId String @map("agent_id") /// 关联的Agent
/// 规则标识
ruleCode String @map("rule_code") /// 规则代码
ruleName String @map("rule_name") /// 规则名称
/// 触发条件
triggerStage String? @map("trigger_stage") /// 触发阶段null表示全局
triggerTiming String @map("trigger_timing") /// on_sync, on_stage_complete, on_generate
/// 规则类型
ruleType String @map("rule_type") /// rule_based, prompt_based
/// 规则内容
conditions Json? @db.JsonB /// 规则条件rule_based时使用
promptTemplate String? @map("prompt_template") @db.Text /// Prompt模板prompt_based时使用
/// 行为配置
severity String @default("warning") /// error, warning, info
failureAction String @default("warn") @map("failure_action") /// block, warn, log
/// 状态
isActive Boolean @default(true) @map("is_active")
sortOrder Int @default(0) @map("sort_order")
/// 关联
agent AgentDefinition @relation(fields: [agentId], references: [id], onDelete: Cascade)
createdAt DateTime @default(now()) @map("created_at")
updatedAt DateTime @updatedAt @map("updated_at")
@@unique([agentId, ruleCode], map: "unique_agent_rule")
@@index([agentId, triggerStage], map: "idx_reflexion_rule_agent_stage")
@@index([isActive], map: "idx_reflexion_rule_active")
@@map("reflexion_rules")
@@schema("agent_schema")
}
// ============================================================
// Protocol Agent Schema (protocol_schema)
// Protocol Agent专用 - 研究方案制定
// ============================================================
/// Protocol Context表 - 存储研究方案的核心上下文数据
model ProtocolContext {
id String @id @default(uuid())
conversationId String @unique @map("conversation_id") /// 关联的对话ID
userId String @map("user_id")
/// 当前状态
currentStage String @default("scientific_question") @map("current_stage")
status String @default("in_progress") /// in_progress, completed, abandoned
/// ===== 5个核心阶段数据 =====
/// 阶段1: 科学问题
scientificQuestion Json? @map("scientific_question") @db.JsonB
/// { content, background, significance, confirmed, confirmedAt }
/// 阶段2: PICO
pico Json? @db.JsonB
/// { P: {value, details}, I: {}, C: {}, O: {}, confirmed, confirmedAt }
/// 阶段3: 研究设计
studyDesign Json? @map("study_design") @db.JsonB
/// { type, blinding, randomization, duration, multiCenter, confirmed }
/// 阶段4: 样本量
sampleSize Json? @map("sample_size") @db.JsonB
/// { total, perGroup, alpha, power, effectSize, dropoutRate, justification, confirmed }
/// 阶段5: 观察指标(终点指标)
endpoints Json? @db.JsonB
/// { primary: [{name, definition, method, timePoint}], secondary: [], safety: [], confirmed }
/// ===== 元数据 =====
completedStages String[] @default([]) @map("completed_stages") /// 已完成的阶段列表
lastActiveAt DateTime @default(now()) @map("last_active_at")
/// 关联
generations ProtocolGeneration[]
createdAt DateTime @default(now()) @map("created_at")
updatedAt DateTime @updatedAt @map("updated_at")
@@index([userId], map: "idx_protocol_context_user")
@@index([status], map: "idx_protocol_context_status")
@@index([currentStage], map: "idx_protocol_context_stage")
@@map("protocol_contexts")
@@schema("protocol_schema")
}
/// Protocol生成记录表 - 存储一键生成的研究方案
model ProtocolGeneration {
id String @id @default(uuid())
contextId String @map("context_id") /// 关联的Context
userId String @map("user_id")
/// 生成内容
generatedContent String @map("generated_content") @db.Text /// 生成的研究方案全文Markdown
contentVersion Int @default(1) @map("content_version") /// 版本号
/// 使用的Prompt
promptUsed String @map("prompt_used") @db.Text /// 实际使用的Prompt
/// 生成参数
modelUsed String @map("model_used") /// 使用的模型
tokensUsed Int? @map("tokens_used") /// 消耗的Token
durationMs Int? @map("duration_ms") /// 生成耗时(毫秒)
/// 导出记录
wordFileKey String? @map("word_file_key") /// Word文件OSS Key
lastExportedAt DateTime? @map("last_exported_at")
/// 状态
status String @default("completed") /// generating, completed, failed
errorMessage String? @map("error_message")
/// 关联
context ProtocolContext @relation(fields: [contextId], references: [id], onDelete: Cascade)
createdAt DateTime @default(now()) @map("created_at")
updatedAt DateTime @updatedAt @map("updated_at")
@@index([contextId], map: "idx_protocol_gen_context")
@@index([userId, createdAt], map: "idx_protocol_gen_user_time")
@@map("protocol_generations")
@@schema("protocol_schema")
}

View File

@@ -0,0 +1,560 @@
/**
* Protocol Agent 初始配置数据种子
*
* 运行方式: npx tsx prisma/seeds/protocol-agent-seed.ts
*/
import { PrismaClient } from '@prisma/client';
const prisma = new PrismaClient();
async function main() {
console.log('🌱 Seeding Protocol Agent configuration...');
// 1. 创建Agent定义
const agentDefinition = await prisma.agentDefinition.upsert({
where: { code: 'protocol_agent' },
update: {},
create: {
code: 'protocol_agent',
name: '研究方案制定助手',
description: '帮助研究者系统地制定临床研究方案覆盖科学问题、PICO、研究设计、样本量和观察指标5个核心阶段',
version: '1.0.0',
config: {
defaultModel: 'deepseek-v3',
maxTurns: 100,
timeout: 60000,
enableTrace: true,
enableReflexion: true,
},
isActive: true,
},
});
console.log('✅ Created Agent Definition:', agentDefinition.code);
// 2. 创建5个阶段
const stages = [
{
stageCode: 'scientific_question',
stageName: '科学问题梳理',
sortOrder: 1,
isInitial: true,
isFinal: false,
nextStages: ['pico'],
config: {
requiredFields: ['content'],
minContentLength: 10,
},
},
{
stageCode: 'pico',
stageName: 'PICO要素',
sortOrder: 2,
isInitial: false,
isFinal: false,
nextStages: ['study_design'],
config: {
requiredFields: ['P', 'I', 'C', 'O'],
},
},
{
stageCode: 'study_design',
stageName: '研究设计',
sortOrder: 3,
isInitial: false,
isFinal: false,
nextStages: ['sample_size'],
config: {
requiredFields: ['type'],
},
},
{
stageCode: 'sample_size',
stageName: '样本量计算',
sortOrder: 4,
isInitial: false,
isFinal: false,
nextStages: ['endpoints'],
config: {
requiredFields: ['total'],
},
},
{
stageCode: 'endpoints',
stageName: '观察指标',
sortOrder: 5,
isInitial: false,
isFinal: true,
nextStages: [],
config: {
requiredFields: ['primary'],
},
},
];
for (const stage of stages) {
await prisma.agentStage.upsert({
where: {
agentId_stageCode: {
agentId: agentDefinition.id,
stageCode: stage.stageCode,
},
},
update: stage,
create: {
agentId: agentDefinition.id,
...stage,
},
});
console.log(` ✅ Stage: ${stage.stageName}`);
}
// 3. 创建Prompt模板
const prompts = [
// 系统Prompt
{
promptType: 'system',
promptCode: 'protocol_system',
content: `你是一位经验丰富的临床研究方法学专家,正在帮助研究者制定研究方案。
你的职责:
1. 系统引导用户完成研究方案的5个核心要素科学问题、PICO、研究设计、样本量、观察指标
2. 提供专业、准确的方法学建议
3. 确保研究设计的科学性和可行性
4. 使用通俗易懂的语言,同时保持学术严谨性
当前阶段: {{context.currentStage}}
已完成阶段: {{context.completedStages}}
请根据用户的输入,提供专业指导。`,
variables: ['context'],
},
// 科学问题阶段Prompt
{
promptType: 'stage',
promptCode: 'stage_scientific_question',
stageCode: 'scientific_question',
content: `【科学问题梳理阶段】
你正在帮助用户梳理研究的科学问题。一个好的科学问题应该:
- 明确、具体、可操作
- 有实际的临床或学术意义
- 可通过研究方法验证
{{#if context.scientificQuestion}}
用户当前的科学问题草稿:
{{context.scientificQuestion.content}}
{{/if}}
请引导用户:
1. 描述研究背景和动机
2. 明确想要解决的核心问题
3. 阐述研究的潜在意义
当用户表达清晰后,帮助整理成规范的科学问题陈述,并提供"同步到方案"按钮。`,
variables: ['context'],
},
// PICO阶段Prompt
{
promptType: 'stage',
promptCode: 'stage_pico',
stageCode: 'pico',
content: `【PICO要素梳理阶段】
PICO是临床研究问题结构化的核心框架
- P (Population): 研究人群
- I (Intervention): 干预措施
- C (Comparison): 对照措施
- O (Outcome): 结局指标
{{#if context.pico}}
当前PICO
- P: {{context.pico.P.value}}
- I: {{context.pico.I.value}}
- C: {{context.pico.C.value}}
- O: {{context.pico.O.value}}
{{/if}}
请引导用户逐一明确四个要素,确保:
1. P: 纳入标准、排除标准清晰
2. I: 干预措施具体可操作
3. C: 对照组设置合理
4. O: 结局指标可测量、有临床意义
当四要素都明确后,提供"同步到方案"按钮。`,
variables: ['context'],
},
// 研究设计阶段Prompt
{
promptType: 'stage',
promptCode: 'stage_study_design',
stageCode: 'study_design',
content: `【研究设计阶段】
根据科学问题和PICO需要确定合适的研究设计
科学问题:{{context.scientificQuestion.content}}
PICO
- P: {{context.pico.P.value}}
- I: {{context.pico.I.value}}
常见研究类型:
- 随机对照试验(RCT):最高证据等级,适合验证干预效果
- 队列研究:适合观察性研究,探索风险因素
- 病例对照研究:适合罕见疾病研究
- 横断面研究:描述性研究
请引导用户确定:
1. 研究类型
2. 盲法设计(如适用)
3. 随机化方法(如适用)
4. 研究周期
5. 是否多中心
设计确定后,提供"同步到方案"按钮。`,
variables: ['context'],
},
// 样本量阶段Prompt
{
promptType: 'stage',
promptCode: 'stage_sample_size',
stageCode: 'sample_size',
content: `【样本量计算阶段】
样本量计算需要考虑:
- α (显著性水平): 通常0.05
- β (统计效力): 通常0.8-0.9
- 预期效应量
- 预计脱落率
研究设计:{{context.studyDesign.type}}
主要结局:{{context.pico.O.value}}
请引导用户:
1. 确定检验类型(优效、非劣效、等效)
2. 估计预期效应量(基于文献或预试验)
3. 设定显著性水平和统计效力
4. 考虑脱落率调整
可以使用样本量计算工具辅助计算。
样本量确定后,提供"同步到方案"按钮。`,
variables: ['context'],
},
// 观察指标阶段Prompt
{
promptType: 'stage',
promptCode: 'stage_endpoints',
stageCode: 'endpoints',
content: `【观察指标设计阶段】
观察指标是评价研究结果的关键:
研究类型:{{context.studyDesign.type}}
PICO-O{{context.pico.O.value}}
需要明确的指标类型:
1. **主要结局指标(Primary Endpoint)**
- 与科学问题直接相关
- 用于样本量计算
- 每个研究通常只有1-2个
2. **次要结局指标(Secondary Endpoints)**
- 支持主要结局的补充指标
- 可以有多个
3. **安全性指标(Safety Endpoints)**
- 不良事件、实验室检查等
4. **探索性指标(Exploratory)**
- 为未来研究提供线索
请引导用户定义每个指标的:
- 名称
- 操作定义
- 测量方法
- 评价时点
所有指标确定后,提供"同步到方案"按钮。
🎉 完成观察指标后,您可以点击"一键生成研究方案"生成完整方案文档!`,
variables: ['context'],
},
// 数据提取Prompt
{
promptType: 'extraction',
promptCode: 'extraction_scientific_question',
stageCode: 'scientific_question',
content: `请从以下对话中提取科学问题信息:
用户消息:{{userMessage}}
请以JSON格式输出
{
"content": "完整的科学问题陈述",
"background": "研究背景",
"significance": "研究意义",
"readyToSync": true/false
}
如果信息不完整readyToSync设为false。`,
variables: ['userMessage'],
},
{
promptType: 'extraction',
promptCode: 'extraction_pico',
stageCode: 'pico',
content: `请从以下对话中提取PICO要素
用户消息:{{userMessage}}
当前PICO{{currentPico}}
请以JSON格式输出
{
"P": { "value": "研究人群", "details": "详细描述" },
"I": { "value": "干预措施", "details": "详细描述" },
"C": { "value": "对照措施", "details": "详细描述" },
"O": { "value": "结局指标", "details": "详细描述" },
"readyToSync": true/false
}
只更新用户提到的字段,保留其他字段不变。
如果PICO四要素都已完整readyToSync设为true。`,
variables: ['userMessage', 'currentPico'],
},
// 研究方案生成Prompt
{
promptType: 'generation',
promptCode: 'generate_protocol',
content: `你是一位资深的临床研究方法学专家,请基于以下核心要素生成一份完整、规范的临床研究方案。
## 核心要素
### 科学问题
{{scientificQuestion.content}}
{{#if scientificQuestion.background}}背景:{{scientificQuestion.background}}{{/if}}
### PICO要素
- **研究人群(P)**: {{pico.P.value}}
{{pico.P.details}}
- **干预措施(I)**: {{pico.I.value}}
{{pico.I.details}}
- **对照措施(C)**: {{pico.C.value}}
{{pico.C.details}}
- **结局指标(O)**: {{pico.O.value}}
{{pico.O.details}}
### 研究设计
- 研究类型: {{studyDesign.type}}
- 盲法设计: {{studyDesign.blinding}}
- 随机化方法: {{studyDesign.randomization}}
- 研究周期: {{studyDesign.duration}}
{{#if studyDesign.multiCenter}}- 多中心: 是,{{studyDesign.centerCount}}个中心{{/if}}
### 样本量
- 总样本量: {{sampleSize.total}}
- 每组样本量: {{sampleSize.perGroup}}
- 计算依据: {{sampleSize.justification}}
### 观察指标
**主要结局指标:**
{{#each endpoints.primary}}
- {{name}}: {{definition}} ({{method}}, {{timePoint}})
{{/each}}
**次要结局指标:**
{{#each endpoints.secondary}}
- {{name}}: {{definition}}
{{/each}}
**安全性指标:**
{{#each endpoints.safety}}
- {{name}}: {{definition}}
{{/each}}
---
## 生成要求
请生成包含以下章节的完整研究方案:
1. **研究背景与立题依据**
- 疾病/问题背景
- 国内外研究现状
- 研究的必要性和意义
2. **研究目的**
- 主要目的
- 次要目的
3. **研究方法**
- 研究类型与设计
- 研究对象
- 干预措施
- 对照设置
- 随机化与盲法
4. **受试者选择**
- 入选标准
- 排除标准
- 退出/剔除标准
5. **观察指标与评价标准**
- 主要疗效指标
- 次要疗效指标
- 安全性指标
- 评价时点
6. **统计分析计划**
- 样本量估算
- 分析数据集定义
- 统计方法
7. **质量控制**
- 数据管理
- 质量保证措施
8. **伦理考虑**
- 伦理审查
- 知情同意
- 受试者保护
9. **研究进度安排**
- 时间节点
- 里程碑
请使用专业、规范的学术语言,确保内容完整、逻辑清晰、符合临床研究规范。`,
variables: ['scientificQuestion', 'pico', 'studyDesign', 'sampleSize', 'endpoints'],
},
];
// 获取阶段ID映射
const stageIdMap = new Map<string, string>();
const savedStages = await prisma.agentStage.findMany({
where: { agentId: agentDefinition.id },
});
for (const stage of savedStages) {
stageIdMap.set(stage.stageCode, stage.id);
}
// 创建Prompts
for (const prompt of prompts) {
const stageId = prompt.stageCode ? stageIdMap.get(prompt.stageCode) : null;
await prisma.agentPrompt.upsert({
where: {
agentId_promptCode_version: {
agentId: agentDefinition.id,
promptCode: prompt.promptCode,
version: 1,
},
},
update: {
content: prompt.content,
variables: prompt.variables,
},
create: {
agentId: agentDefinition.id,
stageId: stageId,
promptType: prompt.promptType,
promptCode: prompt.promptCode,
content: prompt.content,
variables: prompt.variables,
version: 1,
isActive: true,
},
});
console.log(` ✅ Prompt: ${prompt.promptCode}`);
}
// 4. 创建Reflexion规则
const reflexionRules = [
{
ruleCode: 'scientific_question_completeness',
ruleName: '科学问题完整性检查',
triggerStage: 'scientific_question',
triggerTiming: 'on_sync',
ruleType: 'rule_based',
conditions: {
content: { required: true, minLength: 10 },
},
severity: 'warning',
failureAction: 'warn',
sortOrder: 1,
},
{
ruleCode: 'pico_completeness',
ruleName: 'PICO要素完整性检查',
triggerStage: 'pico',
triggerTiming: 'on_sync',
ruleType: 'rule_based',
conditions: {
P: 'required',
I: 'required',
C: 'required',
O: 'required',
},
severity: 'error',
failureAction: 'warn',
sortOrder: 2,
},
{
ruleCode: 'sample_size_validity',
ruleName: '样本量有效性检查',
triggerStage: 'sample_size',
triggerTiming: 'on_sync',
ruleType: 'rule_based',
conditions: {
total: { required: true, min: 1 },
},
severity: 'error',
failureAction: 'warn',
sortOrder: 3,
},
{
ruleCode: 'endpoints_primary_required',
ruleName: '主要终点指标必填',
triggerStage: 'endpoints',
triggerTiming: 'on_sync',
ruleType: 'rule_based',
conditions: {
primary: { notEmpty: true },
},
severity: 'error',
failureAction: 'warn',
sortOrder: 4,
},
];
for (const rule of reflexionRules) {
await prisma.reflexionRule.upsert({
where: {
agentId_ruleCode: {
agentId: agentDefinition.id,
ruleCode: rule.ruleCode,
},
},
update: rule,
create: {
agentId: agentDefinition.id,
...rule,
isActive: true,
},
});
console.log(` ✅ Rule: ${rule.ruleName}`);
}
console.log('\n🎉 Protocol Agent configuration seeded successfully!');
}
main()
.catch((e) => {
console.error('❌ Seed failed:', e);
process.exit(1);
})
.finally(async () => {
await prisma.$disconnect();
});

View File

@@ -135,6 +135,8 @@ Write-Host ""

View File

@@ -245,6 +245,8 @@ function extractCodeBlocks(obj, blocks = []) {

View File

@@ -43,5 +43,7 @@ CREATE TABLE IF NOT EXISTS platform_schema.job_common (

View File

@@ -117,5 +117,7 @@ CREATE OR REPLACE FUNCTION platform_schema.delete_queue(queue_name text) RETURNS

View File

@@ -264,6 +264,8 @@ checkDCTables();

View File

@@ -18,5 +18,7 @@ CREATE SCHEMA IF NOT EXISTS capability_schema;

View File

@@ -216,6 +216,8 @@ createAiHistoryTable()

View File

@@ -203,6 +203,8 @@ createToolCTable()

View File

@@ -200,6 +200,8 @@ createToolCTable()

View File

@@ -323,3 +323,5 @@ main()

View File

@@ -128,5 +128,7 @@ main()

View File

@@ -347,6 +347,8 @@ runTests().catch(error => {

View File

@@ -94,5 +94,7 @@ testAPI().catch(console.error);

View File

@@ -126,3 +126,5 @@ testDeepSearch().catch(console.error);

View File

@@ -312,6 +312,8 @@ verifySchemas()

View File

@@ -201,5 +201,7 @@ export const jwtService = new JWTService();

View File

@@ -332,6 +332,8 @@ export function getBatchItems<T>(

View File

@@ -84,5 +84,7 @@ export interface VariableValidation {

View File

@@ -356,3 +356,5 @@ export default ChunkService;

View File

@@ -52,3 +52,5 @@ export const DifyClient = DeprecatedDifyClient;

View File

@@ -207,3 +207,5 @@ export function createOpenAIStreamAdapter(

View File

@@ -213,3 +213,5 @@ export async function streamChat(

View File

@@ -31,3 +31,5 @@ export { THINKING_TAGS } from './types';

View File

@@ -106,3 +106,5 @@ export type SSEEventType =

View File

@@ -159,6 +159,15 @@ logger.info('✅ PKB个人知识库路由已注册: /api/v1/pkb');
await fastify.register(aiaRoutes, { prefix: '/api/v1/aia' });
logger.info('✅ AIA智能问答路由已注册: /api/v1/aia');
// ============================================
// 【业务模块】Protocol Agent - 研究方案制定Agent
// ============================================
import { protocolAgentRoutes } from './modules/agent/protocol/index.js';
await fastify.register((instance, opts, done) => {
protocolAgentRoutes(instance, { prisma, ...opts }).then(() => done()).catch(done);
}, { prefix: '/api/v1/aia/protocol-agent' });
logger.info('✅ Protocol Agent路由已注册: /api/v1/aia/protocol-agent');
// ============================================
// 【业务模块】ASL - AI智能文献筛选
// ============================================

View File

@@ -92,3 +92,5 @@ export async function moduleRoutes(fastify: FastifyInstance) {

View File

@@ -122,3 +122,5 @@ export interface PaginatedResponse<T> {

View File

@@ -169,3 +169,5 @@ export const ROLE_DISPLAY_NAMES: Record<UserRole, string> = {

View File

@@ -0,0 +1,25 @@
/**
* Agent Framework Module
*
* 通用Agent框架入口
*
* @module agent
*/
// Types
export * from './types/index.js';
// Services
export {
ConfigLoader,
BaseAgentOrchestrator,
QueryAnalyzer,
StageManager,
TraceLogger,
type OrchestratorDependencies,
type LLMServiceInterface,
type IntentType,
type TransitionCondition,
type TraceLogInput,
} from './services/index.js';

View File

@@ -0,0 +1,287 @@
/**
* Protocol Agent Controller
* 处理Protocol Agent的HTTP请求
*
* @module agent/protocol/controllers/ProtocolAgentController
*/
import { FastifyRequest, FastifyReply } from 'fastify';
import { PrismaClient } from '@prisma/client';
import { ProtocolOrchestrator } from '../services/ProtocolOrchestrator.js';
import { LLMServiceInterface } from '../../services/BaseAgentOrchestrator.js';
import { ProtocolStageCode } from '../../types/index.js';
// 请求类型定义
interface SendMessageBody {
conversationId: string;
content: string;
messageId?: string;
}
interface SyncDataBody {
conversationId: string;
stageCode: ProtocolStageCode;
data: Record<string, unknown>;
}
interface GenerateProtocolBody {
conversationId: string;
options?: {
sections?: string[];
style?: 'academic' | 'concise';
};
}
interface GetContextParams {
conversationId: string;
}
export class ProtocolAgentController {
private orchestrator: ProtocolOrchestrator;
constructor(prisma: PrismaClient, llmService: LLMServiceInterface) {
this.orchestrator = new ProtocolOrchestrator({ prisma, llmService });
}
/**
* 发送消息
* POST /api/aia/protocol-agent/message
*/
async sendMessage(
request: FastifyRequest<{ Body: SendMessageBody }>,
reply: FastifyReply
): Promise<void> {
try {
const { conversationId, content, messageId } = request.body;
const userId = (request as any).user?.userId;
if (!userId) {
reply.code(401).send({ error: 'Unauthorized' });
return;
}
if (!conversationId || !content) {
reply.code(400).send({ error: 'Missing required fields: conversationId, content' });
return;
}
const response = await this.orchestrator.handleMessage({
conversationId,
userId,
content,
messageId,
});
reply.send({
success: true,
data: response,
});
} catch (error) {
console.error('[ProtocolAgentController] sendMessage error:', error);
reply.code(500).send({
success: false,
error: error instanceof Error ? error.message : 'Internal server error',
});
}
}
/**
* 同步阶段数据
* POST /api/aia/protocol-agent/sync
*/
async syncData(
request: FastifyRequest<{ Body: SyncDataBody }>,
reply: FastifyReply
): Promise<void> {
try {
const { conversationId, stageCode, data } = request.body;
const userId = (request as any).user?.userId;
if (!userId) {
reply.code(401).send({ error: 'Unauthorized' });
return;
}
if (!conversationId || !stageCode) {
reply.code(400).send({ error: 'Missing required fields: conversationId, stageCode' });
return;
}
const result = await this.orchestrator.handleProtocolSync(
conversationId,
userId,
stageCode,
data
);
reply.send({
success: true,
data: result,
});
} catch (error) {
console.error('[ProtocolAgentController] syncData error:', error);
reply.code(500).send({
success: false,
error: error instanceof Error ? error.message : 'Internal server error',
});
}
}
/**
* 获取上下文状态
* GET /api/aia/protocol-agent/context/:conversationId
*/
async getContext(
request: FastifyRequest<{ Params: GetContextParams }>,
reply: FastifyReply
): Promise<void> {
try {
const { conversationId } = request.params;
if (!conversationId) {
reply.code(400).send({ error: 'Missing conversationId' });
return;
}
const summary = await this.orchestrator.getContextSummary(conversationId);
if (!summary) {
reply.code(404).send({ error: 'Context not found' });
return;
}
reply.send({
success: true,
data: summary,
});
} catch (error) {
console.error('[ProtocolAgentController] getContext error:', error);
reply.code(500).send({
success: false,
error: error instanceof Error ? error.message : 'Internal server error',
});
}
}
/**
* 一键生成研究方案
* POST /api/aia/protocol-agent/generate
*/
async generateProtocol(
request: FastifyRequest<{ Body: GenerateProtocolBody }>,
reply: FastifyReply
): Promise<void> {
try {
const { conversationId, options } = request.body;
const userId = (request as any).user?.userId;
if (!userId) {
reply.code(401).send({ error: 'Unauthorized' });
return;
}
if (!conversationId) {
reply.code(400).send({ error: 'Missing conversationId' });
return;
}
// 获取上下文
const contextService = this.orchestrator.getContextService();
const context = await contextService.getContext(conversationId);
if (!context) {
reply.code(404).send({ error: 'Context not found' });
return;
}
// 检查是否所有阶段都已完成
if (!contextService.isAllStagesCompleted(context)) {
reply.code(400).send({
error: '请先完成所有5个阶段科学问题、PICO、研究设计、样本量、观察指标'
});
return;
}
// TODO: 实现方案生成逻辑
// 这里先返回占位响应实际应该调用LLM生成完整方案
reply.send({
success: true,
data: {
generationId: 'placeholder',
status: 'generating',
message: '研究方案生成中...',
estimatedTime: 30,
},
});
} catch (error) {
console.error('[ProtocolAgentController] generateProtocol error:', error);
reply.code(500).send({
success: false,
error: error instanceof Error ? error.message : 'Internal server error',
});
}
}
/**
* 获取生成的方案
* GET /api/aia/protocol-agent/generation/:generationId
*/
async getGeneration(
request: FastifyRequest<{ Params: { generationId: string } }>,
reply: FastifyReply
): Promise<void> {
try {
const { generationId } = request.params;
// TODO: 实现获取生成结果逻辑
reply.send({
success: true,
data: {
id: generationId,
status: 'completed',
content: '# 研究方案\n\n生成中...',
contentVersion: 1,
},
});
} catch (error) {
console.error('[ProtocolAgentController] getGeneration error:', error);
reply.code(500).send({
success: false,
error: error instanceof Error ? error.message : 'Internal server error',
});
}
}
/**
* 导出Word文档
* POST /api/aia/protocol-agent/generation/:generationId/export
*/
async exportWord(
request: FastifyRequest<{
Params: { generationId: string };
Body: { format: 'docx' | 'pdf' };
}>,
reply: FastifyReply
): Promise<void> {
try {
const { generationId } = request.params;
const { format } = request.body;
// TODO: 实现导出逻辑
reply.send({
success: true,
data: {
downloadUrl: `/api/aia/protocol-agent/download/${generationId}.${format}`,
expiresAt: new Date(Date.now() + 3600 * 1000).toISOString(),
},
});
} catch (error) {
console.error('[ProtocolAgentController] exportWord error:', error);
reply.code(500).send({
success: false,
error: error instanceof Error ? error.message : 'Internal server error',
});
}
}
}

View File

@@ -0,0 +1,22 @@
/**
* Protocol Agent Module
* 研究方案制定Agent
*
* @module agent/protocol
*/
// Services
export {
ProtocolContextService,
ProtocolOrchestrator,
PromptBuilder,
LLMServiceAdapter,
createLLMServiceAdapter,
} from './services/index.js';
// Routes
export { protocolAgentRoutes } from './routes/index.js';
// Controllers
export { ProtocolAgentController } from './controllers/ProtocolAgentController.js';

View File

@@ -0,0 +1,153 @@
/**
* Protocol Agent Routes
* Protocol Agent的API路由定义
*
* @module agent/protocol/routes
*/
import { FastifyInstance } from 'fastify';
import { PrismaClient } from '@prisma/client';
import { ProtocolAgentController } from '../controllers/ProtocolAgentController.js';
import { createLLMServiceAdapter } from '../services/LLMServiceAdapter.js';
import { authenticate } from '../../../../common/auth/auth.middleware.js';
export async function protocolAgentRoutes(
fastify: FastifyInstance,
options: { prisma: PrismaClient }
): Promise<void> {
const { prisma } = options;
// 创建LLM服务使用真实的LLM适配器
const llmService = createLLMServiceAdapter();
// 创建控制器
const controller = new ProtocolAgentController(prisma, llmService);
// ============================================
// Protocol Agent API Routes (需要认证)
// ============================================
// 发送消息
fastify.post<{
Body: {
conversationId: string;
content: string;
messageId?: string;
};
}>('/message', {
preHandler: [authenticate],
schema: {
body: {
type: 'object',
required: ['conversationId', 'content'],
properties: {
conversationId: { type: 'string' },
content: { type: 'string' },
messageId: { type: 'string' },
},
},
},
}, (request, reply) => controller.sendMessage(request, reply));
// 同步阶段数据
fastify.post('/sync', {
preHandler: [authenticate],
schema: {
body: {
type: 'object',
required: ['conversationId', 'stageCode'],
properties: {
conversationId: { type: 'string' },
stageCode: { type: 'string' },
data: { type: 'object' },
},
},
},
}, (request, reply) => controller.syncData(request as never, reply));
// 获取上下文状态
fastify.get<{
Params: { conversationId: string };
}>('/context/:conversationId', {
preHandler: [authenticate],
schema: {
params: {
type: 'object',
required: ['conversationId'],
properties: {
conversationId: { type: 'string' },
},
},
},
}, (request, reply) => controller.getContext(request, reply));
// 一键生成研究方案
fastify.post<{
Body: {
conversationId: string;
options?: {
sections?: string[];
style?: 'academic' | 'concise';
};
};
}>('/generate', {
preHandler: [authenticate],
schema: {
body: {
type: 'object',
required: ['conversationId'],
properties: {
conversationId: { type: 'string' },
options: {
type: 'object',
properties: {
sections: { type: 'array', items: { type: 'string' } },
style: { type: 'string', enum: ['academic', 'concise'] },
},
},
},
},
},
}, (request, reply) => controller.generateProtocol(request, reply));
// 获取生成的方案
fastify.get<{
Params: { generationId: string };
}>('/generation/:generationId', {
preHandler: [authenticate],
schema: {
params: {
type: 'object',
required: ['generationId'],
properties: {
generationId: { type: 'string' },
},
},
},
}, (request, reply) => controller.getGeneration(request, reply));
// 导出Word文档
fastify.post<{
Params: { generationId: string };
Body: { format: 'docx' | 'pdf' };
}>('/generation/:generationId/export', {
preHandler: [authenticate],
schema: {
params: {
type: 'object',
required: ['generationId'],
properties: {
generationId: { type: 'string' },
},
},
body: {
type: 'object',
required: ['format'],
properties: {
format: { type: 'string', enum: ['docx', 'pdf'] },
},
},
},
}, (request, reply) => controller.exportWord(request, reply));
}

View File

@@ -0,0 +1,184 @@
/**
* LLM Service Adapter
* 将现有的LLM服务适配为Protocol Agent所需的接口
*
* @module agent/protocol/services/LLMServiceAdapter
*/
import { LLMFactory } from '../../../../common/llm/adapters/LLMFactory.js';
import { ModelType, Message as LLMMessage, LLMOptions } from '../../../../common/llm/adapters/types.js';
import { LLMServiceInterface } from '../../services/BaseAgentOrchestrator.js';
export class LLMServiceAdapter implements LLMServiceInterface {
private defaultModel: ModelType;
private defaultTemperature: number;
private defaultMaxTokens: number;
constructor(options?: {
defaultModel?: ModelType;
defaultTemperature?: number;
defaultMaxTokens?: number;
}) {
this.defaultModel = options?.defaultModel ?? 'deepseek-v3';
this.defaultTemperature = options?.defaultTemperature ?? 0.7;
this.defaultMaxTokens = options?.defaultMaxTokens ?? 4096;
}
/**
* 调用LLM进行对话
*/
async chat(params: {
messages: Array<{ role: string; content: string }>;
model?: string;
temperature?: number;
maxTokens?: number;
}): Promise<{
content: string;
thinkingContent?: string;
tokensUsed: number;
model: string;
}> {
// 获取模型类型
const modelType = this.parseModelType(params.model);
// 获取LLM适配器
const adapter = LLMFactory.getAdapter(modelType);
// 转换消息格式
const messages: LLMMessage[] = params.messages.map(m => ({
role: m.role as 'system' | 'user' | 'assistant',
content: m.content,
}));
// 调用LLM
const options: LLMOptions = {
temperature: params.temperature ?? this.defaultTemperature,
maxTokens: params.maxTokens ?? this.defaultMaxTokens,
};
try {
const response = await adapter.chat(messages, options);
// 提取思考内容(如果有)
const { content, thinkingContent } = this.extractThinkingContent(response.content);
return {
content,
thinkingContent,
tokensUsed: response.usage?.totalTokens ?? 0,
model: response.model,
};
} catch (error) {
console.error('[LLMServiceAdapter] chat error:', error);
throw error;
}
}
/**
* 流式调用LLM返回AsyncGenerator
*/
async *chatStream(params: {
messages: Array<{ role: string; content: string }>;
model?: string;
temperature?: number;
maxTokens?: number;
}): AsyncGenerator<{
content: string;
done: boolean;
tokensUsed?: number;
}> {
const modelType = this.parseModelType(params.model);
const adapter = LLMFactory.getAdapter(modelType);
const messages: LLMMessage[] = params.messages.map(m => ({
role: m.role as 'system' | 'user' | 'assistant',
content: m.content,
}));
const options: LLMOptions = {
temperature: params.temperature ?? this.defaultTemperature,
maxTokens: params.maxTokens ?? this.defaultMaxTokens,
stream: true,
};
try {
for await (const chunk of adapter.chatStream(messages, options)) {
yield {
content: chunk.content,
done: chunk.done,
tokensUsed: chunk.usage?.totalTokens,
};
}
} catch (error) {
console.error('[LLMServiceAdapter] chatStream error:', error);
throw error;
}
}
/**
* 解析模型类型
*/
private parseModelType(model?: string): ModelType {
if (!model) return this.defaultModel;
// 映射模型名称到ModelType
const modelMap: Record<string, ModelType> = {
'deepseek-v3': 'deepseek-v3',
'deepseek-chat': 'deepseek-v3',
'qwen-max': 'qwen3-72b',
'qwen3-72b': 'qwen3-72b',
'qwen-long': 'qwen-long',
'gpt-5': 'gpt-5',
'gpt-5-pro': 'gpt-5',
'claude-4.5': 'claude-4.5',
'claude-sonnet': 'claude-4.5',
};
return modelMap[model.toLowerCase()] ?? this.defaultModel;
}
/**
* 从响应中提取思考内容(<think>...</think>
*/
private extractThinkingContent(content: string): {
content: string;
thinkingContent?: string;
} {
// 匹配 <think>...</think> 或 <thinking>...</thinking>
const thinkingPattern = /<(?:think|thinking)>([\s\S]*?)<\/(?:think|thinking)>/gi;
const matches = content.matchAll(thinkingPattern);
let thinkingContent = '';
let cleanContent = content;
for (const match of matches) {
thinkingContent += match[1].trim() + '\n';
cleanContent = cleanContent.replace(match[0], '').trim();
}
return {
content: cleanContent,
thinkingContent: thinkingContent ? thinkingContent.trim() : undefined,
};
}
/**
* 获取支持的模型列表
*/
getSupportedModels(): string[] {
return LLMFactory.getSupportedModels();
}
}
/**
* 创建默认的LLM服务适配器
*/
export function createLLMServiceAdapter(): LLMServiceInterface {
return new LLMServiceAdapter({
defaultModel: 'deepseek-v3',
defaultTemperature: 0.7,
defaultMaxTokens: 4096,
});
}

View File

@@ -0,0 +1,286 @@
/**
* Prompt Builder
* 构建和渲染Protocol Agent的Prompt
*
* @module agent/protocol/services/PromptBuilder
*/
import { ConfigLoader } from '../../services/ConfigLoader.js';
import {
AgentPrompt,
ProtocolContextData,
PromptRenderContext,
} from '../../types/index.js';
export class PromptBuilder {
private configLoader: ConfigLoader;
constructor(configLoader: ConfigLoader) {
this.configLoader = configLoader;
}
/**
* 构建完整的消息列表
*/
async buildMessages(
context: ProtocolContextData,
userMessage: string,
conversationHistory?: Array<{ role: string; content: string }>
): Promise<Array<{ role: string; content: string }>> {
const messages: Array<{ role: string; content: string }> = [];
// 1. 系统Prompt
const systemPrompt = await this.buildSystemPrompt(context);
if (systemPrompt) {
messages.push({ role: 'system', content: systemPrompt });
}
// 2. 阶段Prompt
const stagePrompt = await this.buildStagePrompt(context);
if (stagePrompt) {
messages.push({ role: 'system', content: stagePrompt });
}
// 3. 对话历史最近5轮
if (conversationHistory?.length) {
const recentHistory = conversationHistory.slice(-10); // 最近10条消息5轮
messages.push(...recentHistory);
}
// 4. 用户消息
messages.push({ role: 'user', content: userMessage });
return messages;
}
/**
* 构建系统Prompt
*/
async buildSystemPrompt(context: ProtocolContextData): Promise<string | null> {
const prompt = await this.configLoader.getSystemPrompt('protocol_agent');
if (!prompt) return null;
return this.renderTemplate(prompt.content, { context });
}
/**
* 构建当前阶段Prompt
*/
async buildStagePrompt(context: ProtocolContextData): Promise<string | null> {
const prompt = await this.configLoader.getStagePrompt(
'protocol_agent',
context.currentStage
);
if (!prompt) return null;
return this.renderTemplate(prompt.content, { context });
}
/**
* 构建数据提取Prompt
*/
async buildExtractionPrompt(
context: ProtocolContextData,
userMessage: string
): Promise<string | null> {
const prompt = await this.configLoader.getExtractionPrompt(
'protocol_agent',
context.currentStage
);
if (!prompt) return null;
return this.renderTemplate(prompt.content, {
userMessage,
context,
currentPico: context.pico,
});
}
/**
* 构建方案生成Prompt
*/
async buildGenerationPrompt(context: ProtocolContextData): Promise<string | null> {
const config = await this.configLoader.loadAgentConfig('protocol_agent');
const prompt = config.prompts.find(p => p.promptCode === 'generate_protocol');
if (!prompt) return null;
return this.renderTemplate(prompt.content, {
scientificQuestion: context.scientificQuestion,
pico: context.pico,
studyDesign: context.studyDesign,
sampleSize: context.sampleSize,
endpoints: context.endpoints,
});
}
/**
* 渲染模板
* 支持 {{variable}} 和 {{#if variable}}...{{/if}} 语法
*/
private renderTemplate(
template: string,
variables: Record<string, unknown>
): string {
let result = template;
// 处理 {{#if variable}}...{{/if}} 条件块
result = this.processConditionals(result, variables);
// 处理 {{#each array}}...{{/each}} 循环块
result = this.processEachBlocks(result, variables);
// 处理简单变量替换 {{variable}}
result = this.processVariables(result, variables);
return result;
}
/**
* 处理条件块
*/
private processConditionals(
template: string,
variables: Record<string, unknown>
): string {
const ifPattern = /\{\{#if\s+(\S+)\}\}([\s\S]*?)\{\{\/if\}\}/g;
return template.replace(ifPattern, (match, condition, content) => {
const value = this.getNestedValue(variables, condition);
if (value && value !== false && value !== null && value !== undefined) {
return content;
}
return '';
});
}
/**
* 处理循环块
*/
private processEachBlocks(
template: string,
variables: Record<string, unknown>
): string {
const eachPattern = /\{\{#each\s+(\S+)\}\}([\s\S]*?)\{\{\/each\}\}/g;
return template.replace(eachPattern, (match, arrayPath, content) => {
const array = this.getNestedValue(variables, arrayPath);
if (!Array.isArray(array)) return '';
return array.map((item, index) => {
let itemContent = content;
// 替换 {{this}} 为当前项
itemContent = itemContent.replace(/\{\{this\}\}/g, String(item));
// 替换 {{@index}} 为索引
itemContent = itemContent.replace(/\{\{@index\}\}/g, String(index));
// 替换项属性 {{name}}, {{definition}} 等
if (typeof item === 'object' && item !== null) {
for (const [key, value] of Object.entries(item)) {
const regex = new RegExp(`\\{\\{${key}\\}\\}`, 'g');
itemContent = itemContent.replace(regex, String(value ?? ''));
}
}
return itemContent;
}).join('\n');
});
}
/**
* 处理变量替换
*/
private processVariables(
template: string,
variables: Record<string, unknown>
): string {
const varPattern = /\{\{([^#/][^}]*)\}\}/g;
return template.replace(varPattern, (match, varPath) => {
const value = this.getNestedValue(variables, varPath.trim());
if (value === undefined || value === null) {
return '';
}
if (typeof value === 'object') {
return JSON.stringify(value, null, 2);
}
return String(value);
});
}
/**
* 获取嵌套属性值
*/
private getNestedValue(
obj: Record<string, unknown>,
path: string
): unknown {
const parts = path.split('.');
let current: unknown = obj;
for (const part of parts) {
if (current === null || current === undefined) {
return undefined;
}
if (typeof current !== 'object') {
return undefined;
}
current = (current as Record<string, unknown>)[part];
}
return current;
}
/**
* 构建欢迎消息
*/
buildWelcomeMessage(): string {
return `您好!我是研究方案制定助手,将帮助您系统地完成临床研究方案的核心要素设计。
我们将一起完成以下5个关键步骤
1⃣ **科学问题梳理** - 明确研究要解决的核心问题
2⃣ **PICO要素** - 确定研究人群、干预、对照和结局
3⃣ **研究设计** - 选择合适的研究类型和方法
4⃣ **样本量计算** - 估算所需的样本量
5⃣ **观察指标** - 定义主要和次要结局指标
完成这5个要素后您可以**一键生成完整的研究方案**并下载为Word文档。
让我们开始吧!请先告诉我,您想研究什么问题?或者描述一下您的研究背景和想法。`;
}
/**
* 构建阶段完成消息
*/
buildStageCompleteMessage(
stageName: string,
nextStageName?: string,
isAllCompleted: boolean = false
): string {
if (isAllCompleted) {
return `${stageName}已同步到方案!
🎉 **恭喜您已完成所有5个核心要素的梳理**
您现在可以:
- 点击「🚀 一键生成研究方案」生成完整方案
- 或者回顾修改任何阶段的内容
需要我帮您生成研究方案吗?`;
}
return `${stageName}已同步到方案!
接下来我们进入**${nextStageName}**阶段。准备好了吗?
说"继续"我们就开始,或者您也可以先问我任何问题。`;
}
}

View File

@@ -0,0 +1,289 @@
/**
* Protocol Context Service
* 管理研究方案的上下文数据
*
* @module agent/protocol/services/ProtocolContextService
*/
import { PrismaClient } from '@prisma/client';
import {
ProtocolContextData,
ProtocolStageCode,
ScientificQuestionData,
PICOData,
StudyDesignData,
SampleSizeData,
EndpointsData,
} from '../../types/index.js';
export class ProtocolContextService {
private prisma: PrismaClient;
constructor(prisma: PrismaClient) {
this.prisma = prisma;
}
/**
* 获取或创建上下文
*/
async getOrCreateContext(
conversationId: string,
userId: string
): Promise<ProtocolContextData> {
let context = await this.getContext(conversationId);
if (!context) {
context = await this.createContext(conversationId, userId);
}
return context;
}
/**
* 获取上下文
*/
async getContext(conversationId: string): Promise<ProtocolContextData | null> {
const result = await this.prisma.protocolContext.findUnique({
where: { conversationId },
});
if (!result) return null;
return this.mapToContextData(result);
}
/**
* 创建新上下文
*/
async createContext(
conversationId: string,
userId: string
): Promise<ProtocolContextData> {
const result = await this.prisma.protocolContext.create({
data: {
conversationId,
userId,
currentStage: 'scientific_question',
status: 'in_progress',
completedStages: [],
},
});
return this.mapToContextData(result);
}
/**
* 更新阶段数据
*/
async updateStageData(
conversationId: string,
stageCode: ProtocolStageCode,
data: Record<string, unknown>
): Promise<ProtocolContextData> {
const updateData: Record<string, unknown> = {};
switch (stageCode) {
case 'scientific_question':
updateData.scientificQuestion = data;
break;
case 'pico':
updateData.pico = data;
break;
case 'study_design':
updateData.studyDesign = data;
break;
case 'sample_size':
updateData.sampleSize = data;
break;
case 'endpoints':
updateData.endpoints = data;
break;
}
const result = await this.prisma.protocolContext.update({
where: { conversationId },
data: {
...updateData,
lastActiveAt: new Date(),
},
});
return this.mapToContextData(result);
}
/**
* 标记阶段完成并更新当前阶段
*/
async completeStage(
conversationId: string,
stageCode: ProtocolStageCode,
nextStage?: ProtocolStageCode
): Promise<ProtocolContextData> {
const context = await this.getContext(conversationId);
if (!context) {
throw new Error('Context not found');
}
const completedStages = [...context.completedStages];
if (!completedStages.includes(stageCode)) {
completedStages.push(stageCode);
}
const result = await this.prisma.protocolContext.update({
where: { conversationId },
data: {
completedStages,
currentStage: nextStage ?? context.currentStage,
lastActiveAt: new Date(),
},
});
return this.mapToContextData(result);
}
/**
* 更新当前阶段
*/
async updateCurrentStage(
conversationId: string,
stageCode: ProtocolStageCode
): Promise<ProtocolContextData> {
const result = await this.prisma.protocolContext.update({
where: { conversationId },
data: {
currentStage: stageCode,
lastActiveAt: new Date(),
},
});
return this.mapToContextData(result);
}
/**
* 标记方案完成
*/
async markCompleted(conversationId: string): Promise<ProtocolContextData> {
const result = await this.prisma.protocolContext.update({
where: { conversationId },
data: {
status: 'completed',
lastActiveAt: new Date(),
},
});
return this.mapToContextData(result);
}
/**
* 检查是否所有阶段都已完成
*/
isAllStagesCompleted(context: ProtocolContextData): boolean {
const requiredStages: ProtocolStageCode[] = [
'scientific_question',
'pico',
'study_design',
'sample_size',
'endpoints',
];
return requiredStages.every(stage => context.completedStages.includes(stage));
}
/**
* 获取进度百分比
*/
getProgress(context: ProtocolContextData): number {
const totalStages = 5;
return Math.round((context.completedStages.length / totalStages) * 100);
}
/**
* 获取阶段状态列表
*/
getStagesStatus(context: ProtocolContextData): Array<{
stageCode: ProtocolStageCode;
stageName: string;
status: 'completed' | 'current' | 'pending';
data: Record<string, unknown> | null;
}> {
const stages: Array<{
code: ProtocolStageCode;
name: string;
dataKey: keyof ProtocolContextData;
}> = [
{ code: 'scientific_question', name: '科学问题梳理', dataKey: 'scientificQuestion' },
{ code: 'pico', name: 'PICO要素', dataKey: 'pico' },
{ code: 'study_design', name: '研究设计', dataKey: 'studyDesign' },
{ code: 'sample_size', name: '样本量计算', dataKey: 'sampleSize' },
{ code: 'endpoints', name: '观察指标', dataKey: 'endpoints' },
];
return stages.map(stage => ({
stageCode: stage.code,
stageName: stage.name,
status: context.completedStages.includes(stage.code)
? 'completed' as const
: context.currentStage === stage.code
? 'current' as const
: 'pending' as const,
data: context[stage.dataKey] as unknown as Record<string, unknown> | null ?? null,
}));
}
/**
* 获取用于生成方案的完整数据
*/
getGenerationData(context: ProtocolContextData): {
scientificQuestion: ScientificQuestionData | null;
pico: PICOData | null;
studyDesign: StudyDesignData | null;
sampleSize: SampleSizeData | null;
endpoints: EndpointsData | null;
} {
return {
scientificQuestion: context.scientificQuestion ?? null,
pico: context.pico ?? null,
studyDesign: context.studyDesign ?? null,
sampleSize: context.sampleSize ?? null,
endpoints: context.endpoints ?? null,
};
}
/**
* 将数据库结果映射为上下文数据
*/
private mapToContextData(result: {
id: string;
conversationId: string;
userId: string;
currentStage: string;
status: string;
scientificQuestion: unknown;
pico: unknown;
studyDesign: unknown;
sampleSize: unknown;
endpoints: unknown;
completedStages: string[];
lastActiveAt: Date;
createdAt: Date;
updatedAt: Date;
}): ProtocolContextData {
return {
id: result.id,
conversationId: result.conversationId,
userId: result.userId,
currentStage: result.currentStage as ProtocolStageCode,
status: result.status as ProtocolContextData['status'],
scientificQuestion: result.scientificQuestion as ScientificQuestionData | undefined,
pico: result.pico as PICOData | undefined,
studyDesign: result.studyDesign as StudyDesignData | undefined,
sampleSize: result.sampleSize as SampleSizeData | undefined,
endpoints: result.endpoints as EndpointsData | undefined,
completedStages: result.completedStages as ProtocolStageCode[],
lastActiveAt: result.lastActiveAt,
createdAt: result.createdAt,
updatedAt: result.updatedAt,
};
}
}

View File

@@ -0,0 +1,321 @@
/**
* Protocol Orchestrator
* Protocol Agent的具体实现
*
* @module agent/protocol/services/ProtocolOrchestrator
*/
import { PrismaClient } from '@prisma/client';
import {
BaseAgentOrchestrator,
OrchestratorDependencies,
} from '../../services/BaseAgentOrchestrator.js';
import {
AgentSession,
AgentResponse,
ProtocolContextData,
ProtocolStageCode,
SyncButtonData,
ActionCard,
UserMessageInput,
} from '../../types/index.js';
import { ProtocolContextService } from './ProtocolContextService.js';
import { PromptBuilder } from './PromptBuilder.js';
/** 阶段名称映射 */
const STAGE_NAMES: Record<ProtocolStageCode, string> = {
scientific_question: '科学问题梳理',
pico: 'PICO要素',
study_design: '研究设计',
sample_size: '样本量计算',
endpoints: '观察指标',
};
/** 阶段顺序 */
const STAGE_ORDER: ProtocolStageCode[] = [
'scientific_question',
'pico',
'study_design',
'sample_size',
'endpoints',
];
export class ProtocolOrchestrator extends BaseAgentOrchestrator {
private contextService: ProtocolContextService;
private promptBuilder: PromptBuilder;
constructor(deps: OrchestratorDependencies) {
super(deps);
this.contextService = new ProtocolContextService(deps.prisma);
this.promptBuilder = new PromptBuilder(this.configLoader);
}
/**
* 获取Agent代码
*/
getAgentCode(): string {
return 'protocol_agent';
}
/**
* 覆盖父类handleMessage确保上下文在处理消息前创建
*/
async handleMessage(input: UserMessageInput): Promise<AgentResponse> {
// 确保上下文存在
await this.contextService.getOrCreateContext(input.conversationId, input.userId);
// 调用父类方法处理消息
return super.handleMessage(input);
}
/**
* 获取上下文数据(如果不存在则返回默认结构)
*/
async getContext(conversationId: string): Promise<Record<string, unknown> | null> {
const context = await this.contextService.getContext(conversationId);
if (!context) {
// 返回默认上下文结构,避免 undefined 错误
return {
currentStage: 'scientific_question',
completedStages: [],
status: 'in_progress',
};
}
return context as unknown as Record<string, unknown>;
}
/**
* 保存上下文数据
*/
async saveContext(
conversationId: string,
userId: string,
data: Record<string, unknown>
): Promise<void> {
const context = await this.contextService.getOrCreateContext(conversationId, userId);
const stageCode = context.currentStage;
await this.contextService.updateStageData(conversationId, stageCode, data);
}
/**
* 构建阶段响应
*/
async buildStageResponse(
session: AgentSession,
llmResponse: string,
contextData: Record<string, unknown>
): Promise<AgentResponse> {
const context = contextData as unknown as ProtocolContextData;
const stageCode = session.currentStage as ProtocolStageCode;
const stageName = STAGE_NAMES[stageCode] || session.currentStage;
// 检测是否应该显示同步按钮
const syncButton = this.buildSyncButton(llmResponse, stageCode, context);
// 构建动作卡片
const actionCards = this.buildActionCards(stageCode, context);
return {
content: llmResponse,
stage: stageCode,
stageName,
syncButton,
actionCards,
};
}
/**
* 处理Protocol同步请求
*/
async handleProtocolSync(
conversationId: string,
userId: string,
stageCode: string,
data: Record<string, unknown>
): Promise<{
success: boolean;
context: ProtocolContextData;
nextStage?: ProtocolStageCode;
message?: string;
}> {
const stage = stageCode as ProtocolStageCode;
// 保存阶段数据
await this.contextService.updateStageData(conversationId, stage, {
...data,
confirmed: true,
confirmedAt: new Date(),
});
// 获取下一阶段
const currentIndex = STAGE_ORDER.indexOf(stage);
const nextStage = currentIndex < STAGE_ORDER.length - 1
? STAGE_ORDER[currentIndex + 1]
: undefined;
// 标记当前阶段完成,更新到下一阶段
const context = await this.contextService.completeStage(
conversationId,
stage,
nextStage
);
// 检查是否所有阶段都已完成
const allCompleted = this.contextService.isAllStagesCompleted(context);
return {
success: true,
context,
nextStage,
message: allCompleted
? '🎉 所有核心要素已完成!您可以点击「一键生成研究方案」生成完整方案。'
: nextStage
? `已同步${STAGE_NAMES[stage]},进入${STAGE_NAMES[nextStage]}阶段`
: `已同步${STAGE_NAMES[stage]}`,
};
}
/**
* 获取Protocol上下文服务
*/
getContextService(): ProtocolContextService {
return this.contextService;
}
/**
* 构建同步按钮数据
*/
private buildSyncButton(
llmResponse: string,
stageCode: ProtocolStageCode,
context: ProtocolContextData
): SyncButtonData | undefined {
// 检测LLM响应中是否有已整理好的数据
// 这里用简单的关键词检测,实际可以用更复杂的方式
const readyPatterns = [
'整理',
'总结',
'您的科学问题',
'您的PICO',
'您的研究设计',
'样本量',
'观察指标',
'同步到方案',
];
const hasReadyData = readyPatterns.some(p => llmResponse.includes(p));
if (!hasReadyData) {
return undefined;
}
// 确保 completedStages 存在
const completedStages = context.completedStages || [];
// 检查当前阶段是否已完成
if (completedStages.includes(stageCode)) {
return {
stageCode,
extractedData: {},
label: '已同步',
disabled: true,
};
}
return {
stageCode,
extractedData: this.extractDataFromResponse(llmResponse, stageCode),
label: '✅ 同步到方案',
disabled: false,
};
}
/**
* 从LLM响应中提取结构化数据
*/
private extractDataFromResponse(
response: string,
stageCode: ProtocolStageCode
): Record<string, unknown> {
// 尝试从响应中提取JSON格式的数据
const jsonMatch = response.match(/<extracted_data>([\s\S]*?)<\/extracted_data>/);
if (jsonMatch) {
try {
return JSON.parse(jsonMatch[1]);
} catch {
// 解析失败,继续使用默认逻辑
}
}
// 简单提取实际应该用LLM来提取
switch (stageCode) {
case 'scientific_question':
return { content: response.substring(0, 500), readyToSync: true };
default:
return { readyToSync: true };
}
}
/**
* 构建动作卡片
*/
private buildActionCards(
stageCode: ProtocolStageCode,
context: ProtocolContextData
): ActionCard[] {
const cards: ActionCard[] = [];
// 样本量阶段:添加样本量计算器卡片
if (stageCode === 'sample_size') {
cards.push({
id: 'sample_size_calculator',
type: 'tool',
title: '📊 样本量计算器',
description: '使用专业计算器进行样本量估算',
actionUrl: '/tools/sample-size-calculator',
});
}
// 所有阶段完成后添加一键生成按钮确保context有效
if (context.completedStages && this.contextService.isAllStagesCompleted(context)) {
cards.push({
id: 'generate_protocol',
type: 'action',
title: '🚀 一键生成研究方案',
description: '基于5个核心要素生成完整研究方案',
actionUrl: '/api/aia/protocol-agent/generate',
});
}
return cards;
}
/**
* 获取上下文状态摘要用于前端State Panel
*/
async getContextSummary(conversationId: string): Promise<{
currentStage: string;
stageName: string;
progress: number;
stages: Array<{
stageCode: string;
stageName: string;
status: 'completed' | 'current' | 'pending';
data: Record<string, unknown> | null;
}>;
canGenerate: boolean;
} | null> {
const context = await this.contextService.getContext(conversationId);
if (!context) return null;
return {
currentStage: context.currentStage,
stageName: STAGE_NAMES[context.currentStage] || context.currentStage,
progress: this.contextService.getProgress(context),
stages: this.contextService.getStagesStatus(context),
canGenerate: this.contextService.isAllStagesCompleted(context),
};
}
}

View File

@@ -0,0 +1,11 @@
/**
* Protocol Agent Services Export
*
* @module agent/protocol/services
*/
export { ProtocolContextService } from './ProtocolContextService.js';
export { ProtocolOrchestrator } from './ProtocolOrchestrator.js';
export { PromptBuilder } from './PromptBuilder.js';
export { LLMServiceAdapter, createLLMServiceAdapter } from './LLMServiceAdapter.js';

View File

@@ -0,0 +1,561 @@
/**
* Base Agent Orchestrator
* Agent框架的核心编排器抽象类
*
* @module agent/services/BaseAgentOrchestrator
*/
import { PrismaClient } from '@prisma/client';
import { v4 as uuidv4 } from 'uuid';
import {
AgentSession,
AgentResponse,
UserMessageInput,
StageTransitionResult,
IntentAnalysis,
AgentFullConfig,
AgentStage,
} from '../types/index.js';
import { ConfigLoader } from './ConfigLoader.js';
import { QueryAnalyzer } from './QueryAnalyzer.js';
import { StageManager } from './StageManager.js';
import { TraceLogger } from './TraceLogger.js';
/** Orchestrator依赖注入 */
export interface OrchestratorDependencies {
prisma: PrismaClient;
llmService: LLMServiceInterface;
}
/** LLM服务接口由具体实现提供 */
export interface LLMServiceInterface {
chat(params: {
messages: Array<{ role: string; content: string }>;
model?: string;
temperature?: number;
maxTokens?: number;
}): Promise<{
content: string;
thinkingContent?: string;
tokensUsed: number;
model: string;
}>;
}
/**
* 抽象基类Agent编排器
* 子类需要实现具体Agent的逻辑
*/
export abstract class BaseAgentOrchestrator {
protected prisma: PrismaClient;
protected llmService: LLMServiceInterface;
protected configLoader: ConfigLoader;
protected queryAnalyzer: QueryAnalyzer;
protected stageManager: StageManager;
protected traceLogger: TraceLogger;
protected config: AgentFullConfig | null = null;
constructor(deps: OrchestratorDependencies) {
this.prisma = deps.prisma;
this.llmService = deps.llmService;
this.configLoader = new ConfigLoader(deps.prisma);
this.queryAnalyzer = new QueryAnalyzer(deps.llmService);
this.stageManager = new StageManager(deps.prisma);
this.traceLogger = new TraceLogger(deps.prisma);
}
/**
* 获取Agent唯一标识子类必须实现
*/
abstract getAgentCode(): string;
/**
* 获取上下文数据(子类必须实现)
*/
abstract getContext(conversationId: string): Promise<Record<string, unknown> | null>;
/**
* 保存上下文数据(子类必须实现)
*/
abstract saveContext(
conversationId: string,
userId: string,
data: Record<string, unknown>
): Promise<void>;
/**
* 构建阶段响应(子类必须实现)
*/
abstract buildStageResponse(
session: AgentSession,
llmResponse: string,
context: Record<string, unknown>
): Promise<AgentResponse>;
/**
* 初始化配置
*/
protected async ensureConfig(): Promise<AgentFullConfig> {
if (!this.config) {
this.config = await this.configLoader.loadAgentConfig(this.getAgentCode());
}
return this.config;
}
/**
* 处理用户消息 - 主入口
*/
async handleMessage(input: UserMessageInput): Promise<AgentResponse> {
const traceId = uuidv4();
let stepIndex = 0;
try {
await this.ensureConfig();
// 1. 获取或创建会话
const session = await this.getOrCreateSession(input.conversationId, input.userId);
// 2. 记录输入追踪
await this.traceLogger.log({
sessionId: session.id,
traceId,
stepIndex: stepIndex++,
stepType: 'query',
input: { content: input.content },
stageCode: session.currentStage,
});
// 3. 意图识别
const intent = await this.analyzeIntent(input.content, session);
await this.traceLogger.log({
sessionId: session.id,
traceId,
stepIndex: stepIndex++,
stepType: 'plan',
input: { userMessage: input.content },
output: { intent },
stageCode: session.currentStage,
});
// 4. 根据意图处理
const context = await this.getContext(input.conversationId) || {};
// 5. 执行LLM对话
const llmResponse = await this.executeDialogue(session, input.content, context, intent);
await this.traceLogger.log({
sessionId: session.id,
traceId,
stepIndex: stepIndex++,
stepType: 'execute',
input: { intent, context },
output: { response: llmResponse.content.substring(0, 500) },
stageCode: session.currentStage,
modelUsed: llmResponse.model,
tokensUsed: llmResponse.tokensUsed,
});
// 6. 更新会话统计
await this.updateSessionStats(session.id, llmResponse.tokensUsed);
// 7. 构建响应
const response = await this.buildStageResponse(session, llmResponse.content, context);
response.thinkingContent = llmResponse.thinkingContent;
response.tokensUsed = llmResponse.tokensUsed;
response.modelUsed = llmResponse.model;
return response;
} catch (error) {
// 记录错误
const session = await this.getSession(input.conversationId);
if (session) {
await this.traceLogger.log({
sessionId: session.id,
traceId,
stepIndex,
stepType: 'execute',
stageCode: session.currentStage,
errorType: 'execution_error',
errorMsg: error instanceof Error ? error.message : String(error),
});
}
throw error;
}
}
/**
* 处理同步请求
*/
async handleSync(
conversationId: string,
userId: string,
stageCode: string,
data: Record<string, unknown>
): Promise<StageTransitionResult> {
const traceId = uuidv4();
await this.ensureConfig();
const session = await this.getSession(conversationId);
if (!session) {
throw new Error('Session not found');
}
// 记录同步操作
await this.traceLogger.log({
sessionId: session.id,
traceId,
stepIndex: 0,
stepType: 'sync',
input: { stageCode, data },
stageCode: session.currentStage,
});
// 保存上下文数据
await this.saveContext(conversationId, userId, {
...data,
confirmed: true,
confirmedAt: new Date(),
});
// 执行Reflexion检查如果启用
const reflexionResults = await this.runReflexion(session, stageCode, 'on_sync', data);
// 检查是否有阻塞性错误
const blockingError = reflexionResults.find(r => !r.passed && r.severity === 'error');
if (blockingError) {
return {
success: false,
fromStage: session.currentStage,
toStage: session.currentStage,
message: blockingError.message,
reflexionResults,
};
}
// 获取下一阶段
const nextStage = await this.configLoader.getNextStage(this.getAgentCode(), stageCode);
if (nextStage) {
// 更新会话状态
await this.updateSessionStage(session.id, nextStage.stageCode);
}
return {
success: true,
fromStage: session.currentStage,
toStage: nextStage?.stageCode ?? session.currentStage,
message: nextStage ? `进入${nextStage.stageName}阶段` : '所有阶段已完成',
reflexionResults,
};
}
/**
* 获取会话
*/
protected async getSession(conversationId: string): Promise<AgentSession | null> {
const result = await this.prisma.agentSession.findUnique({
where: { conversationId },
});
if (!result) return null;
return {
id: result.id,
agentId: result.agentId,
conversationId: result.conversationId,
userId: result.userId,
currentStage: result.currentStage,
status: result.status as AgentSession['status'],
contextRef: result.contextRef ?? undefined,
turnCount: result.turnCount,
totalTokens: result.totalTokens,
createdAt: result.createdAt,
updatedAt: result.updatedAt,
};
}
/**
* 获取或创建会话
*/
protected async getOrCreateSession(
conversationId: string,
userId: string
): Promise<AgentSession> {
let session = await this.getSession(conversationId);
if (!session) {
const config = await this.ensureConfig();
const initialStage = config.stages.find((s: AgentStage) => s.isInitial);
const result = await this.prisma.agentSession.create({
data: {
agentId: config.definition.id,
conversationId,
userId,
currentStage: initialStage?.stageCode ?? config.stages[0]?.stageCode ?? 'unknown',
status: 'active',
},
});
session = {
id: result.id,
agentId: result.agentId,
conversationId: result.conversationId,
userId: result.userId,
currentStage: result.currentStage,
status: result.status as AgentSession['status'],
contextRef: result.contextRef ?? undefined,
turnCount: result.turnCount,
totalTokens: result.totalTokens,
createdAt: result.createdAt,
updatedAt: result.updatedAt,
};
}
return session;
}
/**
* 分析用户意图
*/
protected async analyzeIntent(
userMessage: string,
session: AgentSession
): Promise<IntentAnalysis> {
// 简单的意图识别(后续可增强)
const lowerMessage = userMessage.toLowerCase();
// 检测阶段切换意图
if (lowerMessage.includes('继续') ||
lowerMessage.includes('下一步') ||
lowerMessage.includes('next')) {
return {
intent: 'proceed_next_stage',
confidence: 0.9,
entities: {},
suggestedAction: 'stage_transition',
};
}
// 检测同步意图
if (lowerMessage.includes('确认') ||
lowerMessage.includes('同步') ||
lowerMessage.includes('保存')) {
return {
intent: 'confirm_sync',
confidence: 0.85,
entities: {},
suggestedAction: 'sync_data',
};
}
// 默认为对话意图
return {
intent: 'dialogue',
confidence: 0.7,
entities: {},
suggestedAction: 'continue_dialogue',
};
}
/**
* 执行对话
*/
protected async executeDialogue(
session: AgentSession,
userMessage: string,
context: Record<string, unknown>,
intent: IntentAnalysis
): Promise<{
content: string;
thinkingContent?: string;
tokensUsed: number;
model: string;
}> {
const config = await this.ensureConfig();
// 获取系统Prompt
const systemPrompt = await this.configLoader.getSystemPrompt(this.getAgentCode());
// 获取当前阶段Prompt
const stagePrompt = await this.configLoader.getStagePrompt(
this.getAgentCode(),
session.currentStage
);
// 构建消息
const messages: Array<{ role: string; content: string }> = [];
if (systemPrompt) {
messages.push({
role: 'system',
content: this.renderPrompt(systemPrompt.content, { context, intent }),
});
}
if (stagePrompt) {
messages.push({
role: 'system',
content: this.renderPrompt(stagePrompt.content, { context, intent }),
});
}
messages.push({
role: 'user',
content: userMessage,
});
// 调用LLM
return await this.llmService.chat({
messages,
model: config.definition.config?.defaultModel,
temperature: 0.7,
});
}
/**
* 渲染Prompt模板
*/
protected renderPrompt(
template: string,
variables: Record<string, unknown>
): string {
let result = template;
// 简单的变量替换 {{variable}}
for (const [key, value] of Object.entries(variables)) {
const regex = new RegExp(`\\{\\{${key}\\}\\}`, 'g');
result = result.replace(regex, JSON.stringify(value, null, 2));
}
return result;
}
/**
* 运行Reflexion检查
*/
protected async runReflexion(
session: AgentSession,
stageCode: string,
timing: 'on_sync' | 'on_stage_complete' | 'on_generate',
data: Record<string, unknown>
): Promise<Array<{
ruleCode: string;
ruleName: string;
passed: boolean;
severity: 'error' | 'warning' | 'info';
message?: string;
}>> {
const rules = await this.configLoader.getReflexionRules(
this.getAgentCode(),
stageCode,
timing
);
const results = [];
for (const rule of rules) {
try {
let passed = true;
let message: string | undefined;
if (rule.ruleType === 'rule_based' && rule.conditions) {
// 基于规则的检查
passed = this.evaluateRuleConditions(rule.conditions, data);
if (!passed) {
message = `${rule.ruleName}检查未通过`;
}
} else if (rule.ruleType === 'prompt_based' && rule.promptTemplate) {
// 基于Prompt的检查简化版实际应调用LLM
// TODO: 实现LLM-based检查
passed = true;
}
results.push({
ruleCode: rule.ruleCode,
ruleName: rule.ruleName,
passed,
severity: rule.severity,
message,
});
} catch (error) {
results.push({
ruleCode: rule.ruleCode,
ruleName: rule.ruleName,
passed: false,
severity: 'warning' as const,
message: `规则执行出错: ${error instanceof Error ? error.message : String(error)}`,
});
}
}
return results;
}
/**
* 评估规则条件
*/
protected evaluateRuleConditions(
conditions: Record<string, unknown>,
data: Record<string, unknown>
): boolean {
// 简单的规则评估逻辑
for (const [field, requirement] of Object.entries(conditions)) {
const value = data[field];
if (requirement === 'required' && !value) {
return false;
}
if (typeof requirement === 'object' && requirement !== null) {
const req = requirement as Record<string, unknown>;
if (req.minLength && typeof value === 'string' && value.length < (req.minLength as number)) {
return false;
}
if (req.notEmpty && (!value || (Array.isArray(value) && value.length === 0))) {
return false;
}
}
}
return true;
}
/**
* 更新会话统计
*/
protected async updateSessionStats(sessionId: string, tokensUsed: number): Promise<void> {
await this.prisma.agentSession.update({
where: { id: sessionId },
data: {
turnCount: { increment: 1 },
totalTokens: { increment: tokensUsed },
},
});
}
/**
* 更新会话阶段
*/
protected async updateSessionStage(sessionId: string, stageCode: string): Promise<void> {
await this.prisma.agentSession.update({
where: { id: sessionId },
data: { currentStage: stageCode },
});
}
/**
* 获取当前阶段信息
*/
protected async getCurrentStageInfo(stageCode: string): Promise<AgentStage | null> {
const config = await this.ensureConfig();
return config.stages.find((s: AgentStage) => s.stageCode === stageCode) ?? null;
}
}

View File

@@ -0,0 +1,266 @@
/**
* Agent Configuration Loader
* 从数据库加载Agent配置
*
* @module agent/services/ConfigLoader
*/
import { PrismaClient } from '@prisma/client';
import {
AgentDefinition,
AgentStage,
AgentPrompt,
ReflexionRule,
AgentFullConfig,
AgentConfig,
} from '../types/index.js';
export class ConfigLoader {
private prisma: PrismaClient;
// 配置缓存
private configCache: Map<string, { config: AgentFullConfig; loadedAt: Date }> = new Map();
private cacheTTL: number = 5 * 60 * 1000; // 5分钟缓存
constructor(prisma: PrismaClient) {
this.prisma = prisma;
}
/**
* 加载Agent完整配置
* @param agentCode Agent代码如 'protocol_agent'
* @param useCache 是否使用缓存
*/
async loadAgentConfig(agentCode: string, useCache: boolean = true): Promise<AgentFullConfig> {
// 检查缓存
if (useCache) {
const cached = this.configCache.get(agentCode);
if (cached && Date.now() - cached.loadedAt.getTime() < this.cacheTTL) {
return cached.config;
}
}
// 从数据库加载
const definition = await this.loadDefinition(agentCode);
if (!definition) {
throw new Error(`Agent not found: ${agentCode}`);
}
const [stages, prompts, reflexionRules] = await Promise.all([
this.loadStages(definition.id),
this.loadPrompts(definition.id),
this.loadReflexionRules(definition.id),
]);
const config: AgentFullConfig = {
definition,
stages,
prompts,
reflexionRules,
};
// 更新缓存
this.configCache.set(agentCode, { config, loadedAt: new Date() });
return config;
}
/**
* 加载Agent定义
*/
private async loadDefinition(agentCode: string): Promise<AgentDefinition | null> {
const result = await this.prisma.agentDefinition.findUnique({
where: { code: agentCode },
});
if (!result) return null;
return {
id: result.id,
code: result.code,
name: result.name,
description: result.description ?? undefined,
version: result.version,
config: result.config as unknown as AgentConfig | undefined,
isActive: result.isActive,
createdAt: result.createdAt,
updatedAt: result.updatedAt,
};
}
/**
* 加载Agent阶段配置
*/
private async loadStages(agentId: string): Promise<AgentStage[]> {
const results = await this.prisma.agentStage.findMany({
where: { agentId },
orderBy: { sortOrder: 'asc' },
});
return results.map((r) => ({
id: r.id,
agentId: r.agentId,
stageCode: r.stageCode,
stageName: r.stageName,
sortOrder: r.sortOrder,
config: r.config as Record<string, unknown> | undefined,
nextStages: r.nextStages,
isInitial: r.isInitial,
isFinal: r.isFinal,
}));
}
/**
* 加载Agent Prompt模板
*/
private async loadPrompts(agentId: string): Promise<AgentPrompt[]> {
const results = await this.prisma.agentPrompt.findMany({
where: {
agentId,
isActive: true,
},
});
return results.map((r) => ({
id: r.id,
agentId: r.agentId,
stageId: r.stageId ?? undefined,
promptType: r.promptType as AgentPrompt['promptType'],
promptCode: r.promptCode,
content: r.content,
variables: r.variables,
version: r.version,
isActive: r.isActive,
}));
}
/**
* 加载Reflexion规则
*/
private async loadReflexionRules(agentId: string): Promise<ReflexionRule[]> {
const results = await this.prisma.reflexionRule.findMany({
where: {
agentId,
isActive: true,
},
orderBy: { sortOrder: 'asc' },
});
return results.map((r) => ({
id: r.id,
agentId: r.agentId,
ruleCode: r.ruleCode,
ruleName: r.ruleName,
triggerStage: r.triggerStage ?? undefined,
triggerTiming: r.triggerTiming as ReflexionRule['triggerTiming'],
ruleType: r.ruleType as ReflexionRule['ruleType'],
conditions: r.conditions as Record<string, unknown> | undefined,
promptTemplate: r.promptTemplate ?? undefined,
severity: r.severity as ReflexionRule['severity'],
failureAction: r.failureAction as ReflexionRule['failureAction'],
isActive: r.isActive,
sortOrder: r.sortOrder,
}));
}
/**
* 获取指定阶段的Prompt
*/
async getStagePrompt(agentCode: string, stageCode: string): Promise<AgentPrompt | null> {
const config = await this.loadAgentConfig(agentCode);
// 先找阶段特定的Prompt
const stage = config.stages.find((s: AgentStage) => s.stageCode === stageCode);
if (stage) {
const stagePrompt = config.prompts.find(
(p: AgentPrompt) => p.stageId === stage.id && p.promptType === 'stage'
);
if (stagePrompt) return stagePrompt;
}
return null;
}
/**
* 获取系统Prompt
*/
async getSystemPrompt(agentCode: string): Promise<AgentPrompt | null> {
const config = await this.loadAgentConfig(agentCode);
return config.prompts.find(p => p.promptType === 'system' && !p.stageId) ?? null;
}
/**
* 获取提取Prompt
*/
async getExtractionPrompt(agentCode: string, stageCode: string): Promise<AgentPrompt | null> {
const config = await this.loadAgentConfig(agentCode);
const stage = config.stages.find(s => s.stageCode === stageCode);
if (stage) {
return config.prompts.find(
p => p.stageId === stage.id && p.promptType === 'extraction'
) ?? null;
}
return null;
}
/**
* 获取阶段的Reflexion规则
*/
async getReflexionRules(
agentCode: string,
stageCode: string,
timing: ReflexionRule['triggerTiming']
): Promise<ReflexionRule[]> {
const config = await this.loadAgentConfig(agentCode);
return config.reflexionRules.filter(
r => (r.triggerStage === stageCode || r.triggerStage === null) &&
r.triggerTiming === timing
);
}
/**
* 清除缓存
*/
clearCache(agentCode?: string): void {
if (agentCode) {
this.configCache.delete(agentCode);
} else {
this.configCache.clear();
}
}
/**
* 获取阶段流程顺序
*/
async getStageFlow(agentCode: string): Promise<AgentStage[]> {
const config = await this.loadAgentConfig(agentCode);
return config.stages.sort((a, b) => a.sortOrder - b.sortOrder);
}
/**
* 获取下一个阶段
*/
async getNextStage(agentCode: string, currentStageCode: string): Promise<AgentStage | null> {
const stages = await this.getStageFlow(agentCode);
const currentIndex = stages.findIndex(s => s.stageCode === currentStageCode);
if (currentIndex >= 0 && currentIndex < stages.length - 1) {
return stages[currentIndex + 1];
}
return null;
}
/**
* 检查是否为最终阶段
*/
async isFinalStage(agentCode: string, stageCode: string): Promise<boolean> {
const config = await this.loadAgentConfig(agentCode);
const stage = config.stages.find(s => s.stageCode === stageCode);
return stage?.isFinal ?? false;
}
}

View File

@@ -0,0 +1,270 @@
/**
* Query Analyzer
* 用户意图识别和查询分析
*
* @module agent/services/QueryAnalyzer
*/
import { IntentAnalysis } from '../types/index.js';
import { LLMServiceInterface } from './BaseAgentOrchestrator.js';
/** 意图类型定义 */
export type IntentType =
| 'dialogue' // 普通对话
| 'proceed_next_stage' // 进入下一阶段
| 'confirm_sync' // 确认同步
| 'edit_data' // 编辑已有数据
| 'ask_question' // 提问
| 'clarification' // 澄清/补充信息
| 'generate_protocol' // 生成研究方案
| 'export_document'; // 导出文档
/** 意图识别规则 */
interface IntentRule {
type: IntentType;
keywords: string[];
patterns?: RegExp[];
confidence: number;
}
export class QueryAnalyzer {
private llmService: LLMServiceInterface;
// 预定义的意图规则
private intentRules: IntentRule[] = [
{
type: 'proceed_next_stage',
keywords: ['继续', '下一步', '下一个', 'next', '进入下一', '开始下一'],
confidence: 0.9,
},
{
type: 'confirm_sync',
keywords: ['确认', '同步', '保存', '确定', 'confirm', 'save'],
confidence: 0.85,
},
{
type: 'edit_data',
keywords: ['修改', '编辑', '更改', '调整', '更新', 'edit', 'modify'],
confidence: 0.8,
},
{
type: 'generate_protocol',
keywords: ['生成方案', '生成研究方案', '一键生成', '生成全文'],
confidence: 0.95,
},
{
type: 'export_document',
keywords: ['导出', '下载', '保存word', '保存文档', 'export', 'download'],
confidence: 0.9,
},
{
type: 'ask_question',
keywords: ['什么是', '如何', '怎么', '为什么', '是否', '?', ''],
patterns: [/^(什么|如何|怎么|为什么|是否|能不能|可以|请问)/],
confidence: 0.7,
},
{
type: 'clarification',
keywords: ['补充', '另外', '还有', '再说', '顺便'],
confidence: 0.6,
},
];
constructor(llmService: LLMServiceInterface) {
this.llmService = llmService;
}
/**
* 分析用户输入的意图
*/
async analyze(
userMessage: string,
context?: {
currentStage?: string;
completedStages?: string[];
conversationHistory?: Array<{ role: string; content: string }>;
}
): Promise<IntentAnalysis> {
// 1. 基于规则的快速识别
const ruleBasedResult = this.ruleBasedAnalysis(userMessage);
if (ruleBasedResult.confidence >= 0.85) {
return ruleBasedResult;
}
// 2. 上下文增强
const contextEnhanced = this.contextEnhancedAnalysis(
userMessage,
ruleBasedResult,
context
);
// 3. 如果信心度仍然较低可以考虑使用LLM此处简化
// TODO: 实现LLM-based意图识别
return contextEnhanced;
}
/**
* 基于规则的意图分析
*/
private ruleBasedAnalysis(userMessage: string): IntentAnalysis {
const lowerMessage = userMessage.toLowerCase();
const entities: Record<string, unknown> = {};
// 遍历规则匹配
for (const rule of this.intentRules) {
// 关键词匹配
const keywordMatch = rule.keywords.some(keyword =>
lowerMessage.includes(keyword.toLowerCase())
);
// 正则匹配
const patternMatch = rule.patterns?.some(pattern =>
pattern.test(userMessage)
);
if (keywordMatch || patternMatch) {
// 提取实体
this.extractEntities(userMessage, rule.type, entities);
return {
intent: rule.type,
confidence: keywordMatch && patternMatch ?
Math.min(rule.confidence + 0.1, 1) : rule.confidence,
entities,
suggestedAction: this.getSuggestedAction(rule.type),
};
}
}
// 默认为对话意图
return {
intent: 'dialogue',
confidence: 0.5,
entities,
suggestedAction: 'continue_dialogue',
};
}
/**
* 上下文增强分析
*/
private contextEnhancedAnalysis(
userMessage: string,
baseResult: IntentAnalysis,
context?: {
currentStage?: string;
completedStages?: string[];
conversationHistory?: Array<{ role: string; content: string }>;
}
): IntentAnalysis {
if (!context) return baseResult;
const result = { ...baseResult };
// 根据当前阶段调整意图
if (context.currentStage) {
// 如果已完成所有阶段,倾向于生成方案
if (context.completedStages?.length === 5) {
if (userMessage.includes('方案') || userMessage.includes('生成')) {
result.intent = 'generate_protocol';
result.confidence = Math.max(result.confidence, 0.8);
result.suggestedAction = 'generate_protocol';
}
}
}
// 根据对话历史调整
if (context.conversationHistory?.length) {
const lastAssistantMsg = context.conversationHistory
.filter(m => m.role === 'assistant')
.pop();
// 如果上一条AI消息询问是否继续
if (lastAssistantMsg?.content.includes('继续') ||
lastAssistantMsg?.content.includes('下一阶段')) {
if (userMessage.match(/^(好|是|对|可以|行|ok|yes)/i)) {
result.intent = 'proceed_next_stage';
result.confidence = Math.max(result.confidence, 0.85);
result.suggestedAction = 'stage_transition';
}
}
}
return result;
}
/**
* 提取实体
*/
private extractEntities(
message: string,
intentType: IntentType,
entities: Record<string, unknown>
): void {
// 阶段名称实体
const stagePatterns = [
{ pattern: /科学问题/, entity: 'scientific_question' },
{ pattern: /PICO|pico/, entity: 'pico' },
{ pattern: /研究设计/, entity: 'study_design' },
{ pattern: /样本量/, entity: 'sample_size' },
{ pattern: /(终点|结局|观察)指标/, entity: 'endpoints' },
];
for (const { pattern, entity } of stagePatterns) {
if (pattern.test(message)) {
entities['targetStage'] = entity;
break;
}
}
// 数字实体
const numberMatch = message.match(/(\d+)/g);
if (numberMatch) {
entities['numbers'] = numberMatch.map(n => parseInt(n, 10));
}
}
/**
* 获取建议的动作
*/
private getSuggestedAction(intentType: IntentType): string {
const actionMap: Record<IntentType, string> = {
'dialogue': 'continue_dialogue',
'proceed_next_stage': 'stage_transition',
'confirm_sync': 'sync_data',
'edit_data': 'edit_context',
'ask_question': 'answer_question',
'clarification': 'request_clarification',
'generate_protocol': 'generate_protocol',
'export_document': 'export_document',
};
return actionMap[intentType] ?? 'continue_dialogue';
}
/**
* 检测用户是否表达了同意/肯定
*/
isAffirmative(message: string): boolean {
const affirmativePatterns = [
/^(好|是|对|可以|行|ok|yes|yeah|sure|当然|没问题|同意|确认)/i,
/^(嗯|恩|en|um)/i,
];
return affirmativePatterns.some(p => p.test(message.trim()));
}
/**
* 检测用户是否表达了否定
*/
isNegative(message: string): boolean {
const negativePatterns = [
/^(不|否|no|nope|算了|取消|等等|暂时不|还没)/i,
];
return negativePatterns.some(p => p.test(message.trim()));
}
}

View File

@@ -0,0 +1,283 @@
/**
* Stage Manager
* Agent阶段状态管理
*
* @module agent/services/StageManager
*/
import { PrismaClient } from '@prisma/client';
import {
AgentSession,
AgentStage,
StageTransitionResult,
ProtocolStageCode,
} from '../types/index.js';
/** 阶段转换条件 */
export interface TransitionCondition {
fromStage: string;
toStage: string;
conditions: Record<string, unknown>;
}
export class StageManager {
private prisma: PrismaClient;
constructor(prisma: PrismaClient) {
this.prisma = prisma;
}
/**
* 执行阶段转换
*/
async transition(
session: AgentSession,
targetStage: string,
stages: AgentStage[]
): Promise<StageTransitionResult> {
const currentStageConfig = stages.find(s => s.stageCode === session.currentStage);
const targetStageConfig = stages.find(s => s.stageCode === targetStage);
if (!targetStageConfig) {
return {
success: false,
fromStage: session.currentStage,
toStage: targetStage,
message: `目标阶段 ${targetStage} 不存在`,
};
}
// 检查是否允许转换
if (currentStageConfig && !currentStageConfig.nextStages.includes(targetStage)) {
// 检查是否是返回之前的阶段(允许回退)
const targetOrder = targetStageConfig.sortOrder;
const currentOrder = currentStageConfig.sortOrder;
if (targetOrder > currentOrder) {
return {
success: false,
fromStage: session.currentStage,
toStage: targetStage,
message: `不能从 ${currentStageConfig.stageName} 直接跳转到 ${targetStageConfig.stageName}`,
};
}
}
// 更新会话
await this.prisma.agentSession.update({
where: { id: session.id },
data: { currentStage: targetStage },
});
return {
success: true,
fromStage: session.currentStage,
toStage: targetStage,
message: `已进入${targetStageConfig.stageName}阶段`,
};
}
/**
* 获取下一个阶段
*/
getNextStage(
currentStageCode: string,
stages: AgentStage[]
): AgentStage | null {
const sortedStages = [...stages].sort((a, b) => a.sortOrder - b.sortOrder);
const currentIndex = sortedStages.findIndex(s => s.stageCode === currentStageCode);
if (currentIndex >= 0 && currentIndex < sortedStages.length - 1) {
return sortedStages[currentIndex + 1];
}
return null;
}
/**
* 获取上一个阶段
*/
getPreviousStage(
currentStageCode: string,
stages: AgentStage[]
): AgentStage | null {
const sortedStages = [...stages].sort((a, b) => a.sortOrder - b.sortOrder);
const currentIndex = sortedStages.findIndex(s => s.stageCode === currentStageCode);
if (currentIndex > 0) {
return sortedStages[currentIndex - 1];
}
return null;
}
/**
* 检查阶段是否完成
*/
isStageCompleted(
stageCode: ProtocolStageCode,
completedStages: ProtocolStageCode[]
): boolean {
return completedStages.includes(stageCode);
}
/**
* 检查是否所有阶段都已完成
*/
areAllStagesCompleted(
stages: AgentStage[],
completedStages: string[]
): boolean {
const requiredStages = stages
.filter(s => !s.isFinal)
.map(s => s.stageCode);
return requiredStages.every(code => completedStages.includes(code));
}
/**
* 获取当前进度百分比
*/
getProgressPercentage(
stages: AgentStage[],
completedStages: string[]
): number {
const totalStages = stages.filter(s => !s.isFinal).length;
const completed = completedStages.length;
return totalStages > 0 ? Math.round((completed / totalStages) * 100) : 0;
}
/**
* 获取阶段状态摘要
*/
getStagesSummary(
stages: AgentStage[],
currentStageCode: string,
completedStages: string[]
): Array<{
stageCode: string;
stageName: string;
status: 'completed' | 'current' | 'pending';
order: number;
}> {
return stages
.sort((a, b) => a.sortOrder - b.sortOrder)
.map(stage => ({
stageCode: stage.stageCode,
stageName: stage.stageName,
status: completedStages.includes(stage.stageCode)
? 'completed' as const
: stage.stageCode === currentStageCode
? 'current' as const
: 'pending' as const,
order: stage.sortOrder,
}));
}
/**
* 验证阶段转换是否有效
*/
validateTransition(
fromStage: string,
toStage: string,
stages: AgentStage[],
completedStages: string[]
): { valid: boolean; reason?: string } {
const fromStageConfig = stages.find(s => s.stageCode === fromStage);
const toStageConfig = stages.find(s => s.stageCode === toStage);
if (!fromStageConfig || !toStageConfig) {
return { valid: false, reason: '阶段配置不存在' };
}
// 允许回退到已完成的阶段
if (completedStages.includes(toStage)) {
return { valid: true };
}
// 前向转换:检查当前阶段是否已完成
if (toStageConfig.sortOrder > fromStageConfig.sortOrder) {
if (!completedStages.includes(fromStage)) {
return {
valid: false,
reason: `请先完成${fromStageConfig.stageName}阶段`
};
}
}
// 检查是否是允许的转换
if (!fromStageConfig.nextStages.includes(toStage)) {
// 如果是回退,允许
if (toStageConfig.sortOrder < fromStageConfig.sortOrder) {
return { valid: true };
}
return {
valid: false,
reason: `不能从${fromStageConfig.stageName}转换到${toStageConfig.stageName}`
};
}
return { valid: true };
}
/**
* 获取可用的下一步操作
*/
getAvailableActions(
currentStageCode: string,
stages: AgentStage[],
completedStages: string[]
): Array<{
action: string;
label: string;
targetStage?: string;
enabled: boolean;
}> {
const currentStage = stages.find(s => s.stageCode === currentStageCode);
const actions = [];
// 同步当前阶段数据
if (currentStage && !completedStages.includes(currentStageCode)) {
actions.push({
action: 'sync',
label: '同步到方案',
enabled: true,
});
}
// 进入下一阶段
const nextStage = this.getNextStage(currentStageCode, stages);
if (nextStage) {
actions.push({
action: 'next_stage',
label: `进入${nextStage.stageName}`,
targetStage: nextStage.stageCode,
enabled: completedStages.includes(currentStageCode),
});
}
// 生成研究方案(仅当所有阶段完成时)
if (this.areAllStagesCompleted(stages, completedStages)) {
actions.push({
action: 'generate_protocol',
label: '一键生成研究方案',
enabled: true,
});
}
// 返回上一阶段
const prevStage = this.getPreviousStage(currentStageCode, stages);
if (prevStage) {
actions.push({
action: 'prev_stage',
label: `返回${prevStage.stageName}`,
targetStage: prevStage.stageCode,
enabled: true,
});
}
return actions;
}
}

View File

@@ -0,0 +1,349 @@
/**
* Trace Logger
* Agent执行追踪日志记录
*
* @module agent/services/TraceLogger
*/
import { PrismaClient } from '@prisma/client';
import { AgentTraceRecord, StepType } from '../types/index.js';
/** 追踪日志输入 */
export interface TraceLogInput {
sessionId: string;
traceId: string;
stepIndex: number;
stepType: StepType;
input?: Record<string, unknown>;
output?: Record<string, unknown>;
stageCode?: string;
modelUsed?: string;
tokensUsed?: number;
durationMs?: number;
errorType?: string;
errorMsg?: string;
}
export class TraceLogger {
private prisma: PrismaClient;
private enabled: boolean = true;
constructor(prisma: PrismaClient) {
this.prisma = prisma;
}
/**
* 启用/禁用追踪
*/
setEnabled(enabled: boolean): void {
this.enabled = enabled;
}
/**
* 记录追踪日志
*/
async log(input: TraceLogInput): Promise<AgentTraceRecord | null> {
if (!this.enabled) {
return null;
}
try {
const result = await this.prisma.agentTrace.create({
data: {
sessionId: input.sessionId,
traceId: input.traceId,
stepIndex: input.stepIndex,
stepType: input.stepType,
input: input.input ? JSON.parse(JSON.stringify(input.input)) : undefined,
output: input.output ? JSON.parse(JSON.stringify(input.output)) : undefined,
stageCode: input.stageCode ?? undefined,
modelUsed: input.modelUsed ?? undefined,
tokensUsed: input.tokensUsed ?? undefined,
durationMs: input.durationMs ?? undefined,
errorType: input.errorType ?? undefined,
errorMsg: input.errorMsg ?? undefined,
},
});
return {
id: result.id,
sessionId: result.sessionId,
traceId: result.traceId,
stepIndex: result.stepIndex,
stepType: result.stepType as StepType,
input: result.input as Record<string, unknown> | undefined,
output: result.output as Record<string, unknown> | undefined,
stageCode: result.stageCode ?? undefined,
modelUsed: result.modelUsed ?? undefined,
tokensUsed: result.tokensUsed ?? undefined,
durationMs: result.durationMs ?? undefined,
errorType: result.errorType ?? undefined,
errorMsg: result.errorMsg ?? undefined,
createdAt: result.createdAt,
};
} catch (error) {
console.error('[TraceLogger] Failed to log trace:', error);
return null;
}
}
/**
* 批量记录追踪日志
*/
async logBatch(inputs: TraceLogInput[]): Promise<number> {
if (!this.enabled || inputs.length === 0) {
return 0;
}
try {
const result = await this.prisma.agentTrace.createMany({
data: inputs.map(input => ({
sessionId: input.sessionId,
traceId: input.traceId,
stepIndex: input.stepIndex,
stepType: input.stepType,
input: input.input ? JSON.parse(JSON.stringify(input.input)) : undefined,
output: input.output ? JSON.parse(JSON.stringify(input.output)) : undefined,
stageCode: input.stageCode ?? undefined,
modelUsed: input.modelUsed ?? undefined,
tokensUsed: input.tokensUsed ?? undefined,
durationMs: input.durationMs ?? undefined,
errorType: input.errorType ?? undefined,
errorMsg: input.errorMsg ?? undefined,
})),
});
return result.count;
} catch (error) {
console.error('[TraceLogger] Failed to batch log traces:', error);
return 0;
}
}
/**
* 获取会话的所有追踪记录
*/
async getSessionTraces(sessionId: string): Promise<AgentTraceRecord[]> {
const results = await this.prisma.agentTrace.findMany({
where: { sessionId },
orderBy: [
{ traceId: 'asc' },
{ stepIndex: 'asc' },
],
});
return results.map(r => ({
id: r.id,
sessionId: r.sessionId,
traceId: r.traceId,
stepIndex: r.stepIndex,
stepType: r.stepType as StepType,
input: r.input as Record<string, unknown> | undefined,
output: r.output as Record<string, unknown> | undefined,
stageCode: r.stageCode ?? undefined,
modelUsed: r.modelUsed ?? undefined,
tokensUsed: r.tokensUsed ?? undefined,
durationMs: r.durationMs ?? undefined,
errorType: r.errorType ?? undefined,
errorMsg: r.errorMsg ?? undefined,
createdAt: r.createdAt,
}));
}
/**
* 获取单个请求的追踪记录
*/
async getTraceById(traceId: string): Promise<AgentTraceRecord[]> {
const results = await this.prisma.agentTrace.findMany({
where: { traceId },
orderBy: { stepIndex: 'asc' },
});
return results.map(r => ({
id: r.id,
sessionId: r.sessionId,
traceId: r.traceId,
stepIndex: r.stepIndex,
stepType: r.stepType as StepType,
input: r.input as Record<string, unknown> | undefined,
output: r.output as Record<string, unknown> | undefined,
stageCode: r.stageCode ?? undefined,
modelUsed: r.modelUsed ?? undefined,
tokensUsed: r.tokensUsed ?? undefined,
durationMs: r.durationMs ?? undefined,
errorType: r.errorType ?? undefined,
errorMsg: r.errorMsg ?? undefined,
createdAt: r.createdAt,
}));
}
/**
* 获取会话的错误追踪
*/
async getSessionErrors(sessionId: string): Promise<AgentTraceRecord[]> {
const results = await this.prisma.agentTrace.findMany({
where: {
sessionId,
errorType: { not: null },
},
orderBy: { createdAt: 'desc' },
});
return results.map(r => ({
id: r.id,
sessionId: r.sessionId,
traceId: r.traceId,
stepIndex: r.stepIndex,
stepType: r.stepType as StepType,
input: r.input as Record<string, unknown> | undefined,
output: r.output as Record<string, unknown> | undefined,
stageCode: r.stageCode ?? undefined,
modelUsed: r.modelUsed ?? undefined,
tokensUsed: r.tokensUsed ?? undefined,
durationMs: r.durationMs ?? undefined,
errorType: r.errorType ?? undefined,
errorMsg: r.errorMsg ?? undefined,
createdAt: r.createdAt,
}));
}
/**
* 计算会话的Token统计
*/
async getSessionTokenStats(sessionId: string): Promise<{
totalTokens: number;
byModel: Record<string, number>;
byStepType: Record<string, number>;
}> {
const traces = await this.prisma.agentTrace.findMany({
where: { sessionId },
select: {
tokensUsed: true,
modelUsed: true,
stepType: true,
},
});
const stats = {
totalTokens: 0,
byModel: {} as Record<string, number>,
byStepType: {} as Record<string, number>,
};
for (const trace of traces) {
const tokens = trace.tokensUsed ?? 0;
stats.totalTokens += tokens;
if (trace.modelUsed) {
stats.byModel[trace.modelUsed] = (stats.byModel[trace.modelUsed] ?? 0) + tokens;
}
stats.byStepType[trace.stepType] = (stats.byStepType[trace.stepType] ?? 0) + tokens;
}
return stats;
}
/**
* 计算平均响应时间
*/
async getAverageResponseTime(sessionId: string): Promise<number> {
const result = await this.prisma.agentTrace.aggregate({
where: {
sessionId,
durationMs: { not: null },
},
_avg: {
durationMs: true,
},
});
return result._avg.durationMs ?? 0;
}
/**
* 删除旧的追踪记录(清理)
*/
async cleanupOldTraces(daysToKeep: number = 30): Promise<number> {
const cutoffDate = new Date();
cutoffDate.setDate(cutoffDate.getDate() - daysToKeep);
const result = await this.prisma.agentTrace.deleteMany({
where: {
createdAt: { lt: cutoffDate },
},
});
return result.count;
}
/**
* 生成追踪报告
*/
async generateTraceReport(sessionId: string): Promise<{
sessionId: string;
totalSteps: number;
totalTokens: number;
totalDuration: number;
stageBreakdown: Array<{
stage: string;
steps: number;
tokens: number;
duration: number;
}>;
errors: Array<{
traceId: string;
stepType: string;
errorType: string;
errorMsg: string;
}>;
}> {
const traces = await this.getSessionTraces(sessionId);
const stageMap = new Map<string, { steps: number; tokens: number; duration: number }>();
let totalTokens = 0;
let totalDuration = 0;
const errors: Array<{
traceId: string;
stepType: string;
errorType: string;
errorMsg: string;
}> = [];
for (const trace of traces) {
totalTokens += trace.tokensUsed ?? 0;
totalDuration += trace.durationMs ?? 0;
if (trace.stageCode) {
const existing = stageMap.get(trace.stageCode) ?? { steps: 0, tokens: 0, duration: 0 };
stageMap.set(trace.stageCode, {
steps: existing.steps + 1,
tokens: existing.tokens + (trace.tokensUsed ?? 0),
duration: existing.duration + (trace.durationMs ?? 0),
});
}
if (trace.errorType) {
errors.push({
traceId: trace.traceId,
stepType: trace.stepType,
errorType: trace.errorType,
errorMsg: trace.errorMsg ?? '',
});
}
}
return {
sessionId,
totalSteps: traces.length,
totalTokens,
totalDuration,
stageBreakdown: Array.from(stageMap.entries()).map(([stage, stats]) => ({
stage,
...stats,
})),
errors,
};
}
}

View File

@@ -0,0 +1,12 @@
/**
* Agent Services Export
*
* @module agent/services
*/
export { ConfigLoader } from './ConfigLoader.js';
export { BaseAgentOrchestrator, type OrchestratorDependencies, type LLMServiceInterface } from './BaseAgentOrchestrator.js';
export { QueryAnalyzer, type IntentType } from './QueryAnalyzer.js';
export { StageManager, type TransitionCondition } from './StageManager.js';
export { TraceLogger, type TraceLogInput } from './TraceLogger.js';

View File

@@ -0,0 +1,382 @@
/**
* Protocol Agent Framework - Core Types
* 通用Agent框架类型定义
*
* @module agent/types
*/
// ============================================================
// 基础类型
// ============================================================
/** Agent状态枚举 */
export type AgentSessionStatus = 'active' | 'completed' | 'paused' | 'error';
/** 阶段状态枚举 */
export type StageStatus = 'pending' | 'in_progress' | 'completed' | 'skipped';
/** 执行步骤类型 */
export type StepType = 'query' | 'plan' | 'execute' | 'reflect' | 'tool_call' | 'sync';
/** Prompt类型 */
export type PromptType = 'system' | 'stage' | 'extraction' | 'reflexion' | 'generation';
/** Reflexion规则类型 */
export type ReflexionRuleType = 'rule_based' | 'prompt_based';
/** 规则严重性 */
export type ReflexionSeverity = 'error' | 'warning' | 'info';
/** 失败处理动作 */
export type FailureAction = 'block' | 'warn' | 'log';
// ============================================================
// Agent定义相关
// ============================================================
/** Agent全局配置 */
export interface AgentConfig {
defaultModel: string;
maxTurns: number;
timeout: number;
enableTrace: boolean;
enableReflexion: boolean;
}
/** Agent定义 */
export interface AgentDefinition {
id: string;
code: string;
name: string;
description?: string;
version: string;
config?: AgentConfig;
isActive: boolean;
createdAt: Date;
updatedAt: Date;
}
/** Agent阶段定义 */
export interface AgentStage {
id: string;
agentId: string;
stageCode: string;
stageName: string;
sortOrder: number;
config?: Record<string, unknown>;
nextStages: string[];
isInitial: boolean;
isFinal: boolean;
}
/** Agent Prompt定义 */
export interface AgentPrompt {
id: string;
agentId: string;
stageId?: string;
promptType: PromptType;
promptCode: string;
content: string;
variables: string[];
version: number;
isActive: boolean;
}
// ============================================================
// 会话与追踪
// ============================================================
/** Agent会话 */
export interface AgentSession {
id: string;
agentId: string;
conversationId: string;
userId: string;
currentStage: string;
status: AgentSessionStatus;
contextRef?: string;
turnCount: number;
totalTokens: number;
createdAt: Date;
updatedAt: Date;
}
/** 执行追踪记录 */
export interface AgentTraceRecord {
id: string;
sessionId: string;
traceId: string;
stepIndex: number;
stepType: StepType;
input?: Record<string, unknown>;
output?: Record<string, unknown>;
stageCode?: string;
modelUsed?: string;
tokensUsed?: number;
durationMs?: number;
errorType?: string;
errorMsg?: string;
createdAt: Date;
}
// ============================================================
// Reflexion规则
// ============================================================
/** Reflexion规则定义 */
export interface ReflexionRule {
id: string;
agentId: string;
ruleCode: string;
ruleName: string;
triggerStage?: string;
triggerTiming: 'on_sync' | 'on_stage_complete' | 'on_generate';
ruleType: ReflexionRuleType;
conditions?: Record<string, unknown>;
promptTemplate?: string;
severity: ReflexionSeverity;
failureAction: FailureAction;
isActive: boolean;
sortOrder: number;
}
/** Reflexion检查结果 */
export interface ReflexionResult {
ruleCode: string;
ruleName: string;
passed: boolean;
severity: ReflexionSeverity;
message?: string;
details?: Record<string, unknown>;
}
// ============================================================
// Protocol Agent专用类型
// ============================================================
/** Protocol Context状态 */
export type ProtocolContextStatus = 'in_progress' | 'completed' | 'abandoned';
/** Protocol阶段代码 */
export type ProtocolStageCode =
| 'scientific_question'
| 'pico'
| 'study_design'
| 'sample_size'
| 'endpoints';
/** 科学问题数据 */
export interface ScientificQuestionData {
content: string;
background?: string;
significance?: string;
confirmed: boolean;
confirmedAt?: Date;
}
/** PICO要素 */
export interface PICOElement {
value: string;
details?: string;
}
/** PICO数据 */
export interface PICOData {
P: PICOElement; // Population
I: PICOElement; // Intervention
C: PICOElement; // Comparison
O: PICOElement; // Outcome
confirmed: boolean;
confirmedAt?: Date;
}
/** 研究设计数据 */
export interface StudyDesignData {
type: string; // RCT, Cohort, Case-Control, etc.
blinding?: string; // Open, Single-blind, Double-blind
randomization?: string; // Simple, Block, Stratified
duration?: string; // 研究周期
multiCenter?: boolean;
centerCount?: number;
confirmed: boolean;
confirmedAt?: Date;
}
/** 样本量数据 */
export interface SampleSizeData {
total: number;
perGroup?: number;
alpha?: number; // 显著性水平
power?: number; // 统计效力
effectSize?: number; // 效应量
dropoutRate?: number; // 脱落率
justification?: string; // 计算依据
confirmed: boolean;
confirmedAt?: Date;
}
/** 终点指标项 */
export interface EndpointItem {
name: string;
definition?: string;
method?: string;
timePoint?: string;
}
/** 终点指标数据 */
export interface EndpointsData {
primary: EndpointItem[];
secondary: EndpointItem[];
safety: EndpointItem[];
exploratory?: EndpointItem[];
confirmed: boolean;
confirmedAt?: Date;
}
/** Protocol上下文完整数据 */
export interface ProtocolContextData {
id: string;
conversationId: string;
userId: string;
currentStage: ProtocolStageCode;
status: ProtocolContextStatus;
// 5个核心阶段数据
scientificQuestion?: ScientificQuestionData;
pico?: PICOData;
studyDesign?: StudyDesignData;
sampleSize?: SampleSizeData;
endpoints?: EndpointsData;
// 元数据
completedStages: ProtocolStageCode[];
lastActiveAt: Date;
createdAt: Date;
updatedAt: Date;
}
/** Protocol生成记录 */
export interface ProtocolGenerationRecord {
id: string;
contextId: string;
userId: string;
generatedContent: string;
contentVersion: number;
promptUsed: string;
modelUsed: string;
tokensUsed?: number;
durationMs?: number;
wordFileKey?: string;
lastExportedAt?: Date;
status: 'generating' | 'completed' | 'failed';
errorMessage?: string;
createdAt: Date;
updatedAt: Date;
}
// ============================================================
// 服务接口类型
// ============================================================
/** 用户消息输入 */
export interface UserMessageInput {
conversationId: string;
userId: string;
content: string;
messageId?: string;
attachments?: unknown[];
}
/** Agent响应 */
export interface AgentResponse {
content: string;
thinkingContent?: string;
// 元数据
stage: string;
stageName: string;
// 动作卡片
actionCards?: ActionCard[];
// 同步按钮
syncButton?: SyncButtonData;
// Token统计
tokensUsed?: number;
modelUsed?: string;
}
/** 动作卡片 */
export interface ActionCard {
id: string;
type: string;
title: string;
description?: string;
actionUrl?: string;
actionParams?: Record<string, unknown>;
}
/** 同步按钮数据 */
export interface SyncButtonData {
stageCode: string;
extractedData: Record<string, unknown>;
label: string;
disabled?: boolean;
}
/** 同步请求 */
export interface SyncRequest {
conversationId: string;
userId: string;
stageCode: ProtocolStageCode;
data: Record<string, unknown>;
}
/** 同步响应 */
export interface SyncResponse {
success: boolean;
context: ProtocolContextData;
nextStage?: ProtocolStageCode;
message?: string;
}
/** 意图识别结果 */
export interface IntentAnalysis {
intent: string;
confidence: number;
entities: Record<string, unknown>;
suggestedAction?: string;
}
/** 阶段转换结果 */
export interface StageTransitionResult {
success: boolean;
fromStage: string;
toStage: string;
message?: string;
reflexionResults?: ReflexionResult[];
}
// ============================================================
// 配置加载相关
// ============================================================
/** Agent完整配置含所有关联数据 */
export interface AgentFullConfig {
definition: AgentDefinition;
stages: AgentStage[];
prompts: AgentPrompt[];
reflexionRules: ReflexionRule[];
}
/** Prompt渲染上下文 */
export interface PromptRenderContext {
stage: string;
context: ProtocolContextData;
conversationHistory?: Array<{ role: string; content: string }>;
userMessage: string;
additionalVariables?: Record<string, unknown>;
}

View File

@@ -244,3 +244,5 @@ async function matchIntent(query: string): Promise<{

View File

@@ -98,3 +98,5 @@ export async function uploadAttachment(

View File

@@ -27,3 +27,5 @@ export { aiaRoutes };

View File

@@ -368,6 +368,8 @@ runTests().catch((error) => {

View File

@@ -347,6 +347,8 @@ Content-Type: application/json

View File

@@ -283,6 +283,8 @@ export const conflictDetectionService = new ConflictDetectionService();

View File

@@ -233,6 +233,8 @@ curl -X POST http://localhost:3000/api/v1/dc/tool-c/test/execute \

View File

@@ -287,6 +287,8 @@ export const streamAIController = new StreamAIController();

View File

@@ -196,6 +196,8 @@ logger.info('[SessionMemory] 会话记忆管理器已启动', {

View File

@@ -130,6 +130,8 @@ checkTableStructure();

View File

@@ -117,6 +117,8 @@ checkProjectConfig().catch(console.error);

View File

@@ -99,6 +99,8 @@ main();

View File

@@ -556,6 +556,8 @@ URL: https://iit.xunzhengyixue.com/api/v1/iit/patient-wechat/callback

View File

@@ -191,6 +191,8 @@ console.log('');

View File

@@ -508,6 +508,8 @@ export const patientWechatService = new PatientWechatService();

View File

@@ -153,6 +153,8 @@ testDifyIntegration().catch(error => {

View File

@@ -182,6 +182,8 @@ testIitDatabase()

View File

@@ -168,6 +168,8 @@ if (hasError) {

View File

@@ -194,6 +194,8 @@ async function testUrlVerification() {

View File

@@ -275,6 +275,8 @@ main().catch((error) => {

View File

@@ -159,6 +159,8 @@ Write-Host ""

View File

@@ -252,6 +252,8 @@ export interface CachedProtocolRules {

View File

@@ -65,6 +65,8 @@ export default async function healthRoutes(fastify: FastifyInstance) {

View File

@@ -143,6 +143,8 @@ Content-Type: application/json

Some files were not shown because too many files have changed in this diff Show More