feat(ssa): SSA Agent mode MVP - prompt management + Phase 5A guardrails + UX enhancements
Backend: - Agent core prompts (Planner + Coder) now loaded from PromptService with 3-tier fallback (DB -> cache -> hardcoded) - Seed script (seed-ssa-agent-prompts.ts) for idempotent SSA_AGENT_PLANNER + SSA_AGENT_CODER setup - SSA fallback prompts added to prompt.fallbacks.ts - Phase 5A: XML tag extraction, defensive programming prompt, high-fidelity schema injection, AST pre-check - Default agent mode migration + session CRUD (rename/delete) APIs - R Docker: structured error handling (20+ patterns) + AST syntax pre-check Frontend: - Default agent mode (QPER toggle removed), view code fix, analysis result cards in chat - Session history sidebar with inline rename/delete, robust plan parsing from reviewResult - R code export wrapper for local reproducibility (package checks + data loader + polyfills) - SSA workspace CSS updates for sidebar actions and plan display Docs: - SSA module doc v4.2: Prompt inventory (2 Agent active / 11 QPER archived), dev progress updated - System overview doc v6.8: SSA Agent MVP milestone - Deployment checklist: DB-5 (seed script) + BE-10 (prompt management) Made-with: Cursor
This commit is contained in:
@@ -0,0 +1,8 @@
|
||||
-- 1. 修改列默认值:新 session 默认使用 agent 模式
|
||||
ALTER TABLE "ssa_schema"."ssa_sessions"
|
||||
ALTER COLUMN "execution_mode" SET DEFAULT 'agent';
|
||||
|
||||
-- 2. 将所有已有 session 从 qper 更新为 agent
|
||||
UPDATE "ssa_schema"."ssa_sessions"
|
||||
SET "execution_mode" = 'agent'
|
||||
WHERE "execution_mode" = 'qper';
|
||||
@@ -2481,7 +2481,7 @@ model SsaSession {
|
||||
dataPayload Json? @map("data_payload") /// 真实数据(仅R可见)
|
||||
dataOssKey String? @map("data_oss_key") /// OSS 存储 key(大数据)
|
||||
dataProfile Json? @map("data_profile") /// 🆕 Python 生成的 DataProfile(Phase 2A)
|
||||
executionMode String @default("qper") @map("execution_mode") /// qper | agent
|
||||
executionMode String @default("agent") @map("execution_mode") /// qper | agent
|
||||
status String @default("active") /// active | consult | completed | error
|
||||
createdAt DateTime @default(now()) @map("created_at")
|
||||
updatedAt DateTime @updatedAt @map("updated_at")
|
||||
|
||||
238
backend/prisma/seed-ssa-agent-prompts.ts
Normal file
238
backend/prisma/seed-ssa-agent-prompts.ts
Normal file
@@ -0,0 +1,238 @@
|
||||
/**
|
||||
* SSA Agent Prompt 种子脚本
|
||||
*
|
||||
* 将 PlannerAgent / CoderAgent 的系统 Prompt 写入 prompt_templates + prompt_versions,
|
||||
* 使其可在运营管理端进行在线编辑、灰度预览和版本管理。
|
||||
*
|
||||
* 运行方式:
|
||||
* npx tsx prisma/seed-ssa-agent-prompts.ts
|
||||
*
|
||||
* 幂等设计:使用 upsert,可安全重复执行。
|
||||
*/
|
||||
|
||||
import { PrismaClient } from '@prisma/client';
|
||||
|
||||
const prisma = new PrismaClient();
|
||||
|
||||
/* ------------------------------------------------------------------ */
|
||||
/* Prompt 内容 */
|
||||
/* ------------------------------------------------------------------ */
|
||||
|
||||
const SSA_AGENT_PLANNER_CONTENT = `你是一位高级统计分析规划师(Planner Agent)。你的职责是根据用户的研究需求和数据特征,制定严谨的统计分析计划。
|
||||
|
||||
## 数据上下文
|
||||
{{{dataContext}}}
|
||||
|
||||
## 规划规则(铁律)
|
||||
1. 必须声明研究设计类型(横断面 / 队列 / 病例对照 / RCT / 前后对比等)
|
||||
2. 必须明确变量角色:结局变量(outcome)、预测变量(predictors)、分组变量(grouping)、混杂因素(confounders)
|
||||
3. 统计方法选择必须给出理由(数据类型、分布、样本量等)
|
||||
4. 连续变量需考虑正态性:正态→参数方法,非正态→非参数方法
|
||||
5. 分类变量的期望频数 < 5 时应选择 Fisher 精确检验而非卡方检验
|
||||
6. 多因素分析需考虑共线性和 EPV(Events Per Variable)
|
||||
7. 禁止编造任何数据或预测分析结果
|
||||
|
||||
## 输出格式
|
||||
请输出 JSON 格式的分析计划,结构如下:
|
||||
\`\`\`json
|
||||
{
|
||||
"title": "分析计划标题",
|
||||
"designType": "研究设计类型",
|
||||
"variables": {
|
||||
"outcome": ["结局变量名"],
|
||||
"predictors": ["预测变量名"],
|
||||
"grouping": "分组变量名或null",
|
||||
"confounders": ["混杂因素"]
|
||||
},
|
||||
"steps": [
|
||||
{
|
||||
"order": 1,
|
||||
"method": "统计方法名称",
|
||||
"description": "这一步做什么",
|
||||
"rationale": "为什么选这个方法"
|
||||
}
|
||||
],
|
||||
"assumptions": ["需要验证的统计假设"]
|
||||
}
|
||||
\`\`\`
|
||||
|
||||
在 JSON 代码块之后,可以用自然语言补充说明。`;
|
||||
|
||||
const SSA_AGENT_CODER_CONTENT = `你是一位 R 统计编程专家(Coder Agent)。你的职责是根据分析计划生成可在 R Docker 沙箱中执行的 R 代码。
|
||||
|
||||
## 数据上下文
|
||||
{{{dataContext}}}
|
||||
|
||||
## R 代码规范(铁律)
|
||||
|
||||
### 数据加载(重要!)
|
||||
数据已由执行环境**自动加载**到变量 \`df\` 中(data.frame 格式)。
|
||||
**禁止**自己调用 \`load_input_data()\`,直接使用 \`df\` 即可。
|
||||
|
||||
\`\`\`r
|
||||
# df 已存在,直接使用
|
||||
str(df) # 查看结构
|
||||
\`\`\`
|
||||
|
||||
### 输出规范
|
||||
代码最后必须返回一个 list,包含 report_blocks 字段:
|
||||
\`\`\`r
|
||||
# 使用 block_helpers.R 中的函数构造 Block
|
||||
blocks <- list()
|
||||
blocks[[length(blocks) + 1]] <- make_markdown_block("## 分析结果\\n...")
|
||||
blocks[[length(blocks) + 1]] <- make_table_block_from_df(result_df, title = "表1. 统计结果")
|
||||
blocks[[length(blocks) + 1]] <- make_image_block(base64_data, title = "图1. 可视化")
|
||||
blocks[[length(blocks) + 1]] <- make_kv_block(list("P值" = "0.023", "效应量" = "0.45"))
|
||||
|
||||
# 必须以此格式返回
|
||||
list(
|
||||
status = "success",
|
||||
method = "使用的统计方法",
|
||||
report_blocks = blocks
|
||||
)
|
||||
\`\`\`
|
||||
|
||||
### 可用辅助函数(由 block_helpers.R 提供)
|
||||
- \`make_markdown_block(content, title)\` — Markdown 文本块
|
||||
- \`make_table_block(headers, rows, title, footnote)\` — 表格块
|
||||
- \`make_table_block_from_df(df_arg, title, footnote, digits)\` — 从 data.frame 生成表格块(注意参数名不要与 df 变量冲突)
|
||||
- \`make_image_block(base64_data, title, alt)\` — 图片块
|
||||
- \`make_kv_block(items, title)\` — 键值对块
|
||||
|
||||
### 图表生成
|
||||
\`\`\`r
|
||||
library(base64enc)
|
||||
tmp_file <- tempfile(fileext = ".png")
|
||||
png(tmp_file, width = 800, height = 600, res = 120)
|
||||
# ... 绑图代码 ...
|
||||
dev.off()
|
||||
base64_data <- paste0("data:image/png;base64,", base64encode(tmp_file))
|
||||
unlink(tmp_file)
|
||||
\`\`\`
|
||||
|
||||
### 预装可用包(仅限以下包,禁止使用其他包)
|
||||
base, stats, utils, graphics, grDevices,
|
||||
ggplot2, dplyr, tidyr, broom, gtsummary, gt, scales, gridExtra,
|
||||
car, lmtest, survival, meta, base64enc, glue, jsonlite, cowplot
|
||||
|
||||
### 防御性编程(必须遵守!)
|
||||
1. **因子转换**:对分组/分类变量在使用前必须 as.factor(),不可假设已经是 factor
|
||||
2. **缺失值处理**:统计函数必须加 na.rm = TRUE 或在之前 na.omit()
|
||||
3. **安全测试包裹**:所有 t.test / wilcox.test / chisq.test 等检验必须用 tryCatch 包裹
|
||||
4. **样本量检查**:在分组比较前检查各组 n >= 2,否则跳过并说明
|
||||
5. **变量存在性检查**:使用某列前用 if ("col" %in% names(df)) 检查
|
||||
6. **数值安全**:除法前检查分母 != 0,对 Inf/NaN 结果做 is.finite() 过滤
|
||||
7. **图表容错**:绑图代码用 tryCatch 包裹,失败时返回文字说明而非崩溃
|
||||
|
||||
### 禁止事项
|
||||
1. 禁止 install.packages() — 只能用上面列出的预装包
|
||||
2. 禁止调用 load_input_data() — 数据已自动加载到 df
|
||||
3. 禁止访问外部网络 — 无 httr/curl 网络请求
|
||||
4. 禁止读写沙箱外文件 — 只能用 tempfile()
|
||||
5. 禁止 system() / shell() 命令
|
||||
6. 禁止使用 pROC, nortest, exact2x2 等未安装的包
|
||||
|
||||
## 输出格式(铁律!违反即视为失败)
|
||||
1. **必须将完整 R 代码放在 <r_code> 和 </r_code> 标签之间**
|
||||
2. <r_code> 标签外面仅限简要说明(1-3 句话)
|
||||
3. <r_code> 标签里面**只允许纯 R 代码**,绝对禁止混入中文解释性文字或自然语言段落
|
||||
4. 代码必须是可直接执行的 R 脚本,不能有伪代码或占位符
|
||||
5. 代码最后必须返回包含 report_blocks 的 list
|
||||
6. 中文注释只能以 # 开头写在代码行内,禁止出现不带 # 的中文
|
||||
|
||||
示例输出格式:
|
||||
简要说明...
|
||||
|
||||
<r_code>
|
||||
library(ggplot2)
|
||||
# 数据处理
|
||||
df$group <- as.factor(df$group)
|
||||
# ... 完整 R 代码 ...
|
||||
list(status = "success", method = "t_test", report_blocks = blocks)
|
||||
</r_code>`;
|
||||
|
||||
/* ------------------------------------------------------------------ */
|
||||
/* Seed 逻辑 */
|
||||
/* ------------------------------------------------------------------ */
|
||||
|
||||
interface PromptSeed {
|
||||
code: string;
|
||||
name: string;
|
||||
description: string;
|
||||
variables: string[];
|
||||
content: string;
|
||||
}
|
||||
|
||||
const PROMPTS: PromptSeed[] = [
|
||||
{
|
||||
code: 'SSA_AGENT_PLANNER',
|
||||
name: 'SSA Agent 规划师系统 Prompt',
|
||||
description: '智能统计分析 — Planner Agent 的系统提示词,负责制定统计分析计划。模板变量:dataContext(数据上下文)',
|
||||
variables: ['dataContext'],
|
||||
content: SSA_AGENT_PLANNER_CONTENT,
|
||||
},
|
||||
{
|
||||
code: 'SSA_AGENT_CODER',
|
||||
name: 'SSA Agent 编码器系统 Prompt',
|
||||
description: '智能统计分析 — Coder Agent 的系统提示词,负责生成可执行的 R 代码。模板变量:dataContext(数据上下文)',
|
||||
variables: ['dataContext'],
|
||||
content: SSA_AGENT_CODER_CONTENT,
|
||||
},
|
||||
];
|
||||
|
||||
async function seedSSAAgentPrompts() {
|
||||
console.log('🌱 开始写入 SSA Agent Prompt 种子数据...\n');
|
||||
|
||||
for (const p of PROMPTS) {
|
||||
// 1. upsert template
|
||||
const template = await prisma.prompt_templates.upsert({
|
||||
where: { code: p.code },
|
||||
update: {
|
||||
name: p.name,
|
||||
description: p.description,
|
||||
variables: p.variables,
|
||||
},
|
||||
create: {
|
||||
code: p.code,
|
||||
name: p.name,
|
||||
module: 'SSA',
|
||||
description: p.description,
|
||||
variables: p.variables,
|
||||
},
|
||||
});
|
||||
console.log(` ✅ Template: ${p.code} (id=${template.id})`);
|
||||
|
||||
// 2. Check if ACTIVE version exists
|
||||
const existing = await prisma.prompt_versions.findFirst({
|
||||
where: { template_id: template.id, status: 'ACTIVE' },
|
||||
});
|
||||
|
||||
if (existing) {
|
||||
console.log(` ⏭ ACTIVE v${existing.version} already exists — skipping version creation`);
|
||||
continue;
|
||||
}
|
||||
|
||||
// 3. Create version 1 as ACTIVE
|
||||
const version = await prisma.prompt_versions.create({
|
||||
data: {
|
||||
template_id: template.id,
|
||||
version: 1,
|
||||
content: p.content,
|
||||
model_config: { model: 'deepseek-v3', temperature: 0.3 },
|
||||
status: 'ACTIVE',
|
||||
changelog: 'Initial seed — migrated from hardcoded prompt',
|
||||
created_by: 'system-seed',
|
||||
},
|
||||
});
|
||||
console.log(` ✅ Version v${version.version} created (ACTIVE)`);
|
||||
}
|
||||
|
||||
console.log('\n🎉 SSA Agent Prompt 种子数据写入完成!');
|
||||
}
|
||||
|
||||
seedSSAAgentPrompts()
|
||||
.catch((e) => {
|
||||
console.error('❌ Seed failed:', e);
|
||||
process.exit(1);
|
||||
})
|
||||
.finally(() => prisma.$disconnect());
|
||||
@@ -290,6 +290,150 @@ Please provide precise, actionable suggestions.`,
|
||||
},
|
||||
};
|
||||
|
||||
/**
|
||||
* SSA 智能统计分析模块兜底 Prompt
|
||||
*/
|
||||
const SSA_FALLBACKS: Record<string, FallbackPrompt> = {
|
||||
SSA_AGENT_PLANNER: {
|
||||
content: `你是一位高级统计分析规划师(Planner Agent)。你的职责是根据用户的研究需求和数据特征,制定严谨的统计分析计划。
|
||||
|
||||
## 数据上下文
|
||||
{{{dataContext}}}
|
||||
|
||||
## 规划规则(铁律)
|
||||
1. 必须声明研究设计类型(横断面 / 队列 / 病例对照 / RCT / 前后对比等)
|
||||
2. 必须明确变量角色:结局变量(outcome)、预测变量(predictors)、分组变量(grouping)、混杂因素(confounders)
|
||||
3. 统计方法选择必须给出理由(数据类型、分布、样本量等)
|
||||
4. 连续变量需考虑正态性:正态→参数方法,非正态→非参数方法
|
||||
5. 分类变量的期望频数 < 5 时应选择 Fisher 精确检验而非卡方检验
|
||||
6. 多因素分析需考虑共线性和 EPV(Events Per Variable)
|
||||
7. 禁止编造任何数据或预测分析结果
|
||||
|
||||
## 输出格式
|
||||
请输出 JSON 格式的分析计划,结构如下:
|
||||
\`\`\`json
|
||||
{
|
||||
"title": "分析计划标题",
|
||||
"designType": "研究设计类型",
|
||||
"variables": {
|
||||
"outcome": ["结局变量名"],
|
||||
"predictors": ["预测变量名"],
|
||||
"grouping": "分组变量名或null",
|
||||
"confounders": ["混杂因素"]
|
||||
},
|
||||
"steps": [
|
||||
{
|
||||
"order": 1,
|
||||
"method": "统计方法名称",
|
||||
"description": "这一步做什么",
|
||||
"rationale": "为什么选这个方法"
|
||||
}
|
||||
],
|
||||
"assumptions": ["需要验证的统计假设"]
|
||||
}
|
||||
\`\`\`
|
||||
|
||||
在 JSON 代码块之后,可以用自然语言补充说明。`,
|
||||
modelConfig: { model: 'deepseek-v3', temperature: 0.3 },
|
||||
},
|
||||
|
||||
SSA_AGENT_CODER: {
|
||||
content: `你是一位 R 统计编程专家(Coder Agent)。你的职责是根据分析计划生成可在 R Docker 沙箱中执行的 R 代码。
|
||||
|
||||
## 数据上下文
|
||||
{{{dataContext}}}
|
||||
|
||||
## R 代码规范(铁律)
|
||||
|
||||
### 数据加载(重要!)
|
||||
数据已由执行环境**自动加载**到变量 \`df\` 中(data.frame 格式)。
|
||||
**禁止**自己调用 \`load_input_data()\`,直接使用 \`df\` 即可。
|
||||
|
||||
\`\`\`r
|
||||
# df 已存在,直接使用
|
||||
str(df) # 查看结构
|
||||
\`\`\`
|
||||
|
||||
### 输出规范
|
||||
代码最后必须返回一个 list,包含 report_blocks 字段:
|
||||
\`\`\`r
|
||||
# 使用 block_helpers.R 中的函数构造 Block
|
||||
blocks <- list()
|
||||
blocks[[length(blocks) + 1]] <- make_markdown_block("## 分析结果\\n...")
|
||||
blocks[[length(blocks) + 1]] <- make_table_block_from_df(result_df, title = "表1. 统计结果")
|
||||
blocks[[length(blocks) + 1]] <- make_image_block(base64_data, title = "图1. 可视化")
|
||||
blocks[[length(blocks) + 1]] <- make_kv_block(list("P值" = "0.023", "效应量" = "0.45"))
|
||||
|
||||
# 必须以此格式返回
|
||||
list(
|
||||
status = "success",
|
||||
method = "使用的统计方法",
|
||||
report_blocks = blocks
|
||||
)
|
||||
\`\`\`
|
||||
|
||||
### 可用辅助函数(由 block_helpers.R 提供)
|
||||
- \`make_markdown_block(content, title)\` — Markdown 文本块
|
||||
- \`make_table_block(headers, rows, title, footnote)\` — 表格块
|
||||
- \`make_table_block_from_df(df_arg, title, footnote, digits)\` — 从 data.frame 生成表格块(注意参数名不要与 df 变量冲突)
|
||||
- \`make_image_block(base64_data, title, alt)\` — 图片块
|
||||
- \`make_kv_block(items, title)\` — 键值对块
|
||||
|
||||
### 图表生成
|
||||
\`\`\`r
|
||||
library(base64enc)
|
||||
tmp_file <- tempfile(fileext = ".png")
|
||||
png(tmp_file, width = 800, height = 600, res = 120)
|
||||
# ... 绑图代码 ...
|
||||
dev.off()
|
||||
base64_data <- paste0("data:image/png;base64,", base64encode(tmp_file))
|
||||
unlink(tmp_file)
|
||||
\`\`\`
|
||||
|
||||
### 预装可用包(仅限以下包,禁止使用其他包)
|
||||
base, stats, utils, graphics, grDevices,
|
||||
ggplot2, dplyr, tidyr, broom, gtsummary, gt, scales, gridExtra,
|
||||
car, lmtest, survival, meta, base64enc, glue, jsonlite, cowplot
|
||||
|
||||
### 防御性编程(必须遵守!)
|
||||
1. **因子转换**:对分组/分类变量在使用前必须 as.factor(),不可假设已经是 factor
|
||||
2. **缺失值处理**:统计函数必须加 na.rm = TRUE 或在之前 na.omit()
|
||||
3. **安全测试包裹**:所有 t.test / wilcox.test / chisq.test 等检验必须用 tryCatch 包裹
|
||||
4. **样本量检查**:在分组比较前检查各组 n >= 2,否则跳过并说明
|
||||
5. **变量存在性检查**:使用某列前用 if ("col" %in% names(df)) 检查
|
||||
6. **数值安全**:除法前检查分母 != 0,对 Inf/NaN 结果做 is.finite() 过滤
|
||||
7. **图表容错**:绑图代码用 tryCatch 包裹,失败时返回文字说明而非崩溃
|
||||
|
||||
### 禁止事项
|
||||
1. 禁止 install.packages() — 只能用上面列出的预装包
|
||||
2. 禁止调用 load_input_data() — 数据已自动加载到 df
|
||||
3. 禁止访问外部网络 — 无 httr/curl 网络请求
|
||||
4. 禁止读写沙箱外文件 — 只能用 tempfile()
|
||||
5. 禁止 system() / shell() 命令
|
||||
6. 禁止使用 pROC, nortest, exact2x2 等未安装的包
|
||||
|
||||
## 输出格式(铁律!违反即视为失败)
|
||||
1. **必须将完整 R 代码放在 <r_code> 和 </r_code> 标签之间**
|
||||
2. <r_code> 标签外面仅限简要说明(1-3 句话)
|
||||
3. <r_code> 标签里面**只允许纯 R 代码**,绝对禁止混入中文解释性文字或自然语言段落
|
||||
4. 代码必须是可直接执行的 R 脚本,不能有伪代码或占位符
|
||||
5. 代码最后必须返回包含 report_blocks 的 list
|
||||
6. 中文注释只能以 # 开头写在代码行内,禁止出现不带 # 的中文
|
||||
|
||||
示例输出格式:
|
||||
简要说明...
|
||||
|
||||
<r_code>
|
||||
library(ggplot2)
|
||||
# 数据处理
|
||||
df$group <- as.factor(df$group)
|
||||
# ... 完整 R 代码 ...
|
||||
list(status = "success", method = "t_test", report_blocks = blocks)
|
||||
</r_code>`,
|
||||
modelConfig: { model: 'deepseek-v3', temperature: 0.3 },
|
||||
},
|
||||
};
|
||||
|
||||
/**
|
||||
* 所有模块的兜底 Prompt 汇总
|
||||
*/
|
||||
@@ -297,6 +441,7 @@ export const FALLBACK_PROMPTS: Record<string, FallbackPrompt> = {
|
||||
...RVW_FALLBACKS,
|
||||
...ASL_FALLBACKS,
|
||||
...AIA_FALLBACKS,
|
||||
...SSA_FALLBACKS,
|
||||
};
|
||||
|
||||
/**
|
||||
|
||||
@@ -115,12 +115,8 @@ export default async function chatRoutes(app: FastifyInstance) {
|
||||
}
|
||||
// ── H1 结束 ──
|
||||
|
||||
// 3. 读取 session 的执行模式
|
||||
const session = await (prisma.ssaSession as any).findUnique({
|
||||
where: { id: sessionId },
|
||||
select: { executionMode: true },
|
||||
});
|
||||
const executionMode = (session?.executionMode as string) || 'qper';
|
||||
// 3. 执行模式:统一使用 Agent 通道(QPER 已废弃 UI 入口)
|
||||
const executionMode = 'agent';
|
||||
|
||||
// ── Agent 通道分流 ──
|
||||
if (executionMode === 'agent') {
|
||||
|
||||
@@ -50,7 +50,9 @@ export default async function sessionRoutes(app: FastifyInstance) {
|
||||
if (data) {
|
||||
const buffer = await data.toBuffer();
|
||||
const filename = data.filename;
|
||||
title = filename;
|
||||
const baseName = filename.replace(/\.(csv|xlsx?)$/i, '') || '数据';
|
||||
const now = new Date();
|
||||
title = `${baseName} ${now.getMonth() + 1}月${now.getDate()}日`;
|
||||
|
||||
// 生成存储 Key(遵循 OSS 目录结构规范)
|
||||
const uuid = crypto.randomUUID().replace(/-/g, '').substring(0, 16);
|
||||
@@ -113,20 +115,45 @@ export default async function sessionRoutes(app: FastifyInstance) {
|
||||
|
||||
const sessions = await prisma.ssaSession.findMany({
|
||||
where: { userId },
|
||||
orderBy: { createdAt: 'desc' },
|
||||
take: 20
|
||||
orderBy: { updatedAt: 'desc' },
|
||||
take: 30,
|
||||
select: {
|
||||
id: true,
|
||||
title: true,
|
||||
status: true,
|
||||
executionMode: true,
|
||||
createdAt: true,
|
||||
updatedAt: true,
|
||||
},
|
||||
});
|
||||
|
||||
return reply.send(sessions);
|
||||
return reply.send({ sessions });
|
||||
});
|
||||
|
||||
// 获取会话详情
|
||||
// 获取会话详情(含 Agent 执行历史)
|
||||
app.get('/:id', async (req, reply) => {
|
||||
const { id } = req.params as { id: string };
|
||||
|
||||
const session = await prisma.ssaSession.findUnique({
|
||||
where: { id },
|
||||
include: { messages: true }
|
||||
include: {
|
||||
agentExecutions: {
|
||||
orderBy: { createdAt: 'asc' },
|
||||
select: {
|
||||
id: true,
|
||||
query: true,
|
||||
planText: true,
|
||||
reviewResult: true,
|
||||
generatedCode: true,
|
||||
reportBlocks: true,
|
||||
retryCount: true,
|
||||
status: true,
|
||||
errorMessage: true,
|
||||
durationMs: true,
|
||||
createdAt: true,
|
||||
},
|
||||
},
|
||||
},
|
||||
});
|
||||
|
||||
if (!session) {
|
||||
@@ -136,6 +163,60 @@ export default async function sessionRoutes(app: FastifyInstance) {
|
||||
return reply.send(session);
|
||||
});
|
||||
|
||||
/**
|
||||
* PATCH /sessions/:id
|
||||
* 更新会话(当前仅支持 title)
|
||||
*/
|
||||
app.patch('/:id', async (req, reply) => {
|
||||
const userId = getUserId(req);
|
||||
const { id } = req.params as { id: string };
|
||||
const body = req.body as { title?: string };
|
||||
|
||||
const session = await prisma.ssaSession.findUnique({ where: { id } });
|
||||
if (!session) {
|
||||
return reply.status(404).send({ error: 'Session not found' });
|
||||
}
|
||||
if (session.userId !== userId) {
|
||||
return reply.status(403).send({ error: 'Forbidden' });
|
||||
}
|
||||
|
||||
const data: { title?: string } = {};
|
||||
if (typeof body.title === 'string' && body.title.trim()) {
|
||||
data.title = body.title.trim();
|
||||
}
|
||||
if (Object.keys(data).length === 0) {
|
||||
return reply.send(session);
|
||||
}
|
||||
|
||||
const updated = await prisma.ssaSession.update({
|
||||
where: { id },
|
||||
data,
|
||||
});
|
||||
logger.info('[SSA:Session] Session updated', { sessionId: id, title: data.title });
|
||||
return reply.send(updated);
|
||||
});
|
||||
|
||||
/**
|
||||
* DELETE /sessions/:id
|
||||
* 删除会话(级联删除消息、执行记录等)
|
||||
*/
|
||||
app.delete('/:id', async (req, reply) => {
|
||||
const userId = getUserId(req);
|
||||
const { id } = req.params as { id: string };
|
||||
|
||||
const session = await prisma.ssaSession.findUnique({ where: { id } });
|
||||
if (!session) {
|
||||
return reply.status(404).send({ error: 'Session not found' });
|
||||
}
|
||||
if (session.userId !== userId) {
|
||||
return reply.status(403).send({ error: 'Forbidden' });
|
||||
}
|
||||
|
||||
await prisma.ssaSession.delete({ where: { id } });
|
||||
logger.info('[SSA:Session] Session deleted', { sessionId: id });
|
||||
return reply.send({ success: true });
|
||||
});
|
||||
|
||||
/**
|
||||
* PATCH /sessions/:id/execution-mode
|
||||
* 切换双通道执行模式 (qper / agent)
|
||||
|
||||
@@ -14,6 +14,8 @@
|
||||
import { LLMFactory } from '../../../common/llm/adapters/LLMFactory.js';
|
||||
import type { Message as LLMMessage } from '../../../common/llm/adapters/types.js';
|
||||
import { logger } from '../../../common/logging/index.js';
|
||||
import { getPromptService } from '../../../common/prompt/index.js';
|
||||
import { prisma } from '../../../config/database.js';
|
||||
import { sessionBlackboardService } from './SessionBlackboardService.js';
|
||||
import { tokenTruncationService } from './TokenTruncationService.js';
|
||||
import type { AgentPlan } from './AgentPlannerService.js';
|
||||
@@ -38,7 +40,7 @@ export class AgentCoderService {
|
||||
previousCode?: string,
|
||||
): Promise<GeneratedCode> {
|
||||
const dataContext = await this.buildDataContext(sessionId);
|
||||
const systemPrompt = this.buildSystemPrompt(dataContext);
|
||||
const systemPrompt = await this.buildSystemPrompt(dataContext);
|
||||
|
||||
const userMessage = errorFeedback
|
||||
? this.buildRetryMessage(plan, errorFeedback, previousCode)
|
||||
@@ -84,7 +86,7 @@ export class AgentCoderService {
|
||||
previousCode?: string,
|
||||
): Promise<GeneratedCode> {
|
||||
const dataContext = await this.buildDataContext(sessionId);
|
||||
const systemPrompt = this.buildSystemPrompt(dataContext);
|
||||
const systemPrompt = await this.buildSystemPrompt(dataContext);
|
||||
|
||||
const userMessage = errorFeedback
|
||||
? this.buildRetryMessage(plan, errorFeedback, previousCode)
|
||||
@@ -135,13 +137,24 @@ export class AgentCoderService {
|
||||
if (!blackboard) return '(无数据上下文)';
|
||||
|
||||
const truncated = tokenTruncationService.truncate(blackboard, {
|
||||
maxTokens: 1500,
|
||||
maxTokens: 2500,
|
||||
strategy: 'balanced',
|
||||
});
|
||||
return tokenTruncationService.toPromptString(truncated);
|
||||
}
|
||||
|
||||
private buildSystemPrompt(dataContext: string): string {
|
||||
private async buildSystemPrompt(dataContext: string): Promise<string> {
|
||||
try {
|
||||
const promptService = getPromptService(prisma);
|
||||
const rendered = await promptService.get('SSA_AGENT_CODER', { dataContext });
|
||||
return rendered.content;
|
||||
} catch (err) {
|
||||
logger.warn('[AgentCoder] Failed to load prompt from DB, using fallback', { error: (err as Error).message });
|
||||
}
|
||||
return this.fallbackSystemPrompt(dataContext);
|
||||
}
|
||||
|
||||
private fallbackSystemPrompt(dataContext: string): string {
|
||||
return `你是一位 R 统计编程专家(Coder Agent)。你的职责是根据分析计划生成可在 R Docker 沙箱中执行的 R 代码。
|
||||
|
||||
## 数据上下文
|
||||
@@ -199,21 +212,41 @@ base, stats, utils, graphics, grDevices,
|
||||
ggplot2, dplyr, tidyr, broom, gtsummary, gt, scales, gridExtra,
|
||||
car, lmtest, survival, meta, base64enc, glue, jsonlite, cowplot
|
||||
|
||||
### 防御性编程(必须遵守!)
|
||||
1. **因子转换**:对分组/分类变量在使用前必须 as.factor(),不可假设已经是 factor
|
||||
2. **缺失值处理**:统计函数必须加 na.rm = TRUE 或在之前 na.omit()
|
||||
3. **安全测试包裹**:所有 t.test / wilcox.test / chisq.test 等检验必须用 tryCatch 包裹
|
||||
4. **样本量检查**:在分组比较前检查各组 n >= 2,否则跳过并说明
|
||||
5. **变量存在性检查**:使用某列前用 if ("col" %in% names(df)) 检查
|
||||
6. **数值安全**:除法前检查分母 != 0,对 Inf/NaN 结果做 is.finite() 过滤
|
||||
7. **图表容错**:绑图代码用 tryCatch 包裹,失败时返回文字说明而非崩溃
|
||||
|
||||
### 禁止事项
|
||||
1. 禁止 install.packages() — 只能用上面列出的预装包
|
||||
2. 禁止调用 load_input_data() — 数据已自动加载到 df
|
||||
3. 禁止访问外部网络 — 无 httr/curl 网络请求
|
||||
4. 禁止读写沙箱外文件 — 只能用 tempfile()
|
||||
5. 禁止 system() / shell() 命令
|
||||
6. 所有数字结果必须用 tryCatch 包裹,防止 NA/NaN 导致崩溃
|
||||
7. 禁止使用 pROC, nortest, exact2x2 等未安装的包
|
||||
6. 禁止使用 pROC, nortest, exact2x2 等未安装的包
|
||||
|
||||
## 输出格式(铁律!违反即视为失败)
|
||||
1. 必须在 \`\`\`r ... \`\`\` 代码块中输出完整 R 代码
|
||||
2. 代码块外仅限简要说明(1-3 句话)
|
||||
3. **绝对禁止**在代码块内混入中文解释性文字或自然语言段落
|
||||
1. **必须将完整 R 代码放在 <r_code> 和 </r_code> 标签之间**
|
||||
2. <r_code> 标签外面仅限简要说明(1-3 句话)
|
||||
3. <r_code> 标签里面**只允许纯 R 代码**,绝对禁止混入中文解释性文字或自然语言段落
|
||||
4. 代码必须是可直接执行的 R 脚本,不能有伪代码或占位符
|
||||
5. 代码最后必须返回包含 report_blocks 的 list`;
|
||||
5. 代码最后必须返回包含 report_blocks 的 list
|
||||
6. 中文注释只能以 # 开头写在代码行内,禁止出现不带 # 的中文
|
||||
|
||||
示例输出格式:
|
||||
简要说明...
|
||||
|
||||
<r_code>
|
||||
library(ggplot2)
|
||||
# 数据处理
|
||||
df$group <- as.factor(df$group)
|
||||
# ... 完整 R 代码 ...
|
||||
list(status = "success", method = "t_test", report_blocks = blocks)
|
||||
</r_code>`;
|
||||
}
|
||||
|
||||
private buildFirstMessage(plan: AgentPlan): string {
|
||||
@@ -242,9 +275,9 @@ ${plan.assumptions.join('\n') || '无特殊假设'}
|
||||
private buildRetryMessage(plan: AgentPlan, errorFeedback: string, previousCode?: string): string {
|
||||
const codeSection = previousCode
|
||||
? `## 上次失败的完整代码(供参考,请在此基础上修正后输出完整新代码)
|
||||
\`\`\`r
|
||||
<previous_code>
|
||||
${previousCode}
|
||||
\`\`\``
|
||||
</previous_code>`
|
||||
: '';
|
||||
|
||||
return `上一次生成的 R 代码执行失败。
|
||||
@@ -252,9 +285,9 @@ ${previousCode}
|
||||
${codeSection}
|
||||
|
||||
## 错误信息
|
||||
\`\`\`
|
||||
<error>
|
||||
${errorFeedback}
|
||||
\`\`\`
|
||||
</error>
|
||||
|
||||
## 分析计划(不变)
|
||||
- 标题:${plan.title}
|
||||
@@ -263,23 +296,33 @@ ${errorFeedback}
|
||||
- 分组变量:${plan.variables.grouping || '无'}
|
||||
|
||||
## 修复要求
|
||||
1. **仔细分析上面的错误信息**,找到报错的根本原因
|
||||
1. **仔细分析 <error> 中的错误信息**,找到报错的根本原因
|
||||
2. 针对错误原因做精确修复,输出完整的、可直接执行的 R 代码
|
||||
3. 对可能出错的关键步骤使用 tryCatch 包裹
|
||||
4. 用 safe_test 模式包裹统计检验,处理 NA/NaN/Inf
|
||||
5. 检查所有 library() 调用是否在预装包列表内
|
||||
6. 保持 report_blocks 输出格式不变`;
|
||||
6. 保持 report_blocks 输出格式不变
|
||||
7. **必须将修正后的完整代码放在 <r_code>...</r_code> 标签中**`;
|
||||
}
|
||||
|
||||
private parseCode(content: string): GeneratedCode {
|
||||
const codeMatch = content.match(/```r\s*([\s\S]*?)```/)
|
||||
// 三级提取:XML 标签 > Markdown 代码块 > 启发式推断
|
||||
const xmlMatch = content.match(/<r_code>([\s\S]*?)<\/r_code>/);
|
||||
const mdMatch = content.match(/```r\s*([\s\S]*?)```/)
|
||||
|| content.match(/```R\s*([\s\S]*?)```/)
|
||||
|| content.match(/```\s*([\s\S]*?)```/);
|
||||
|
||||
let code: string;
|
||||
if (codeMatch) {
|
||||
code = codeMatch[1].trim();
|
||||
let extractMethod: string;
|
||||
|
||||
if (xmlMatch) {
|
||||
code = xmlMatch[1].trim();
|
||||
extractMethod = 'xml_tag';
|
||||
} else if (mdMatch) {
|
||||
code = mdMatch[1].trim();
|
||||
extractMethod = 'markdown_block';
|
||||
} else {
|
||||
// 启发式:检查是否有足够多的 R 代码特征行
|
||||
const lines = content.split('\n');
|
||||
const rLines = lines.filter(l => {
|
||||
const t = l.trim();
|
||||
@@ -289,9 +332,10 @@ ${errorFeedback}
|
||||
});
|
||||
if (rLines.length >= 3) {
|
||||
code = content.trim();
|
||||
extractMethod = 'heuristic';
|
||||
} else {
|
||||
throw new Error(
|
||||
'LLM 返回内容中未找到有效的 R 代码块。请确保在 ```r ... ``` 中输出代码。'
|
||||
'LLM 返回内容中未找到有效的 R 代码块。请确保在 <r_code>...</r_code> 标签中输出代码。'
|
||||
+ ` (收到 ${content.length} 字符, 首 100 字: ${content.slice(0, 100)})`
|
||||
);
|
||||
}
|
||||
@@ -301,6 +345,8 @@ ${errorFeedback}
|
||||
throw new Error(`解析到的 R 代码过短 (${code.length} 字符),可能生成失败`);
|
||||
}
|
||||
|
||||
logger.debug('[AgentCoder] Code extracted', { extractMethod, codeLength: code.length });
|
||||
|
||||
const packageRegex = /library\((\w+)\)/g;
|
||||
const packages: string[] = [];
|
||||
let match;
|
||||
@@ -309,6 +355,7 @@ ${errorFeedback}
|
||||
}
|
||||
|
||||
const explanation = content
|
||||
.replace(/<r_code>[\s\S]*?<\/r_code>/g, '')
|
||||
.replace(/```r[\s\S]*?```/gi, '')
|
||||
.replace(/```[\s\S]*?```/g, '')
|
||||
.trim()
|
||||
|
||||
@@ -14,6 +14,8 @@
|
||||
import { LLMFactory } from '../../../common/llm/adapters/LLMFactory.js';
|
||||
import type { Message as LLMMessage } from '../../../common/llm/adapters/types.js';
|
||||
import { logger } from '../../../common/logging/index.js';
|
||||
import { getPromptService } from '../../../common/prompt/index.js';
|
||||
import { prisma } from '../../../config/database.js';
|
||||
import { sessionBlackboardService } from './SessionBlackboardService.js';
|
||||
import { tokenTruncationService } from './TokenTruncationService.js';
|
||||
|
||||
@@ -47,7 +49,7 @@ export class AgentPlannerService {
|
||||
): Promise<AgentPlan> {
|
||||
const dataContext = await this.buildDataContext(sessionId);
|
||||
|
||||
const systemPrompt = this.buildSystemPrompt(dataContext);
|
||||
const systemPrompt = await this.buildSystemPrompt(dataContext);
|
||||
|
||||
const messages: LLMMessage[] = [
|
||||
{ role: 'system', content: systemPrompt },
|
||||
@@ -90,7 +92,18 @@ export class AgentPlannerService {
|
||||
return tokenTruncationService.toPromptString(truncated);
|
||||
}
|
||||
|
||||
private buildSystemPrompt(dataContext: string): string {
|
||||
private async buildSystemPrompt(dataContext: string): Promise<string> {
|
||||
try {
|
||||
const promptService = getPromptService(prisma);
|
||||
const rendered = await promptService.get('SSA_AGENT_PLANNER', { dataContext });
|
||||
return rendered.content;
|
||||
} catch (err) {
|
||||
logger.warn('[AgentPlanner] Failed to load prompt from DB, using fallback', { error: (err as Error).message });
|
||||
}
|
||||
return this.fallbackSystemPrompt(dataContext);
|
||||
}
|
||||
|
||||
private fallbackSystemPrompt(dataContext: string): string {
|
||||
return `你是一位高级统计分析规划师(Planner Agent)。你的职责是根据用户的研究需求和数据特征,制定严谨的统计分析计划。
|
||||
|
||||
## 数据上下文
|
||||
|
||||
@@ -34,6 +34,7 @@ interface TruncatedContext {
|
||||
variables: string;
|
||||
pico: string;
|
||||
report: string;
|
||||
highFidelitySchema: string;
|
||||
estimatedTokens: number;
|
||||
}
|
||||
|
||||
@@ -61,12 +62,17 @@ export class TokenTruncationService {
|
||||
const overview = this.formatOverview(blackboard.dataOverview, strategy);
|
||||
const variables = this.formatVariables(blackboard.variableDictionary, strategy);
|
||||
const report = this.formatReport(blackboard, strategy);
|
||||
const highFidelitySchema = this.formatHighFidelitySchema(
|
||||
blackboard.dataOverview,
|
||||
blackboard.variableDictionary,
|
||||
);
|
||||
|
||||
let ctx: TruncatedContext = {
|
||||
pico,
|
||||
overview,
|
||||
variables,
|
||||
report,
|
||||
highFidelitySchema,
|
||||
estimatedTokens: 0,
|
||||
};
|
||||
|
||||
@@ -92,6 +98,7 @@ export class TokenTruncationService {
|
||||
|
||||
if (ctx.pico) parts.push(`## PICO 结构\n${ctx.pico}`);
|
||||
if (ctx.overview) parts.push(`## 数据概览\n${ctx.overview}`);
|
||||
if (ctx.highFidelitySchema) parts.push(`## 数据 Schema(高保真)\n${ctx.highFidelitySchema}`);
|
||||
if (ctx.variables) parts.push(`## 变量列表\n${ctx.variables}`);
|
||||
if (ctx.report) parts.push(`## 数据诊断摘要\n${ctx.report}`);
|
||||
|
||||
@@ -139,6 +146,47 @@ export class TokenTruncationService {
|
||||
}).join('\n');
|
||||
}
|
||||
|
||||
/**
|
||||
* 高保真 Schema:为 CoderAgent 生成包含列类型、样本值、缺失率的详细 Schema。
|
||||
* 每列一行,LLM 可据此精确使用 as.factor() / as.numeric()。
|
||||
*/
|
||||
formatHighFidelitySchema(overview: DataOverview | null, dict: VariableDictEntry[]): string {
|
||||
if (!overview?.profile?.columns?.length) return '';
|
||||
|
||||
const cols = overview.profile.columns as any[];
|
||||
const dictMap = new Map(dict.map(v => [v.name, v]));
|
||||
|
||||
const lines: string[] = ['列名 | R类型 | 缺失率 | 详情'];
|
||||
lines.push('---|---|---|---');
|
||||
|
||||
for (const col of cols) {
|
||||
if (col.isIdLike) continue;
|
||||
|
||||
const dictEntry = dictMap.get(col.name);
|
||||
const confirmedType = dictEntry?.confirmedType ?? dictEntry?.inferredType ?? col.type;
|
||||
const picoTag = dictEntry?.picoRole ? ` [${dictEntry.picoRole}]` : '';
|
||||
const missingPct = col.missingRate != null ? `${(col.missingRate * 100).toFixed(1)}%` : '0%';
|
||||
|
||||
let detail = '';
|
||||
if (col.type === 'numeric') {
|
||||
const parts: string[] = [];
|
||||
if (col.mean != null) parts.push(`M=${Number(col.mean).toFixed(2)}`);
|
||||
if (col.std != null) parts.push(`SD=${Number(col.std).toFixed(2)}`);
|
||||
if (col.min != null && col.max != null) parts.push(`[${col.min}, ${col.max}]`);
|
||||
detail = parts.join(', ');
|
||||
} else if (col.type === 'categorical' && col.topValues?.length) {
|
||||
const levels = col.topValues.slice(0, 5).map((v: any) => `"${v.value}"(${v.count})`).join(', ');
|
||||
detail = `${col.totalLevels ?? col.topValues.length}级: ${levels}`;
|
||||
} else if (col.type === 'datetime') {
|
||||
detail = col.dateRange || (col.minDate && col.maxDate ? `${col.minDate}~${col.maxDate}` : '');
|
||||
}
|
||||
|
||||
lines.push(`${col.name}${picoTag} | ${confirmedType} | ${missingPct} | ${detail}`);
|
||||
}
|
||||
|
||||
return lines.join('\n');
|
||||
}
|
||||
|
||||
private formatReport(bb: SessionBlackboard, strategy: string): string {
|
||||
const report = bb.dataOverview
|
||||
? this.buildReportSummary(bb.dataOverview)
|
||||
@@ -172,7 +220,8 @@ export class TokenTruncationService {
|
||||
}
|
||||
|
||||
private estimateTokens(ctx: TruncatedContext): number {
|
||||
const total = ctx.pico.length + ctx.overview.length + ctx.variables.length + ctx.report.length;
|
||||
const total = ctx.pico.length + ctx.overview.length + ctx.variables.length
|
||||
+ ctx.report.length + ctx.highFidelitySchema.length;
|
||||
return Math.ceil(total / 2);
|
||||
}
|
||||
|
||||
@@ -185,6 +234,7 @@ export class TokenTruncationService {
|
||||
|
||||
result.report = result.report.length > 300 ? result.report.slice(0, 300) + '...' : result.report;
|
||||
|
||||
// 激进模式下,highFidelitySchema 已包含类型信息,可以简化 variables
|
||||
let vars = bb.variableDictionary.filter(v => !v.isIdLike);
|
||||
if (vars.length > 10) {
|
||||
const picoVars = vars.filter(v => v.picoRole);
|
||||
@@ -196,6 +246,12 @@ export class TokenTruncationService {
|
||||
return `- ${v.name}: ${type}`;
|
||||
}).join('\n');
|
||||
|
||||
// 如果还超限,截断高保真 Schema(保留前 15 列)
|
||||
if (this.estimateTokens(result) > maxTokens && result.highFidelitySchema) {
|
||||
const schemaLines = result.highFidelitySchema.split('\n');
|
||||
result.highFidelitySchema = schemaLines.slice(0, 17).join('\n') + '\n...';
|
||||
}
|
||||
|
||||
result.estimatedTokens = this.estimateTokens(result);
|
||||
return result;
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user